pytorch 有多種乘法運(yùn)算宛篇,在這里做一次全面的總結(jié)叫倍。
元素一一相乘
該操作又稱作 "哈達(dá)瑪積", 簡(jiǎn)單來(lái)說(shuō)就是 tensor 元素逐個(gè)相乘。這個(gè)操作段标,是通過(guò) 也就是常規(guī)的乘號(hào)操作符定義的操作結(jié)果。torch.mul 是等價(jià)的蛇更。
import torch
def element_by_element():
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
return x * y, torch.mul(x, y)
element_by_element()
(tensor([ 4, 10, 18]), tensor([ 4, 10, 18]))
這個(gè)操作是可以 broad cast 的赛糟。
def element_by_element_broadcast():
x = torch.tensor([1, 2, 3])
y = 2
return x * y
element_by_element_broadcast()
tensor([2, 4, 6])
向量點(diǎn)乘
torch.matmul: If both tensors are 1-dimensional, the dot product (scalar) is returned.
如果都是1維的,返回的就是 dot product 結(jié)果
def vec_dot_product():
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
return torch.matmul(x, y)
vec_dot_product()
tensor(32)
矩陣乘法
torch.matmul: If both arguments are 2-dimensional, the matrix-matrix product is returned.
如果都是2維掌逛,那么就是矩陣乘法的結(jié)果返回司倚。與 torch.mm 是等價(jià)的,torch.mm 僅僅能處理的是矩陣乘法皿伺。
def matrix_multiple():
x = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
y = torch.tensor([
[7, 8],
[9, 10],
[11, 12]
])
return torch.matmul(x, y), torch.mm(x, y)
matrix_multiple()
(tensor([[ 58, 64],
[139, 154]]), tensor([[ 58, 64],
[139, 154]]))
vector 與 matrix 相乘
torch.matmul: If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
如果第一個(gè)是 vector, 第二個(gè)是 matrix, 會(huì)在 vector 中增加一個(gè)維度盒粮。也就是 vector 變成了 與 matrix 相乘之后,變成 , 在結(jié)果中將 維 再去掉妒穴。
def vec_matrix():
x = torch.tensor([1, 2, 3])
y = torch.tensor([
[7, 8],
[9, 10],
[11, 12]
])
return torch.matmul(x, y)
vec_matrix()
tensor([58, 64])
matrix 與 vector 相乘
同樣的道理摊崭, vector會(huì)被擴(kuò)充一個(gè)維度。
def matrix_vec():
x = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
y = torch.tensor([
7, 8, 9
])
return torch.matmul(x, y)
matrix_vec()
tensor([ 50, 122])
帶有batch_size 的 broad cast乘法
def batched_matrix_broadcasted_vector():
x = torch.tensor([
[
[1, 2], [3, 4]
],
[
[5, 6], [7, 8]
]
])
print(f"x shape: {x.size()} \n {x}")
y = torch.tensor([1, 3])
return torch.matmul(x, y)
batched_matrix_broadcasted_vector()
x shape: torch.Size([2, 2, 2])
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
tensor([[ 7, 15],
[23, 31]])
batched matrix x batched matrix
def batched_matrix_batched_matrix():
x = torch.tensor([
[
[1, 2, 1], [3, 4, 4]
],
[
[5, 6, 2], [7, 8, 0]
]
])
y = torch.tensor([
[
[1, 2],
[3, 4],
[5, 6]
],
[
[7, 8],
[9, 10],
[1, 2]
]
])
print(f"x shape: {x.size()} \n y shape: {y.size()}")
return torch.matmul(x, y)
xy = batched_matrix_batched_matrix()
print(f"xy shape: {xy.size()} \n {xy}")
x shape: torch.Size([2, 2, 3])
y shape: torch.Size([2, 3, 2])
xy shape: torch.Size([2, 2, 2])
tensor([[[ 12, 16],
[ 35, 46]],
[[ 91, 104],
[121, 136]]])
上面的效果與 torch.bmm
是一樣的淆攻。matmul
比 bmm
功能更加強(qiáng)大嘿架,但是 bmm
的語(yǔ)義非常明確啸箫, bmm
處理的只能是 3維的。
def batched_matrix_batched_matrix_bmm():
x = torch.tensor([
[
[1, 2, 1], [3, 4, 4]
],
[
[5, 6, 2], [7, 8, 0]
]
])
y = torch.tensor([
[
[1, 2],
[3, 4],
[5, 6]
],
[
[7, 8],
[9, 10],
[1, 2]
]
])
print(f"x shape: {x.size()} \n y shape: {y.size()}")
return torch.bmm(x, y)
xy = batched_matrix_batched_matrix()
print(f"xy shape: {xy.size()} \n {xy}")
x shape: torch.Size([2, 2, 3])
y shape: torch.Size([2, 3, 2])
xy shape: torch.Size([2, 2, 2])
tensor([[[ 12, 16],
[ 35, 46]],
[[ 91, 104],
[121, 136]]])
tensordot
這個(gè)函數(shù)還沒(méi)有特別清楚蝉娜。
def tesnordot():
x = torch.tensor([
[1, 2, 1],
[3, 4, 4]])
y = torch.tensor([
[7, 8],
[9, 10],
[1, 2]])
print(f"x shape: {x.size()}, y shape: {y.size()}")
return torch.tensordot(x, y, dims=([0], [1]))
tesnordot()
x shape: torch.Size([2, 3]), y shape: torch.Size([3, 2])
tensor([[31, 39, 7],
[46, 58, 10],
[39, 49, 9]])