Matplotlib Scatter Plots

Visualize relationships between two variables using scatter plots

🔍 Scatter Plot Visualization

Scatter plots are perfect for showing relationships between two numerical variables. Each point represents one observation with x and y coordinates.


import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]
y = [2, 5, 3, 8, 7]

plt.scatter(x, y)
plt.xlabel('X values')
plt.ylabel('Y values')
plt.show()
                                    
2D
Relationships
Color
Coding
Size
Variation

Scatter Plot Features

Customize your scatter plots with various visual elements:

Basic Points

Simple x,y coordinate plotting

plt.scatter()
🎨

Color Mapping

Color points by third variable

c parameter colormap
📏

Size Variation

Point size represents data values

s parameter
🔍

Transparency

Handle overlapping points

alpha

🔹 Basic Scatter Plot

Create simple scatter plots to show data relationships

import matplotlib.pyplot as plt

# Sample data
height = [150, 160, 165, 170, 175, 180, 185]
weight = [50, 60, 65, 70, 75, 80, 85]

# Basic scatter plot
plt.scatter(height, weight)
plt.xlabel('Height (cm)')
plt.ylabel('Weight (kg)')
plt.title('Height vs Weight')
plt.show()

# With custom styling
plt.scatter(height, weight, 
           color='red',      # Point color
           s=100,           # Point size
           alpha=0.7)       # Transparency

plt.xlabel('Height (cm)')
plt.ylabel('Weight (kg)')
plt.title('Height vs Weight (Styled)')
plt.grid(True, alpha=0.3)
plt.show()

🔹 Color-Coded Scatter Plots

Use colors to represent additional data dimensions

import matplotlib.pyplot as plt
import numpy as np

# Generate sample data
np.random.seed(42)
x = np.random.randn(100)
y = np.random.randn(100)
colors = np.random.randn(100)  # Third variable for coloring

# Color-coded scatter plot
plt.scatter(x, y, c=colors, cmap='viridis', alpha=0.7)
plt.colorbar(label='Color Scale')  # Add color bar
plt.xlabel('X values')
plt.ylabel('Y values')
plt.title('Color-Coded Scatter Plot')
plt.show()

# Using different colormaps
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

colormaps = ['viridis', 'plasma', 'coolwarm']
titles = ['Viridis', 'Plasma', 'Coolwarm']

for i, (cmap, title) in enumerate(zip(colormaps, titles)):
    scatter = axes[i].scatter(x, y, c=colors, cmap=cmap, alpha=0.7)
    axes[i].set_title(title)
    plt.colorbar(scatter, ax=axes[i])

plt.tight_layout()
plt.show()

🔹 Variable Point Sizes

Use point size to represent data magnitude

import matplotlib.pyplot as plt
import numpy as np

# Sample data: cities
cities = ['NYC', 'LA', 'Chicago', 'Houston', 'Phoenix']
x_coord = [1, 4, 2, 3, 5]  # Longitude (simplified)
y_coord = [4, 2, 5, 1, 3]  # Latitude (simplified)
population = [8.4, 4.0, 2.7, 2.3, 1.7]  # Population in millions

# Scale population for point sizes
sizes = [p * 100 for p in population]  # Multiply for visibility

plt.scatter(x_coord, y_coord, s=sizes, alpha=0.6, color='blue')

# Add city labels
for i, city in enumerate(cities):
    plt.annotate(city, (x_coord[i], y_coord[i]), 
                xytext=(5, 5), textcoords='offset points')

plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('City Populations (Size = Population)')
plt.show()

# Bubble chart with color and size
np.random.seed(42)
x = np.random.randn(50)
y = np.random.randn(50)
sizes = np.random.randint(20, 200, 50)  # Random sizes
colors = np.random.randn(50)

plt.scatter(x, y, s=sizes, c=colors, cmap='rainbow', alpha=0.6)
plt.colorbar(label='Color Value')
plt.xlabel('X values')
plt.ylabel('Y values')
plt.title('Bubble Chart (Size and Color Vary)')
plt.show()

🔹 Multiple Data Series

Compare different groups in the same scatter plot

import matplotlib.pyplot as plt
import numpy as np

# Generate sample data for two groups
np.random.seed(42)
group1_x = np.random.normal(2, 1, 50)
group1_y = np.random.normal(3, 1, 50)

group2_x = np.random.normal(5, 1, 50)
group2_y = np.random.normal(6, 1, 50)

# Plot both groups
plt.scatter(group1_x, group1_y, 
           color='red', alpha=0.6, label='Group A')
plt.scatter(group2_x, group2_y, 
           color='blue', alpha=0.6, label='Group B')

plt.xlabel('X values')
plt.ylabel('Y values')
plt.title('Two Groups Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Different markers for groups
plt.scatter(group1_x, group1_y, 
           marker='o', color='red', alpha=0.6, label='Circles')
plt.scatter(group2_x, group2_y, 
           marker='s', color='blue', alpha=0.6, label='Squares')

plt.xlabel('X values')
plt.ylabel('Y values')
plt.title('Different Markers for Groups')
plt.legend()
plt.show()

🔹 Advanced Scatter Plot

Combine multiple visual elements for rich data visualization

import matplotlib.pyplot as plt
import numpy as np

# Generate realistic sample data
np.random.seed(42)
n_points = 100

# Simulate student data
study_hours = np.random.normal(5, 2, n_points)
test_scores = study_hours * 10 + np.random.normal(0, 5, n_points)
class_year = np.random.choice([1, 2, 3, 4], n_points)  # Freshman to Senior

# Create advanced scatter plot
fig, ax = plt.subplots(figsize=(10, 8))

# Create scatter plot with multiple dimensions
scatter = ax.scatter(study_hours, test_scores, 
                    c=class_year,           # Color by class year
                    s=study_hours*20,       # Size by study hours
                    alpha=0.6,
                    cmap='viridis',
                    edgecolors='black',     # Add edge to points
                    linewidth=0.5)

# Customize the plot
ax.set_xlabel('Study Hours per Week', fontsize=12)
ax.set_ylabel('Test Score', fontsize=12)
ax.set_title('Student Performance Analysis\n(Size=Study Hours, Color=Class Year)', 
             fontsize=14, fontweight='bold')

# Add colorbar
cbar = plt.colorbar(scatter)
cbar.set_label('Class Year', fontsize=12)
cbar.set_ticks([1, 2, 3, 4])
cbar.set_ticklabels(['Freshman', 'Sophomore', 'Junior', 'Senior'])

# Add grid
ax.grid(True, alpha=0.3)

# Set axis limits
ax.set_xlim(0, 10)
ax.set_ylim(0, 100)

plt.tight_layout()
plt.show()

# Add trend line
from scipy import stats
slope, intercept, r_value, p_value, std_err = stats.linregress(study_hours, test_scores)
line = slope * study_hours + intercept

plt.figure(figsize=(10, 6))
plt.scatter(study_hours, test_scores, alpha=0.6, color='blue')
plt.plot(study_hours, line, 'r-', linewidth=2, 
         label=f'Trend line (R² = {r_value**2:.3f})')

plt.xlabel('Study Hours per Week')
plt.ylabel('Test Score')
plt.title('Study Hours vs Test Scores with Trend Line')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

🧠 Test Your Knowledge

Which parameter controls point size in scatter plots?

What does the 'c' parameter do in plt.scatter()?