算子融合,FuseOps,是编译领域最常见的一个优化技术,在这里也算属于 relay 里面最复杂的一类优化了,整个优化的核心逻辑 1k+ 行
代码:src/relay/transforms/fuse_ops.cc
算子融合的目的最终是要解决 AI 处理器的内存墙、并行墙的问题,提升 Tensor 数据的访存局部性。目前算子融合的技术路线有比较多,这里不涉及,我们只需要知道 tvm 是基于支配树来实现算子融合的就行了
1. 基本概念
1.1. 算子融合
算子融合,即将多个算子组合在一起放到同一个核中,通过算子融合的方式,不需要将中间结果保存到全局内存,进而减少执行所需要的时间。
tvm 中将算子分为7种类型:
- kElemWise:2个 tensor 之间按元素逐个操作的算子,实际上所有四则运算都是这种类型,https://deeplizard.com/learn/video/QscEWm0QTRY
- kBroadcast:见上述链接,到操作两个不同形状的 tensor 时
- kInjective:一对一映射函数,比如 add / sqrt / exp 等操作算子(operator)
- kCommReduce:多到少的映射,输入到输出具有降维性质,如:sum / max / min等操作操作算子(operator)
- kOutEWiseFusable:这是计算比较复杂的,如:conv2d / bn / relu等操作算子(operator)
- kTuple:xx
- kOpaque:无法被融合的算符,比如 sort
根据以上对算符的不同类型,TVM提供了三种融合规则(我看论文是这么写的,但是现在不止3种了?):
从融合算子的内部视角看,这种融合实际上是数据计算pipeline化,即两次计算中间数据不再经历store-load过程,而是直接给到下一个计算单元完成计算。
1.2. 支配树
支配树是 tvm 算子融合的一个重要概念
参考资料:
- https://www.cnblogs.com/LuckyGlass-blog/p/14670451.html xx
- 《Engineering a Compiler》龙书第9章
支配树的定义:
给定一个有向图和一个起点 G,起点是 R,从 R 可以到达图上的所有点
当u是所有到达 v 的路径(从起点 R 开始)中的必经点时,我们称 u 支配 v。必经点可以理解为删除这个点之后,R 无法到达 v
显然:
- v 有多个支配点,u 是 v 的支配点之一
- 除起点 R 以外的所有点,都有2个平凡的支配点,一个是起点,一个是自身
在 v 的所有支配点中,离 v 最近的一个支配点(不能是 v 本身),称为 v 的最近支配点,记为 idom[v]
建立所有 idom[v] -> v 的边,就得到了一个完整的支配树
如下:蓝色的边构成的树就是这个图的支配树
2. 算子融合的具体实现
The general algorithm is as follows:
- Construct a DAG of dataflow graph for dominator analysis
- Construct a post-dominator tree which gives immediate post dominator of each node.
- Run fusion algorithm with the given post-dominator information
2.2. 阶段1:建立 DAG 图
tvm 使用 IndexedForwardGraph 结构来保存 DAG 图,这个结构体里有几个关键的信息:
- Edge 定义,表示边
- Node 定义,表示节点
- 还有一个 std::vector<Node*> post_dfs_order; 保存了整个树的后序遍历
还有一个 IndexedForwardGraph::Creator 类,继承自 ExprVisitor,用来实现对计算图的遍历
这个类重载了 ExprVisitor 的所有节点类型的 VisitExpr_ 方法,对遍历到的节点只干2个事情:
- 后序遍历计算图(原理自行百度),保存到 post_dfs_order 里面,由于遍历是从计算图出口开始的,而且是后序遍历,所以 post_dfs_order 最后一个保存的就是计算图的出口节点
- 推断节点类型,OpPatternKind,并调用 Update() 函数把节点类型保存到 Node 结构里
void VisitExpr_(const TupleNode* op) final { ICHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { this->Update(field, tuple_node, kInjective); } else { this->Update(field, nullptr, kOpaque); } } ExprVisitor::VisitExpr_(op); this->AddNode(op); }
另外,CallNode 节点的类型推断是最复杂的,后序再研究下
2.3. 阶段2:建立后续支配树
接下来看后序支配树的构建。构建函数是PostDom。因为根节点(DAG图的出口)在post_dfs_order中最后,所以从根节点开始寻找每个节点出点的LCA,这个LCA就是后序支配点。
幸运的是,tvm 把 LCA 的求解算法封装到了 support::Arena 类里面,我们暂时不必深入去理解,只需要知道通过 GetNode() 即可获得从根节点到输入节点的最近支配点
DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { DominatorTree tree; tree.nodes.resize(graph.post_dfs_order.size(), nullptr); // reverse topo order for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { size_t index = i - 1; tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); } return tree; }
2.4. 阶段3:操作融合
完成支配树构建之后,就可以开始融合操作了。
整个融合分为3个阶段,每个阶段执行不同的融合操作,具体逻辑都在 RunFuse() 函数里面
std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition( const IndexedForwardGraph& graph) { this->InitGroups(graph); if (opt_level_ == 0) return std::move(groups_); // get post dominator tree auto post_dom_tree = DominatorTree::PostDom(arena_, graph); // run fusion algorithm. for (int phase = 0; phase < 3; ++phase) { this->RunFuse(graph, post_dom_tree, phase); } return std::move(groups_); }
这个函数干了3个事情:
- InitGroups:把计算图的所有 Node,转换成 Group 结构,构造一个和计算图一模一样的 Group 图
- 后序遍历得到 post_dom_tree,这个前面已讲
- 然后分别遍历dag,postDominator tree,以及group图中节点,来判断算符是否能被融合。注意这次遍历,是从计算图入口而不是出口开始遍历
融合规则是啥??conv2d + bn + relu 咋就融合了?
3. 一个融合示例
tvm 代码库的 tests 集里有很多 fuse op 相关的测试用例,用来调试学习最合适不过
tests/python/relay/test_pass_fuse_ops.py
写了一个融合的例子,但是没研究清楚,例子网上找的,供大家调试
import tvm from tvm import relay from tvm.contrib import relay_viz # BN def batch_norm(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs): name = kwargs.get("name") kwargs.pop("name") if not gamma: gamma = relay.var(name + "_gamma") if not beta: beta = relay.var(name + "_beta") if not moving_mean: moving_mean = relay.var(name + "_moving_mean") if not moving_var: moving_var = relay.var(name + "_moving_var") return relay.nn.batch_norm(data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs)[0] # conv2d def conv2d(data, weight=None, **kwargs): name = kwargs.get("name") kwargs.pop("name") if not weight: weight = relay.var(name + "_weight") return relay.nn.conv2d(data, weight, **kwargs) # conv2d+BN+ReLU def simplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), epsilon=1e-5): conv = conv2d( data=data, channels=channels, kernel_size=kernel_size, strides=strides, padding=padding, data_layout='NCHW', name=name+'_conv') bn = batch_norm(data=conv, epsilon=epsilon, name=name + '_bn') act = relay.nn.relu(data=bn) return act data_shape = (1, 3, 224, 224) kernel_shape = (32, 3, 3, 3) dtype = "float32" data = relay.var("data", shape=data_shape, dtype=dtype) act = simplenet(data, "graph", 32, strides=(2, 2)) func = relay.Function(relay.analysis.free_vars(act), act) mod = tvm.IRModule.from_expr(func) mod = relay.transform.InferType()(mod) mod = relay.transform.FuseOps()(mod) print(mod) viz = relay_viz.RelayVisualizer( mod, plotter=relay_viz.DotPlotter(), parser=relay_viz.DotVizParser()) viz.render('1')
融合前:
融合后: