前面一章我们已经了解了 relay ir 和 tir 中公共的 ir 概念,现在我们继续深入一下 relay ir
relay ir 是 tvm 最高层次的 IR,也是最接近前端的一种计算图表示方式
相比传统的基于 Graph 的 High Level IR, Relay 引入了不少函数式编程的概念,带来了强大的表达能力。
但是相比官方文档,这个写的更好:https://zhuanlan.zhihu.com/p/446976730
1. Type
include/tvm/relay/type.h
relay 并没有定义任何新的 Type 类型,全部直接服用 tvm/ir 里的定义,因此这里不展开描述
// namespace update for backward compact // will be removed later. using AnyNode = tvm::tir::AnyNode; using Any = tvm::tir::Any; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; using TypeVar = tvm::TypeVar; using TypeVarNode = tvm::TypeVarNode; using GlobalTypeVar = tvm::GlobalTypeVar; using GlobalTypeVarNode = tvm::GlobalTypeVarNode; using TupleType = tvm::TupleType; using TupleTypeNode = tvm::TupleTypeNode; using TypeConstraint = tvm::TypeConstraint; using TypeConstraintNode = tvm::TypeConstraintNode; using FuncType = tvm::FuncType; using FuncTypeNode = tvm::FuncTypeNode; using IncompleteType = tvm::IncompleteType; using IncompleteTypeNode = tvm::IncompleteTypeNode; using RelayRefType = tvm::RelayRefType; using RelayRefTypeNode = tvm::RelayRefTypeNode; using TensorType = tvm::TensorType; using TensorTypeNode = tvm::TensorTypeNode; using TypeCall = tvm::TypeCall; using TypeCallNode = tvm::TypeCallNode; using TypeRelation = tvm::TypeRelation; using TypeRelationNode = tvm::TypeRelationNode; using TypeRelationFn = tvm::TypeRelationFn; using TypeReporter = tvm::TypeReporter; using TypeReporterNode = tvm::TypeReporterNode;
2. Expr
relay 里面的所有 Expr 都是继承自 ir 里面的 RelayExpr 类
namespace tvm {
namespace relay {
using Expr = tvm::RelayExpr;
relay 定义的 Expr 列表如下:
- Constant:常量
- Tuple:数组,由N个Expr构成
- Var:局部变量,计算图最常见的结构之一
- Call:算子调用,计算图最常见的结构之一
- Let:Let binding
- If:条件表达式
- TupleGetItem
- RefCreate
- RefRead
- RefWrite
- TempExpr
比如 Tuple
class Tuple : public Expr {
public:
/*!
* \brief The constructor
* \param fields The fields of a tuple.
* \param span The source span of the expression.
*/
TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};
其他的比如 Constant/TupleGetItem 这样比较简单的 Expr, 就不介绍了,有兴趣的同学可以阅读 include/tvm/relay/expr.h
再来看下 Var 的定义
class VarNode : public ExprNode {
public:
/*!
* \brief The unique identifier of the Var.
*
* vid will be preserved for the same Var during type inference
* and other rewritings, while the VarNode might be recreated
* to attach additional information.
* This property can be used to keep track of parameter Var
* information across passes.
*/
Id vid;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
* This field is optional and can be None.
*/
Type type_annotation;
// ...
}
Id 可以简单理解为一个字符串,是一个全局唯一的标识。你可以改下代码,增加一些调试,打印出来是这样的:
fc1_weight
fc1_bias
stage1_unit1_bn1_gamma
stage1_unit1_bn1_beta
另外这里有个地方没太理解,Var的值是存在哪里的?
接下来我们看函数,这是之前提到的 BaseFunc 在 Relay 中的具体实现
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array type_params;
// ...
}
FunctionNode 的定义还是比较清楚的,body 标识函数的计算逻辑,其实就是一个表达式 Expr,params是输入参数,ret_type 标识返回值类型
再看看 Call,也就是函数调用
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
*/
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
protected:
// CallNode uses own deleter to indirectly call non-recursive destructor
Object::FDeleter saved_deleter_;
static void Deleter_(Object* ptr);
public:
/*!
* \brief The operator(function) being invoked
*
* - It can be tvm::Op which corresponds to the primitive operators.
* - It can also be user defined functions (Function, GlobalVar, Var).
*/
Expr op;
/*! \brief The arguments(inputs) of the call */
tvm::Array args;
/*! \brief The additional attributes */
Attrs attrs;
/*!
* \brief The type arguments passed to polymorphic(template) function.
*
* This is the advance feature that is only used when the function is
* polymorphic. It is safe to be ignored in most cases. For example, in the
* following code, the type_args of addone call is [int].
*
* \code
*
* template
* T addone(T a) { return a + 1; }
*
* void main() {
* int x = addone(10);
* }
*
* \endcode
*/
tvm::Array type_args;
// ...
}
但是我没看懂 CallNode 和 FunctionNode 具体的区别是啥?
FunctionNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> tvm::RelayExprNode -> BaseFuncNode -> tvm::relay::FunctionNode
CallNode 的继承链:tvm::runtime::Object -> tvm::BaseExprNode -> ExprNode -> tvm::relay::CallNode
再看看 If,这个是比较简单的
/*!
* \brief Condition expression
*
* Unlike traditional statement `if`s, the if evalutes
* to the result of the branch taken.
*
* let x = if (true) { 1 } else { 0 }; // x is 1
* let y = if (false) { 1 } else { 0 }; // y is 0
*
* \note This is similar to C's ternary operator.
*/
class If;
/*! \brief container of If */
class IfNode : public ExprNode {
public:
/*! \brief The condition */
Expr cond;
/*! \brief The expression evaluated when condition is true. */
Expr true_branch;
/*! \brief The expression evaluated when condition is false */
Expr false_branch;
// ...
}
3. Op
relay OP 完全复用了上一章 ir 里的 OP 机制,没有增加任何新的东西,其头文件定义:
#include <tvm/ir/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
namespace tvm {
namespace relay {
using Op = tvm::Op;
using OpNode = tvm::OpNode;