Model Interpretability
Understanding and implementing techniques for model interpretability
Techniques and methods for understanding and interpreting machine learning models.
Local Interpretability
LIME Implementation
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.metrics.pairwise import cosine_similarity
class LIMEExplainer:
def __init__(self, model, num_samples=5000):
self.model = model
self.num_samples = num_samples
def generate_samples(self, x, feature_names):
# Generate perturbed samples around x
samples = np.random.normal(
loc=x,
scale=0.1,
size=(self.num_samples, len(x))
)
# Get predictions
predictions = self.model.predict(samples)
# Calculate distances
distances = cosine_similarity(samples, x.reshape(1, -1))
return samples, predictions, distances
def explain_instance(self, x, feature_names=None):
if feature_names is None:
feature_names = [f'feature_{i}' for i in range(len(x))]
# Generate samples
samples, predictions, distances = self.generate_samples(x, feature_names)
# Train interpretable model
interpreter = Ridge(alpha=1.0)
interpreter.fit(
samples,
predictions,
sample_weight=distances.ravel()
)
# Get feature importance
importance = dict(zip(feature_names, interpreter.coef_))
return {
'importance': importance,
'intercept': interpreter.intercept_,
'r2_score': interpreter.score(samples, predictions)
}
def plot_lime_explanation(explanation, top_k=10):
importance = pd.DataFrame(
explanation['importance'].items(),
columns=['Feature', 'Importance']
)
importance = importance.nlargest(top_k, 'Importance')
plt.figure(figsize=(10, 6))
plt.barh(
importance['Feature'],
importance['Importance']
)
plt.title('LIME Feature Importance')
return plt
SHAP Values
import shap
class SHAPExplainer:
def __init__(self, model, background_data=None):
self.model = model
if background_data is not None:
self.explainer = shap.KernelExplainer(
model.predict,
background_data
)
else:
self.explainer = shap.KernelExplainer(
model.predict,
shap.sample(background_data, 100)
)
def explain_instance(self, x):
shap_values = self.explainer.shap_values(x)
if isinstance(shap_values, list):
# For multi-class models
return {
'shap_values': shap_values,
'expected_value': self.explainer.expected_value
}
else:
return {
'shap_values': shap_values,
'expected_value': self.explainer.expected_value
}
def explain_dataset(self, X):
return self.explainer.shap_values(X)
def plot_summary(self, X):
shap_values = self.explain_dataset(X)
shap.summary_plot(shap_values, X)
def plot_dependence(self, X, feature_idx):
shap_values = self.explain_dataset(X)
shap.dependence_plot(feature_idx, shap_values, X)
Global Interpretability
Feature Importance Analysis
class FeatureImportanceAnalyzer:
def __init__(self, model, feature_names=None):
self.model = model
self.feature_names = feature_names
def get_importance(self):
if hasattr(self.model, 'feature_importances_'):
importance = self.model.feature_importances_
elif hasattr(self.model, 'coef_'):
importance = np.abs(self.model.coef_)
else:
raise ValueError("Model doesn't provide feature importance")
if self.feature_names is None:
self.feature_names = [f'feature_{i}' for i in range(len(importance))]
return pd.DataFrame({
'feature': self.feature_names,
'importance': importance
}).sort_values('importance', ascending=False)
def plot_importance(self, top_k=10):
importance = self.get_importance()
importance = importance.nlargest(top_k, 'importance')
plt.figure(figsize=(10, 6))
plt.barh(
importance['feature'],
importance['importance']
)
plt.title('Feature Importance')
return plt
def permutation_importance(model, X, y, n_repeats=10):
baseline_score = model.score(X, y)
importance = {}
for feature in X.columns:
scores = []
for _ in range(n_repeats):
# Permute feature
X_permuted = X.copy()
X_permuted[feature] = np.random.permutation(X_permuted[feature])
# Calculate score drop
permuted_score = model.score(X_permuted, y)
scores.append(baseline_score - permuted_score)
importance[feature] = {
'mean_importance': np.mean(scores),
'std_importance': np.std(scores)
}
return importance
Partial Dependence Plots
class PartialDependence:
def __init__(self, model, feature_names=None):
self.model = model
self.feature_names = feature_names
def compute_ice_curves(self, X, feature, num_points=50):
# Get feature range
feature_min = X[feature].min()
feature_max = X[feature].max()
feature_vals = np.linspace(feature_min, feature_max, num_points)
# Compute Individual Conditional Expectation curves
ice_curves = []
for i in range(len(X)):
predictions = []
for val in feature_vals:
X_modified = X.copy()
X_modified.iloc[i, X.columns.get_loc(feature)] = val
pred = self.model.predict(X_modified)
predictions.append(pred)
ice_curves.append(predictions)
return feature_vals, np.array(ice_curves)
def plot_partial_dependence(self, X, feature, num_points=50):
feature_vals, ice_curves = self.compute_ice_curves(
X, feature, num_points
)
plt.figure(figsize=(10, 6))
# Plot ICE curves
for curve in ice_curves:
plt.plot(feature_vals, curve, color='blue', alpha=0.1)
# Plot PDP (mean curve)
mean_curve = ice_curves.mean(axis=0)
plt.plot(
feature_vals,
mean_curve,
color='red',
linewidth=2,
label='PDP'
)
plt.xlabel(feature)
plt.ylabel('Predicted Value')
plt.title(f'Partial Dependence Plot for {feature}')
plt.legend()
return plt
Model-Specific Interpretability
Decision Tree Visualization
from sklearn.tree import export_graphviz
import graphviz
def visualize_decision_tree(tree_model, feature_names=None, class_names=None):
dot_data = export_graphviz(
tree_model,
out_file=None,
feature_names=feature_names,
class_names=class_names,
filled=True,
rounded=True,
special_characters=True
)
return graphviz.Source(dot_data)
def extract_decision_rules(tree_model, feature_names=None):
n_nodes = tree_model.tree_.node_count
children_left = tree_model.tree_.children_left
children_right = tree_model.tree_.children_right
feature = tree_model.tree_.feature
threshold = tree_model.tree_.threshold
def recurse(node, depth, path):
if children_left[node] == children_right[node]: # Leaf node
return [path]
feature_name = (f'feature_{feature[node]}'
if feature_names is None
else feature_names[feature[node]])
left_path = path + [f'{feature_name} <= {threshold[node]:.2f}']
right_path = path + [f'{feature_name} > {threshold[node]:.2f}']
rules = []
rules.extend(recurse(children_left[node], depth + 1, left_path))
rules.extend(recurse(children_right[node], depth + 1, right_path))
return rules
return recurse(0, 0, [])
Neural Network Visualization
class ActivationVisualizer:
def __init__(self, model):
self.model = model
self.activations = {}
self._register_hooks()
def _register_hooks(self):
def get_activation(name):
def hook(model, input, output):
self.activations[name] = output.detach()
return hook
for name, layer in self.model.named_modules():
if isinstance(layer, (nn.ReLU, nn.Conv2d)):
layer.register_forward_hook(get_activation(name))
def visualize_layer_activations(self, x, layer_name):
# Forward pass
self.model(x)
# Get activations
activations = self.activations[layer_name]
# Plot activations
if len(activations.shape) == 4: # Conv layer
fig, axes = plt.subplots(
4, 4,
figsize=(10, 10)
)
for i, ax in enumerate(axes.flat):
if i < activations.shape[1]:
ax.imshow(activations[0, i].cpu())
ax.axis('off')
plt.tight_layout()
return plt
else: # Dense layer
plt.figure(figsize=(10, 6))
plt.hist(activations.cpu().numpy().flatten(), bins=50)
plt.title(f'Activation Distribution for {layer_name}')
return plt
Best Practices
- Model Selection for Interpretability
def select_interpretable_model(dataset_size, num_features, task_type):
if dataset_size < 10000 or num_features < 20:
if task_type == 'classification':
return 'decision_tree'
else:
return 'linear_regression'
elif num_features < 100:
return 'random_forest'
else:
return 'gradient_boosting'
- Explanation Generation
def generate_comprehensive_explanation(model, X, instance_idx):
instance = X.iloc[instance_idx]
# Local explanation
lime_explainer = LIMEExplainer(model)
lime_explanation = lime_explainer.explain_instance(instance)
# SHAP values
shap_explainer = SHAPExplainer(model, X)
shap_explanation = shap_explainer.explain_instance(instance)
# Feature importance
importance_analyzer = FeatureImportanceAnalyzer(model, X.columns)
global_importance = importance_analyzer.get_importance()
return {
'local_explanation': lime_explanation,
'shap_values': shap_explanation,
'global_importance': global_importance
}
- Validation of Explanations
def validate_explanation(model, X, explanation, feature_names=None):
# Check explanation stability
stability_scores = []
for _ in range(10):
noise = np.random.normal(0, 0.01, X.shape)
noisy_X = X + noise
noisy_explanation = generate_comprehensive_explanation(
model,
noisy_X,
0
)
# Calculate similarity
similarity = cosine_similarity(
explanation['local_explanation']['importance'].values(),
noisy_explanation['local_explanation']['importance'].values()
)
stability_scores.append(similarity)
# Check feature importance consistency
local_importance = explanation['local_explanation']['importance']
global_importance = explanation['global_importance']
importance_consistency = np.corrcoef(
list(local_importance.values()),
global_importance['importance']
)[0, 1]
return {
'stability': np.mean(stability_scores),
'importance_consistency': importance_consistency
}