ADT 的全称是 Algebraic data type
ADT 是函数式编程的一个重要特征,也是 tvm relay 的一个高级特性。但是目前 ADT 看起来还是 tvm 的一个内部能力,还不能像我们构建计算图一样直接用文本描述出来,因此官方文档里面的示例都是只用来参考的
由于目前还不能通过调试的方式,深入的了解 ADT,所以本文只能算是对官文文档的一个解读,方便大家对 ADT 有一个初步的概念
参考:
1. ADT 的基本概念
1.1. 定义和匹配
简单来说,一个ADT定义含有多个构造函数,每个构造函数有不同的参数类型。每个ADT实例构造出来后只是简单地包含其构造时的所有参数值
重点:最后一句。这个和 c++ 的结构体还不太一样。在c++里面,可以允许有多个构造函数,但是结构体本身是不变的,成员可以被默认初始化。但是在这里,ADT 你可以理解为就是一个 tagged unions
如下:
Numbers 是一个 ADT,具有3个构造函数
# Defines an ADT named "Numbers" data Numbers { Empty : () -> Numbers Single : (Tensor[(), int32]) -> Numbers Pair : (Tensor[(), int32], Tensor[(), int32]) -> Numbers } # A Numbers value can be produced using an Empty, Single, or Pair # constructor, each with a signature given above
由于 ADT 的值,在析构之前是不确定的,因此,当一个函数接收 ADT 实例作为参数时,我们必须根据其构造函数的类型,来决定下一步做什么
这就有了 match 语法,如下:
def @sum(%n : Numbers[]) -> Tensor[(), int32] { # The match expression branches on the constructor that was # used to produce %n. The variables in each case are bound # if the constructor matches that used for %n match(%n) { case Empty() { 0 } case Single(x) { x } case Pair(x, y) { x + y } } } @sum(Empty()) # evaluates to 0 @sum(Single(3)) # evaluates to 3 @sum(Pair(5, 6)) # evaluates to 11
由于 ADT 是通过名字标识的,意味着2个相同具有相同构造函数的ADT,在类型系统看来仍然是不同的
比如下面,调用 @sum(Empty2()) 就会报错
# structurally identical constructors to Numbers data Numbers2 { Empty2 : () -> Numbers2 Single2 : (Tensor[(), int32]) -> Numbers2 Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2 } # the below results in a type error because Numbers2 # is a distinct type from Numbers # fn() { @sum(Empty2()) }
疑问:Empty()和Empty2() 这2个构造函数,不能说相同吧,名字就不一样。难道相同的定义,是指只要输入输出相同吗?
再发散一下,这样声明是否会报错?
Numbers 有一个 Empty 构造函数,Numbers2 也有一个 Empty 构造函数
# structurally identical constructors to Numbers data Numbers2 { Empty : () -> Numbers2 Single2 : (Tensor[(), int32]) -> Numbers2 Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2 }
1.2. 类型检查和多态
和函数一样,ADT可以是多态的,并且可以接受类型参数,这是ADT最复杂的特性
这个有点像 C++ Template,如下定义了一个类似 C++ std::optional 的 ADT
# a is a type parameter data Optional<a> { None : () -> Optional Some : (a) -> Optional }
Optional[type1] 和 Optional[Type2] 会被类型系统认为是两个不同的类型
类型系统是怎么做到的呢?
在类型系统里面,任何一个构造函数,都有一个唯一的签名,这个签名包括了函数的输入输出,类型定义
比如:
一个 ADT D,其定义支持类型参数,v1,…,vn
D的构造函数C,其输入参数T1, …, Tn(用c++的模板来理解的话,v1, …, vn 是 T1, …, Tn 的一个子集,这个地方不懂的话,写几个c++模板函数应该就理解了)
# a is a type parameter data D<v1, ..., vn> { C : (T1, ..., Tn) -> D }
那么对类型系统来说,这个C的签名就是:fun<v1, …, vn>(T1, …, Tn) -> D[v1, …, vn]
再回到上面 Optional 那个例子,于是有:
- Some 的签名是:fun(a) -> Optional[a]
- None 的签名是:fun() -> Optional[a]
根据签名,类型系统就能做正确的类型检查
比如,对于函数:
# the signature for option indicates the type argument def @inc_scalar(%opt : Optional[Tensor[(), int32]]) -> Tensor[(), int32] { match(%opt) { case None() { 1 } case Some(%s) { %s + 1 } } }
下面这是可以的:
def @main() { let %one : Optional[Tensor[(), int32]] = Some(1); let %two = inc_scalar(%one); let %z = inc_scalar(None()); () }
这是不行的,因为 Optional[Tensor[(10, 10), float32]] 不是 inc_scalar 所允许的输入类型
def @main() { let %big : Optional[Tensor[(10, 10), float32]] = Some(Constant(1, (10, 10), float32)); let %bigger = inc_scalar(%big); }
xx
1.3. 递归
ADT 支持递归定义,如下通过递归定义一个 List
data List { Nil : () -> List Cons : (T, List[T]) -> List }
我们可以很方便地对 List 进行求和,代码如下
def @list_sum(%l : List[Tensor[(), int32]]) -> Tensor[(), int32] { match(%l) { case Nil() { 0 } # add the head of the list to the sum of the tail case Cons(%h, %t) { %h + @list_sum(%t) } } }
其中 Nil 表示空列表
关于 match 语法最后补充一点,match 语法会按照 case 顺序去匹配,匹配到第一个合法的case就会结束,后面的case不会再有匹配的机会。
1.4. 表达式匹配中的模糊匹配
xx
2. ADT 的具体实现
接下来我们看下 ADT 在 tvm 中的具体实现
xx