我們整理一下tensor的常見的處理函數(shù)努潘。包括拆分(Split)饭玲、合并(Cat)伊履、Stack、Chunk
合并(Cat)
和TensorFlow的tf.concat類似艺骂。
torch.cat([a , b] , dim),合并tensor a和b,dim指的是從哪個(gè)維度隆夯。其他維度需要保持一致钳恕,如果不一致會(huì)出錯(cuò)。
batch_1 = torch.rand(2,3,28,28)
batch_2 = torch.rand(5,3,28,28)
torch.cat([batch_1,batch_2],dim=0).shape
#torch.Size([7, 3, 28, 28])
stack
stack 與 concat 不同之處蹄衷,會(huì)增加一個(gè)維度用于區(qū)分合并的不同 tensor忧额。需要要合并兩個(gè) tensor 形狀完全一致,而 dim=2 維度前添加一個(gè)維度愧口。
batch_1 = torch.rand(2,3,16,32)
batch_2 = torch.rand(2,3,16,32)
torch.stack([batch_1,batch_2],dim=2).shape
#torch.Size([2, 3, 2, 16, 32])
grp_1 = torch.rand(32,8)
grp_2 = torch.rand(32,8)
torch.stack([grp_1,grp_2],dim=0).shape
# torch.Size([2, 32, 8])
split
c = torch.rand(3,32,8)
grp_1,grp_2 = c.split([1,2],dim=0)
print(grp_1.shape)
print(grp_2.shape)
#torch.Size([1, 32, 8])
#torch.Size([2, 32, 8])
c = torch.rand(4,32,8)
grp_1,grp_2 = c.split([2,dim=0)
print(grp_1.shape)
print(grp_2.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
chunk
# chunk 按數(shù)量進(jìn)行拆分
grp_1,grp_2,grp_3 = c.chunk(3,dim=0)
print(grp_1.shape)
print(grp_2.shape)
print(grp_3.shape)
'''
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
'''
view
類似于numpy中的resize睦番,改變tensor的size。
import torch
tt1=torch.tensor([-0.3623,-0.6115,0.7283,0.4699,2.3261,0.1599])
result=tt1.view(3,2)
輸出
tensor([[-0.3623, -0.6115],
[ 0.7283, 0.4699],
[ 2.3261, 0.1599]])
size
Tensor.szie()可以獲取tensor的形狀调卑。