深入浅出 tvm – (7) Relay IR 之 ADT

ADT 的全称是 Algebraic data type
ADT 是函数式编程的一个重要特征,也是 tvm relay 的一个高级特性。但是目前 ADT 看起来还是 tvm 的一个内部能力,还不能像我们构建计算图一样直接用文本描述出来,因此官方文档里面的示例都是只用来参考的
由于目前还不能通过调试的方式,深入的了解 ADT,所以本文只能算是对官文文档的一个解读,方便大家对 ADT 有一个初步的概念
参考:
0

1. ADT 的基本概念

1.1. 定义和匹配

简单来说,一个ADT定义含有多个构造函数,每个构造函数有不同的参数类型。每个ADT实例构造出来后只是简单地包含其构造时的所有参数值
重点:最后一句。这个和 c++ 的结构体还不太一样。在c++里面,可以允许有多个构造函数,但是结构体本身是不变的,成员可以被默认初始化。但是在这里,ADT 你可以理解为就是一个 tagged unions
如下:
Numbers 是一个 ADT,具有3个构造函数
# Defines an ADT named "Numbers"
data Numbers {
  Empty : () -> Numbers
  Single : (Tensor[(), int32]) -> Numbers
  Pair : (Tensor[(), int32], Tensor[(), int32]) -> Numbers
}
# A Numbers value can be produced using an Empty, Single, or Pair
# constructor, each with a signature given above
由于 ADT 的值,在析构之前是不确定的,因此,当一个函数接收 ADT 实例作为参数时,我们必须根据其构造函数的类型,来决定下一步做什么
这就有了 match 语法,如下:
def @sum(%n : Numbers[]) -> Tensor[(), int32] {
   # The match expression branches on the constructor that was
   # used to produce %n. The variables in each case are bound
   # if the constructor matches that used for %n
   match(%n) {
     case Empty() { 0 }
     case Single(x) { x }
     case Pair(x, y) { x + y }
   }
}

@sum(Empty())    # evaluates to 0
@sum(Single(3))  # evaluates to 3
@sum(Pair(5, 6)) # evaluates to 11
由于 ADT 是通过名字标识的,意味着2个相同具有相同构造函数的ADT,在类型系统看来仍然是不同的
比如下面,调用 @sum(Empty2()) 就会报错
# structurally identical constructors to Numbers
data Numbers2 {
  Empty2 : () -> Numbers2
  Single2 : (Tensor[(), int32]) -> Numbers2
  Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2
}

# the below results in a type error because Numbers2
# is a distinct type from Numbers
# fn() { @sum(Empty2()) }

疑问:Empty()和Empty2() 这2个构造函数,不能说相同吧,名字就不一样。难道相同的定义,是指只要输入输出相同吗?
再发散一下,这样声明是否会报错?
Numbers 有一个 Empty 构造函数,Numbers2 也有一个 Empty 构造函数
# structurally identical constructors to Numbers
data Numbers2 {
  Empty : () -> Numbers2
  Single2 : (Tensor[(), int32]) -> Numbers2
  Pair2 : (Tensor[(), int32], Tensor[(), int32]) -> Numbers2
}

1.2. 类型检查和多态

和函数一样,ADT可以是多态的,并且可以接受类型参数,这是ADT最复杂的特性
这个有点像 C++ Template,如下定义了一个类似 C++ std::optional 的 ADT
# a is a type parameter
data Optional<a> {
  None : () -> Optional
  Some : (a) -> Optional
}
Optional[type1] 和 Optional[Type2] 会被类型系统认为是两个不同的类型
类型系统是怎么做到的呢?
在类型系统里面,任何一个构造函数,都有一个唯一的签名,这个签名包括了函数的输入输出,类型定义
比如:
一个 ADT D,其定义支持类型参数,v1,…,vn
D的构造函数C,其输入参数T1, …, Tn(用c++的模板来理解的话,v1, …, vn 是 T1, …, Tn 的一个子集,这个地方不懂的话,写几个c++模板函数应该就理解了)
# a is a type parameter
data D<v1, ..., vn> {
  C : (T1, ..., Tn) -> D
}
那么对类型系统来说,这个C的签名就是:fun<v1, …, vn>(T1, …, Tn) -> D[v1, …, vn]
再回到上面 Optional 那个例子,于是有:
  1. Some 的签名是:fun(a) -> Optional[a]
  2. None 的签名是:fun() -> Optional[a]
根据签名,类型系统就能做正确的类型检查
比如,对于函数:
# the signature for option indicates the type argument
def @inc_scalar(%opt : Optional[Tensor[(), int32]]) -> Tensor[(), int32] {
  match(%opt) {
    case None() { 1 }
    case Some(%s) { %s + 1 }
  }
}
下面这是可以的:
def @main() {
  let %one : Optional[Tensor[(), int32]] = Some(1);
  let %two = inc_scalar(%one);
  let %z = inc_scalar(None());
  ()
}
这是不行的,因为 Optional[Tensor[(10, 10), float32]] 不是 inc_scalar 所允许的输入类型
def @main() {
  let %big : Optional[Tensor[(10, 10), float32]]
    = Some(Constant(1, (10, 10), float32));
  let %bigger = inc_scalar(%big);
}
xx

1.3. 递归

ADT 支持递归定义,如下通过递归定义一个 List
data List {
   Nil : () -> List
   Cons : (T, List[T]) -> List
}
我们可以很方便地对 List 进行求和,代码如下
def @list_sum(%l : List[Tensor[(), int32]]) -> Tensor[(), int32] {
  match(%l) {
    case Nil() { 0 }
    # add the head of the list to the sum of the tail
    case Cons(%h, %t) { %h + @list_sum(%t) }
  }
}
其中 Nil 表示空列表
关于 match 语法最后补充一点,match 语法会按照 case 顺序去匹配,匹配到第一个合法的case就会结束,后面的case不会再有匹配的机会。

1.4. 表达式匹配中的模糊匹配

xx

2. ADT 的具体实现

接下来我们看下 ADT 在 tvm 中的具体实现
xx
发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注