深入浅出 tvm – (10) Relay Pass 之剔除无用函数

我们从一个最简单的 Pass 开始入手,RemoveUnusedFunctions,带大家一步一步了解 Pass 优化的具体过程
这个 Pass 看名字顾名思义,就是要把 Module 里执行不到的函数删除掉
代码:src/relay/transforms/dead_code.cc

1. 无用函数的定义

无用函数是指那些根本不会调用到的函数
我们用 tvm.relay 手写一个最简单的网络,来跟踪优化的过程和效果
如下,有2个函数,f1 和 f2,其中 f1 就是无用的函数,因为 main 函数进来后,根本不会调用到这个函数
import tvm
from tvm import relay
from tvm.contrib import relay_viz
 
a1 = relay.var("a1", shape=(1,), dtype="float32")
a2 = relay.var("a2", shape=(1,), dtype="float32")
add_op = relay.add(a1, a2)
f1 = relay.Function([a1, a2], add_op)

d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32")
w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32")
b1 = relay.var("b1", shape=(32,), dtype="float32")
conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1))
bias = relay.nn.bias_add(conv, b1)
relu = relay.nn.relu(bias)
f2 = relay.Function([d1, w1, b1], relu)

mod = tvm.IRModule({'add_func': f1, 'main': f2})
mod = relay.transform.InferType()(mod)

mod = relay.transform.RemoveUnusedFunctions()(mod)
优化完之后,右边的图就被去掉了(图片的可视化,参考之前的 Relay 计算图)
0

2. RemoveUnusedFunctions 的实现分析

RemoveUnusedFunctions 这个 Pass 的实现,是放在 src/relay/backend/vm/removed_unused_funcs.cc 文件里的,和之前的 Pass 有一点点不太一样,之前的所有 Pass 都是放在 src/relay/transform 目录下的
看下这个 Pass 的函数主体:
/*!
 * \brief Remove functions that are not used.
 *
 * \param module The Relay module.
 * \param entry_funcs The set of functions that can be entry function.
 *
 * \return The module with dead functions removed.
 */
IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) {
  std::unordered_set called_funcs{};
  for (auto entry : entry_funcs) {
    auto funcs = CallTracer(module).Trace(entry);
    called_funcs.insert(funcs.cbegin(), funcs.cend());
  }
  auto existing_functions = module->functions;
  for (auto f : existing_functions) {
    auto it = called_funcs.find(f.first->name_hint);
    if (it == called_funcs.end()) {
      module->Remove(f.first);
    }
  }
  return module;
}
可以看到 RemoveUsedFunctions 的基本逻辑就是:
  1. 从 entry_funcs(默认就只有一个 main 函数)入口,遍历当前 module 的所有函数,返回被 entry_func 依赖的函数列表
  2. 遍历 module->functions,如果函数不在上面的依赖列表里,说明当前函数永远不会被执行到,就可以删除
上面的 add_op 函数就是这么被干掉了

2.1. 依赖分析

RemoveUsedFunctions 使用一个叫 CallTracer 的 class 来实现函数的依赖分析
CallTracer 有3个成员变量:
  1. module_ 表示当前正在分析的 module,RemoveUsedFunctions 是一个 module 级别的 Pass
  2. called_funcs_ 记录当前 module 所有被调用到的函数
  3. visiting_ 记录当前已经访问过的函数(避免递归调用等引发的死循环)
struct CallTracer : ExprVisitor {
  IRModule module_;
  std::unordered_set called_funcs_;

  std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
}
CallTrancer 本质上是一个 ExprVisitor,ExprVisitor 是 tvm 用来解决图遍历最主要的数据结构,其背后用了大量的 c++ template,有点复杂,初次看的人如果看其实现原理估计看的有点晕,但是我们这里不关注 ExprVisitor 底层怎么实现,我们只需要关注 ExprVisitor 能用来做什么就可以了。
ExprVisitor 允许你通过扩展的方式,来自定义遍历的行为
比如这里,CallTrancer 重载了基类的 VisitExpr_(CallNode *) 以及 VisitExpr_(FunctionNode *),表示我们只关注这2种类型的节点,当遍历到这种类型的节点时,回调我的 VisitExpr_ 函数。在每个Node类型对应的 VisitExpr_ 实现里,继续递归遍历这个 Node 的下游依赖的 Node 节点,最终完成整个图所有节点的遍历
struct CallTracer : ExprVisitor {
  void VisitExpr_(const CallNode* call_node) final {
    // TODO(mbs): Cleanup shape functions.
    CallLoweredProps props = GetCallLoweredProps(call_node);
    if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {
      auto callee = Downcast(props.attrs.metadata["prim_shape_fn_var"]);
      // We are implicitly calling the shape function *in addition to* the callee.
      called_funcs_.insert(callee->name_hint);
    }
    ExprVisitor::VisitExpr_(call_node);
  }

  void VisitExpr_(const FunctionNode* func_node) final {
    auto func = GetRef(func_node);
    if (visiting_.find(func) == visiting_.end()) {
      visiting_.insert(func);
      for (auto param : func_node->params) {
        ExprVisitor::VisitExpr(param);
      }
      ExprVisitor::VisitExpr(func_node->body);
    }
  }
}
通过上面的实现,我们知道 CallTracer 把从 main 函数开始,所有依赖的 FunctionNode 都遍历了一遍,并记录到 visiting_ 列表里,最后和 IRModule 全局的 functions() 求 diff,我们就知道哪些函数是没用的函数了

2.2. 函数删除

一般来说,删除计算图中的节点,是会导致计算图异常的,因为每个节点的输入输出都不一样
不过因为我们这里要删除的是无用的函数,也就是说,这个函数在 IRModule 里面,但并不在真正的计算图里
从最上面的实现分析,我们知道,对于无用的函数,会调用 IRModuleNode::Remove() 删除
module->Remove(f.first);
看下里面的实现
void IRModuleNode::Remove(const GlobalVar& var) {
  auto functions_node = this->functions.CopyOnWrite();
  functions_node->erase(var);
  auto gvar_node = global_var_map_.CopyOnWrite();
  gvar_node->erase(var->name_hint);
}
发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注