一芙盘、前言
本文以實(shí)現(xiàn)一個(gè)axis_abs的自定義算子為例介紹如何在tvm中添加新的relay算子踏烙,該算子實(shí)現(xiàn)的功能是以輸入的3維tensor取某一維度的指定切片取絕對值舱痘。
二臭胜、添加自定義算子
新增relay算子基本是下面幾個(gè)步驟:
定義新增算子的屬性節(jié)點(diǎn)(Attribute Node)录语,聲明在編譯時(shí)已知的固定參數(shù);
為新增算子編寫類型關(guān)系注服,以集成到relay的類型系統(tǒng)中韭邓;
使用C++RELAY_REGISTER_OP宏措近,為新增算子注冊生命參數(shù)數(shù)量、類型女淑、提示信息瞭郑;
算子的compute實(shí)現(xiàn);
注冊算子的compute鸭你、schedule凰浮;
定義C++函數(shù),為新增算子生成調(diào)用節(jié)點(diǎn)苇本,并為該函數(shù)注冊 Python API hook;
將上面的 Python API hook 封裝成簡潔的調(diào)用方式菜拓;
為新的relay 算子編寫測試瓣窄。
1、定義新增算子的屬性節(jié)點(diǎn)(Attribute Node)
在include/tvm/relay/attrs/transform.h中增加算子的屬性數(shù)據(jù)結(jié)構(gòu):
/*! \brief Attributes used in axisabs operator */
struct AxisAbsAttrs : public tvm::AttrsNode<AxisAbsAttrs> {
int axis;
int indice;
TVM_DECLARE_ATTRS(AxisAbsAttrs, "relay.attrs.AxisAbsAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("Axis to abs");
TVM_ATTR_FIELD(indice).set_default(0).describe("Indice to abs");
}
};
Q:宏TVM_DECLARE_ATTRS 與 TVM_ATTR_FIELD的作用是什么纳鼎?
A:這兩個(gè)宏定義在 include/tvm/ir/attrs.h
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template <typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)
其中的TVM_DECLARE_FINAL_OBJECT_INFO定義在include/tvm/runtime/object.h
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
static const constexpr bool _type_final = true; \
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj marked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
return _GetOrAllocRuntimeTypeIndex(); \
} \
static uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
return tindex; \
}
所以宏展開后定義的屬性節(jié)點(diǎn)數(shù)據(jù)結(jié)構(gòu)為:
struct AxisAbsAttrs : public tvm::ArrayNode<AxisAbsAttrs> {
int axis;
static constexpr const char* _type_key = "relay.attrs.AxisAbsAttrs";
static const constexpr bool _type_final = true;
static const constexpr int _type_child_slots = 0;
static_assert(!::tvm::BaseAttrsNode::_type_final, "ParentObj marked as final");
static uint32_t RuntimeTypeIndex() {
static_assert(AxisAbsAttrs::_type_child_slots == 0 || ::tvm::BaseAttrsNode::_type_child_slots == 0 ||
AxisAbsAttrs::_type_child_slots < ::tvm::BaseAttrsNode::_type_child_slots,
"Need to set _type_child_slots when parent specifies it.");
if (AxisAbsAttrs::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
return AxisAbsAttrs::_type_index;
}
return _GetOrAllocRuntimeTypeIndex();
}
static uint32_t _GetOrAllocRuntimeTypeIndex() {
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(
AxisAbsAttrs::_type_key, AxisAbsAttrs::_type_index, ::tvm::BaseAttrsNode::_GetOrAllocRuntimeTypeIndex(),
AxisAbsAttrs::_type_child_slots, AxisAbsAttrs::_type_child_slots_can_overflow);
return tindex;
}
template <typename FVisit>
void __VisitAttrs__(FVisit& __fvisit__) {
__fvisit__(axis, &axis).set_default(0).describe("Axis to abs");
}
}
可以看到俺夕,每個(gè)屬性節(jié)點(diǎn)都定義了獲取運(yùn)行時(shí)類型索引的函數(shù)RuntimeTypeIndex()以及訪問屬性內(nèi)部成員的模版函數(shù)VisitAttrs(FVisit& fvisit)。
Q:模版函數(shù)VisitAttrs(FVisit& fvisit)的調(diào)用過程是怎么樣的贱鄙?
A:首先分析定義在include/tvm/ir/attrs.h中的類class AttrsNode
template <typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void VisitNonDefaultAttrs(AttrVisitor* v) {...}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {...}
bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {...}
void SHashReduce(SHashReducer hash_reducer) const {...}
Array<AttrFieldInfo> ListFieldInfo() const final {...}
private:
DerivedType* self() const {
return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
}
};
它是一個(gè)模版類劝贸,模版參數(shù)是繼承它的子類類型,在成員函數(shù)VisitAttrs(AttrVisitor* v)中逗宁,傳入屬性訪問器類AttrVisitor對象:
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
TVM_DLL virtual ~AttrVisitor() = default;
TVM_DLL virtual void Visit(const char* key, double* value) = 0;
TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
TVM_DLL virtual void Visit(const char* key, int* value) = 0;
TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
TVM_DLL virtual void Visit(const char* key, void** value) = 0;
TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
然后通過::tvm::detail::AttrNormalVisitor vis(v);包裹一層普通屬性訪問函數(shù):
// Wrapper for normal visitor.
class AttrNormalVisitor {
public:
explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
template <typename T>
AttrNopEntry operator()(const char* key, T* value) {
visitor_->Visit(key, value);
return AttrNopEntry();
}
private:
AttrVisitor* visitor_;
};
它重載了運(yùn)算符“()”映九,當(dāng)class AttrsNode通過self()->VisitAttrs(vis)獲取子類的對象并通過子類對象調(diào)用VisitAttrs(FVisit& fvisit) 時(shí),隨即調(diào)用了fvisit(axis, &axis)瞎颗,這個(gè)fvisit最終調(diào)到的就是class AttrNormalVisitor 中的重載"()"函數(shù)件甥,這個(gè)函數(shù)會返回一個(gè)結(jié)構(gòu)體用于支持鏈?zhǔn)秸{(diào)用:
// helper entry that does nothing in set_default/bound/describe calls.
struct AttrNopEntry {
using TSelf = AttrNopEntry;
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
template <typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {return *this;}
template <typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {return *this;}
template <typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {return *this;}
};
這些調(diào)用實(shí)際上什么都沒有做就返回了其自身。
2哼拔、編寫算子類型關(guān)系引有,集成到Relay的類型系統(tǒng)
為了算子注冊的靈活性以及relay算子有更好的泛化能力,relay算子通過輸入輸出之間的類型關(guān)系來實(shí)例化倦逐。本質(zhì)上譬正,算子類型關(guān)系除了推導(dǎo)輸出類型外,還能夠強(qiáng)制指定類型規(guī)則(檢查輸入類型)檬姥。需要在src\relay\op\tensor\transform.cc中添加算子的類型關(guān)系處理函數(shù):
bool AxisAbsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, output]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "cast: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as<AxisAbsAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
const int axis_len = data->shape[axis].as<IntImmNode>()->value;
const int indice = param->indice;
ICHECK(0 <= axis && axis < ndim)
<< "axis_abs only accepts `axis` in [0, data.ndim - 1]"
<< ", but got axis = " << axis << ", and data.ndim = " << ndim;
ICHECK(0 <= indice && indice < axis_len)
<< "axis_abs only accepts `indice` in [0, data[axis] - 1"
<< ", but got indice = " << indice << ", and data[axis] = " << axis_len;
reporter->Assign(types[1], TensorType(data->shape, data->dtype));
return true;
}
Q:類型關(guān)系處理函數(shù)在什么時(shí)候調(diào)用曾我?
A:類型關(guān)系處理函數(shù)在注冊Relay算子時(shí)通過鏈?zhǔn)秸{(diào)用add_type_rel()注冊。
Q:函數(shù)輸入?yún)?shù)types的含意是什么穿铆?
A:types傳入的是一個(gè)數(shù)組引用您单,內(nèi)容一般為輸入與輸出的TensorType,首先看class TensorTypeNode:
class TensorTypeNode : public BaseTensorTypeNode {
public:
Array<PrimExpr> shape; // Tensor的shape
DataType dtype; // Tensor中數(shù)據(jù)類型
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {...}
void SHashReduce(SHashReducer hash_reduce) const {...}
TVM_DLL PrimExpr Size() const;
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};
它定義了一個(gè)Tensor所需要的基本數(shù)據(jù)信息如:shape與數(shù)據(jù)類型荞雏,但是并沒有實(shí)際的數(shù)據(jù)虐秦,所以類名也就叫TensorTypeNode平酿。通過它可以獲取到輸入Tensor的類型信息從而對參數(shù)做合法性檢查。
Q:函數(shù)輸入?yún)?shù)reporter的含意是什么悦陋?
A:class TypeReporter是一個(gè)TypeReporterNode的容器類:
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {}
TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>(static_cast<const TypeReporterNode*>(get()));
}
using ContainerType = TypeReporterNode;
};
它重載了運(yùn)算符"->"蜈彼,所以:
reporter->Assign(types[1], TensorType(data->shape, data->dtype));
首先會實(shí)例化一個(gè)TensorType對象,因?yàn)槲覀兊睦邮菍δ骋粋€(gè)維度的數(shù)據(jù)取絕對值俺驶,所以輸出的數(shù)據(jù)shape及dtype與輸入相同幸逆。然后通過reporter->Assign()調(diào)用class TypeReporterNode中純虛函數(shù)virtual void Assign(dst, src) = 0,將創(chuàng)建好的TensorType對象賦值給輸出TensorType暮现,即types[1]还绘。
3、關(guān)聯(lián)算子的參數(shù)數(shù)目栖袋、屬性
這一步的操作拍顷,為自定義算子注冊算子名稱,通過調(diào)用接口增加算子注釋塘幅。這里需要用到C++的宏RELAY_REGISTER_OP昔案,涉及的參數(shù)含義如下:
Arity(參數(shù)數(shù)量)
位置參數(shù)的名稱和描述
支持級別(1 表示內(nèi)部實(shí)現(xiàn);較高的數(shù)字表示較少的內(nèi)部支持或外部支持的算子)
算子的類型關(guān)系
優(yōu)化算子時(shí)有用的其他注釋。
需要在src/relay/op/tensor/transform.cc中注冊算子并設(shè)置相關(guān)屬性:
RELAY_REGISTER_OP("axis_abs")
.describe(R"doc(Computes the axis abs of a tensor.)doc") TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_support_level(3)
.add_type_rel("axis_abs", AxisAbsRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque);
Q:宏RELAY_REGISTER_OP做了什么电媳?
A:RELAY_REGISTER_OP用于注冊Relay算子:
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
#define TVM_REGISTER_OP(OpName) \
TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
展開為:
static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_Op0=::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
其中COUNTER為編譯器內(nèi)置宏踏揣,初值是0,每預(yù)編譯一次其值自己加1匾乓,通常配合 ## 使用捞稿,用于構(gòu)建唯一的標(biāo)識符,做法其實(shí)很簡單钝尸,把任意一個(gè)標(biāo)識符與 COUNTER 合并就可以了:
#define STR_CONCAT_(x, y) x##y // 合并用的宏
#define STR_CONCAT(x, y) STR_CONCAT_(x, y) // 因?yàn)?## 的特性 ( 阻止另一個(gè)宏的展開 )括享,需要中間層
#define UNIQUE_NAME(name) STR_CONCAT(name, __COUNTER__) // 把標(biāo)識符與 __COUNTER__合并, 就可以建立唯一的變數(shù)名稱了
而::tvm::OpRegEntry::RegisterOrGet(OpName)通過算子名稱在全局的算子注冊機(jī)對象中查找算子并返回OpRegEntry對象:
OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
return OpRegistry::Global()->RegisterOrGet(name);
}
Q:類OpRegEntry定義了什么珍促?
A:類的定義在include/tvm/ir/op.h:
class OpRegEntry {public:
const Op& op() const { return op_; }
inline OpRegEntry& describe(const std::string& descr);
inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
const std::string& description);
inline OpRegEntry& add_type_rel(const std::string& rel_name,
runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)> type_rel_func);
template <typename AttrsType>
inline OpRegEntry& set_attrs_type();
inline OpRegEntry& set_attrs_type_key(const String& key);
inline OpRegEntry& set_num_inputs(int32_t n);
inline OpRegEntry& set_support_level(int32_t level);
template <typename ValueType>
inline OpRegEntry& set_attr(const std::string& attr_name,
const ValueType& value, int plevel = 10);
inline void reset_attr(const std::string& attr_name);
inline OpRegEntry& set_name() {...}
TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
private:
template <typename, typename>
friend class AttrRegistry;
std::string name;
Op op_;
TVM_DLL OpRegEntry(uint32_t reg_index);
inline OpNode* get()
TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};
可以看到铃辖,大部分成員函數(shù)都是返回自身的指針,從而方便鏈?zhǔn)秸{(diào)用猪叙,它們的實(shí)現(xiàn)代碼與定義在同一個(gè)頭文件中娇斩。其中的get()私有成員函數(shù)會返回OpNode指針,其它成員函數(shù)通過get()來獲取算子節(jié)點(diǎn)的指針:
inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }
Q:鏈?zhǔn)秸{(diào)用中的幾個(gè)函數(shù)作用是什么穴翩?
A:鏈?zhǔn)秸{(diào)用中的幾個(gè)函數(shù)都是對OpNode節(jié)點(diǎn)對象的成員進(jìn)行賦值犬第,所以需要了解class OpNode的定義:
class OpNode : public RelayExprNode {
public:
String name; // 算子的名稱
mutable FuncType op_type; // 算子的類型
String description; // 算子的具體描述,可以用在自動生成說明文檔
Array<AttrFieldInfo> arguments; // 算子的輸入?yún)?shù)信息
String attrs_type_key; // 屬性字段的類型鍵值芒帕,可以為空
uint32_t attrs_type_index{0}; // 屬性的類型索引
int32_t num_inputs = -1; // 算子輸入?yún)?shù)個(gè)數(shù)歉嗓,-1表示可變長
int32_t support_level = 10; // 算子的支持等級,值越低優(yōu)先級越高背蟆。void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
v->Visit("arguments", &arguments);
v->Visit("attrs_type_key", &attrs_type_key);
v->Visit("num_inputs", &num_inputs);
v->Visit("support_level", &support_level);
}
...
static constexpr const char* _type_key = "Op";
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
private:
...
};
所以在鏈?zhǔn)秸{(diào)用的函數(shù)中鉴分,只做了簡單賦值的函數(shù)是:
describe()就是給OpNode->description賦值哮幢;
set_num_inputs()是設(shè)置輸入?yún)?shù)個(gè)數(shù);
set_support_level()是設(shè)置支持等級志珍;
add_argument()為arguments數(shù)組添加元素橙垢;
set_attr<>()會調(diào)用class AttrRegistry中的UpdateAttr()方法進(jìn)行屬性更新。
其中伦糯,對于add_argument()柜某,因?yàn)門VM將每個(gè)算子的參數(shù)都用AttrFieldInfo描述,而AttrFieldInfo實(shí)際的內(nèi)容是AttrFieldInfoNode:
class AttrFieldInfoNode : public Object {
public:
String name; // 字段名稱
String type_info; // 類型說明
String description; // 詳細(xì)描述
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
所以add_argument()在賦值前敛纲,會創(chuàng)建一個(gè)AttrFieldInfoNode對象再把它放入到arguments數(shù)組中:
inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
const std::string& description) {
auto n = make_object<AttrFieldInfoNode>();
n->name = name;
n->type_info = type;
n->description = description;
get()->arguments.push_back(AttrFieldInfo(n));
return *this;
}
需要詳細(xì)說明的是.add_type_rel() 和.set_attr()喂击。
Q:add_type_rel()流程是怎么樣的?
A:在函數(shù)中會創(chuàng)建輸入與輸出的TypeVarNode淤翔,然后創(chuàng)建TypeRelationNode將類型關(guān)系函數(shù)管理起來惭等,并定義一個(gè)FuncTypeNode將這些定義好的對象作為輸入,最終賦值給op_type办铡。
inline OpRegEntry& OpRegEntry::add_type_rel(
const std::string& rel_name,
runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypeRelationFn env_type_rel_func;
if (runtime::Registry::Get(func_name)) {
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
} else {
runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); // 創(chuàng)建registy注冊type_rel_func
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func; // 這個(gè)跟第二小節(jié)定義的類型關(guān)系函數(shù)相關(guān)聯(lián)
}
Array<TypeVar> type_params; // TypeVar是Type的子類
Array<Type> arg_types;
// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVar(name, TypeKind::kType); // 創(chuàng)建一個(gè)TypeVarNode對象
type_params.push_back(param);
arg_types.push_back(param);
}
Array<Type> ty_call_args = arg_types;
// Add output type.
auto out_param = TypeVar("out", TypeKind::kType);
type_params.push_back(out_param);
ty_call_args.push_back(out_param);
TypeConstraint type_rel =
TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());// 創(chuàng)建TypeRelationNode
auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); // 創(chuàng)建FuncTypeNode
get()->op_type = func_type; // 對op_type成員賦值
return *this;
}
TypeRelation()會創(chuàng)建一個(gè)TypeRelationNode,它實(shí)際上保存了之前定義的類型關(guān)系函數(shù)的相關(guān)信息:
class TypeRelationNode : public TypeConstraintNode {
public:
TypeRelationFn func;
Array<Type> args; // The type arguments to the type function.
int num_inputs; // Number of inputs arguments
Attrs attrs; // Attributes to the relation function
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
}
...
}
FuncType()創(chuàng)建FuncTypeNode琳要,將定義的輸入寡具、輸出、參數(shù)類型和類型關(guān)系節(jié)點(diǎn)作為輸入:
class FuncTypeNode : public TypeNode {
public:
Array<Type> arg_types; // type type of arguments
Type ret_type; // The type of return value
Array<TypeVar> type_params; // The type parameters of the function
Array<TypeConstraint> type_constraints; // potential constraint the type need to obey
void VisitAttrs(AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
...
}
4稚补、算子compute實(shí)現(xiàn)
有兩種方式實(shí)現(xiàn)算子計(jì)算過程:
(1)python端實(shí)現(xiàn)計(jì)算
在python/tvm/topi/transform.py添加:
def axis_abs(x, axis, indice):
"""Take absolute value of the input of axis in x, element-wise.
Parameters
----------
x : tvm.te.Tensor
Input argument.
axis: int
Input argument.
indice: int
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
ishape = x.shape
assert len(ishape) == 3
assert indice < get_const_int(ishape[axis])
assert indice >= 0
if axis == 0:
return te.compute(x.shape, lambda i,j,k: te.if_then_else(x[i,j,k] >= 0, x[i,j,k],
te.if_then_else(i == indice, -x[i,j,k], x[i,j,k])))
elif axis == 1:
return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
te.if_then_else(j == indice, -x[i, j, k], x[i, j, k])))
else:
return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
te.if_then_else(k == indice, -x[i, j, k], x[i, j, k])))
并且在python/tvm/relay/op/_transform.py中設(shè)置算子計(jì)算函數(shù)屬性:
@_reg.register_compute("axis_abs") # 設(shè)置算子的計(jì)算函數(shù)屬性童叠,默認(rèn)的level為10
def compute_axis_abs(attrs, inputs, output_type):
"""Compute definition of axis_abs"""
return topi.axis_abs(inputs[0], attrs.axis, attrs.indice)
(2)C++端實(shí)現(xiàn)計(jì)算
在src/relay/op/tensor/transform.cc添加算子計(jì)算函數(shù):
Array<te::Tensor> AxisAbsCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
// TODO
}
并且調(diào)用RELAY_REGISTER_OP("axis_abs")注冊算子時(shí)需要設(shè)置它的計(jì)算函數(shù)屬性:
.set_attr<FTVMCompute>("FTVMCompute", AxisAbsCompute)
此時(shí)在python/tvm/topi/transform.py中的算子實(shí)現(xiàn)可以直接調(diào)用cpp的代碼:
def axis_abs(x, axis, indice):
"""Take absolute value of the input of axis in x, element-wise.
Parameters
----------
x : tvm.te.Tensor
Input argument.
axis: int
Input argument.
indice: int
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return cpp.axis_abs(x, axis, indice)
5、注冊算子的compute课幕、schedule
在實(shí)現(xiàn)了算子compute邏輯以后厦坛,需要與我們實(shí)現(xiàn)的算子接口綁定在一起。在TVM中乍惊,這就需要不僅實(shí)現(xiàn)算子的compute接口杜秸,還要實(shí)現(xiàn)對應(yīng)的schedule。而strategy就是對compute選擇合適的schedule润绎。需要在python/tvm/relay/op/strategy/generic.py添加算子的計(jì)算策略:
def wrap_compute_axis_abs(topi_compute):
"""Wrap axis_abs topi compute"""
def _compute_axis_abs(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.indice)]
return _compute_axis_abs
@override_native_generic_func("axis_abs_strategy")
def axis_abs_strategy(attrs, inputs, out_type, target):
"""axis_abs generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_axis_abs(topi.axis_abs),
wrap_topi_schedule(topi.generic.schedule_injective),
name="axix_abs.generic",
)
return strategy
在python/tvm/relay/op/_transform.py中將算子與計(jì)算策略關(guān)聯(lián)起來:
_reg.register_strategy("axis_abs", strategy.axis_abs_strategy)
6撬碟、為算子生成調(diào)用節(jié)點(diǎn)并注冊 API hook
現(xiàn)在有一個(gè)可以調(diào)用的relay算子,下一步就是如何通過relay call node調(diào)用莉撇。這就需要實(shí)現(xiàn)一個(gè)函數(shù)呢蛤,傳遞相應(yīng)的參數(shù)給對應(yīng)的relay算子,并且返回對應(yīng)算子的Call Node(這個(gè)算子最終在Relay表達(dá)式的AST里面)棍郎。需要在src\relay\op\tensor\transform.cc添加:
Expr MakeAxisAbs(Expr data, int axis, int indice) {
auto attrs = make_object<AxisAbsAttrs>();
attrs->axis = axis;
attrs->indice = indice;
static const Op& op = Op::Get("axis_abs");
return Call(op, {data}, Attrs(attrs), {}); // 會創(chuàng)建一個(gè)CallNode實(shí)例
}
TVM_REGISTER_GLOBAL("relay.op._make.axis_abs").set_body_typed(MakeAxisAbs);
Q:Call Node是什么其障?
A:CallNode類是ExprNode的子類,它在程序調(diào)用Call函數(shù)時(shí)被實(shí)例化:
class CallNode : public ExprNode {
protected:
Object::FDeleter saved_deleter_;
static void Deleter_(Object* ptr);
public:
Expr op; // 算子的計(jì)算表達(dá)函數(shù)
tvm::Array<relay::Expr> args; // call函數(shù)的輸入?yún)?shù)
Attrs attrs; // 屬性
tvm::Array<Type> type_args; // 傳遞給多態(tài)(模板)函數(shù)的類型參數(shù)
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {...}
void SHashReduce(SHashReducer hash_reduce) const {...}
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
template <typename>
friend class runtime::ObjAllocatorBase;
friend class Call;
};
Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
ObjectPtr<CallNode> n = make_object<CallNode>(); // 創(chuàng)建CallNode
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
n->span = std::move(span);
data_ = std::move(n);
}
7涂佃、將Python API hook 封裝成簡潔的調(diào)用方式
為更方便的使用励翼,通常的做法是構(gòu)造單獨(dú)的函數(shù)蜈敢,因此最好封裝成更簡潔的python接口,需要在python/tvm/relay/op/transform.py中添加:
def axis_abs(data, axis=0, indice=0):
"""Computes abs of data along a certain axis indice.
Parameters
----------
data : relay.Expr
The source data to be invert permuated.
Returns
-------
ret : relay.Expr
Invert permuated data. Has the same type as data.
"""
return _make.axis_abs(data, axis, indice)
8抚笔、為新的relay 算子編寫測試用例
需要在tvm/tests/python/test_op_level3.py添加:
class TestAxisAbs:
dshape, axis, indice = tvm.testing.parameters( # 定義測試用例參數(shù)扶认,這里是輸入tensor的shape,axis和indice
((4, 4, 1), 1, 1),
((4, 4, 1), 0, 1),
((3, 3, 3), 1, 1),
)
def test_axis_abs(self, dshape, axis, indice):
x = relay.var("x", relay.TensorType(dshape, "int32")) # 定義relay輸入tensor
y = relay.axis_abs(x, axis=axis, indice=indice) # 定義axis_abs運(yùn)算表達(dá)式
yy = run_infer_type(y) # 推理運(yùn)算表達(dá)式的類型殊橙,定義在python/tvm/relay/testing/__init__.py
assert yy.checked_type == relay.TensorType(dshape, "int32") # 類型測試
data = np.random.randint(-5, 5, size=dshape).astype("int32")
op_res = create_executor().evaluate(y, {x: relay.const(data)}) # 創(chuàng)建執(zhí)行器并執(zhí)行算子推理
if axis == 0:
data[indice,:,:] = np.abs(data[indice,:,:])
elif axis == 1:
data[:,indice, :] = np.abs(data[:,indice,:])
else:
data[:,:,indice] = np.abs(data[:,:,indice])
ref_res = data
np.testing.assert_equal(op_res.numpy(), ref_res) # 對比numpy結(jié)果與relay的計(jì)算結(jié)果
如果沒有安裝pytest辐宾,要先安裝pytest,再運(yùn)行測試用例:
pip install pytest && cd tvm/tests/python && pytest relay/test_op_level3.py::TestAxisAbs
用例通過的結(jié)果如下:
三膨蛮、總結(jié)
本文根據(jù)TVM官方文檔給出的步驟添加了一個(gè)自定義算子叠纹,并對過程中有疑問的地方做了一些說明。