深入浅出 tvm – (13) Relay Pass 之常量折叠

和算子融合一样,常量折叠是编译领域里最常见的一个优化,简单来说,就是把常量表达式前置计算,在编译阶段就计算好,然后以常量的形式翻译成底层机器码,以提高执行效率,减少计算量
实际上大部分的编译器,常量折叠一般包含2种优化技术:常量折叠和常量传播

1. 基本概念

1.1. 常量折叠

constant folding,常量折叠,编译器优化技术之一,通过对编译时常量或常量表达式进行计算来简化代码。以下面的代码为例:
i = 320 * 200 * 32;
上面的代码中,编译器通常会在编译过程中直接对表达式进行求值,计算出320 * 200 * 32的结果,而不会生成2个乘法指令。
还有一些更复杂(但不清楚tvm是否支持,后面验证下)。比如,在执行一些复杂表达式的计算时,我们可以将表达式内部一些常量运算合并,最终起到简化的效果,如下
优化前(左边)每个表达式的运算量是 8 flop,优化后(右边)的运算量是 2 flop,运算效率极大提升了
0
不过,实际上 tvm 并没有做的那么高级,tvm 的常亮折叠只用来处理一些比较简单的场景

1.2. 常量传播

constant propagation,常量传播,同样也是编译器最常见的优化技术之一,在编译的过程中,对常量依赖做一些基本的推演和提前计算,再使用常量折叠技术来简化代码。如下:
int x = 14;
int y = 7 - x / 2;
return y * (28 / x + 2);

//常量传播

int x = 14;
int y = 7 - 14 / 2;
return y * (28 / 14 + 2);

//常量折叠

int x = 14;
int y = 0;
return 0;

2. 从一个示例开始

我们写一个简单的,可折叠的示例
import tvm
from tvm import relay
from tvm.contrib import relay_viz
from tvm import nd

a1 = relay.var("a1", shape=(1,), dtype="float32")
c1 = relay.const(10, 'float32')
f1 = relay.add(c1, c1)
f2 = relay.multiply(f1, relay.const(2, "float32"))
f3 = relay.multiply(f2, a1)

mod = tvm.IRModule.from_expr(f3)
mod = relay.transform.InferType()(mod)

print(mod)
mod = relay.transform.FoldConstant()(mod)
print(mod)
上述代码执行的时候,启动 VLOG 调试:TVM_LOG_DEBUG=src/relay/transforms/fold_constant.cc=2
得到如下关键信息(我做了一些过滤):
[05:45:37] src/relay/transforms/fold_constant.cc:254: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating: add(10f /* ty=float32 */, 10f /* ty=float32 */) /* ty=float32 */
[05:45:37] src/relay/transforms/fold_constant.cc:273: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant: 20f
[05:45:37] src/relay/transforms/fold_constant.cc:254: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating: multiply(20f, 2f /* ty=float32 */) /* ty=float32 */
[05:45:37] src/relay/transforms/fold_constant.cc:273: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant: 40f
看起来做了2个操作:
  1. 第一步,先把 c1 + c1,在编译阶段,直接计算了,得到结果 c2 = 20
  2. 第二步,再直接计算 c3 = c2 * 2 = 40
优化前 IR
def @main(%a1: Tensor[(1), float32] /* ty=Tensor[(1), float32] */) -> Tensor[(1), float32] {
  %0 = add(10f /* ty=float32 */, 10f /* ty=float32 */) /* ty=float32 */;
  %1 = multiply(%0, 2f /* ty=float32 */) /* ty=float32 */;
  multiply(%1, %a1) /* ty=Tensor[(1), float32] */
}
优化后 IR
def @main(%a1: Tensor[(1), float32] /* ty=Tensor[(1), float32] */) -> Tensor[(1), float32] {
  multiply(40f /* ty=float32 */, %a1) /* ty=Tensor[(1), float32] */
}
常量折叠优化后,指令数明显减少
xx

3. 实现分析

tvm Relay 支持3种类型节点的常量折叠:LetNode、TupleItemGetNode 和 CallNode、IfNode

3.1. MixedModeMutator 和 ExprMutator

常量折叠需要修改计算图,因此涉及到2个很重要的类
1)ExprMutator
ExprMutator 和 ExprVisitor 的唯一区别就是,ExprMutator 增加了一个 memo_ 成员,用来记录遍历到的节点
每个节点的遍历结果,要么是原节点(不需要修改),要么是一个新的节点(需要修改),如果是一个新的节点,后序所有依赖这个节点的都会被替换成新的实现
所以 ExprMutator 的 VisitExpr 会返回一个新的 Expr,而 ExprVisitor 返回 void
Expr ExprMutator::VisitExpr(const Expr& expr) {
  auto it = this->memo_.find(expr);
  if (it != this->memo_.end()) {
    return it->second;
  } else {
    Expr new_expr = ExprFunctor::VisitExpr(expr);
    memo_[expr] = new_expr;
    return new_expr;
  }
}
2)MixedModeMutator
这个类说实话没看懂具体有什么用
从类的定义来看,继承了 ExprMutator,然后把 TupleNode、CallNode、TupleGetItemNode 这3类节点,通过调用 Rewrite() 来实现了
  Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
  Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
  Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };

3.2. LetNode

tvm 支持 Let 表达式,不过我还没理解
todo(后序补充)

3.3. CallNode

CallNode 节点的折叠实现在 Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final 这个函数里面
  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
    Call pre_call = GetRef(pre_call_node);
    if (inside_primitive_) {
      return std::move(pre_call);
    }

    Call post_call = Downcast(post);

    if (post_call->args.empty()) {
      // We don't constant fold function with zero arguments.
      // This is a heuristic that is useful.
      // For example it is harmful to fold ones(shape=(4, 5)).
      return std::move(pre_call);
    }

    const auto* op_node = post_call->op.as();
    if (op_node == nullptr) {
      // Only evaluate primitives.
      return std::move(post_call);
    }
    Op op = GetRef(op_node);
    static auto op_stateful = Op::GetAttrMap("TOpIsStateful");
    if (op_stateful.get(op, false)) {
      // skip stateful ops.
      return std::move(post_call);
    }
    // ...
    if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) {
      // At least one non-constant argument.
      return std::move(post_call);
    }
    // During evaluation we have obviously lost all on_device annotations. However any
    // on_device wrapping this call will be left in place.
    return ConstEvaluate(post_call);
  }
首先由一些情况,是不会执行折叠操作的,比较多:
  1. EvaluateShapeOf
  2. TNonComputational 相关,细节待确认
  3. FTVMQnnCanonicalize 相关,细节待确认
  4. 有状态节点
  5. 参数为空
除此之外,CallNode 的常量折叠逻辑非常简单:判断当前CallNode的所有参数是否为常量,如果是,直接原地执行函数计算结果,计算过程在 ConstEvaluate() 函数里面,不展开了

3.4. TupleItemGetNode

这个折叠是最简单,直接返回Tuple对应的元素
  Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node,
                const Expr& post_tuple_get_item) final {
    const auto* post_tuple_get_item_node = post_tuple_get_item.as();
    if (const auto* tuple_node = AsIgnoringOnDevice(post_tuple_get_item_node->tuple)) {
      Expr result = tuple_node->fields[tuple_get_item_node->index];
      OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple);
      if (props.body.defined()) {
        // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D)
        return MaybeOnDeviceWithProps(result, props);
      } else {
        return result;
      }
    }
    return post_tuple_get_item;
  }

3.5. IfNode

如果 If 表达式本身就是一个常量,那直接返回 true branch,否则返回 false branch,对应的表达式
  Expr VisitExpr_(const IfNode* if_node) final {
    If new_if = Downcast(ExprMutator::VisitExpr_(if_node));
    if (const auto* const_node = AsIgnoringOnDevice(new_if->cond)) {
      if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) {
        return new_if->true_branch;
      } else {
        return new_if->false_branch;
      }
    }
    return std::move(new_if);
  }
发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注