前面一章我们已经了解了 relay ir 和 tir 中公共的 ir 概念,现在我们继续深入一下 relay ir
relay ir 是 tvm 最高层次的 IR,也是最接近前端的一种计算图表示方式
相比传统的基于 Graph 的 High Level IR, Relay 引入了不少函数式编程的概念,带来了强大的表达能力。
但是相比官方文档,这个写的更好:https://zhuanlan.zhihu.com/p/446976730
1. Type
include/tvm/relay/type.h
relay 并没有定义任何新的 Type 类型,全部直接服用 tvm/ir 里的定义,因此这里不展开描述
// namespace update for backward compact // will be removed later. using AnyNode = tvm::tir::AnyNode; using Any = tvm::tir::Any; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; using TypeVar = tvm::TypeVar; using TypeVarNode = tvm::TypeVarNode; using GlobalTypeVar = tvm::GlobalTypeVar; using GlobalTypeVarNode = tvm::GlobalTypeVarNode; using TupleType = tvm::TupleType; using TupleTypeNode = tvm::TupleTypeNode; using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; using IncompleteType = tvm::IncompleteType; using IncompleteTypeNode = tvm::IncompleteTypeNode; using RelayRefType = tvm::RelayRefType; using RelayRefTypeNode = tvm::RelayRefTypeNode; using TensorType = tvm::TensorType; using TensorTypeNode = tvm::TensorTypeNode; using TypeCall = tvm::TypeCall; using TypeCallNode = tvm::TypeCallNode; using TypeRelation = tvm::TypeRelation; using TypeRelationNode = tvm::TypeRelationNode; using TypeRelationFn = tvm::TypeRelationFn; using TypeReporter = tvm::TypeReporter; using TypeReporterNode = tvm::TypeReporterNode;
2. Expr
relay 里面的所有 Expr 都是继承自 ir 里面的 RelayExpr 类
namespace tvm {
namespace relay {
using Expr = tvm::RelayExpr;
relay 定义的 Expr 列表如下:
- Constant:常量
- Tuple:数组,由N个Expr构成
- Var:局部变量,计算图最常见的结构之一
- Call:算子调用,计算图最常见的结构之一
- Let:Let binding
- If:条件表达式
- TupleGetItem
- RefCreate
- RefRead
- RefWrite
- TempExpr
比如 Tuple
class Tuple : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param fields The fields of a tuple.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span());
  TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};
其他的比如 Constant/TupleGetItem 这样比较简单的 Expr, 就不介绍了,有兴趣的同学可以阅读 include/tvm/relay/expr.h
再来看下 Var 的定义
class VarNode : public ExprNode {
 public:
  /*!
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
   */
  Id vid;
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;
  // ...
}
Id 可以简单理解为一个字符串,是一个全局唯一的标识。你可以改下代码,增加一些调试,打印出来是这样的:
fc1_weight
fc1_bias
stage1_unit1_bn1_gamma
stage1_unit1_bn1_beta
另外这里有个地方没太理解,Var的值是存在哪里的?
接下来我们看函数,这是之前提到的 BaseFunc 在 Relay 中的具体实现
class FunctionNode : public BaseFuncNode {
 public:
  /*! \brief Function parameters */
  tvm::Array params;
  /*!
   * \brief
   * The expression which represents the computation of the function,
   * the expression may reference the parameters, and the type of it
   * or sub-expressions may reference the type variables.
   */
  Expr body;
  /*! \brief User annotated return type of the function. */
  Type ret_type;
  /*!
   * \brief Type parameters of the function.
   *  Enables the function to vary its type based on these.
   *  This corresponds to template paramaters in c++'s terminology.
   *
   * \note This can be usually empty for non-polymorphic functions.
   */
  tvm::Array type_params;
  // ...
}
FunctionNode 的定义还是比较清楚的,body 标识函数的计算逻辑,其实就是一个表达式 Expr,params是输入参数,ret_type 标识返回值类型
再看看 Call,也就是函数调用
/*!
 * \brief Call corresponds to operator invocation.
 *  Corresponds to the operator in computational graph terminology.
 */
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
 protected:
  // CallNode uses own deleter to indirectly call non-recursive destructor
  Object::FDeleter saved_deleter_;
  static void Deleter_(Object* ptr);
 public:
  /*!
   * \brief The operator(function) being invoked
   *
   *  - It can be tvm::Op which corresponds to the primitive operators.
   *  - It can also be user defined functions (Function, GlobalVar, Var).
   */
  Expr op;
  /*! \brief The arguments(inputs) of the call */
  tvm::Array args;
  /*! \brief The additional attributes */
  Attrs attrs;
  /*!
   * \brief The type arguments passed to polymorphic(template) function.
   *
   * This is the advance feature that is only used when the function is
   * polymorphic. It is safe to be ignored in most cases. For example, in the
   * following code, the type_args of addone call is [int].
   *
   * \code
   *
   * template
   * T addone(T a) { return a + 1; }
   *
   * void main() {
   *   int x = addone(10);
   * }
   *
   * \endcode
   */
  tvm::Array type_args;
  // ...
}
但是我没看懂 CallNode 和 FunctionNode 具体的区别是啥?
FunctionNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> tvm::RelayExprNode -> BaseFuncNode -> tvm::relay::FunctionNode
CallNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> ExprNode -> tvm::relay::CallNode
再看看 If,这个是比较简单的
/*!
 * \brief Condition expression
 *
 * Unlike traditional statement `if`s, the if evalutes
 * to the result of the branch taken.
 *
 * let x = if (true) { 1 } else { 0 }; // x is 1
 * let y = if (false) { 1 } else { 0 }; // y is 0
 *
 * \note This is similar to C's ternary operator.
 */
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
 public:
  /*! \brief The condition */
  Expr cond;
  /*! \brief The expression evaluated when condition is true. */
  Expr true_branch;
  /*! \brief The expression evaluated when condition is false */
  Expr false_branch;
  // ...
}
3. Op
relay OP 完全复用了上一章 ir 里的 OP 机制,没有增加任何新的东西,其头文件定义:
#include <tvm/ir/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
namespace tvm {
namespace relay {
using Op = tvm::Op;
using OpNode = tvm::OpNode;