Pytorch 代码里看到 @ 和 * 两个操作符,其功能如下.
1. 矩阵相乘@
示例如,
import torch
#
X = torch.tensor([[1,2],[3,4],[5,6]]) #torch.Size([3, 2])
#tensor([[1, 2],
# [3, 4],
# [5, 6]])
Y = torch.tensor([[7,8],[9, 10]]) #torch.Size([2, 2])
#tensor([[5, 6],
# [7, 8]])
Z = X@Y #torch.Size([3, 2])
#tensor([[ 25, 28],
# [ 57, 64],
# [ 89, 100]])
#等价于:
Z = torch.matmul(X, Y)
2. 矩阵逐元素相乘*
示例如,
import torch
#
X = torch.tensor([[1,2],[3,4],[5,6]]) #torch.Size([3, 2])
#tensor([[1, 2],
# [3, 4],
# [5, 6]])
Y = torch.tensor([[7,8],[9, 10],[11,12]]) #torch.Size([3, 2])
#tensor([[ 7, 8],
# [ 9, 10],
# [11, 12]])
Z = X*Y
#tensor([[ 7, 16],
# [27, 40],
# [55, 72]])
#等价于:
Z = torch.mul(X, Y)