深入浅出 tvm – (5) IR 公共的一些核心概念

在深入 OptimizeImpl 阶段,也就是 Pass 优化之前,我们先了解一下 Relay 阶段的一些基本概念
OptimizeImpl 阶段最主要的工作就是图优化,图其实就是一种高级的表示,tvm 和图相关的概念好好几个:
  1. relay ir:这是 tvm 最 high level 的表示,另外 relay ir -> tir 的过程中,又会依赖 topi 和 te 这2种特定抽象的中间表示
    1. topi:TVM Operator Inventory,TOPI 提供了比 tir 具有更高抽象的 numpy 风格的,通用操作和调度
    1. te:Tensor Expression,张量表达式
  1. tir:最接近目标代码的中间表示
  2. ir:relay ir 和 tir 的一些公共基础结构,和上述2种ir不一样,并不是一个完整独立的抽象
本章我们先来了解下 IR 这个 relay ir 和 tir 最公共的基础设施,后续会依次介绍 relay ir、tir、topi、te
代码目录:
  • 代码:src/ir
  • 头文件:include/tvm/ir
编程语言最基本的核心概念就3个:类型、运算符、表达式,在 IR 这里分别对应 Type, OP, Expr

1.1 Type

Type 相关的定义都在 include/tvm/ir/type.h ,Type 包括基础的整型/浮点型等,也包括函数类型等相对复杂的类型。
这里我们介绍2种基本的类型:
  1. PrimType:最原始的 Type,可以直接映射到 low-level IR 的基本数据类型
  2. FuncType:函数类型
PrimType 可以在这上面做一些 Low-level 优化
定义如下:
class PrimTypeNode : public TypeNode {
 public:
  /*!
   * \brief The corresponding dtype field.
   */
  runtime::DataType dtype;
}
可以看到 PrimType 就一个数据成员,runtime::DataType,这个是 runtime 最底层的概念,代码在 include/tvm/runtime/data_type.h
/*!
 * \brief Runtime primitive data type.
 *
 *  This class is a thin wrapper of DLDataType.
 *  We also make use of DataType in compiler to store quick hint
 */
class DataType {
 public:
  /*!
   * \brief Type code for the DataType.
   *
   * DLPack consistency:
   * 1) kInt is consistent with kDLInt
   * 2) kUInt is consistent with kDLUInt
   * 3) kFloat is consistent with kDLFloat
   */
  enum TypeCode {
    kInt = kDLInt,
    kUInt = kDLUInt,
    kFloat = kDLFloat,
    kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
    kBFloat = kDLBfloat,
    kCustomBegin = 129
  };
  /*! \brief default constructor */
  DataType() { data_ = DataType::Void(); }
  /*!
   * \brief Constructor
   * \param dtype The DLDataType
   */
  explicit DataType(DLDataType dtype) : data_(dtype) {}

再看 FuncType 的定义,FuncType 记录了函数参数类型/返回值类型/模板参数等信息,如下:
/*!
 * \brief Function type.
 *
 * We support polymorphic function type.
 * This can be roughly viewed as template function in C++.
 *
 * \sa FuncType, TypeVar, TypeConstraint
 */
class FuncTypeNode : public TypeNode {
 public:
  /*! \brief type type of arguments */
  Array arg_types;
  /*! \brief The type of return value. */
  Type ret_type;
  // The following fields are used in polymorphic(template) functions
  // For normal functions, the following two fields will be empty.
  /*! \brief The type parameters of the function */
  Array type_params;
  /*!
   * \brief potential constraint the type need to obey
   * \note this field is reserved for futher purposes.
   */
  Array type_constraints;  
再来看下 TensorType, 代码如下
/*!
 * \brief This is the most commonly used type in relay.
 *  TensorType have a fixed dimension, data type.
 *
 *  The elements of shape can be either IntImm(constant integer),
 *  or any symbolic integer expression.
 *  The symbolic integer allows generic shape inference in certain cases.
 * \sa TensorType
 */
class TensorTypeNode : public BaseTensorTypeNode {
 public:
  /*!
   * \brief The shape of the tensor,
   *  represented by PrimExpr(tvm::Expr).
   */
  Array shape;
  /*! \brief The content data type */
  DataType dtype;
需要注意,shape 也是类型的一部分,所以 TVM 中的类型推断的意思自然也包括了 shape 推断。
其他一些比较重要的 Type
  • 量化相关的Type, 定义于 include/tvm/ir/affine_type.h
  • ADT 相关,定义于 include/tvm/ir/adt.h

1.2 Expr

了解了Type, 再看Expr,Expr 包括简单的定义一个字面值,也包括定义一个复杂的函数。
ir 中定义的 Expr 有:
  1. PrimExpr:原始 Expr,主要在 tir 模块中定义,可以相对直接地映射到 low-level code
  2. RelayExpr:所有的非PrimExpr,比如 tensor,function,adt 等其他一等公民
  3. GlobalVar:全局变量,只有函数才会引用 GlobalVar,通常用来实现函数的递归调用
  4. IntImm:Int64 常量
  5. FloatImm:double 常量
  6. Bool:布尔常量
  7. Integer:继承自 IntImm
  8. Range:表示范围
我们接下来主要看下 PrimExpr
/*!
 * \brief Reference to PrimExprNode.
 * \sa PrimExprNode
 */
class PrimExpr : public BaseExpr {
 public:

  TVM_DLL PrimExpr(int32_t value);  // NOLINT(*)

  TVM_DLL PrimExpr(float value);  // NOLINT(*)
  // ...
  
};
可以看到 PrimExpr 可以直接从 float 或者 int32_t 类型直接转换过来,另外支持常见的各种基础运算
其实现代码在:src/tir/op/op.cc 里面
TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
TVM_DLL PrimExpr operator-(PrimExpr a);
TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
其实这里有点奇怪,这个代码放 src/ir/expr.cc 更合适的
我们看一下 FloatImm 的代码实现来体会 Expr
/*!
 * \brief Constant floating point literals in the program.
 * \sa FloatImm
 */
class FloatImmNode : public PrimExprNode {
 public:
  /*! \brief The constant value content. */
  double value;
FloatImm 的意思是浮点数字面值表达式,所以其记录了一个 double 成员。
再看下 BaseFunc, 这个是 relay.Function 和 tir.PrimFunc的共同基类, 定义于 include/tvm/ir/function.h
class BaseFuncNode : public RelayExprNode {
 public:
  /*! \brief Additional attributes storing the meta-data */
  DictAttrs attrs;
  // ...
}

/*!
 * \brief Managed reference to BaseFuncNode.
 * \sa BaseFuncNode
 */
class BaseFunc : public RelayExpr {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

DictAttrs 就是一个 Map<String, xx>,记录了函数的一些元数据,其他的没有更多描述了

1.3 Op

Op 就是我们最常说的算子,比如 nn.conv2d 就是一个算子
Op 是表示所有系统定义的原始算子/内联函数的通用类。开发者可以向系统注册新的 Ops,以及它们的额外属性(例如 Op 是否为元素)。
在 TVM 中,无论是 relay ir 还是 tir 的 op 都定义为一种 RelayExpr,如下
/*!
 * \brief Primitive Op(builtin intrinsics)
 *
 * This data structure stores the meta-data
 * about primitive operators that can be invoked via Call.
 *
 * Low-level IR intrinsics(such as libc.expf) are also
 * implemented via Op.
 *
 * \sa Op
 */
class OpNode : public RelayExprNode {
 public:
  /*! \brief name of the operator */
  String name;
  /*! \brief the type of the operator */
  mutable FuncType op_type;
  /*!
   * \brief detailed description of the operator
   *  This can be used to generate docstring automatically for the operator.
   */
  String description;
  /* \brief Information of input arguments to the operator */
  Array arguments;
  /*!
   * \brief The type key of the attribute field
   *  This can be empty, in which case it defaults to anything.
   */
  String attrs_type_key;
  /*!
   * \brief attribute type index,
   * this field varies in each run and is not exposed to frontend.
   */
  uint32_t attrs_type_index{0};
  /*!
   * \brief number of input arguments to the operator,
   * -1 means it is variable length
   */
  int32_t num_inputs = -1;
  /*!
   * \brief support level of the operator,
   *  The lower the more priority it contains.
   *  This is in analogies to BLAS levels.
   */
  int32_t support_level = 10;
代码中有些成员不太明白什么意思也没关系,接下来我们看个例子来加深理解。
首先我们看 Relay 中定义的 nn.bias_add, 第一步,先为其定义一个属性类型记录其所有属性,如下
struct BiasAddAttrs : public tvm::AttrsNode {
  int axis;

  TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
    TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1);
  }
};
如代码所述,定义属性时还能规定默认值等,接着为其定义 type relation
bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs,
                const TypeReporter& reporter) {
  ICHECK_EQ(types.size(), 3);
  const auto* data = types[0].as();
  if (data == nullptr) return false;

  const BiasAddAttrs* param = attrs.as();
  ICHECK(param != nullptr);
  int axis = param->axis;
  if (axis < 0) { axis = data->shape.size() + axis;
  }
  if (axis >= static_cast(data->shape.size()) || axis < 0) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
                                     << "The axis in bias_add must be in range for the shape; "
                                     << "attempted to access index " << param->axis << " of "
                                     << PrettyPrint(data->shape));
    return false;
  }

  // assign output type
  reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
  reporter->Assign(types[2], types[0]);
  return true;
}
如代码所述,这是一个类型推断函数,根据0号Tensor输入和属性推断1号Tensor输入和2号输出的Type.
然后注册到全局表中
RELAY_REGISTER_OP("nn.bias_add")
    .describe(R"code(Add bias to an axis of the input.

)code" TVM_ADD_FILELINE)
    .set_attrs_type()
    .set_num_inputs(2)
    .add_argument("data", "nD Tensor", "Input data.")
    .add_argument("bias", "1D Tensor", "Bias.")
    .set_support_level(1)
    .add_type_rel("BiasAdd", BiasAddRel)
    .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs,
                                             const Type& out_type) {
      const auto* param = attrs.as();
      return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
    });
此时我们可以与一开始的 OpNode 中的成员一一对应起来了,其中的name/description/num_inputs/arguments/support_level对应关系都比较直接,不再赘述。
这里要回忆起来 BiasAddAttrs 是一个 Object, 所以会有其 type key 和 type index,也就是 OpNode 中的成员 attrs_type_key 和 attrs_type_index 了。
此外对于 BiasAdd 还设置了一个名为 FTVMCompute 的额外属性描述其具体如何计算,这部分我们后面章节再深入,这里先略过。
注册完以后我们可以通过 FFI 机制暴露给 Python 端调用
// Positional relay function to create dense operator used by frontend FFI.
Expr MakeBiasAdd(Expr data, Expr bias, int axis) {
  auto attrs = make_object();
  attrs->axis = axis;
  static const Op& op = Op::Get("nn.bias_add");
  return Call(op, {data, bias}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd);

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注