Decision Trees
Learn how decision trees make predictions through simple yes/no questions
🌳 What are Decision Trees?
Decision trees are intuitive machine learning models that make predictions by asking a series of yes/no questions. They split data based on features to create a tree-like structure that's easy to understand and interpret.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load data
iris = load_iris()
X, y = iris.data, iris.target
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Create and train decision tree
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X_train, y_train)
# Make predictions
accuracy = tree.score(X_test, y_test)
print(f"Accuracy: {accuracy:.3f}")
Key Concepts
Root Node
Starting point of the tree
Leaf Nodes
Final predictions
Splits
Decision points in the tree
Depth
How deep the tree grows
🔹 Basic Decision Tree
Creating your first decision tree classifier
from sklearn.tree import DecisionTreeClassifier
import numpy as np
# Simple dataset: [height, weight] -> [adult/child]
X = np.array([[150, 40], [160, 50], [170, 60], [180, 70],
[120, 25], [130, 30], [140, 35]])
y = np.array([0, 1, 1, 1, 0, 0, 0]) # 0=child, 1=adult
# Create decision tree
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X, y)
# Make prediction
new_person = [[155, 45]]
prediction = tree.predict(new_person)
print(f"Prediction: {'Adult' if prediction[0] == 1 else 'Child'}")
# Get prediction probability
prob = tree.predict_proba(new_person)
print(f"Probabilities: Child={prob[0][0]:.3f}, Adult={prob[0][1]:.3f}")
🔹 Tree Parameters
Control how your tree grows
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
# Create sample data
X, y = make_classification(n_samples=1000, n_features=4,
n_classes=2, random_state=42)
# Different tree configurations
trees = {
'default': DecisionTreeClassifier(random_state=42),
'max_depth_3': DecisionTreeClassifier(max_depth=3, random_state=42),
'min_samples_20': DecisionTreeClassifier(min_samples_split=20, random_state=42),
'min_leaf_10': DecisionTreeClassifier(min_samples_leaf=10, random_state=42)
}
# Train and compare
for name, tree in trees.items():
tree.fit(X, y)
accuracy = tree.score(X, y)
print(f"{name}: Accuracy={accuracy:.3f}, Depth={tree.tree_.max_depth}")
🔹 Feature Importance
See which features matter most
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_wine
# Load wine dataset
wine = load_wine()
X, y = wine.data, wine.target
# Train tree
tree = DecisionTreeClassifier(random_state=42)
tree.fit(X, y)
# Get feature importance
importance = tree.feature_importances_
feature_names = wine.feature_names
# Show top 5 most important features
top_features = sorted(zip(feature_names, importance),
key=lambda x: x[1], reverse=True)[:5]
print("Top 5 Most Important Features:")
for feature, score in top_features:
print(f"{feature}: {score:.3f}")
🔹 Decision Tree Regression
Using trees for continuous predictions
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error
# Create regression data
X, y = make_regression(n_samples=100, n_features=1,
noise=10, random_state=42)
# Train regression tree
tree_reg = DecisionTreeRegressor(max_depth=3, random_state=42)
tree_reg.fit(X, y)
# Make predictions
y_pred = tree_reg.predict(X)
mse = mean_squared_error(y, y_pred)
print(f"Mean Squared Error: {mse:.2f}")
# Predict new values
new_X = [[0.5], [1.0], [-0.5]]
predictions = tree_reg.predict(new_X)
print(f"Predictions: {predictions}")
🔹 Visualizing Decision Trees
Understanding tree structure
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.datasets import load_iris
# Load data and train simple tree
iris = load_iris()
X, y = iris.data, iris.target
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X, y)
# Text representation of tree
tree_rules = export_text(tree, feature_names=iris.feature_names)
print("Decision Tree Rules:")
print(tree_rules[:500] + "...") # Show first 500 characters
# Tree statistics
print(f"\nTree Statistics:")
print(f"Max depth: {tree.tree_.max_depth}")
print(f"Number of leaves: {tree.tree_.n_leaves}")
print(f"Number of nodes: {tree.tree_.node_count}")