Skip to content

Chapter 6: Decision Tree Algorithms

Decision trees are one of the most intuitive and easy-to-understand algorithms in machine learning. They make decisions through a series of if-else conditions, similar to human thought processes. This chapter will详细介绍 the principles, implementation, and applications of decision trees.

6.1 What is a Decision Tree?

A decision tree is a tree-structured classification and regression algorithm that learns a series of decision rules to make predictions on data. Each internal node represents a test on a feature, each branch represents a test result, and each leaf node represents a class label or numerical value.

6.1.1 Components of a Decision Tree

  • Root node: The top of the tree, containing all training samples
  • Internal nodes: Represent tests on certain features
  • Branches: Represent test results
  • Leaf nodes: Represent classification results or regression values

6.1.2 Advantages of Decision Trees

  • Easy to understand and interpret: The decision process is transparent
  • No data preprocessing required: Can handle both numerical and categorical features
  • Can handle multi-output problems: Simultaneously predict multiple targets
  • Can validate the model: Can validate the model through statistical tests
  • Insensitive to outliers: Based on sorting for splitting

6.1.3 Disadvantages of Decision Trees

  • Prone to overfitting: Especially for deep trees
  • Unstable: Small changes in data can lead to completely different trees
  • Biased toward features with more levels: Information gain bias problem
  • Difficult to express linear relationships: Requires many splits

6.2 Preparing Environment and Data

python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_classification, load_iris, load_wine, make_regression
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree, export_text
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    mean_squared_error, r2_score
)
import graphviz
from sklearn.tree import export_graphviz
import warnings
warnings.filterwarnings('ignore')

# Set random seed
np.random.seed(42)

# Set plot style
plt.style.use('seaborn-v0_8')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

6.3 Principles of Decision Tree Construction

6.3.1 Information Theory Fundamentals

python
def calculate_entropy(y):
    """Calculate entropy"""
    _, counts = np.unique(y, return_counts=True)
    probabilities = counts / len(y)
    entropy = -np.sum(probabilities * np.log2(probabilities + 1e-10))
    return entropy

def calculate_gini(y):
    """Calculate Gini impurity"""
    _, counts = np.unique(y, return_counts=True)
    probabilities = counts / len(y)
    gini = 1 - np.sum(probabilities ** 2)
    return gini

def calculate_information_gain(y, y_left, y_right):
    """Calculate information gain"""
    n = len(y)
    n_left, n_right = len(y_left), len(y_right)
    
    entropy_before = calculate_entropy(y)
    entropy_after = (n_left/n) * calculate_entropy(y_left) + (n_right/n) * calculate_entropy(y_right)
    
    information_gain = entropy_before - entropy_after
    return information_gain

# Demonstrate different purity metrics
y_pure = np.array([1, 1, 1, 1, 1])  # Pure
y_mixed = np.array([0, 0, 1, 1, 1])  # Mixed
y_impure = np.array([0, 0, 1, 1, 2])  # Impure

print("Purity metrics for different datasets:")
print("Dataset\t\tEntropy\t\tGini Impurity")
print("-" * 40)
print(f"Pure [1,1,1,1,1]\t{calculate_entropy(y_pure):.4f}\t\t{calculate_gini(y_pure):.4f}")
print(f"Mixed [0,0,1,1,1]\t{calculate_entropy(y_mixed):.4f}\t\t{calculate_gini(y_mixed):.4f}")
print(f"Impure [0,0,1,1,2]\t{calculate_entropy(y_impure):.4f}\t\t{calculate_gini(y_impure):.4f}")

6.3.2 Splitting Criterion Visualization

python
# Visualize different splitting criteria
def plot_impurity_measures():
    """Visualize different impurity measures"""
    p = np.linspace(0.01, 0.99, 100)
    
    # Impurity measures for binary classification
    entropy = -p * np.log2(p) - (1-p) * np.log2(1-p)
    gini = 2 * p * (1-p)
    misclassification = 1 - np.maximum(p, 1-p)
    
    plt.figure(figsize=(10, 6))
    plt.plot(p, entropy, label='Entropy', linewidth=2)
    plt.plot(p, gini, label='Gini Impurity', linewidth=2)
    plt.plot(p, misclassification, label='Misclassification Rate', linewidth=2)
    
    plt.xlabel('Probability of Class 1')
    plt.ylabel('Impurity')
    plt.title('Comparison of Different Impurity Measures')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

plot_impurity_measures()

6.4 Classification Decision Trees

6.4.1 Simple Binary Classification Example

python
# Create simple binary classification data
X_simple, y_simple = make_classification(
    n_samples=200,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    n_clusters_per_class=1,
    random_state=42
)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X_simple, y_simple, test_size=0.2, random_state=42
)

# Create decision tree classifier
dt_classifier = DecisionTreeClassifier(
    criterion='gini',  # Splitting criterion
    max_depth=3,       # Maximum depth
    min_samples_split=20,  # Minimum samples required to split
    min_samples_leaf=10,   # Minimum samples in leaf node
    random_state=42
)

# Train model
dt_classifier.fit(X_train, y_train)

# Predict
y_pred = dt_classifier.predict(X_test)
y_pred_proba = dt_classifier.predict_proba(X_test)

# Evaluate
accuracy = accuracy_score(y_test, y_pred)
print(f"Decision tree classification accuracy: {accuracy:.4f}")

print("\nClassification Report:")
print(classification_report(y_test, y_pred))

6.4.2 Decision Boundary Visualization

python
def plot_decision_tree_boundary(X, y, model, title="Decision Tree Decision Boundary"):
    """Plot decision tree decision boundary"""
    plt.figure(figsize=(12, 8))
    
    # Create grid
    h = 0.02
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    
    # Predict grid points
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    # Plot decision boundary
    plt.contourf(xx, yy, Z, alpha=0.8, cmap='RdYlBu')
    
    # Plot data points
    colors = ['red', 'blue']
    for i, color in enumerate(colors):
        idx = np.where(y == i)
        plt.scatter(X[idx, 0], X[idx, 1], c=color, 
                   label=f'Class {i}', cmap='RdYlBu', edgecolors='black')
    
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.title(title)
    plt.legend()
    plt.colorbar()
    plt.show()

# Plot decision boundary
plot_decision_tree_boundary(X_train, y_train, dt_classifier, "Decision Tree Classification Boundary")

6.4.3 Decision Tree Visualization

python
# Visualize decision tree structure
plt.figure(figsize=(15, 10))
plot_tree(dt_classifier, 
          feature_names=['Feature1', 'Feature2'],
          class_names=['Class0', 'Class1'],
          filled=True,
          rounded=True,
          fontsize=10)
plt.title('Decision Tree Structure')
plt.show()

# Text form of decision tree
tree_rules = export_text(dt_classifier, 
                        feature_names=['Feature1', 'Feature2'])
print("Decision tree rules (text form):")
print(tree_rules)

6.4.4 Feature Importance

python
# Feature importance analysis
feature_importance = dt_classifier.feature_importances_
feature_names = ['Feature1', 'Feature2']

plt.figure(figsize=(8, 6))
plt.bar(feature_names, feature_importance, color=['skyblue', 'lightcoral'])
plt.title('Decision Tree Feature Importance')
plt.xlabel('Feature')
plt.ylabel('Importance')
plt.show()

print("Feature importance:")
for name, importance in zip(feature_names, feature_importance):
    print(f"{name}: {importance:.4f}")

6.5 Regression Decision Trees

6.5.1 Creating Regression Data

python
# Create regression dataset
X_reg, y_reg = make_regression(
    n_samples=200,
    n_features=1,
    noise=10,
    random_state=42
)

# Add some nonlinear relationship
X_reg = X_reg.flatten()
y_reg = y_reg + 0.1 * X_reg**2

# Split data
X_train_reg, X_test_reg, y_train_reg, y_test_reg = train_test_split(
    X_reg.reshape(-1, 1), y_reg, test_size=0.2, random_state=42
)

# Create regression decision trees with different depths
depths = [2, 5, 10, None]
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Regression Decision Trees with Different Depths', fontsize=16)

for i, depth in enumerate(depths):
    row = i // 2
    col = i % 2
    
    # Train model
    dt_regressor = DecisionTreeRegressor(
        max_depth=depth,
        min_samples_split=20,
        min_samples_leaf=10,
        random_state=42
    )
    dt_regressor.fit(X_train_reg, y_train_reg)
    
    # Predict
    y_pred_reg = dt_regressor.predict(X_test_reg)
    r2 = r2_score(y_test_reg, y_pred_reg)
    rmse = np.sqrt(mean_squared_error(y_test_reg, y_pred_reg))
    
    # Plot results
    X_plot = np.linspace(X_reg.min(), X_reg.max(), 100).reshape(-1, 1)
    y_plot = dt_regressor.predict(X_plot)
    
    axes[row, col].scatter(X_train_reg, y_train_reg, alpha=0.6, label='Training Data')
    axes[row, col].scatter(X_test_reg, y_test_reg, alpha=0.6, color='green', label='Test Data')
    axes[row, col].plot(X_plot, y_plot, color='red', linewidth=2, label='Decision Tree Prediction')
    
    depth_str = str(depth) if depth is not None else 'Unlimited'
    axes[row, col].set_title(f'Depth={depth_str}, R²={r2:.3f}, RMSE={rmse:.1f}')
    axes[row, col].set_xlabel('X')
    axes[row, col].set_ylabel('y')
    axes[row, col].legend()
    axes[row, col].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

6.5.2 Regression Tree Splitting Process

python
# Demonstrate regression tree splitting process
def demonstrate_regression_splits():
    """Demonstrate regression decision tree splitting process"""
    # Create simple 1D data
    np.random.seed(42)
    X_demo = np.linspace(0, 10, 50).reshape(-1, 1)
    y_demo = np.sin(X_demo.flatten()) + 0.1 * np.random.randn(50)
    
    # Train shallow decision tree
    dt_demo = DecisionTreeRegressor(max_depth=3, random_state=42)
    dt_demo.fit(X_demo, y_demo)
    
    # Get split points
    tree = dt_demo.tree_
    
    plt.figure(figsize=(12, 8))
    
    # Plot original data
    plt.scatter(X_demo, y_demo, alpha=0.6, color='blue', label='Training Data')
    
    # Plot prediction curve
    X_plot = np.linspace(0, 10, 200).reshape(-1, 1)
    y_plot = dt_demo.predict(X_plot)
    plt.plot(X_plot, y_plot, color='red', linewidth=2, label='Decision Tree Prediction')
    
    # Mark split points
    def get_split_points(node_id, depth=0):
        if tree.children_left[node_id] != tree.children_right[node_id]:
            split_value = tree.threshold[node_id]
            plt.axvline(x=split_value, color='green', linestyle='--', alpha=0.7)
            plt.text(split_value, plt.ylim()[1] - 0.1 * (depth + 1), 
                    f'Split {depth+1}', rotation=90, ha='right')
            
            get_split_points(tree.children_left[node_id], depth + 1)
            get_split_points(tree.children_right[node_id], depth + 1)
    
    get_split_points(0)
    
    plt.xlabel('X')
    plt.ylabel('y')
    plt.title('Regression Decision Tree Splitting Process')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    # Print split information
    print("Decision tree split information:")
    print(export_text(dt_demo, feature_names=['X']))

demonstrate_regression_splits()

6.6 Decision Tree Hyperparameters

6.6.1 Main Hyperparameters Explanation

python
# Demonstrate impact of different hyperparameters
def compare_hyperparameters():
    """Compare impact of different hyperparameters on decision trees"""
    
    # Use iris dataset
    iris = load_iris()
    X_iris, y_iris = iris.data, iris.target
    X_train_iris, X_test_iris, y_train_iris, y_test_iris = train_test_split(
        X_iris, y_iris, test_size=0.2, random_state=42
    )
    
    # Define different hyperparameter combinations
    hyperparams = {
        'max_depth': [3, 5, 10, None],
        'min_samples_split': [2, 10, 20],
        'min_samples_leaf': [1, 5, 10],
        'criterion': ['gini', 'entropy']
    }
    
    results = []
    
    # Test impact of max_depth
    print("Impact of max_depth parameter:")
    print("Depth\tTraining Accuracy\tTest Accuracy\tLeaf Nodes")
    print("-" * 50)
    
    for depth in hyperparams['max_depth']:
        dt = DecisionTreeClassifier(max_depth=depth, random_state=42)
        dt.fit(X_train_iris, y_train_iris)
        
        train_acc = dt.score(X_train_iris, y_train_iris)
        test_acc = dt.score(X_test_iris, y_test_iris)
        n_leaves = dt.get_n_leaves()
        
        depth_str = str(depth) if depth is not None else 'Unlimited'
        print(f"{depth_str}\t{train_acc:.4f}\t\t{test_acc:.4f}\t\t{n_leaves}")
        
        results.append({
            'param': 'max_depth',
            'value': depth_str,
            'train_acc': train_acc,
            'test_acc': test_acc,
            'n_leaves': n_leaves
        })
    
    # Visualize impact of max_depth
    depths_numeric = [3, 5, 10, 20]  # Use 20 instead of None for visualization
    train_accs = []
    test_accs = []
    
    for depth in depths_numeric:
        dt = DecisionTreeClassifier(max_depth=depth if depth != 20 else None, random_state=42)
        dt.fit(X_train_iris, y_train_iris)
        train_accs.append(dt.score(X_train_iris, y_train_iris))
        test_accs.append(dt.score(X_test_iris, y_test_iris))
    
    plt.figure(figsize=(10, 6))
    plt.plot(depths_numeric, train_accs, 'o-', label='Training Accuracy', linewidth=2)
    plt.plot(depths_numeric, test_accs, 'o-', label='Test Accuracy', linewidth=2)
    plt.xlabel('Maximum Depth')
    plt.ylabel('Accuracy')
    plt.title('Impact of Decision Tree Depth on Performance')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xticks(depths_numeric, ['3', '5', '10', 'Unlimited'])
    plt.show()

compare_hyperparameters()

6.6.2 Grid Search Optimization

python
# Use grid search to optimize hyperparameters
param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': [2, 5, 10, 20],
    'min_samples_leaf': [1, 2, 5, 10],
    'criterion': ['gini', 'entropy']
}

# Use wine dataset
wine = load_wine()
X_wine, y_wine = wine.data, wine.target
X_train_wine, X_test_wine, y_train_wine, y_test_wine = train_test_split(
    X_wine, y_wine, test_size=0.2, random_state=42
)

# Grid search
grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1
)

print("Performing grid search...")
grid_search.fit(X_train_wine, y_train_wine)

print("Grid search results:")
print(f"Best parameters: {grid_search.best_params_}")
print(f"Best cross-validation score: {grid_search.best_score_:.4f}")

# Test best model
best_dt = grid_search.best_estimator_
test_accuracy = best_dt.score(X_test_wine, y_test_wine)
print(f"Test set accuracy: {test_accuracy:.4f}")

# Visualize grid search results
results_df = pd.DataFrame(grid_search.cv_results_)

# Select several important parameters for visualization
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# max_depth vs performance
depth_results = results_df.groupby('param_max_depth')['mean_test_score'].mean()
axes[0].bar(range(len(depth_results)), depth_results.values)
axes[0].set_xticks(range(len(depth_results)))
axes[0].set_xticklabels([str(x) if x is not None else 'None' for x in depth_results.index])
axes[0].set_xlabel('Maximum Depth')
axes[0].set_ylabel('Mean Cross-Validation Score')
axes[0].set_title('Impact of Maximum Depth on Performance')

# criterion vs performance
criterion_results = results_df.groupby('param_criterion')['mean_test_score'].mean()
axes[1].bar(criterion_results.index, criterion_results.values, color=['orange', 'green'])
axes[1].set_xlabel('Splitting Criterion')
axes[1].set_ylabel('Mean Cross-Validation Score')
axes[1].set_title('Impact of Splitting Criterion on Performance')

plt.tight_layout()
plt.show()

6.7 Overfitting and Pruning

6.7.1 Overfitting Demonstration

python
def demonstrate_overfitting():
    """Demonstrate decision tree overfitting phenomenon"""
    
    # Create noisy data
    np.random.seed(42)
    X_noise = np.random.uniform(-3, 3, 200).reshape(-1, 1)
    y_noise = np.sin(X_noise.flatten()) + 0.3 * np.random.randn(200)
    
    X_train_noise, X_test_noise, y_train_noise, y_test_noise = train_test_split(
        X_noise, y_noise, test_size=0.3, random_state=42
    )
    
    # Train decision trees with different complexities
    complexities = [
        {'max_depth': 2, 'min_samples_leaf': 20},  # Simple
        {'max_depth': 5, 'min_samples_leaf': 10},  # Medium
        {'max_depth': None, 'min_samples_leaf': 1}  # Complex
    ]
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    for i, params in enumerate(complexities):
        dt = DecisionTreeRegressor(random_state=42, **params)
        dt.fit(X_train_noise, y_train_noise)
        
        # Calculate performance
        train_score = dt.score(X_train_noise, y_train_noise)
        test_score = dt.score(X_test_noise, y_test_noise)
        
        # Plot results
        X_plot = np.linspace(-3, 3, 300).reshape(-1, 1)
        y_plot = dt.predict(X_plot)
        
        axes[i].scatter(X_train_noise, y_train_noise, alpha=0.6, label='Training Data')
        axes[i].scatter(X_test_noise, y_test_noise, alpha=0.6, color='green', label='Test Data')
        axes[i].plot(X_plot, y_plot, color='red', linewidth=2, label='Decision Tree Prediction')
        
        # True function
        y_true = np.sin(X_plot.flatten())
        axes[i].plot(X_plot, y_true, color='black', linestyle='--', alpha=0.7, label='True Function')
        
        complexity_name = ['Simple Model', 'Medium Complexity', 'Complex Model'][i]
        axes[i].set_title(f'{complexity_name}\nTraining R²={train_score:.3f}, Test R²={test_score:.3f}')
        axes[i].set_xlabel('X')
        axes[i].set_ylabel('y')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

demonstrate_overfitting()

6.7.2 Learning Curve Analysis

python
from sklearn.model_selection import learning_curve

def plot_learning_curves():
    """Plot learning curves to analyze overfitting"""
    
    # Use wine dataset
    wine = load_wine()
    X, y = wine.data, wine.target
    
    # Compare models with different complexities
    models = {
        'Simple Decision Tree': DecisionTreeClassifier(max_depth=3, min_samples_leaf=10, random_state=42),
        'Complex Decision Tree': DecisionTreeClassifier(max_depth=None, min_samples_leaf=1, random_state=42),
        'Best Decision Tree': grid_search.best_estimator_
    }
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    for i, (name, model) in enumerate(models.items()):
        train_sizes, train_scores, val_scores = learning_curve(
            model, X, y, cv=5, n_jobs=-1,
            train_sizes=np.linspace(0.1, 1.0, 10),
            scoring='accuracy'
        )
        
        train_mean = np.mean(train_scores, axis=1)
        train_std = np.std(train_scores, axis=1)
        val_mean = np.mean(val_scores, axis=1)
        val_std = np.std(val_scores, axis=1)
        
        axes[i].plot(train_sizes, train_mean, 'o-', color='blue', label='Training Score')
        axes[i].fill_between(train_sizes, train_mean - train_std, train_mean + train_std, 
                           alpha=0.1, color='blue')
        
        axes[i].plot(train_sizes, val_mean, 'o-', color='red', label='Validation Score')
        axes[i].fill_between(train_sizes, val_mean - val_std, val_mean + val_std, 
                           alpha=0.1, color='red')
        
        axes[i].set_xlabel('Training Samples')
        axes[i].set_ylabel('Accuracy')
        axes[i].set_title(f'{name} Learning Curve')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_learning_curves()

6.8 Practical Application Cases

6.8.1 Customer Churn Prediction

python
# Create customer churn prediction dataset
def create_customer_churn_dataset():
    """Create customer churn prediction dataset"""
    np.random.seed(42)
    n_samples = 1000
    
    # Generate features
    age = np.random.normal(40, 15, n_samples)
    monthly_charges = np.random.normal(70, 20, n_samples)
    total_charges = monthly_charges * np.random.normal(24, 12, n_samples)  # Average 24 months
    contract_length = np.random.choice([1, 12, 24], n_samples, p=[0.3, 0.4, 0.3])
    tech_support = np.random.choice([0, 1], n_samples, p=[0.6, 0.4])
    online_security = np.random.choice([0, 1], n_samples, p=[0.5, 0.5])
    
    # Generate target variable (churn probability based on features)
    churn_prob = (
        0.01 * (age < 30).astype(int) +  # Young customers more likely to churn
        0.02 * (monthly_charges > 80).astype(int) +  # High fee customers more likely to churn
        0.03 * (contract_length == 1).astype(int) +  # Short contract customers more likely to churn
        -0.02 * tech_support +  # Tech support reduces churn
        -0.01 * online_security +  # Online security reduces churn
        0.1  # Base churn rate
    )
    
    churn = np.random.binomial(1, np.clip(churn_prob, 0, 1), n_samples)
    
    data = pd.DataFrame({
        'Age': age,
        'Monthly Charges': monthly_charges,
        'Total Charges': total_charges,
        'Contract Length': contract_length,
        'Tech Support': tech_support,
        'Online Security': online_security,
        'Churn': churn
    })
    
    return data

# Create dataset
churn_data = create_customer_churn_dataset()

print("Customer churn dataset information:")
print(churn_data.info())
print("\nChurn rate:")
print(churn_data['Churn'].value_counts(normalize=True))

# Feature analysis
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Customer Feature and Churn Relationship Analysis', fontsize=16)

features = ['Age', 'Monthly Charges', 'Total Charges', 'Contract Length', 'Tech Support', 'Online Security']

for i, feature in enumerate(features):
    row = i // 3
    col = i % 3
    
    if feature in ['Tech Support', 'Online Security']:
        # Categorical features
        churn_data.groupby([feature, 'Churn']).size().unstack().plot(kind='bar', ax=axes[row, col])
        axes[row, col].set_title(f'{feature} vs Churn')
        axes[row, col].set_xlabel(feature)
        axes[row, col].tick_params(axis='x', rotation=0)
    else:
        # Numerical features
        for churn_status in [0, 1]:
            data_subset = churn_data[churn_data['Churn'] == churn_status][feature]
            axes[row, col].hist(data_subset, alpha=0.6, 
                              label=f'Churn={churn_status}', bins=20)
        axes[row, col].set_title(f'{feature} Distribution')
        axes[row, col].set_xlabel(feature)
        axes[row, col].legend()

plt.tight_layout()
plt.show()

6.8.2 Building Churn Prediction Model

python
# Prepare data
X_churn = churn_data.drop('Churn', axis=1)
y_churn = churn_data['Churn']

X_train_churn, X_test_churn, y_train_churn, y_test_churn = train_test_split(
    X_churn, y_churn, test_size=0.2, random_state=42, stratify=y_churn
)

# Train decision tree model
churn_dt = DecisionTreeClassifier(
    max_depth=5,
    min_samples_split=50,
    min_samples_leaf=20,
    random_state=42
)

churn_dt.fit(X_train_churn, y_train_churn)

# Predict and evaluate
y_pred_churn = churn_dt.predict(X_test_churn)
y_pred_proba_churn = churn_dt.predict_proba(X_test_churn)

print("Customer churn prediction model evaluation:")
print(f"Accuracy: {accuracy_score(y_test_churn, y_pred_churn):.4f}")
print("\nDetailed classification report:")
print(classification_report(y_test_churn, y_pred_churn, 
                          target_names=['Not Churned', 'Churned']))

# Confusion matrix
cm = confusion_matrix(y_test_churn, y_pred_churn)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Not Churned', 'Churned'],
            yticklabels=['Not Churned', 'Churned'])
plt.title('Customer Churn Prediction Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# Feature importance
feature_importance = churn_dt.feature_importances_
importance_df = pd.DataFrame({
    'feature': X_churn.columns,
    'importance': feature_importance
}).sort_values('importance', ascending=False)

plt.figure(figsize=(10, 6))
plt.barh(importance_df['feature'], importance_df['importance'])
plt.title('Customer Churn Prediction Model Feature Importance')
plt.xlabel('Importance')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print("Feature importance ranking:")
for _, row in importance_df.iterrows():
    print(f"{row['feature']}: {row['importance']:.4f}")

6.8.3 Decision Rule Interpretation

python
# Visualize decision tree
plt.figure(figsize=(20, 12))
plot_tree(churn_dt, 
          feature_names=X_churn.columns,
          class_names=['Not Churned', 'Churned'],
          filled=True,
          rounded=True,
          fontsize=8)
plt.title('Customer Churn Prediction Decision Tree')
plt.show()

# Extract decision rules
def extract_decision_rules(tree, feature_names, class_names):
    """Extract decision tree rules"""
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != -2
        else "undefined!"
        for i in tree_.feature
    ]
    
    def recurse(node, depth, parent_rule=""):
        indent = "  " * depth
        if tree_.feature[node] != -2:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            left_rule = f"{parent_rule} AND {name} <= {threshold:.2f}"
            right_rule = f"{parent_rule} AND {name} > {threshold:.2f}"
            
            print(f"{indent}if {name} <= {threshold:.2f}:")
            recurse(tree_.children_left[node], depth + 1, left_rule)
            print(f"{indent}else:  # if {name} > {threshold:.2f}")
            recurse(tree_.children_right[node], depth + 1, right_rule)
        else:
            # Leaf node
            value = tree_.value[node][0]
            predicted_class = class_names[np.argmax(value)]
            confidence = np.max(value) / np.sum(value)
            print(f"{indent}Prediction: {predicted_class} (Confidence: {confidence:.3f})")
            print(f"{indent}Rule: {parent_rule.strip(' AND ')}")
            print()

print("Decision tree rules:")
extract_decision_rules(churn_dt, X_churn.columns, ['Not Churned', 'Churned'])

# Prediction examples
sample_customers = pd.DataFrame({
    'Age': [25, 45, 60],
    'Monthly Charges': [90, 50, 30],
    'Total Charges': [1800, 1200, 720],
    'Contract Length': [1, 12, 24],
    'Tech Support': [0, 1, 1],
    'Online Security': [0, 1, 1]
})

predictions = churn_dt.predict(sample_customers)
probabilities = churn_dt.predict_proba(sample_customers)

print("Customer churn prediction examples:")
for i, (_, customer) in enumerate(sample_customers.iterrows()):
    print(f"\nCustomer {i+1}:")
    print(f"  Features: {dict(customer)}")
    print(f"  Prediction: {'Churned' if predictions[i] == 1 else 'Not Churned'}")
    print(f"  Churn Probability: {probabilities[i][1]:.3f}")

6.9 Decision Trees vs Other Algorithms

6.9.1 Algorithm Comparison

python
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB

# Compare multiple algorithms
algorithms = {
    'Decision Tree': DecisionTreeClassifier(max_depth=5, random_state=42),
    'Logistic Regression': LogisticRegression(random_state=42),
    'Support Vector Machine': SVC(random_state=42, probability=True),
    'K-Nearest Neighbors': KNeighborsClassifier(n_neighbors=5),
    'Naive Bayes': GaussianNB()
}

# Use iris dataset for comparison
iris = load_iris()
X_iris, y_iris = iris.data, iris.target
X_train_iris, X_test_iris, y_train_iris, y_test_iris = train_test_split(
    X_iris, y_iris, test_size=0.2, random_state=42
)

results = {}

print("Algorithm performance comparison:")
print("Algorithm\t\tTraining Time\tAccuracy\tInterpretability")
print("-" * 60)

import time

for name, algorithm in algorithms.items():
    # Training time
    start_time = time.time()
    algorithm.fit(X_train_iris, y_train_iris)
    training_time = time.time() - start_time
    
    # Accuracy
    accuracy = algorithm.score(X_test_iris, y_test_iris)
    
    # Interpretability score (subjective)
    interpretability = {
        'Decision Tree': 'High',
        'Logistic Regression': 'Medium',
        'Support Vector Machine': 'Low',
        'K-Nearest Neighbors': 'Medium',
        'Naive Bayes': 'Medium'
    }
    
    results[name] = {
        'training_time': training_time,
        'accuracy': accuracy,
        'interpretability': interpretability[name]
    }
    
    print(f"{name}\t{training_time:.4f}s\t\t{accuracy:.4f}\t\t{interpretability[name]}")

# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Accuracy comparison
accuracies = [results[name]['accuracy'] for name in algorithms.keys()]
axes[0].bar(algorithms.keys(), accuracies, color='skyblue')
axes[0].set_title('Algorithm Accuracy Comparison')
axes[0].set_ylabel('Accuracy')
axes[0].tick_params(axis='x', rotation=45)

# Training time comparison
times = [results[name]['training_time'] for name in algorithms.keys()]
axes[1].bar(algorithms.keys(), times, color='lightcoral')
axes[1].set_title('Algorithm Training Time Comparison')
axes[1].set_ylabel('Training Time (seconds)')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

6.10 Practice Exercises

Exercise 1: Basic Decision Trees

  1. Use make_classification to generate a binary classification dataset
  2. Train a decision tree and visualize the decision boundary
  3. Analyze the impact of different depths on model performance

Exercise 2: Regression Decision Trees

  1. Create a regression dataset with nonlinear relationships
  2. Compare decision tree regression with linear regression performance
  3. Analyze how decision trees handle nonlinear relationships

Exercise 3: Hyperparameter Optimization

  1. Use grid search to optimize decision tree hyperparameters
  2. Analyze the impact of different hyperparameters on overfitting
  3. Plot validation curves to analyze optimal parameters

Exercise 4: Practical Application

  1. Choose a real dataset (such as Titanic survival prediction)
  2. Build a decision tree classification model
  3. Explain the model's decision rules and analyze feature importance

6.11 Summary

In this chapter, we have deeply learned various aspects of decision tree algorithms:

Core Concepts

  • Decision tree principles: Information gain, Gini impurity, splitting criteria
  • Tree construction: Recursive splitting, stopping conditions, pruning
  • Classification and regression: Decision tree applications for different task types

Main Techniques

  • Classification decision trees: Handle discrete target variables
  • Regression decision trees: Handle continuous target variables
  • Hyperparameter tuning: Depth control, sample number limits
  • Model visualization: Tree structure diagrams, decision boundaries

Practical Skills

  • Overfitting control: Pruning techniques, complexity control
  • Feature importance: Feature selection based on splits
  • Model interpretation: Rule extraction, decision path analysis
  • Real applications: Customer churn prediction, medical diagnosis

Key Points

  • Decision trees have good interpretability, suitable for scenarios requiring understanding of decision processes
  • Prone to overfitting, need to avoid it through pruning and parameter control
  • Sensitive to small data changes, but this is also the foundation of ensemble methods
  • Can automatically perform feature selection and handle nonlinear relationships

6.12 Next Steps

Now you have mastered decision trees, an important basic algorithm! In the next chapter Random Forest and Ensemble Methods, we will learn how to build more powerful and stable models by combining multiple decision trees.


Chapter Highlights:

  • ✅ Understood the construction principles and splitting criteria of decision trees
  • ✅ Mastered the implementation of classification and regression decision trees
  • ✅ Learned to control overfitting and perform hyperparameter optimization
  • ✅ Understood decision tree visualization and interpretation methods
  • ✅ Mastered feature importance analysis and practical applications
  • ✅ Can build interpretable prediction models

Content is for learning and research only.