我们从一个最简单的 Pass 开始入手,RemoveUnusedFunctions,带大家一步一步了解 Pass 优化的具体过程
这个 Pass 看名字顾名思义,就是要把 Module 里执行不到的函数删除掉
代码:src/relay/transforms/dead_code.cc
1. 无用函数的定义
无用函数是指那些根本不会调用到的函数
我们用 tvm.relay 手写一个最简单的网络,来跟踪优化的过程和效果
如下,有2个函数,f1 和 f2,其中 f1 就是无用的函数,因为 main 函数进来后,根本不会调用到这个函数
import tvm from tvm import relay from tvm.contrib import relay_viz a1 = relay.var("a1", shape=(1,), dtype="float32") a2 = relay.var("a2", shape=(1,), dtype="float32") add_op = relay.add(a1, a2) f1 = relay.Function([a1, a2], add_op) d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32") w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32") b1 = relay.var("b1", shape=(32,), dtype="float32") conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1)) bias = relay.nn.bias_add(conv, b1) relu = relay.nn.relu(bias) f2 = relay.Function([d1, w1, b1], relu) mod = tvm.IRModule({'add_func': f1, 'main': f2}) mod = relay.transform.InferType()(mod) mod = relay.transform.RemoveUnusedFunctions()(mod)
优化完之后,右边的图就被去掉了(图片的可视化,参考之前的 Relay 计算图)