原文:A basic introduction to NumPy's einsum

einsum 函数是 Numpy 库的一个宝藏函数,其往往能够比常见的 array 函数具有更快的速度和更少的内存,因为其具有更强的表达能力和更巧妙的循环. 但,理解 einsum 符号可能需要一点时间,而且在某些时候将其正确应用也比较棘手.

1. einsum 的作用

einsum 函数的使用,可以采用 Einstein summation convention 对 Numpy 数组指定操作.

假设两个数组 A和B,需要进行如下操作:

[1] - 乘法:首先,以特定方式将 A 和 B 相乘,以得到乘积结果数组;

[2] - 求和:然后,沿特定 axes 求和,得到新的数组;

[3] - 转置:再以特定顺序,对数组转置.

采用 einsum 能够更快速和更少内存实现 Numpy 的 multiply、sum 和 transpose 函数功能.

如,

import numpy as np 

A = np.array([0, 1, 2]) #(3, )
B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]]) #(3, 4)

采用 Numpy 函数的一般实现是:

首先,需要先对 A 进行 reshape,以便于与B进行操作(A需要是列向量);

然后,对 B 的第一行乘以 0,对 B 的第二行乘以 1, 对B 的第三行乘以 2,以得到新的数组;

接着,将第三列相加求和.

即:

(A[:, np.newaxis] * B).sum(axis=1)
#array([ 0, 22, 76])

而,einsum 方式如:

np.einsum('i,ij->i', A, B)
#array([ 0, 22, 76])

这种方式更有,其原因是,不需要对 A 进行 reshape 操作;最重要的是,不会产生临时数据,如 A[:, np.newaxis] * B.

即使上面这种小例子,einsum 也具有三倍更快的速度.

2. einsum 的使用

其关键是,为输入数组和期望输出数组的轴(axes)的正确标签.

einsum 函数提供了两种方式:使用字母串,或者整数列表.

简单期间,选择最常用的字母串方式.

矩阵乘法是便于示例的,其涉及了行和列的相乘以及乘积结果的求和. 对于两个2D数组A和B,矩阵乘法可以实现为:

np.einsum('ij,jk->ik', A, B)

ij,jk->ik 想象成在箭头->处一分为二.

其中,

左边ij,jk 是输入数组的 axes:ij 标记A,jk 标记B. 右边ik 是输出数组的 axes.

换句话说,是将两个输入 2D 数组放到一个新的输出 2D 数据中.

如,数组 A 和B,

A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])

B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])

np.einsum('ij,jk->ik', A, B) 实现如图:

为了便于理解输出数组的计算过程,记住如下三个规则:

[1] - 输入数组之间的重复字母,表示沿着这些轴(axes)的值将相乘,乘积构成输出数组的值.

如,字母 j 重复了两次,一次是A,一次是B. 意味着是将 A 的每一行和B的每一列相乘. 其仅对两个数组在 j 标记的 axes 长度一致才有效(或者是其中一个数组的长度为 1).

[2] - 输出中忽略的字母,表示沿该轴(axes)的值将被求和.

如,字母 j 不在输出数组的标记中. 忽略它沿轴求和,并显式地将最终数据的维数减少1.

如果输出的标记名为 ijk,则最终会得到一个 3x3x3 的乘积数组.(且,如果没有给出输出标记,而是只写箭头->,则需要对整个数据求和.)

[3] - 可以以任意次序返回未求和的轴(unsummed axes)

如果忽略箭头 ->,Numpy 将会获取出现过一次的字母标记,并按照字母顺序进行排列(所以,实际上 ij,jk->ik 等价于 ij,jk).

如果期望控制输出的形式,可以自定义选择输出字母标记的顺序,如,ij,jk->ki 表示矩阵乘法的转置(注意 ki 在输出标记中的次序有切换).

至此,应该更容易理解矩阵乘法的计算流程. 如图示,如果不对 j 轴求和,而是采用 np.einsum('ij,jk->ijk', A, B))j 轴包含在输出中. 右侧,j 轴已被求和.

注:np.einsum('ij,jk->ik', A, B) 函数并未构建 3D 数组再求和,其只是将总和累计到一个 2D 数组中.

3. einsum 简单示例

假设 A 和 B 是两个 1D 数据,其具有兼容的shapes(即,配对在一起的轴(axes)的长度要么相等,要么其中之一的长度为 1).

标记符号NumPy 等价形式描述
('i', A)AA
('i->', A)sum(A)A求和
('i,i->i', A, B)A * BA和B逐元素相乘
('i,i', A, B)inner(A, B)内积inner product of A and B
('i,j->ij', A, B)outer(A, B)外积outer product of A and B

假设 A 和 B 是两个 2D 数据,其具有兼容的shapes.

标记符号NumPy 等价形式描述
('ij', A)Areturns a view of A
('ji', A)A.T转置view transpose of A
('ii->i', A)diag(A)主对角线view main diagonal of A
('ii', A)trace(A)主对角线求和sums main diagonal of A
('ij->', A)sum(A)求和sums the values of A
('ij->j', A)sum(A, axis=0)sum down the columns of A (across rows)
('ij->i', A)sum(A, axis=1)sum horizontally along the rows of A
('ij,ij->ij', A, B)A * B逐元素乘element-wise multiplication of A and B
('ij,ji->ij', A, B)A * B.Telement-wise multiplication of A and B.T
('ij,jk', A, B)dot(A, B)矩阵乘法matrix multiplication of A and B
('ij,kj->ik', A, B)inner(A, B)内积inner product of A and B
('ij,kj->ikj', A, B)A[:, None] * Beach row of A multiplied by B
('ij,kl->ijkl', A, B)A[:, :, None, None] * Beach value of A multiplied by B

4. einsum 注意事项

[1] - 数据类型问题,does not promote data types when summing.

a = np.ones(300, dtype=np.int8)
np.sum(a) # correct result
#300
np.einsum('i->', a) # produces incorrect result
#44

[2] - 轴(axes)的排列问题,might not permute axes in the order inteded.

[3] - einsum 在 Numpy 中也并不是总是最快的.

5 einsum 材料

Last modification:June 5th, 2021 at 01:46 pm