Classification Metrics
This section covers the various metrics used to evaluate classification models, including their mathematical foundations, implementation, and practical considerations.
Basic Metrics
1. Confusion Matrix
- True Positives (TP): Correctly predicted positive cases
- True Negatives (TN): Correctly predicted negative cases
- False Positives (FP): Incorrectly predicted positive cases
- False Negatives (FN): Incorrectly predicted negative cases
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
def plot_confusion_matrix(y_true, y_pred, labels=None):
"""Plot confusion matrix with percentages"""
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Convert to percentages
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
# Create heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(
cm_percent,
annot=True,
fmt='.1f',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.title('Confusion Matrix (%)')
plt.xlabel('Predicted')
plt.ylabel('True')
return plt.gcf()
2. Basic Metrics
def compute_basic_metrics(y_true, y_pred):
"""Compute basic classification metrics"""
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()
# Accuracy
accuracy = (tp + tn) / (tp + tn + fp + fn)
# Precision
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
# Recall (Sensitivity)
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
# Specificity
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
# F1 Score
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'specificity': specificity,
'f1_score': f1
}
Advanced Metrics
1. ROC and AUC
from sklearn.metrics import roc_curve, auc
def plot_roc_curve(y_true, y_prob):
"""Plot ROC curve and compute AUC"""
# Compute ROC curve
fpr, tpr, thresholds = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)
# Plot
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True)
return plt.gcf(), roc_auc
2. Precision-Recall Curve
from sklearn.metrics import precision_recall_curve, average_precision_score
def plot_precision_recall_curve(y_true, y_prob):
"""Plot Precision-Recall curve and compute average precision"""
# Compute PR curve
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
avg_precision = average_precision_score(y_true, y_prob)
# Plot
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, color='darkorange', lw=2,
label=f'PR curve (AP = {avg_precision:.2f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.grid(True)
return plt.gcf(), avg_precision
3. Log Loss
from sklearn.metrics import log_loss
def compute_log_loss(y_true, y_prob):
"""Compute log loss (cross-entropy)"""
return log_loss(y_true, y_prob)
Multiclass Metrics
1. Multiclass Confusion Matrix
def plot_multiclass_confusion_matrix(y_true, y_pred, labels):
"""Plot multiclass confusion matrix"""
cm = confusion_matrix(y_true, y_pred)
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
plt.figure(figsize=(12, 10))
sns.heatmap(
cm_percent,
annot=True,
fmt='.1f',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.title('Multiclass Confusion Matrix (%)')
plt.xlabel('Predicted')
plt.ylabel('True')
return plt.gcf()
2. Per-Class Metrics
from sklearn.metrics import classification_report
def compute_per_class_metrics(y_true, y_pred, labels=None):
"""Compute per-class precision, recall, and F1-score"""
return classification_report(
y_true,
y_pred,
labels=labels,
output_dict=True
)
Calibration Metrics
1. Calibration Curves
from sklearn.calibration import calibration_curve
def plot_calibration_curve(y_true, y_prob, n_bins=10):
"""Plot calibration curve"""
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)
plt.figure(figsize=(10, 8))
plt.plot(prob_pred, prob_true, marker='o', label='Model')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve')
plt.legend()
plt.grid(True)
return plt.gcf()
Best Practices
1. Metric Selection
- Choose metrics based on problem characteristics
- Consider class imbalance
- Account for business requirements
- Use multiple complementary metrics
2. Threshold Selection
- Optimize threshold based on business needs
- Consider cost of different types of errors
- Use ROC and PR curves for threshold selection
- Validate threshold on holdout set
3. Reporting
- Include confidence intervals
- Report all relevant metrics
- Provide visual representations
- Document metric definitions and assumptions
4. Common Pitfalls
- Using accuracy for imbalanced datasets
- Ignoring probability calibration
- Not considering business context
- Overfitting to evaluation metrics