概述
PyTorch的成功歸功于其簡單易用性(與Python的用法相似)和動態(tài)靈活性棺蛛。即使在PyTorch 2.0時代,它仍然保持著"Faster, more pythonic and dynamic as ever"的核心特性别厘。
PyTorch的動態(tài)性源自內(nèi)部的調(diào)度器(dispatcher),它可以根據(jù)不同的輸入類型自動選擇正確的運算方式拥诡。當調(diào)用Python函數(shù)時触趴,調(diào)度器會根據(jù)傳入的參數(shù)類型選擇正確的操作實現(xiàn)氮发,這個過程稱為分派(dispatch)。
例如冗懦,當執(zhí)行矩陣乘法(torch.matmul(a, b))時爽冕,調(diào)度器會根據(jù)輸入張量a和b的類型(dtype、shape披蕉、device等)選擇正確的BLAS庫(CPU還是CUDA颈畸,float還是half,是否批量計算)來進行計算没讲。對于PyTorch來說眯娱,模型的執(zhí)行過程就是將各個操作(op)分派給本地方法(native function)執(zhí)行的過程。
dispatcher 為每個 op 都維護了一張?zhí)D(zhuǎn)表(它有點像 C++ 實現(xiàn)多態(tài)用的虛表)食零,如上圖所示困乒,表中每個條目存儲了一個本地方法寂屏,有些方法和輸入張量所屬的設備有關(guān)贰谣,比如 XLA/CUDA/CPU
,有的和 requires_grad
有關(guān)迁霎,比如 Autograd
(這圖是從 ezyang’s blog 拿來的吱抚,他這篇博客詳細講解了分派機制,建議閱讀)考廉。
當 op 被執(zhí)行時秘豹,e.g. aten::addmm
,調(diào)度器會在它的跳轉(zhuǎn)表中找出一個方法來執(zhí)行昌粤,而且一個 op 執(zhí)行過程可能會調(diào)用多個方法既绕,例如,輸入張量需要求導(requires_grad = true)涮坐,那會先調(diào)用 Autograd 方法來構(gòu)建反向圖凄贩,再調(diào)用 backend(CPU/CUDA/XLA)的方法來運算。
分派規(guī)則
跳轉(zhuǎn)表里的條目是以鍵值對的形式來存調(diào)度方法袱讹,其中“鍵”稱為 dispatch key
疲扎,以 bit 的形式存在,bit 值越大捷雕,優(yōu)先級越高椒丧,調(diào)度器會從鍵集(dispatch key set
)中選取優(yōu)先級最高的條目來執(zhí)行。
從上圖可以看到救巷,鍵集不只有一個壶熏,每個輸入張量都有自己的鍵集,還有 local(local include
和local exclude
) 和 global 鍵集浦译,這些鍵集最終會合并棒假,調(diào)度器從中選取優(yōu)先級最高的鍵值對應的方法來執(zhí)行俄占。
輸入張量的鍵集是比較好理解的,張量本身具有很多屬性淆衷,如 layout (dense or sparse)缸榄、shape 和 device (CPU or CUDA),一個屬性對應一個 dispatch key(可以從 DispatchKey.h 找到所有的 key)祝拯。對于不同類型的張量甚带,我們希望能使用不同實現(xiàn)的操作以實現(xiàn)高性能計算的目標。
Local 鍵集 與張量個體無關(guān)佳头,與模型的行為有關(guān)鹰贵,表示模型運行在某模式中,比如 tracing康嘉。它可以允許用戶在某個范圍內(nèi)開啟或關(guān)閉模式碉输。要開啟模式就是往 local include 里添加鍵,要關(guān)閉模式就是往 local exclude 里添加要屏蔽的鍵亭珍。
Global 則表示無論什么操作都會添加的鍵集(圖中 autograd 已經(jīng)從 global 移到 tensor 鍵集)敷钾。
分派流程
前面也提到,一個 op 的執(zhí)行是要經(jīng)歷多次分派的肄梨,上圖就展示了這個過程:
- 首先阻荒,輸入張量需要求導(requires_grad = true),調(diào)度器就分派給 Autograd key 的本地方法众羡。它會為 op 生成一個反向計算操作侨赡,然后,再把控制權(quán)交給調(diào)度器做重新分派粱侣。
- 接著由于輸入張量在CPU上羊壹,CPU的方法會被分派執(zhí)行。
前面提到齐婴,調(diào)度器會調(diào)用優(yōu)先級最高的 dispatch key油猫,因此,重新分派的前提是將已經(jīng)調(diào)度過的鍵從鍵集里清除尔店,否則重新分派將會重復調(diào)用相同的方法眨攘。
Autograd 的本地方法通過在 local exclude 鍵集中添加要屏蔽的鍵(Autograd)來避免方法的重復調(diào)用∠荩可以通過創(chuàng)建 AutoNonVariableTypeMode RAII guard 來實現(xiàn):
class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
at::AutoNonVariableTypeMode g;
return myadd(self, other);
}
...
};
注冊自定義操作
回想一下分派規(guī)則:調(diào)度器首先找到 op 對應的跳轉(zhuǎn)表鲫售,合并鍵集,并調(diào)用鍵值最大的條目中的函數(shù)该肴。由于 dispatch key 是 PyTorch 固定且不可擴展的情竹,因此注冊自定義操作需要注冊 op 以及跳轉(zhuǎn)表中鍵的方法。
注冊 op
TORCH_LIBRARY(myops, m) {
m.def("myadd(Tensor self, Tensor other) -> Tensor");
}
PyTorch 提供 TORCH_LIBRARY
用于將 op(也稱作 schema string
或 signature
)注冊到一個庫里匀哄,用戶可以在 python 通過 c = torch._ops.myops.myadd(a, b)
調(diào)用該 op秦效。
schema 與 TensorFlow 的 op_def
和 ONNX 的 node
一樣雏蛮,都用于描述一個操作,只是由于 PyTorch 是動態(tài)圖的阱州,schema 不需要也不能承載更多信息挑秉。
注冊 dispatch function
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("myadd", myadd_cuda);
}
注冊完 op 后,接著就可以通過 TORCH_LIBRARY_IMPL
注冊 dispatch key 對應的方法苔货。上述代碼片段通過將 myadd_cuda
注冊到鍵:CUDA犀概。
除了為每個鍵單獨注冊一個方法,還可以為所有的鍵注冊一個共同的方法夜惭,這類方法稱為 catch-all
:
TORCH_LIBRARY(myops, m) {
m.def("myadd", myadd_catchall);
}
此外姻灶,還可以為所有 op 的某個鍵注冊一個共同的 fallback
方法:
TORCH_LIBRARY_IMPL(_, XLA, m) {
m.fallback(xla_fallback);
}
除了 dispatch key 具有優(yōu)先級外,這些方法也有優(yōu)先級:impl > catch-all > fallback:
END
PyTorch的調(diào)度器(dispatcher)和分派機制是其靈活性和高性能計算的關(guān)鍵诈茧。調(diào)度器根據(jù)輸入類型自動選擇適當?shù)牟僮鲗崿F(xiàn)产喉,通過分派流程將操作分派給本地方法執(zhí)行。分派規(guī)則通過 dispatch key 和 keyset 確定執(zhí)行方法的優(yōu)先級敢会。注冊自定義操作的過程允許用戶擴展PyTorch的功能曾沈。了解這些原理有助于深入理解PyTorch的內(nèi)部工作機制,并為模型開發(fā)和優(yōu)化提供指導走触。