Task6:可解釋性分析

1. torch-cam工具包:CAM熱力圖

CAM algorithm

#安裝配置環(huán)境
pip install numpy pandas matplotlib requests tqdm opencv-python pillow scanpy anndata scipy tqdm stlearn sklearn glob2 -i https://pypi.tuna.tsinghua.edu.cn/simple

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html

#下載中文字體文件
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf

#下載ImageNet1000類別信息
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/meta_data/imagenet_class_index.csv

#創(chuàng)建目錄
import os
# 存放測(cè)試圖片
os.mkdir('test_img')

# 存放結(jié)果文件
os.mkdir('output')

# 存放訓(xùn)練得到的模型權(quán)重
# 不能命名為checkpoints
os.mkdir('checkpoint')

# 下載樣例模型文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/checkpoints/fruit30_pytorch_20220814.pth -P checkpoint

# 下載 類別名稱 和 ID索引號(hào) 的映射字典
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/labels_to_idx.npy
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/idx_to_labels.npy

# 下載測(cè)試圖像文件 至 test_img 文件夾

# 邊牧犬,來(lái)源:https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/border-collie.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/room_video.mp4 -P test_img

# 草莓圖像烁设,來(lái)源:https://www.pexels.com/zh-cn/photo/4828489/
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/test_草莓.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_fruits.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange_2.jpg -P test_img 

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_bananan.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_kiwi.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_石榴.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_lemon.jpg -P test_img

# !wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_火龍果.jpg -P test_img

#安裝torchcam
# 刪除原有的 torch-cam 目錄(如有)
!rm -rf torch-cam

# 下載安裝 torch-cam
!git clone https://github.com/frgfm/torch-cam.git
!pip install -e torch-cam/.

# 驗(yàn)證安裝成功
import torchcam

import matplotlib.pyplot as plt
%matplotlib inline

# Linux操作系統(tǒng)替梨,例如 云GPU平臺(tái):https://featurize.cn/?s=d7ce99f842414bfcaea5662a97581bd1
# 如果報(bào)錯(cuò) Unable to establish SSL connection.,重新運(yùn)行本代碼塊即可
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf -O /environment/miniconda3/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf/SimHei.ttf --no-check-certificate
!rm -rf /home/featurize/.cache/matplotlib

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
matplotlib.rc("font",family='SimHei') # 中文字體
plt.rcParams['axes.unicode_minus']=False  # 用來(lái)正常顯示負(fù)號(hào)

torchcam可解釋性分析可視化(命令行)装黑,對(duì)圖像進(jìn)行各種基于CAM的可解釋性分析

#導(dǎo)入工具包
import os
import pandas as pd
from PIL import Image

python torch-cam/scripts/cam_example.py --help

#ImageNet預(yù)訓(xùn)練圖像分類模型
# ImageNet1000類別名稱與ID號(hào)
df = pd.read_csv('imagenet_class_index.csv')
df

神經(jīng)網(wǎng)絡(luò)的注意力可視化

#圖中只有一個(gè)類別
# 類別-邊牧犬
python torch-cam/scripts/cam_example.py \
        --img test_img/border-collie.jpg \
        --savefig output/B1_border_collie.jpg \
        --arch resnet18 \
        --class-idx 232 \
        --rows 2

Image.open('output/B1_border_collie.jpg')
#圖中有多個(gè)類別
# 類別-虎斑貓
python torch-cam/scripts/cam_example.py \
        --img test_img/cat_dog.jpg \
        --savefig output/B2_cat_dog.jpg \
        --arch resnet18 \
        --class-idx 282 \
        --rows 2

Image.open('output/B2_cat_dog.jpg')
# 類別-邊牧犬
python torch-cam/scripts/cam_example.py \
        --img test_img/cat_dog.jpg \
        --savefig output/B3_cat_dog.jpg \
        --arch resnet18 \
        --class-idx 232 \
        --rows 2

Image.open('output/B3_cat_dog.jpg')

torchcam可解釋性分析可視化(python API)副瀑,對(duì)pytorch預(yù)訓(xùn)練的1000圖像分類模型進(jìn)行基于CAM的可解釋性分析

#導(dǎo)入工具包
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image

import torch
# 有 GPU 就用 GPU,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

#導(dǎo)入中文字體
from PIL import ImageFont, ImageDraw
# 導(dǎo)入中文字體,指定字體大小
font = ImageFont.truetype('SimHei.ttf', 50)

#導(dǎo)入ImageNet預(yù)訓(xùn)練模型
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval().to(device)

#導(dǎo)入可解釋性分析方法
from torchcam.methods import SmoothGradCAMpp 
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM

cam_extractor = SmoothGradCAMpp(model)

#預(yù)處理
from torchvision import transforms
# 測(cè)試集圖像預(yù)處理-RCTN:縮放刻两、裁剪脏毯、轉(zhuǎn) Tensor、歸一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

#運(yùn)用圖像分類預(yù)測(cè)
img_path = 'test_img/cat_dog.jpg'

img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理

input_tensor.shape

pred_logits = model(input_tensor)
pred_top1 = torch.topk(pred_logits, 1)
pred_id = pred_top1[1].detach().cpu().numpy().squeeze().item()

pred_id

#生成可解釋性分析熱力圖
activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()
activation_map.shape
activation_map

#可視化
plt.imshow(activation_map)
plt.show()
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)
result
import pandas as pd
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
idx_to_labels_cn = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = row['class']
    idx_to_labels_cn[row['ID']] = row['Chinese']

idx_to_labels

img_path = 'test_img/cat_dog.jpg'

# 可視化熱力圖的類別ID狈孔,如果為 None,則為置信度最高的預(yù)測(cè)類別ID
show_class_id = 231
# show_class_id = None

# 是否顯示中文類別
Chinese = True
# Chinese = False

# 前向預(yù)測(cè)
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
pred_logits = model(input_tensor)
pred_top1 = torch.topk(pred_logits, 1)
pred_id = pred_top1[1].detach().cpu().numpy().squeeze().item()

# 可視化熱力圖的類別ID材义,如果不指定均抽,則為置信度最高的預(yù)測(cè)類別ID
if show_class_id:
    show_id = show_class_id
else:
    show_id = pred_id
    show_class_id = pred_id

# 生成可解釋性分析熱力圖
activation_map = cam_extractor(show_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()
result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)

# 在圖像上寫字
draw = ImageDraw.Draw(result)

if Chinese:
    # 在圖像上寫中文
    text_pred = 'Pred Class: {}'.format(idx_to_labels_cn[pred_id])
    text_show = 'Show Class: {}'.format(idx_to_labels_cn[show_class_id])
else:
    # 在圖像上寫英文
    text_pred = 'Pred Class: {}'.format(idx_to_labels[pred_id])
    text_show = 'Show Class: {}'.format(idx_to_labels[show_class_id])
# 文字坐標(biāo),中文字符串其掂,字體油挥,rgba顏色
draw.text((50, 100), text_pred, font=font, fill=(255, 0, 0, 1))
draw.text((50, 200), text_show, font=font, fill=(255, 0, 0, 1))

result

通過(guò)Python API方式,使用torchcam算法庫(kù)款熬,對(duì)自己訓(xùn)練的水果圖像分類模型進(jìn)行基于CAM的可解釋性分析(輸入單張圖像)

#導(dǎo)入工具包
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
from PIL import Image

import torch
# 有 GPU 就用 GPU深寥,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

#導(dǎo)入訓(xùn)練好的pytorch模型
model = torch.load('checkpoint/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)

#導(dǎo)入可解釋性分析方法
from torchcam.methods import GradCAMpp
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM

cam_extractor = GradCAMpp(model)

#預(yù)處理
from torchvision import transforms
# 測(cè)試集圖像預(yù)處理-RCTN:縮放、裁剪华烟、轉(zhuǎn) Tensor翩迈、歸一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

#運(yùn)行圖像分類預(yù)測(cè)
img_path = 'test_img/test_fruits.jpg'

img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理

pred_logits = model(input_tensor)
pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()
pred_id
#生成可解釋性分析熱力圖
activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()

activation_map.shape

#可視化
plt.imshow(activation_map)
plt.show()
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)

result
# 完整代碼
img_path = 'test_img/test_fruits.jpg'

# 可視化熱力圖的類別,如果不指定盔夜,則為置信度最高的預(yù)測(cè)類別
show_class = '獼猴桃'

# 前向預(yù)測(cè)
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
pred_logits = model(input_tensor)
pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()

if show_class:
    class_id = labels_to_idx[show_class]
    show_id = class_id
else:
    show_id = pred_id

# 獲取熱力圖
activation_map = cam_extractor(show_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()
result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.4)
plt.imshow(result)
plt.axis('off')

plt.title('{}\nPred:{} Show:{}'.format(img_path, idx_to_labels[pred_id], show_class))
plt.show()

2. pytorch-grad-cam工具包:CAM熱力圖负饲、Guided Grad-CAM熱力圖、DFF

Grad-CAM algorithm

#安裝配置環(huán)境
pip install grad-cam torchcam

#下載pytorch-grad-cam
git clone https://github.com/jacobgil/pytorch-grad-cam.git

import os
# 存放測(cè)試圖片
os.mkdir('test_img')

# 存放結(jié)果文件
os.mkdir('output')

# 存放模型權(quán)重文件
os.mkdir('checkpoint')

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/test_img/puppies.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/test_img/bear.jpeg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/test_img/box_tabby.png -P test_img

# 蛇喂链,來(lái)源:https://www.pexels.com/zh-cn/photo/80474/
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/test_img/snake.jpg -P test_img

# 長(zhǎng)頸鹿和斑馬返十,來(lái)源:https://www.istockphoto.com/hk/%E7%85%A7%E7%89%87/giraffes-and-zebras-at-waterhole-gm503592172-82598465
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/test_img/giraffe_zebra.jpg -P test_img

# 大象、獅子椭微、羚羊洞坑,來(lái)源:https://www.istockphoto.com/hk/%E7%85%A7%E7%89%87/%E5%A4%A7%E8%B1%A1%E5%92%8C%E7%8D%85%E5%AD%90-gm1136053333-30244130
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/test_img/africa.jpg -P test_img

# 邊牧犬,來(lái)源:https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/border-collie.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_fruits.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange_2.jpg -P test_img 

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_bananan.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_kiwi.jpg -P test_img

# 草莓圖像蝇率,來(lái)源:https://www.pexels.com/zh-cn/photo/4828489/
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/test_草莓.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_石榴.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_lemon.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_火龍果.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/watermelon1.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/banana1.jpg -P test_img

#下載ImageNet1000類別信息
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/meta_data/imagenet_class_index.csv

#model
# 下載樣例模型文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/checkpoints/fruit30_pytorch_20220814.pth -P checkpoint

# 下載 類別名稱 和 ID索引號(hào) 的映射字典
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/labels_to_idx.npy
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/idx_to_labels.npy

import pytorch_grad_cam

對(duì)單張圖像進(jìn)行Grad-CAM熱力圖可解釋性分析

#導(dǎo)入工具包
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision.models import resnet50

import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
# 有 GPU 就用 GPU迟杂,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

#載入ImageNet預(yù)訓(xùn)練圖像分類模型
model = resnet50(pretrained=True).eval().to(device)

#圖像預(yù)處理
from torchvision import transforms

# 測(cè)試集圖像預(yù)處理-RCTN:縮放刽沾、裁剪、轉(zhuǎn) Tensor排拷、歸一化
test_transform = transforms.Compose([transforms.Resize(512),
                                     # transforms.CenterCrop(512),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

#載入測(cè)試圖像
img_path = 'test_img/cat_dog.jpg'
img_pil=Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
input_tensor.shape

#指定分析的類別
#281虎斑貓 232 邊牧犬
# 如果 targets 為 None侧漓,則默認(rèn)為最高置信度類別
targets = [ClassifierOutputTarget(232)]

#分析模型結(jié)構(gòu),確定待分析的層
model
model.layer4[-1]
model.layer1[0]

#任選一個(gè)可解釋性分析方法
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMElementWise, GradCAMPlusPlus, XGradCAM, AblationCAM, ScoreCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad

# Grad-CAM
from pytorch_grad_cam import GradCAM
target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)

# # Grad-CAM++
# from pytorch_grad_cam import GradCAMPlusPlus
# target_layers = [model.layer4[-1]]
# cam = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=True)

#生成CAM熱力圖
cam_map = cam(input_tensor=input_tensor, targets=targets)[0] # 不加平滑
# cam_map = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)[0] # 加平滑

#可視化
cam_map.shape
plt.imshow(cam_map)
plt.show()

import torchcam
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(cam_map), alpha=0.6) # alpha越小监氢,原圖越淡
result

result.save('output/B1.jpg')

對(duì)單張圖像進(jìn)行LayerCAM可解釋性分析

#導(dǎo)入工具包
from torchvision.models import vgg16, resnet50

import numpy as np
import pandas as pd
import cv2
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

import torch
# 有 GPU 就用 GPU布蔗,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

# 載入ImageNet預(yù)訓(xùn)練圖像分類模型
model = vgg16(pretrained=True).eval().to(device)
# model = resnet50(pretrained=True).eval().to(device)

#預(yù)處理
from torchvision import transforms

# 測(cè)試集圖像預(yù)處理-RCTN:縮放、裁剪浪腐、轉(zhuǎn) Tensor纵揍、歸一化
test_transform = transforms.Compose([transforms.Resize(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

img_path = 'test_img/snake.jpg'

# img_path = 'test_img/cat_dog.jpg'

img_pil = Image.open(img_path)
# img_pil
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
input_tensor.shape

#輸入模型,執(zhí)行前向預(yù)測(cè)
# 執(zhí)行前向預(yù)測(cè)议街,得到所有類別的 logit 預(yù)測(cè)分?jǐn)?shù)
pred_logits = model(input_tensor) 

import torch.nn.functional as F
pred_softmax = F.softmax(pred_logits, dim=1) # 對(duì) logit 分?jǐn)?shù)做 softmax 運(yùn)算
pred_softmax.shape

#獲得圖像分類預(yù)測(cè)結(jié)果
n = 5
top_n = torch.topk(pred_softmax, n)
top_n

# 解析出類別
pred_ids = top_n[1].cpu().detach().numpy().squeeze()

pred_ids
# 解析出置信度
confs = top_n[0].cpu().detach().numpy().squeeze()
confs

# 載入ImageNet 1000圖像分類標(biāo)簽
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = [row['wordnet'], row['class']]

for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 獲取類別名稱
    confidence = confs[i] * 100 # 獲取置信度
    text = '{:<5} {:<15} {:>.4f}'.format(pred_ids[i], class_name, confidence)
    print(text)

#指定分析的類別
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# 如果 targets 為 None泽谨,則默認(rèn)為最高置信度類別
targets = [ClassifierOutputTarget(56)]

#確定模型結(jié)構(gòu),確定待分析的層
# model

#選擇可解釋性分析方法
# LayerCAM
from pytorch_grad_cam import LayerCAM
target_layers = [model.features[8]] # vgg16
# target_layers = [model.layer3[0]] # resnet50
cam = LayerCAM(model=model, target_layers=target_layers, use_cuda=True)

#生成CAM熱力圖
cam_map = cam(input_tensor=input_tensor, targets=targets)[0] # 不加平滑

#可視化CAM熱力圖
cam_map.shape
cam_map.dtype
plt.imshow(cam_map)
plt.show()
import torchcam
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(cam_map), alpha=0.12) # alpha越小特漩,原圖越淡
# result
result.save('output/B2.jpg')

對(duì)單張圖像隔盛,進(jìn)行Guided Grad-CAM可解釋性分析,繪制既具有類別判別性(Class-Discriminative)拾稳,又具有高分辨率的細(xì)粒度熱力圖

#導(dǎo)入工具包
import numpy as np
import cv2
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torchvision import models
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad, GradCAMElementWise
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# 有 GPU 就用 GPU,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

model = models.resnet50(pretrained=True).eval().to(device)

from torchvision import transforms

# 測(cè)試集圖像預(yù)處理-RCTN:縮放腊脱、裁剪访得、轉(zhuǎn) Tensor、歸一化
test_transform = transforms.Compose([transforms.Resize(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

#載入測(cè)試圖片
img_path = 'test_img/cat_dog.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
input_tensor.shape

#選擇可解釋性分析方法
# GradCAM
from pytorch_grad_cam import GradCAM
target_layers = [model.layer4[-1]] # 要分析的層
targets = [ClassifierOutputTarget(232)] # 要分析的類別
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)

#生成Grad-CAM熱力圖
cam_map = cam(input_tensor=input_tensor, targets=targets)[0] # 不加平滑
# cam_map = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)[0] # 加平滑
cam_map.shape
plt.imshow(cam_map)
plt.title('Grad-CAM')
plt.show()

import torchcam
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(cam_map), alpha=0.5) # alpha越小陕凹,原圖越淡

plt.imshow(result)
plt.title('Grad-CAM')
plt.show()

#Guided Backpropagation算法
# 初始化算法
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=True)
# 生成 Guided Backpropagation熱力圖
gb_origin = gb_model(input_tensor, target_category=None)
gb_show = deprocess_image(gb_origin)
gb_show.shape

plt.imshow(gb_show)
plt.title('Guided Backpropagation')
plt.show()

#將Grad-CAM熱力圖與Guided Backpropagation熱力圖逐元素相乘
# Grad-CAM三通道熱力圖
cam_mask = cv2.merge([cam_map, cam_map, cam_map])
cam_mask.shape
# 逐元素相乘
guided_gradcam = deprocess_image(cam_mask * gb_origin)
guided_gradcam.shape
plt.imshow(guided_gradcam)
plt.title('Guided Grad-CAM')
plt.show()
cv2.imwrite('output/C1_guided_gradcam.jpg', guided_gradcam)

對(duì)單張圖像悍抑,進(jìn)行Guided Grad-CAM可解釋性分析,繪制既具有類別判別性杜耙,又具有高分辨率的細(xì)粒度熱力圖

#導(dǎo)入工具包
import numpy as np
import cv2
from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torchvision import models
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM, LayerCAM, FullGrad, GradCAMElementWise
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# 有 GPU 就用 GPU搜骡,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

#導(dǎo)入模型
model = torch.load('checkpoint/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)
idx_to_labels_cn = np.load('idx_to_labels.npy', allow_pickle=True).item()
idx_to_labels_cn

#圖像預(yù)處理
from torchvision import transforms

# 測(cè)試集圖像預(yù)處理-RCTN:縮放、裁剪佑女、轉(zhuǎn) Tensor记靡、歸一化
test_transform = transforms.Compose([transforms.Resize(224),
                                     # transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

#載入測(cè)試圖片
img_path = 'test_img/test_fruits.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
input_tensor.shape

#選擇可解釋性分析方法
# GradCAM
from pytorch_grad_cam import GradCAM
target_layers = [model.layer4[-1]] # 要分析的層
targets = [ClassifierOutputTarget(28)] # 要分析的類別
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)

#生成Grad-CAM熱力圖
cam_map = cam(input_tensor=input_tensor, targets=targets)[0] # 不加平滑
# cam_map = cam(input_tensor=input_tensor, targets=targets, aug_smooth=True, eigen_smooth=True)[0] # 加平滑
cam_map.shape

plt.imshow(cam_map)
plt.title('Grad-CAM')
plt.show()

import torchcam
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(cam_map), alpha=0.5) # alpha越小,原圖越淡

plt.imshow(result)
plt.title('Grad-CAM')
plt.show()

#Guided Backpropagation算法
# 初始化算法
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=True)
# 生成 Guided Backpropagation熱力圖
gb_origin = gb_model(input_tensor, target_category=None)
gb_show = deprocess_image(gb_origin)
gb_show.shape

plt.imshow(gb_show)
plt.title('Guided Backpropagation')
plt.show()

#兩個(gè)熱力圖逐元素相乘
# Grad-CAM三通道熱力圖
cam_mask = cv2.merge([cam_map, cam_map, cam_map])
cam_mask.shape
# 逐元素相乘
guided_gradcam = deprocess_image(cam_mask * gb_origin)
guided_gradcam.shape
plt.imshow(guided_gradcam)
plt.title('Guided Grad-CAM')
plt.show()
cv2.imwrite('output/C2_guided_gradcam.jpg', guided_gradcam)

對(duì)單張圖像团驱,進(jìn)行Deep feature factorization可解釋性分析摸吠,展示concept discovery概念發(fā)現(xiàn)圖

#導(dǎo)入工具包
import warnings
warnings.filterwarnings('ignore')
import requests

from PIL import Image
import numpy as np
import pandas as pd
import cv2
import json

import matplotlib.pyplot as plt
%matplotlib inline

from pytorch_grad_cam import DeepFeatureFactorization
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image
from pytorch_grad_cam import GradCAM
from torchvision.models import resnet50

import torch

#預(yù)處理函數(shù)
def get_image_from_path(img_path):
    '''
    輸入圖像文件路徑,輸出 圖像array嚎花、歸一化圖像array寸痢、預(yù)處理后的tensor
    '''

    img = np.array(Image.open(img_path))
    rgb_img_float = np.float32(img) / 255
    input_tensor = preprocess_image(rgb_img_float,
                                   mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
    return img, rgb_img_float, input_tensor

def create_labels(concept_scores, top_k=2):
    """ Create a list with the image-net category names of the top scoring categories"""

    df = pd.read_csv('imagenet_class_index.csv')
    labels = {}
    for idx, row in df.iterrows():
        labels[row['ID']] = row['class']
    
    concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
    concept_labels_topk = []
    for concept_index in range(concept_categories.shape[0]):
        categories = concept_categories[concept_index, :]    
        concept_labels = []
        for category in categories:
            score = concept_scores[concept_index, category]
            label = f"{labels[category].split(',')[0]}:{score:.2f}"
            concept_labels.append(label)
        concept_labels_topk.append("\n".join(concept_labels))
    return concept_labels_topk

#載入模型
model = resnet50(pretrained=True).eval()

#載入測(cè)試圖像
img_path = 'test_img/cat_dog.jpg'

#預(yù)處理
img, rgb_img_float, input_tensor = get_image_from_path(img_path)
img.shape
input_tensor.shape

#初始化DFF算法
classifier = model.fc
dff = DeepFeatureFactorization(model=model, 
                               target_layer=model.layer4, 
                               computation_on_concepts=classifier)
# concept個(gè)數(shù)(圖塊顏色個(gè)數(shù))
n_components = 5

concepts, batch_explanations, concept_outputs = dff(input_tensor, n_components)
concepts.shape

#圖像中每個(gè)像素對(duì)應(yīng)的concept熱力圖
# concept個(gè)數(shù) x 高 x 寬
batch_explanations[0].shape
plt.imshow(batch_explanations[0][4])
plt.show()
#concept與類別的關(guān)系
concept_outputs.shape
concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()    
concept_outputs.shape

#每個(gè)concept展示前top_k個(gè)類別
# 每個(gè)概念展示幾個(gè)類別
top_k = 2
concept_label_strings = create_labels(concept_outputs, top_k=top_k)
concept_label_strings

#生成可視化效果
from pytorch_grad_cam.utils.image import show_factorization_on_image
visualization = show_factorization_on_image(rgb_img_float, 
                                            batch_explanations[0],
                                            image_weight=0.3, # 原始圖像透明度
                                            concept_labels=concept_label_strings)
result = np.hstack((img, visualization))
Image.fromarray(result)
#封裝函數(shù)
def dff_show(img_path='test_img/cat_dog.jpg', n_components=5, top_k=2, hstack=False):
    img, rgb_img_float, input_tensor = get_image_from_path(img_path)
    dff = DeepFeatureFactorization(model=model, 
                                   target_layer=model.layer4, 
                                   computation_on_concepts=classifier)
    concepts, batch_explanations, concept_outputs = dff(input_tensor, n_components)
    concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()
    concept_label_strings = create_labels(concept_outputs, top_k=top_k)
    visualization = show_factorization_on_image(rgb_img_float, 
                                                batch_explanations[0],
                                                image_weight=0.3, # 原始圖像透明度
                                                concept_labels=concept_label_strings)
    if hstack:
        result = np.hstack((img, visualization))
    else:
        result = visualization
    display(Image.fromarray(result))
dff_show()
dff_show(hstack=True)
dff_show(img_path='test_img/box_tabby.png', hstack=True)
dff_show(img_path='test_img/puppies.jpg', hstack=True)
dff_show(img_path='test_img/bear.jpeg', hstack=True)
dff_show(img_path='test_img/bear.jpeg', n_components=10, top_k=1, hstack=True)
dff_show(img_path='test_img/giraffe_zebra.jpg', n_components=5, top_k=2, hstack=True)

DFF

#導(dǎo)入工具包
import warnings
warnings.filterwarnings('ignore')
import requests

from PIL import Image
import numpy as np
import pandas as pd
import cv2
import json

import matplotlib.pyplot as plt
%matplotlib inline

from pytorch_grad_cam import DeepFeatureFactorization
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image
from pytorch_grad_cam import GradCAM
from torchvision.models import resnet50

import torch

# 有 GPU 就用 GPU,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
print('device', device)

#預(yù)處理函數(shù)
from torchvision import transforms

# 測(cè)試集圖像預(yù)處理-RCTN:縮放紊选、裁剪啼止、轉(zhuǎn) Tensor道逗、歸一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

def get_image_from_path(img_path):
    '''
    輸入圖像文件路徑,輸出 圖像array献烦、歸一化圖像array滓窍、預(yù)處理后的tensor
    '''

    img = np.array(Image.open(img_path))
    rgb_img_float = np.float32(img) / 255
    input_tensor = preprocess_image(rgb_img_float,
                                   mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225])
    return img, rgb_img_float, input_tensor

def create_labels(concept_scores, top_k=2):
    """ Create a list with the image-net category names of the top scoring categories"""
    
    labels = {
        0:'Hami Melon',
        1:'Cherry Tomatoes',
        2:'Shanzhu',
        3:'Bayberry',
        4:'Grapefruit',
        5:'Lemon',
        6:'Longan',
        7:'Pears',
        8:'Coconut',
        9:'Durian',
        10:'Dragon Fruit',
        11:'Kiwi',
        12:'Pomegranate',
        13:'Sugar orange',
        14:'Carrots',
        15:'Navel orange',
        16:'Mango',
        17:'Balsam pear',
        18:'Apple Red',
        19:'Apple Green',
        20:'Strawberries',
        21:'Litchi',
        22:'Pineapple',
        23:'Grape White',
        24:'Grape Red',
        25:'Watermelon',
        26:'Tomato',
        27:'Cherts',
        28:'Banana',
        29:'Cucumber'
    }
    
    concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
    concept_labels_topk = []
    for concept_index in range(concept_categories.shape[0]):
        categories = concept_categories[concept_index, :]    
        concept_labels = []
        for category in categories:
            score = concept_scores[concept_index, category]
            label = f"{labels[category].split(',')[0]}:{score:.2f}"
            concept_labels.append(label)
        concept_labels_topk.append("\n".join(concept_labels))
    return concept_labels_topk

#載入模型
model = torch.load('checkpoint/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)

#載入測(cè)試圖像
img_path = 'test_img/test_fruits.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device)
input_tensor.shape

#預(yù)處理
img, rgb_img_float, input_tensor = get_image_from_path(img_path)
img.shape
input_tensor.shape

#初始化DFF算法
classifier = model.fc
dff = DeepFeatureFactorization(model=model, 
                               target_layer=model.layer4, 
                               computation_on_concepts=classifier)

# concept個(gè)數(shù)(圖塊顏色個(gè)數(shù))
n_components = 5

concepts, batch_explanations, concept_outputs = dff(input_tensor, n_components)
concepts.shape

#圖像中每個(gè)像素對(duì)應(yīng)的concept熱力圖
# concept個(gè)數(shù) x 高 x 寬
batch_explanations[0].shape
plt.imshow(batch_explanations[0][2])
plt.show()
#concept與類別的關(guān)系
concept_outputs.shape
concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()    
concept_outputs.shape

#每個(gè)concept展示前top_k個(gè)類別
# 每個(gè)概念展示幾個(gè)類別
top_k = 2
concept_label_strings = create_labels(concept_outputs, top_k=top_k)
concept_label_strings

#生成可視化效果
from pytorch_grad_cam.utils.image import show_factorization_on_image
visualization = show_factorization_on_image(rgb_img_float, 
                                            batch_explanations[0],
                                            image_weight=0.3, # 原始圖像透明度
                                            concept_labels=concept_label_strings)
result = np.hstack((img, visualization))
Image.fromarray(result)
#封裝函數(shù)
def dff_show(img_path='test_img/cat_dog.jpg', n_components=5, top_k=2, hstack=False):
    img, rgb_img_float, input_tensor = get_image_from_path(img_path)
    dff = DeepFeatureFactorization(model=model, 
                                   target_layer=model.layer4, 
                                   computation_on_concepts=classifier)
    concepts, batch_explanations, concept_outputs = dff(input_tensor, n_components)
    concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()
    concept_label_strings = create_labels(concept_outputs, top_k=top_k)
    visualization = show_factorization_on_image(rgb_img_float, 
                                                batch_explanations[0],
                                                image_weight=0.3, # 原始圖像透明度
                                                concept_labels=concept_label_strings)
    if hstack:
        result = np.hstack((img, visualization))
    else:
        result = visualization
    display(Image.fromarray(result))
dff_show(img_path='test_img/test_草莓.jpg', hstack=True)
dff_show(img_path='test_img/test_火龍果.jpg', hstack=True)
dff_show(img_path='test_img/test_石榴.jpg', hstack=True)
dff_show(img_path='test_img/test_bananan.jpg', hstack=True)
dff_show(img_path='test_img/test_kiwi.jpg', hstack=True)

3. Captum工具包:遮擋、梯度

CAM&Captum algorithm

4. shap

#安裝配置環(huán)境
pip install numpy pandas matplotlib requests tqdm opencv-python pillow shap tensorflow keras -i https://pypi.tuna.tsinghua.edu.cn/simple

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

import shap
import os
# 存放測(cè)試圖片
os.mkdir('test_img')
# 存放結(jié)果文件
os.mkdir('output')
# 存放訓(xùn)練得到的模型權(quán)重
os.mkdir('checkpoint')
# 存放標(biāo)注文件
os.mkdir('data')

#下載ImageNet1000類別信息
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/meta_data/imagenet_class_index.csv -P data

# 下載樣例模型文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/checkpoints/fruit30_pytorch_20220814.pth -P checkpoint

# 下載 類別名稱 和 ID索引號(hào) 的映射字典
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/labels_to_idx.npy -P data
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/idx_to_labels.npy -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220919-explain/imagenet_class_index.json -P data

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/fruit30/idx_to_labels_en.npy -P data

#下載測(cè)試圖像文件至test_img文件夾
# 邊牧犬仿荆,來(lái)源:https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/border-collie.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/room_video.mp4 -P test_img

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/swan-3299528_1280.jpg -P test_img

# 草莓圖像贰您,來(lái)源:https://www.pexels.com/zh-cn/photo/4828489/
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/test_草莓.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_fruits.jpg -P test_img

!wget https://zihao-openmmlab.obs.myhuaweicloud.com/20220716-mmclassification/test/0818/test_orange_2.jpg -P test_img 

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/banana-kiwi.png -P test_img

對(duì)自己訓(xùn)練得到的30類水果圖像分類模型進(jìn)行可解釋性分析,可視化制定預(yù)測(cè)類別的shap值熱力圖

#導(dǎo)入工具包
import json
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import shap

# 有 GPU 就用 GPU拢操,沒(méi)有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

#載入30類水果圖像分類圖像
model = torch.load('checkpoint/fruit30_pytorch_20220814.pth')
model = model.eval().to(device)

#載入分類數(shù)據(jù)集的類別
idx_to_labels = np.load('data/idx_to_labels_en.npy', allow_pickle=True).item()
idx_to_labels
class_names = list(idx_to_labels.values())
class_names

#載入一張測(cè)試圖像锦亦,整理緯度
# img_path = 'test_img/test_草莓.jpg'
img_path = 'test_img/test_fruits.jpg'

img_pil = Image.open(img_path)
X = torch.Tensor(np.array(img_pil)).unsqueeze(0)
X.shape

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        x = x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)
    elif x.dim() == 3:
        x = x if x.shape[0] == 3 else x.permute(2, 0, 1)
    return x

def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        x = x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)
    elif x.dim() == 3:
        x = x if x.shape[2] == 3 else x.permute(1, 2, 0)
    return x 
        

transform= [
    transforms.Lambda(nhwc_to_nchw),
    transforms.Resize(224),
    transforms.Lambda(lambda x: x*(1/255)),
    transforms.Normalize(mean=mean, std=std),
    transforms.Lambda(nchw_to_nhwc),
]

inv_transform= [
    transforms.Lambda(nhwc_to_nchw),
    transforms.Normalize(
        mean = (-1 * np.array(mean) / np.array(std)).tolist(),
        std = (1 / np.array(std)).tolist()
    ),
    transforms.Lambda(nchw_to_nhwc),
]

transform = torchvision.transforms.Compose(transform)
inv_transform = torchvision.transforms.Compose(inv_transform)

#構(gòu)建模型預(yù)測(cè)函數(shù)
def predict(img: np.ndarray) -> torch.Tensor:
    img = nhwc_to_nchw(torch.Tensor(img)).to(device)
    output = model(img)
    return output

def predict(img):
    img = nhwc_to_nchw(torch.Tensor(img)).to(device)
    output = model(img)
    return output

#測(cè)試整個(gè)工作流正常
Xtr = transform(X)
out = predict(Xtr[0:1])
out.shape
classes = torch.argmax(out, axis=1).detach().cpu().numpy()
print(f'Classes: {classes}: {np.array(class_names)[classes]}')
#設(shè)置shap可解釋性分析算法
# 構(gòu)造輸入圖像
input_img = Xtr[0].unsqueeze(0)
input_img.shape
batch_size = 50
n_evals = 5000 # 迭代次數(shù)越大,顯著性分析粒度越精細(xì)令境,計(jì)算消耗時(shí)間越長(zhǎng)

# 定義 mask杠园,遮蓋輸入圖像上的局部區(qū)域
masker_blur = shap.maskers.Image("blur(64, 64)", Xtr[0].shape)

# 創(chuàng)建可解釋分析算法
explainer = shap.Explainer(predict, masker_blur, output_names=class_names)

#指定單個(gè)預(yù)測(cè)類別
# 28:香蕉 banana
shap_values = explainer(input_img, max_evals=n_evals, batch_size=batch_size, outputs=[28])
# 整理張量維度
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0] # 原圖
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)] # shap值熱力圖
# 原圖
shap_values.data.shape
# shap值熱力圖
shap_values.values[0].shape

# 可視化
shap.image_plot(shap_values=shap_values.values,
                pixel_values=shap_values.data,
                labels=shap_values.output_names)
#指定多個(gè)預(yù)測(cè)類別
# 5 檸檬
# 12 石榴
# 15 臍橙
shap_values = explainer(input_img, max_evals=n_evals, batch_size=batch_size, outputs=[5, 12, 15])

# 整理張量維度
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0] # 原圖
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)] # shap值熱力圖

# shap值熱力圖:每個(gè)像素,對(duì)于每個(gè)類別的shap值
shap_values.shape

# 可視化
shap.image_plot(shap_values=shap_values.values,
                pixel_values=shap_values.data,
                labels=shap_values.output_names)
#前k個(gè)預(yù)測(cè)類別
topk = 5
shap_values = explainer(input_img, max_evals=n_evals, batch_size=batch_size, outputs=shap.Explanation.argsort.flip[:topk])
# shap值熱力圖:每個(gè)像素舔庶,對(duì)于每個(gè)類別的shap值
shap_values.shape
# 整理張量維度
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0] # 原圖
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)] # 各個(gè)類別的shap值熱力圖
# 各個(gè)類別的shap值熱力圖
len(shap_values.values)
# 第一個(gè)類別抛蚁,shap值熱力圖
shap_values.values[0].shape

# 可視化
shap.image_plot(shap_values=shap_values.values,
                pixel_values=shap_values.data,
                labels=shap_values.output_names
                )

使用shap庫(kù)的GradientExplainer,對(duì)預(yù)訓(xùn)練VGG16模型的中間層輸出惕橙,計(jì)算shap值

#導(dǎo)入工具包
import torch, torchvision
from torch import nn
from torchvision import transforms, models, datasets
import shap
import json
import numpy as np

#載入模型
# load the model
model = models.vgg16(pretrained=True).eval()

#載入數(shù)據(jù)集瞧甩、預(yù)處理
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def normalize(image):
    if image.max() > 1:
        image /= 255
    image = (image - mean) / std
    # in addition, roll the axis so that they suit pytorch
    return torch.tensor(image.swapaxes(-1, 1).swapaxes(2, 3)).float()

#指定測(cè)試圖像
X, y = shap.datasets.imagenet50()

X /= 255

to_explain = X[[39, 41]]

#載入類別和索引號(hào)
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
fname = shap.datasets.cache(url)
with open(fname) as f:
    class_names = json.load(f)

#計(jì)算模型中間層,在輸入圖像上的shap值
# 指定中間層
layer_index = 7 

# 迭代次數(shù)弥鹦,200次大約需計(jì)算 5 分鐘
samples = 200
e = shap.GradientExplainer((model, model.features[layer_index]), normalize(X))
shap_values,indexes = e.shap_values(normalize(to_explain), ranked_outputs=2, nsamples=samples)

#預(yù)測(cè)類別名稱
index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)
index_names

#可視化
shap_values = [np.swapaxes(np.swapaxes(s, 2, 3), 1, -1) for s in shap_values]

shap.image_plot(shap_values, to_explain, index_names)

#在圖像上引入局部平滑
# 計(jì)算模型中間層肚逸,在輸入圖像上的shap值
explainer = shap.GradientExplainer((model, model.features[layer_index]), normalize(X), local_smoothing=0.5)
shap_values, indexes = explainer.shap_values(normalize(to_explain), ranked_outputs=2, nsamples=samples)

# 預(yù)測(cè)類別名稱
index_names = np.vectorize(lambda x: class_names[str(x)][1])(indexes)

# 可視化
shap_values = [np.swapaxes(np.swapaxes(s, 2, 3), 1, -1) for s in shap_values]

shap.image_plot(shap_values, to_explain, index_names)

將輸入圖像局部遮擋,對(duì)resnet50圖像分類模型的預(yù)測(cè)結(jié)果進(jìn)行可解釋性分析

#導(dǎo)入工具包
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
import shap

#導(dǎo)入預(yù)訓(xùn)練模型
model = ResNet50(weights='imagenet')

#導(dǎo)入數(shù)據(jù)集
X, y = shap.datasets.imagenet50()

#構(gòu)建模型預(yù)測(cè)函數(shù)
def f(x):
    tmp = x.copy()
    preprocess_input(tmp)
    return model(tmp)

#構(gòu)建局部遮擋函數(shù)
masker = shap.maskers.Image("inpaint_telea", X[0].shape)

#輸出類別名稱
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
    class_names = [v[1] for v in json.load(file).values()]

#創(chuàng)建Explainer
explainer = shap.Explainer(f, masker, output_names=class_names)
#計(jì)算shap值
shap_values = explainer(X[1:3], max_evals=100, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]) 
#可視化
shap.image_plot(shap_values)
#更加細(xì)粒度的shap計(jì)算和可視化
masker_blur = shap.maskers.Image("blur(128,128)", X[0].shape)

explainer_blur = shap.Explainer(f, masker_blur, output_names=class_names)

shap_values_fine = explainer_blur(X[1:3], max_evals=5000, batch_size=50, outputs=shap.Explanation.argsort.flip[:4]) 

shap.image_plot(shap_values_fine)
shap.image_plot(shap_values_fine)

5. LIME

LIME algorithm
LIME&shap algorithm

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末彬坏,一起剝皮案震驚了整個(gè)濱河市朦促,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌栓始,老刑警劉巖务冕,帶你破解...
    沈念sama閱讀 218,546評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異幻赚,居然都是意外死亡禀忆,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,224評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門坯屿,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)油湖,“玉大人,你說(shuō)我怎么就攤上這事领跛》Φ拢” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 164,911評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)喊括。 經(jīng)常有香客問(wèn)我胧瓜,道長(zhǎng),這世上最難降的妖魔是什么郑什? 我笑而不...
    開(kāi)封第一講書人閱讀 58,737評(píng)論 1 294
  • 正文 為了忘掉前任府喳,我火速辦了婚禮,結(jié)果婚禮上蘑拯,老公的妹妹穿的比我還像新娘钝满。我一直安慰自己,他們只是感情好申窘,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,753評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布弯蚜。 她就那樣靜靜地躺著,像睡著了一般剃法。 火紅的嫁衣襯著肌膚如雪碎捺。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書人閱讀 51,598評(píng)論 1 305
  • 那天贷洲,我揣著相機(jī)與錄音收厨,去河邊找鬼。 笑死优构,一個(gè)胖子當(dāng)著我的面吹牛诵叁,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播钦椭,決...
    沈念sama閱讀 40,338評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼黎休,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了玉凯?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書人閱讀 39,249評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤联贩,失蹤者是張志新(化名)和其女友劉穎漫仆,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體泪幌,經(jīng)...
    沈念sama閱讀 45,696評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡盲厌,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,888評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了祸泪。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片吗浩。...
    茶點(diǎn)故事閱讀 40,013評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖没隘,靈堂內(nèi)的尸體忽然破棺而出懂扼,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 35,731評(píng)論 5 346
  • 正文 年R本政府宣布阀湿,位于F島的核電站赶熟,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏陷嘴。R本人自食惡果不足惜映砖,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,348評(píng)論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望灾挨。 院中可真熱鬧邑退,春花似錦、人聲如沸劳澄。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 31,929評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)浴骂。三九已至乓土,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間溯警,已是汗流浹背趣苏。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 33,048評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留梯轻,地道東北人食磕。 一個(gè)月前我還...
    沈念sama閱讀 48,203評(píng)論 3 370
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像喳挑,于是被迫代替她去往敵國(guó)和親彬伦。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,960評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容