深入浅出 tvm – (16) Tensor Expression 详解

TVM addresses these challenges with three key modules.
(1) We introduce a tensor expression languageto build operators and provide program transformation primitives that generate different versions of the program with various optimizations.
tensor expression 是 tvm 引入的一门面向张量计算的函数式编程语言。前面我们知道 topi 就是基于 te 实现了一系列常见的张量算子库
在 tvm 里面,每个算子都可以表达为一个 te 表达式,te 接受张量作为输入,并返回一个新的张量
这个语言本身是非常简单的:
  1. Type:因为处理的输入输出,都是 tensor,所以只有一种数据类型(不像 relay ir,有 Var、Constant、Function、Call 等各种类型)
  2. Op:操作符也很简单,最常见的就2个 placeholder_op,compute_op
te 的代码主要在 include/tvm/te/ src/te/ 这2个目录,全部加起来才1.3w行,所以说是并不复杂
这篇文章,我们从编译器的视角来理解一下 tensor expression
参考资料:

1. te 的设计目标

te 可能应该有其他更多的设计目标
但从我目前了解到的来看,te 最大的作用应该是简化tensor 算子的编写难度,并支持进一步的调度和优化
比如向量加法
1)TVMScript 版本
在没有 te 之前,比如只有 tir 的时候,用户只能用 TVM script 来编写tensor 算子。如下:
@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        # 我们通过 T.handle 进行数据交换,类似于内存指针
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # 通过 handle 创建 Buffer
        A = T.match_buffer(a, (8,), dtype="float32")
        B = T.match_buffer(b, (8,), dtype="float32")
        for i in range(8):
            # block 是针对计算的抽象
            with T.block("B"):
                # 定义一个空间(可并行)block 迭代器,并且将它的值绑定成 i
                vi = T.axis.spatial(8, i)
                B[vi] = A[vi] + 1.0

ir_module = MyModule
2)tensor expression 版本
但是如果用 te 来描述向量加法的话,只需要2行,并且通过script() 打印出来会发现和 TVM script 是一样的
from tvm import te

A = te.placeholder((8,), dtype="float32", name="A")
B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B")

func = te.create_prim_func([A, B])
ir_module_from_te = IRModule({"main": func})
print(ir_module_from_te.script())