XLA(Accelerated Linear Algebra)是專用于機(jī)器學(xué)習(xí)的編譯器跑慕,機(jī)器學(xué)習(xí)的運(yùn)算中99%都是向量乘以矩陣潮秘、矩陣乘以矩陣的計(jì)算殴俱,XLA是專門用來優(yōu)化這些計(jì)算的。
How to
舉個(gè)例子抓狭,運(yùn)行在GPU上的model_fn
函數(shù)會順序調(diào)用multiply
伯病、add
和reduce_sum
這三個(gè)op,而且multiply
否过,也就是y * z
的計(jì)算結(jié)果會先從GPU拷貝回host午笛,再拷貝到device作為add
的input,同樣的苗桂,add的計(jì)算結(jié)果也會以相同的方式傳遞給下一個(gè)op药磺。
def model_fn(x, y, z):
return tf.reduce_sum(x + y * z)
顯然,對于整個(gè)函數(shù)來說煤伟,將中間變量在host和device間來回倒騰是沒有意義的癌佩。因此,如果把函數(shù)看作一個(gè)op持偏,那在計(jì)算中產(chǎn)生的中間結(jié)果就不必返回到host驼卖,少了數(shù)據(jù)傳輸?shù)臅r(shí)間開銷,就可以大幅提升運(yùn)算效率鸿秆。
這種將多個(gè)op融合成一個(gè)op的方法就稱為fuse
酌畜,當(dāng)前fuse的技術(shù)路線有:
- 通過手寫或codegen工具來開發(fā)fused op,例如在上述例子中就可以開發(fā)
tf.fused_reduce_sum(x, y, z)
卿叽。它的優(yōu)點(diǎn)是代碼可控性高桥胞,易于性能優(yōu)化,但缺點(diǎn)是程序缺乏靈活性考婴。像Pytorch這種動態(tài)圖的框架走的就是這條路線贩虾,Nvidia的Apex提供有大量fused kernel,對fused kernel感興趣的沥阱,可以讀讀LayerNorm核心技術(shù)缎罢。 - 通過XLA等AI編譯器將python函數(shù)編譯成fused op。這樣做的好處是靈活性強(qiáng),可以fuse任何計(jì)算策精,弊端則是開發(fā)難度大舰始,且性能通常會遜色于手寫或codegen kernel。
性能
XLA的優(yōu)化當(dāng)然不只是fuse咽袜,還有對計(jì)算圖的優(yōu)化丸卷,包括刪除無效指令、減少內(nèi)存占用询刹、替換復(fù)雜指令等優(yōu)化谜嫉。下圖是官方提供的性能報(bào)告,經(jīng)XLA優(yōu)化過后凹联,Tensorflow BERT MLPerf的訓(xùn)練性能提升了~7倍沐兰。除了Tensorflow外,XLA還支持JAX匕垫、Julia僧鲁、PyTorch和Nx等前端虐呻。
Just in time(JIT)
jit
是指在首次運(yùn)行時(shí)將函數(shù)編譯成二進(jìn)制程序象泵,后續(xù)再調(diào)用該函數(shù)時(shí)直接運(yùn)行先前編譯好的程序而非python code。@tf.funciton
修飾的函數(shù)(包括它的子函數(shù))會做jit
斟叼。除非signature發(fā)生了變化偶惠,也就是input的shape或dtype和編譯時(shí)不同,否則get_MSE
是不需要重復(fù)編譯的朗涩。
@tf.function
def get_MSE(y_true, y_pred):
print("compiling ...")
sq_diff = tf.pow(y_true - y_pred, 2)
return tf.reduce_mean(sq_diff)
get_MSE(tf.constant(1.0), tf.constant(2.0)) # compile
get_MSE(tf.constant(3.0), tf.constant(4.0)) # It won't recompile
get_MSE(tf.ones([2, 2]), tf.ones([2, 2]) # compile again for new signature
@tf.function
將函數(shù)內(nèi)的ops替換成一組(XlaCompile
, XlaRun
) ops忽孽,在運(yùn)行時(shí)前者負(fù)責(zé)編譯,并將編譯結(jié)果--executable
保存到cache谢床,后者負(fù)責(zé)運(yùn)行executable兄一。如果cache里已經(jīng)有編譯好的程序就不需要編譯了,例如get_MSE(tf.constant(3.0), tf.constant(4.0))
识腿。
HLO
XLA編譯器支持的語言(IR)是HLO(High Level Operations)出革,顧名思義這些語言是由一個(gè)個(gè)op組成,因此渡讼,我們在編譯前需要先從python code中提取出所有ops骂束,再將它們轉(zhuǎn)換成HLO。
JAX通過tracing的方式成箫,從@jax.jit
修飾的函數(shù)中提取ops展箱,這些ops通過jaxpr
來表示。然后再通過XLA client提供的API為ops生成相應(yīng)的HLO蹬昌。PyTorch/XLA也是采用類似的方法來生成HLO混驰。
Tensorflow的tf2xla
為每個(gè)Op
創(chuàng)建了一個(gè)同名的XlaOp
用于生成HLO,XlaOp
派生于Op
,使用相同的注冊機(jī)制栖榨,因此竞慢,只要把要編譯的子圖根據(jù)拓?fù)渑判蜻\(yùn)行一遍就能生成它的HLO。
編譯
HLO先經(jīng)過一系列pass
優(yōu)化后再將HLO lowering成ISA治泥,最后將編譯好的二進(jìn)制封裝到executable
筹煮。
Executable
除了二進(jìn)制程序,它還包含運(yùn)行該程序所需要的infos和options居夹。調(diào)用executable.run()
就可以執(zhí)行計(jì)算圖败潦。