Task2:預(yù)訓(xùn)練模型預(yù)測

注意睹限,ImageNet 1000類中并不包含“西瓜”

1. 安裝配置環(huán)境

pip install numpy pandas matplotlib requests tqdm opencv-python pillow gc -i https://pypi.tuna.tsinghua.edu.cn/simple
#安裝pytorch
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
#安裝mmcv-full
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
#下載中文字體文件
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf
#下載Image Net1000類別信息
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/meta_data/imagenet_class_index.csv

#創(chuàng)建目錄
import os
# 存放測試圖片
os.mkdir('test_img')

# 存放結(jié)果文件
os.mkdir('output')

# 下載測試圖像文件 至 test_img 文件夾

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/watermelon1.jpg -O test_img/watermelon1.jpg
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/banana1.jpg -O test_img/banana1.jpg
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat1.jpg -O test_img/cat1.jpg

# 哈士奇鲸沮,來源:https://www.pexels.com/zh-cn/photo/2853130/
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/husky1.jpeg -O test_img/husky1.jpeg

# 貓狗兰珍,來源:https://unsplash.com/photos/ouo1hbizWwo
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -O test_img/cat_dog.jpg

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/video_2.mp4 -O test_img/video_2.mp4

2. 預(yù)測單張圖像

import os

import cv2

import pandas as pd
import numpy as np

import torch

import matplotlib.pyplot as plt

# 有 GPU 就用 GPU,沒有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

#載入預(yù)訓(xùn)練圖像分類模型
from torchvision import models
# 載入預(yù)訓(xùn)練圖像分類模型
dir(models)
model = models.resnet18(pretrained=True) 

# model = models.resnet152(pretrained=True)

model = model.eval()
model = model.to(device)

#圖像預(yù)處理
from torchvision import transforms

# 測試集圖像預(yù)處理-RCTN:縮放裁剪骤竹、轉(zhuǎn) Tensor烁挟、歸一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

#載入一張測試圖像
# img_path = 'test_img/banana1.jpg'
# img_path = 'test_img/husky1.jpeg'
img_path = 'test_img/basketball_shoe.jpeg'

# img_path = 'test_img/cat_dog.jpg'

# 用 pillow 載入
from PIL import Image
img_pil = Image.open(img_path)

img_pil
img_pil

np.array(img_pil).shape

#圖像分類訓(xùn)練
input_img = test_transform(img_pil) # 預(yù)處理
input_img.shape
input_img = input_img.unsqueeze(0).to(device)
input_img.shape

# 執(zhí)行前向預(yù)測挥转,得到所有類別的 logit 預(yù)測分?jǐn)?shù)
pred_logits = model(input_img) 
pred_logits.shape
# pred_logits

import torch.nn.functional as F
pred_softmax = F.softmax(pred_logits, dim=1) # 對 logit 分?jǐn)?shù)做 softmax 運算
pred_softmax.shape
# pred_softmax

#預(yù)測結(jié)果分析
plt.figure(figsize=(8,4))

x = range(1000)
y = pred_softmax.cpu().detach().numpy()[0]

ax = plt.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
plt.ylim([0, 1.0]) # y軸取值范圍
# plt.bar_label(ax, fmt='%.2f', fontsize=15) # 置信度數(shù)值

plt.xlabel('Class', fontsize=20)
plt.ylabel('Confidence', fontsize=20)
plt.tick_params(labelsize=16) # 坐標(biāo)文字大小
plt.title(img_path, fontsize=25)

plt.show()

#取置信度最高的N個結(jié)果
n = 10
top_n = torch.topk(pred_softmax, n)
top_n

# 解析出類別
pred_ids = top_n[1].cpu().detach().numpy().squeeze()
pred_ids
#array([954, 939, 941, 940, 943, 506, 945, 991, 883, 600])

# 解析出置信度
confs = top_n[0].cpu().detach().numpy().squeeze()
confs
#array([9.9776304e-01, 1.2627112e-03, 4.3848471e-04, 1.3670148e-04,
#       6.2257830e-05, 6.0630489e-05, 3.7490368e-05, 2.2272276e-05,
#       1.6812892e-05, 1.5484391e-05], dtype=float32)

#載入Image Net圖像分類標(biāo)簽
df = pd.read_csv('imagenet_class_index.csv')
df
idx_to_labels = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = [row['wordnet'], row['class']]
# idx_to_labels

#在原圖上標(biāo)注出分類結(jié)果
# 用 opencv 載入原圖
img_bgr = cv2.imread(img_path)
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 獲取類別名稱
    confidence = confs[i] * 100 # 獲取置信度
    text = '{:<15} {:>.4f}'.format(class_name, confidence)
    print(text)
    
    # !圖片净蚤,添加的文字钥组,左上角坐標(biāo),字體今瀑,字號程梦,bgr顏色,線寬
    img_bgr = cv2.putText(img_bgr, text, (25, 50 + 40 * i), cv2.FONT_HERSHEY_SIMPLEX, 1.25, (0, 0, 255), 3)

banana 99.7763
zucchini 0.1263
acorn_squash 0.0438
spaghetti_squash 0.0137
cucumber 0.0062
coil 0.0061
bell_pepper 0.0037
coral_fungus 0.0022
vase 0.0017
hook 0.0015

# 保存圖像
cv2.imwrite('output/img_pred.jpg', img_bgr)
img_pred.jpg
# 載入預(yù)測結(jié)果圖像
img_pred = Image.open('output/img_pred.jpg')
img_pred
fig = plt.figure(figsize=(18,6))

# 繪制左圖-預(yù)測圖
ax1 = plt.subplot(1,2,1)
ax1.imshow(img_pred)
ax1.axis('off')

# 繪制右圖-柱狀圖
ax2 = plt.subplot(1,2,2)
x = df['ID']
y = pred_softmax.cpu().detach().numpy()[0]
ax2.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)

plt.ylim([0, 1.0]) # y軸取值范圍
plt.title('{} Classification'.format(img_path), fontsize=30)
plt.xlabel('Class', fontsize=20)
plt.ylabel('Confidence', fontsize=20)
ax2.tick_params(labelsize=16) # 坐標(biāo)文字大小

plt.tight_layout()
fig.savefig('output/預(yù)測圖+柱狀圖.jpg')
預(yù)測圖+柱狀圖.jpg
#預(yù)測結(jié)果以表格輸出
pred_df = pd.DataFrame() # 預(yù)測結(jié)果表格
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]][1] # 獲取類別名稱
    label_idx = int(pred_ids[i]) # 獲取類別號
    wordnet = idx_to_labels[pred_ids[i]][0] # 獲取 WordNet
    confidence = confs[i] * 100 # 獲取置信度
    pred_df = pred_df.append({'Class':class_name, 'Class_ID':label_idx, 'Confidence(%)':confidence, 'WordNet':wordnet}, ignore_index=True) # 預(yù)測結(jié)果表格添加一行
display(pred_df) # 展示預(yù)測結(jié)果表格

3. 預(yù)測視頻文件

import os
import time
import shutil
import tempfile
from tqdm import tqdm

import cv2
from PIL import Image

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['axes.unicode_minus']=False  # 用來正常顯示負(fù)號
plt.rcParams['font.sans-serif']=['SimHei']  # 用來正常顯示中文標(biāo)簽
import gc

import torch
import torch.nn.functional as F
from torchvision import models

import mmcv

# 有 GPU 就用 GPU橘荠,沒有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)

# 后端繪圖屿附,不顯示,只保存
import matplotlib
matplotlib.use('Agg')

#載入預(yù)訓(xùn)練模型
model = models.resnet18(pretrained=True)
model = model.eval()
model = model.to(device)

#載入圖像分類標(biāo)簽
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = [row['wordnet'], row['class']]
# idx_to_labels

#圖像預(yù)處理
from torchvision import transforms

# 測試集圖像預(yù)處理-RCTN:縮放裁剪哥童、轉(zhuǎn) Tensor挺份、歸一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

def pred_single_frame(img, n=5):
    '''
    輸入攝像頭畫面bgr-array,輸出前n個圖像分類預(yù)測結(jié)果的圖像bgr-array
    '''
    img_bgr = img
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 轉(zhuǎn) RGB
    img_pil = Image.fromarray(img_rgb) # array 轉(zhuǎn) pil
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 預(yù)處理
    pred_logits = model(input_img) # 執(zhí)行前向預(yù)測贮懈,得到所有類別的 logit 預(yù)測分?jǐn)?shù)
    pred_softmax = F.softmax(pred_logits, dim=1) # 對 logit 分?jǐn)?shù)做 softmax 運算
    
    top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 個結(jié)果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析出類別
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析出置信度
    
    # 在圖像上寫字
    for i in range(n):
        class_name = idx_to_labels[pred_ids[i]][1] # 獲取類別名稱
        confidence = confs[i] * 100 # 獲取置信度
        text = '{:<15} {:>.4f}'.format(class_name, confidence)

        # !圖片匀泊,添加的文字,左上角坐標(biāo)错邦,字體探赫,字號型宙,bgr顏色撬呢,線寬
        img_bgr = cv2.putText(img_bgr, text, (25, 50 + 40 * i), cv2.FONT_HERSHEY_SIMPLEX, 1.25, (0, 0, 255), 3)
        
    return img_bgr, pred_softmax
#輸入視頻
input_video = 'test_img/video_2.mp4'
### 可視化方案一:原始圖像+預(yù)測結(jié)果文字
# 創(chuàng)建臨時文件夾,存放每幀結(jié)果
temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('創(chuàng)建文件夾 {} 用于存放每幀預(yù)測結(jié)果'.format(temp_out_dir))

# 讀入待預(yù)測視頻
imgs = mmcv.VideoReader(input_video)

prog_bar = mmcv.ProgressBar(len(imgs))

# 對視頻逐幀處理
for frame_id, img in enumerate(imgs):
    ## 處理單幀畫面
    img, pred_softmax = pred_single_frame(img, n=5)
    # 將處理后的該幀畫面圖像文件妆兑,保存至 /tmp 目錄下
    cv2.imwrite(f'{temp_out_dir}/{frame_id:06d}.jpg', img)    
    prog_bar.update() # 更新進度條

# 把每一幀串成視頻文件
mmcv.frames2video(temp_out_dir, 'output/output_pred.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir) # 刪除存放每幀畫面的臨時文件夾
print('刪除臨時文件夾', temp_out_dir)

### 可視化方案二:原始圖像+預(yù)測結(jié)果文字+各類別置信度柱狀圖
def pred_single_frame_bar(img):
    '''
    輸入pred_single_frame函數(shù)輸出的bgr-array魂拦,加柱狀圖,保存
    '''
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR 轉(zhuǎn) RGB
    fig = plt.figure(figsize=(18,6))
    # 繪制左圖-視頻圖
    ax1 = plt.subplot(1,2,1)
    ax1.imshow(img)
    ax1.axis('off')
    # 繪制右圖-柱狀圖
    ax2 = plt.subplot(1,2,2)
    x = range(1000)
    y = pred_softmax.cpu().detach().numpy()[0]
    ax2.bar(x, y, alpha=0.5, width=0.3, color='yellow', edgecolor='red', lw=3)
    plt.xlabel('類別', fontsize=20)
    plt.ylabel('置信度', fontsize=20)
    ax2.tick_params(labelsize=16) # 坐標(biāo)文字大小
    plt.ylim([0, 1.0]) # y軸取值范圍
    plt.xlabel('類別',fontsize=25)
    plt.ylabel('置信度',fontsize=25)
    plt.title('圖像分類預(yù)測結(jié)果', fontsize=30)
    
    plt.tight_layout()
    fig.savefig(f'{temp_out_dir}/{frame_id:06d}.jpg')
    # 釋放內(nèi)存
    fig.clf()
    plt.close()
    gc.collect()

# 創(chuàng)建臨時文件夾搁嗓,存放每幀結(jié)果
temp_out_dir = time.strftime('%Y%m%d%H%M%S')
os.mkdir(temp_out_dir)
print('創(chuàng)建文件夾 {} 用于存放每幀預(yù)測結(jié)果'.format(temp_out_dir))

# 讀入待預(yù)測視頻
imgs = mmcv.VideoReader(input_video)

prog_bar = mmcv.ProgressBar(len(imgs))

# 對視頻逐幀處理
for frame_id, img in enumerate(imgs):
    
    ## 處理單幀畫面
    img, pred_softmax = pred_single_frame(img, n=5)
    img = pred_single_frame_bar(img)
    
    prog_bar.update() # 更新進度條

# 把每一幀串成視頻文件
mmcv.frames2video(temp_out_dir, 'output/output_bar.mp4', fps=imgs.fps, fourcc='mp4v')

shutil.rmtree(temp_out_dir) # 刪除存放每幀畫面的臨時文件夾
print('刪除臨時文件夾', temp_out_dir)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末芯勘,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子腺逛,更是在濱河造成了極大的恐慌荷愕,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,546評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件棍矛,死亡現(xiàn)場離奇詭異安疗,居然都是意外死亡,警方通過查閱死者的電腦和手機够委,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,224評論 3 395
  • 文/潘曉璐 我一進店門荐类,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人茁帽,你說我怎么就攤上這事玉罐∏停” “怎么了?”我有些...
    開封第一講書人閱讀 164,911評論 0 354
  • 文/不壞的土叔 我叫張陵吊输,是天一觀的道長饶号。 經(jīng)常有香客問我,道長季蚂,這世上最難降的妖魔是什么讨韭? 我笑而不...
    開封第一講書人閱讀 58,737評論 1 294
  • 正文 為了忘掉前任,我火速辦了婚禮癣蟋,結(jié)果婚禮上透硝,老公的妹妹穿的比我還像新娘。我一直安慰自己疯搅,他們只是感情好濒生,可當(dāng)我...
    茶點故事閱讀 67,753評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著幔欧,像睡著了一般罪治。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上礁蔗,一...
    開封第一講書人閱讀 51,598評論 1 305
  • 那天觉义,我揣著相機與錄音,去河邊找鬼浴井。 笑死晒骇,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的磺浙。 我是一名探鬼主播洪囤,決...
    沈念sama閱讀 40,338評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼撕氧!你這毒婦竟也來了瘤缩?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,249評論 0 276
  • 序言:老撾萬榮一對情侶失蹤伦泥,失蹤者是張志新(化名)和其女友劉穎剥啤,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體不脯,經(jīng)...
    沈念sama閱讀 45,696評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡府怯,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,888評論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了跨新。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片富腊。...
    茶點故事閱讀 40,013評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖域帐,靈堂內(nèi)的尸體忽然破棺而出赘被,到底是詐尸還是另有隱情是整,我是刑警寧澤,帶...
    沈念sama閱讀 35,731評論 5 346
  • 正文 年R本政府宣布民假,位于F島的核電站浮入,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏羊异。R本人自食惡果不足惜事秀,卻給世界環(huán)境...
    茶點故事閱讀 41,348評論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望野舶。 院中可真熱鬧易迹,春花似錦、人聲如沸平道。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,929評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽一屋。三九已至窘疮,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間冀墨,已是汗流浹背闸衫。 一陣腳步聲響...
    開封第一講書人閱讀 33,048評論 1 270
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留诽嘉,地道東北人蔚出。 一個月前我還...
    沈念sama閱讀 48,203評論 3 370
  • 正文 我出身青樓,卻偏偏與公主長得像含懊,于是被迫代替她去往敵國和親身冬。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,960評論 2 355

推薦閱讀更多精彩內(nèi)容