深入浅出 tvm – (4) Relay 计算图

在深入了解 IR 以及 relay.IR 之前,我们先对计算图在 relay 里的表示有一个直观的认识
幸运的是,relay 提供一个可视化组件 Relay Visualizer,帮我们了解其计算图的内部结构:https://tvm.apache.org/docs/how_to/work_with_relay/using_relay_viz.html
获得计算图有2种方式:
  1. 加载开源模型,tvm提供一些便捷的函数
  2. 自定义手写一个模型
我们看下

1. 开源模型

我们可以通过 relay.testing 模块获取一些常见的模型,具体可以看 python/tvm/relay/testing 模块,目前支持的模型有:
  1. resnet
  2. resnet_3d
  3. mobilenet
  4. mlp
  5. lstm
  6. synthetic
比如,我们想看一下 resnet18 模型的图结果,可以这么看
from tvm import relay
from tvm.relay import testing
import tvm
from tvm.contrib import relay_viz

# Resnet18 model
resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18)

viz = relay_viz.RelayVisualizer(
    resnet18_mod,
    relay_param=resnet18_params,
    plotter=relay_viz.DotPlotter(),
    parser=relay_viz.DotVizParser())
viz.render('1')
viz.render 会在当前目录生成 1.pdf 的文件,打开即可得到完整的计算图,如下:
0
这个计算图非常大,这里只展示部分,从这个图里面我们可以看到,整个图基本上只有2类节点:
  1. Var:本地变量
  2. Call:算子调用
除了 Var, Call 之外,Relay 的计算图还包含 Function, GlobalVar, Tuple 等其他节点类型,后面我们将深入展开描述

2. 自定义模型

但是上述的模型还是非常大,非常复杂,光是节点就上千个,不方便我们学习研究。
我们在用 pytorch,自定义 class,并声明 forward() 函数的时候,就是在构建一个前向网络
class RegLSTM(nn.Module):


    def forward(self, x):
        y = self.rnn(x)[0]

        seq_len, batch_size, hid_dim = y.shape
        y = y.view(-1, hid_dim)
        y = self.reg(y)
        y = y.view(seq_len, batch_size, -1)
        return y
不过 tvm 是面向编译器的,就算是构建 lstm 还是有一点复杂的,可以直接参考:python/tvm/relay/testing/lstm.py 的 get_net() 函数实现
参考 tvm 内置的 lstm,我们来写个更简单的:
from tvm import relay
from tvm.relay import testing
import tvm
from tvm.contrib import relay_viz

data = relay.var("data")
bias = relay.var("bias")
add_op = relay.add(data, bias)
add_func = relay.Function([data, bias], add_op)
add_gvar = relay.GlobalVar("AddFunc")

input0 = relay.var("input0")
input1 = relay.var("input1")
input2 = relay.var("input2")
add_01 = relay.Call(add_gvar, [input0, input1])
add_012 = relay.Call(add_gvar, [input2, add_01])
main_func = relay.Function([input0, input1, input2], add_012)
main_gvar = relay.GlobalVar("main")

mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func})

viz = relay_viz.RelayVisualizer(
    mod,
    plotter=relay_viz.DotPlotter(),
    parser=relay_viz.DotVizParser())

viz.render('1')

得到如下结构的图:
0
大家在学习 tvm 的过程中,比如研究各种 Pass 优化,可以手写一些简单的模型,研究下,这样调试会简单很多
发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注