深入浅出 tvm – (10) Relay Pass 之剔除无用函数

我们从一个最简单的 Pass 开始入手,RemoveUnusedFunctions,带大家一步一步了解 Pass 优化的具体过程
这个 Pass 看名字顾名思义,就是要把 Module 里执行不到的函数删除掉
代码:src/relay/transforms/dead_code.cc

1. 无用函数的定义

无用函数是指那些根本不会调用到的函数
我们用 tvm.relay 手写一个最简单的网络,来跟踪优化的过程和效果
如下,有2个函数,f1 和 f2,其中 f1 就是无用的函数,因为 main 函数进来后,根本不会调用到这个函数
import tvm
from tvm import relay
from tvm.contrib import relay_viz
 
a1 = relay.var("a1", shape=(1,), dtype="float32")
a2 = relay.var("a2", shape=(1,), dtype="float32")
add_op = relay.add(a1, a2)
f1 = relay.Function([a1, a2], add_op)

d1 = relay.var("d1", shape=(1, 32, 56, 56), dtype="float32")
w1 = relay.var("w1", shape=(32, 32, 3, 3), dtype="float32")
b1 = relay.var("b1", shape=(32,), dtype="float32")
conv = relay.nn.conv2d(d1, w1, strides=(1, 1), padding=(1, 1))
bias = relay.nn.bias_add(conv, b1)
relu = relay.nn.relu(bias)
f2 = relay.Function([d1, w1, b1], relu)

mod = tvm.IRModule({'add_func': f1, 'main': f2})
mod = relay.transform.InferType()(mod)

mod = relay.transform.RemoveUnusedFunctions()(mod)
优化完之后,右边的图就被去掉了(图片的可视化,参考之前的 Relay 计算图)
0