torch.split(tensor, split_size, dim=)
tensor是要切割的張量忆矛,dim表示在哪個(gè)維度上面進(jìn)行切割
注意:split_size是切分后每塊的大小,不是切分為多少塊!
a = torch.LongTensor([[1,2,3,4],[2,3,4,5]])
b = torch.cat(torch.split(a, 4, dim=1), dim=0)
print(b)
輸出:tensor([[1, 2, 3, 4],
? ? ? ? [2, 3, 4, 5]])