【TVM系列五】添加Relay自定義算子

一芙盘、前言

本文以實(shí)現(xiàn)一個(gè)axis_abs的自定義算子為例介紹如何在tvm中添加新的relay算子踏烙,該算子實(shí)現(xiàn)的功能是以輸入的3維tensor取某一維度的指定切片取絕對值舱痘。

二臭胜、添加自定義算子

新增relay算子基本是下面幾個(gè)步驟:

  • 定義新增算子的屬性節(jié)點(diǎn)(Attribute Node)录语,聲明在編譯時(shí)已知的固定參數(shù);

  • 為新增算子編寫類型關(guān)系注服,以集成到relay的類型系統(tǒng)中韭邓;

  • 使用C++RELAY_REGISTER_OP宏措近,為新增算子注冊生命參數(shù)數(shù)量、類型女淑、提示信息瞭郑;

  • 算子的compute實(shí)現(xiàn);

  • 注冊算子的compute鸭你、schedule凰浮;

  • 定義C++函數(shù),為新增算子生成調(diào)用節(jié)點(diǎn)苇本,并為該函數(shù)注冊 Python API hook;

  • 將上面的 Python API hook 封裝成簡潔的調(diào)用方式菜拓;

  • 為新的relay 算子編寫測試瓣窄。

1、定義新增算子的屬性節(jié)點(diǎn)(Attribute Node)

在include/tvm/relay/attrs/transform.h中增加算子的屬性數(shù)據(jù)結(jié)構(gòu):

/*! \brief Attributes used in axisabs operator */
struct AxisAbsAttrs : public tvm::AttrsNode<AxisAbsAttrs> {
    int axis;
    int indice;

    TVM_DECLARE_ATTRS(AxisAbsAttrs, "relay.attrs.AxisAbsAttrs") {
        TVM_ATTR_FIELD(axis).set_default(0).describe("Axis to abs");
        TVM_ATTR_FIELD(indice).set_default(0).describe("Indice to abs");
    }
};

Q:宏TVM_DECLARE_ATTRS 與 TVM_ATTR_FIELD的作用是什么纳鼎?
A:這兩個(gè)宏定義在 include/tvm/ir/attrs.h

#define TVM_DECLARE_ATTRS(ClassName, TypeKey)                    \
  static constexpr const char* _type_key = TypeKey;              \
  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
  template <typename FVisit>                                     \
  void __VisitAttrs__(FVisit& __fvisit__)  // NOLINT(*)

#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)

其中的TVM_DECLARE_FINAL_OBJECT_INFO定義在include/tvm/runtime/object.h

#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
   static const constexpr bool _type_final = true;           \
   static const constexpr int _type_child_slots = 0;         \
   TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
  
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                                     \
   static_assert(!ParentType::_type_final, "ParentObj marked as final");                        \
   static uint32_t RuntimeTypeIndex() {                                                         \
     static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 ||    \
                      TypeName::_type_child_slots < ParentType::_type_child_slots,             \
                  "Need to set _type_child_slots when parent specifies it.");                  \
     if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {                        \
       return TypeName::_type_index;                                                            \
     }                                                                                          \
     return _GetOrAllocRuntimeTypeIndex();                                                      \
   }                                                                                            \
  static uint32_t _GetOrAllocRuntimeTypeIndex() {                                              \
    static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(                               \
        TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
        TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow);                \
    return tindex;                                                                             \
  }

所以宏展開后定義的屬性節(jié)點(diǎn)數(shù)據(jù)結(jié)構(gòu)為:

struct AxisAbsAttrs : public tvm::ArrayNode<AxisAbsAttrs> {
    int axis;    
    static constexpr const char* _type_key = "relay.attrs.AxisAbsAttrs";
    static const constexpr bool _type_final = true;
    static const constexpr int _type_child_slots = 0;

    static_assert(!::tvm::BaseAttrsNode::_type_final, "ParentObj marked as final");

    static uint32_t RuntimeTypeIndex() {                                                       
        static_assert(AxisAbsAttrs::_type_child_slots == 0 || ::tvm::BaseAttrsNode::_type_child_slots == 0 ||    
                          AxisAbsAttrs::_type_child_slots < ::tvm::BaseAttrsNode::_type_child_slots,             
                      "Need to set _type_child_slots when parent specifies it.");                  
        if (AxisAbsAttrs::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {                        
            return AxisAbsAttrs::_type_index;                                                            
        }                                                                                          
         return _GetOrAllocRuntimeTypeIndex();                                                      
      }           

    static uint32_t _GetOrAllocRuntimeTypeIndex() {                                              
         static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(                               
         AxisAbsAttrs::_type_key, AxisAbsAttrs::_type_index, ::tvm::BaseAttrsNode::_GetOrAllocRuntimeTypeIndex(), 
         AxisAbsAttrs::_type_child_slots, AxisAbsAttrs::_type_child_slots_can_overflow);                
         return tindex;                                                                             
    }

    template <typename FVisit>                                    
    void __VisitAttrs__(FVisit& __fvisit__)  {
        __fvisit__(axis, &axis).set_default(0).describe("Axis to abs");
    }
}

可以看到俺夕,每個(gè)屬性節(jié)點(diǎn)都定義了獲取運(yùn)行時(shí)類型索引的函數(shù)RuntimeTypeIndex()以及訪問屬性內(nèi)部成員的模版函數(shù)VisitAttrs(FVisit& fvisit)。

Q:模版函數(shù)VisitAttrs(FVisit& fvisit)的調(diào)用過程是怎么樣的贱鄙?
A:首先分析定義在include/tvm/ir/attrs.h中的類class AttrsNode

template <typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
  void VisitAttrs(AttrVisitor* v) {
    ::tvm::detail::AttrNormalVisitor vis(v);
    self()->__VisitAttrs__(vis);
  }
  void VisitNonDefaultAttrs(AttrVisitor* v) {...}
  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {...}
  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {...}
  void SHashReduce(SHashReducer hash_reducer) const {...}
  Array<AttrFieldInfo> ListFieldInfo() const final {...}
private:
  DerivedType* self() const {
    return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
  }
};

它是一個(gè)模版類劝贸,模版參數(shù)是繼承它的子類類型,在成員函數(shù)VisitAttrs(AttrVisitor* v)中逗宁,傳入屬性訪問器類AttrVisitor對象:

class AttrVisitor {
 public:
  //! \cond Doxygen_Suppress
  TVM_DLL virtual ~AttrVisitor() = default;
  TVM_DLL virtual void Visit(const char* key, double* value) = 0;
  TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
  TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
  TVM_DLL virtual void Visit(const char* key, int* value) = 0;
  TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
  TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
  TVM_DLL virtual void Visit(const char* key, void** value) = 0;
  TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
  TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
  TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
  void Visit(const char* key, ENum* ptr) {
    static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
                  "declare enum to be enum int to use visitor");
    this->Visit(key, reinterpret_cast<int*>(ptr));
  }
  //! \endcond
};

然后通過::tvm::detail::AttrNormalVisitor vis(v);包裹一層普通屬性訪問函數(shù):

// Wrapper for normal visitor.
class AttrNormalVisitor {
public:
  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
  template <typename T>
  AttrNopEntry operator()(const char* key, T* value) {
    visitor_->Visit(key, value);
    return AttrNopEntry();
  }

private:
  AttrVisitor* visitor_;
};

它重載了運(yùn)算符“()”映九,當(dāng)class AttrsNode通過self()->VisitAttrs(vis)獲取子類的對象并通過子類對象調(diào)用VisitAttrs(FVisit& fvisit) 時(shí),隨即調(diào)用了fvisit(axis, &axis)瞎颗,這個(gè)fvisit最終調(diào)到的就是class AttrNormalVisitor 中的重載"()"函數(shù)件甥,這個(gè)函數(shù)會返回一個(gè)結(jié)構(gòu)體用于支持鏈?zhǔn)秸{(diào)用:

// helper entry that does nothing in set_default/bound/describe calls.
struct AttrNopEntry {
  using TSelf = AttrNopEntry;
  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
  template <typename T>
  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {return *this;}
  template <typename T>
  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {return *this;}
  template <typename T>
  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {return *this;}
};

這些調(diào)用實(shí)際上什么都沒有做就返回了其自身。

2哼拔、編寫算子類型關(guān)系引有,集成到Relay的類型系統(tǒng)

為了算子注冊的靈活性以及relay算子有更好的泛化能力,relay算子通過輸入輸出之間的類型關(guān)系來實(shí)例化倦逐。本質(zhì)上譬正,算子類型關(guān)系除了推導(dǎo)輸出類型外,還能夠強(qiáng)制指定類型規(guī)則(檢查輸入類型)檬姥。需要在src\relay\op\tensor\transform.cc中添加算子的類型關(guān)系處理函數(shù):

bool AxisAbsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
               const TypeReporter& reporter) {
    // types: [data, output]
    ICHECK_EQ(types.size(), 2);
    const auto* data = types[0].as<TensorTypeNode>();
    if (data == nullptr) {
      ICHECK(types[0].as<IncompleteTypeNode>())
          << "cast: expect input type to be TensorType but get " << types[0];
      return false;
    }
    const auto* param = attrs.as<AxisAbsAttrs>();
    const int ndim = static_cast<int>(data->shape.size());
    const int axis = param->axis;
    const int axis_len = data->shape[axis].as<IntImmNode>()->value;
    const int indice = param->indice;

    ICHECK(0 <= axis && axis < ndim)
      << "axis_abs only accepts `axis` in [0, data.ndim - 1]"
      << ", but got axis = " << axis << ", and data.ndim = " << ndim;

    ICHECK(0 <= indice && indice < axis_len)
      << "axis_abs only accepts `indice` in [0, data[axis] - 1"
      << ", but got indice = " << indice << ", and data[axis] = " << axis_len;

    reporter->Assign(types[1], TensorType(data->shape, data->dtype));
    return true;
}

Q:類型關(guān)系處理函數(shù)在什么時(shí)候調(diào)用曾我?
A:類型關(guān)系處理函數(shù)在注冊Relay算子時(shí)通過鏈?zhǔn)秸{(diào)用add_type_rel()注冊。

Q:函數(shù)輸入?yún)?shù)types的含意是什么穿铆?
A:types傳入的是一個(gè)數(shù)組引用您单,內(nèi)容一般為輸入與輸出的TensorType,首先看class TensorTypeNode:

class TensorTypeNode : public BaseTensorTypeNode {
 public:

  Array<PrimExpr> shape;   // Tensor的shape
  DataType dtype;    // Tensor中數(shù)據(jù)類型
  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("shape", &shape);
    v->Visit("dtype", &dtype);
    v->Visit("span", &span);
  }
  bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {...}
  void SHashReduce(SHashReducer hash_reduce) const {...}
  TVM_DLL PrimExpr Size() const;
  static constexpr const char* _type_key = "relay.TensorType";
  TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};

它定義了一個(gè)Tensor所需要的基本數(shù)據(jù)信息如:shape與數(shù)據(jù)類型荞雏,但是并沒有實(shí)際的數(shù)據(jù)虐秦,所以類名也就叫TensorTypeNode平酿。通過它可以獲取到輸入Tensor的類型信息從而對參數(shù)做合法性檢查。

Q:函數(shù)輸入?yún)?shù)reporter的含意是什么悦陋?
A:class TypeReporter是一個(gè)TypeReporterNode的容器類:

class TypeReporter : public ObjectRef {
 public:
  TypeReporter() {}
  explicit TypeReporter(ObjectPtr<Object> n) : ObjectRef(n) {}
  TypeReporterNode* operator->() const {
    return const_cast<TypeReporterNode*>(static_cast<const TypeReporterNode*>(get()));
  }
  using ContainerType = TypeReporterNode;
};

它重載了運(yùn)算符"->"蜈彼,所以:

reporter->Assign(types[1], TensorType(data->shape, data->dtype));

首先會實(shí)例化一個(gè)TensorType對象,因?yàn)槲覀兊睦邮菍δ骋粋€(gè)維度的數(shù)據(jù)取絕對值俺驶,所以輸出的數(shù)據(jù)shape及dtype與輸入相同幸逆。然后通過reporter->Assign()調(diào)用class TypeReporterNode中純虛函數(shù)virtual void Assign(dst, src) = 0,將創(chuàng)建好的TensorType對象賦值給輸出TensorType暮现,即types[1]还绘。

3、關(guān)聯(lián)算子的參數(shù)數(shù)目栖袋、屬性

這一步的操作拍顷,為自定義算子注冊算子名稱,通過調(diào)用接口增加算子注釋塘幅。這里需要用到C++的宏RELAY_REGISTER_OP昔案,涉及的參數(shù)含義如下:

  • Arity(參數(shù)數(shù)量)

  • 位置參數(shù)的名稱和描述

  • 支持級別(1 表示內(nèi)部實(shí)現(xiàn);較高的數(shù)字表示較少的內(nèi)部支持或外部支持的算子)

  • 算子的類型關(guān)系

  • 優(yōu)化算子時(shí)有用的其他注釋。

需要在src/relay/op/tensor/transform.cc中注冊算子并設(shè)置相關(guān)屬性:

RELAY_REGISTER_OP("axis_abs")
    .describe(R"doc(Computes the axis abs of a tensor.)doc") TVM_ADD_FILELINE)
    .set_num_inputs(1)
    .add_argument("data", "Tensor", "The input tensor")
    .set_support_level(3)
    .add_type_rel("axis_abs", AxisAbsRel)
    .set_attr<TOpPattern>("TOpPattern", kOpaque);

Q:宏RELAY_REGISTER_OP做了什么电媳?
A:RELAY_REGISTER_OP用于注冊Relay算子:

#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) 
#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op

#define TVM_REGISTER_OP(OpName)                          \
  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
      ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()

展開為:

static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_Op0=::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()

其中COUNTER為編譯器內(nèi)置宏踏揣,初值是0,每預(yù)編譯一次其值自己加1匾乓,通常配合 ## 使用捞稿,用于構(gòu)建唯一的標(biāo)識符,做法其實(shí)很簡單钝尸,把任意一個(gè)標(biāo)識符與 COUNTER 合并就可以了:

#define STR_CONCAT_(x, y) x##y  // 合并用的宏
#define STR_CONCAT(x, y) STR_CONCAT_(x, y)    // 因?yàn)?## 的特性 ( 阻止另一個(gè)宏的展開 )括享,需要中間層
#define UNIQUE_NAME(name) STR_CONCAT(name, __COUNTER__)  // 把標(biāo)識符與 __COUNTER__合并, 就可以建立唯一的變數(shù)名稱了

而::tvm::OpRegEntry::RegisterOrGet(OpName)通過算子名稱在全局的算子注冊機(jī)對象中查找算子并返回OpRegEntry對象:

OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
  return OpRegistry::Global()->RegisterOrGet(name);
}

Q:類OpRegEntry定義了什么珍促?
A:類的定義在include/tvm/ir/op.h:

class OpRegEntry {public:
  const Op& op() const { return op_; }
  inline OpRegEntry& describe(const std::string& descr); 
  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
                                  const std::string& description);
  inline OpRegEntry& add_type_rel(const std::string& rel_name,
      runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)> type_rel_func);
  template <typename AttrsType>
  inline OpRegEntry& set_attrs_type();
  inline OpRegEntry& set_attrs_type_key(const String& key);
  inline OpRegEntry& set_num_inputs(int32_t n); 
  inline OpRegEntry& set_support_level(int32_t level); 
  template <typename ValueType>
  inline OpRegEntry& set_attr(const std::string& attr_name, 
                              const ValueType& value, int plevel = 10);
  inline void reset_attr(const std::string& attr_name);
  inline OpRegEntry& set_name() {...}
  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
private:
  template <typename, typename>
  friend class AttrRegistry;
  std::string name;
  Op op_;
  TVM_DLL OpRegEntry(uint32_t reg_index);
  inline OpNode* get()
  TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};

可以看到铃辖,大部分成員函數(shù)都是返回自身的指針,從而方便鏈?zhǔn)秸{(diào)用猪叙,它們的實(shí)現(xiàn)代碼與定義在同一個(gè)頭文件中娇斩。其中的get()私有成員函數(shù)會返回OpNode指針,其它成員函數(shù)通過get()來獲取算子節(jié)點(diǎn)的指針:

inline OpNode* OpRegEntry::get() { return const_cast<OpNode*>(op_.operator->()); }

Q:鏈?zhǔn)秸{(diào)用中的幾個(gè)函數(shù)作用是什么穴翩?
A:鏈?zhǔn)秸{(diào)用中的幾個(gè)函數(shù)都是對OpNode節(jié)點(diǎn)對象的成員進(jìn)行賦值犬第,所以需要了解class OpNode的定義:

class OpNode : public RelayExprNode {
public:
  String name;  // 算子的名稱
  mutable FuncType op_type;  // 算子的類型
  String description;  // 算子的具體描述,可以用在自動生成說明文檔
  Array<AttrFieldInfo> arguments;  // 算子的輸入?yún)?shù)信息
  String attrs_type_key;  // 屬性字段的類型鍵值芒帕,可以為空
  uint32_t attrs_type_index{0};  // 屬性的類型索引
  int32_t num_inputs = -1;  // 算子輸入?yún)?shù)個(gè)數(shù)歉嗓,-1表示可變長
  int32_t support_level = 10; // 算子的支持等級,值越低優(yōu)先級越高背蟆。void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("op_type", &op_type);
    v->Visit("description", &description);
    v->Visit("arguments", &arguments);
    v->Visit("attrs_type_key", &attrs_type_key);
    v->Visit("num_inputs", &num_inputs);
    v->Visit("support_level", &support_level);
  }
  ...
  static constexpr const char* _type_key = "Op";
  TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
private:
   ...
};

所以在鏈?zhǔn)秸{(diào)用的函數(shù)中鉴分,只做了簡單賦值的函數(shù)是:

  • describe()就是給OpNode->description賦值哮幢;

  • set_num_inputs()是設(shè)置輸入?yún)?shù)個(gè)數(shù);

  • set_support_level()是設(shè)置支持等級志珍;

  • add_argument()為arguments數(shù)組添加元素橙垢;

  • set_attr<>()會調(diào)用class AttrRegistry中的UpdateAttr()方法進(jìn)行屬性更新。

其中伦糯,對于add_argument()柜某,因?yàn)門VM將每個(gè)算子的參數(shù)都用AttrFieldInfo描述,而AttrFieldInfo實(shí)際的內(nèi)容是AttrFieldInfoNode:

class AttrFieldInfoNode : public Object {
 public:
  String name; // 字段名稱
  String type_info;  // 類型說明
  String description; // 詳細(xì)描述

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("type_info", &type_info);
    v->Visit("description", &description);
  }

  static constexpr const char* _type_key = "AttrFieldInfo";
  static constexpr bool _type_has_method_sequal_reduce = false;
  static constexpr bool _type_has_method_shash_reduce = false;
  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

所以add_argument()在賦值前敛纲,會創(chuàng)建一個(gè)AttrFieldInfoNode對象再把它放入到arguments數(shù)組中:

inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
                                            const std::string& description) {
  auto n = make_object<AttrFieldInfoNode>();
  n->name = name;
  n->type_info = type;
  n->description = description;
  get()->arguments.push_back(AttrFieldInfo(n));
  return *this;
}

需要詳細(xì)說明的是.add_type_rel() 和.set_attr()喂击。

Q:add_type_rel()流程是怎么樣的?
A:在函數(shù)中會創(chuàng)建輸入與輸出的TypeVarNode淤翔,然后創(chuàng)建TypeRelationNode將類型關(guān)系函數(shù)管理起來惭等,并定義一個(gè)FuncTypeNode將這些定義好的對象作為輸入,最終賦值給op_type办铡。

inline OpRegEntry& OpRegEntry::add_type_rel(
    const std::string& rel_name,
    runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
        type_rel_func) {
  auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
  TypeRelationFn env_type_rel_func;

  if (runtime::Registry::Get(func_name)) {
    auto env_func = EnvFunc::Get(func_name);
    env_type_rel_func = env_func;
  } else {
    runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); // 創(chuàng)建registy注冊type_rel_func
    auto env_func = EnvFunc::Get(func_name);
    env_type_rel_func = env_func; // 這個(gè)跟第二小節(jié)定義的類型關(guān)系函數(shù)相關(guān)聯(lián)
  }

  Array<TypeVar> type_params;  // TypeVar是Type的子類
  Array<Type> arg_types;
  // Add inputs.
  std::string input_name_prefix = "in";
  for (int i = 0; i < get()->num_inputs; i++) {
    auto name = input_name_prefix + std::to_string(i);
    auto param = TypeVar(name, TypeKind::kType);  // 創(chuàng)建一個(gè)TypeVarNode對象
    type_params.push_back(param);
    arg_types.push_back(param);
  }
  Array<Type> ty_call_args = arg_types;

  // Add output type.
  auto out_param = TypeVar("out", TypeKind::kType);
  type_params.push_back(out_param);
  ty_call_args.push_back(out_param);
  TypeConstraint type_rel =
  TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());// 創(chuàng)建TypeRelationNode
  auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); // 創(chuàng)建FuncTypeNode
  get()->op_type = func_type;  // 對op_type成員賦值  
  return *this;
}

TypeRelation()會創(chuàng)建一個(gè)TypeRelationNode,它實(shí)際上保存了之前定義的類型關(guān)系函數(shù)的相關(guān)信息:

class TypeRelationNode : public TypeConstraintNode {
public:
  TypeRelationFn func;  
  Array<Type> args; // The type arguments to the type function.
  int num_inputs; // Number of inputs arguments
  Attrs attrs; // Attributes to the relation function
  void VisitAttrs(AttrVisitor* v) {
    v->Visit("func", &func);
    v->Visit("args", &args);
    v->Visit("num_inputs", &num_inputs);
    v->Visit("attrs", &attrs);
    v->Visit("span", &span);
  }
  ...
}

FuncType()創(chuàng)建FuncTypeNode琳要,將定義的輸入寡具、輸出、參數(shù)類型和類型關(guān)系節(jié)點(diǎn)作為輸入:

class FuncTypeNode : public TypeNode {
 public:
  Array<Type> arg_types; // type type of arguments
  Type ret_type; // The type of return value
  Array<TypeVar> type_params; // The type parameters of the function
  Array<TypeConstraint> type_constraints; // potential constraint the type need to obey
  void VisitAttrs(AttrVisitor* v) {
    v->Visit("arg_types", &arg_types);
    v->Visit("ret_type", &ret_type);
    v->Visit("type_params", &type_params);
    v->Visit("type_constraints", &type_constraints);
    v->Visit("span", &span);
  }
  ...
}
4稚补、算子compute實(shí)現(xiàn)

有兩種方式實(shí)現(xiàn)算子計(jì)算過程:

(1)python端實(shí)現(xiàn)計(jì)算

在python/tvm/topi/transform.py添加:

def axis_abs(x, axis, indice):
    """Take absolute value of the input of axis in x, element-wise.

    Parameters
    ----------
    x : tvm.te.Tensor
        Input argument.
    axis: int
        Input argument.
    indice: int
        Input argument.
    Returns
    -------
    y : tvm.te.Tensor
        The result.
    """
    ishape = x.shape
    assert len(ishape) == 3
    assert indice < get_const_int(ishape[axis])
    assert indice >= 0
    if axis == 0:
        return te.compute(x.shape, lambda i,j,k: te.if_then_else(x[i,j,k] >= 0, x[i,j,k],
                            te.if_then_else(i == indice, -x[i,j,k], x[i,j,k])))
    elif axis == 1:
        return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
                            te.if_then_else(j == indice, -x[i, j, k], x[i, j, k])))
    else:
        return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
                            te.if_then_else(k == indice, -x[i, j, k], x[i, j, k])))

并且在python/tvm/relay/op/_transform.py中設(shè)置算子計(jì)算函數(shù)屬性:

@_reg.register_compute("axis_abs")  # 設(shè)置算子的計(jì)算函數(shù)屬性童叠,默認(rèn)的level為10
def compute_axis_abs(attrs, inputs, output_type):
    """Compute definition of axis_abs"""
    return topi.axis_abs(inputs[0], attrs.axis, attrs.indice)

(2)C++端實(shí)現(xiàn)計(jì)算

在src/relay/op/tensor/transform.cc添加算子計(jì)算函數(shù):

Array<te::Tensor> AxisAbsCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
                                    const Type& out_type) {
    // TODO
}

并且調(diào)用RELAY_REGISTER_OP("axis_abs")注冊算子時(shí)需要設(shè)置它的計(jì)算函數(shù)屬性:

.set_attr<FTVMCompute>("FTVMCompute", AxisAbsCompute)

此時(shí)在python/tvm/topi/transform.py中的算子實(shí)現(xiàn)可以直接調(diào)用cpp的代碼:

def axis_abs(x, axis, indice):
    """Take absolute value of the input of axis in x, element-wise.

    Parameters
    ----------
    x : tvm.te.Tensor
        Input argument.
    axis: int
        Input argument.
    indice: int
        Input argument.
    Returns
    -------
    y : tvm.te.Tensor
        The result.
    """
    return cpp.axis_abs(x, axis, indice)
5、注冊算子的compute课幕、schedule

在實(shí)現(xiàn)了算子compute邏輯以后厦坛,需要與我們實(shí)現(xiàn)的算子接口綁定在一起。在TVM中乍惊,這就需要不僅實(shí)現(xiàn)算子的compute接口杜秸,還要實(shí)現(xiàn)對應(yīng)的schedule。而strategy就是對compute選擇合適的schedule润绎。需要在python/tvm/relay/op/strategy/generic.py添加算子的計(jì)算策略:

def wrap_compute_axis_abs(topi_compute):
    """Wrap axis_abs topi compute"""

    def _compute_axis_abs(attrs, inputs, _):
        return [topi_compute(inputs[0], attrs.axis, attrs.indice)]

    return _compute_axis_abs

@override_native_generic_func("axis_abs_strategy")
def axis_abs_strategy(attrs, inputs, out_type, target):
    """axis_abs generic strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_axis_abs(topi.axis_abs),
        wrap_topi_schedule(topi.generic.schedule_injective),
        name="axix_abs.generic",
    )
    return strategy

在python/tvm/relay/op/_transform.py中將算子與計(jì)算策略關(guān)聯(lián)起來:

_reg.register_strategy("axis_abs", strategy.axis_abs_strategy)
6撬碟、為算子生成調(diào)用節(jié)點(diǎn)并注冊 API hook

現(xiàn)在有一個(gè)可以調(diào)用的relay算子,下一步就是如何通過relay call node調(diào)用莉撇。這就需要實(shí)現(xiàn)一個(gè)函數(shù)呢蛤,傳遞相應(yīng)的參數(shù)給對應(yīng)的relay算子,并且返回對應(yīng)算子的Call Node(這個(gè)算子最終在Relay表達(dá)式的AST里面)棍郎。需要在src\relay\op\tensor\transform.cc添加:

Expr MakeAxisAbs(Expr data, int axis, int indice) {
    auto attrs = make_object<AxisAbsAttrs>();
    attrs->axis = axis;
    attrs->indice = indice;
    static const Op& op = Op::Get("axis_abs");
    return Call(op, {data}, Attrs(attrs), {}); // 會創(chuàng)建一個(gè)CallNode實(shí)例
}

TVM_REGISTER_GLOBAL("relay.op._make.axis_abs").set_body_typed(MakeAxisAbs);

Q:Call Node是什么其障?
A:CallNode類是ExprNode的子類,它在程序調(diào)用Call函數(shù)時(shí)被實(shí)例化:

class CallNode : public ExprNode {
 protected:
  Object::FDeleter saved_deleter_;
  static void Deleter_(Object* ptr);
 public:
  Expr op; // 算子的計(jì)算表達(dá)函數(shù)
  tvm::Array<relay::Expr> args;  // call函數(shù)的輸入?yún)?shù)
  Attrs attrs; // 屬性
  tvm::Array<Type> type_args;  // 傳遞給多態(tài)(模板)函數(shù)的類型參數(shù)
  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("op", &op);
    v->Visit("args", &args);
    v->Visit("attrs", &attrs);
    v->Visit("type_args", &type_args);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }
  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {...}
  void SHashReduce(SHashReducer hash_reduce) const {...}
  static constexpr const char* _type_key = "relay.Call";
  TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
  template <typename>
  friend class runtime::ObjAllocatorBase;
  friend class Call;
};

Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span span) {
  ObjectPtr<CallNode> n = make_object<CallNode>();  // 創(chuàng)建CallNode
  n->op = std::move(op);
  n->args = std::move(args);
  n->attrs = std::move(attrs);
  n->type_args = std::move(type_args);
  n->span = std::move(span); 
  data_ = std::move(n);
}
7涂佃、將Python API hook 封裝成簡潔的調(diào)用方式

為更方便的使用励翼,通常的做法是構(gòu)造單獨(dú)的函數(shù)蜈敢,因此最好封裝成更簡潔的python接口,需要在python/tvm/relay/op/transform.py中添加:

def axis_abs(data, axis=0, indice=0):
    """Computes abs of data along a certain axis indice.

    Parameters
    ----------
    data : relay.Expr
        The source data to be invert permuated.

    Returns
    -------
    ret : relay.Expr
        Invert permuated data. Has the same type as data.
    """
    return _make.axis_abs(data, axis, indice)
8抚笔、為新的relay 算子編寫測試用例

需要在tvm/tests/python/test_op_level3.py添加:

class TestAxisAbs:
    dshape, axis, indice = tvm.testing.parameters(     # 定義測試用例參數(shù)扶认,這里是輸入tensor的shape,axis和indice
        ((4, 4, 1), 1, 1),
        ((4, 4, 1), 0, 1),
        ((3, 3, 3), 1, 1),
    )

    def test_axis_abs(self, dshape, axis, indice):
        x = relay.var("x", relay.TensorType(dshape, "int32"))  # 定義relay輸入tensor
        y = relay.axis_abs(x, axis=axis, indice=indice)    # 定義axis_abs運(yùn)算表達(dá)式
        yy = run_infer_type(y)      # 推理運(yùn)算表達(dá)式的類型殊橙,定義在python/tvm/relay/testing/__init__.py
        assert yy.checked_type == relay.TensorType(dshape, "int32")  # 類型測試

        data = np.random.randint(-5, 5, size=dshape).astype("int32")
        op_res = create_executor().evaluate(y, {x: relay.const(data)})  # 創(chuàng)建執(zhí)行器并執(zhí)行算子推理
        if axis == 0:
            data[indice,:,:] = np.abs(data[indice,:,:])
        elif axis == 1:
            data[:,indice, :] = np.abs(data[:,indice,:])
        else:
            data[:,:,indice] = np.abs(data[:,:,indice])
        ref_res = data
        np.testing.assert_equal(op_res.numpy(), ref_res)  # 對比numpy結(jié)果與relay的計(jì)算結(jié)果

如果沒有安裝pytest辐宾,要先安裝pytest,再運(yùn)行測試用例:

pip install pytest && cd tvm/tests/python && pytest relay/test_op_level3.py::TestAxisAbs

用例通過的結(jié)果如下:

image.png

三膨蛮、總結(jié)

本文根據(jù)TVM官方文檔給出的步驟添加了一個(gè)自定義算子叠纹,并對過程中有疑問的地方做了一些說明。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末敞葛,一起剝皮案震驚了整個(gè)濱河市誉察,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌惹谐,老刑警劉巖持偏,帶你破解...
    沈念sama閱讀 217,542評論 6 504
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異氨肌,居然都是意外死亡鸿秆,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,822評論 3 394
  • 文/潘曉璐 我一進(jìn)店門怎囚,熙熙樓的掌柜王于貴愁眉苦臉地迎上來卿叽,“玉大人,你說我怎么就攤上這事恳守】加ぃ” “怎么了?”我有些...
    開封第一講書人閱讀 163,912評論 0 354
  • 文/不壞的土叔 我叫張陵催烘,是天一觀的道長沥阱。 經(jīng)常有香客問我,道長伊群,這世上最難降的妖魔是什么喳钟? 我笑而不...
    開封第一講書人閱讀 58,449評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮在岂,結(jié)果婚禮上奔则,老公的妹妹穿的比我還像新娘。我一直安慰自己蔽午,他們只是感情好易茬,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,500評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般抽莱。 火紅的嫁衣襯著肌膚如雪范抓。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,370評論 1 302
  • 那天食铐,我揣著相機(jī)與錄音匕垫,去河邊找鬼。 笑死虐呻,一個(gè)胖子當(dāng)著我的面吹牛象泵,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播斟叼,決...
    沈念sama閱讀 40,193評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼偶惠,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了朗涩?” 一聲冷哼從身側(cè)響起忽孽,我...
    開封第一講書人閱讀 39,074評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎谢床,沒想到半個(gè)月后兄一,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,505評論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡识腿,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,722評論 3 335
  • 正文 我和宋清朗相戀三年瘾腰,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片覆履。...
    茶點(diǎn)故事閱讀 39,841評論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖费薄,靈堂內(nèi)的尸體忽然破棺而出硝全,到底是詐尸還是另有隱情,我是刑警寧澤楞抡,帶...
    沈念sama閱讀 35,569評論 5 345
  • 正文 年R本政府宣布伟众,位于F島的核電站,受9級特大地震影響召廷,放射性物質(zhì)發(fā)生泄漏凳厢。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,168評論 3 328
  • 文/蒙蒙 一竞慢、第九天 我趴在偏房一處隱蔽的房頂上張望先紫。 院中可真熱鬧,春花似錦筹煮、人聲如沸遮精。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,783評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽本冲。三九已至准脂,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間檬洞,已是汗流浹背狸膏。 一陣腳步聲響...
    開封第一講書人閱讀 32,918評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留添怔,地道東北人湾戳。 一個(gè)月前我還...
    沈念sama閱讀 47,962評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像澎灸,于是被迫代替她去往敵國和親院塞。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,781評論 2 354

推薦閱讀更多精彩內(nèi)容