Explainable AI
Understanding and implementing techniques for model interpretability and explanation in machine learning
Introduction to Explainable AI
Explainable AI (XAI) focuses on making machine learning models more transparent and interpretable, enabling users to understand how models make decisions.
Model-Agnostic Methods
LIME (Local Interpretable Model-agnostic Explanations)
from lime import lime_tabular
import numpy as np
def explain_with_lime(model, X_train, X_test, feature_names):
"""Explain predictions using LIME"""
# Initialize LIME explainer
explainer = lime_tabular.LimeTabularExplainer(
X_train,
feature_names=feature_names,
class_names=['Class 0', 'Class 1'],
mode='classification'
)
# Explain a single prediction
exp = explainer.explain_instance(
X_test[0],
model.predict_proba,
num_features=len(feature_names)
)
return exp
def visualize_lime_explanation(explanation):
"""Visualize LIME explanation"""
import matplotlib.pyplot as plt
# Get feature importance
feature_importance = explanation.as_list()
features, importance = zip(*feature_importance)
# Plot
plt.figure(figsize=(10, 6))
plt.barh(range(len(features)), importance)
plt.yticks(range(len(features)), features)
plt.xlabel('Impact on Prediction')
plt.title('LIME Feature Importance')
return plt.gcf()
SHAP (SHapley Additive exPlanations)
import shap
def explain_with_shap(model, X_train, X_test):
"""Explain predictions using SHAP"""
# Initialize SHAP explainer
explainer = shap.KernelExplainer(model.predict, X_train)
# Calculate SHAP values
shap_values = explainer.shap_values(X_test)
return explainer, shap_values
def plot_shap_summary(explainer, shap_values, feature_names):
"""Plot SHAP summary plot"""
shap.summary_plot(
shap_values,
feature_names=feature_names,
plot_type='bar'
)
Model-Specific Methods
Decision Tree Visualization
from sklearn.tree import export_graphviz
import graphviz
def visualize_decision_tree(tree_model, feature_names):
"""Visualize decision tree structure"""
dot_data = export_graphviz(
tree_model,
feature_names=feature_names,
filled=True,
rounded=True,
special_characters=True
)
graph = graphviz.Source(dot_data)
return graph
Feature Importance Analysis
def analyze_feature_importance(model, feature_names):
"""Analyze feature importance for tree-based models"""
importance = pd.DataFrame({
'feature': feature_names,
'importance': model.feature_importances_
})
importance = importance.sort_values('importance', ascending=False)
# Plot importance
plt.figure(figsize=(10, 6))
plt.bar(importance['feature'], importance['importance'])
plt.xticks(rotation=45)
plt.xlabel('Features')
plt.ylabel('Importance')
plt.title('Feature Importance')
return plt.gcf(), importance
Partial Dependence Plots
from sklearn.inspection import partial_dependence
def create_partial_dependence_plot(model, X, feature_idx, feature_name):
"""Create partial dependence plot for a feature"""
# Calculate partial dependence
pdp = partial_dependence(
model,
X,
[feature_idx],
kind='average'
)
# Plot
plt.figure(figsize=(8, 6))
plt.plot(pdp[1][0], pdp[0][0])
plt.xlabel(feature_name)
plt.ylabel('Partial dependence')
plt.title(f'Partial Dependence Plot for {feature_name}')
return plt.gcf()
Counterfactual Explanations
def generate_counterfactual(model, instance, desired_class, features, step_size=0.1):
"""Generate counterfactual explanation"""
counterfactual = instance.copy()
current_pred = model.predict([counterfactual])[0]
while current_pred != desired_class:
# Calculate feature importance
importance = model.feature_importances_
# Modify most important feature
most_important = np.argmax(importance)
if current_pred < desired_class:
counterfactual[most_important] += step_size
else:
counterfactual[most_important] -= step_size
current_pred = model.predict([counterfactual])[0]
changes = pd.DataFrame({
'feature': features,
'original': instance,
'counterfactual': counterfactual,
'difference': counterfactual - instance
})
return changes
Model Interpretability Metrics
def calculate_interpretability_metrics(model, X, y):
"""Calculate various interpretability metrics"""
metrics = {}
# Model complexity
if hasattr(model, 'get_n_leaves'):
metrics['n_leaves'] = model.get_n_leaves()
# Feature importance stability
if hasattr(model, 'feature_importances_'):
importances = []
for _ in range(10):
idx = np.random.choice(len(X), size=len(X)//2)
model.fit(X[idx], y[idx])
importances.append(model.feature_importances_)
metrics['importance_stability'] = np.std(importances, axis=0).mean()
return metrics
Best Practices
1. Model Selection
- Consider interpretability requirements
- Balance accuracy and explainability
- Use simpler models when possible
- Document model limitations
2. Explanation Methods
- Choose appropriate explanation techniques
- Validate explanations
- Consider audience needs
- Use multiple methods
3. Implementation Tips
- Start with simple explanations
- Focus on key features
- Ensure consistency
- Document assumptions
4. Common Pitfalls
- Over-complicated explanations
- Ignoring model uncertainty
- Misinterpreting results
- Not validating explanations