深入浅出 tvm – (7) Relay IR 之 ADT

ADT 的全称是 Algebraic data type
ADT 是函数式编程的一个重要特征,也是 tvm relay 的一个高级特性。但是目前 ADT 看起来还是 tvm 的一个内部能力,还不能像我们构建计算图一样直接用文本描述出来,因此官方文档里面的示例都是只用来参考的
由于目前还不能通过调试的方式,深入的了解 ADT,所以本文只能算是对官文文档的一个解读,方便大家对 ADT 有一个初步的概念
参考:
0

1. ADT 的基本概念

1.1. 定义和匹配

简单来说,一个ADT定义含有多个构造函数,每个构造函数有不同的参数类型。每个ADT实例构造出来后只是简单地包含其构造时的所有参数值
重点:最后一句。这个和 c++ 的结构体还不太一样。在c++里面,可以允许有多个构造函数,但是结构体本身是不变的,成员可以被默认初始化。但是在这里,ADT 你可以理解为就是一个 tagged unions
如下:
Numbers 是一个 ADT,具有3个构造函数
# Defines an ADT named "Numbers"
data Numbers {
  Empty : () -> Numbers
  Single : (Tensor[(), int32]) -> Numbers
  Pair : (Tensor[(), int32], Tensor[(), int32]) -> Numbers
}
# A Numbers value can be produced using an Empty, Single, or Pair
# constructor, each with a signature given above
由于 ADT 的值,在析构之前是不确定的,因此,当一个函数接收 ADT 实例作为参数时,我们必须根据其构造函数的类型,来决定下一步做什么
这就有了 match 语法,如下:
def @sum(%n : Numbers[]) -> Tensor[(), int32] {
   # The match expression branches on the constructor that was
   # used to produce %n. The variables in each case are bound
   # if the constructor matches that used for %n
   match(%n) {
     case Empty() { 0 }
     case Single(x) { x }
     case Pair(x, y) { x + y }
   }
}

@sum(Empty())    # evaluates to 0
@sum(Single(3))  # evaluates to 3
@sum(Pair(5, 6)) # evaluates to 11
由于 ADT 是通过名字标识的,意味着2个相同具有相同构造函数的ADT,在类型系统看来仍然是不同的
比如下面,调用 @sum(Empty2()) 就会报错
# structurally identical constructors to Numbers
data Numbers2 {
  Empty2 : () -> Numbers2
  Single2 : (Tensor[(), int32]) -> Numbers2
  Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2
}

# the below results in a type error because Numbers2
# is a distinct type from Numbers
# fn() { @sum(Empty2()) }