Notes on PyTorch Internals III

原版英文鏈接:Edward Z. Yang's PyTorch internals : Inside 245-5D

Mechanics

Code Flow

PyTorch源碼倉(cāng)庫(kù)中包含眾多的文件件蝌诡,詳細(xì)的介紹可以參考CONTRIBUTING桩砰。其中最重要的四個(gè)文件夾及源碼模塊為:

torch/: PyTorch模塊庫(kù)逗抑。包含PyTorch開(kāi)發(fā)最常用的功能嘀略,通過(guò)import導(dǎo)入的預(yù)定義模塊。PyTorch的Python前端(frontend)。

torch/csrc/: PyTorch前端模塊的C++代碼,實(shí)現(xiàn)C++代碼與Python的綁定茎芋。此外包含自動(dòng)求導(dǎo)引擎(autograd/)、JIT編譯器(jit/)咆贬,以及PyTorch的C++前端(api/)败徊。

aten: "A Tensor Library"的縮寫,張量相關(guān)操作的實(shí)現(xiàn)掏缎,不包含自動(dòng)求導(dǎo)功能皱蹦。src/目錄下包含兩種實(shí)現(xiàn):已經(jīng)過(guò)時(shí)的c版本實(shí)現(xiàn)(TH/, THC/, THNN/, THCUNN/),和基于C++實(shí)現(xiàn)(ATen/)眷蜈。不同device上張量操作實(shí)現(xiàn)的方式分別位于不同文件夾(ATen/cpu, ATen/cuda, ATen/sparse, ...)

c10: Caffe2和ATen的混合縮寫(caffe ten)沪哺,PyTorch的核心抽象和基礎(chǔ)功能實(shí)現(xiàn),包括張量的具體存儲(chǔ)和實(shí)現(xiàn)方式酌儒,支持部署到服務(wù)器和移動(dòng)端設(shè)別辜妓。PyTorch正在將ATen/core中的基礎(chǔ)核心實(shí)現(xiàn)移植到c10/core。PyTorch的C++后端(backend)

核心模塊支持上層的邏輯實(shí)現(xiàn)忌怎。以PyTorch的相加函數(shù)torch.add為例籍滴,模塊間調(diào)用流程為:

  1. 將Python函數(shù)轉(zhuǎn)換為C函數(shù)調(diào)用,解析Python參數(shù)為C++參數(shù)榴啸,由torch/csrc/中函數(shù)實(shí)現(xiàn)孽惰。例如,以PyTorch的add函數(shù)為例(下列代碼自動(dòng)生成):
// actual binding
static PyMethodDef torch_functions[] ={
...
{"add", (PyCFunction)THPVariable_add, METH_VARARGS | METH_VARKEYWORDS | METH_STATIC, NULL}
...
}

// auto-generated codes, needed to build PyTorch to generate it
static PyObject* torch._C.VariableFunctions.add.THPVariable_add(
          PyObject* self_, PyObject* args, PyObject* kwargs){
static PythonArgParser parser(...);

ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
...

if(r.isNone(3)){
    return wrap(dispatch_add(r.tensor(0), r.tensor(1), r.scalar(2)));
}else{
    return wrap(dispatch_add(r.tensor(0), r.tensor(1), r.scalar(2), r.tensor(3)));
}
...
}

torch_functions定義了Python函數(shù)和C版本函數(shù)名稱的對(duì)應(yīng)關(guān)系鸥印,通過(guò)該映射表查詢對(duì)應(yīng)的C版本函數(shù)勋功。PythonArgParser類實(shí)現(xiàn)了對(duì)PyTorch參數(shù)的C版本解析,然后通過(guò)dispatch_add調(diào)度底層C版本的add實(shí)現(xiàn)库说。計(jì)算結(jié)果通過(guò)wrap重新包裝為PyObject對(duì)象狂鞋,返回給Python層。

2.變量類型的調(diào)度
上一步中的dispathc_add函數(shù)調(diào)度實(shí)際上調(diào)用self.add(tensor, scalar)函數(shù)潜的,即張量自身的實(shí)現(xiàn)版本骚揍,而該版本通過(guò)如下函數(shù)實(shí)現(xiàn),調(diào)用不同變量類型的實(shí)現(xiàn)版本啰挪。

// inline functions defined on the 'type'
inline Tensor Tensor::add(const Tensor& other, Scalar alpha) const {
    return type().add(*this, other, alpha);
}

函數(shù)type()確定變量的具體類型疏咐,并通過(guò)虛函數(shù)add調(diào)度實(shí)際類型的實(shí)現(xiàn)纤掸。這些處理在aten/src/ATen模塊中完成。

3.設(shè)別類型和布局的調(diào)度
type()實(shí)際同時(shí)完成了變量和設(shè)備類型的調(diào)度浑塞,返回類似TypeDefaultGPUFloatType等包括數(shù)據(jù)類型和設(shè)備類型的描述政己。針對(duì)每種類型酌壕,PyTorch在build后會(huì)生成具體的類似如下的實(shí)現(xiàn)代碼:

Tensor TypeDefault:add(const Tensor& self, const Tensor& other, Scalar alpha) const {
    const OptionalDeviceGuard device_guard(device_of(self))   # device type checking
    return at::native::add(self, other, alpha)    # modern c++ impl.
}

由于add函數(shù)對(duì)于不同類型變量及設(shè)備類型的底層實(shí)現(xiàn)相同,通可以過(guò)TypeDefault統(tǒng)一封裝歇由。如果某種計(jì)算操作有不同實(shí)現(xiàn)卵牍,則需要擴(kuò)展實(shí)現(xiàn)并調(diào)用對(duì)應(yīng)版本,類似于GPUFloatType::add(...)沦泌。

4.核心代碼的調(diào)用
第三步中的代碼封裝了更為底層的at::native::add(self, other, alpha)實(shí)現(xiàn)糊昙。這些實(shí)現(xiàn)依賴aten/src/ATen中的模塊,通過(guò)C++版本(native/)或者過(guò)時(shí)的c版本實(shí)現(xiàn)(TH/, THC/, THNN/, THCUNN/)谢谦。

Kernels

PyTorch提供了一些工具和規(guī)范用于開(kāi)發(fā)核心計(jì)算操作符释牺,由aten/src/ATen模塊支持。一段完整的自定義核心操作示例代碼所示:

Tensor my_op(Tensor& result, const Tensor& self, const Tensor& other){
    // error checking
    TORCH_CHECK(result.is_cpu() && self.is_cpu() && other.is_cpu());
    TORCH_CHECK(self.dim() == 1);
    TORCH_CHECK(self.sizes() == other.sizes());

    // output allocation
    result.resize_(self.sizes());
    
    // data type (dtype) dispatch
    AT_DISPATCH_FORALL_TYPES(
        self.scalar_type(), "my_op", [&]{
            my_op_cpu<scalar_t>(result, self, other), 
        }
    );
}

template<typename scalar_t>
void my_op_cpu(Tensor& result, const Tensor& self, const Tensor& other){
    // data access
    auto result_accessor = result.accessor<scalar_t, 1>();
    auto self_accessor = self.accessor<scalar_t, 1>();
    auto other_accessor = other.accessor<scalar_t, 1>();
  
    // parallelization
    parallel_for(0, self.size(0), 0, [&](int64_t start, int64_t end){
        ... self_accessor[i] ...
    });
}

包含如下幾個(gè)部分:

元數(shù)據(jù)注冊(cè):由PyTorch提供的元數(shù)據(jù)要求回挽,用于自動(dòng)化生成Python的綁定代碼(如上節(jié)介紹的Python與C代碼之間的轉(zhuǎn)換和參數(shù)解析)没咙。每個(gè)定義的核心操作都需要提供如下的元數(shù)據(jù)模式:

    - func: func_name(ArgType arg0, ArgType arg1, ...) -> Return
    variants: function, method
    dispatch:
        CPU: func_cpu
        CUDA: func_cuda 

其中:
func_name:所定義的核心計(jì)算操作函數(shù)的名稱。
ArgType:參數(shù)類型千劈,可以是Tensor, Tensor[], int, int[], float, Scalar等祭刚。
variants: 包含functionmethod兩個(gè)類型,用于控制PyTorch自動(dòng)生成Python版本函數(shù)的名稱是張量方法(t.foo())還是命名空間的函數(shù)(at::foo())墙牌。當(dāng)使用method變體時(shí)涡驮,需要包含self參數(shù)。在自動(dòng)生成Python版本函數(shù)名稱時(shí)喜滨,該self參數(shù)會(huì)從參數(shù)列表中去掉捉捅。例如對(duì)于where(BoolTensor cond, Tensor self, Tensor other)的函數(shù)聲明。設(shè)置為method會(huì)自動(dòng)生成self.where(cond, other)的函數(shù)名稱鸿市;設(shè)置為function會(huì)自動(dòng)生成at::where(cond, self, other)的函數(shù)名稱锯梁。缺省情況下,ATen對(duì)native函數(shù)只生成function方式名稱焰情,對(duì)張量相關(guān)的核心操作符(e.g, add, sub等)可以使用method方式名稱陌凳。
dispatch: 指定針對(duì)不同設(shè)別類型,該函數(shù)可以調(diào)度的實(shí)際函數(shù)名稱内舟『隙兀可以針對(duì)不同設(shè)別類型,指定生成不同的版本的函數(shù)名稱验游。

更詳細(xì)的規(guī)范要求可參考aten/src/ATen/native/README.md充岛。任何自定義的核心計(jì)算函數(shù)的元數(shù)據(jù)需要按照如上要求編寫保檐,并添加到native_functions.yaml
文件中進(jìn)行注冊(cè)。PyTorch會(huì)對(duì)注冊(cè)的函數(shù)按照元數(shù)據(jù)描述的要求崔梗,自動(dòng)生成Python的綁定夜只。

上述自定義的核心操作函數(shù),一種可能的元數(shù)據(jù)描述如下:

-func: my_op(Tensor& result, const Tensor& self, const Tensor& other) -> Tensor
variants: function, method
dispath:
    CPU: my_op_cpu
    CUDA: my_op_cuda

對(duì)于需要支持反向梯度計(jì)算的核心計(jì)算操作蒜魄,需要按照類似的方式提供求導(dǎo)操作函數(shù)的元數(shù)據(jù)扔亥。具體可參考derivatives.yaml

錯(cuò)誤檢測(cè)(Error Checking)
錯(cuò)誤檢查在編寫核心代碼時(shí)非常重要。PyTorch提供了兩種錯(cuò)誤檢查的工具方便開(kāi)發(fā)者:low level的方式是提供了TORCH_CHECK宏谈为;High level方式通過(guò)將Tensor封裝為TensorArg旅挤,并提供checkDim等檢測(cè)函數(shù)。

輸出存儲(chǔ)分配(Output Allocation)
在輸出結(jié)果前伞鲫,需要預(yù)先分配內(nèi)存用于存儲(chǔ)粘茄。PyTorch支持預(yù)分配輸出、原位輸出秕脓、拷貝輸出等方式輸出結(jié)果柒瓣。實(shí)現(xiàn)過(guò)程中,原位輸出和拷貝輸出只是預(yù)分配輸出的簡(jiǎn)單封裝撒会。例如

// pre-allocate storage for 'result' outside
Tensor& abs_out(Tensor& result, const Tensor& self){
    result.resize_(self.sizes());
    // ... the real impl.
}

// a new allocated operation
Tensor& abs(const Tensor& self){
    Tensor result = at::empty({0}, self.options());
    abs_out(result, self);
    return result
}

// in-place operation
Tensor& abs_(const Tensor& self){
    return abs_out(self, self);
}

數(shù)據(jù)類型調(diào)度(Dtype Dispatch)
通過(guò)AT_DISPATCH_ALL_TYPES宏定義數(shù)據(jù)類型調(diào)度嘹朗。該宏其實(shí)是一個(gè)模版函數(shù),通過(guò)變量當(dāng)前類型進(jìn)行特化诵肛,調(diào)度實(shí)際匹配的類型實(shí)現(xiàn)屹培。

數(shù)據(jù)訪問(wèn)(Data Access)
PyTorch支持三種不同的、針對(duì)張量的訪問(wèn)方式怔檩。封裝的訪問(wèn)方式比訪問(wèn)原始數(shù)據(jù)指針?lè)奖阃市悖梢宰詣?dòng)處理底層的stride或者布局。TensorAccesor支持訪問(wèn)張量某個(gè)特定位置的數(shù)據(jù)薛训;TensorIterator支持規(guī)則方式輪詢?cè)L問(wèn)媒吗;針對(duì)CPU的序列化,提供了Vec256等序列化描述訪問(wèn)乙埃。

Notes on PyTorch Internals系列文章

Notes on PyTorch Internals I
Notes on PyTorch Internals II
Notes on PyTorch Internals III

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末闸英,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子介袜,更是在濱河造成了極大的恐慌甫何,老刑警劉巖卑雁,帶你破解...
    沈念sama閱讀 212,029評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件付秕,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡哟沫,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,395評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門巍耗,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)秋麸,“玉大人,你說(shuō)我怎么就攤上這事炬太【捏。” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 157,570評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵亲族,是天一觀的道長(zhǎng)次乓。 經(jīng)常有香客問(wèn)我,道長(zhǎng)孽水,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書人閱讀 56,535評(píng)論 1 284
  • 正文 為了忘掉前任城看,我火速辦了婚禮女气,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘测柠。我一直安慰自己炼鞠,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,650評(píng)論 6 386
  • 文/花漫 我一把揭開(kāi)白布轰胁。 她就那樣靜靜地躺著谒主,像睡著了一般。 火紅的嫁衣襯著肌膚如雪赃阀。 梳的紋絲不亂的頭發(fā)上霎肯,一...
    開(kāi)封第一講書人閱讀 49,850評(píng)論 1 290
  • 那天,我揣著相機(jī)與錄音榛斯,去河邊找鬼观游。 笑死,一個(gè)胖子當(dāng)著我的面吹牛驮俗,可吹牛的內(nèi)容都是我干的懂缕。 我是一名探鬼主播,決...
    沈念sama閱讀 39,006評(píng)論 3 408
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼王凑,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼搪柑!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起索烹,我...
    開(kāi)封第一講書人閱讀 37,747評(píng)論 0 268
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤工碾,失蹤者是張志新(化名)和其女友劉穎,沒(méi)想到半個(gè)月后术荤,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體倚喂,經(jīng)...
    沈念sama閱讀 44,207評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,536評(píng)論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了端圈。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片焦读。...
    茶點(diǎn)故事閱讀 38,683評(píng)論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖舱权,靈堂內(nèi)的尸體忽然破棺而出矗晃,到底是詐尸還是另有隱情,我是刑警寧澤宴倍,帶...
    沈念sama閱讀 34,342評(píng)論 4 330
  • 正文 年R本政府宣布张症,位于F島的核電站,受9級(jí)特大地震影響鸵贬,放射性物質(zhì)發(fā)生泄漏俗他。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,964評(píng)論 3 315
  • 文/蒙蒙 一阔逼、第九天 我趴在偏房一處隱蔽的房頂上張望兆衅。 院中可真熱鬧,春花似錦嗜浮、人聲如沸羡亩。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 30,772評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)畏铆。三九已至,卻和暖如春吉殃,著一層夾襖步出監(jiān)牢的瞬間辞居,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 32,004評(píng)論 1 266
  • 我被黑心中介騙來(lái)泰國(guó)打工寨腔, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留速侈,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,401評(píng)論 2 360
  • 正文 我出身青樓迫卢,卻偏偏與公主長(zhǎng)得像倚搬,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子乾蛤,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,566評(píng)論 2 349

推薦閱讀更多精彩內(nèi)容