1 Tensor的創(chuàng)建
# 生成一個(gè)值全為6的tensor
torch.full([2,3], 6)
torch.full([], 6) # 生成一個(gè)標(biāo)量
torch.full([1], 6) # 生成一個(gè)標(biāo)量
# arange 生成一個(gè)等差數(shù)列
torch.arange(0,10)
torch.arange(0,10,2)
# linspace/logspace 等分的切
torch.linspace(0, 10, steps=4) # 這里是等分的切開(kāi)张遭,而arange是等差的數(shù)列
torch.logspace(0, 10, steps=4)
# ones/zeros/eyes
torch.ones(2,3)
torch.zeros(2,3)
torch.eyes(3)
torch.ones_like(a)
# randperm 生成一個(gè)指定范圍的序列趁啸,且順序被打亂 (random.shuffle)
print(randperm(10)) # tensor([7, 1, 8, 2, 9, 0, 5, 4, 3, 6]
2 Tensor的索引和切片
2.1 Indexing 索引
a = torch.rand(4, 3, 28, 28)
print(a[0].shape) # torch.Size([3, 28, 28])
print(a[0, 0].shape) # torch.Size([28, 28])
print(a[0, 0, 0].shape) # torch.Size([28])
print(a[0, 0, 0, 1].shape) # torch.Size([]), 標(biāo)量
2.2 select first/last N
a = torch.rand(4, 3, 28, 28)
print(a[:2].shape) # torch.Size([2, 3, 28, 28]) 含義青灼,選擇第一、二兩張圖片
print(a[:2, :1].shape) # torch.Size([2, 1, 28, 28]) 含義,取前兩張圖片第一個(gè)通道上的數(shù)據(jù)
print(a[:2, 1:,:,:].shape) # torch.Size([2, 2, 28, 28]) 含義,取前兩張圖片除第一個(gè)通道上的數(shù)據(jù)
print(a[:2, -1:,:,:].shape) # torch.Size([2, 1, 28, 28]) 含義,取前兩張圖片最后通道上的數(shù)據(jù)
2.3 select by steps
a = torch.rand(4, 3, 28, 28)
print(a[:,:,0:28:2,0:28:2].shape) # torch.Size([4, 3, 14, 14]), 含義,隔行采樣
print(a[:,:,::2,::2].shape) # torch.Size([4, 3, 14, 14]), 含義渗鬼,隔行采樣,同上
pytorch中冒號(hào)在切塊的含義
1荧琼、冒號(hào)單獨(dú)出現(xiàn)表示取全部譬胎;
2、冒號(hào)在數(shù)字前面如:N命锄,表示取0~N-1堰乔;
3、冒號(hào)在數(shù)字后面如N:累舷,表示取N~len(sequence)浩考;
4、start:end, 表示取從start~end-1被盈;
5析孽、start:end:step, 表示從start~end-1每隔step取一次搭伤, 可以省略寫為:::step。
2.4 select by specific index
a = torch.rand(4, 3, 28, 28)
print(a.index_select(0, torch.tensor([0, 2])).shape)
# torch.Size([2, 3, 28, 28]) 選擇圖片中的第0/2兩張圖片袜瞬;
print(a.index_select(2, torch.arange(24)).shape) # torch.Size([4, 3, 24, 28]) 選擇第三維的24行
print(a[...].shape) # torch.Size([4, 3, 28, 28])
print(a[0, ...].shape) # torch.Size([3, 28, 28])
print(a[:,1, ...].shape) # torch.Size([4, 28, 28])
print(a[0,..., ::2].shape) # torch.Size([3, 28, 14])
print(a[..., :2].shape) # torch.Size([4, 3, 28, 2])
“...”表示任意多的維度怜俐,會(huì)根據(jù)tensor自動(dòng)推斷出
2.5 select by mask 【masked_select】
a = torch.randn(3, 4)
print(a)
tensor([[ 0.1641, 1.2368, 0.7215, -0.5228],
[ 0.0288, 0.6919, -1.6339, -0.1283],
[-0.0908, -0.1472, -0.2184, -0.6402]])
mask = a.ge(0.5)
print(mask)
tensor([[0, 1, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 0]], dtype=torch.uint8)
print(torch.masked_select(a, mask))
tensor([1.2368, 0.7215, 0.6919])
2.6 select by flatten index 【將tensor打平之后,再按照index進(jìn)行選擇】
a = torch.tensor([[4, 3, 5], [6, 7, 8]])
print(torch.take(a, torch.tensor([0, 2, 5]))) # tensor([4, 5, 8])
3 Tensor維度變換
3.1 view/reshape 【view和reshape是等價(jià)的】
a = torch.rand(4, 1, 28, 28)
print(a.shape) # torch.Size([4, 1, 28, 28])
print(a.view(4, 28*28).shape) # torch.Size([4, 784])
print(a.view(4*1, 28, 28).shape) # torch.Size([4, 28, 28])
print(a.view(4*1*28, 28).shape) # torch.Size([112, 28])
view改變了Tensor的維度邓尤。
3.2 squeeze/unsqueeze 【在指定位置處縮減維度或增加維度】
正負(fù)索引的對(duì)應(yīng)關(guān)系:
[0, 1, 2, 3, 4] => [-5, -4, -3, -2, -1]
# 維度增加
a = torch.rand(4, 1, 28, 28)
print(a.unsqueeze(0).shape) # torch.Size([1, 4, 1, 28, 28]) 增加一個(gè)維度
print(a.unsqueeze(-1).shape) # torch.Size([4, 1, 28, 28, 1]) 增加一個(gè)維度
print(a.unsqueeze(4).shape) # torch.Size([4, 1, 28, 28, 1]) 增加一個(gè)維度
print(a.unsqueeze(-4).shape) # torch.Size([4, 1, 1, 28, 28]) 增加一個(gè)維度
print(a.unsqueeze(-5).shape) # torch.Size([1, 4, 1, 28, 28]) 增加一個(gè)維度
b = torch.tensor([1.2, 2.3])
print(b) # tensor([1.2000, 2.3000])
print(b.unsqueeze(-1)) # tensor([[1.2000],[2.3000]]) 在內(nèi)層嵌入了一個(gè)維度
print(b.unsqueeze(0)) # tensor([[1.2000, 2.3000]]) 在外層嵌入了一個(gè)維度
c = torch.rand(32)
d = c.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(c.shape) # torch.Size([32])
print(c.unsqueeze(1).shape) # torch.Size([32, 1])
print(c.unsqueeze(1).unsqueeze(2).shape) # torch.Size([32, 1, 1])
print(c.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape) # torch.Size([1, 32, 1, 1])
print(d.shape) # torch.Size([1, 32, 1, 1])
# 維度刪減拍鲤。未指定刪減的維度時(shí),dim維度為1的會(huì)被擠壓掉且只有大小為1的維度才能被擠壓掉
print(torch.rand(4, 1, 28, 28).squeeze().shape) # 維度為1的都被擠壓掉汞扎,torch.Size([4, 28, 28])
print(torch.rand(1, 32, 1, 1).squeeze().shape) # 維度為1的都被擠壓掉季稳,torch.Size([32])
print(torch.rand(1, 32, 1, 1).squeeze(0).shape) # 只擠壓掉了第一個(gè)維度, torch.Size([32, 1, 1])
print(torch.rand(1, 32, 1, 1).squeeze(-1).shape) # 只擠壓掉了最后一個(gè)維度澈魄, torch.Size([1, 32, 1])
print(torch.rand(4, 1, 28, 28).squeeze(0).shape) # 擠壓大小非1的維度景鼠,不起作用 torch.Size([4, 1, 28, 28])
3.3 expand(broadcasting)/repeat(memory copied) 【維度擴(kuò)展】
# 維度擴(kuò)展, 維度為1時(shí)才可以擴(kuò)展, -1表示不擴(kuò)展對(duì)應(yīng)的維度
a = torch.rand(4, 32, 14, 14)
b = torch.rand(1, 32, 1, 1)
print(b.expand(4, 32, 14, 14).shape)
print(torch.rand(2, 32, 1, 1).expand(-1, 32, 14, 14).shape) # -1表示不擴(kuò)展對(duì)應(yīng)的維度痹扇, torch.Size([2, 32, 14, 14])
# print(torch.rand(2, 32, 1, 1).expand(4, 32, 14, 14).shape)
# RuntimeError: The expanded size of the tensor (4) must match the existing size (2)
# repeat維度擴(kuò)展铛漓。repeat的參數(shù)表示每個(gè)維度要復(fù)制的次數(shù),而expand的參數(shù)指定的就是對(duì)應(yīng)的維度(不推薦)
b = torch.rand(1, 32, 1, 1)
print(b.repeat(4, 32, 1, 1).shape) # torch.Size([4, 1024, 1, 1])
print(b.repeat(4, 1, 1, 1).shape) # torch.Size([4, 32, 1, 1])
print(b.repeat(4, 1, 14, 14).shape) # torch.Size([4, 32, 14, 14])
- expand不會(huì)主動(dòng)的復(fù)制數(shù)據(jù)鲫构,推薦使用浓恶。repeat會(huì)先拷貝數(shù)據(jù);
- repeat的參數(shù)表示每個(gè)維度要復(fù)制的次數(shù)结笨,而expand的參數(shù)指定的就是對(duì)應(yīng)的維度(不推薦repeat)包晰;
- expand維度為1時(shí)才可以擴(kuò)展, -1表示不擴(kuò)展對(duì)應(yīng)的維度炕吸;
3.4 transpose/t/permute 【矩陣的轉(zhuǎn)置】
a = torch.rand(4, 3, 32, 32)
a1 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 3, 32, 32) # 正確但是數(shù)據(jù)已經(jīng)錯(cuò)了
a2 = a.transpose(1, 3).contiguous().view(4, 3*32*32).view(4, 32, 32, 3).transpose(1, 3) # 正確
print(a1.shape)
print(a2.shape)
print(torch.all(torch.eq(a, a1))) # tensor(0, dtype=torch.uint8)
print(torch.all(torch.eq(a, a2))) # tensor(1, dtype=torch.uint8)
# permute 可以直接指定維度的順序
a = torch.rand(4, 3, 32, 28)
print(a.permute(0, 2, 3, 1).shape) # torch.Size([4, 32, 28, 3])
- b.t()只適用2D的Tensor杜窄;
- transpose()可以指定交換的維度,進(jìn)行兩兩交換算途, view前要使用contiguous;
- view()操作會(huì)導(dǎo)致數(shù)據(jù)維度的丟失蚀腿;
- permute()可以直接指定的順序交換維度嘴瓤;
4 auto-broadcasting
# broadcasting = unsqueeze + expand
# insert 1 dim ahead (在高緯度擴(kuò)展)
# expand dims with size 1 to same size
# feature maps: [4, 32, 14, 14]
# Bias: [32, 1, 1] => [1, 32, 1, 1] => [4, 32, 14, 14]
Is it broadcasting-able? (match from last dim!!!)
If current dim =1 , expand to same;
if either has no dim , insert one dim and expand to same;
otherwise, NOT broadcasting-able;
When it has no dim, treat it as own the same.
[class, student, scores] + [scores]
when it has dim of size 1, treat is shared by all.
[class, student, scores] + [student, 1]
5 Tensor 的拼接與拆分
cat/stack
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.rand(4, 32, 8)
print(torch.cat([a, b], dim=0).shape) # torch.Size([9, 32, 8])
print(torch.cat([a, c], dim=1).shape) # torch.Size([4, 64, 8])
a = torch.rand(4, 3, 16, 16)
b = torch.rand(4, 3, 16, 16)
print(torch.stack([a, b], dim=2).shape) # torch.Size([4, 3, 2, 16, 16])
cat: 只有被連接的維度的size可以不一樣,其他維度的size必須一樣莉钙,拼接完成之后非拼接維度的size保持不變廓脆,拼接的維度是各個(gè)Tensor對(duì)應(yīng)size的累加。
stack會(huì)在指定維度之前插入一個(gè)新的維度磁玉,舊維度的size必須要完全一致;
split/chunk
a = torch.rand(3, 32, 8)
a1, a2 = a.split([2, 1], dim=0) # 拆分長(zhǎng)度分別為2停忿, 1
print(a1.shape, a2.shape) # torch.Size([2, 32, 8]) torch.Size([1, 32, 8])
a = torch.rand(7, 32, 8)
a1, a2 = a.chunk(2, dim=0) # 分成2塊
print(a1.shape, a2.shape) # torch.Size([4, 32, 8]) torch.Size([3, 32, 8])
split 按長(zhǎng)度拆分, split by length
chunk 按數(shù)量拆分, chunk by number
6 Tensor的基本運(yùn)算
矩陣的加減乘除
a, b = torch.rand(3, 4), torch.rand(4)
print(torch.all(torch.eq(a+b, torch.add(a, b)))) # 加法和加法運(yùn)算符蚊伞,tensor(1, dtype=torch.uint8)
print(torch.all(torch.eq(a-b, torch.sub(a, b)))) # 減法和減法運(yùn)算符席赂,tensor(1, dtype=torch.uint8)
print(torch.all(torch.eq(a*b, torch.mul(a, b)))) # 乘法和乘法運(yùn)算符吮铭,tensor(1, dtype=torch.uint8)
print(torch.all(torch.eq(a/b, torch.div(a, b)))) # 除法和除法運(yùn)算符,tensor(1, dtype=torch.uint8)
矩陣的乘法和按位相乘
* 表示按元素相乘颅停,element-wise
.matmul 矩陣相乘谓晌, matrix mul
1st: torch.mm only for 2d
2nd: torch.matmul for all-dim, @為matmul的重載運(yùn)算符【推薦】
a = torch.full([2, 2], 3)
b = torch.full([2, 2], 4)
print(a*b) # 表示按元素相乘, tensor([[12., 12.],[12., 12.]])
print(a.mm(b)) # 矩陣相乘, tensor([[24., 24.], [24., 24.]])
print(a@b) # 矩陣相乘, tensor([[24., 24.], [24., 24.]])
print(a.matmul(b)) # 矩陣相乘, tensor([[24., 24.], [24., 24.]])
# matmul 實(shí)際上只取后面的兩個(gè)維度進(jìn)行計(jì)算,
# 其他維度的size要相同或滿足broadcasting的條件
a = torch.rand(4, 3, 28, 64)
b = torch.rand(4, 3, 64, 32)
c = torch.rand(4, 1, 64, 32)
d = torch.rand(4, 2, 64, 32)
print((a@b).shape) # torch.Size([4, 3, 28, 32])
print((a@c).shape) # torch.Size([4, 3, 28, 32]), broadcasting
print((a@d).shape) # torch.Size([4, 3, 28, 32]), can't broadcasting, error
指數(shù)運(yùn)算
# pow或者**進(jìn)行指數(shù)運(yùn)算
a = torch.full([2, 2], 3)
print(a.pow(2)) # tensor([[9., 9.], [9., 9.]])
print(a**2) # tensor([[9., 9.], [9., 9.]])
# sqrt 求平方根, rsqrt 平方根的倒數(shù)
b = a.pow(2)
print(b.sqrt()) # tensor([[3., 3.], [3., 3.]])
print(b.rsqrt()) # tensor([[0.3333, 0.3333], [0.3333, 0.3333]])
# exp log
a = torch.exp(torch.ones(2, 2))
b = torch.ones(2, 2)*100
c = torch.ones(2, 2)*8
print(a) # tensor([[2.7183, 2.7183], [2.7183, 2.7183]])
print(torch.log(a)) # tensor([[1., 1.], [1., 1.]])
print(torch.log10(b)) # tensor([[2., 2.], [2., 2.]])
print(torch.log2(c)) # tensor([[3., 3.], [3., 3.]])
近似計(jì)算 【floor/ceil/round/trunc/frac/clamp】
a = torch.tensor(3.14)
print(torch.floor(a)) # 向下取整 tensor(3.)
print(torch.ceil(a)) # 向上取整 tensor(4.)
print(torch.round(a)) # 四舍五入 tensor(3.)
print(torch.trunc(a)) # 取整數(shù)部分 tensor(3.)
print(torch.frac(a)) # 取小數(shù)部分 tensor(0.1400)
# clamp 裁剪癞揉, gradient clipping
grad = torch.rand(2, 3)*15
grad = torch.tensor([[10.8008, 11.9414, 3.9532], [7.5537, 13.9067, 4.6728]])
print(grad) # tensor([[10.8008, 11.9414, 3.9532], [7.5537, 13.9067, 4.6728]])
print(grad.max()) # tensor(13.9067)
print(grad.median()) # tensor(7.5537)
print(grad.clamp(10)) # tensor([[10.8008, 11.9414, 10.0000], [10.0000, 13.9067, 10.0000]], 裁剪最小值
print(grad.clamp(0, 10))# tensor([[10.0000, 10.0000, 3.9532], [ 7.5537, 10.0000, 4.6728]])纸肉,裁剪最大值和最小值
7 Tensor的統(tǒng)計(jì)屬性
norm
# norm是范數(shù)而不是normalize
a = torch.full([8], 1)
b = a.view(2, 4)
c = a.view(2, 2, 2)
print(a.norm(1), b.norm(1), c.norm(1)) # tensor(8.) tensor(8.) tensor(8.)
print(a.norm(2), b.norm(2), c.norm(2)) # tensor(2.8284) tensor(2.8284) tensor(2.8284)
print(b.norm(1, dim=1)) # tensor([4., 4.])
print(c.norm(1, dim=1)) # tensor([[2., 2.], [2., 2.]])
mean/sum/min/max/prod
a = torch.arange(8).view(2, 4).float()
print(a) # tensor([[0., 1., 2., 3.], [4., 5., 6., 7.]])
print(a.min(), a.max(), a.mean(), a.prod()) # tensor(0.) tensor(7.) tensor(3.5000) tensor(0.)
print(a.sum(), a.argmax(), a.argmin()) # tensor(28.) tensor(7) tensor(0)
print(a.min(dim=0)) # (tensor([0., 1., 2., 3.]), tensor([0, 0, 0, 0]))
print(a.argmin(dim=0)) # tensor([0, 0, 0, 0])
# dim, keepdim
a = torch.rand(4, 10)
print(a)
print(a.argmin(dim=1)) # tensor([5, 5, 0, 5])
print(a.argmin(dim=1, keepdim=True)) # tensor([[5],[5],[0],[5]])
print(a.argmin(dim=1).unsqueeze(-1)) # tensor([[5],[5],[0],[5]]) 和上式等價(jià)
- a.min() 返回一個(gè)tuple,第一個(gè)為最小的元素喊熟,二個(gè)為對(duì)應(yīng)的索引柏肪,即a.argmin()的返回值。
- 理解a.argmax(dim=1)返回的維度: dim=1時(shí)說(shuō)明在維度1上做聚合操作芥牌,聚合后該維度值全為1會(huì)被消掉烦味,所以返回的大小是dim=0的大小。因?yàn)榫S度為1的被壓縮掉了胳泉,若不希望維度被壓縮拐叉,可將keepdim設(shè)置為True。
top-k, k-th
a = torch.rand(4, 10)
print(a.topk(3, dim=1)[0].shape) # torch.Size([4, 3])
print(a.kthvalue(8, dim=1)[0].shape) # torch.Size([4])
compare
# compare
a = torch.rand(4, 10)
print(torch.eq(a, a))
# tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)
print(torch.equal(a, a)) # True扇商, 返回True或False
torch.eq()和torch.equal()返回值是不同的凤瘦。torch.eq()返回的是一個(gè)tensor案铺,tensor中的各個(gè)位相等為1,不相等為0控汉。
torch.equal()返回的是True或者False。
高級(jí)Tensor操作:where | gather
# torch.where(condition, x, y) -> Tensor
# Return a tensor of elements selected from either x or y, depending on condition.
# out=x if condition else out=y
# x, y的shape必須要一致
cond = torch.rand(2, 2)
print(cond) # tensor([[0.1623, 0.4277], [0.6705, 0.4220]])
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(torch.where(cond > 0.5, a, b)) # tensor([[1., 1.], [0., 1.]])
# gather
# torch.gather(input, dim, index, out=None) -> Tensor
# Gathers values along an axis specified by dim
# Gather的本質(zhì)就是查表姑子,input要查的表,dim維度上的index索引序列
prob = torch.randn(4, 10)
idx = prob.topk(dim=1, k=3)
idx = idx[1]
print(idx)
# tensor([[8, 0, 4],
# [4, 8, 3],
# [6, 0, 9],
# [8, 7, 9]])
label = torch.arange(10) + 100
print(torch.gather(label.expand(4, 10), dim=1, index=idx.long()))
# tensor([[108, 100, 104],
# [104, 108, 103],
# [106, 100, 109],
# [108, 107, 109]])
其他
- 能夠直接求解的方程街佑,如二元一次方程,可以直接求解沐旨,這樣的問(wèn)題稱為closed form solution。而實(shí)際生活中大部分是無(wú)法直接求解的磁携,只能通過(guò)近似的計(jì)算褒侧。
- PyTorch是沒(méi)有string類型的張量的,通過(guò)one-hot encoding或者Embedding表示烟央。