在深入 OptimizeImpl 阶段,也就是 Pass 优化之前,我们先了解一下 Relay 阶段的一些基本概念
OptimizeImpl 阶段最主要的工作就是图优化,图其实就是一种高级的表示,tvm 和图相关的概念好好几个:
- relay ir:这是 tvm 最 high level 的表示,另外 relay ir -> tir 的过程中,又会依赖 topi 和 te 这2种特定抽象的中间表示
- topi:TVM Operator Inventory,TOPI 提供了比 tir 具有更高抽象的 numpy 风格的,通用操作和调度
-
- te:Tensor Expression,张量表达式
- tir:最接近目标代码的中间表示
- ir:relay ir 和 tir 的一些公共基础结构,和上述2种ir不一样,并不是一个完整独立的抽象
本章我们先来了解下 IR 这个 relay ir 和 tir 最公共的基础设施,后续会依次介绍 relay ir、tir、topi、te
代码目录:
- 代码:src/ir
- 头文件:include/tvm/ir
编程语言最基本的核心概念就3个:类型、运算符、表达式,在 IR 这里分别对应 Type, OP, Expr
1.1 Type
Type 相关的定义都在 include/tvm/ir/type.h ,Type 包括基础的整型/浮点型等,也包括函数类型等相对复杂的类型。
这里我们介绍2种基本的类型:
- PrimType:最原始的 Type,可以直接映射到 low-level IR 的基本数据类型
- FuncType:函数类型
PrimType 可以在这上面做一些 Low-level 优化
定义如下:
class PrimTypeNode : public TypeNode { public: /*! * \brief The corresponding dtype field. */ runtime::DataType dtype; }
可以看到 PrimType 就一个数据成员,runtime::DataType,这个是 runtime 最底层的概念,代码在 include/tvm/runtime/data_type.h
/*! * \brief Runtime primitive data type. * * This class is a thin wrapper of DLDataType. * We also make use of DataType in compiler to store quick hint */ class DataType { public: /*! * \brief Type code for the DataType. * * DLPack consistency: * 1) kInt is consistent with kDLInt * 2) kUInt is consistent with kDLUInt * 3) kFloat is consistent with kDLFloat */ enum TypeCode { kInt = kDLInt, kUInt = kDLUInt, kFloat = kDLFloat, kHandle = TVMArgTypeCode::kTVMOpaqueHandle, kBFloat = kDLBfloat, kCustomBegin = 129 }; /*! \brief default constructor */ DataType() { data_ = DataType::Void(); } /*! * \brief Constructor * \param dtype The DLDataType */ explicit DataType(DLDataType dtype) : data_(dtype) {}
再看 FuncType 的定义,FuncType 记录了函数参数类型/返回值类型/模板参数等信息,如下:
/*! * \brief Function type. * * We support polymorphic function type. * This can be roughly viewed as template function in C++. * * \sa FuncType, TypeVar, TypeConstraint */ class FuncTypeNode : public TypeNode { public: /*! \brief type type of arguments */ Array arg_types; /*! \brief The type of return value. */ Type ret_type; // The following fields are used in polymorphic(template) functions // For normal functions, the following two fields will be empty. /*! \brief The type parameters of the function */ Array type_params; /*! * \brief potential constraint the type need to obey * \note this field is reserved for futher purposes. */ Array type_constraints;
再来看下 TensorType, 代码如下
/*! * \brief This is the most commonly used type in relay. * TensorType have a fixed dimension, data type. * * The elements of shape can be either IntImm(constant integer), * or any symbolic integer expression. * The symbolic integer allows generic shape inference in certain cases. * \sa TensorType */ class TensorTypeNode : public BaseTensorTypeNode { public: /*! * \brief The shape of the tensor, * represented by PrimExpr(tvm::Expr). */ Array shape; /*! \brief The content data type */ DataType dtype;
需要注意,shape 也是类型的一部分,所以 TVM 中的类型推断的意思自然也包括了 shape 推断。
其他一些比较重要的 Type
- 量化相关的Type, 定义于 include/tvm/ir/affine_type.h
- ADT 相关,定义于 include/tvm/ir/adt.h
1.2 Expr
了解了Type, 再看Expr,Expr 包括简单的定义一个字面值,也包括定义一个复杂的函数。
ir 中定义的 Expr 有:
- PrimExpr:原始 Expr,主要在 tir 模块中定义,可以相对直接地映射到 low-level code
- RelayExpr:所有的非PrimExpr,比如 tensor,function,adt 等其他一等公民
- GlobalVar:全局变量,只有函数才会引用 GlobalVar,通常用来实现函数的递归调用
- IntImm:Int64 常量
- FloatImm:double 常量
- Bool:布尔常量
- Integer:继承自 IntImm
- Range:表示范围
我们接下来主要看下 PrimExpr
/*! * \brief Reference to PrimExprNode. * \sa PrimExprNode */ class PrimExpr : public BaseExpr { public: TVM_DLL PrimExpr(int32_t value); // NOLINT(*) TVM_DLL PrimExpr(float value); // NOLINT(*) // ... };
可以看到 PrimExpr 可以直接从 float 或者 int32_t 类型直接转换过来,另外支持常见的各种基础运算
其实现代码在:src/tir/op/op.cc 里面
TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); TVM_DLL PrimExpr operator-(PrimExpr a); TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
其实这里有点奇怪,这个代码放 src/ir/expr.cc 更合适的
我们看一下 FloatImm 的代码实现来体会 Expr
/*! * \brief Constant floating point literals in the program. * \sa FloatImm */ class FloatImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ double value;
FloatImm 的意思是浮点数字面值表达式,所以其记录了一个 double 成员。
再看下 BaseFunc, 这个是 relay.Function 和 tir.PrimFunc的共同基类, 定义于 include/tvm/ir/function.h
class BaseFuncNode : public RelayExprNode { public: /*! \brief Additional attributes storing the meta-data */ DictAttrs attrs; // ... } /*! * \brief Managed reference to BaseFuncNode. * \sa BaseFuncNode */ class BaseFunc : public RelayExpr { public: TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); };
DictAttrs 就是一个 Map<String, xx>,记录了函数的一些元数据,其他的没有更多描述了
1.3 Op
Op 就是我们最常说的算子,比如 nn.conv2d 就是一个算子
Op 是表示所有系统定义的原始算子/内联函数的通用类。开发者可以向系统注册新的 Ops,以及它们的额外属性(例如 Op 是否为元素)。
在 TVM 中,无论是 relay ir 还是 tir 的 op 都定义为一种 RelayExpr,如下
/*! * \brief Primitive Op(builtin intrinsics) * * This data structure stores the meta-data * about primitive operators that can be invoked via Call. * * Low-level IR intrinsics(such as libc.expf) are also * implemented via Op. * * \sa Op */ class OpNode : public RelayExprNode { public: /*! \brief name of the operator */ String name; /*! \brief the type of the operator */ mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. */ String description; /* \brief Information of input arguments to the operator */ Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. */ String attrs_type_key; /*! * \brief attribute type index, * this field varies in each run and is not exposed to frontend. */ uint32_t attrs_type_index{0}; /*! * \brief number of input arguments to the operator, * -1 means it is variable length */ int32_t num_inputs = -1; /*! * \brief support level of the operator, * The lower the more priority it contains. * This is in analogies to BLAS levels. */ int32_t support_level = 10;
代码中有些成员不太明白什么意思也没关系,接下来我们看个例子来加深理解。
首先我们看 Relay 中定义的 nn.bias_add, 第一步,先为其定义一个属性类型记录其所有属性,如下
struct BiasAddAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") { TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1); } };
如代码所述,定义属性时还能规定默认值等,接着为其定义 type relation
bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) return false; const BiasAddAttrs* param = attrs.as(); ICHECK(param != nullptr); int axis = param->axis; if (axis < 0) { axis = data->shape.size() + axis; } if (axis >= static_cast(data->shape.size()) || axis < 0) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "The axis in bias_add must be in range for the shape; " << "attempted to access index " << param->axis << " of " << PrettyPrint(data->shape)); return false; } // assign output type reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[2], types[0]); return true; }
如代码所述,这是一个类型推断函数,根据0号Tensor输入和属性推断1号Tensor输入和2号输出的Type.
然后注册到全局表中
RELAY_REGISTER_OP("nn.bias_add") .describe(R"code(Add bias to an axis of the input. )code" TVM_ADD_FILELINE) .set_attrs_type() .set_num_inputs(2) .add_argument("data", "nD Tensor", "Input data.") .add_argument("bias", "1D Tensor", "Bias.") .set_support_level(1) .add_type_rel("BiasAdd", BiasAddRel) .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; });
此时我们可以与一开始的 OpNode 中的成员一一对应起来了,其中的name/description/num_inputs/arguments/support_level对应关系都比较直接,不再赘述。
这里要回忆起来 BiasAddAttrs 是一个 Object, 所以会有其 type key 和 type index,也就是 OpNode 中的成员 attrs_type_key 和 attrs_type_index 了。
此外对于 BiasAdd 还设置了一个名为 FTVMCompute 的额外属性描述其具体如何计算,这部分我们后面章节再深入,这里先略过。
注册完以后我们可以通过 FFI 机制暴露给 Python 端调用
// Positional relay function to create dense operator used by frontend FFI. Expr MakeBiasAdd(Expr data, Expr bias, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return Call(op, {data, bias}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd);