原版英文鏈接:Edward Z. Yang's PyTorch internals : Inside 245-5D
Mechanics
Code Flow
PyTorch源碼倉(cāng)庫(kù)中包含眾多的文件件蝌诡,詳細(xì)的介紹可以參考CONTRIBUTING桩砰。其中最重要的四個(gè)文件夾及源碼模塊為:
torch/: PyTorch模塊庫(kù)逗抑。包含PyTorch開(kāi)發(fā)最常用的功能嘀略,通過(guò)import導(dǎo)入的預(yù)定義模塊。PyTorch的Python前端(frontend)。
torch/csrc/: PyTorch前端模塊的C++代碼,實(shí)現(xiàn)C++代碼與Python的綁定茎芋。此外包含自動(dòng)求導(dǎo)引擎(autograd/)、JIT編譯器(jit/)咆贬,以及PyTorch的C++前端(api/)败徊。
aten: "A Tensor Library"的縮寫,張量相關(guān)操作的實(shí)現(xiàn)掏缎,不包含自動(dòng)求導(dǎo)功能皱蹦。src/目錄下包含兩種實(shí)現(xiàn):已經(jīng)過(guò)時(shí)的c版本實(shí)現(xiàn)(TH/, THC/, THNN/, THCUNN/),和基于C++實(shí)現(xiàn)(ATen/)眷蜈。不同device上張量操作實(shí)現(xiàn)的方式分別位于不同文件夾(ATen/cpu, ATen/cuda, ATen/sparse, ...)
c10: Caffe2和ATen的混合縮寫(caffe ten)沪哺,PyTorch的核心抽象和基礎(chǔ)功能實(shí)現(xiàn),包括張量的具體存儲(chǔ)和實(shí)現(xiàn)方式酌儒,支持部署到服務(wù)器和移動(dòng)端設(shè)別辜妓。PyTorch正在將ATen/core中的基礎(chǔ)核心實(shí)現(xiàn)移植到c10/core。PyTorch的C++后端(backend)
核心模塊支持上層的邏輯實(shí)現(xiàn)忌怎。以PyTorch的相加函數(shù)torch.add
為例籍滴,模塊間調(diào)用流程為:
- 將Python函數(shù)轉(zhuǎn)換為C函數(shù)調(diào)用,解析Python參數(shù)為C++參數(shù)榴啸,由torch/csrc/中函數(shù)實(shí)現(xiàn)孽惰。例如,以PyTorch的
add
函數(shù)為例(下列代碼自動(dòng)生成):
// actual binding
static PyMethodDef torch_functions[] ={
...
{"add", (PyCFunction)THPVariable_add, METH_VARARGS | METH_VARKEYWORDS | METH_STATIC, NULL}
...
}
// auto-generated codes, needed to build PyTorch to generate it
static PyObject* torch._C.VariableFunctions.add.THPVariable_add(
PyObject* self_, PyObject* args, PyObject* kwargs){
static PythonArgParser parser(...);
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
...
if(r.isNone(3)){
return wrap(dispatch_add(r.tensor(0), r.tensor(1), r.scalar(2)));
}else{
return wrap(dispatch_add(r.tensor(0), r.tensor(1), r.scalar(2), r.tensor(3)));
}
...
}
torch_functions
定義了Python函數(shù)和C版本函數(shù)名稱的對(duì)應(yīng)關(guān)系鸥印,通過(guò)該映射表查詢對(duì)應(yīng)的C版本函數(shù)勋功。PythonArgParser
類實(shí)現(xiàn)了對(duì)PyTorch參數(shù)的C版本解析,然后通過(guò)dispatch_add
調(diào)度底層C版本的add
實(shí)現(xiàn)库说。計(jì)算結(jié)果通過(guò)wrap
重新包裝為PyObject
對(duì)象狂鞋,返回給Python層。
2.變量類型的調(diào)度
上一步中的dispathc_add
函數(shù)調(diào)度實(shí)際上調(diào)用self.add(tensor, scalar)
函數(shù)潜的,即張量自身的實(shí)現(xiàn)版本骚揍,而該版本通過(guò)如下函數(shù)實(shí)現(xiàn),調(diào)用不同變量類型的實(shí)現(xiàn)版本啰挪。
// inline functions defined on the 'type'
inline Tensor Tensor::add(const Tensor& other, Scalar alpha) const {
return type().add(*this, other, alpha);
}
函數(shù)type()
確定變量的具體類型疏咐,并通過(guò)虛函數(shù)add
調(diào)度實(shí)際類型的實(shí)現(xiàn)纤掸。這些處理在aten/src/ATen模塊中完成。
3.設(shè)別類型和布局的調(diào)度
type()
實(shí)際同時(shí)完成了變量和設(shè)備類型的調(diào)度浑塞,返回類似TypeDefault
、GPUFloatType
等包括數(shù)據(jù)類型和設(shè)備類型的描述政己。針對(duì)每種類型酌壕,PyTorch在build后會(huì)生成具體的類似如下的實(shí)現(xiàn)代碼:
Tensor TypeDefault:add(const Tensor& self, const Tensor& other, Scalar alpha) const {
const OptionalDeviceGuard device_guard(device_of(self)) # device type checking
return at::native::add(self, other, alpha) # modern c++ impl.
}
由于add
函數(shù)對(duì)于不同類型變量及設(shè)備類型的底層實(shí)現(xiàn)相同,通可以過(guò)TypeDefault
統(tǒng)一封裝歇由。如果某種計(jì)算操作有不同實(shí)現(xiàn)卵牍,則需要擴(kuò)展實(shí)現(xiàn)并調(diào)用對(duì)應(yīng)版本,類似于GPUFloatType::add(...)
沦泌。
4.核心代碼的調(diào)用
第三步中的代碼封裝了更為底層的at::native::add(self, other, alpha)
實(shí)現(xiàn)糊昙。這些實(shí)現(xiàn)依賴aten/src/ATen中的模塊,通過(guò)C++版本(native/)或者過(guò)時(shí)的c版本實(shí)現(xiàn)(TH/, THC/, THNN/, THCUNN/)谢谦。
Kernels
PyTorch提供了一些工具和規(guī)范用于開(kāi)發(fā)核心計(jì)算操作符释牺,由aten/src/ATen模塊支持。一段完整的自定義核心操作示例代碼所示:
Tensor my_op(Tensor& result, const Tensor& self, const Tensor& other){
// error checking
TORCH_CHECK(result.is_cpu() && self.is_cpu() && other.is_cpu());
TORCH_CHECK(self.dim() == 1);
TORCH_CHECK(self.sizes() == other.sizes());
// output allocation
result.resize_(self.sizes());
// data type (dtype) dispatch
AT_DISPATCH_FORALL_TYPES(
self.scalar_type(), "my_op", [&]{
my_op_cpu<scalar_t>(result, self, other),
}
);
}
template<typename scalar_t>
void my_op_cpu(Tensor& result, const Tensor& self, const Tensor& other){
// data access
auto result_accessor = result.accessor<scalar_t, 1>();
auto self_accessor = self.accessor<scalar_t, 1>();
auto other_accessor = other.accessor<scalar_t, 1>();
// parallelization
parallel_for(0, self.size(0), 0, [&](int64_t start, int64_t end){
... self_accessor[i] ...
});
}
包含如下幾個(gè)部分:
元數(shù)據(jù)注冊(cè):由PyTorch提供的元數(shù)據(jù)要求回挽,用于自動(dòng)化生成Python的綁定代碼(如上節(jié)介紹的Python與C代碼之間的轉(zhuǎn)換和參數(shù)解析)没咙。每個(gè)定義的核心操作都需要提供如下的元數(shù)據(jù)模式:
- func: func_name(ArgType arg0, ArgType arg1, ...) -> Return
variants: function, method
dispatch:
CPU: func_cpu
CUDA: func_cuda
其中:
func_name
:所定義的核心計(jì)算操作函數(shù)的名稱。
ArgType
:參數(shù)類型千劈,可以是Tensor, Tensor[], int, int[], float, Scalar
等祭刚。
variants
: 包含function
和method
兩個(gè)類型,用于控制PyTorch自動(dòng)生成Python版本函數(shù)的名稱是張量方法(t.foo()
)還是命名空間的函數(shù)(at::foo()
)墙牌。當(dāng)使用method
變體時(shí)涡驮,需要包含self
參數(shù)。在自動(dòng)生成Python版本函數(shù)名稱時(shí)喜滨,該self
參數(shù)會(huì)從參數(shù)列表中去掉捉捅。例如對(duì)于where(BoolTensor cond, Tensor self, Tensor other)
的函數(shù)聲明。設(shè)置為method
會(huì)自動(dòng)生成self.where(cond, other)
的函數(shù)名稱鸿市;設(shè)置為function
會(huì)自動(dòng)生成at::where(cond, self, other)
的函數(shù)名稱锯梁。缺省情況下,ATen對(duì)native函數(shù)只生成function
方式名稱焰情,對(duì)張量相關(guān)的核心操作符(e.g, add, sub
等)可以使用method
方式名稱陌凳。
dispatch
: 指定針對(duì)不同設(shè)別類型,該函數(shù)可以調(diào)度的實(shí)際函數(shù)名稱内舟『隙兀可以針對(duì)不同設(shè)別類型,指定生成不同的版本的函數(shù)名稱验游。
更詳細(xì)的規(guī)范要求可參考aten/src/ATen/native/README.md充岛。任何自定義的核心計(jì)算函數(shù)的元數(shù)據(jù)需要按照如上要求編寫保檐,并添加到native_functions.yaml
文件中進(jìn)行注冊(cè)。PyTorch會(huì)對(duì)注冊(cè)的函數(shù)按照元數(shù)據(jù)描述的要求崔梗,自動(dòng)生成Python的綁定夜只。
上述自定義的核心操作函數(shù),一種可能的元數(shù)據(jù)描述如下:
-func: my_op(Tensor& result, const Tensor& self, const Tensor& other) -> Tensor
variants: function, method
dispath:
CPU: my_op_cpu
CUDA: my_op_cuda
對(duì)于需要支持反向梯度計(jì)算的核心計(jì)算操作蒜魄,需要按照類似的方式提供求導(dǎo)操作函數(shù)的元數(shù)據(jù)扔亥。具體可參考derivatives.yaml
錯(cuò)誤檢測(cè)(Error Checking)
錯(cuò)誤檢查在編寫核心代碼時(shí)非常重要。PyTorch提供了兩種錯(cuò)誤檢查的工具方便開(kāi)發(fā)者:low level的方式是提供了TORCH_CHECK
宏谈为;High level方式通過(guò)將Tensor
封裝為TensorArg
旅挤,并提供checkDim
等檢測(cè)函數(shù)。
輸出存儲(chǔ)分配(Output Allocation)
在輸出結(jié)果前伞鲫,需要預(yù)先分配內(nèi)存用于存儲(chǔ)粘茄。PyTorch支持預(yù)分配輸出、原位輸出秕脓、拷貝輸出等方式輸出結(jié)果柒瓣。實(shí)現(xiàn)過(guò)程中,原位輸出和拷貝輸出只是預(yù)分配輸出的簡(jiǎn)單封裝撒会。例如
// pre-allocate storage for 'result' outside
Tensor& abs_out(Tensor& result, const Tensor& self){
result.resize_(self.sizes());
// ... the real impl.
}
// a new allocated operation
Tensor& abs(const Tensor& self){
Tensor result = at::empty({0}, self.options());
abs_out(result, self);
return result
}
// in-place operation
Tensor& abs_(const Tensor& self){
return abs_out(self, self);
}
數(shù)據(jù)類型調(diào)度(Dtype Dispatch)
通過(guò)AT_DISPATCH_ALL_TYPES
宏定義數(shù)據(jù)類型調(diào)度嘹朗。該宏其實(shí)是一個(gè)模版函數(shù),通過(guò)變量當(dāng)前類型進(jìn)行特化诵肛,調(diào)度實(shí)際匹配的類型實(shí)現(xiàn)屹培。
數(shù)據(jù)訪問(wèn)(Data Access)
PyTorch支持三種不同的、針對(duì)張量的訪問(wèn)方式怔檩。封裝的訪問(wèn)方式比訪問(wèn)原始數(shù)據(jù)指針?lè)奖阃市悖梢宰詣?dòng)處理底層的stride
或者布局。TensorAccesor
支持訪問(wèn)張量某個(gè)特定位置的數(shù)據(jù)薛训;TensorIterator
支持規(guī)則方式輪詢?cè)L問(wèn)媒吗;針對(duì)CPU的序列化,提供了Vec256
等序列化描述訪問(wèn)乙埃。
Notes on PyTorch Internals系列文章
Notes on PyTorch Internals I
Notes on PyTorch Internals II
Notes on PyTorch Internals III