torch.Tensor數(shù)據(jù)類(lèi)型
torch.Tensor
是一種包含單一數(shù)據(jù)類(lèi)型元素的多維矩陣宅广。
Data tyoe | CPU tensor | GPU tensor |
---|---|---|
32-bit floating point | torch.FloatTensor | torch.cuda.FloatTensor |
64-bit floating point | torch.DoubleTensor | torch.cuda.DoubleTensor |
16-bit floating point | N/A | torch.cuda.HalfTensor |
8-bit integer (unsigned) | torch.ByteTensor | torch.cuda.ByteTensor |
8-bit integer (signed) | torch.CharTensor | torch.cuda.CharTensor |
16-bit integer (signed) | torch.ShortTensor | torch.cuda.ShortTensor |
32-bit integer (signed) | torch.IntTensor | torch.cuda.IntTensor |
64-bit integer (signed) | torch.LongTensor | torch.cuda.LongTensor |
torch.Tensor是默認(rèn)的tensor類(lèi)型(torch.FlaotTensor)的簡(jiǎn)稱
會(huì)改變tensor的函數(shù)操作會(huì)用一個(gè)下劃線后綴來(lái)標(biāo)示。比如恬口,torch.FloatTensor.abs_()
會(huì)在原地計(jì)算絕對(duì)值表悬,并返回改變后的tensor弥锄,而tensor.FloatTensor.abs()
將會(huì)在一個(gè)新的tensor中計(jì)算結(jié)果。
創(chuàng)建Tensor
# uninitialized
torch.empty()
torch.FloatTensor()
torch.IntTensor(d1,d2,d3)
torch.tensor([1.2, 3]).type()
# 設(shè)置默認(rèn)數(shù)據(jù)類(lèi)型
torch.set_default_tensor_type(torch.DoubleTensor)
# 隨機(jī)初始化
a = torch.rand(3,3) # [0,1]
torch.rand_like(a)
torch.randint(1,10,[3,3]) # [min, max]
# 正態(tài)分布
torch.randn(3,3) # N(0,1)
torch.normal(mean=torch.full([10], 0), std=torch.arange(1, 0, -0.1))
torch.full([2,3], 7) # 每個(gè)元素都設(shè)置為7
torch.full([], 7) # 標(biāo)量
torch.arange(0,10)
# linspace/logspace
torch.linspace(0,10, steps=4)
torch.logspace(0, -1, steps=10)
# ones/zeros/eye/*_like
torch.ones(3,3)
torch.zeros(3,3)
torch.eye(3,4)
# randperm == random.shuffle
torch.randperm(10)
Tensor 切片
類(lèi)似于numpy切片操作蟆沫,eg: a[1:10,:], a[:10:2,:]
a = torch.randn(4,3,28,28)
a[:2]
a[:2, 1:, :,:].shape # output: [2,2,28,28]
# select by specific index
a.index_select(0, torch.tensor([0,2]))
a[...].shape # 任意維度
a[..., :2] # 與*list 變長(zhǎng)解包類(lèi)似籽暇?
# select by mask
x = torch.randn(3,4)
mask = x.ge(0.5)
torch.masked_select(x, mask)
# select by flatten index
src = torch.tensor([[4,3,5], [6,7,8]])
torch.take(src, torch.tensor([0,2]))
Tensor維度變換
- view/reshape
- squeeze/unsqueeze
- transpose/t/permute
- expand/repeat
# view reshape (lost dim information)
In [41]: a = torch.rand(4,1 ,28, 28)
In [42]: a.shape
Out[42]: torch.Size([4, 1, 28, 28])
In [43]: a.view(4, 28*28)
Out[43]:
tensor([[0.6006, 0.8933, 0.1474, ..., 0.5848, 0.9790, 0.6479],
[0.1824, 0.8874, 0.1635, ..., 0.3386, 0.3563, 0.0075],
[0.8867, 0.9460, 0.1208, ..., 0.1569, 0.2614, 0.7639],
[0.1437, 0.5749, 0.2275, ..., 0.5167, 0.6074, 0.5263]])
In [44]: a.view(4, 28*28).shape
Out[44]: torch.Size([4, 784])
# unsqueeze(維度增加)
In [50]: b = torch.rand(32)
In [51]: f = torch.rand(4,32, 14,14)
In [52]: b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
In [53]: b.shape
Out[53]: torch.Size([1, 32, 1, 1])
# expand/repeat
b.expand([4,32,14,14]) # [1,32,1,1] -> [4,32,14,14]
b.repeat(4,1,32,32) # 重復(fù)
# a.t() 2d數(shù)據(jù)
# transpose
a.transpose(1,3) # 指定交換的dim
a.transpose(1,3).contiguous()
# permute 交換維度
# [b c h w] -> [b h w c]
b.permute(0,2,3,1) # [b h w c]