本次主要對pytorch中的模型加載方法算柳、各個層對應(yīng)的名稱及tensor值查看的方法做一個總結(jié)土砂。
模型加載
import torch
# 模型文件位置
model_file = 'deeplab-mobilenet.pth.tar' # 或者.pth格式的模型文件
# 創(chuàng)建模型對象
model = MyDeepLab()
# 加載模型參數(shù)祝谚,若為cpu加載,則后面添加參數(shù): map_location='cpu'
ckpt = torch.load(model_file) # cpu加載方式 ckpt = torch.load(model_file, map_location='cpu')
# 模型參數(shù)恢復
model.load_state_dict(ckpt) # 若要忽略名稱不匹配的層:model.load_state_dict(ckpt, strict=False)
模型文件參數(shù)查看
# 使用上面代碼塊中的ckpt们豌,其實質(zhì)上是一個dict, key是layer名稱, value是對應(yīng)的tensor值
# 獲取所有l(wèi)ayer名稱
layer_name = list(ckpt.keys()) # 獲取預(yù)訓練模型各層的名稱涯捻,并轉(zhuǎn)為list
# 查看指定layer的tensor值
print(ckpt[ layer_name[2]) # 查看第2個層的參數(shù)值
模型對象參數(shù)查看
# 這里與模型文件參數(shù)的區(qū)別在于:這是針對代碼中創(chuàng)建的模型對象,查看其各個layer的名稱與tensor值
# 獲取模型中所有l(wèi)ayer的名稱
layer_name = list(model.state_dict().keys())
# 查看指定layer的tensor值
print( model.state_dict()[ layer_name[2] ])
模型文件中l(wèi)ayer名稱與模型對象中l(wèi)ayer名稱的區(qū)別
主要區(qū)別在于模型對象中的名稱最開始會多一個module.
例如:
模型文件中l(wèi)ayer名稱: conv5.weight
模型對象中l(wèi)ayer名稱: module.conv5.weight