深入浅出 tvm – (11) Relay Pass 之算子融合

算子融合,FuseOps,是编译领域最常见的一个优化技术,在这里也算属于 relay 里面最复杂的一类优化了,整个优化的核心逻辑 1k+ 行
代码:src/relay/transforms/fuse_ops.cc
算子融合的目的最终是要解决 AI 处理器的内存墙、并行墙的问题,提升 Tensor 数据的访存局部性。目前算子融合的技术路线有比较多,这里不涉及,我们只需要知道 tvm 是基于支配树来实现算子融合的就行了

1. 基本概念

1.1. 算子融合

算子融合,即将多个算子组合在一起放到同一个核中,通过算子融合的方式,不需要将中间结果保存到全局内存,进而减少执行所需要的时间。
tvm 中将算子分为7种类型:
  1. kElemWise:2个 tensor 之间按元素逐个操作的算子,实际上所有四则运算都是这种类型,https://deeplizard.com/learn/video/QscEWm0QTRY
  2. kBroadcast:见上述链接,到操作两个不同形状的 tensor 时
  3. kInjective:一对一映射函数,比如 add / sqrt / exp 等操作算子(operator)
  4. kCommReduce:多到少的映射,输入到输出具有降维性质,如:sum / max / min等操作操作算子(operator)
  5. kOutEWiseFusable:这是计算比较复杂的,如:conv2d / bn /  relu等操作算子(operator)
  6. kTuple:xx
  7. kOpaque:无法被融合的算符,比如 sort
根据以上对算符的不同类型,TVM提供了三种融合规则(我看论文是这么写的,但是现在不止3种了?):
0
从融合算子的内部视角看,这种融合实际上是数据计算pipeline化,即两次计算中间数据不再经历store-load过程,而是直接给到下一个计算单元完成计算。

1.2. 支配树

支配树是 tvm 算子融合的一个重要概念
参考资料:
  1. https://www.cnblogs.com/LuckyGlass-blog/p/14670451.html xx
  2. 《Engineering a Compiler》龙书第9章
支配树的定义:
给定一个有向图和一个起点 G,起点是 R,从 R 可以到达图上的所有点
当u是所有到达 v 的路径(从起点 R 开始)中的必经点时,我们称 u 支配 v。必经点可以理解为删除这个点之后,R 无法到达 v
显然:
  1. v 有多个支配点,u 是 v 的支配点之一
  2. 除起点 R 以外的所有点,都有2个平凡的支配点,一个是起点,一个是自身
在 v 的所有支配点中,离 v 最近的一个支配点(不能是 v 本身),称为 v 的最近支配点,记为 idom[v]
建立所有 idom[v] -> v 的边,就得到了一个完整的支配树
如下:蓝色的边构成的树就是这个图的支配树
0

2. 算子融合的具体实现

The general algorithm is as follows:
  • Construct a DAG of dataflow graph for dominator analysis
  • Construct a post-dominator tree which gives immediate post dominator of each node.
  • Run fusion algorithm with the given post-dominator information

2.2. 阶段1:建立 DAG 图

tvm 使用 IndexedForwardGraph 结构来保存 DAG 图,这个结构体里有几个关键的信息:
  1. Edge 定义,表示边
  2. Node 定义,表示节点
  3. 还有一个 std::vector<Node*> post_dfs_order; 保存了整个树的后序遍历
还有一个 IndexedForwardGraph::Creator 类,继承自 ExprVisitor,用来实现对计算图的遍历
这个类重载了 ExprVisitor 的所有节点类型的 VisitExpr_ 方法,对遍历到的节点只干2个事情:
  1. 后序遍历计算图(原理自行百度),保存到 post_dfs_order 里面,由于遍历是从计算图出口开始的,而且是后序遍历,所以 post_dfs_order 最后一个保存的就是计算图的出口节点
  2. 推断节点类型,OpPatternKind,并调用 Update() 函数把节点类型保存到 Node 结构里
  void VisitExpr_(const TupleNode* op) final {
    ICHECK(graph_.node_map.count(op));
    Node* tuple_node = graph_.node_map.at(op);
    tuple_node->pattern = kTuple;
    for (const Expr& field : op->fields) {
      if (field->checked_type().as()) {
        this->Update(field, tuple_node, kInjective);
      } else {
        this->Update(field, nullptr, kOpaque);
      }
    }
    ExprVisitor::VisitExpr_(op);
    this->AddNode(op);
  }
另外,CallNode 节点的类型推断是最复杂的,后序再研究下

2.3. 阶段2:建立后续支配树

接下来看后序支配树的构建。构建函数是PostDom。因为根节点(DAG图的出口)在post_dfs_order中最后,所以从根节点开始寻找每个节点出点的LCA,这个LCA就是后序支配点。
幸运的是,tvm 把 LCA 的求解算法封装到了 support::Arena 类里面,我们暂时不必深入去理解,只需要知道通过 GetNode() 即可获得从根节点到输入节点的最近支配点
DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
  DominatorTree tree;
  tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
  // reverse topo order
  for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
    size_t index = i - 1;
    tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
 
  }
  return tree;
}

2.4. 阶段3:操作融合

完成支配树构建之后,就可以开始融合操作了。
整个融合分为3个阶段,每个阶段执行不同的融合操作,具体逻辑都在 RunFuse() 函数里面
std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
    const IndexedForwardGraph& graph) {
  this->InitGroups(graph);
  if (opt_level_ == 0) return std::move(groups_);
  // get post dominator tree
  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
  // run fusion algorithm.
  for (int phase = 0; phase < 3; ++phase) { this->RunFuse(graph, post_dom_tree, phase);
  }
  return std::move(groups_);
}
这个函数干了3个事情:
  1. InitGroups:把计算图的所有 Node,转换成 Group 结构,构造一个和计算图一模一样的 Group 图
  2. 后序遍历得到 post_dom_tree,这个前面已讲
  3. 然后分别遍历dag,postDominator tree,以及group图中节点,来判断算符是否能被融合。注意这次遍历,是从计算图入口而不是出口开始遍历
融合规则是啥??conv2d + bn + relu 咋就融合了?

3. 一个融合示例

tvm 代码库的 tests 集里有很多 fuse op 相关的测试用例,用来调试学习最合适不过
tests/python/relay/test_pass_fuse_ops.py
写了一个融合的例子,但是没研究清楚,例子网上找的,供大家调试
import tvm
from tvm import relay
from tvm.contrib import relay_viz

# BN
def batch_norm(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs):
    name = kwargs.get("name")
    kwargs.pop("name")
    if not gamma:
        gamma = relay.var(name + "_gamma")
    if not beta:
        beta = relay.var(name + "_beta")
    if not moving_mean:
        moving_mean = relay.var(name + "_moving_mean")
    if not moving_var:
        moving_var = relay.var(name + "_moving_var")
    return relay.nn.batch_norm(data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs)[0]

# conv2d
def conv2d(data, weight=None, **kwargs):
    name = kwargs.get("name")
    kwargs.pop("name")
    if not weight:
        weight = relay.var(name + "_weight")
    return relay.nn.conv2d(data, weight, **kwargs)

# conv2d+BN+ReLU
def simplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), epsilon=1e-5):
    conv = conv2d(
            data=data,
            channels=channels,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_layout='NCHW',
            name=name+'_conv')
    bn = batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')
    act = relay.nn.relu(data=bn)
    return act

data_shape = (1, 3, 224, 224)
kernel_shape = (32, 3, 3, 3)
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
act = simplenet(data, "graph", 32, strides=(2, 2))
func = relay.Function(relay.analysis.free_vars(act), act)

mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
mod = relay.transform.FuseOps()(mod)

print(mod)

viz = relay_viz.RelayVisualizer(
    mod,
    plotter=relay_viz.DotPlotter(),
    parser=relay_viz.DotVizParser())

viz.render('1')
融合前:
0
融合后:
0

4. 扩展阅读

4.1. 相关论文

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注