概述
PyTorch是一款非常流行的深度學(xué)習(xí)框架,開發(fā)者和研究者常常選擇它葛账,因?yàn)樗哂徐`活性柠衅、易用性和良好的性能。然而籍琳,PyTorch的靈活易用性是建立在動(dòng)態(tài)計(jì)算圖的基礎(chǔ)上的菲宴,相比采用靜態(tài)圖的TensorFlow,PyTorch在推理性能和部署方面存在明顯的劣勢(shì)趋急。
為了解決這個(gè)問題喝峦,TorchScript應(yīng)運(yùn)而生。它將PyTorch模型轉(zhuǎn)換為靜態(tài)類型的優(yōu)化序列化格式呜达,以實(shí)現(xiàn)高效的優(yōu)化和跨平臺(tái)部署(包括C++谣蠢、Python、移動(dòng)設(shè)備和云端)查近。
構(gòu)建 TorchScript
TorchScript將PyTorch模型轉(zhuǎn)換為靜態(tài)圖形式眉踱,因此構(gòu)建TorchScript的核心是構(gòu)建模型的靜態(tài)計(jì)算圖。
PyTorch提供了兩種方法來構(gòu)建TorchScript:trace和script霜威。
torch.jit.trace
:該函數(shù)接收一個(gè)已訓(xùn)練好的模型和實(shí)際輸入樣例谈喳,通過運(yùn)行模型的方式來生成靜態(tài)圖(static graph)。這種轉(zhuǎn)換方式稱為"追蹤模式"(tracing mode)戈泼。torch.jit.script
:該函數(shù)將PyTorch代碼編譯成靜態(tài)圖婿禽。與追蹤模式相反,它被稱為"腳本模式"(scripting mode)大猛,因?yàn)樗苯訉yTorch代碼翻譯成靜態(tài)圖扭倾,而不需要追蹤執(zhí)行流程。
Tracing Mode
model = torch.nn.Sequential(nn.Linear(3, 4))
input = torch.randn(1, 3)
traced_model = torch.jit.trace(model, input)
追蹤模式通過運(yùn)行模型一次挽绩,并根據(jù)操作序列生成靜態(tài)圖吆录。因此,它需要提供輸入樣例(input)琼牧。通過追蹤機(jī)制恢筝,自動(dòng)捕捉和生成模型的計(jì)算圖。這是許多AI編譯器采用的JIT模式巨坊。然而撬槽,追蹤模式存在一個(gè)問題,即無法處理控制流趾撵,例如if侄柔、while等語句共啃。
class MyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
if x > 0:
x += 1
else:
x -= 1
return x
model = MyModel()
input = torch.randn(1)
traced_model = torch.jit.trace(model, input)
以if
語句為例,它是Python的語句暂题,根據(jù)具體的x
值移剪,只能在then分支或else分支上執(zhí)行。因此薪者,追蹤模式只能捕捉到一個(gè)分支上的操作纵苛。要想生成完整的控制流圖,需要采用腳本模式言津。
Scripting Mode
與追蹤模式不同攻人,Scripting 模式直接將 Python 和 PyTorch 的語句翻譯成 TorchScript 的靜態(tài)圖,因此不需要追蹤模型的執(zhí)行流程悬槽,并且能夠生成完整的控制流圖:
script_model = torch.jit.script(model)
print(script_model.graph)
graph(%self : __torch__.___torch_mangle_3.MyModel,
%x.1 : Tensor):
......
%x : Tensor = prim::If(%6) # <ipython-input-3-6fda6c66b1df>:6:4
block0():
%x.7 : Tensor = aten::add_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:7:6
-> (%x.7)
block1():
%x.13 : Tensor = aten::sub_(%x.1, %8, %8) # <ipython-input-3-6fda6c66b1df>:9:6
-> (%x.13)
return (%x)
然而怀吻,這種模式也有其局限性:對(duì)于每個(gè)語句,都需要提供相應(yīng)的轉(zhuǎn)換函數(shù)初婆,將 Python/PyTorch 語句轉(zhuǎn)換成 TorchScript 語句蓬坡。目前,PyTorch僅支持部分 Python 內(nèi)置函數(shù)和 PyTorch 語句的轉(zhuǎn)換磅叛。
Tracing + Script
因此渣窜,對(duì)于具有控制流的模型,可以采用混合模式:將追蹤模式無法處理的控制流圖封裝為子模塊宪躯,使用腳本模式來轉(zhuǎn)換這些子模塊,然后通過追蹤機(jī)制對(duì)整個(gè)模型進(jìn)行追蹤(通過腳本模式轉(zhuǎn)換后的子模塊不會(huì)再被追蹤)位迂。有關(guān)具體實(shí)現(xiàn)访雪,請(qǐng)參考官方示例:https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting。
運(yùn)行 TorchScript
前面生成的計(jì)算圖會(huì)封裝到 TorchScript 模塊的 forward()
方法中掂林,在運(yùn)行時(shí)被編譯成 native code(JIT)臣缀。如上圖所示,特化后的計(jì)算圖經(jīng)過圖優(yōu)化后被編譯成 native code泻帮,最后通過棧機(jī)解釋器執(zhí)行精置。
Specialization
JIT(just-in-time)將靜態(tài)圖編譯后的結(jié)果以 <signature: executable> 鍵值對(duì)的形式存儲(chǔ)在緩存中。只有在緩存未命中(cache miss)時(shí)锣杂,也就是首次運(yùn)行時(shí)脂倦,才會(huì)觸發(fā)編譯過程。
Signature 表示唯一的靜態(tài)計(jì)算圖元莫。在計(jì)算流圖不變的情況下赖阻,靜態(tài)圖由輸入?yún)?shù)(arguments)決定。不同的 dtype踱蠢、shape 的參數(shù)將生成不同的靜態(tài)計(jì)算圖火欧。
Specialization 的目的是根據(jù) torchscript 的輸入(Input),為參數(shù)賦予 dtype、shape苇侵、設(shè)備類型(CPU赶盔、CUDA)等靜態(tài)信息(ArgumentSpec
),生成 signature榆浓,以便為緩存搜索做準(zhǔn)備于未。
# post specialization, inputs are now specialized types
graph(%x : Float(*, *),
%hx : Float(*, *),
%cx : Float(*, *),
%w_ih : Float(*, *),
%w_hh : Float(*, *),
%b_ih : Float(*),
%b_hh : Float(*)):
%7 : int = prim::Constant[value=4]()
%8 : int = prim::Constant[value=1]()
%9 : Tensor = aten::t(%w_ih)
Optimization
PyTorch JIT 使用一系列 passes(torch.jit.passes
)對(duì)圖進(jìn)行優(yōu)化,旨在從執(zhí)行效率哀军、內(nèi)存占用等方面優(yōu)化計(jì)算沉眶。其中包括對(duì) dtype、shape 和常量進(jìn)行前向推導(dǎo)的形狀推導(dǎo)(Shape inference)和常數(shù)傳播(Const propagation)等優(yōu)化杉适,以減少實(shí)際操作的數(shù)量谎倔。
除了上述常見的優(yōu)化,對(duì)于 GPU 來說猿推,最核心的優(yōu)化是算子融合(Operation fusion):將匹配的一組算子合并為一個(gè)算子片习。例如,將連續(xù)的一系列 element-wise 操作合并為一個(gè)操作蹬叭,這樣可以減少 CUDA kernels 的啟動(dòng)時(shí)間開銷藕咏,并減少操作之間訪問全局內(nèi)存的次數(shù)。
圖優(yōu)化是 AI 編譯器的標(biāo)配秽五,用于優(yōu)化計(jì)算圖的執(zhí)行效率和內(nèi)存占用等方面孽查。PyTorch 的圖優(yōu)化通過一系列 passes(torch.jit.passes)來實(shí)現(xiàn),包括常數(shù)折疊(Constant folding)坦喘、死代碼清除(Dead code elimination)和算子融合(Operation fusion)等盲再。
在圖優(yōu)化過程中,FuseGraph
pass 將可以融合的算子封裝為 FusionGroup 靜態(tài)子圖:
graph(%x : Float(*, *),
...):
%9 : Float(*, *) = aten::t(%w_ih)
...
%77 : Tensor[] = prim::ListConstruct(%b_hh, %b_ih, %10, %12)
%78 : Tensor[] = aten::broadcast_tensors(%77)
%79 : Tensor, %80 : Tensor, %81 : Tensor, %82 : Tensor = prim::ListUnpack(%78)
%hy : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %82, %81, %80, %79)
%30 : (Float(*, *), Float(*, *)) = prim::TupleConstruct(%hy, %cy)
return (%30);
with prim::FusionGroup_0 = graph(%13 : Float(*, *),
...):
%87 : Float(*, *), %88 : Float(*, *), %89 : Float(*, *), %90 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%86)
%82 : Float(*, *), %83 : Float(*, *), %84 : Float(*, *), %85 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%81)
%77 : Float(*, *), %78 : Float(*, *), %79 : Float(*, *), %80 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%76)
%72 : Float(*, *), %73 : Float(*, *), %74 : Float(*, *), %75 : Float(*, *) = prim::ConstantChunk[chunks=4, dim=1](%71)
%69 : int = prim::Constant[value=1]()
%70 : Float(*, *) = aten::add(%77, %72, %69)
%66 : Float(*, *) = aten::add(%78, %73, %69)
...
%4 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %4)
return (%hy, %cy)
Codegen
優(yōu)化的最后是為圖(symbolic graph)中的符號(hào)操作生成加速器所需的操作內(nèi)核(op kernel)瓣铣。PyTorch 已經(jīng)為 CPU 和 Nvidia GPU 提供了一個(gè)名為 ATen 的 C++ 算子庫答朋,像圖中的 aten::add
節(jié)點(diǎn)就會(huì)在運(yùn)行時(shí)調(diào)用 built-in 算子。
對(duì)于融合算子棠笑,PyTorch 提供了基于 LLVM 的 NNC
編譯器梦碗,用于生成相應(yīng)的目標(biāo)代碼。它將 FusionGroup 子圖里的 node lowering 成 C++ functions蓖救,再基于 LLVM 將它們編譯成一個(gè)大算子:
RegisterNNCLoweringsFunction aten_matmul(
{"aten::mm(Tensor self, Tensor mat2) -> (Tensor)",
"aten::matmul(Tensor self, Tensor other) -> (Tensor)"},
computeMatmul);
Tensor computeMatmul(...) {
...
return Tensor(
ResultBuf.node(),
ExternalCall::make(ResultBuf, "nnc_aten_matmul", {a, b}, {}));
}
void nnc_aten_matmul(...) {
...
try {
at::matmul_out(r, self, other);
} catch (...) {}
}
Interpreter
TorchScript 提供一個(gè)棧機(jī)解釋器在C++上高效地運(yùn)行計(jì)算圖:
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';