從PyTorch到TorchScript: 打通深度學(xué)習(xí)模型的生產(chǎn)和應(yīng)用

概述

PyTorch是一款非常流行的深度學(xué)習(xí)框架,開發(fā)者和研究者常常選擇它葛账,因?yàn)樗哂徐`活性柠衅、易用性和良好的性能。然而籍琳,PyTorch的靈活易用性是建立在動(dòng)態(tài)計(jì)算圖的基礎(chǔ)上的菲宴,相比采用靜態(tài)圖的TensorFlow,PyTorch在推理性能和部署方面存在明顯的劣勢(shì)趋急。

為了解決這個(gè)問題喝峦,TorchScript應(yīng)運(yùn)而生。它將PyTorch模型轉(zhuǎn)換為靜態(tài)類型的優(yōu)化序列化格式呜达,以實(shí)現(xiàn)高效的優(yōu)化和跨平臺(tái)部署(包括C++谣蠢、Python、移動(dòng)設(shè)備和云端)查近。

構(gòu)建 TorchScript

TorchScript將PyTorch模型轉(zhuǎn)換為靜態(tài)圖形式眉踱,因此構(gòu)建TorchScript的核心是構(gòu)建模型的靜態(tài)計(jì)算圖。
PyTorch提供了兩種方法來構(gòu)建TorchScript:trace和script霜威。

  • torch.jit.trace:該函數(shù)接收一個(gè)已訓(xùn)練好的模型和實(shí)際輸入樣例谈喳,通過運(yùn)行模型的方式來生成靜態(tài)圖(static graph)。這種轉(zhuǎn)換方式稱為"追蹤模式"(tracing mode)戈泼。

  • torch.jit.script:該函數(shù)將PyTorch代碼編譯成靜態(tài)圖婿禽。與追蹤模式相反,它被稱為"腳本模式"(scripting mode)大猛,因?yàn)樗苯訉yTorch代碼翻譯成靜態(tài)圖扭倾,而不需要追蹤執(zhí)行流程。

Tracing Mode

model = torch.nn.Sequential(nn.Linear(3, 4))
input = torch.randn(1, 3)
traced_model = torch.jit.trace(model, input)

追蹤模式通過運(yùn)行模型一次挽绩,并根據(jù)操作序列生成靜態(tài)圖吆录。因此,它需要提供輸入樣例(input)琼牧。通過追蹤機(jī)制恢筝,自動(dòng)捕捉和生成模型的計(jì)算圖。這是許多AI編譯器采用的JIT模式巨坊。然而撬槽,追蹤模式存在一個(gè)問題,即無法處理控制流趾撵,例如if侄柔、while等語句共啃。

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    if x > 0:
      x += 1
    else:
      x -= 1
    return x

model = MyModel()
input = torch.randn(1)
traced_model = torch.jit.trace(model, input)

if語句為例,它是Python的語句暂题,根據(jù)具體的x值移剪,只能在then分支或else分支上執(zhí)行。因此薪者,追蹤模式只能捕捉到一個(gè)分支上的操作纵苛。要想生成完整的控制流圖,需要采用腳本模式言津。

Scripting Mode

與追蹤模式不同攻人,Scripting 模式直接將 Python 和 PyTorch 的語句翻譯成 TorchScript 的靜態(tài)圖,因此不需要追蹤模型的執(zhí)行流程悬槽,并且能夠生成完整的控制流圖:

script_model = torch.jit.script(model)
print(script_model.graph)


graph(%self : __torch__.___torch_mangle_3.MyModel,
      %x.1 : Tensor):
  ......
  %x : Tensor = prim::If(%6) # <ipython-input-3-6fda6c66b1df>:6:4
    block0():
      %x.7 : Tensor = aten::add_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:7:6
      -> (%x.7)
    block1():
      %x.13 : Tensor = aten::sub_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:9:6
      -> (%x.13)
  return (%x)

然而怀吻,這種模式也有其局限性:對(duì)于每個(gè)語句,都需要提供相應(yīng)的轉(zhuǎn)換函數(shù)初婆,將 Python/PyTorch 語句轉(zhuǎn)換成 TorchScript 語句蓬坡。目前,PyTorch僅支持部分 Python 內(nèi)置函數(shù)和 PyTorch 語句的轉(zhuǎn)換磅叛。

Tracing + Script

因此渣窜,對(duì)于具有控制流的模型,可以采用混合模式:將追蹤模式無法處理的控制流圖封裝為子模塊宪躯,使用腳本模式來轉(zhuǎn)換這些子模塊,然后通過追蹤機(jī)制對(duì)整個(gè)模型進(jìn)行追蹤(通過腳本模式轉(zhuǎn)換后的子模塊不會(huì)再被追蹤)位迂。有關(guān)具體實(shí)現(xiàn)访雪,請(qǐng)參考官方示例:https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting

運(yùn)行 TorchScript

pytorch_jit_forward.png

前面生成的計(jì)算圖會(huì)封裝到 TorchScript 模塊的 forward() 方法中掂林,在運(yùn)行時(shí)被編譯成 native code(JIT)臣缀。如上圖所示,特化后的計(jì)算圖經(jīng)過圖優(yōu)化后被編譯成 native code泻帮,最后通過棧機(jī)解釋器執(zhí)行精置。

Specialization

JIT(just-in-time)將靜態(tài)圖編譯后的結(jié)果以 <signature: executable> 鍵值對(duì)的形式存儲(chǔ)在緩存中。只有在緩存未命中(cache miss)時(shí)锣杂,也就是首次運(yùn)行時(shí)脂倦,才會(huì)觸發(fā)編譯過程。

Signature 表示唯一的靜態(tài)計(jì)算圖元莫。在計(jì)算流圖不變的情況下赖阻,靜態(tài)圖由輸入?yún)?shù)(arguments)決定。不同的 dtype踱蠢、shape 的參數(shù)將生成不同的靜態(tài)計(jì)算圖火欧。

Specialization 的目的是根據(jù) torchscript 的輸入(Input),為參數(shù)賦予 dtype、shape苇侵、設(shè)備類型(CPU赶盔、CUDA)等靜態(tài)信息(ArgumentSpec),生成 signature榆浓,以便為緩存搜索做準(zhǔn)備于未。

# post specialization, inputs are now specialized types
graph(%x : Float(*, *),
      %hx : Float(*, *),
      %cx : Float(*, *),
      %w_ih : Float(*, *),
      %w_hh : Float(*, *),
      %b_ih : Float(*),
      %b_hh : Float(*)):
  %7 : int = prim::Constant[value=4]()
  %8 : int = prim::Constant[value=1]()
  %9 : Tensor = aten::t(%w_ih)

Optimization

PyTorch JIT 使用一系列 passes(torch.jit.passes)對(duì)圖進(jìn)行優(yōu)化,旨在從執(zhí)行效率哀军、內(nèi)存占用等方面優(yōu)化計(jì)算沉眶。其中包括對(duì) dtype、shape 和常量進(jìn)行前向推導(dǎo)的形狀推導(dǎo)(Shape inference)和常數(shù)傳播(Const propagation)等優(yōu)化杉适,以減少實(shí)際操作的數(shù)量谎倔。

除了上述常見的優(yōu)化,對(duì)于 GPU 來說猿推,最核心的優(yōu)化是算子融合(Operation fusion):將匹配的一組算子合并為一個(gè)算子片习。例如,將連續(xù)的一系列 element-wise 操作合并為一個(gè)操作蹬叭,這樣可以減少 CUDA kernels 的啟動(dòng)時(shí)間開銷藕咏,并減少操作之間訪問全局內(nèi)存的次數(shù)。

圖優(yōu)化是 AI 編譯器的標(biāo)配秽五,用于優(yōu)化計(jì)算圖的執(zhí)行效率和內(nèi)存占用等方面孽查。PyTorch 的圖優(yōu)化通過一系列 passes(torch.jit.passes)來實(shí)現(xiàn),包括常數(shù)折疊(Constant folding)坦喘、死代碼清除(Dead code elimination)和算子融合(Operation fusion)等盲再。

在圖優(yōu)化過程中,FuseGraph pass 將可以融合的算子封裝為 FusionGroup 靜態(tài)子圖:

graph(%x : Float(*, *),
      ...):
  %9 : Float(*, *) = aten::t(%w_ih)
  ...
  %77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
  %78 : Tensor[] = aten::broadcast_tensors(%77)
  %79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
  %hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
  %30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
  return (%30);

with prim::FusionGroup_0 = graph(%13 : Float(*, *),
  ...):
  %87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
  %82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
  %77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
  %72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
  %69 : int = prim::Constant[value=1]()
  %70 : Float(*, *) = aten::add(%77, %72, %69)
  %66 : Float(*, *) = aten::add(%78, %73, %69)
  ...
  %4 : Float(*, *) = aten::tanh(%cy)
  %hy : Float(*, *) = aten::mul(%outgate, %4)
  return (%hy, %cy)

Codegen

優(yōu)化的最后是為圖(symbolic graph)中的符號(hào)操作生成加速器所需的操作內(nèi)核(op kernel)瓣铣。PyTorch 已經(jīng)為 CPU 和 Nvidia GPU 提供了一個(gè)名為 ATen 的 C++ 算子庫答朋,像圖中的 aten::add 節(jié)點(diǎn)就會(huì)在運(yùn)行時(shí)調(diào)用 built-in 算子。

對(duì)于融合算子棠笑,PyTorch 提供了基于 LLVM 的 NNC 編譯器梦碗,用于生成相應(yīng)的目標(biāo)代碼。它將 FusionGroup 子圖里的 node lowering 成 C++ functions蓖救,再基于 LLVM 將它們編譯成一個(gè)大算子:

  RegisterNNCLoweringsFunction aten_matmul(
      {"aten::mm(Tensor self, Tensor mat2) -> (Tensor)",
       "aten::matmul(Tensor self, Tensor other) -> (Tensor)"},
      computeMatmul);

  Tensor computeMatmul(...) {
    ...
    return Tensor(
        ResultBuf.node(),
        ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
  }

  void nnc_aten_matmul(...) {
    ...
    try {
      at::matmul_out(r, self, other);
    } catch (...) {}
  }

Interpreter

TorchScript 提供一個(gè)棧機(jī)解釋器在C++上高效地運(yùn)行計(jì)算圖:

// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));

// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';

END

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末洪规,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子循捺,更是在濱河造成了極大的恐慌淹冰,老刑警劉巖,帶你破解...
    沈念sama閱讀 222,183評(píng)論 6 516
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件巨柒,死亡現(xiàn)場離奇詭異樱拴,居然都是意外死亡柠衍,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,850評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門晶乔,熙熙樓的掌柜王于貴愁眉苦臉地迎上來珍坊,“玉大人,你說我怎么就攤上這事正罢≌舐” “怎么了?”我有些...
    開封第一講書人閱讀 168,766評(píng)論 0 361
  • 文/不壞的土叔 我叫張陵翻具,是天一觀的道長履怯。 經(jīng)常有香客問我,道長裆泳,這世上最難降的妖魔是什么叹洲? 我笑而不...
    開封第一講書人閱讀 59,854評(píng)論 1 299
  • 正文 為了忘掉前任,我火速辦了婚禮工禾,結(jié)果婚禮上运提,老公的妹妹穿的比我還像新娘。我一直安慰自己闻葵,他們只是感情好民泵,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,871評(píng)論 6 398
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著槽畔,像睡著了一般栈妆。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上厢钧,一...
    開封第一講書人閱讀 52,457評(píng)論 1 311
  • 那天鳞尔,我揣著相機(jī)與錄音,去河邊找鬼坏快。 笑死,一個(gè)胖子當(dāng)著我的面吹牛憎夷,可吹牛的內(nèi)容都是我干的莽鸿。 我是一名探鬼主播,決...
    沈念sama閱讀 40,999評(píng)論 3 422
  • 文/蒼蘭香墨 我猛地睜開眼拾给,長吁一口氣:“原來是場噩夢(mèng)啊……” “哼祥得!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起蒋得,我...
    開封第一講書人閱讀 39,914評(píng)論 0 277
  • 序言:老撾萬榮一對(duì)情侶失蹤级及,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后额衙,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體饮焦,經(jīng)...
    沈念sama閱讀 46,465評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡怕吴,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,543評(píng)論 3 342
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了县踢。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片转绷。...
    茶點(diǎn)故事閱讀 40,675評(píng)論 1 353
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖硼啤,靈堂內(nèi)的尸體忽然破棺而出议经,到底是詐尸還是另有隱情,我是刑警寧澤谴返,帶...
    沈念sama閱讀 36,354評(píng)論 5 351
  • 正文 年R本政府宣布煞肾,位于F島的核電站,受9級(jí)特大地震影響嗓袱,放射性物質(zhì)發(fā)生泄漏籍救。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,029評(píng)論 3 335
  • 文/蒙蒙 一索抓、第九天 我趴在偏房一處隱蔽的房頂上張望钧忽。 院中可真熱鬧,春花似錦逼肯、人聲如沸耸黑。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,514評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽大刊。三九已至,卻和暖如春三椿,著一層夾襖步出監(jiān)牢的瞬間缺菌,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,616評(píng)論 1 274
  • 我被黑心中介騙來泰國打工搜锰, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留伴郁,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 49,091評(píng)論 3 378
  • 正文 我出身青樓蛋叼,卻偏偏與公主長得像焊傅,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子狈涮,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,685評(píng)論 2 360

推薦閱讀更多精彩內(nèi)容