實(shí)現(xiàn)TensorRT-7.0插件自由限书!(如果不踩坑使用TensorRT插件功能)

本系列為新TensorRT的第一篇,為什么叫新章咧,因?yàn)橹耙呀?jīng)寫了兩篇關(guān)于TensorRT的文章倦西,是關(guān)于TensorRT-5.0版本的。好久沒(méi)寫關(guān)于TensorRT的文章了赁严,所幸就以來(lái)開(kāi)頭吧~

接下來(lái)將要講解的TensorRT扰柠,將會(huì)是基于7.0版本粉铐。

7版本開(kāi)頭的TensorRT變化還是挺大的,增加了很多新特性耻矮,但是TensorRT的核心運(yùn)作方式還是沒(méi)有什么變化的秦躯,關(guān)于TensorRT的介紹可以看之前寫的這兩篇:

本文的內(nèi)容呢,主要是講解:

  • TensorRT自定義插件的使用方式
  • 如何添加自己的自定義算子

看完本篇可以讓你少踩巨多坑裆装,客官記得常來(lái)看啊踱承。

前言

隨著tensorRT的不斷發(fā)展(v5->v6->v7),TensorRT的插件的使用方式也在不斷更新哨免。插件接口也在不斷地變化茎活,由v5版本的IPluginV2Ext,到v6版本的IPluginV2IOExtIPluginV2DynamicExt琢唾。未來(lái)不知道會(huì)不會(huì)出來(lái)新的API载荔,不過(guò)這也不是咱要考慮的問(wèn)題,因?yàn)門ensorRT的后兼容性做的很好采桃,根本不用擔(dān)心你寫的舊版本插件在新版本上無(wú)法運(yùn)行懒熙。

目前的plugin-API:

QQ20201103-101737

TensorRT插件的存在目的,主要是為了讓我們實(shí)現(xiàn)TensorRT目前還不支持的算子普办,畢竟眾口難調(diào)嘛工扎,我們?cè)谵D(zhuǎn)換過(guò)程中肯定會(huì)有op不支持的情況。這個(gè)時(shí)候就需要使用TensorRT的plugin去實(shí)現(xiàn)我們的自己的op衔蹲。此時(shí)我們需要通過(guò)TensorRT提供的接口去實(shí)現(xiàn)自己的op肢娘,因此這個(gè)plugin的生命周期也需要遵循TensorRT的規(guī)則

一個(gè)簡(jiǎn)單的了解

那么plugin到底長(zhǎng)啥樣舆驶,可以先看看TensorRT的官方plugin庫(kù)長(zhǎng)啥樣橱健,截止寫這篇文章時(shí),master分支是7.2版本的plugin:

https://github.com/NVIDIA/TensorRT/tree/master/plugin

tensorrt-plugin

官方提供的插件已經(jīng)相當(dāng)多沙廉,而且TensorRT開(kāi)源了plugin部分(可以讓我們白嫖拘荡!)。并且可以看到其源碼撬陵,通過(guò)模仿源碼來(lái)學(xué)習(xí)plugin是如何寫的俱病。

如果要添加自己的算子,可以在官方的plugin庫(kù)里頭進(jìn)行修改添加袱结,然后編譯官方的plugin庫(kù)。將生成的libnvinfer_plugin.so.7替換原本的.so文件即可途凫」讣校或者自己寫一個(gè)類似于官方plugin的組件,將名稱替換一下,同樣生成.so姐刁,在TensorRT的推理項(xiàng)目中引用這個(gè)動(dòng)態(tài)鏈接庫(kù)即可定嗓。

以下介紹中碧信,我們需要寫的IPlugin簡(jiǎn)稱為插件op膝宁。

開(kāi)始寫插件

有興趣的可以先看看TensorRT的官方文檔碍论,官方文檔的介紹簡(jiǎn)單意駭透敌,不過(guò)坑是少不了的..而本文的目的抗果,就是盡量讓你少趟坑倡怎。

首先按照官方plugin的排布方式迅耘,下面隨便挑了個(gè)官方plugin:

instance_normalization_plugin

準(zhǔn)備一個(gè)自己的插件:custom.cppcustom.h,copy并paste官方代碼监署,名字替換成自己的颤专。以最新的IPluginV2DynamicExt類為接口。

我們需要寫兩個(gè)類:

  • MyCustomPlugin钠乏,繼承IPluginV2DynamicExt栖秕,是插件類,用于寫插件具體的實(shí)現(xiàn)
  • MyCustomPluginCreator晓避,繼承BaseCreator簇捍,是插件工廠類,用于根據(jù)需求創(chuàng)建該插件

對(duì)了俏拱,插件類繼承IPluginV2DynamicExt才可以支持動(dòng)態(tài)尺寸暑塑,其他插件類接口例如IPluginV2IOExt和前者大部分是相似的。

// 繼承IPluginV2DynamicExt就夠啦
class MyCustomPlugin final : public nvinfer1::IPluginV2DynamicExt

class MyCustomPluginCreator : public BaseCreator

MyCustomPlugin 插件類

總覽:

class MyCustomPlugin final : public nvinfer1::IPluginV2DynamicExt
{

public:

  MyCustomPlugin( int in_channel,
                  const std::vector<float>& weight,
                  const std::vector<float>& bias);
                            
  MyCustomPlugin( int in_channel,
                  nvinfer1::Weights const& weight,
                  nvinfer1::Weights const& bias);

  MyCustomPlugin(void const* serialData, size_t serialLength);
  MyCustomPlugin() = delete;
  ~MyCustomPlugin() override;
  int getNbOutputs() const override;
  DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) override;
  int initialize() override;
  void terminate() override;
  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override;
  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, 
              const void* const* inputs, void* const* outputs, 
              void* workspace, 
              cudaStream_t stream) override;
  size_t getSerializationSize() const override;
  void serialize(void* buffer) const override;
  bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override;
  const char* getPluginType() const override;
  const char* getPluginVersion() const override;
  void destroy() override;
  nvinfer1::IPluginV2DynamicExt* clone() const override;
  void setPluginNamespace(const char* pluginNamespace) override;
  const char* getPluginNamespace() const override;
  DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override;
  void attachToContext(cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) override;
  void detachFromContext() override;
  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, 
                       const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override;
private:
    int _in_channel;
    std::vector<float> weight;
    std::vector<float> bias;
    float* weight;
    float* bias;
    bool _initialized;
    const char* mPluginNamespace;
    std::string mNamespace;
};

成員變量

如果你的插件有weights(類似于conv操作的weight和bias)彰触,有參數(shù)(類似于conv中的kernel-size梯投、padding),在類中則需要定義為成員變量况毅,為private類型:

MyCustomPlugin為例分蓖,假設(shè)我們的這個(gè)MyCustomPlugin有兩個(gè)權(quán)重weight和bias以及一個(gè)參數(shù)in_channel(這個(gè)權(quán)重和參數(shù)沒(méi)有啥意義,純粹尔许,純粹為了演示):

private:
    int  _in_channel; // 參數(shù) 
    std::vector<float> _weight; // 權(quán)重么鹤,在cpu空間存放
    std::vector<float> _bias;   // 偏置權(quán)重,在cpu空間存放
    float* _d_weight;           // 權(quán)重味廊,在GPU空間存放
    float* _d_bias;
    bool _initialized;
    cudnnHandle_t _cudnn_handle;
    const char* mPluginNamespace;
    std::string mNamespace;

構(gòu)造函數(shù)和析構(gòu)函數(shù)

構(gòu)造函數(shù)一般設(shè)置為三個(gè)蒸甜。

第一個(gè)用于在parse階段,PluginCreator用于創(chuàng)建該插件時(shí)調(diào)用的構(gòu)造函數(shù)余佛,需要傳遞權(quán)重信息以及參數(shù)柠新。

第二個(gè)用于在clone階段,復(fù)制這個(gè)plugin時(shí)會(huì)用到的構(gòu)造函數(shù)辉巡。

第三個(gè)用于在deserialize階段恨憎,用于將序列化好的權(quán)重和參數(shù)傳入該plugin并創(chuàng)建愛(ài)你哦。

以我們的MyCustomPlugin為例:

MyCustomPlugin(int in_channel, nvinfer1::Weights const& weight, nvinfer1::Weights const& bias);
MyCustomPlugin(float in_channel, const std::vector<float>& weight, const std::vector<float>& bias);
MyCustomPlugin(void const* serialData, size_t serialLength);

析構(gòu)函數(shù)則需要執(zhí)行terminateterminate函數(shù)就是釋放這個(gè)op之前開(kāi)辟的一些顯存空間:

MyCustomPlugin::~MyCustomPlugin()
{
    terminate();
}

注意需要把默認(rèn)構(gòu)造函數(shù)刪掉:

MyCustomPlugin() = delete;

getNbOutputs

插件op返回多少個(gè)Tensor憔恳,比如MyCustomPlugin這個(gè)操作只輸出一個(gè)Tensor(也就是一個(gè)output)瓤荔,所以直接return 1

// MyCustomPlugin returns one output.
int MyCustomPlugin::getNbOutputs() const
{
    return 1;
}

initialize

初始化函數(shù),在這個(gè)插件準(zhǔn)備開(kāi)始run之前執(zhí)行钥组。

主要初始化一些提前開(kāi)辟空間的參數(shù)输硝,一般是一些cuda操作需要的參數(shù)(例如conv操作需要執(zhí)行卷積操作,我們就需要提前開(kāi)辟weight和bias的顯存)程梦,假如我們的算子需要這些參數(shù)点把,則在這里需要提前開(kāi)辟顯存。

需要注意的是作烟,如果插件算子需要開(kāi)辟比較大的顯存空間愉粤,不建議自己去申請(qǐng)顯存空間,可以使用Tensorrt官方接口傳過(guò)來(lái)的workspace指針來(lái)獲取顯存空間拿撩。因?yàn)槿绻@個(gè)插件被一個(gè)網(wǎng)絡(luò)調(diào)用了很多次衣厘,而這個(gè)插件op需要開(kāi)辟很多顯存空間,那么TensorRT在構(gòu)建network的時(shí)候會(huì)根據(jù)這個(gè)插件被調(diào)用的次數(shù)開(kāi)辟很多顯存压恒,很容易導(dǎo)致顯存溢出影暴。

getOutputDataType

返回結(jié)果的類型,一般來(lái)說(shuō)我們插件op返回結(jié)果類型與輸入類型一致:

nvinfer1::DataType InstanceNormalizationPlugin::getOutputDataType(
    int index, const nvinfer1::DataType* inputTypes, int nbInputs) const
{
    ASSERT(inputTypes && nbInputs > 0 && index == 0);
    return inputTypes[0];
}

getWorkspaceSize

這個(gè)函數(shù)需要返回這個(gè)插件op需要中間顯存變量的實(shí)際數(shù)據(jù)大小(bytesize)探赫,這個(gè)是通過(guò)TensorRT的接口去獲取型宙,是比較規(guī)范的方式。

我們需要在這里確定這個(gè)op需要多大的顯存空間去運(yùn)行伦吠,在實(shí)際運(yùn)行的時(shí)候就可以直接使用TensorRT開(kāi)辟好的空間而不是自己去申請(qǐng)顯存空間妆兑。

size_t MyCustomPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const 
{ 
    // 計(jì)算這個(gè)op前向過(guò)程中你認(rèn)為需要的中間顯存數(shù)量
    size_t need_num;
    return need_num * sizeof(float);
}

enqueue

實(shí)際插件op的執(zhí)行函數(shù),我們自己實(shí)現(xiàn)的cuda操作就放到這里(當(dāng)然C++寫的op也可以放進(jìn)來(lái)毛仪,不過(guò)因?yàn)槭荂PU執(zhí)行搁嗓,速度就比較慢了),與往常一樣接受輸入inputs產(chǎn)生輸出outputs箱靴,傳給相應(yīng)的指針就可以腺逛。

int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
        const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream){

            // 假如這個(gè)fun是你需要的中間變量 這里可以直接用TensorRT為你開(kāi)辟的顯存空間
            fun  = static_cast<float*>(workspace);

        }

需要注意的是,如果我們的操作需要一些分布在顯存中的中間變量衡怀,可以通過(guò)傳過(guò)來(lái)的指針參數(shù)workspace獲取棍矛,上述代碼簡(jiǎn)單說(shuō)明了一下使用方法。

再多說(shuō)一句抛杨,我們默認(rèn)寫的.cu是fp32的够委,TensorRT在fp16運(yùn)行模式下,運(yùn)行到不支持fp16的插件op時(shí)怖现,會(huì)自動(dòng)切換到fp32模式慨绳,等插件op運(yùn)行完再切換回來(lái)。

getOutputDimensions

TensorRT支持Dynamic-shape的時(shí)候,batch這一維度必須是explicit的脐雪,也就是說(shuō),TensorRT處理的維度從以往的三維[3,-1,-1]變成了[1,3,-1,-1]恢共。最新的onnx-tensorrt也必須設(shè)置explicit的batchsize战秋,而且這個(gè)batch維度在getOutputDimensions中是可以獲取到的。

在舊版的IPluginV2類中讨韭,getOutputDimensions的定義如下:

  virtual Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRTNOEXCEPT = 0;

而在新版的IPluginV2DynamicExt類中定義如下:

virtual DimsExprs getOutputDimensions(int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) = 0;

我們要做的就是在這個(gè)成員函數(shù)中根據(jù)輸入維度推理出模型的輸出維度脂信,需要注意的是,雖然說(shuō)輸出維度
是由輸入維度決定透硝,但這個(gè)輸出維度其實(shí)“內(nèi)定”的(也就是在計(jì)算之前就算出來(lái)了)狰闪。如果咱的插件op的輸出維度需要通過(guò)實(shí)際運(yùn)行計(jì)算得到,那么這個(gè)函數(shù)就無(wú)法滿足咱了濒生。

去你媽的好氣哦

set/getPluginNamespace

為這個(gè)插件設(shè)置namespace名字埋泵,如果不設(shè)置則默認(rèn)是"",需要注意的是同一個(gè)namespace下的plugin如果名字相同會(huì)沖突罪治。

PluginFieldCollection

這個(gè)是成員變量丽声,也會(huì)作為getFieldNames成員函數(shù)的返回類型。PluginFieldCollection的主要作用是傳遞這個(gè)插件op所需要的權(quán)重和參數(shù)觉义,在實(shí)際的engine推理過(guò)程中并不使用雁社,而在parse中會(huì)用到(例如caffe2trt、onnx2trt)晒骇。

當(dāng)使用這些parse去解析這個(gè)op的時(shí)候霉撵,這個(gè)op的權(quán)重和參數(shù)會(huì)經(jīng)歷Models --> TensorRT engine --> TensorRT runtime這個(gè)過(guò)程。

舉個(gè)例子洪囤,在onnx-tensorrt中徒坡,我們用過(guò)DEFINE_BUILTIN_OP_IMPORTER去注冊(cè)op,然后通過(guò)parse解析onnx模型箍鼓,根據(jù)注冊(cè)好的op去一個(gè)個(gè)解析構(gòu)建模型崭参,假如我們定義的op為my_custom_op,在DEFINE_BUILTIN_OP_IMPORTER(my_custom_op)會(huì)這樣實(shí)現(xiàn):

DEFINE_BUILTIN_OP_IMPORTER(mycustom_op)
{
    ASSERT(inputs.at(0).is_tensor(), ErrorCode::kUNSUPPORTED_NODE); 
    ...
    const std::string pluginName = "CUSTOM-OP";
    const std::string pluginVersion = "001";
    // 這個(gè)f保存這個(gè)op需要的權(quán)重和參數(shù),從onnx模型中獲取
    std::vector<nvinfer1::PluginField> f;

    f.emplace_back("in_channel", &in_channel, nvinfer1::PluginFieldType::kINT32, 1);
    f.emplace_back("weight", kernel_weights.values, nvinfer1::PluginFieldType::kFLOAT32, kernel_weights.count());
    f.emplace_back("bias", bias_weights.values, nvinfer1::PluginFieldType::kFLOAT32, bias_weights.count);

    // 這個(gè)從將plugin工廠中獲取該插件款咖,并且將權(quán)重和參數(shù)傳遞進(jìn)去
    nvinfer1::IPluginV2* plugin = importPluginFromRegistry(ctx, pluginName, pluginVersion, node.name(), f);

    RETURN_FIRST_OUTPUT(ctx->network()->addPluginV2(tensors.data(), tensors.size(), *plugin));
}

進(jìn)入importPluginFromRegistry函數(shù)內(nèi)部何暮,可以發(fā)現(xiàn)參數(shù)通過(guò)fc變量通過(guò)createPlugin傳遞給了plugin

nvinfer1::IPluginV2* importPluginFromRegistry(IImporterContext* ctx, const std::string& pluginName,
    const std::string& pluginVersion, const std::string& nodeName,
    const std::vector<nvinfer1::PluginField>& pluginFields)
{
    const auto mPluginRegistry = getPluginRegistry();
    const auto pluginCreator
        = mPluginRegistry->getPluginCreator(pluginName.c_str(), pluginVersion.c_str(), "ONNXTRT_NAMESPACE");

    if (!pluginCreator)
    {
        return nullptr;
    }
    // 接受傳進(jìn)來(lái)的權(quán)重和參數(shù)信息 傳遞給plugin
    nvinfer1::PluginFieldCollection fc;
    fc.nbFields = pluginFields.size();
    fc.fields = pluginFields.data();

    return pluginCreator->createPlugin(nodeName.c_str(), &fc);
}

上述步驟中,會(huì)提供pluginNamepluginVersion初始化MyCustomPluginCreator海洼,其中createPlugin成員函數(shù)是我們需要編寫的(下文會(huì)說(shuō))富腊。

configurePlugin

配置這個(gè)插件op,判斷輸入和輸出類型數(shù)量是否正確。官方還提到通過(guò)這個(gè)配置信息可以告知TensorRT去選擇合適的算法(algorithm)去調(diào)優(yōu)這個(gè)模型是整。

但自動(dòng)調(diào)優(yōu)目前還沒(méi)有嘗試過(guò),我們一般自己寫的plugin執(zhí)行代碼都是定死的浮入,所謂的調(diào)優(yōu)步驟可能更多地針對(duì)官方的op。

下面的plugin中configurePlugin函數(shù)僅僅是簡(jiǎn)單地確認(rèn)了下輸入和輸出以及類型事秀。

void MyCustomPluginDynamic::configurePlugin(
    const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
    const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {
  // Validate input arguments
  assert(nbOutputs == 1);
  assert(nbInputs == 2);
  assert(mType == inputs[0].desc.type);
}

clone

這玩意兒干嘛的彤断,顧名思義易迹,就是克隆嘛,將這個(gè)plugin對(duì)象克隆一份給TensorRT的builder供炼、network或者engine句伶。這個(gè)成員函數(shù)會(huì)調(diào)用上述說(shuō)到的第二個(gè)構(gòu)造函數(shù):

MyCustomPlugin(float in_channel, const std::vector<float>& weight, const std::vector<float>& bias);

將要克隆的plugin的權(quán)重和參數(shù)傳遞給這個(gè)構(gòu)造函數(shù)。

IPluginV2DynamicExt* MyCustomPlugin::clone() const
{ 
    // 
    auto plugin = new MyCustomPlugin{_in_channel, _weight, _bias};
    plugin->setPluginNamespace(mPluginNamespace);
    return plugin;
}

clone成員函數(shù)主要用于傳遞不變的權(quán)重和參數(shù)先嬉,將plugin復(fù)制n多份楚堤,從而可以被不同engine或者builder或者network使用身冬。

getSerializationSize

返回序列化時(shí)需要寫多少字節(jié)到buffer中。

size_t MyCustomPlugin::getSerializationSize() const
{
    return (serialized_size(_in_channel) +
            serialized_size(_weight) +
            serialized_size(_bias)
            );
}

supportsFormatCombination

TensorRT調(diào)用此方法以判斷pos索引的輸入/輸出是否支持inOut[pos].formatinOut[pos].type指定的格式/數(shù)據(jù)類型滚躯。

如果插件支持inOut[pos]處的格式/數(shù)據(jù)類型嘿歌,則返回true宙帝。 如果是否支持取決于其他的輸入/輸出格式/數(shù)據(jù)類型,則插件可以使其結(jié)果取決于inOut[0..pos-1]中的格式/數(shù)據(jù)類型愿待,該格式/數(shù)據(jù)類型將設(shè)置為插件支持的值。 這個(gè)函數(shù)不需要檢查inOut[pos + 1..nbInputs + nbOutputs-1]要出,pos的決定必須僅基于inOut[0..pos]农渊。

bool MyCustomPlugin::supportsFormatCombination(
    int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs)
{
    // 假設(shè)有一個(gè)輸入一個(gè)輸出
    assert(0 <= pos && pos < 2);
    const auto *in = inOut;
    const auto *out = inOut + nbInputs;
    switch (pos) {
        case 0:
        return in[0].type == DataType::kFLOAT &&
                in[0].format == nvinfer1::TensorFormat::kLINEAR;
        case 1:
        return out[0].type == in[0].type &&
                out[0].format == nvinfer1::TensorFormat::kLINEAR;
    }
}

serialize

把需要用的數(shù)據(jù)按照順序序列化到buffer里頭。

void MyCustomPlugin::serialize(void *buffer) const
{
    serialize_value(&buffer, _in_channel);
    serialize_value(&buffer, _weight);
    serialize_value(&buffer, _bias);
}

attachToContext

如果這個(gè)op使用到了一些其他東西,例如cublas handle饭宾,可以直接借助TensorRT內(nèi)部提供的cublas handle:

void MyCustomPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator)
{
     mCublas = cublasContext;
}

MyCustomPluginCreator 插件工廠類

總覽:

class MyCustomPluginCreator : public BaseCreator
{
public:
  MyCustomPluginCreator();
  ~MyCustomPluginCreator() override = default;
  const char* getPluginName() const override;    // 不介紹
  const char* getPluginVersion() const override; // 不介紹
  const PluginFieldCollection* getFieldNames() override; // 不介紹
  IPluginV2DynamicExt* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override;
  IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
private:
  static PluginFieldCollection mFC;
  static std::vector<PluginField> mPluginAttributes;
  std::string mNamespace;
};

構(gòu)造函數(shù)

創(chuàng)建一個(gè)空的mPluginAttributes初始化mFC看铆。

MyCustomPluginCreator::MyCustomPluginCreator()
{
    mPluginAttributes.emplace_back(PluginField("in_channel", nullptr, PluginFieldType::kFLOAT32, 1));
    mPluginAttributes.emplace_back(PluginField("weight", nullptr, PluginFieldType::kFLOAT32, 1));
    mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1));
    
    mFC.nbFields = mPluginAttributes.size();
    mFC.fields = mPluginAttributes.data();
}

createPlugin

這個(gè)成員函數(shù)作用是通過(guò)PluginFieldCollection去創(chuàng)建plugin弹惦,將op需要的權(quán)重和參數(shù)一個(gè)一個(gè)取出來(lái),然后調(diào)用上文提到的第一個(gè)構(gòu)造函數(shù):

MyCustomPlugin(int in_channel, nvinfer1::Weights const& weight, nvinfer1::Weights const& bias);

去創(chuàng)建plugin棠隐。

MyCustomPlugin示例:

IPluginV2DynamicExt* MyCustomPlugin::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc)
{
    int in_channel;
    std::vector<float> weight;
    std::vector<float> bias;
    const PluginField* fields = fc->fields;
    for (int i = 0; i < fc->nbFields; ++i)
    {
        const char* attrName = fields[i].name;
        if (!strcmp(attrName, "in_channel"))
        {
            ASSERT(fields[i].type == PluginFieldType::kINT32);
            in_channel= *(static_cast<const int32_t*>(fields[i].data));
        }
        else if (!strcmp(attrName, "weight"))
        {
            ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
            int size = fields[i].length;
            h_weight.reserve(size);
            const auto* w = static_cast<const float*>(fields[i].data);
            for (int j = 0; j < size; j++)
            {
                h_weight.push_back(*w);
                w++;
            }
        }
        else if (!strcmp(attrName, "bias"))
        {
            ASSERT(fields[i].type == PluginFieldType::kFLOAT32);
            int size = fields[i].length;
            h_bias.reserve(size);
            const auto* w = static_cast<const float*>(fields[i].data);
            for (int j = 0; j < size; j++)
            {
                h_bias.push_back(*w);
                w++;
            }
        }
    }

    Weights weightWeights{DataType::kFLOAT, weight.data(), (int64_t) weight.size()};
    Weights biasWeights{DataType::kFLOAT, bias.data(), (int64_t)_bias.size()};

    MyCustomPlugin* obj = new MyCustomPlugin(in_channel, weightWeights, biasWeights);
    obj->setPluginNamespace(mNamespace.c_str());
    return obj;
}

deserializePlugin

這個(gè)函數(shù)會(huì)被onnx-tensorrt的一個(gè)叫做TRT_PluginV2的轉(zhuǎn)換op調(diào)用啰扛,這個(gè)op會(huì)讀取onnx模型的data數(shù)據(jù)將其反序列化到network中嗡贺。

一些官方插件的注意事項(xiàng)

使用官方插件會(huì)遇到些小問(wèn)題。

topk問(wèn)題

官方的topk插件最多支持k<=3840煞茫。否則會(huì)報(bào):

[TensorRT] ERROR: Parameter check failed at: ../builder/Layers.cpp::TopKLayer::3137, condition: k > 0 && k <= MAX_TOPK_K

相關(guān)問(wèn)題:https://github.com/tensorflow/tensorflow/issues/31671

batchednms問(wèn)題

官方的batchednms最大支持的topk為4096续徽,太大也會(huì)崩潰架谎。不過(guò)可以修改源代碼實(shí)現(xiàn)突破這個(gè)數(shù)值,但仍然有bug

  void (*kernel[])(const int, const int, const int, const int, const float,
                     const bool, const bool, float *, T_SCORE *, int *,
                     T_SCORE *, int *, bool) = {
      P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10),
      P(11), P(12), P(13), P(14), P(15), P(16)
  };

關(guān)于plugin的注冊(cè)

簡(jiǎn)單說(shuō)下plugin的注冊(cè)流程土全。

在加載NvInferRuntimeCommon.h頭文件的時(shí)候會(huì)得到一個(gè)getPluginRegistry,這里類中包含了所有已經(jīng)注冊(cè)了的IPluginCreator瑞凑,在使用的時(shí)候我們通過(guò)getPluginCreator函數(shù)得到相應(yīng)的IPluginCreator概页。

注冊(cè)插件有兩種方式,第一種可以看官方的plugin代碼惰匙。

extern "C" {
bool initLibNvInferPlugins(void* logger, const char* libNamespace)
{
    initializePlugin<nvinfer1::plugin::GridAnchorPluginCreator>(logger, libNamespace);
    initializePlugin<nvinfer1::plugin::NMSPluginCreator>(logger, libNamespace);
    initializePlugin<nvinfer1::plugin::ReorgPluginCreator>(logger, libNamespace);
    ...
    return true;
}

其中initializePlugin函數(shù)執(zhí)行了addPluginCreator函數(shù):

template <typename CreatorType>
void initializePlugin(void* logger, const char* libNamespace)
{
    PluginCreatorRegistry::getInstance().addPluginCreator<CreatorType>(logger, libNamespace);
}

addPluginCreator函數(shù)又執(zhí)行了getPluginRegistry()->registerCreator對(duì)pluginCreator進(jìn)行了注冊(cè)项鬼,這樣就完成注冊(cè)任務(wù)了:

void addPluginCreator(void* logger, const char* libNamespace)
{
    ...
        if (mRegistryList.find(pluginType) == mRegistryList.end())
        {
            bool status = getPluginRegistry()->registerCreator(*pluginCreator, libNamespace);
            if (status)
            {
                mRegistry.push(std::move(pluginCreator));
                mRegistryList.insert(pluginType);
                verboseMsg = "Plugin creator registration succeeded - " + pluginType;
            }
            else
            {
                errorMsg = "Could not register plugin creator:  " + pluginType;
            }
        }
        else
        {
            verboseMsg = "Plugin creator already registered - " + pluginType;
        }
    ...
}

另一種注冊(cè)可以直接通過(guò)REGISTER_TENSORRT_PLUGIN來(lái)注冊(cè):

//!
//! \brief Return the plugin registry
//!
//  在加載`NvInferRuntimeCommon.h`頭文件的時(shí)候會(huì)得到一個(gè)`getPluginRegistry`
extern "C" TENSORRTAPI nvinfer1::IPluginRegistry* getPluginRegistry();

namespace nvinfer1
{

template <typename T>
class PluginRegistrar
{
public:
    PluginRegistrar() { getPluginRegistry()->registerCreator(instance, ""); }
private:
    T instance{};
};

#define REGISTER_TENSORRT_PLUGIN(name) \
    static nvinfer1::PluginRegistrar<name> pluginRegistrar##name {}

} // namespace nvinfer1

也就是說(shuō)鸠真,如果我們已經(jīng)在plugin的.h文件中執(zhí)行了REGISTER_TENSORRT_PLUGIN(BatchedNMSPluginCreator);就不需要再創(chuàng)建一個(gè)類似于官方的initLibNvInferPlugins()函數(shù)去一個(gè)一個(gè)注冊(cè)了龄毡。

參考鏈接

https://github.com/NVIDIA/TensorRT/tree/release/7.0/plugin
https://github.com/triton-inference-server/server/issues/767
https://blog.csdn.net/u010552731/article/details/106520241
https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work_dynamic_shapes
https://forums.developer.nvidia.com/t/tensorrt-5-1-6-custom-plugin-with-fp16-issue/84132/4
https://forums.developer.nvidia.com/t/tensorrt-cask-error-in-checkcaskexecerror-false-7-cask-convolution-execution/109735
https://github.com/NVIDIA/TensorRT/tree/release/7.0/samples/opensource/samplePlugin
https://forums.developer.nvidia.com/t/unable-to-run-two-tensorrt-models-in-a-cascade-manner/145274/2

DCNv2-github

https://github.com/CharlesShang/DCNv2
https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch

想要找我

如果你與我志同道合于此沦零,老潘很愿意與你交流;如果你喜歡老潘的內(nèi)容序攘,歡迎關(guān)注和支持程奠。博客每周更新一篇深度原創(chuàng)文祭钉,關(guān)注公眾號(hào)「oldpan博客」不錯(cuò)過(guò)最新文章。老潘也會(huì)整理一些自己的私藏距境,希望能幫助到大家垮卓,公眾號(hào)回復(fù)"888"獲取老潘學(xué)習(xí)路線資料與文章匯總粟按,還有更多等你挖掘霹粥。如果不想錯(cuò)過(guò)老潘的最新推文后控,請(qǐng)點(diǎn)擊神秘鏈接空镜。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末吴攒,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子欣鳖,更是在濱河造成了極大的恐慌茴厉,老刑警劉巖矾缓,帶你破解...
    沈念sama閱讀 217,907評(píng)論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件嗜闻,死亡現(xiàn)場(chǎng)離奇詭異桅锄,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)翠肘,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,987評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門束倍,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)盟戏,“玉大人柿究,你說(shuō)我怎么就攤上這事∩艏纾” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 164,298評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵懂诗,是天一觀的道長(zhǎng)苗膝。 經(jīng)常有香客問(wèn)我,道長(zhǎng)离唐,這世上最難降的妖魔是什么亥鬓? 我笑而不...
    開(kāi)封第一講書人閱讀 58,586評(píng)論 1 293
  • 正文 為了忘掉前任域庇,我火速辦了婚禮听皿,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘庵朝。我一直安慰自己又厉,他們只是感情好覆致,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,633評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布煌妈。 她就那樣靜靜地躺著,像睡著了一般笔链。 火紅的嫁衣襯著肌膚如雪腮猖。 梳的紋絲不亂的頭發(fā)上澈缺,一...
    開(kāi)封第一講書人閱讀 51,488評(píng)論 1 302
  • 那天炕婶,我揣著相機(jī)與錄音柠掂,去河邊找鬼依沮。 笑死危喉,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的皇拣。 我是一名探鬼主播薄嫡,決...
    沈念sama閱讀 40,275評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼毫深,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼费什!你這毒婦竟也來(lái)了手素?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書人閱讀 39,176評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎崩哩,沒(méi)想到半個(gè)月后邓嘹,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,619評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡矿筝,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,819評(píng)論 3 336
  • 正文 我和宋清朗相戀三年窖维,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了铸史。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,932評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡判沟,死狀恐怖水评,靈堂內(nèi)的尸體忽然破棺而出媚送,到底是詐尸還是另有隱情,我是刑警寧澤疗涉,帶...
    沈念sama閱讀 35,655評(píng)論 5 346
  • 正文 年R本政府宣布咱扣,位于F島的核電站涵防,受9級(jí)特大地震影響壮池,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜厅克,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,265評(píng)論 3 329
  • 文/蒙蒙 一证舟、第九天 我趴在偏房一處隱蔽的房頂上張望窗骑。 院中可真熱鬧,春花似錦鲤竹、人聲如沸辛藻。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 31,871評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)纺蛆。三九已至规揪,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間字支,已是汗流浹背堕伪。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 32,994評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工栗菜, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留疙筹,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,095評(píng)論 3 370
  • 正文 我出身青樓霍比,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親馍驯。 傳聞我的和親對(duì)象是個(gè)殘疾皇子玛痊,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,884評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容