cat()是用來連接多個tensor的:
T = torch.tensor( [ [ 1 ] ] ) print("[[1]]:", torch.cat( [ T, T, T ] ) )
[[1]]: tensor( [ [1], [1], [1] ] )
T = torch.tensor( [ 1 ] ) print("[1]:", torch.cat( [ T, T, T ] ) )
[1]: tensor([1, 1, 1])
這樣不行:
torch.cat( [ 1, 1, 1 ] )
TypeError: expected Tensor as element 0 in argument 0, but got int