我們以resnet18為例野芒,介紹幾種獲取模型摘要的方法缕允。
import torchvistion
model = torchvision.models.resnet18()
1.直接使用PrettyTable
from prettytable import PrettyTable
table = PrettyTable(['Modules', 'Parameters'])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f'Total Trainable Params: {total_params}')
效果如下:
PrettyTable
比較簡(jiǎn)單旁趟,也沒(méi)有模型的輸入輸出情況艰匙。
2. TorchSummary
from torchsummary import summary
summary(model, input_size = (3, 64, 64), batch_size = -1)
TorchSummary
整體看美觀了很多限煞,也有了輸出的維度。但是如果能打印出模型的層次結(jié)構(gòu)就更好了旬薯。
3. torchinfo
import torchinfo
torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 0)
torchinfo
這種方式更加美觀晰骑,且內(nèi)容詳細(xì),灰常棒绊序。