深入浅出 tvm – (9) Relay Pass 优化

Pass 又称为 transform,是编译领域最常见的一种优化技术,tvm 里的 pass 设计,主要是参考了 llvm 的设计思想
pass 本质上是一种图到图的转换,pass不改变图计算的结果
我们知道 tvm 计算图有两种 level 的表示 ir,一种是 relay ir,一种是 tir,不同的 ir 有不同的 pass 优化逻辑,因此 tvm 中的 pass 也有两种:
  1. relay 层的 pass。代码在 relay/transform,包括很多图结构的优化,比如算符融合,常量折叠,等等,属于偏前端的优化
  2. tir 层的 pass,偏向底层的优化,比如 prefetch 注入,unrollLoop 等
这里我们只讲 relay 相关的 pass,tir 的后续再讲

1. 基本过程

前面我们将 relay.build 的时候就知道 build 过程中的 OptimizeImpl 阶段就是在执行 Pass 优化,我们来看下 OptimizeImpl 函数的实现
src/relay/backend/build_module.cc
这个函数很长,但是非常简单,我省略了一部分代码
pass_seqs 就是一个数组,函数根据 config,选择相应的 pass,然后放到这个数组里面,最后执行 pass 优化的时候,直接一个 for 循环执行 pass_func() 就完了,由于每个 pass_func 的输入和输出都是 IRModule,串行执行即可
  IRModule OptimizeImpl(IRModule relay_module) {
    ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler.";

    backend::BindParamsInModule(relay_module, params_);

    Array pass_seqs =
        GetPassPrefix(/*is_homogenous=*/config_->primitive_targets.size() == 1, /*is_vm=*/false);
    transform::PassContext pass_ctx = PassContext::Current();
    // ...    

    relay_module = transform::InferType()(relay_module);

    // Inline the functions that have been lifted by the module scope.
    //
    // TODO(@zhiics) Note that we need to be careful about the subgraphs with
    // global function calls. We should make sure that these callees are also
    // inline functions. However, this should be very unlikely for accelerators
    // and vendor-provided libraries. So we don't handle for now.
    relay_module = transform::Inline()(relay_module);
    relay_module = transform::InferType()(relay_module);
    relay_module = transform::LabelOps()(relay_module);
    relay_module = transform::AnnotateMemoryScope(config_)(relay_module);

    ICHECK(relay_module.defined());

    return relay_module;
  }
从这个函数看,relay层有3类pass:
  1. GetPassPrefix() 函数会首先返回一系列,最常见的、公共的pass,大概20来个左右,可以细看函数
  2. homogeneous 相关的 pass
  3. auto schedule 相关的 pass
  4. meta schedule 相关的 pass

2. 常见的 Pass 列表

grep -r ^Pass src/relay/transforms/ | grep “{“| sed ‘s/(/ /g’ | awk ‘{print $2}’ | sort
基本可得到所有的 Pass 函数声明,如下我只写了重点的部分
  1. AlterOpLayout:替换操作符的布局或用其他表达式替换基本操作符
  2. CanonicalizeOps:将特殊算子规范化为基本算子
  3. CombineParallelConv2D:合并 conv2d 操作
  4. CombineParallelDense:合并 dense 操作
  5. ConvertLayout:布局转换
  6. DeadCodeElimination:删除没用的代码
  7. DefuseOps:FuseOps的逆操作
  8. EliminateCommonSubexpr:删除常见的子表达式
  9. FastMath:将昂贵的非线性函数转换为快速但近似的对应函数。
  10. FoldConstant:常量表达式折叠
  11. FuseOps:将 expr 中的操作符融合为更大的操作符
  12. InferType:类型推断
  13. Inline:执行内联操作
  14. LazyGradientInit:减少梯度张量的内存使用
  15. MergeCompilerRegions:合并编译区域
  16. SimplifyExpr:简化表达式,比如合并连续的 reshapes 操作
  17. SimplifyInference:简化推理阶段的 data-flow
  18. SplitArgs:将具有大量参数的函数切割成更小的块
  19. ToBasicBlockNormalForm:将表达式转换为基本块的形式
  20. ToMixedPrecision:自动混合精度重写

3. 数据结构

由于每个 Pass 都支持 operator(),其输入输出都是 IRModule,整个优化过程没有什么复杂的用法,因此 Pass 的架构还是比较简单的
Pass 的核心数据结构有3个:
  1. Pass & PassNode
  2. PassInfo & PassInfoNode
  3. PassContext & PassContextNode
tvm 的数据结构设计,看着好像就是每个 object 都有一个相应的 Node 数据结构

3.1. PassContext

PassContext 结构保存了 Pass 优化的上下文相关的信息,并且这个数据结构是每个线程一个,通过 PassContext::Current() 可以获得一个 thread local 的当前生效的 PassContext
定义如下:
class PassContextNode : public Object {
 public:
  /*! \brief The default optimization level. */
  int opt_level{2};

  /*! \brief The list of required passes. */
  Array required_pass;
  /*! \brief The list of disabled passes. */
  Array disabled_pass;
  /*! \brief The diagnostic context. */
  mutable Optional diag_ctx;
  /*! \brief Pass specific configurations. */
  Map<String, ObjectRef> config;

  /*! \brief A list of pass instrument implementations. */
  Array instruments;
  // ...
}
其中最关键的就是 opt_level 了,表示优化的级别,如果为 0 表示执行最简优化,默认是2,我一般用3
required_pass 表示必须打开的 pass,用户显式指定
disabled_pass 表示必须关闭的 pass,用户显式指定

3.2. Pass

PassNode 是所有 Pass 实现的基类,其定义了2个必须要实现的接口,就是 operator(),这个是 c++ 中的操作符重载的概念,不了解的同学可以先了解下
class PassNode : public Object {
 public:
  virtual ~PassNode() {}

  IRModule operator()(IRModule mod) const {
    return this->operator()(std::move(mod), PassContext::Current());
  }

  virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
  // ...
}
另外,通常我们在扩展 Pass 的时候并不会直接继承 PassNode,tvm 提供了3种类型的 Pass
首先看 Module-Level, 这个与原始 Pass 最接近,操作的就是 IRModule
class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};
然后看 Function-Level, 这个遍历 Module 中的Function进行处理
class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};
最后是 Sequential, 这个类似 pytorch 里面的 nn.Sequential, 包含了一堆可执行的Pass按照顺序执行
class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
src/relay/transform 的所有 Pass 都是继承自上面3种类型中的一个

3.3. PassInfo

PassInfoNode 保存了 Pass 的优化级别,名称,以及当前 Pass 的前置依赖,其结构如下:
class PassInfoNode : public Object {
 public:
  /*! \brief The minimal optimization level that this pass will be enabled. */
  int opt_level;

  /*! \brief The name of an optimization/analysis pass. */
  String name;

  /*! \brief The passes that are required to perform the current pass. */
  Array required;


  // ...   
}
Pass 的前置依赖主要是用来决定多个 Pass 的执行顺序
发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注