我们从一个最简单的 Pass 开始入手,RemoveUnusedFunctions,带大家一步一步了解 Pass 优化的具体过程
这个 Pass 看名字顾名思义,就是要把 Module 里执行不到的函数删除掉
代码:src/relay/transforms/dead_code.cc
1. 无用函数的定义
无用函数是指那些根本不会调用到的函数
我们用 tvm.relay 手写一个最简单的网络,来跟踪优化的过程和效果
如下,有2个函数,f1 和 f2,其中 f1 就是无用的函数,因为 main 函数进来后,根本不会调用到这个函数
import tvm from tvm import relay from tvm.contrib import relay_viz a1 = relay.var("a1", shape=(1,), dtype="float32") a2 = relay.var("a2", shape=(1,), dtype="float32") add_op = relay.add(a1, a2) f1 = relay.Function([a1, a2], add_op) d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32") w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32") b1 = relay.var("b1", shape=(32,), dtype="float32") conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1)) bias = relay.nn.bias_add(conv, b1) relu = relay.nn.relu(bias) f2 = relay.Function([d1, w1, b1], relu) mod = tvm.IRModule({'add_func': f1, 'main': f2}) mod = relay.transform.InferType()(mod) mod = relay.transform.RemoveUnusedFunctions()(mod)
优化完之后,右边的图就被去掉了(图片的可视化,参考之前的 Relay 计算图)
2. RemoveUnusedFunctions 的实现分析
RemoveUnusedFunctions 这个 Pass 的实现,是放在 src/relay/backend/vm/removed_unused_funcs.cc 文件里的,和之前的 Pass 有一点点不太一样,之前的所有 Pass 都是放在 src/relay/transform 目录下的
看下这个 Pass 的函数主体:
/*! * \brief Remove functions that are not used. * * \param module The Relay module. * \param entry_funcs The set of functions that can be entry function. * * \return The module with dead functions removed. */ IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { auto funcs = CallTracer(module).Trace(entry); called_funcs.insert(funcs.cbegin(), funcs.cend()); } auto existing_functions = module->functions; for (auto f : existing_functions) { auto it = called_funcs.find(f.first->name_hint); if (it == called_funcs.end()) { module->Remove(f.first); } } return module; }
可以看到 RemoveUsedFunctions 的基本逻辑就是:
- 从 entry_funcs(默认就只有一个 main 函数)入口,遍历当前 module 的所有函数,返回被 entry_func 依赖的函数列表
- 遍历 module->functions,如果函数不在上面的依赖列表里,说明当前函数永远不会被执行到,就可以删除
上面的 add_op 函数就是这么被干掉了
2.1. 依赖分析
RemoveUsedFunctions 使用一个叫 CallTracer 的 class 来实现函数的依赖分析
CallTracer 有3个成员变量:
- module_ 表示当前正在分析的 module,RemoveUsedFunctions 是一个 module 级别的 Pass
- called_funcs_ 记录当前 module 所有被调用到的函数
- visiting_ 记录当前已经访问过的函数(避免递归调用等引发的死循环)
struct CallTracer : ExprVisitor { IRModule module_; std::unordered_set called_funcs_; std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_; }
CallTrancer 本质上是一个 ExprVisitor,ExprVisitor 是 tvm 用来解决图遍历最主要的数据结构,其背后用了大量的 c++ template,有点复杂,初次看的人如果看其实现原理估计看的有点晕,但是我们这里不关注 ExprVisitor 底层怎么实现,我们只需要关注 ExprVisitor 能用来做什么就可以了。
ExprVisitor 允许你通过扩展的方式,来自定义遍历的行为
比如这里,CallTrancer 重载了基类的 VisitExpr_(CallNode *) 以及 VisitExpr_(FunctionNode *),表示我们只关注这2种类型的节点,当遍历到这种类型的节点时,回调我的 VisitExpr_ 函数。在每个Node类型对应的 VisitExpr_ 实现里,继续递归遍历这个 Node 的下游依赖的 Node 节点,最终完成整个图所有节点的遍历
struct CallTracer : ExprVisitor { void VisitExpr_(const CallNode* call_node) final { // TODO(mbs): Cleanup shape functions. CallLoweredProps props = GetCallLoweredProps(call_node); if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) { auto callee = Downcast(props.attrs.metadata["prim_shape_fn_var"]); // We are implicitly calling the shape function *in addition to* the callee. called_funcs_.insert(callee->name_hint); } ExprVisitor::VisitExpr_(call_node); } void VisitExpr_(const FunctionNode* func_node) final { auto func = GetRef(func_node); if (visiting_.find(func) == visiting_.end()) { visiting_.insert(func); for (auto param : func_node->params) { ExprVisitor::VisitExpr(param); } ExprVisitor::VisitExpr(func_node->body); } } }
通过上面的实现,我们知道 CallTracer 把从 main 函数开始,所有依赖的 FunctionNode 都遍历了一遍,并记录到 visiting_ 列表里,最后和 IRModule 全局的 functions() 求 diff,我们就知道哪些函数是没用的函数了
2.2. 函数删除
一般来说,删除计算图中的节点,是会导致计算图异常的,因为每个节点的输入输出都不一样
不过因为我们这里要删除的是无用的函数,也就是说,这个函数在 IRModule 里面,但并不在真正的计算图里
从最上面的实现分析,我们知道,对于无用的函数,会调用 IRModuleNode::Remove() 删除
module->Remove(f.first);
看下里面的实现
void IRModuleNode::Remove(const GlobalVar& var) { auto functions_node = this->functions.CopyOnWrite(); functions_node->erase(var); auto gvar_node = global_var_map_.CopyOnWrite(); gvar_node->erase(var->name_hint); }