基于shapley值的機(jī)器學(xué)習(xí)可解釋性分析
shapley值:當(dāng)多人聯(lián)盟博弈時(shí)应媚,某人加入組織,對(duì)最終博弈決策帶來(lái)的邊際貢獻(xiàn);某一個(gè)特征引入時(shí),對(duì)模型預(yù)測(cè)結(jié)果帶來(lái)的邊際影響(特征重要度)
在機(jī)器學(xué)習(xí)中照卦,shapley值反映特定樣本的特征重要度
SHAP:SHapley Additive explanation (SHapley Additive explanation)是一種解釋任何機(jī)器學(xué)習(xí)模型輸出的博弈論方法
pip install shap
or
conda install -c conda-forge shap
LIME可解釋性分析
pip install lime scikit-learn numpy pandas matplotlib pillow
import os
# 存放測(cè)試圖片
os.mkdir('test_img')
# 存放模型權(quán)重文件
os.mkdir('checkpoint')
# 下載樣例模型文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/checkpoints/fruit30_pytorch_20220814.pth -P checkpoint
# 下載 類別名稱 和 ID索引號(hào) 的映射字典
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/labels_to_idx.npy
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/idx_to_labels.npy
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_fruits.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange_2.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_bananan.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_kiwi.jpg -P test_img
# 草莓圖像,來(lái)源:https://www.pexels.com/zh-cn/photo/4828489/
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/test_草莓.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_石榴.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_lemon.jpg -P test_img
!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_火龍果.jpg -P test_img
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/watermelon1.jpg -P test_img
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/banana1.jpg -P test_img
import lime
import sklearn
import numpy as np
import pandas as pd
import lime
from lime import lime_tabular
#載入數(shù)據(jù)集
df = pd.read_csv('wine.csv')
df.shape
df
#劃分訓(xùn)練集和測(cè)試集
from sklearn.model_selection import train_test_split
X = df.drop('quality', axis=1)
y = df['quality']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train.shape
X_test.shape
y_train.shape
y_test.shape
#訓(xùn)練模型
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
#評(píng)估模型
score = model.score(X_test, y_test)
score
explainer = lime_tabular.LimeTabularExplainer(
training_data=np.array(X_train), # 訓(xùn)練集特征乡摹,必須是 numpy 的 Array
feature_names=X_train.columns, # 特征列名
class_names=['bad', 'good'], # 預(yù)測(cè)類別名稱
mode='classification' # 分類模式
)
idx = 3
data_test = np.array(X_test.iloc[idx]).reshape(1, -1)
prediction = model.predict(data_test)[0]
y_true = np.array(y_test)[idx]
print('測(cè)試集中的 {} 號(hào)樣本, 模型預(yù)測(cè)為 {}, 真實(shí)類別為 {}'.format(idx, prediction, y_true))
exp = explainer.explain_instance(
data_row=X_test.iloc[idx],
predict_fn=model.predict_proba
)
exp.show_in_notebook(show_table=True)
對(duì)Pytorch的ImageNet預(yù)訓(xùn)練圖像分類模型役耕,運(yùn)行LIME可解釋性分析:可視化某個(gè)輸入圖像,某個(gè)圖塊區(qū)域聪廉,對(duì)模型預(yù)測(cè)為某個(gè)類別的貢獻(xiàn)影響
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json
import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F
# 有 GPU 就用 GPU瞬痘,沒有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)
#載入測(cè)試圖片
img_path = 'test_img/cat_dog.jpg'
img_pil=Image.open(img_path)
img_pil
#載入模型
model = models.inception_v3(pretrained=True).eval().to(device)
idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.abspath('imagenet_class_index.json'), 'r') as read_file:
class_idx = json.load(read_file)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}
#圖像預(yù)處理
trans_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
trans_A = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
trans_norm
])
trans_B = transforms.Compose([
transforms.ToTensor(),
trans_norm
])
trans_C = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224)
])
#圖像分類預(yù)測(cè)
input_tensor = trans_A(img_pil).unsqueeze(0).to(device)
pred_logits = model(input_tensor)
pred_softmax = F.softmax(pred_logits, dim=1)
top_n = pred_softmax.topk(5)
top_n
#定義分類預(yù)測(cè)函數(shù)
def batch_predict(images):
batch = torch.stack(tuple(trans_B(i) for i in images), dim=0)
batch = batch.to(device)
logits = model(batch)
probs = F.softmax(logits, dim=1)
return probs.detach().cpu().numpy()
test_pred = batch_predict([trans_C(img_pil)])
test_pred.squeeze().argmax()
#可解釋性分析
from lime import lime_image
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(trans_C(img_pil)),
batch_predict, # 分類預(yù)測(cè)函數(shù)
top_labels=5,
hide_color=0,
num_samples=8000) # LIME生成的鄰域圖像個(gè)數(shù)
explanation.top_labels[0]
#可視化
from skimage.segmentation import mark_boundaries
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=20, hide_rest=False)
img_boundry = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry)
plt.show()
#修改可視化參數(shù)
temp, mask = explanation.get_image_and_mask(281, positive_only=False, num_features=20, hide_rest=False)
img_boundry = mark_boundaries(temp/255.0, mask)
plt.imshow(img_boundry)
plt.show()