在深入 OptimizeImpl 阶段,也就是 Pass 优化之前,我们先了解一下 Relay 阶段的一些基本概念
OptimizeImpl 阶段最主要的工作就是图优化,图其实就是一种高级的表示,tvm 和图相关的概念好好几个:
- relay ir:这是 tvm 最 high level 的表示,另外 relay ir -> tir 的过程中,又会依赖 topi 和 te 这2种特定抽象的中间表示
- topi:TVM Operator Inventory,TOPI 提供了比 tir 具有更高抽象的 numpy 风格的,通用操作和调度
-
- te:Tensor Expression,张量表达式
- tir:最接近目标代码的中间表示
- ir:relay ir 和 tir 的一些公共基础结构,和上述2种ir不一样,并不是一个完整独立的抽象
本章我们先来了解下 IR 这个 relay ir 和 tir 最公共的基础设施,后续会依次介绍 relay ir、tir、topi、te
代码目录:
- 代码:src/ir
- 头文件:include/tvm/ir
编程语言最基本的核心概念就3个:类型、运算符、表达式,在 IR 这里分别对应 Type, OP, Expr
1.1 Type
Type 相关的定义都在 include/tvm/ir/type.h ,Type 包括基础的整型/浮点型等,也包括函数类型等相对复杂的类型。
这里我们介绍2种基本的类型:
- PrimType:最原始的 Type,可以直接映射到 low-level IR 的基本数据类型
- FuncType:函数类型
PrimType 可以在这上面做一些 Low-level 优化
定义如下:
class PrimTypeNode : public TypeNode {
public:
/*!
* \brief The corresponding dtype field.
*/
runtime::DataType dtype;
}
可以看到 PrimType 就一个数据成员,runtime::DataType,这个是 runtime 最底层的概念,代码在 include/tvm/runtime/data_type.h
/*!
* \brief Runtime primitive data type.
*
* This class is a thin wrapper of DLDataType.
* We also make use of DataType in compiler to store quick hint
*/
class DataType {
public:
/*!
* \brief Type code for the DataType.
*
* DLPack consistency:
* 1) kInt is consistent with kDLInt
* 2) kUInt is consistent with kDLUInt
* 3) kFloat is consistent with kDLFloat
*/
enum TypeCode {
kInt = kDLInt,
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kBFloat = kDLBfloat,
kCustomBegin = 129
};
/*! \brief default constructor */
DataType() { data_ = DataType::Void(); }
/*!
* \brief Constructor
* \param dtype The DLDataType
*/
explicit DataType(DLDataType dtype) : data_(dtype) {}
再看 FuncType 的定义,FuncType 记录了函数参数类型/返回值类型/模板参数等信息,如下:
/*!
* \brief Function type.
*
* We support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa FuncType, TypeVar, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
Array arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
Array type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
Array type_constraints;
再来看下 TensorType, 代码如下
/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorType
*/
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by PrimExpr(tvm::Expr).
*/
Array shape;
/*! \brief The content data type */
DataType dtype;
需要注意,shape 也是类型的一部分,所以 TVM 中的类型推断的意思自然也包括了 shape 推断。
其他一些比较重要的 Type
- 量化相关的Type, 定义于 include/tvm/ir/affine_type.h
- ADT 相关,定义于 include/tvm/ir/adt.h
1.2 Expr
了解了Type, 再看Expr,Expr 包括简单的定义一个字面值,也包括定义一个复杂的函数。
ir 中定义的 Expr 有:
- PrimExpr:原始 Expr,主要在 tir 模块中定义,可以相对直接地映射到 low-level code
- RelayExpr:所有的非PrimExpr,比如 tensor,function,adt 等其他一等公民
- GlobalVar:全局变量,只有函数才会引用 GlobalVar,通常用来实现函数的递归调用
- IntImm:Int64 常量
- FloatImm:double 常量
- Bool:布尔常量
- Integer:继承自 IntImm
- Range:表示范围
我们接下来主要看下 PrimExpr
/*!
* \brief Reference to PrimExprNode.
* \sa PrimExprNode
*/
class PrimExpr : public BaseExpr {
public:
TVM_DLL PrimExpr(int32_t value); // NOLINT(*)
TVM_DLL PrimExpr(float value); // NOLINT(*)
// ...
};
可以看到 PrimExpr 可以直接从 float 或者 int32_t 类型直接转换过来,另外支持常见的各种基础运算
其实现代码在:src/tir/op/op.cc 里面
TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b); TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b); TVM_DLL PrimExpr operator-(PrimExpr a); TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
其实这里有点奇怪,这个代码放 src/ir/expr.cc 更合适的
我们看一下 FloatImm 的代码实现来体会 Expr
/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
FloatImm 的意思是浮点数字面值表达式,所以其记录了一个 double 成员。
再看下 BaseFunc, 这个是 relay.Function 和 tir.PrimFunc的共同基类, 定义于 include/tvm/ir/function.h
class BaseFuncNode : public RelayExprNode {
public:
/*! \brief Additional attributes storing the meta-data */
DictAttrs attrs;
// ...
}
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
DictAttrs 就是一个 Map<String, xx>,记录了函数的一些元数据,其他的没有更多描述了
1.3 Op
Op 就是我们最常说的算子,比如 nn.conv2d 就是一个算子
Op 是表示所有系统定义的原始算子/内联函数的通用类。开发者可以向系统注册新的 Ops,以及它们的额外属性(例如 Op 是否为元素)。
在 TVM 中,无论是 relay ir 还是 tir 的 op 都定义为一种 RelayExpr,如下
/*!
* \brief Primitive Op(builtin intrinsics)
*
* This data structure stores the meta-data
* about primitive operators that can be invoked via Call.
*
* Low-level IR intrinsics(such as libc.expf) are also
* implemented via Op.
*
* \sa Op
*/
class OpNode : public RelayExprNode {
public:
/*! \brief name of the operator */
String name;
/*! \brief the type of the operator */
mutable FuncType op_type;
/*!
* \brief detailed description of the operator
* This can be used to generate docstring automatically for the operator.
*/
String description;
/* \brief Information of input arguments to the operator */
Array arguments;
/*!
* \brief The type key of the attribute field
* This can be empty, in which case it defaults to anything.
*/
String attrs_type_key;
/*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
*/
uint32_t attrs_type_index{0};
/*!
* \brief number of input arguments to the operator,
* -1 means it is variable length
*/
int32_t num_inputs = -1;
/*!
* \brief support level of the operator,
* The lower the more priority it contains.
* This is in analogies to BLAS levels.
*/
int32_t support_level = 10;
代码中有些成员不太明白什么意思也没关系,接下来我们看个例子来加深理解。
首先我们看 Relay 中定义的 nn.bias_add, 第一步,先为其定义一个属性类型记录其所有属性,如下
struct BiasAddAttrs : public tvm::AttrsNode {
int axis;
TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1);
}
};
如代码所述,定义属性时还能规定默认值等,接着为其定义 type relation
bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as();
if (data == nullptr) return false;
const BiasAddAttrs* param = attrs.as();
ICHECK(param != nullptr);
int axis = param->axis;
if (axis < 0) { axis = data->shape.size() + axis;
}
if (axis >= static_cast(data->shape.size()) || axis < 0) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "The axis in bias_add must be in range for the shape; "
<< "attempted to access index " << param->axis << " of "
<< PrettyPrint(data->shape));
return false;
}
// assign output type
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], types[0]);
return true;
}
如代码所述,这是一个类型推断函数,根据0号Tensor输入和属性推断1号Tensor输入和2号输出的Type.
然后注册到全局表中
RELAY_REGISTER_OP("nn.bias_add")
.describe(R"code(Add bias to an axis of the input.
)code" TVM_ADD_FILELINE)
.set_attrs_type()
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("bias", "1D Tensor", "Bias.")
.set_support_level(1)
.add_type_rel("BiasAdd", BiasAddRel)
.set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs,
const Type& out_type) {
const auto* param = attrs.as();
return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
});
此时我们可以与一开始的 OpNode 中的成员一一对应起来了,其中的name/description/num_inputs/arguments/support_level对应关系都比较直接,不再赘述。
这里要回忆起来 BiasAddAttrs 是一个 Object, 所以会有其 type key 和 type index,也就是 OpNode 中的成员 attrs_type_key 和 attrs_type_index 了。
此外对于 BiasAdd 还设置了一个名为 FTVMCompute 的额外属性描述其具体如何计算,这部分我们后面章节再深入,这里先略过。
注册完以后我们可以通过 FFI 机制暴露给 Python 端调用
// Positional relay function to create dense operator used by frontend FFI.
Expr MakeBiasAdd(Expr data, Expr bias, int axis) {
auto attrs = make_object();
attrs->axis = axis;
static const Op& op = Op::Get("nn.bias_add");
return Call(op, {data, bias}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd);