1、torch.cat(inputs, dim=0) -> Tensor?
參考鏈接:
Pytorch學習筆記(一):torch.cat()模塊的詳解
函數(shù)作用:cat 是 concatnate 的意思:拼接徘熔,聯(lián)系在一起募寨。在給定維度上對輸入的 Tensor 序列進行拼接操作寄月。torch.cat 可以看作是 torch.split 和 torch.chunk 的反操作
參數(shù):
inputs(sequence of Tensors):可以是任意相同類型的 Tensor 的 python 序列
dim(int, optional):defaults=0
dim=0: 按列進行拼接?
dim=1: 按行進行拼接
dim=-1: 如果行和列數(shù)都相同則按行進行拼接酵紫,否則按照行數(shù)或列數(shù)相等的維度進行拼接
假設 a 和 b 都是 Tensor,且 a 的維度為 [2, 3]河劝,b 的維度為 [2, 4]惠啄,則
torch.cat((a, b), dim=1) 的維度為 [2, 7]
2慎恒、torch.nn.CrossEntropyLoss()
函數(shù)作用:CrossEntropy 是交叉熵的意思,故而 CrossEntropyLoss 的作用是計算交叉熵撵渡。CrossEntropyLoss 函數(shù)是將 torch.nn.Softmax 和 torch.nn.NLLLoss 兩個函數(shù)組合在一起使用融柬,故而傳入的預測值不需要先進行 torch.nnSoftmax 操作。
參數(shù):
input(N, C):N 是 batch_size趋距,C 則是類別數(shù)粒氧,即在定義模型輸出時,輸出節(jié)點個數(shù)要定義為 [N, C]节腐。其中特別注意的是 target 的數(shù)據(jù)類型需要是浮點數(shù)外盯,即 float32
target(N):N 是 batch_size,故 target 需要是 1D 張量翼雀。其中特別注意的是 target 的數(shù)據(jù)類型需要是 long饱苟,即 int64
例子:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True, dtype=torch.float32)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output
輸出為:
tensor(1.6916, grad_fn=<NllLossBackward>)