Model Interpretability

Understanding and implementing techniques for model interpretability

Techniques and methods for understanding and interpreting machine learning models.

Local Interpretability

LIME Implementation

import numpy as np
from sklearn.linear_model import Ridge
from sklearn.metrics.pairwise import cosine_similarity

class LIMEExplainer:
    def __init__(self, model, num_samples=5000):
        self.model = model
        self.num_samples = num_samples
        
    def generate_samples(self, x, feature_names):
        # Generate perturbed samples around x
        samples = np.random.normal(
            loc=x,
            scale=0.1,
            size=(self.num_samples, len(x))
        )
        
        # Get predictions
        predictions = self.model.predict(samples)
        
        # Calculate distances
        distances = cosine_similarity(samples, x.reshape(1, -1))
        
        return samples, predictions, distances
    
    def explain_instance(self, x, feature_names=None):
        if feature_names is None:
            feature_names = [f'feature_{i}' for i in range(len(x))]
        
        # Generate samples
        samples, predictions, distances = self.generate_samples(x, feature_names)
        
        # Train interpretable model
        interpreter = Ridge(alpha=1.0)
        interpreter.fit(
            samples,
            predictions,
            sample_weight=distances.ravel()
        )
        
        # Get feature importance
        importance = dict(zip(feature_names, interpreter.coef_))
        
        return {
            'importance': importance,
            'intercept': interpreter.intercept_,
            'r2_score': interpreter.score(samples, predictions)
        }

def plot_lime_explanation(explanation, top_k=10):
    importance = pd.DataFrame(
        explanation['importance'].items(),
        columns=['Feature', 'Importance']
    )
    importance = importance.nlargest(top_k, 'Importance')
    
    plt.figure(figsize=(10, 6))
    plt.barh(
        importance['Feature'],
        importance['Importance']
    )
    plt.title('LIME Feature Importance')
    return plt

SHAP Values

import shap

class SHAPExplainer:
    def __init__(self, model, background_data=None):
        self.model = model
        if background_data is not None:
            self.explainer = shap.KernelExplainer(
                model.predict,
                background_data
            )
        else:
            self.explainer = shap.KernelExplainer(
                model.predict,
                shap.sample(background_data, 100)
            )
    
    def explain_instance(self, x):
        shap_values = self.explainer.shap_values(x)
        
        if isinstance(shap_values, list):
            # For multi-class models
            return {
                'shap_values': shap_values,
                'expected_value': self.explainer.expected_value
            }
        else:
            return {
                'shap_values': shap_values,
                'expected_value': self.explainer.expected_value
            }
    
    def explain_dataset(self, X):
        return self.explainer.shap_values(X)
    
    def plot_summary(self, X):
        shap_values = self.explain_dataset(X)
        shap.summary_plot(shap_values, X)
    
    def plot_dependence(self, X, feature_idx):
        shap_values = self.explain_dataset(X)
        shap.dependence_plot(feature_idx, shap_values, X)

Global Interpretability

Feature Importance Analysis

class FeatureImportanceAnalyzer:
    def __init__(self, model, feature_names=None):
        self.model = model
        self.feature_names = feature_names
    
    def get_importance(self):
        if hasattr(self.model, 'feature_importances_'):
            importance = self.model.feature_importances_
        elif hasattr(self.model, 'coef_'):
            importance = np.abs(self.model.coef_)
        else:
            raise ValueError("Model doesn't provide feature importance")
        
        if self.feature_names is None:
            self.feature_names = [f'feature_{i}' for i in range(len(importance))]
        
        return pd.DataFrame({
            'feature': self.feature_names,
            'importance': importance
        }).sort_values('importance', ascending=False)
    
    def plot_importance(self, top_k=10):
        importance = self.get_importance()
        importance = importance.nlargest(top_k, 'importance')
        
        plt.figure(figsize=(10, 6))
        plt.barh(
            importance['feature'],
            importance['importance']
        )
        plt.title('Feature Importance')
        return plt

def permutation_importance(model, X, y, n_repeats=10):
    baseline_score = model.score(X, y)
    importance = {}
    
    for feature in X.columns:
        scores = []
        for _ in range(n_repeats):
            # Permute feature
            X_permuted = X.copy()
            X_permuted[feature] = np.random.permutation(X_permuted[feature])
            
            # Calculate score drop
            permuted_score = model.score(X_permuted, y)
            scores.append(baseline_score - permuted_score)
        
        importance[feature] = {
            'mean_importance': np.mean(scores),
            'std_importance': np.std(scores)
        }
    
    return importance

Partial Dependence Plots

class PartialDependence:
    def __init__(self, model, feature_names=None):
        self.model = model
        self.feature_names = feature_names
    
    def compute_ice_curves(self, X, feature, num_points=50):
        # Get feature range
        feature_min = X[feature].min()
        feature_max = X[feature].max()
        feature_vals = np.linspace(feature_min, feature_max, num_points)
        
        # Compute Individual Conditional Expectation curves
        ice_curves = []
        for i in range(len(X)):
            predictions = []
            for val in feature_vals:
                X_modified = X.copy()
                X_modified.iloc[i, X.columns.get_loc(feature)] = val
                pred = self.model.predict(X_modified)
                predictions.append(pred)
            ice_curves.append(predictions)
        
        return feature_vals, np.array(ice_curves)
    
    def plot_partial_dependence(self, X, feature, num_points=50):
        feature_vals, ice_curves = self.compute_ice_curves(
            X, feature, num_points
        )
        
        plt.figure(figsize=(10, 6))
        
        # Plot ICE curves
        for curve in ice_curves:
            plt.plot(feature_vals, curve, color='blue', alpha=0.1)
        
        # Plot PDP (mean curve)
        mean_curve = ice_curves.mean(axis=0)
        plt.plot(
            feature_vals,
            mean_curve,
            color='red',
            linewidth=2,
            label='PDP'
        )
        
        plt.xlabel(feature)
        plt.ylabel('Predicted Value')
        plt.title(f'Partial Dependence Plot for {feature}')
        plt.legend()
        return plt

Model-Specific Interpretability

Decision Tree Visualization

from sklearn.tree import export_graphviz
import graphviz

def visualize_decision_tree(tree_model, feature_names=None, class_names=None):
    dot_data = export_graphviz(
        tree_model,
        out_file=None,
        feature_names=feature_names,
        class_names=class_names,
        filled=True,
        rounded=True,
        special_characters=True
    )
    
    return graphviz.Source(dot_data)

def extract_decision_rules(tree_model, feature_names=None):
    n_nodes = tree_model.tree_.node_count
    children_left = tree_model.tree_.children_left
    children_right = tree_model.tree_.children_right
    feature = tree_model.tree_.feature
    threshold = tree_model.tree_.threshold
    
    def recurse(node, depth, path):
        if children_left[node] == children_right[node]:  # Leaf node
            return [path]
        
        feature_name = (f'feature_{feature[node]}'
                       if feature_names is None
                       else feature_names[feature[node]])
        
        left_path = path + [f'{feature_name} <= {threshold[node]:.2f}']
        right_path = path + [f'{feature_name} > {threshold[node]:.2f}']
        
        rules = []
        rules.extend(recurse(children_left[node], depth + 1, left_path))
        rules.extend(recurse(children_right[node], depth + 1, right_path))
        
        return rules
    
    return recurse(0, 0, [])

Neural Network Visualization

class ActivationVisualizer:
    def __init__(self, model):
        self.model = model
        self.activations = {}
        self._register_hooks()
    
    def _register_hooks(self):
        def get_activation(name):
            def hook(model, input, output):
                self.activations[name] = output.detach()
            return hook
        
        for name, layer in self.model.named_modules():
            if isinstance(layer, (nn.ReLU, nn.Conv2d)):
                layer.register_forward_hook(get_activation(name))
    
    def visualize_layer_activations(self, x, layer_name):
        # Forward pass
        self.model(x)
        
        # Get activations
        activations = self.activations[layer_name]
        
        # Plot activations
        if len(activations.shape) == 4:  # Conv layer
            fig, axes = plt.subplots(
                4, 4,
                figsize=(10, 10)
            )
            for i, ax in enumerate(axes.flat):
                if i < activations.shape[1]:
                    ax.imshow(activations[0, i].cpu())
                ax.axis('off')
            plt.tight_layout()
            return plt
        else:  # Dense layer
            plt.figure(figsize=(10, 6))
            plt.hist(activations.cpu().numpy().flatten(), bins=50)
            plt.title(f'Activation Distribution for {layer_name}')
            return plt

Best Practices

  1. Model Selection for Interpretability
def select_interpretable_model(dataset_size, num_features, task_type):
    if dataset_size < 10000 or num_features < 20:
        if task_type == 'classification':
            return 'decision_tree'
        else:
            return 'linear_regression'
    elif num_features < 100:
        return 'random_forest'
    else:
        return 'gradient_boosting'
  1. Explanation Generation
def generate_comprehensive_explanation(model, X, instance_idx):
    instance = X.iloc[instance_idx]
    
    # Local explanation
    lime_explainer = LIMEExplainer(model)
    lime_explanation = lime_explainer.explain_instance(instance)
    
    # SHAP values
    shap_explainer = SHAPExplainer(model, X)
    shap_explanation = shap_explainer.explain_instance(instance)
    
    # Feature importance
    importance_analyzer = FeatureImportanceAnalyzer(model, X.columns)
    global_importance = importance_analyzer.get_importance()
    
    return {
        'local_explanation': lime_explanation,
        'shap_values': shap_explanation,
        'global_importance': global_importance
    }
  1. Validation of Explanations
def validate_explanation(model, X, explanation, feature_names=None):
    # Check explanation stability
    stability_scores = []
    for _ in range(10):
        noise = np.random.normal(0, 0.01, X.shape)
        noisy_X = X + noise
        noisy_explanation = generate_comprehensive_explanation(
            model,
            noisy_X,
            0
        )
        
        # Calculate similarity
        similarity = cosine_similarity(
            explanation['local_explanation']['importance'].values(),
            noisy_explanation['local_explanation']['importance'].values()
        )
        stability_scores.append(similarity)
    
    # Check feature importance consistency
    local_importance = explanation['local_explanation']['importance']
    global_importance = explanation['global_importance']
    
    importance_consistency = np.corrcoef(
        list(local_importance.values()),
        global_importance['importance']
    )[0, 1]
    
    return {
        'stability': np.mean(stability_scores),
        'importance_consistency': importance_consistency
    }