前面一章我们已经了解了 relay ir 和 tir 中公共的 ir 概念,现在我们继续深入一下 relay ir
relay ir 是 tvm 最高层次的 IR,也是最接近前端的一种计算图表示方式
相比传统的基于 Graph 的 High Level IR, Relay 引入了不少函数式编程的概念,带来了强大的表达能力。
但是相比官方文档,这个写的更好:https://zhuanlan.zhihu.com/p/446976730
1. Type
include/tvm/relay/type.h
relay 并没有定义任何新的 Type 类型,全部直接服用 tvm/ir 里的定义,因此这里不展开描述
// namespace update for backward compact // will be removed later. using AnyNode = tvm::tir::AnyNode; using Any = tvm::tir::Any; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; using TypeVar = tvm::TypeVar; using TypeVarNode = tvm::TypeVarNode; using GlobalTypeVar = tvm::GlobalTypeVar; using GlobalTypeVarNode = tvm::GlobalTypeVarNode; using TupleType = tvm::TupleType; using TupleTypeNode = tvm::TupleTypeNode; using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; using IncompleteType = tvm::IncompleteType; using IncompleteTypeNode = tvm::IncompleteTypeNode; using RelayRefType = tvm::RelayRefType; using RelayRefTypeNode = tvm::RelayRefTypeNode; using TensorType = tvm::TensorType; using TensorTypeNode = tvm::TensorTypeNode; using TypeCall = tvm::TypeCall; using TypeCallNode = tvm::TypeCallNode; using TypeRelation = tvm::TypeRelation; using TypeRelationNode = tvm::TypeRelationNode; using TypeRelationFn = tvm::TypeRelationFn; using TypeReporter = tvm::TypeReporter; using TypeReporterNode = tvm::TypeReporterNode;
2. Expr
relay 里面的所有 Expr 都是继承自 ir 里面的 RelayExpr 类
namespace tvm { namespace relay { using Expr = tvm::RelayExpr;
relay 定义的 Expr 列表如下:
- Constant:常量
- Tuple:数组,由N个Expr构成
- Var:局部变量,计算图最常见的结构之一
- Call:算子调用,计算图最常见的结构之一
- Let:Let binding
- If:条件表达式
- TupleGetItem
- RefCreate
- RefRead
- RefWrite
- TempExpr
比如 Tuple
class Tuple : public Expr { public: /*! * \brief The constructor * \param fields The fields of a tuple. * \param span The source span of the expression. */ TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); };
其他的比如 Constant/TupleGetItem 这样比较简单的 Expr, 就不介绍了,有兴趣的同学可以阅读 include/tvm/relay/expr.h
再来看下 Var 的定义
class VarNode : public ExprNode { public: /*! * \brief The unique identifier of the Var. * * vid will be preserved for the same Var during type inference * and other rewritings, while the VarNode might be recreated * to attach additional information. * This property can be used to keep track of parameter Var * information across passes. */ Id vid; /*! * \brief type annotaion of the variable. * This field records user provided type annotation of the Var. * This field is optional and can be None. */ Type type_annotation; // ... }
Id 可以简单理解为一个字符串,是一个全局唯一的标识。你可以改下代码,增加一些调试,打印出来是这样的:
fc1_weight
fc1_bias
stage1_unit1_bn1_gamma
stage1_unit1_bn1_beta
另外这里有个地方没太理解,Var的值是存在哪里的?
接下来我们看函数,这是之前提到的 BaseFunc 在 Relay 中的具体实现
class FunctionNode : public BaseFuncNode { public: /*! \brief Function parameters */ tvm::Array params; /*! * \brief * The expression which represents the computation of the function, * the expression may reference the parameters, and the type of it * or sub-expressions may reference the type variables. */ Expr body; /*! \brief User annotated return type of the function. */ Type ret_type; /*! * \brief Type parameters of the function. * Enables the function to vary its type based on these. * This corresponds to template paramaters in c++'s terminology. * * \note This can be usually empty for non-polymorphic functions. */ tvm::Array type_params; // ... }
FunctionNode 的定义还是比较清楚的,body 标识函数的计算逻辑,其实就是一个表达式 Expr,params是输入参数,ret_type 标识返回值类型
再看看 Call,也就是函数调用
/*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. */ class Call; /*! \brief Call container. */ class CallNode : public ExprNode { protected: // CallNode uses own deleter to indirectly call non-recursive destructor Object::FDeleter saved_deleter_; static void Deleter_(Object* ptr); public: /*! * \brief The operator(function) being invoked * * - It can be tvm::Op which corresponds to the primitive operators. * - It can also be user defined functions (Function, GlobalVar, Var). */ Expr op; /*! \brief The arguments(inputs) of the call */ tvm::Array args; /*! \brief The additional attributes */ Attrs attrs; /*! * \brief The type arguments passed to polymorphic(template) function. * * This is the advance feature that is only used when the function is * polymorphic. It is safe to be ignored in most cases. For example, in the * following code, the type_args of addone call is [int]. * * \code * * template * T addone(T a) { return a + 1; } * * void main() { * int x = addone(10); * } * * \endcode */ tvm::Array type_args; // ... }
但是我没看懂 CallNode 和 FunctionNode 具体的区别是啥?
FunctionNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> tvm::RelayExprNode -> BaseFuncNode -> tvm::relay::FunctionNode
CallNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> ExprNode -> tvm::relay::CallNode
再看看 If,这个是比较简单的
/*! * \brief Condition expression * * Unlike traditional statement `if`s, the if evalutes * to the result of the branch taken. * * let x = if (true) { 1 } else { 0 }; // x is 1 * let y = if (false) { 1 } else { 0 }; // y is 0 * * \note This is similar to C's ternary operator. */ class If; /*! \brief container of If */ class IfNode : public ExprNode { public: /*! \brief The condition */ Expr cond; /*! \brief The expression evaluated when condition is true. */ Expr true_branch; /*! \brief The expression evaluated when condition is false */ Expr false_branch; // ... }
3. Op
relay OP 完全复用了上一章 ir 里的 OP 机制,没有增加任何新的东西,其头文件定义:
#include <tvm/ir/op.h> #include <tvm/relay/expr.h> #include <tvm/relay/type.h> namespace tvm { namespace relay { using Op = tvm::Op; using OpNode = tvm::OpNode;