einsum
函数是 Numpy 库的一个宝藏函数,其往往能够比常见的 array 函数具有更快的速度和更少的内存,因为其具有更强的表达能力和更巧妙的循环. 但,理解 einsum
符号可能需要一点时间,而且在某些时候将其正确应用也比较棘手.
1. einsum 的作用
einsum
函数的使用,可以采用 Einstein summation convention 对 Numpy 数组指定操作.
假设两个数组 A和B,需要进行如下操作:
[1] - 乘法:首先,以特定方式将 A 和 B 相乘,以得到乘积结果数组;
[2] - 求和:然后,沿特定 axes 求和,得到新的数组;
[3] - 转置:再以特定顺序,对数组转置.
采用 einsum
能够更快速和更少内存实现 Numpy 的 multiply、sum 和 transpose 函数功能.
如,
import numpy as np
A = np.array([0, 1, 2]) #(3, )
B = np.array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]) #(3, 4)
采用 Numpy 函数的一般实现是:
首先,需要先对 A 进行 reshape,以便于与B进行操作(A需要是列向量);
然后,对 B 的第一行乘以 0,对 B 的第二行乘以 1, 对B 的第三行乘以 2,以得到新的数组;
接着,将第三列相加求和.
即:
(A[:, np.newaxis] * B).sum(axis=1)
#array([ 0, 22, 76])
而,einsum
方式如:
np.einsum('i,ij->i', A, B)
#array([ 0, 22, 76])
这种方式更有,其原因是,不需要对 A 进行 reshape 操作;最重要的是,不会产生临时数据,如 A[:, np.newaxis] * B
.
即使上面这种小例子,einsum
也具有三倍更快的速度.
2. einsum 的使用
其关键是,为输入数组和期望输出数组的轴(axes)的正确标签.
einsum
函数提供了两种方式:使用字母串,或者整数列表.
简单期间,选择最常用的字母串方式.
矩阵乘法是便于示例的,其涉及了行和列的相乘以及乘积结果的求和. 对于两个2D数组A和B,矩阵乘法可以实现为:
np.einsum('ij,jk->ik', A, B)
将 ij,jk->ik
想象成在箭头->
处一分为二.
其中,
左边ij,jk
是输入数组的 axes:ij
标记A,jk
标记B. 右边ik
是输出数组的 axes.
换句话说,是将两个输入 2D 数组放到一个新的输出 2D 数据中.
如,数组 A 和B,
A = np.array([[1, 1, 1],
[2, 2, 2],
[5, 5, 5]])
B = np.array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1]])
np.einsum('ij,jk->ik', A, B)
实现如图:
为了便于理解输出数组的计算过程,记住如下三个规则:
[1] - 输入数组之间的重复字母,表示沿着这些轴(axes)的值将相乘,乘积构成输出数组的值.
如,字母 j
重复了两次,一次是A,一次是B. 意味着是将 A 的每一行和B的每一列相乘. 其仅对两个数组在 j
标记的 axes 长度一致才有效(或者是其中一个数组的长度为 1).
[2] - 输出中忽略的字母,表示沿该轴(axes)的值将被求和.
如,字母 j
不在输出数组的标记中. 忽略它沿轴求和,并显式地将最终数据的维数减少1.
如果输出的标记名为 ijk
,则最终会得到一个 3x3x3 的乘积数组.(且,如果没有给出输出标记,而是只写箭头->
,则需要对整个数据求和.)
[3] - 可以以任意次序返回未求和的轴(unsummed axes)
如果忽略箭头 ->
,Numpy 将会获取出现过一次的字母标记,并按照字母顺序进行排列(所以,实际上 ij,jk->ik
等价于 ij,jk
).
如果期望控制输出的形式,可以自定义选择输出字母标记的顺序,如,ij,jk->ki
表示矩阵乘法的转置(注意 k
和 i
在输出标记中的次序有切换).
至此,应该更容易理解矩阵乘法的计算流程. 如图示,如果不对 j
轴求和,而是采用 np.einsum('ij,jk->ijk', A, B))
将 j
轴包含在输出中. 右侧,j
轴已被求和.
注:np.einsum('ij,jk->ik', A, B)
函数并未构建 3D 数组再求和,其只是将总和累计到一个 2D 数组中.
3. einsum 简单示例
假设 A 和 B 是两个 1D 数据,其具有兼容的shapes(即,配对在一起的轴(axes)的长度要么相等,要么其中之一的长度为 1).
标记符号 | NumPy 等价形式 | 描述 |
---|---|---|
('i', A) | A | A |
('i->', A) | sum(A) | A求和 |
('i,i->i', A, B) | A * B | A和B逐元素相乘 |
('i,i', A, B) | inner(A, B) | 内积inner product of A and B |
('i,j->ij', A, B) | outer(A, B) | 外积outer product of A and B |
假设 A 和 B 是两个 2D 数据,其具有兼容的shapes.
标记符号 | NumPy 等价形式 | 描述 |
---|---|---|
('ij', A) | A | returns a view of A |
('ji', A) | A.T | 转置view transpose of A |
('ii->i', A) | diag(A) | 主对角线view main diagonal of A |
('ii', A) | trace(A) | 主对角线求和sums main diagonal of A |
('ij->', A) | sum(A) | 求和sums the values of A |
('ij->j', A) | sum(A, axis=0) | sum down the columns of A (across rows) |
('ij->i', A) | sum(A, axis=1) | sum horizontally along the rows of A |
('ij,ij->ij', A, B) | A * B | 逐元素乘element-wise multiplication of A and B |
('ij,ji->ij', A, B) | A * B.T | element-wise multiplication of A and B.T |
('ij,jk', A, B) | dot(A, B) | 矩阵乘法matrix multiplication of A and B |
('ij,kj->ik', A, B) | inner(A, B) | 内积inner product of A and B |
('ij,kj->ikj', A, B) | A[:, None] * B | each row of A multiplied by B |
('ij,kl->ijkl', A, B) | A[:, :, None, None] * B | each value of A multiplied by B |
4. einsum 注意事项
[1] - 数据类型问题,does not promote data types when summing.
a = np.ones(300, dtype=np.int8)
np.sum(a) # correct result
#300
np.einsum('i->', a) # produces incorrect result
#44
[2] - 轴(axes)的排列问题,might not permute axes in the order inteded.
[3] - einsum
在 Numpy 中也并不是总是最快的.