和算子融合一样,常量折叠是编译领域里最常见的一个优化,简单来说,就是把常量表达式前置计算,在编译阶段就计算好,然后以常量的形式翻译成底层机器码,以提高执行效率,减少计算量
实际上大部分的编译器,常量折叠一般包含2种优化技术:常量折叠和常量传播
1. 基本概念
1.1. 常量折叠
constant folding,常量折叠,编译器优化技术之一,通过对编译时常量或常量表达式进行计算来简化代码。以下面的代码为例:
i = 320 * 200 * 32;
上面的代码中,编译器通常会在编译过程中直接对表达式进行求值,计算出320 * 200 * 32的结果,而不会生成2个乘法指令。
还有一些更复杂(但不清楚tvm是否支持,后面验证下)。比如,在执行一些复杂表达式的计算时,我们可以将表达式内部一些常量运算合并,最终起到简化的效果,如下
优化前(左边)每个表达式的运算量是 8 flop,优化后(右边)的运算量是 2 flop,运算效率极大提升了
不过,实际上 tvm 并没有做的那么高级,tvm 的常亮折叠只用来处理一些比较简单的场景
1.2. 常量传播
constant propagation,常量传播,同样也是编译器最常见的优化技术之一,在编译的过程中,对常量依赖做一些基本的推演和提前计算,再使用常量折叠技术来简化代码。如下:
int x = 14; int y = 7 - x / 2; return y * (28 / x + 2); //常量传播 int x = 14; int y = 7 - 14 / 2; return y * (28 / 14 + 2); //常量折叠 int x = 14; int y = 0; return 0;
2. 从一个示例开始
我们写一个简单的,可折叠的示例
import tvm from tvm import relay from tvm.contrib import relay_viz from tvm import nd a1 = relay.var("a1", shape=(1,), dtype="float32") c1 = relay.const(10, 'float32') f1 = relay.add(c1, c1) f2 = relay.multiply(f1, relay.const(2, "float32")) f3 = relay.multiply(f2, a1) mod = tvm.IRModule.from_expr(f3) mod = relay.transform.InferType()(mod) print(mod) mod = relay.transform.FoldConstant()(mod) print(mod)
上述代码执行的时候,启动 VLOG 调试:TVM_LOG_DEBUG=src/relay/transforms/fold_constant.cc=2
得到如下关键信息(我做了一些过滤):
[05:45:37] src/relay/transforms/fold_constant.cc:254: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating: add(10f /* ty=float32 */, 10f /* ty=float32 */) /* ty=float32 */ [05:45:37] src/relay/transforms/fold_constant.cc:273: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant: 20f [05:45:37] src/relay/transforms/fold_constant.cc:254: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluating: multiply(20f, 2f /* ty=float32 */) /* ty=float32 */ [05:45:37] src/relay/transforms/fold_constant.cc:273: FoldConstant: FoldConstantExpr: ConstEvaluate: Evaluated to constant: 40f
看起来做了2个操作:
- 第一步,先把 c1 + c1,在编译阶段,直接计算了,得到结果 c2 = 20
- 第二步,再直接计算 c3 = c2 * 2 = 40
优化前 IR
def @main(%a1: Tensor[(1), float32] /* ty=Tensor[(1), float32] */) -> Tensor[(1), float32] { %0 = add(10f /* ty=float32 */, 10f /* ty=float32 */) /* ty=float32 */; %1 = multiply(%0, 2f /* ty=float32 */) /* ty=float32 */; multiply(%1, %a1) /* ty=Tensor[(1), float32] */ }
优化后 IR
def @main(%a1: Tensor[(1), float32] /* ty=Tensor[(1), float32] */) -> Tensor[(1), float32] { multiply(40f /* ty=float32 */, %a1) /* ty=Tensor[(1), float32] */ }
常量折叠优化后,指令数明显减少
xx
3. 实现分析
tvm Relay 支持3种类型节点的常量折叠:LetNode、TupleItemGetNode 和 CallNode、IfNode
3.1. MixedModeMutator 和 ExprMutator
常量折叠需要修改计算图,因此涉及到2个很重要的类
1)ExprMutator
ExprMutator 和 ExprVisitor 的唯一区别就是,ExprMutator 增加了一个 memo_ 成员,用来记录遍历到的节点
每个节点的遍历结果,要么是原节点(不需要修改),要么是一个新的节点(需要修改),如果是一个新的节点,后序所有依赖这个节点的都会被替换成新的实现
所以 ExprMutator 的 VisitExpr 会返回一个新的 Expr,而 ExprVisitor 返回 void
Expr ExprMutator::VisitExpr(const Expr& expr) { auto it = this->memo_.find(expr); if (it != this->memo_.end()) { return it->second; } else { Expr new_expr = ExprFunctor::VisitExpr(expr); memo_[expr] = new_expr; return new_expr; } }
2)MixedModeMutator
这个类说实话没看懂具体有什么用
从类的定义来看,继承了 ExprMutator,然后把 TupleNode、CallNode、TupleGetItemNode 这3类节点,通过调用 Rewrite() 来实现了
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); }; Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
3.2. LetNode
tvm 支持 Let 表达式,不过我还没理解
todo(后序补充)
3.3. CallNode
CallNode 节点的折叠实现在 Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final 这个函数里面
Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { Call pre_call = GetRef(pre_call_node); if (inside_primitive_) { return std::move(pre_call); } Call post_call = Downcast(post); if (post_call->args.empty()) { // We don't constant fold function with zero arguments. // This is a heuristic that is useful. // For example it is harmful to fold ones(shape=(4, 5)). return std::move(pre_call); } const auto* op_node = post_call->op.as(); if (op_node == nullptr) { // Only evaluate primitives. return std::move(post_call); } Op op = GetRef(op_node); static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); if (op_stateful.get(op, false)) { // skip stateful ops. return std::move(post_call); } // ... if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) { // At least one non-constant argument. return std::move(post_call); } // During evaluation we have obviously lost all on_device annotations. However any // on_device wrapping this call will be left in place. return ConstEvaluate(post_call); }
首先由一些情况,是不会执行折叠操作的,比较多:
- EvaluateShapeOf
- TNonComputational 相关,细节待确认
- FTVMQnnCanonicalize 相关,细节待确认
- 有状态节点
- 参数为空
除此之外,CallNode 的常量折叠逻辑非常简单:判断当前CallNode的所有参数是否为常量,如果是,直接原地执行函数计算结果,计算过程在 ConstEvaluate() 函数里面,不展开了
3.4. TupleItemGetNode
这个折叠是最简单,直接返回Tuple对应的元素
Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node, const Expr& post_tuple_get_item) final { const auto* post_tuple_get_item_node = post_tuple_get_item.as(); if (const auto* tuple_node = AsIgnoringOnDevice(post_tuple_get_item_node->tuple)) { Expr result = tuple_node->fields[tuple_get_item_node->index]; OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D) return MaybeOnDeviceWithProps(result, props); } else { return result; } } return post_tuple_get_item; }
3.5. IfNode
如果 If 表达式本身就是一个常量,那直接返回 true branch,否则返回 false branch,对应的表达式
Expr VisitExpr_(const IfNode* if_node) final { If new_if = Downcast(ExprMutator::VisitExpr_(if_node)); if (const auto* const_node = AsIgnoringOnDevice(new_if->cond)) { if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) { return new_if->true_branch; } else { return new_if->false_branch; } } return std::move(new_if); }