在深入了解 IR 以及 relay.IR 之前,我们先对计算图在 relay 里的表示有一个直观的认识
幸运的是,relay 提供一个可视化组件 Relay Visualizer,帮我们了解其计算图的内部结构:https://tvm.apache.org/docs/how_to/work_with_relay/using_relay_viz.html
获得计算图有2种方式:
- 加载开源模型,tvm提供一些便捷的函数
- 自定义手写一个模型
我们看下
1. 开源模型
我们可以通过 relay.testing 模块获取一些常见的模型,具体可以看 python/tvm/relay/testing 模块,目前支持的模型有:
- resnet
- resnet_3d
- mobilenet
- mlp
- lstm
- synthetic
比如,我们想看一下 resnet18 模型的图结果,可以这么看
from tvm import relay from tvm.relay import testing import tvm from tvm.contrib import relay_viz # Resnet18 model resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) viz = relay_viz.RelayVisualizer( resnet18_mod, relay_param=resnet18_params, plotter=relay_viz.DotPlotter(), parser=relay_viz.DotVizParser()) viz.render('1')
viz.render 会在当前目录生成 1.pdf 的文件,打开即可得到完整的计算图,如下:
这个计算图非常大,这里只展示部分,从这个图里面我们可以看到,整个图基本上只有2类节点:
- Var:本地变量
- Call:算子调用
除了 Var, Call 之外,Relay 的计算图还包含 Function, GlobalVar, Tuple 等其他节点类型,后面我们将深入展开描述
2. 自定义模型
但是上述的模型还是非常大,非常复杂,光是节点就上千个,不方便我们学习研究。
我们在用 pytorch,自定义 class,并声明 forward() 函数的时候,就是在构建一个前向网络
class RegLSTM(nn.Module): def forward(self, x): y = self.rnn(x)[0] seq_len, batch_size, hid_dim = y.shape y = y.view(-1, hid_dim) y = self.reg(y) y = y.view(seq_len, batch_size, -1) return y
不过 tvm 是面向编译器的,就算是构建 lstm 还是有一点复杂的,可以直接参考:python/tvm/relay/testing/lstm.py 的 get_net() 函数实现
参考 tvm 内置的 lstm,我们来写个更简单的:
from tvm import relay from tvm.relay import testing import tvm from tvm.contrib import relay_viz data = relay.var("data") bias = relay.var("bias") add_op = relay.add(data, bias) add_func = relay.Function([data, bias], add_op) add_gvar = relay.GlobalVar("AddFunc") input0 = relay.var("input0") input1 = relay.var("input1") input2 = relay.var("input2") add_01 = relay.Call(add_gvar, [input0, input1]) add_012 = relay.Call(add_gvar, [input2, add_01]) main_func = relay.Function([input0, input1, input2], add_012) main_gvar = relay.GlobalVar("main") mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) viz = relay_viz.RelayVisualizer( mod, plotter=relay_viz.DotPlotter(), parser=relay_viz.DotVizParser()) viz.render('1')
得到如下结构的图:
大家在学习 tvm 的过程中,比如研究各种 Pass 优化,可以手写一些简单的模型,研究下,这样调试会简单很多