參考自:csdn
torch.Tensor的4種乘法
torch.Tensor有4種常見的乘法:*, torch.mul, torch.mm, torch.matmul. 本文拋磚引玉蚜厉,簡單敘述一下這4種乘法的區(qū)別弟翘,具體使用還是要參照官方文檔耙饰。
點乘
a與b做*乘法,原則是如果a與b的size不同玉组,則以某種方式將a或b進行復(fù)制耍目,使得復(fù)制后的a和b的size相同房交,然后再將a和b做element-wise的乘法帜篇。
下面以標(biāo)量和一維向量為例展示上述過程糙捺。
- 標(biāo)量
Tensor與標(biāo)量k做*乘法的結(jié)果是Tensor的每個元素乘以k(相當(dāng)于把k復(fù)制成與lhs大小相同,元素全為k的Tensor).
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> a * 2
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
- 一維向量
Tensor與行向量做*乘法的結(jié)果是每列乘以行向量對應(yīng)列的值(相當(dāng)于把行向量的行復(fù)制笙隙,成為與lhs維度相同的Tensor).
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4])
>>> b
tensor([1., 2., 3., 4.])
>>> a * b
tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
Tensor與列向量做*乘法的結(jié)果是每行乘以列向量對應(yīng)行的值(相當(dāng)于把列向量的列復(fù)制洪灯,成為與lhs維度相同的Tensor).
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
[2.],
[3.]])
>>> a * b
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.mul
官方文檔關(guān)于torch.mul的介紹. 用法與*乘法相同。
兩者都是broadcast的竟痰。broadcast是torch的一個概念签钩,個人理解是為了便利高維(3維以上)矩陣運算。broadcast的概念稍顯復(fù)雜坏快,在此不做展開铅檩,可以參考官方文檔關(guān)于broadcast的介紹. 在torch.matmul里會有關(guān)于broadcast的應(yīng)用的一個簡單的例子。
下面是3個torch.mul的例子.
乘標(biāo)量
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> a * 2
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
乘行向量
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3,4])
>>> b
tensor([1., 2., 3., 4.])
>>> torch.mul(a, b)
tensor([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])
乘列向量
>>> a = torch.ones(3,4)
>>> a
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> b = torch.Tensor([1,2,3]).reshape((3,1))
>>> b
tensor([[1.],
[2.],
[3.]])
>>> torch.mul(a, b)
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.mm
官方文檔關(guān)于torch.mm的介紹. 數(shù)學(xué)里的矩陣乘法莽鸿,要求兩個Tensor的維度滿足矩陣乘法的要求.
例子:
>>> a = torch.ones(3,4)
>>> b = torch.ones(4,2)
>>> torch.mm(a, b)
tensor([[4., 4.],
[4., 4.],
[4., 4.]])
torch.matmul
官方文檔關(guān)于torch.matmul的介紹. torch.mm的broadcast版本.
例子:
>>> a = torch.ones(3,4)
>>> b = torch.ones(5,4,2)
>>> torch.matmul(a, b)
tensor([[[4., 4.],
[4., 4.],
[4., 4.]],
[[4., 4.],
[4., 4.],
[4., 4.]],
[[4., 4.],
[4., 4.],
[4., 4.]],
[[4., 4.],
[4., 4.],
[4., 4.]],
[[4., 4.],
[4., 4.],
[4., 4.]]])
同樣的a和b昧旨,使用torch.mm相乘會報錯
>>> torch.mm(a, b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: matrices expected, got 2D, 3D tensors at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2065