Tensor
自從張量(Tensor)計算這個概念出現(xiàn)后棒假,神經(jīng)網(wǎng)絡(luò)的算法就可以看作是一系列的張量計算蜡饵。所謂的張量条辟,它原本是個數(shù)學(xué)概念黔夭,表示各種向量或者數(shù)值之間的關(guān)系。PyTorch的張量(torch.Tensor)表示的是N維矩陣與一維數(shù)組的關(guān)系羽嫡。
torch.Tensor的使用方法和numpy很相似(https://pytorch.org/...tensor-tutorial-py)本姥,兩者唯一的區(qū)別在于torch.Tensor可以使用GPU來計算,這就比用CPU的numpy要快很多杭棵。
張量計算的種類有很多婚惫,比如加法、乘法、矩陣相乘先舷、矩陣轉(zhuǎn)置等艰管,這些計算被稱為算子(Operator),它們是PyTorch的核心組件蒋川。
算子的backend一般是C/C++的拓展程序牲芋,PyTorch的backend是稱為"ATen"的C/C++庫,ATen是"A Tensor"的縮寫捺球。
Operator
PyTorch所有的Operator都定義在Declarations.cwrap和native_functions.yaml這兩個文件中缸浦,前者定義了從Torch那繼承來的legacy operator(aten/src/TH),后者定義的是native operator懒构,是PyTorch的operator餐济。
相比于用C++開發(fā)的native code,legacy code是在PyTorch編譯時由gen.py根據(jù)Declarations.cwrap的內(nèi)容動態(tài)生成的胆剧。因此絮姆,如果你想要trace這些code,需要先編譯PyTorch秩霍。
legacy code的開發(fā)要比native code復(fù)雜得多篙悯。如果可以的話,建議你盡量避開它們铃绒。
MatMul
本文會以矩陣相乘--torch.matmul()為例來分析PyTorch算子的工作流程鸽照。
我在深入淺出全連接層(fully connected layer)中有講在GPU層面是如何進行矩陣相乘的。Nvidia颠悬、AMD等公司提供了優(yōu)化好的線性代數(shù)計算庫--cuBLAS/rocBLAS/openBLAS矮燎,PyTorch只需要調(diào)用它們的API即可。
Figure 1是torch.matmul()在ATen中的function flow赔癌〉猓可以看到,這個flow可不短灾票,這主要是因為不同類型的tensor(2d or Nd, batched gemm or not峡谊,with or without bias,cuda or cpu)的操作也不盡相同刊苍。
at::matmul()主要負(fù)責(zé)將Tensor轉(zhuǎn)換成cuBLAS需要的格式既们。前面說過,Tensor可以是N維矩陣正什,如果tensor A是3d矩陣啥纸,tensor B是2d矩陣,就需要先將3d轉(zhuǎn)成2d婴氮;如果它們都是>=3d的矩陣斯棒,就要考慮batched matmul的情況馒索;如果bias=True,后續(xù)就應(yīng)該交給at::addmm()來處理名船;總之绰上,matmul要考慮的事情比想象中要多。
除此之外渠驼,不同的dtype蜈块、device和layout需要調(diào)用不同的操作函數(shù),這部分工作交由c10::dispatcher來完成迷扇。
Dispatcher
dispatcher主要用于動態(tài)調(diào)用dtype百揭、device以及l(fā)ayout等方法函數(shù)。用過numpy的都知道蜓席,np.array()的數(shù)據(jù)類型有:float32, float16器一,int8,int32厨内,.... 如果你了解C++就會知道祈秕,這類程序最適合用模板(template)來實現(xiàn)。
很遺憾雏胃,由于ATen有一部分operator是用C語言寫的(從Torch繼承過來)请毛,不支持模板功能,因此瞭亮,就需要dispatcher這樣的動態(tài)調(diào)度器方仿。
類似地,PyTorch的tensor不僅可以運行在GPU上统翩,還可以跑在CPU仙蚜、mkldnn和xla等設(shè)備,F(xiàn)igure 1中的dispatcher4就根據(jù)tensor的device調(diào)用了mm的GPU實現(xiàn)厂汗。
layout是指tensor中元素的排布委粉。一般來說,矩陣的排布都是緊湊型的面徽,也就是strided layout艳丛。而那些有著大量0的稀疏矩陣匣掸,相應(yīng)地就是sparse layout趟紊。
Figure 2是strided layout的演示實例,這里創(chuàng)建了一個2行2列的矩陣a碰酝,它的數(shù)據(jù)實際存放在一維數(shù)組(a.storage)里霎匈,2行2列只是這個數(shù)組的視圖。
stride充當(dāng)了從數(shù)組到視圖的橋梁送爸,比如铛嘱,要打印第2行第2列的元素時暖释,可以通過公式:來計算該元素在數(shù)組中的索引。
除了dtype墨吓、device球匕、layout之外,dispatcher還可以用來調(diào)用legacy operator帖烘。比如說addmm這個operator亮曹,它的GPU實現(xiàn)就是通過dispatcher來跳轉(zhuǎn)到legacy::cuda::_th_addmm。
END
到此秘症,就完成了對PyTorch算子的學(xué)習(xí)照卦。如果你要學(xué)習(xí)其他算子,可以先從aten/src/ATen/native目錄的相關(guān)函數(shù)入手乡摹,從native_functions.yaml中找到dispatch目標(biāo)函數(shù)役耕,詳情可以參考Figure 1。
歡迎關(guān)注和點贊聪廉,你的鼓勵將是我創(chuàng)作的動力
歡迎轉(zhuǎn)發(fā)至朋友圈瞬痘,公眾號轉(zhuǎn)載請后臺留言申請授權(quán)~