-
max()
torch.max(input, dim)
dim參數(shù)指出刪去哪一維度虑椎,0-行凉逛,1-列;輸出兩個(gè)tensor踩验,第一個(gè)得到最大值結(jié)果,第二個(gè)給出相對位置(0-index)
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
[ 1.1949, -1.1127, -2.2379, -0.6702],
[ 1.5717, -0.9207, 0.1297, -1.8768],
[-0.6172, 1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
(tensor([ 0.8475, 1.1949, 1.5717, 1.0036]), tensor([ 3, 0, 0, 1]))
dim=1商玫,刪除列的維度箕憾,只有1列,每一行為該行最大值拳昌,第二個(gè)tensor給出該最大值所在的列數(shù)
等同于a.max(1)
例:在訓(xùn)練網(wǎng)絡(luò)時(shí)
output = net(img)
_, predicted = output.max(1)
output為對img的預(yù)測輸出袭异,batch行l(wèi)abel列,每行是一個(gè)圖片的輸出炬藤,每次輸出batch組御铃。所以預(yù)測結(jié)果需要看每行的最大值碴里,找每行最大值的位置。output.max(1)
找到每行最大值上真,有兩個(gè)tensor輸出咬腋,第一個(gè)為最大值,第二個(gè)為最大值所在位置睡互,所關(guān)注的是位置根竿,所以第一個(gè)下劃線_
舍棄掉最大值。
item()
把tensor轉(zhuǎn)換成數(shù)torch.nn.Sequential語法
nn.Sequential(a, b, c)
括號就珠,逗號torchvision.transforms.Composed語法
transforms.Composed([a, b, c])
括號寇壳,方括號,逗號