C++利用LibTorch調(diào)用pytorch 模型

由于python的易用性卒煞,深度學(xué)習(xí)模型多是在python框架下進(jìn)行訓(xùn)練的狸驳,如TensorFlow藕咏,pytorch等再层。而由于硬件設(shè)備的限制贸铜,有時(shí)候其部署可能需要基于C++/C的平臺(tái)。比如在我們的項(xiàng)目中聂受,語(yǔ)義分割網(wǎng)絡(luò)是在pytorch下訓(xùn)練的蒿秦,而分割結(jié)果基于應(yīng)用的后處理部分是在C下面實(shí)現(xiàn)的,那怎么才能把這兩種平臺(tái)下的東西結(jié)合起來(lái)一起運(yùn)行呢蛋济?我想到的方法有以下幾種:

  • 將pytorch模型訓(xùn)練好后轉(zhuǎn)成caffe的模型渤早,然后利用caffe的接口在C++下面實(shí)現(xiàn)模型的推理應(yīng)用;

  • 將C部分的代碼打包編譯成一個(gè)動(dòng)態(tài)連接庫(kù)dll瘫俊,然后在python框架下調(diào)用該dll實(shí)現(xiàn)c下面的功能鹊杖;

  • 利用pytorch的C++版本LibTorch實(shí)現(xiàn)pytorch模型的調(diào)用。
    本文主要記錄最后一種方法扛芽。

LibTorch 的下載及使用

LibTorch 是pytorch的C++版本骂蓖,在pytorch版本1.0后就有了。在官網(wǎng)通過(guò)如下選擇川尖,就可以得到下載鏈接登下。


image-20200304141339075.png

下載鏈接里有release版本和debug版本,建議兩個(gè)版本都下載叮喳,兩者主要是對(duì)應(yīng)的dll和lib不一樣被芳,debug版本還提供了pdb,可以幫助定位錯(cuò)誤位置馍悟。將release版本解壓后畔濒,得到一個(gè)LibTorch的文件夾,再將debug版本解壓锣咒,將其中的lib文件夾改名為lib_debug侵状,同樣放在之前release版本解壓的LibTorch文件夾,這樣就release和debug版本都可以使用了毅整。

下載的LibTorch中提供了cmakelist趣兄,在linux平臺(tái)可以利用cmake來(lái)使用它。而如果在Windows平臺(tái)利用vs悼嫉,只需要和一般的第三方庫(kù)使用一樣艇潭,在對(duì)應(yīng)的工程中添加正確的AdditionalIncludeDirectories,AdditionalLibraryDirectories戏蔑,AdditionalDependencies等就可以了蹋凝。我在實(shí)驗(yàn)時(shí),將lib里邊所有的lib文件都加入到AdditionalDependencies了辛臊。程序運(yùn)行的時(shí)候還需要把對(duì)應(yīng)的dll拷貝到exe所在的文件夾仙粱。我使用debug時(shí)遇到了一個(gè)編譯錯(cuò)誤,添加preprocessor _SCL_SECURE_NO_WARNINGS就好了彻舰。

使用流程

利用LibTorch來(lái)調(diào)用pytorch模型的流程大致是這樣的:

  1. pytorch訓(xùn)練好模型
  2. 將模型序列化并存成pt文件
  3. 在C中利用LibTorch的接口進(jìn)行正向推演

pytorch模型序列化

第一步我們就不介紹了伐割,我們從第二步開(kāi)始。模型的序列化是利用Torch Script來(lái)完成的刃唤。TorchScript是一種從PyTorch代碼創(chuàng)建可序列化和可優(yōu)化模型的方法隔心。用TorchScript編寫(xiě)的任何代碼都可以從Python進(jìn)程中保存并加載到?jīng)]有Python依賴關(guān)系的進(jìn)程中。對(duì)于一個(gè)已經(jīng)訓(xùn)練好的pytorch模型尚胞,官方提供兩種方法進(jìn)行Torch Script的轉(zhuǎn)換:tracing和annotation硬霍。

Tracing

Tracing的方法還是很簡(jiǎn)單的,參見(jiàn)如下示例代碼:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

Annotation

tracing適用于大多數(shù)網(wǎng)絡(luò)笼裳,如果你的網(wǎng)絡(luò)的forward方法中對(duì)input有邏輯判斷唯卖,比如input的size為一個(gè)值時(shí)走向一個(gè)分支粱玲,而為另一值時(shí)走向另一個(gè)分支,那么只能用annotation進(jìn)行轉(zhuǎn)換拜轨。比如如下的網(wǎng)絡(luò):

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

利用annotation來(lái)將上述網(wǎng)絡(luò)模型轉(zhuǎn)成Torch Script可以按如下代碼:

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

annotation的方法我并沒(méi)有測(cè)試抽减,我使用的模型用tracing就已經(jīng)足夠了。

序列化

序列化的意思是指將上述Torch Script描述的模型存成一個(gè)文件橄碾。

traced_script_module.save("traced_resnet_model.pt")

C++中的正向推演

#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }
    
  // Simple tests of the model
  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(torch::ones({1, 3, 224, 224}));

  // Execute the model and turn its output into a tensor.
  at::Tensor output = module.forward(inputs).toTensor();

  std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

  std::cout << "ok\n";
}

自己訓(xùn)練的模型的實(shí)際操作

下面以我們訓(xùn)練的語(yǔ)義分割網(wǎng)絡(luò)為例卵沉,介紹如何將自己的模型在C++中跑起來(lái)。

在實(shí)際的操作中法牲,也是遇到了一些問(wèn)題的史汗。

GPU及DataParallel的問(wèn)題

第一個(gè)問(wèn)題是我們之前的模型訓(xùn)練是在GPU(相信應(yīng)該都是這樣的)中進(jìn)行的,并且使用了DataParallel拒垃,在序列化時(shí)停撞,如下代碼是正確的,可以與示例代碼做下比較恶复。

device = torch.device('cuda')
model = get_model(args_in)
model = torch.nn.DataParallel(model, device_ids=[0])
model.load_state_dict(torch.load(args_in.test_model_path))
model.to(device)
# use evaluation mode to ignore dropout, etc
model.eval()

# The tracing input need not to be the same size as the forward case.
example = torch.rand(1, 3, 1080, 1920).to(device)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model.module, example)

traced_script_module.save("traced_model.pt")

對(duì)于GPU訓(xùn)練的模型怜森,需要將模型和tracing用的tensor通過(guò)to(device)或者.cuda()轉(zhuǎn)到GPU上,如第5,10行谤牡。對(duì)于利用DataParallel訓(xùn)練的模型副硅,需要在trace時(shí)使用model.module,如第13行翅萤。

關(guān)于DataParallel多說(shuō)一句恐疲,如果希望正向的時(shí)候不需要像第3行那樣將model再包一層,在訓(xùn)練save model的時(shí)候應(yīng)該按如下

torch.save(model.module.state_dict(), save_path)

這樣存的model就不需要第3行代碼套么,而且第13行的.module也不需要了培己。

附上因?yàn)镈ataParallel沒(méi)弄對(duì)在pycharm中遇到的錯(cuò)誤

RuntimeError: hasSpecialCase INTERNAL ASSERT FAILED at ..\torch\csrc\jit\passes\alias_analysis.cpp:300, please report a bug to PyTorch. We don't have an op for aten::to but it isn't a special case. (analyzeImpl at ..\torch\csrc\jit\passes\alias_analysis.cpp:300)

網(wǎng)絡(luò)輸出是Tuple的問(wèn)題

我們的網(wǎng)絡(luò)輸出是一個(gè)tuple而不是一個(gè)tensor,于是在C++調(diào)用的時(shí)候總是crash胚泌,用了debug版本的LibTorch省咨,才發(fā)現(xiàn)問(wèn)題。官方提到LibTorch這種方式需要網(wǎng)絡(luò)的輸出是一個(gè)tuple或者tensor玷室,那如果輸出的是tuple零蓉,在C++端代碼應(yīng)該按如下修改

torch::Tensor result = module.forward(input).toTuple()->elements()[0].toTensor();

圖像的前處理

在pytorch模型的訓(xùn)練過(guò)程中,我們一般會(huì)對(duì)圖像進(jìn)行一些前處理穷缤,比如

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])

在LibTorch中敌蜂,可以這樣做

  tensor_image = tensor_image.toType(torch::kFloat);
  tensor_image = tensor_image.div(255);
  // Normalization
  tensor_image[0][0] = tensor_image[0][0].sub_(0.485).div_(0.229);
  tensor_image[0][1] = tensor_image[0][1].sub_(0.456).div_(0.224);
  tensor_image[0][2] = tensor_image[0][2].sub_(0.406).div_(0.225);

最后貼上我們利用opencv讀視頻,然后對(duì)每一幀運(yùn)行語(yǔ)義分割正向的代碼津肛。

  // module for forward process
  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load("traced_model.pt");
  } catch (const c10::Error &e) {
    std::cerr << "error loading the model\n";
  }
  torch::DeviceType device = torch::kCUDA;
  module.to(device);

  // opencv windows
  cv::namedWindow("Test", 0);
  cvMoveWindow("Test", 0, 0);
  
  cv::VideoCapture  t_video_in(videoPath);
  long nbFrames = static_cast<long>(t_video_in.get(CV_CAP_PROP_FRAME_COUNT));

  for (long f = 0; f < nbFrames; f++) {
    cv::Mat image, input;
    t_video_in >> image;
    cv::cvtColor(image, input, CV_BGR2RGB);

    // run semantic segmentation to get label image
    torch::Tensor tensor_image = torch::from_blob(input.data, { 1, input.rows, input.cols, 3 }, torch::kByte);
    tensor_image = tensor_image.permute({ 0, 3, 1, 2 });
    tensor_image = tensor_image.toType(torch::kFloat);
    tensor_image = tensor_image.div(255);
    // Normalization
    tensor_image[0][0] = tensor_image[0][0].sub_(0.485).div_(0.229);
    tensor_image[0][1] = tensor_image[0][1].sub_(0.456).div_(0.224);
    tensor_image[0][2] = tensor_image[0][2].sub_(0.406).div_(0.225);

    tensor_image = tensor_image.to(torch::kCUDA);
    torch::Tensor result = module.forward({ tensor_image }).toTuple()->elements()[0].toTensor();
    torch::Tensor pred = result.argmax(1);
    pred = pred.squeeze();
    pred = pred.to(torch::kU8);
    pred = pred.to(torch::kCPU);

    cv::Mat label(cv::Size(image.cols,image.rows), CV_8U, pred.data_ptr());
    cv::imshow("Test", label);

    cv::waitKey(1);
  }
  t_video_in.release();
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末章喉,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌秸脱,老刑警劉巖落包,帶你破解...
    沈念sama閱讀 211,265評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異撞反,居然都是意外死亡妥色,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,078評(píng)論 2 385
  • 文/潘曉璐 我一進(jìn)店門(mén)遏片,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人撮竿,你說(shuō)我怎么就攤上這事吮便。” “怎么了幢踏?”我有些...
    開(kāi)封第一講書(shū)人閱讀 156,852評(píng)論 0 347
  • 文/不壞的土叔 我叫張陵髓需,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我房蝉,道長(zhǎng)僚匆,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,408評(píng)論 1 283
  • 正文 為了忘掉前任搭幻,我火速辦了婚禮咧擂,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘檀蹋。我一直安慰自己松申,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,445評(píng)論 5 384
  • 文/花漫 我一把揭開(kāi)白布俯逾。 她就那樣靜靜地躺著贸桶,像睡著了一般。 火紅的嫁衣襯著肌膚如雪桌肴。 梳的紋絲不亂的頭發(fā)上皇筛,一...
    開(kāi)封第一講書(shū)人閱讀 49,772評(píng)論 1 290
  • 那天,我揣著相機(jī)與錄音坠七,去河邊找鬼水醋。 笑死,一個(gè)胖子當(dāng)著我的面吹牛灼捂,可吹牛的內(nèi)容都是我干的离例。 我是一名探鬼主播,決...
    沈念sama閱讀 38,921評(píng)論 3 406
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼悉稠,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼宫蛆!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 37,688評(píng)論 0 266
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤耀盗,失蹤者是張志新(化名)和其女友劉穎想虎,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體叛拷,經(jīng)...
    沈念sama閱讀 44,130評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡舌厨,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,467評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了忿薇。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片裙椭。...
    茶點(diǎn)故事閱讀 38,617評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖署浩,靈堂內(nèi)的尸體忽然破棺而出揉燃,到底是詐尸還是另有隱情,我是刑警寧澤筋栋,帶...
    沈念sama閱讀 34,276評(píng)論 4 329
  • 正文 年R本政府宣布炊汤,位于F島的核電站,受9級(jí)特大地震影響弊攘,放射性物質(zhì)發(fā)生泄漏抢腐。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,882評(píng)論 3 312
  • 文/蒙蒙 一襟交、第九天 我趴在偏房一處隱蔽的房頂上張望迈倍。 院中可真熱鬧,春花似錦婿着、人聲如沸授瘦。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,740評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)提完。三九已至,卻和暖如春丘侠,著一層夾襖步出監(jiān)牢的瞬間徒欣,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 31,967評(píng)論 1 265
  • 我被黑心中介騙來(lái)泰國(guó)打工蜗字, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留打肝,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,315評(píng)論 2 360
  • 正文 我出身青樓挪捕,卻偏偏與公主長(zhǎng)得像粗梭,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子级零,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,486評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容