深入浅出 tvm – (6) Relay IR:一种 high-level 的中间表示

前面一章我们已经了解了 relay ir 和 tir 中公共的 ir 概念,现在我们继续深入一下 relay ir
relay ir 是 tvm 最高层次的 IR,也是最接近前端的一种计算图表示方式
相比传统的基于 Graph 的 High Level IR, Relay 引入了不少函数式编程的概念,带来了强大的表达能力。
但是相比官方文档,这个写的更好:https://zhuanlan.zhihu.com/p/446976730

1. Type

include/tvm/relay/type.h
relay 并没有定义任何新的 Type 类型,全部直接服用 tvm/ir 里的定义,因此这里不展开描述
// namespace update for backward compact
// will be removed later.
using AnyNode = tvm::tir::AnyNode;
using Any = tvm::tir::Any;
using Kind = TypeKind;
using Type = tvm::Type;
using TypeNode = tvm::TypeNode;
using TypeVar = tvm::TypeVar;
using TypeVarNode = tvm::TypeVarNode;
using GlobalTypeVar = tvm::GlobalTypeVar;
using GlobalTypeVarNode = tvm::GlobalTypeVarNode;
using TupleType = tvm::TupleType;
using TupleTypeNode = tvm::TupleTypeNode;
using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
using IncompleteType = tvm::IncompleteType;
using IncompleteTypeNode = tvm::IncompleteTypeNode;
using RelayRefType = tvm::RelayRefType;
using RelayRefTypeNode = tvm::RelayRefTypeNode;
using TensorType = tvm::TensorType;
using TensorTypeNode = tvm::TensorTypeNode;
using TypeCall = tvm::TypeCall;
using TypeCallNode = tvm::TypeCallNode;
using TypeRelation = tvm::TypeRelation;
using TypeRelationNode = tvm::TypeRelationNode;
using TypeRelationFn = tvm::TypeRelationFn;
using TypeReporter = tvm::TypeReporter;
using TypeReporterNode = tvm::TypeReporterNode;

2. Expr

relay 里面的所有 Expr 都是继承自 ir 里面的 RelayExpr 类
namespace tvm {
namespace relay {

using Expr = tvm::RelayExpr;
relay 定义的 Expr 列表如下:
  1. Constant:常量
  2. Tuple:数组,由N个Expr构成
  3. Var:局部变量,计算图最常见的结构之一
  4. Call:算子调用,计算图最常见的结构之一
  5. Let:Let binding
  6. If:条件表达式
  7. TupleGetItem
  8. RefCreate
  9. RefRead
  10. RefWrite
  11. TempExpr
比如 Tuple
class Tuple : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param fields The fields of a tuple.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span());

  TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
  TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};

其他的比如 Constant/TupleGetItem 这样比较简单的 Expr, 就不介绍了,有兴趣的同学可以阅读 include/tvm/relay/expr.h
再来看下 Var 的定义
class VarNode : public ExprNode {
 public:
  /*!
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
   */
  Id vid;
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;
  // ...
}
Id 可以简单理解为一个字符串,是一个全局唯一的标识。你可以改下代码,增加一些调试,打印出来是这样的:
fc1_weight
fc1_bias
stage1_unit1_bn1_gamma
stage1_unit1_bn1_beta
另外这里有个地方没太理解,Var的值是存在哪里的?

接下来我们看函数,这是之前提到的 BaseFunc 在 Relay 中的具体实现
class FunctionNode : public BaseFuncNode {
 public:
  /*! \brief Function parameters */
  tvm::Array params;
  /*!
   * \brief
   * The expression which represents the computation of the function,
   * the expression may reference the parameters, and the type of it
   * or sub-expressions may reference the type variables.
   */
  Expr body;
  /*! \brief User annotated return type of the function. */
  Type ret_type;
  /*!
   * \brief Type parameters of the function.
   *  Enables the function to vary its type based on these.
   *  This corresponds to template paramaters in c++'s terminology.
   *
   * \note This can be usually empty for non-polymorphic functions.
   */
  tvm::Array type_params;
  // ...
}
FunctionNode 的定义还是比较清楚的,body 标识函数的计算逻辑,其实就是一个表达式 Expr,params是输入参数,ret_type 标识返回值类型
再看看 Call,也就是函数调用
/*!
 * \brief Call corresponds to operator invocation.
 *  Corresponds to the operator in computational graph terminology.
 */
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
 protected:
  // CallNode uses own deleter to indirectly call non-recursive destructor
  Object::FDeleter saved_deleter_;
  static void Deleter_(Object* ptr);

 public:
  /*!
   * \brief The operator(function) being invoked
   *
   *  - It can be tvm::Op which corresponds to the primitive operators.
   *  - It can also be user defined functions (Function, GlobalVar, Var).
   */
  Expr op;

  /*! \brief The arguments(inputs) of the call */
  tvm::Array args;

  /*! \brief The additional attributes */
  Attrs attrs;

  /*!
   * \brief The type arguments passed to polymorphic(template) function.
   *
   * This is the advance feature that is only used when the function is
   * polymorphic. It is safe to be ignored in most cases. For example, in the
   * following code, the type_args of addone call is [int].
   *
   * \code
   *
   * template
   * T addone(T a) { return a + 1; }
   *
   * void main() {
   *   int x = addone(10);
   * }
   *
   * \endcode
   */
  tvm::Array type_args;
  // ...
}
但是我没看懂 CallNode 和 FunctionNode 具体的区别是啥?
FunctionNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> tvm::RelayExprNode -> BaseFuncNode -> tvm::relay::FunctionNode
CallNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> ExprNode -> tvm::relay::CallNode
再看看 If,这个是比较简单的
/*!
 * \brief Condition expression
 *
 * Unlike traditional statement `if`s, the if evalutes
 * to the result of the branch taken.
 *
 * let x = if (true) { 1 } else { 0 }; // x is 1
 * let y = if (false) { 1 } else { 0 }; // y is 0
 *
 * \note This is similar to C's ternary operator.
 */
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
 public:
  /*! \brief The condition */
  Expr cond;
  /*! \brief The expression evaluated when condition is true. */
  Expr true_branch;
  /*! \brief The expression evaluated when condition is false */
  Expr false_branch;
  // ...
}

3. Op

relay OP 完全复用了上一章 ir 里的 OP 机制,没有增加任何新的东西,其头文件定义:
#include <tvm/ir/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>

namespace tvm {
namespace relay {

using Op = tvm::Op;
using OpNode = tvm::OpNode;

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注