深入浅出 tvm – (6) Relay IR:一种 high-level 的中间表示

前面一章我们已经了解了 relay ir 和 tir 中公共的 ir 概念,现在我们继续深入一下 relay ir
relay ir 是 tvm 最高层次的 IR,也是最接近前端的一种计算图表示方式
相比传统的基于 Graph 的 High Level IR, Relay 引入了不少函数式编程的概念,带来了强大的表达能力。
但是相比官方文档,这个写的更好:https://zhuanlan.zhihu.com/p/446976730

1. Type

include/tvm/relay/type.h
relay 并没有定义任何新的 Type 类型,全部直接服用 tvm/ir 里的定义,因此这里不展开描述
// namespace update for backward compact
// will be removed later.
using AnyNode = tvm::tir::AnyNode;
using Any = tvm::tir::Any;
using Kind = TypeKind;
using Type = tvm::Type;
using TypeNode = tvm::TypeNode;
using TypeVar = tvm::TypeVar;
using TypeVarNode = tvm::TypeVarNode;
using GlobalTypeVar = tvm::GlobalTypeVar;
using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
using TupleType = tvm::TupleType;
using TupleTypeNode = tvm::TupleTypeNode;
using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
using IncompleteType = tvm::IncompleteType;
using IncompleteTypeNode = tvm::IncompleteTypeNode;
using RelayRefType = tvm::RelayRefType;
using RelayRefTypeNode = tvm::RelayRefTypeNode;
using TensorType = tvm::TensorType;
using TensorTypeNode = tvm::TensorTypeNode;
using TypeCall = tvm::TypeCall;
using TypeCallNode = tvm::TypeCallNode;
using TypeRelation = tvm::TypeRelation;
using TypeRelationNode = tvm::TypeRelationNode;
using TypeRelationFn = tvm::TypeRelationFn;
using TypeReporter = tvm::TypeReporter;
using TypeReporterNode = tvm::TypeReporterNode;

2. Expr

relay 里面的所有 Expr 都是继承自 ir 里面的 RelayExpr 类
namespace tvm {
namespace relay {

using Expr = tvm::RelayExpr;
relay 定义的 Expr 列表如下:
  1. Constant:常量
  2. Tuple:数组,由N个Expr构成
  3. Var:局部变量,计算图最常见的结构之一
  4. Call:算子调用,计算图最常见的结构之一
  5. Let:Let binding
  6. If:条件表达式
  7. TupleGetItem
  8. RefCreate
  9. RefRead
  10. RefWrite
  11. TempExpr
比如 Tuple
class Tuple : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param fields The fields of a tuple.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span());

  TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};

其他的比如 Constant/TupleGetItem 这样比较简单的 Expr, 就不介绍了,有兴趣的同学可以阅读 include/tvm/relay/expr.h
再来看下 Var 的定义
class VarNode : public ExprNode {
 public:
  /*!
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
   */
  Id vid;
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;
  // ...
}
Id 可以简单理解为一个字符串,是一个全局唯一的标识。你可以改下代码,增加一些调试,打印出来是这样的:
fc1_weight
fc1_bias
stage1_unit1_bn1_gamma
stage1_unit1_bn1_beta
另外这里有个地方没太理解,Var的值是存在哪里的?


深入浅出 tvm – (5) IR 公共的一些核心概念

在深入 OptimizeImpl 阶段,也就是 Pass 优化之前,我们先了解一下 Relay 阶段的一些基本概念
OptimizeImpl 阶段最主要的工作就是图优化,图其实就是一种高级的表示,tvm 和图相关的概念好好几个:
  1. relay ir:这是 tvm 最 high level 的表示,另外 relay ir -> tir 的过程中,又会依赖 topi 和 te 这2种特定抽象的中间表示
    1. topi:TVM Operator Inventory,TOPI 提供了比 tir 具有更高抽象的 numpy 风格的,通用操作和调度
    1. te:Tensor Expression,张量表达式
  1. tir:最接近目标代码的中间表示
  2. ir:relay ir 和 tir 的一些公共基础结构,和上述2种ir不一样,并不是一个完整独立的抽象
本章我们先来了解下 IR 这个 relay ir 和 tir 最公共的基础设施,后续会依次介绍 relay ir、tir、topi、te
代码目录:
  • 代码:src/ir
  • 头文件:include/tvm/ir
编程语言最基本的核心概念就3个:类型、运算符、表达式,在 IR 这里分别对应 Type, OP, Expr

1.1 Type

Type 相关的定义都在 include/tvm/ir/type.h ,Type 包括基础的整型/浮点型等,也包括函数类型等相对复杂的类型。
这里我们介绍2种基本的类型:
  1. PrimType:最原始的 Type,可以直接映射到 low-level IR 的基本数据类型
  2. FuncType:函数类型
PrimType 可以在这上面做一些 Low-level 优化
定义如下:
class PrimTypeNode : public TypeNode {
 public:
  /*!
   * \brief The corresponding dtype field.
   */
  runtime::DataType dtype;
}
可以看到 PrimType 就一个数据成员,runtime::DataType,这个是 runtime 最底层的概念,代码在 include/tvm/runtime/data_type.h
/*!
 * \brief Runtime primitive data type.
 *
 *  This class is a thin wrapper of DLDataType.
 *  We also make use of DataType in compiler to store quick hint
 */
class DataType {
 public:
  /*!
   * \brief Type code for the DataType.
   *
   * DLPack consistency:
   * 1) kInt is consistent with kDLInt
   * 2) kUInt is consistent with kDLUInt
   * 3) kFloat is consistent with kDLFloat
   */
  enum TypeCode {
    kInt = kDLInt,
    kUInt = kDLUInt,
    kFloat = kDLFloat,
    kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
    kBFloat = kDLBfloat,
    kCustomBegin = 129
  };
  /*! \brief default constructor */
  DataType() { data_ = DataType::Void(); }
  /*!
   * \brief Constructor
   * \param dtype The DLDataType
   */
  explicit DataType(DLDataType dtype) : data_(dtype) {}