Tensorflow/XLA 探究

Posted by Sz Zheng on 2019-10-23

Tensorflow XLA 探究


1. 前言

这篇博文用于记录探究Tensorflow/XLA时了解的知识。作为 XLA 相关博客的第一篇,先介绍如何安装使用以及总体的优化流程。


##2. 安装Tensorflow/XLA
在tensorflow r2.0前,XLA的特性是关闭的,只有自行编译安装才能开启这一特性,下面是在个人主机上编译安装带XLA特性的tensorflow过程。

  1. 查看操作系统,下面只对Linux (Ubuntu 18.04)系统为例,因为Windows下不支持XLA的编译安装。本人的系统中已经预装好了CUDA 10.0和cuDNN 7.6。对于安装CUDA 9,可以参考这篇文章,安装CUDA 10,可以参考这篇文章

  2. 查看版本,cuda 9.0 对应 tensorflow r1.5 - r1.12 (cudnn v7.0), cuda 10.0 对应 tensorflow r1.13 - r1.14 (cudnn v7.4+)
    下面以cuda 10.0 + tensorflow r1.14 为例,cudnn版本是v7.6

  3. 如果缺少python3,则需要安装python3:

    1. 访问python官网下载源代码,选取需要的版本,这里选择python3.6
    1
    2
    3
    wget https://www.python.org/ftp/python/3.6.9/Python-3.6.9.tgz
    tar xvf Python-3.6.9.tgz
    cd Python-3.6.9
    1. 编译前设置,这里--prefix必须是绝对路径,指明了安装目录,--enable-optimizations指明安装稳定优化过的release版本
    1
    2
    mkdir build
    ./configure --prefix $(pwd)/build/ --enable-optimizations
    1. 编译并安装,这里用了8个线程,具体根据自己的设备情况调整
    1
    make -j 8 && make install
    1. 安装完毕后配置环境变量,一般都会在$HOME/.local/bin中创建一个符号链接
    1
    2
    3
    4
    5
    cd ~/.local
    mkdir bin # 如果存在就不用了
    cd bin
    ln -s ~/Python-3.6.9/build/bin/python3 python3
    export PATH=~/.local/bin:$PATH
    1. 检验
    1
    2
    3
    4
    5
    $ python3
    Python 3.6.9 (default, Oct 14 2019, 21:08:52)
    [GCC 4.8.5 20150623 (Red Hat 4.8.5-28)] on linux
    Type "help", "copyright", "credits" or "license" for more information.
    >>>
    1. 虚拟环境,一般用python时最好使用虚拟环境,这样方便管理,也不怕后续出错对其它环境产生影响,这里用virtualenv为例(virtualenv已经下载安装好)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    mkdir venv && cd venv
    virtualenv prime -p python3 # prime是虚拟环境名字
    # 安装好prime虚拟环境后,使用它
    source prime/bin/activate
    # 检验是否安装正确
    which python # 应输出~/venv/prime/bin/python
    wihch pip # 应输出~/venv/prime/bin/pip
    # 升级pip,默认版本都比较低
    pip install -U pip
  4. 按照官网指定安装依赖

    1
    2
    3
    pip install six numpy wheel setuptools mock future>=0.17.1
    pip install keras_applications==1.0.6 --no-deps
    pip install keras_preprocessing==1.0.5 --no-deps
  5. 安装编译工具bazel,方法列在这里。以 Ubuntu 18.04 为例。可以直接下载预先编译好的bazel,选取tensorflow需要的版本,这里使用0.26.0版本

    1
    2
    3
    4
    5
    6
    7
    cd ~
    wget https://github.com/bazelbuild/bazel/releases/download/0.26.0/bazel-0.26.0-linux-x86_64
    chmod +x bazel-0.26.0-linux-x86_64
    cd ~/.local/bin
    ln -s ~/bazel-0.26.0-linux-x86_64 bazel
    # 测试
    bazel version
  6. 下载tensorflow源码,并检出需要的分支(r1.14)

    1
    2
    3
    4
    cd ~
    git clone https://github.com/tensorflow/tensorflow.git
    cd tensorflow
    git checkout r1.14
  7. 配置

    1
    ./configure
  8. 编译,需要时间非常长,可以放在后台运行,查看编译输出监视是否运行良好。

    1
    2
    bazel build --config=opt --config=cuda --config=mkl //tensorflow/tools/pip_package:build_pip_package
    ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
  9. 安装:

    1
    pip install /tmp/tensorflow_pkg/tensorflow-1.14.1.whl

至此,可以在python中import tensorflow


3. 运行XLA

3.1 模型定义

一般XLA开放的接口有两种:JIT和AOT。
对于JIT,可以使用xla.compile接口。定义模型后直接交给xla.compile即可:

1
2
3
4
5
6
7
8
from tensorflow.contrib.compiler import xla

# define model
def get_model(...):
...

# use xla
output_list = xla.compile(get_model, input_list)

3.2 实际运行

直接运行Python文件即可,如果想要看到XLA每一步做的事情,可以用以下命令
XLA_FLAGS="--xla_dump_to=/some/path --xla_dump_hlo_pass_re=.* --xla_dump_hlo_as_html" python your_program.py
这里用到的选项解释为:

  • xla_dump_to: 希望生成的中间表示存放在哪里
  • xla_dump_hlo_pass_re: 默认xla是不会导出hlo内部pass的,但是使用这个选项后可以导出对应的pass,.*表示所有pass
  • xla_dump_hlo_as_html: 用html格式导出hlo,还可以用xla_dump_hlo_as_{text, proto, ...}
    完整的选项信息可以在xla.proto里找到,注意不同的Tensorflow版本使用的选项是不同的,要对应自己使用的版本。

4. 内部功能解析

从这一节开始,我们正式开始窥探 XLA 内部的实现。首先从文法开始。XLA 使用 HLO 作为 IR (Intermediate Representation),XLA 的全称是 Accelerate Linear Algebra, HLO 的全称是 High Level Optimizer。 HLO 具有自己的文法,在tensorflow/compiler/xla/service/g3doc/hlo_parser.md中记录了 HLO 的完整文法,其内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
hlo_module
: 'HloModule' name computations
;

/* If no computation is marked as ENTRY, the last computation will be the entry
computation of the module.*/
computations
: computation
| computation computations
;

computation
: 'ENTRY' name param_list_to_shape instruction_list
| name param_list_to_shape instruction_list
| 'ENTRY' name instruction_list
| name instruction_list
;

/* If no instruction is marked as ROOT, the last instruction will be the root of
its computation. */
instruction_list
: '{' instruction_list1 '}'
;
instruction_list1
: instruction
| instruction_list1 instruction
;
instruction
: 'ROOT' name '=' shape opcode operands extra_attributes
| name '=' shape opcode operands extra_attributes
;

operands
: '(' operands1 ')'
;
operands1
: /*empty*/
| operand
| operands1 ',' operand
;
operand
: shape name
| name
;

attributes
: /*empty*/
| ',' attribute
| ',' attribute attributes
;
attribute
: attribute_name attribute_value
;
attribute_value
: kInt
| kName
| [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} /*dim_labels_pattern*/
| [0-9]+(x[0-9]+)+ /*dxd_pattern*/
| [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* /*pad_pattern*/
| '{' sub_attributes '}'
;

param_list_to_shape
: param_list '->' shape
;

param_list
: '(' param_list1 ')'
;
param_list1
: /*empty*/
| param
| param_list1 ',' param
;
param
: name shape
;

shape
: shape_val_
| '(' tuple_elements ')'
;
tuple_elements
: /*empty*/
| shape (',' shape)*
;

name
: identifier ':'
| '%' identifier
| identifier
;

identifier
: [a-zA-Z_][a-zA-Z0-9_.-]*
;

/* literal is in the right hand side of a constant instruction. */
literal
: tuple
| non_tuple
;
tuple
: shape '(' literal_list ')'
;
literal_list
: /*empty*/
: literal
| literal_list ',' literal
;
non_tuple
: rank01
| rank2345
;
rank2345
: shape sparse_or_nested_array
;
sparse_or_nested_array
: sparse_array
| nested_array
;
sparse_array
: '{' sparse_array1 '}'
;
sparse_array1
: sparse_array_item
| sparse_array1 ',' sparse_array_item
;
sparse_array_item
: multi_index ':' scalar
;
multi_index
: kInt
| '[' multi_index1 ']'
;
multi_index1
: kInt
| multi_index1 ',' kInt
;

通过 xla_dump_hlo_as_text 可以得到 HLO 的文本形式记录,一个例子如下 (这个例子在GPU上运行):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
HloModule cluster_1992353871243790009__.123

%add_F32.87 (lhs.88: f32[], rhs.89: f32[]) -> f32[] {
%lhs.88 = f32[] parameter(0)
%rhs.89 = f32[] parameter(1)
ROOT %add.90 = f32[] add(f32[] %lhs.88, f32[] %rhs.89)
}

%max_float_.102 (x.103: f32[], y.104: f32[]) -> f32[] {
%x.103 = f32[] parameter(0)
%y.104 = f32[] parameter(1)
ROOT %maximum.105 = f32[] maximum(f32[] %x.103, f32[] %y.104)
}

%add_float_.112 (x.113: f32[], y.114: f32[]) -> f32[] {
%x.113 = f32[] parameter(0)
%y.114 = f32[] parameter(1)
ROOT %add.115 = f32[] add(f32[] %x.113, f32[] %y.114)
}

ENTRY %cluster_1992353871243790009__.123 (arg0.1: f32[1,224,224,3], arg1.2: f32[32], arg2.3: f32[3,3,32,1], arg3.4: f32[32], arg4.5: f32[32], arg5.6: f32[32], arg6.7: f32[64], arg7.8: f32[1,1,32,64], arg8.9: f32[64], arg9.10: f32[64], arg10.11: f32[64], arg11.12: f32[64], arg12.13: f32[3,3,64,1], arg13.14: f32[64], arg14.15: f32[64], arg15.16: f32[64], arg16.17: f32[128], arg17.18: f32[1,1,64,128], arg18.19: f32[128], arg19.20: f32[128], arg20.21: f32[128], arg21.22: f32[1000], arg22.23: f32[128,1000], arg23.24: f32[32], arg24.25: f32[32], arg25.26: f32[32], arg26.27: f32[32], arg27.28: f32[3,3,3,32]) -> f32[1,1000] {
%constant.80 = f32[] constant(0), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/Relu"}
%broadcast.81 = f32[1,56,56,128]{3,2,1,0} broadcast(f32[] %constant.80), dimensions={}, metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/Relu"}
%constant.71 = f32[] constant(0), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/Relu"}
%broadcast.72 = f32[1,56,56,64]{3,2,1,0} broadcast(f32[] %constant.71), dimensions={}, metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/Relu"}
%constant.61 = f32[] constant(0), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_1/pw_batch_norm/Relu"}
%broadcast.62 = f32[1,112,112,64]{3,2,1,0} broadcast(f32[] %constant.61), dimensions={}, metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_1/pw_batch_norm/Relu"}
%constant.52 = f32[] constant(0), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/Relu"}
%broadcast.53 = f32[1,112,112,32]{3,2,1,0} broadcast(f32[] %constant.52), dimensions={}, metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/Relu"}
%constant.42 = f32[] constant(0), metadata={op_type="Relu" op_name="mobilenet/first_batch_norm/Relu"}
%broadcast.43 = f32[1,112,112,32]{3,2,1,0} broadcast(f32[] %constant.42), dimensions={}, metadata={op_type="Relu" op_name="mobilenet/first_batch_norm/Relu"}
%arg0.1 = f32[1,224,224,3]{3,2,1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
%reshape.29 = f32[1,224,224,3]{3,2,1,0} reshape(f32[1,224,224,3]{3,2,1,0} %arg0.1)
%arg27.28 = f32[3,3,3,32]{3,2,1,0} parameter(27), parameter_replication={false}, metadata={op_name="XLA_Args"}
%convolution.36 = f32[1,112,112,32]{3,2,1,0} convolution(f32[1,224,224,3]{3,2,1,0} %reshape.29, f32[3,3,3,32]{3,2,1,0} %arg27.28), window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="mobilenet/first_conv3x3/Conv2D"}
%arg26.27 = f32[32]{0} parameter(26), parameter_replication={false}, metadata={op_name="XLA_Args"}
%broadcast.37 = f32[1,112,112,32]{3,2,1,0} broadcast(f32[32]{0} %arg26.27), dimensions={3}, metadata={op_type="BiasAdd" op_name="mobilenet/first_conv3x3/BiasAdd"}
%add.38 = f32[1,112,112,32]{3,2,1,0} add(f32[1,112,112,32]{3,2,1,0} %convolution.36, f32[1,112,112,32]{3,2,1,0} %broadcast.37), metadata={op_type="BiasAdd" op_name="mobilenet/first_conv3x3/BiasAdd"}
%convert.39 = f32[1,112,112,32]{3,2,1,0} convert(f32[1,112,112,32]{3,2,1,0} %add.38), metadata={op_type="FusedBatchNorm" op_name="mobilenet/first_batch_norm/FusedBatchNorm"}
%constant.30 = f32[] constant(1), metadata={op_type="Const" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/Const"}
%broadcast.31 = f32[32]{0} broadcast(f32[] %constant.30), dimensions={}, metadata={op_type="Const" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/Const"}
%arg23.24 = f32[32]{0} parameter(23), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg24.25 = f32[32]{0} parameter(24), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg25.26 = f32[32]{0} parameter(25), parameter_replication={false}, metadata={op_name="XLA_Args"}
%batch-norm-inference.40 = f32[1,112,112,32]{3,2,1,0} batch-norm-inference(f32[1,112,112,32]{3,2,1,0} %convert.39, f32[32]{0} %broadcast.31, f32[32]{0} %arg23.24, f32[32]{0} %arg24.25, f32[32]{0} %arg25.26), epsilon=0.001, feature_index=3, metadata={op_type="FusedBatchNorm" op_name="mobilenet/first_batch_norm/FusedBatchNorm"}
%convert.41 = f32[1,112,112,32]{3,2,1,0} convert(f32[1,112,112,32]{3,2,1,0} %batch-norm-inference.40), metadata={op_type="FusedBatchNorm" op_name="mobilenet/first_batch_norm/FusedBatchNorm"}
%maximum.44 = f32[1,112,112,32]{3,2,1,0} maximum(f32[1,112,112,32]{3,2,1,0} %broadcast.43, f32[1,112,112,32]{3,2,1,0} %convert.41), metadata={op_type="Relu" op_name="mobilenet/first_batch_norm/Relu"}
%arg2.3 = f32[3,3,32,1]{3,2,1,0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"}
%reshape.45 = f32[3,3,1,32]{3,2,1,0} reshape(f32[3,3,32,1]{3,2,1,0} %arg2.3), metadata={op_type="DepthwiseConv2dNative" op_name="mobilenet/depthwise_seperable_1/depthwise/depthwise"}
%convolution.46 = f32[1,112,112,32]{3,2,1,0} convolution(f32[1,112,112,32]{3,2,1,0} %maximum.44, f32[3,3,1,32]{3,2,1,0} %reshape.45), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=32, metadata={op_type="DepthwiseConv2dNative" op_name="mobilenet/depthwise_seperable_1/depthwise/depthwise"}
%arg1.2 = f32[32]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
%broadcast.47 = f32[1,112,112,32]{3,2,1,0} broadcast(f32[32]{0} %arg1.2), dimensions={3}, metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_1/depthwise/BiasAdd"}
%add.48 = f32[1,112,112,32]{3,2,1,0} add(f32[1,112,112,32]{3,2,1,0} %convolution.46, f32[1,112,112,32]{3,2,1,0} %broadcast.47), metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_1/depthwise/BiasAdd"}
%convert.49 = f32[1,112,112,32]{3,2,1,0} convert(f32[1,112,112,32]{3,2,1,0} %add.48), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/FusedBatchNorm"}
%arg3.4 = f32[32]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg4.5 = f32[32]{0} parameter(4), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg5.6 = f32[32]{0} parameter(5), parameter_replication={false}, metadata={op_name="XLA_Args"}
%batch-norm-inference.50 = f32[1,112,112,32]{3,2,1,0} batch-norm-inference(f32[1,112,112,32]{3,2,1,0} %convert.49, f32[32]{0} %broadcast.31, f32[32]{0} %arg3.4, f32[32]{0} %arg4.5, f32[32]{0} %arg5.6), epsilon=0.001, feature_index=3, metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/FusedBatchNorm"}
%convert.51 = f32[1,112,112,32]{3,2,1,0} convert(f32[1,112,112,32]{3,2,1,0} %batch-norm-inference.50), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/FusedBatchNorm"}
%maximum.54 = f32[1,112,112,32]{3,2,1,0} maximum(f32[1,112,112,32]{3,2,1,0} %broadcast.53, f32[1,112,112,32]{3,2,1,0} %convert.51), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_1/dw_batch_norm/Relu"}
%arg7.8 = f32[1,1,32,64]{3,2,1,0} parameter(7), parameter_replication={false}, metadata={op_name="XLA_Args"}
%convolution.55 = f32[1,112,112,64]{3,2,1,0} convolution(f32[1,112,112,32]{3,2,1,0} %maximum.54, f32[1,1,32,64]{3,2,1,0} %arg7.8), window={size=1x1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="mobilenet/depthwise_seperable_1/pointwise/Conv2D"}
%arg6.7 = f32[64]{0} parameter(6), parameter_replication={false}, metadata={op_name="XLA_Args"}
%broadcast.56 = f32[1,112,112,64]{3,2,1,0} broadcast(f32[64]{0} %arg6.7), dimensions={3}, metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_1/pointwise/BiasAdd"}
%add.57 = f32[1,112,112,64]{3,2,1,0} add(f32[1,112,112,64]{3,2,1,0} %convolution.55, f32[1,112,112,64]{3,2,1,0} %broadcast.56), metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_1/pointwise/BiasAdd"}
%convert.58 = f32[1,112,112,64]{3,2,1,0} convert(f32[1,112,112,64]{3,2,1,0} %add.57), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_1/pw_batch_norm/FusedBatchNorm"}
%constant.32 = f32[] constant(1), metadata={op_type="Const" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/Const"}
%broadcast.33 = f32[64]{0} broadcast(f32[] %constant.32), dimensions={}, metadata={op_type="Const" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/Const"}
%arg8.9 = f32[64]{0} parameter(8), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg9.10 = f32[64]{0} parameter(9), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg10.11 = f32[64]{0} parameter(10), parameter_replication={false}, metadata={op_name="XLA_Args"}
%batch-norm-inference.59 = f32[1,112,112,64]{3,2,1,0} batch-norm-inference(f32[1,112,112,64]{3,2,1,0} %convert.58, f32[64]{0} %broadcast.33, f32[64]{0} %arg8.9, f32[64]{0} %arg9.10, f32[64]{0} %arg10.11), epsilon=0.001, feature_index=3, metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_1/pw_batch_norm/FusedBatchNorm"}
%convert.60 = f32[1,112,112,64]{3,2,1,0} convert(f32[1,112,112,64]{3,2,1,0} %batch-norm-inference.59), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_1/pw_batch_norm/FusedBatchNorm"}
%maximum.63 = f32[1,112,112,64]{3,2,1,0} maximum(f32[1,112,112,64]{3,2,1,0} %broadcast.62, f32[1,112,112,64]{3,2,1,0} %convert.60), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_1/pw_batch_norm/Relu"}
%arg12.13 = f32[3,3,64,1]{3,2,1,0} parameter(12), parameter_replication={false}, metadata={op_name="XLA_Args"}
%reshape.64 = f32[3,3,1,64]{3,2,1,0} reshape(f32[3,3,64,1]{3,2,1,0} %arg12.13), metadata={op_type="DepthwiseConv2dNative" op_name="mobilenet/depthwise_seperable_2/depthwise/depthwise"}
%convolution.65 = f32[1,56,56,64]{3,2,1,0} convolution(f32[1,112,112,64]{3,2,1,0} %maximum.63, f32[3,3,1,64]{3,2,1,0} %reshape.64), window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, feature_group_count=64, metadata={op_type="DepthwiseConv2dNative" op_name="mobilenet/depthwise_seperable_2/depthwise/depthwise"}
%arg11.12 = f32[64]{0} parameter(11), parameter_replication={false}, metadata={op_name="XLA_Args"}
%broadcast.66 = f32[1,56,56,64]{3,2,1,0} broadcast(f32[64]{0} %arg11.12), dimensions={3}, metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_2/depthwise/BiasAdd"}
%add.67 = f32[1,56,56,64]{3,2,1,0} add(f32[1,56,56,64]{3,2,1,0} %convolution.65, f32[1,56,56,64]{3,2,1,0} %broadcast.66), metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_2/depthwise/BiasAdd"}
%convert.68 = f32[1,56,56,64]{3,2,1,0} convert(f32[1,56,56,64]{3,2,1,0} %add.67), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/FusedBatchNorm"}
%arg13.14 = f32[64]{0} parameter(13), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg14.15 = f32[64]{0} parameter(14), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg15.16 = f32[64]{0} parameter(15), parameter_replication={false}, metadata={op_name="XLA_Args"}
%batch-norm-inference.69 = f32[1,56,56,64]{3,2,1,0} batch-norm-inference(f32[1,56,56,64]{3,2,1,0} %convert.68, f32[64]{0} %broadcast.33, f32[64]{0} %arg13.14, f32[64]{0} %arg14.15, f32[64]{0} %arg15.16), epsilon=0.001, feature_index=3, metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/FusedBatchNorm"}
%convert.70 = f32[1,56,56,64]{3,2,1,0} convert(f32[1,56,56,64]{3,2,1,0} %batch-norm-inference.69), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/FusedBatchNorm"}
%maximum.73 = f32[1,56,56,64]{3,2,1,0} maximum(f32[1,56,56,64]{3,2,1,0} %broadcast.72, f32[1,56,56,64]{3,2,1,0} %convert.70), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_2/dw_batch_norm/Relu"}
%arg17.18 = f32[1,1,64,128]{3,2,1,0} parameter(17), parameter_replication={false}, metadata={op_name="XLA_Args"}
%convolution.74 = f32[1,56,56,128]{3,2,1,0} convolution(f32[1,56,56,64]{3,2,1,0} %maximum.73, f32[1,1,64,128]{3,2,1,0} %arg17.18), window={size=1x1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="mobilenet/depthwise_seperable_2/pointwise/Conv2D"}
%arg16.17 = f32[128]{0} parameter(16), parameter_replication={false}, metadata={op_name="XLA_Args"}
%broadcast.75 = f32[1,56,56,128]{3,2,1,0} broadcast(f32[128]{0} %arg16.17), dimensions={3}, metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_2/pointwise/BiasAdd"}
%add.76 = f32[1,56,56,128]{3,2,1,0} add(f32[1,56,56,128]{3,2,1,0} %convolution.74, f32[1,56,56,128]{3,2,1,0} %broadcast.75), metadata={op_type="BiasAdd" op_name="mobilenet/depthwise_seperable_2/pointwise/BiasAdd"}
%convert.77 = f32[1,56,56,128]{3,2,1,0} convert(f32[1,56,56,128]{3,2,1,0} %add.76), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/FusedBatchNorm"}
%constant.34 = f32[] constant(1), metadata={op_type="Const" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/Const"}
%broadcast.35 = f32[128]{0} broadcast(f32[] %constant.34), dimensions={}, metadata={op_type="Const" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/Const"}
%arg18.19 = f32[128]{0} parameter(18), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg19.20 = f32[128]{0} parameter(19), parameter_replication={false}, metadata={op_name="XLA_Args"}
%arg20.21 = f32[128]{0} parameter(20), parameter_replication={false}, metadata={op_name="XLA_Args"}
%batch-norm-inference.78 = f32[1,56,56,128]{3,2,1,0} batch-norm-inference(f32[1,56,56,128]{3,2,1,0} %convert.77, f32[128]{0} %broadcast.35, f32[128]{0} %arg18.19, f32[128]{0} %arg19.20, f32[128]{0} %arg20.21), epsilon=0.001, feature_index=3, metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/FusedBatchNorm"}
%convert.79 = f32[1,56,56,128]{3,2,1,0} convert(f32[1,56,56,128]{3,2,1,0} %batch-norm-inference.78), metadata={op_type="FusedBatchNorm" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/FusedBatchNorm"}
%maximum.82 = f32[1,56,56,128]{3,2,1,0} maximum(f32[1,56,56,128]{3,2,1,0} %broadcast.81, f32[1,56,56,128]{3,2,1,0} %convert.79), metadata={op_type="Relu" op_name="mobilenet/depthwise_seperable_2/pw_batch_norm/Relu"}
%convert.83 = f32[1,56,56,128]{3,2,1,0} convert(f32[1,56,56,128]{3,2,1,0} %maximum.82), metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%constant.85 = f32[] constant(0), metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%pad.86 = f32[1,56,56,128]{3,2,1,0} pad(f32[1,56,56,128]{3,2,1,0} %convert.83, f32[] %constant.85), padding=0_0x0_0x0_0x0_0, metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%constant.84 = f32[] constant(0), metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%reduce-window.91 = f32[1,1,1,128]{3,2,1,0} reduce-window(f32[1,56,56,128]{3,2,1,0} %pad.86, f32[] %constant.84), window={size=1x56x56x1 stride=1x2x2x1}, to_apply=%add_F32.87, metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%constant.92 = f32[] constant(3136), metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%broadcast.93 = f32[1,1,1,128]{3,2,1,0} broadcast(f32[] %constant.92), dimensions={}, metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%divide.94 = f32[1,1,1,128]{3,2,1,0} divide(f32[1,1,1,128]{3,2,1,0} %reduce-window.91, f32[1,1,1,128]{3,2,1,0} %broadcast.93), metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%convert.95 = f32[1,1,1,128]{3,2,1,0} convert(f32[1,1,1,128]{3,2,1,0} %divide.94), metadata={op_type="AvgPool" op_name="mobilenet/avg_pool2d/AvgPool"}
%reshape.96 = f32[1,128]{1,0} reshape(f32[1,1,1,128]{3,2,1,0} %convert.95), metadata={op_type="Squeeze" op_name="mobilenet/squeeze"}
%arg22.23 = f32[128,1000]{1,0} parameter(22), parameter_replication={false}, metadata={op_name="XLA_Args"}
%dot.97 = f32[1,1000]{1,0} dot(f32[1,128]{1,0} %reshape.96, f32[128,1000]{1,0} %arg22.23), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="mobilenet/fc/MatMul"}
%arg21.22 = f32[1000]{0} parameter(21), parameter_replication={false}, metadata={op_name="XLA_Args"}
%broadcast.98 = f32[1,1000]{1,0} broadcast(f32[1000]{0} %arg21.22), dimensions={1}, metadata={op_type="BiasAdd" op_name="mobilenet/fc/BiasAdd"}
%add.99 = f32[1,1000]{1,0} add(f32[1,1000]{1,0} %dot.97, f32[1,1000]{1,0} %broadcast.98), metadata={op_type="BiasAdd" op_name="mobilenet/fc/BiasAdd"}
%reshape.100 = f32[1,1000]{1,0} reshape(f32[1,1000]{1,0} %add.99), metadata={op_type="Reshape" op_name="mobilenet/softmax/Reshape"}
%constant.101 = f32[] constant(-inf), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%reduce.106 = f32[1]{0} reduce(f32[1,1000]{1,0} %reshape.100, f32[] %constant.101), dimensions={1}, to_apply=%max_float_.102, metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%broadcast.107 = f32[1,1000]{1,0} broadcast(f32[1]{0} %reduce.106), dimensions={0}, metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%subtract.108 = f32[1,1000]{1,0} subtract(f32[1,1000]{1,0} %reshape.100, f32[1,1000]{1,0} %broadcast.107), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%exponential.109 = f32[1,1000]{1,0} exponential(f32[1,1000]{1,0} %subtract.108), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%convert.110 = f32[1,1000]{1,0} convert(f32[1,1000]{1,0} %exponential.109), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%constant.111 = f32[] constant(0), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%reduce.116 = f32[1]{0} reduce(f32[1,1000]{1,0} %convert.110, f32[] %constant.111), dimensions={1}, to_apply=%add_float_.112, metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%convert.117 = f32[1]{0} convert(f32[1]{0} %reduce.116), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%broadcast.118 = f32[1,1000]{1,0} broadcast(f32[1]{0} %convert.117), dimensions={0}, metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%divide.119 = f32[1,1000]{1,0} divide(f32[1,1000]{1,0} %exponential.109, f32[1,1000]{1,0} %broadcast.118), metadata={op_type="Softmax" op_name="mobilenet/softmax/Softmax"}
%reshape.120 = f32[1,1000]{1,0} reshape(f32[1,1000]{1,0} %divide.119), metadata={op_name="XLA_Retvals"}
%tuple.121 = (f32[1,1000]{1,0}) tuple(f32[1,1000]{1,0} %reshape.120), metadata={op_name="XLA_Retvals"}
ROOT %get-tuple-element.122 = f32[1,1000]{1,0} get-tuple-element((f32[1,1000]{1,0}) %tuple.121), index=0, metadata={op_name="XLA_Retvals"}
}

例子虽然有点长,但结构相当简单,基本符合文法定义。
这个例子是笔者自己模仿实现的 MobileNet-v1 的一部分,因为原始的 MovileNet-v1 相对来说比较长,所以只保留了两个 depthwise_seperable_block,但足以说明问题。XLA 一共导出了88份 HLO 文本,上面展示的是第一份文本,也就是初始 IR。
从全局来看,HLO 经过的 pass 用伪代码可以写为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

def simplification():
batchnorm_expander()
algsimp()
simplify-sorts()
tuple-simplifier()
while-loop-constant-sinking()
simplify-while-loops()
slice-sinker()
dce()
reshape-mover()
constant-folding()
simplify-conditional()

def optimization():
zero_sized_hlo_eliminatioin()
dynamic-index-splitter()
gpu_hlo_support_checker()
CallInliner()
dot_decomposer()
convolution-group-converter()
stable-sort-expander()
element_type_converter()
for i in range(3):
simplification()
hlo-get-dimension-size-rewriter()
zero_sized_hlo_elimination()
transpose-folding()
cse()
dce()
while-loop-trip-count-annotator()

def conv_canonicalization():
consolver-rewriter()
cudnn-conv-rewriter()
cudnn-fused-convolution()
cudnn-conv-padding()
constant-folding()

def layout_assignment():
_layout-assignment()

def post-layout_assignment():
algsimp()
cudnn-conv-algorithm-picker()
tuple-simplifier()
cse()

def fusion():
variadic-op-splitter
for i in range(2):
_fusion()
fusion_merger()
multi_output_fusion()
cse()
dce()

def copy-insertion():
adding_copies_to_resolve_interference()
removing_unnecessary_copies()
adding_special-case_copies()

def GPU-ir-emit-prepare():
dce()
flatten-call-graph()
copy-insertion()
sanitize-constant-names()

def main():
optimization()
conv_canonicalization()
layout_assignment()
post-layout_assignment()
for i in range(3):
fusion()
reduce-precision()
GPU-ir-emit-prepare()

如果用nvprof来观测用了哪些 kernel,可以得到以下结果(只记录了一部分):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
maxwell_gcgemm_64x64_nt
void fft2d_r2c_32x32<float, bool=0, unsigned int=0, bool=0>(float2*, float const *, int, int, int, int, int, int, int, int, int, cudnn::reduced_divisor, bool, int2, int, int)
void gemv2N_kernel_val<float, float, float, int=128, int=4, int=4, int=4, int=1, cublasGemvParams<cublasGemvTensorStridedBatched<float const >, cublasGemvTensorStridedBatched<float>, float>>(float, float, float const )
void fermiPlusCgemmLDS128_batched<bool=1, bool=0, bool=0, bool=0, int=4, int=4, int=4, int=3, int=3, bool=1, bool=0>(float2* const *, float2* const *, float2* const *, float2*, float2 const *, float2 const *, int, int, int, int, int, int, __int64, __int64, __int64, float2 const *, float2 const *, float2, float2, int)
void DSE::regular_fft_pad<int=0, int=1, int=128, int=16, int=32, int=1, float, float, float2>(float2*, float*, int, int3, float*, int, float*, float*, int, int, int, int, int, bool)
void DSE::vector_fft<int=0, int=1, int=128, int=8, int=8, int=1, float, float, float2>(float2*, float2, int, int3, float2*)
void cudnn::detail::explicit_convolve_sgemm<float, int, int=1024, int=5, int=5, int=3, int=3, int=3, int=0, bool=1>(int, int, int, float const *, int, float const , int, cudnn::detail::explicit_convolve_sgemm<float, int, int=1024, int=5, int=5, int=3, int=3, int=3, int=0, bool=1>*, kernel_conv_params, int, int, float, float, int, float const *, float const *)
maxwell_scudnn_winograd_128x128_ldg1_ldg4_tile148t_nt
maxwell_scudnn_128x32_relu_interior_nn
void cudnn::winograd_nonfused::winogradForwardOutput4x4<float, float>(cudnn::winograd_nonfused::WinogradOutputParams<float, float>)
cudnn::maxwell::gemm::computeOffsetsKernel(cudnn::maxwell::gemm::ComputeOffsetsParams)
fusion_9
void Eigen::internal::EigenMetaKernel<Eigen::TensorEvaluator<Eigen::TensorAssignOp<Eigen::TensorMap<Eigen::Tensor<float, int=1, int=1, int>, int=16, Eigen::MakePointer>, Eigen::TensorCwiseUnaryOp<Eigen::internal::scalar_left<float, float, Eigen::internal::scalar_sum_op<float, float>>, Eigen::TensorMap<Eigen::Tensor<float const , int=1, int=1, int>, int=16, Eigen::MakePointer> const > const > const , Eigen::GpuDevice>, int>(float, int=1)
copy_7