Pass 又称为 transform,是编译领域最常见的一种优化技术,tvm 里的 pass 设计,主要是参考了 llvm 的设计思想
pass 本质上是一种图到图的转换,pass不改变图计算的结果
我们知道 tvm 计算图有两种 level 的表示 ir,一种是 relay ir,一种是 tir,不同的 ir 有不同的 pass 优化逻辑,因此 tvm 中的 pass 也有两种:
- relay 层的 pass。代码在 relay/transform,包括很多图结构的优化,比如算符融合,常量折叠,等等,属于偏前端的优化
- tir 层的 pass,偏向底层的优化,比如 prefetch 注入,unrollLoop 等
这里我们只讲 relay 相关的 pass,tir 的后续再讲
1. 基本过程
前面我们将 relay.build 的时候就知道 build 过程中的 OptimizeImpl 阶段就是在执行 Pass 优化,我们来看下 OptimizeImpl 函数的实现
src/relay/backend/build_module.cc
这个函数很长,但是非常简单,我省略了一部分代码
pass_seqs 就是一个数组,函数根据 config,选择相应的 pass,然后放到这个数组里面,最后执行 pass 优化的时候,直接一个 for 循环执行 pass_func() 就完了,由于每个 pass_func 的输入和输出都是 IRModule,串行执行即可
IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; backend::BindParamsInModule(relay_module, params_); Array pass_seqs = GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false); transform::PassContext pass_ctx = PassContext::Current(); // ... relay_module = transform::InferType()(relay_module); // Inline the functions that have been lifted by the module scope. // // TODO(@zhiics) Note that we need to be careful about the subgraphs with // global function calls. We should make sure that these callees are also // inline functions. However, this should be very unlikely for accelerators // and vendor-provided libraries. So we don't handle for now. relay_module = transform::Inline()(relay_module); relay_module = transform::InferType()(relay_module); relay_module = transform::LabelOps()(relay_module); relay_module = transform::AnnotateMemoryScope(config_)(relay_module); ICHECK(relay_module.defined()); return relay_module; }
从这个函数看,relay层有3类pass:
- GetPassPrefix() 函数会首先返回一系列,最常见的、公共的pass,大概20来个左右,可以细看函数
- homogeneous 相关的 pass
- auto schedule 相关的 pass
- meta schedule 相关的 pass
2. 常见的 Pass 列表
grep -r ^Pass src/relay/transforms/ | grep “{“| sed ‘s/(/ /g’ | awk ‘{print $2}’ | sort
基本可得到所有的 Pass 函数声明,如下我只写了重点的部分
- AlterOpLayout:替换操作符的布局或用其他表达式替换基本操作符
- CanonicalizeOps:将特殊算子规范化为基本算子
- CombineParallelConv2D:合并 conv2d 操作
- CombineParallelDense:合并 dense 操作
- ConvertLayout:布局转换
- DeadCodeElimination:删除没用的代码
- DefuseOps:FuseOps的逆操作
- EliminateCommonSubexpr:删除常见的子表达式
- FastMath:将昂贵的非线性函数转换为快速但近似的对应函数。
- FoldConstant:常量表达式折叠
- FuseOps:将 expr 中的操作符融合为更大的操作符
- InferType:类型推断
- Inline:执行内联操作
- LazyGradientInit:减少梯度张量的内存使用
- MergeCompilerRegions:合并编译区域
- SimplifyExpr:简化表达式,比如合并连续的 reshapes 操作
- SimplifyInference:简化推理阶段的 data-flow
- SplitArgs:将具有大量参数的函数切割成更小的块
- ToBasicBlockNormalForm:将表达式转换为基本块的形式
- ToMixedPrecision:自动混合精度重写