XLA探究: 矩阵乘法

Posted by Sz Zheng on 2019-10-28

XLA 探究:矩阵乘法


1. 矩阵乘法

矩阵乘法是被广泛使用的算子,其数学表达式可以写为:
$$C(i, j) = \sum_{k=0}^{K-1}A(i, k) * B(k, j), 0 \le i < M, 0 \le j < N$$

最简单的实现方法为:

1
2
3
4
5
for i in range(M):
for j in range(N):
C(i, j) = 0
for k in range(K):
C(i, j) += A(i, k) * B(k, j)

2. 在 Tensorflow 中定义矩阵乘法

Tensorflow中提供的线性代数和张量操作 API 允许对于一种计算有多种表达方式,推荐的方式往往是尽量利用已经存在的 API 原子性地表达一个运算,比如 tf.nn.conv2d, tf.linalg.matmul 等等,而不是将其拆开成多个操作。但是在深度学习网络的发展过程中,必然会有新的算子被提出,所以利用已有的运算拼凑新的运算是具有意义的。在本文里我们先尝试利用已有运算拼凑矩阵乘法,用来观测 XLA 对于拼凑出的运算(其实是计算图)会进行怎样的操作。

2.1 直接调用 matmul API

直接使用预设好的 API

1
2
def gemm1(A, B):
return tf.linalg.matmul(A, B)

对于这种计算图,XLA 的优化在于消除冗余的 reshape 等节点。但优化前后都会调用 dot 这个运算 API 完成计算,dot 本身就在 HLO instruction 之中。

2.2 利用升维降维计算矩阵乘法

先将输入矩阵升维到三维,进行逐点相乘后再降维累加得到输出矩阵

1
2
3
4
5
6
7
8
def gemm2(A, B):
return tf.reduce_sum(
tf.multiply(
tf.tile(tf.expand_dims(A, -1), [1, 1, B.shape[1]]),
tf.tile(tf.expand_dims(B, 0), [A.shape[0], 1, 1])
),
axis=1
)

打印出 XLA 的 IR,首先看图形:
初始计算图中有很多冗余的reshape以及broadcast。
初始计算图
优化后的计算图简洁了很多。
优化后的计算图

再来对比文本形式的 HLO IR。
优化前的 IR:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
HloModule cluster_15890044661264488385__.30

%Sum-reduction.21 (x.22: f32[], y.23: f32[]) -> f32[] {
%x.22 = f32[] parameter(0)
%y.23 = f32[] parameter(1)
ROOT %add.24 = f32[] add(f32[] %x.22, f32[] %y.23)
}

ENTRY %cluster_15890044661264488385__.30 (arg0.1: f32[2,2], arg1.2: f32[2,2]) -> f32[2,2] {
%constant.6 = f32[] constant(0), metadata={op_type="Tile" op_name="Tile"}
%broadcast.7 = f32[2,2,2]{2,1,0} broadcast(f32[] %constant.6), dimensions={}, metadata={op_type="Tile" op_name="Tile"}
%arg0.1 = f32[2,2]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
%reshape.3 = f32[2,2]{1,0} reshape(f32[2,2]{1,0} %arg0.1)
%reshape.5 = f32[2,2,1]{2,1,0} reshape(f32[2,2]{1,0} %reshape.3), metadata={op_type="ExpandDims" op_name="ExpandDims"}
%reshape.8 = f32[2,2]{1,0} reshape(f32[2,2,1]{2,1,0} %reshape.5), metadata={op_type="Tile" op_name="Tile"}
%broadcast.9 = f32[2,2,2]{2,1,0} broadcast(f32[2,2]{1,0} %reshape.8), dimensions={0,1}, metadata={op_type="Tile" op_name="Tile"}
%add.10 = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %broadcast.7, f32[2,2,2]{2,1,0} %broadcast.9), metadata={op_type="Tile" op_name="Tile"}
%constant.12 = f32[] constant(0), metadata={op_type="Tile" op_name="Tile_1"}
%broadcast.13 = f32[2,2,2]{2,1,0} broadcast(f32[] %constant.12), dimensions={}, metadata={op_type="Tile" op_name="Tile_1"}
%arg1.2 = f32[2,2]{1,0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
%reshape.4 = f32[2,2]{1,0} reshape(f32[2,2]{1,0} %arg1.2)
%reshape.11 = f32[1,2,2]{2,1,0} reshape(f32[2,2]{1,0} %reshape.4), metadata={op_type="ExpandDims" op_name="ExpandDims_1"}
%reshape.14 = f32[2,2]{1,0} reshape(f32[1,2,2]{2,1,0} %reshape.11), metadata={op_type="Tile" op_name="Tile_1"}
%broadcast.15 = f32[2,2,2]{2,1,0} broadcast(f32[2,2]{1,0} %reshape.14), dimensions={1,2}, metadata={op_type="Tile" op_name="Tile_1"}
%add.16 = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %broadcast.13, f32[2,2,2]{2,1,0} %broadcast.15), metadata={op_type="Tile" op_name="Tile_1"}
%multiply.17 = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} %add.10, f32[2,2,2]{2,1,0} %add.16), metadata={op_type="Mul" op_name="Mul"}
%convert.18 = f32[2,2,2]{2,1,0} convert(f32[2,2,2]{2,1,0} %multiply.17), metadata={op_type="Sum" op_name="Sum"}
%constant.19 = f32[] constant(0), metadata={op_type="Sum" op_name="Sum"}
%convert.20 = f32[] convert(f32[] %constant.19), metadata={op_type="Sum" op_name="Sum"}
%reduce.25 = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} %convert.18, f32[] %convert.20), dimensions={1}, to_apply=%Sum-reduction.21, metadata={op_type="Sum" op_name="Sum"}
%convert.26 = f32[2,2]{1,0} convert(f32[2,2]{1,0} %reduce.25), metadata={op_type="Sum" op_name="Sum"}
%reshape.27 = f32[2,2]{1,0} reshape(f32[2,2]{1,0} %convert.26), metadata={op_name="XLA_Retvals"}
%tuple.28 = (f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %reshape.27), metadata={op_name="XLA_Retvals"}
ROOT %get-tuple-element.29 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}) %tuple.28), index=0, metadata={op_name="XLA_Retvals"}
}

优化后的 IR:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
HloModule cluster_15890044661264488385__.30

%Sum-reduction.21 (x.22: f32[], y.23: f32[]) -> f32[] {
%x.22 = f32[] parameter(0)
%y.23 = f32[] parameter(1)
ROOT %add.24 = f32[] add(f32[] %x.22, f32[] %y.23)
}

%fused_computation (param_0.3: f32[2,2], param_1.4: f32[2,2]) -> f32[2,2] {
%param_1.4 = f32[2,2]{1,0} parameter(1)
%broadcast.3 = f32[2,2,2]{2,1,0} broadcast(f32[2,2]{1,0} %param_1.4), dimensions={0,1}, metadata={op_type="Tile" op_name="Tile"}
%param_0.3 = f32[2,2]{1,0} parameter(0)
%broadcast.2 = f32[2,2,2]{2,1,0} broadcast(f32[2,2]{1,0} %param_0.3), dimensions={1,2}, metadata={op_type="Tile" op_name="Tile_1"}
%multiply.0 = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} %broadcast.3, f32[2,2,2]{2,1,0} %broadcast.2), metadata={op_type="Mul" op_name="Mul"}
%constant_0 = f32[] constant(0), metadata={op_type="Sum" op_name="Sum"}
ROOT %reduce.0 = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} %multiply.0, f32[] %constant_0), dimensions={1}, to_apply=%Sum-reduction.21, metadata={op_type="Sum" op_name="Sum"}
}

ENTRY %cluster_15890044661264488385__.30 (arg0.1: f32[2,2], arg1.2: f32[2,2]) -> f32[2,2] {
%arg1.2 = f32[2,2]{1,0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg0.1 = f32[2,2]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
ROOT %fusion = f32[2,2]{1,0} fusion(f32[2,2]{1,0} %arg1.2, f32[2,2]{1,0} %arg0.1), kind=kLoop, calls=%fused_computation, metadata={op_type="Sum" op_name="Sum"}
}

XLA 对图进行优化并没有将各个小运算合并用 dot 运算代替,因为对于一个编译器来说,识别当前做的运算是否能用更简洁的 API 代替是个复杂(甚至不知道是否可解)的问题。这种利用数学等价性进行更大胆的图替换的优化可能成为下一个要攻克的难点。

2.3 利用for循环计算矩阵乘

这种方法对 Tensorflow 是最具有挑战性的,因为我们直接用 for 循环完成计算,Tensorflow 前端将根据 for 循环的次数增加计算节点。

1
2
3
4
5
6
7
8
9
10
11
12
def gemm3(A, B):
tmp_ary = []
for i in range(A.shape[0]):
tmp_row = []
for j in range(B.shape[1]):
tmp = A[i, 0] * B[0, j]
for k in range(1, A.shape[1]):
tmp = tmp + A[i, k] * B[k, j]
tmp_row.append(tmp)
tmp_ary.append(tf.stack(tmp_row))

return tf.stack(tmp_ary)

先看优化前的计算图:
优化前计算图
再看优化后的计算图:
优化后计算图
计算图一下子就变得复杂了起来,那是因为 for 循环中的每次实例,都在向图中增加计算节点,这个图只是对于$2\times 2$矩阵乘法的,可以想象矩阵变大后图会有多复杂。事实上笔者在尝试$16\times 16$的矩阵乘时就无法正常可视化计算图了。$32 \times 32$以上的矩阵更是算不出来(因为处理计算图卡住)。

这种实现方式实在是难为 Tensorflow 了。因为 Tensorflow 的基本运算单位是张量,尽量将运算都定义在张量上,并构造一个静态的运算图是其基本出发点,而现在的实现方法是每个 scalar 运算都对应了一个节点,且节点数与输入规模相关,这样构造的计算图异常巨大,且不可扩展,不可迁移。XLA 更是无法试别这种计算图的计算模式,用一个简单的 dot 代替这些节点。

近期火热的深度学习编译器如 TVM 则是专门针对循环定义构建的,上面那种 for 循环形式的定义对于其是友好的,这也是 TVM 与 XLA 的不同之一。

3. 性能对比

上文介绍的三种实现方式具有不同的性能,第三种实现方式的性能不言而喻地低,因为基本的图构建都无法完成,所以就不列在比较之中了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
gemm1:
(1 x 1): 0.498796ms
(2 x 2): 0.482011ms
(4 x 4): 0.465155ms
(8 x 8): 0.442171ms
(16 x 16): 0.475311ms
(32 x 32): 0.481582ms
(64 x 64): 0.474691ms
(128 x 128): 0.490928ms
(256 x 256): 0.577831ms
(512 x 512): 1.005220ms
(1024 x 1024): 2.768564ms
(2048 x 2048): 9.252787ms
(4096 x 4096): 43.600631ms
(8192 x 8192): 228.450990ms

gemm2:
(1 x 1): 0.508928ms
(2 x 2): 0.509024ms
(4 x 4): 0.528526ms
(8 x 8): 0.489831ms
(16 x 16): 0.492191ms
(32 x 32): 0.510383ms
(64 x 64): 0.511551ms
(128 x 128): 0.538731ms
(256 x 256): 0.743008ms
(512 x 512): 1.382422ms
(1024 x 1024): 7.267666ms
(2048 x 2048): 102.449155ms
(4096 x 4096): 804.755259ms
(8192 x 8192): 6476.903272ms

可以发现当矩阵大小超过512后,第二种方法产生了明现的劣势。这也体现了单独原子性地调用 API 的重要性。这可以抽象为一个划归问题,即准备好了许多高性能的运算实现(如cuDNN),但是具体的运算图如何划归到合适的运算 API 上是个问题,现在的图优化仅仅做了简单的算术化简以及算子合并,却不能发掘隐藏的激进的规约机会,就会错过优化机会。

4. 总结

本文旨在窥探 XLA 的优化能力边界,结合当前常见的新算子无原子性 API 支持时需要用已有 API 拼凑的场景,利用矩阵乘法作为一个例子,初步得出以下结论:

  • XLA 的图优化很保守,至少难以发掘图上运算模式规约的可能性。当然,这也是所有目前的深度学习编译器面临的挑战。
  • 单独调用一个 API 比拼凑 API 更具有优势。这本质体现了库支持问题。