JAX 是機(jī)器學(xué)習(xí) (ML) 領(lǐng)域的新生力量被芳,它有望使 ML 編程更加直觀缰贝、結(jié)構(gòu)化和簡(jiǎn)潔。
在機(jī)器學(xué)習(xí)領(lǐng)域畔濒,大家可能對(duì) TensorFlow 和 PyTorch 已經(jīng)耳熟能詳剩晴,但除了這兩個(gè)框架,一些新生力量也不容小覷侵状,它就是谷歌推出的 JAX赞弥。很對(duì)研究者對(duì)其寄予厚望,希望它可以取代 TensorFlow 等眾多機(jī)器學(xué)習(xí)框架趣兄。
JAX 最初由谷歌大腦團(tuán)隊(duì)的 Matt Johnson绽左、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人發(fā)起诽俯。
目前妇菱,JAX 在 GitHub 上已累積 13.7K 星。
項(xiàng)目地址:https://github.com/google/jax
迅速發(fā)展的 JAX
JAX 的前身是 Autograd暴区,其借助 Autograd 的更新版本闯团,并且結(jié)合了 XLA,可對(duì) Python 程序與 NumPy 運(yùn)算執(zhí)行自動(dòng)微分仙粱,支持循環(huán)房交、分支、遞歸伐割、閉包函數(shù)求導(dǎo)候味,也可以求三階導(dǎo)數(shù);依賴(lài)于 XLA隔心,JAX 可以在 GPU 和 TPU 上編譯和運(yùn)行 NumPy 程序白群;通過(guò) grad,可以支持自動(dòng)模式反向傳播和正向傳播硬霍,且二者可以任意組合成任何順序帜慢。
開(kāi)發(fā) JAX 的出發(fā)點(diǎn)是什么?說(shuō)到這唯卖,就不得不提 NumPy粱玲。NumPy 是 Python 中的一個(gè)基礎(chǔ)數(shù)值運(yùn)算庫(kù),被廣泛使用拜轨。但是 numpy 不支持 GPU 或其他硬件加速器抽减,也沒(méi)有對(duì)反向傳播的內(nèi)置支持,此外橄碾,Python 本身的速度限制阻礙了 NumPy 使用卵沉,所以少有研究者在生產(chǎn)環(huán)境下直接用 numpy 訓(xùn)練或部署深度學(xué)習(xí)模型颠锉。
在此情況下,出現(xiàn)了眾多的深度學(xué)習(xí)框架偎箫,如 PyTorch木柬、TensorFlow 等。但是 numpy 具有靈活淹办、調(diào)試方便眉枕、API 穩(wěn)定等獨(dú)特的優(yōu)勢(shì)。而 JAX 的主要出發(fā)點(diǎn)就是將 numpy 的以上優(yōu)勢(shì)與硬件加速結(jié)合怜森。
目前速挑,基于 JAX 已有很多優(yōu)秀的開(kāi)源項(xiàng)目,如谷歌的神經(jīng)網(wǎng)絡(luò)庫(kù)團(tuán)隊(duì)開(kāi)發(fā)了 Haiku副硅,這是一個(gè)面向 Jax 的深度學(xué)習(xí)代碼庫(kù)姥宝,通過(guò) Haiku,用戶(hù)可以在 Jax 上進(jìn)行面向?qū)ο箝_(kāi)發(fā)恐疲;又比如 RLax腊满,這是一個(gè)基于 Jax 的強(qiáng)化學(xué)習(xí)庫(kù),用戶(hù)使用 RLax 就能進(jìn)行 Q-learning 模型的搭建和訓(xùn)練培己;此外還包括基于 JAX 的深度學(xué)習(xí)庫(kù) JAXnet碳蛋,該庫(kù)一行代碼就能定義計(jì)算圖、可進(jìn)行 GPU 加速省咨∷嗟埽可以說(shuō),在過(guò)去幾年中零蓉,JAX 掀起了深度學(xué)習(xí)研究的風(fēng)暴笤受,推動(dòng)了科學(xué)研究迅速發(fā)展。
JAX 的安裝
如何使用 JAX 呢敌蜂?首先你需要在 Python 環(huán)境或 Google colab 中安裝 JAX箩兽,使用 pip 進(jìn)行安裝:
$ pip install --upgrade jax jaxlib
注意,上述安裝方式只是支持在 CPU 上運(yùn)行章喉,如果你想在 GPU 執(zhí)行程序比肄,首先你需要有 CUDA、cuDNN 囊陡,然后運(yùn)行以下命令(確保將 jaxlib 版本映射到 CUDA 版本):
$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
現(xiàn)在將 JAX 與 Numpy 一起導(dǎo)入:
import jax
import jax.numpy as jnp
import numpy as np
JAX 的一些特性
使用 grad() 函數(shù)自動(dòng)微分:這對(duì)深度學(xué)習(xí)應(yīng)用非常有用,這樣就可以很容易地運(yùn)行反向傳播掀亥,下面為一個(gè)簡(jiǎn)單的二次函數(shù)并在點(diǎn) 1.0 上求導(dǎo)的示例:
from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0
jit(Just in time) :為了利用 XLA 的強(qiáng)大功能撞反,必須將代碼編譯到 XLA 內(nèi)核中。這就是 jit 發(fā)揮作用的地方搪花。要使用 XLA 和 jit遏片,用戶(hù)可以使用 jit() 函數(shù)或 @jit 注釋嘹害。
from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
for _ in range(10):
x = 0.5*x + 0.1* jnp.sin(x)
return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 μs per loop
pmap:自動(dòng)將計(jì)算分配到所有當(dāng)前設(shè)備,并處理它們之間的所有通信吮便。JAX 通過(guò) pmap 轉(zhuǎn)換支持大規(guī)模的數(shù)據(jù)并行笔呀,從而將單個(gè)處理器無(wú)法處理的大數(shù)據(jù)進(jìn)行處理。要檢查可用設(shè)備髓需,可以運(yùn)行 jax.devices():
from jax import pmap
def f(x):
return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
vmap:是一種函數(shù)轉(zhuǎn)換许师,JAX 通過(guò) vmap 變換提供了自動(dòng)矢量化算法,大大簡(jiǎn)化了這種類(lèi)型的計(jì)算僚匆,這使得研究人員在處理新算法時(shí)無(wú)需再去處理批量化的問(wèn)題微渠。示例如下:
from jax import vmap
def f(x):
return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
TensorFlow vs PyTorch vs Jax
在深度學(xué)習(xí)領(lǐng)域有幾家巨頭公司,他們所提出的框架被廣大研究者使用咧擂。比如谷歌的 TensorFlow逞盆、Facebook 的 PyTorch、微軟的 CNTK松申、亞馬遜 AWS 的 MXnet 等云芦。
每種框架都有其優(yōu)缺點(diǎn),選擇的時(shí)候需要根據(jù)自身需求進(jìn)行選擇贸桶。
我們以 Python 中的 3 個(gè)主要深度學(xué)習(xí)框架——TensorFlow舅逸、PyTorch 和 Jax 為例進(jìn)行比較。這些框架雖然不同刨啸,但有兩個(gè)共同點(diǎn):
- 它們是開(kāi)源的堡赔。這意味著如果庫(kù)中存在錯(cuò)誤,使用者可以在 GitHub 中發(fā)布問(wèn)題(并修復(fù))设联,此外你也可以在庫(kù)中添加自己的功能善已;
- 由于全局解釋器鎖,Python 在內(nèi)部運(yùn)行緩慢离例。所以這些框架使用 C/C++ 作為后端來(lái)處理所有的計(jì)算和并行過(guò)程换团。
那么它們的不同體現(xiàn)在哪些方面呢?如下表所示宫蛆,為 TensorFlow艘包、PyTorch、JAX 三個(gè)框架的比較耀盗。
TensorFlow
TensorFlow 由谷歌開(kāi)發(fā)想虎,最初版本可追溯到 2015 年開(kāi)源的 TensorFlow0.1,之后發(fā)展穩(wěn)定叛拷,擁有強(qiáng)大的用戶(hù)群體舌厨,成為最受歡迎的深度學(xué)習(xí)框架。但是用戶(hù)在使用時(shí)忿薇,也暴露了 TensorFlow 缺點(diǎn)裙椭,例如 API 穩(wěn)定性不足躏哩、靜態(tài)計(jì)算圖編程復(fù)雜等缺陷。因此在 TensorFlow2.0 版本揉燃,谷歌將 Keras 納入進(jìn)來(lái)扫尺,成為 tf.keras。
目前 TensorFlow 主要特點(diǎn)包括以下:
- 這是一個(gè)非常友好的框架炊汤,高級(jí) API-Keras 的可用性使得模型層定義正驻、損失函數(shù)和模型創(chuàng)建變得非常容易直撤;
- TensorFlow2.0 帶有 Eager Execution(動(dòng)態(tài)圖機(jī)制)愈污,這使得該庫(kù)更加用戶(hù)友好,并且是對(duì)以前版本的重大升級(jí)服猪;
- Keras 這種高級(jí)接口有一定的缺點(diǎn)氓栈,由于 TensorFlow 抽象了許多底層機(jī)制(只是為了方便最終用戶(hù))渣磷,這讓研究人員在處理模型方面的自由度更小授瘦;
- Tensorflow 提供了 TensorBoard醋界,它實(shí)際上是 Tensorflow 可視化工具包。它允許研究者可視化損失函數(shù)提完、模型圖形纺、模型分析等。
PyTorch
PyTorch(Python-Torch) 是來(lái)自 Facebook 的機(jī)器學(xué)習(xí)庫(kù)徒欣。用 TensorFlow 還是 PyTorch逐样?在一年前,這個(gè)問(wèn)題毫無(wú)爭(zhēng)議打肝,研究者大部分會(huì)選擇 TensorFlow脂新。但現(xiàn)在的情況大不一樣了,使用 PyTorch 的研究者越來(lái)越多粗梭。PyTorch 的一些最重要的特性包括:
- 與 TensorFlow 不同争便,PyTorch 使用動(dòng)態(tài)類(lèi)型圖,這意味著執(zhí)行圖是在運(yùn)行中創(chuàng)建的断医。它允許我們隨時(shí)修改和檢查圖的內(nèi)部結(jié)構(gòu)滞乙;
- 除了用戶(hù)友好的高級(jí) API 之外,PyTorch 還包括精心構(gòu)建的低級(jí) API鉴嗤,允許對(duì)機(jī)器學(xué)習(xí)模型進(jìn)行越來(lái)越多的控制斩启。我們可以在訓(xùn)練期間對(duì)模型的前向和后向傳遞進(jìn)行檢查和修改輸出。這被證明對(duì)于梯度裁剪和神經(jīng)風(fēng)格遷移非常有效醉锅;
- PyTorch 允許用戶(hù)擴(kuò)展代碼兔簇,可以輕松添加新的損失函數(shù)和用戶(hù)定義的層。PyTorch 的 Autograd 模塊實(shí)現(xiàn)了深度學(xué)習(xí)算法中的反向傳播求導(dǎo)數(shù),在 Tensor 類(lèi)上的所有操作男韧, Autograd 都能自動(dòng)提供微分,簡(jiǎn)化了手動(dòng)計(jì)算導(dǎo)數(shù)的復(fù)雜過(guò)程默垄;
- PyTorch 對(duì)數(shù)據(jù)并行和 GPU 的使用具有廣泛的支持此虑;
- PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常適合 Python 生態(tài)系統(tǒng)口锭,它允許使用 Python 類(lèi)調(diào)試器工具來(lái)調(diào)試 PyTorch 代碼朦前。
JAX
JAX 是來(lái)自 Google 的一個(gè)相對(duì)較新的機(jī)器學(xué)習(xí)庫(kù)。它更像是一個(gè) autograd 庫(kù)鹃操,可以區(qū)分原生的 python 和 NumPy 代碼韭寸。JAX 的一些特性主要包括:
- 正如官方網(wǎng)站所描述的那樣,JAX 能夠執(zhí)行 Python+NumPy 程序的可組合轉(zhuǎn)換:向量化荆隘、JIT 到 GPU/TPU 等等恩伺;
- 與 PyTorch 相比,JAX 最重要的方面是如何計(jì)算梯度椰拒。在 Torch 中晶渠,圖是在前向傳遞期間創(chuàng)建的,梯度在后向傳遞期間計(jì)算燃观, 另一方面褒脯,在 JAX 中,計(jì)算表示為函數(shù)缆毁。在函數(shù)上使用 grad() 返回一個(gè)梯度函數(shù)番川,該函數(shù)直接計(jì)算給定輸入的函數(shù)梯度;
- JAX 是一個(gè) autograd 工具脊框,不建議單獨(dú)使用颁督。有各種基于 JAX 的機(jī)器學(xué)習(xí)庫(kù),其中值得注意的是 ObJax缚陷、Flax 和 Elegy适篙。由于它們都使用相同的核心并且接口只是 JAX 庫(kù)的 wrapper,因此可以將它們放在同一個(gè) bracket 下箫爷;
- Flax 最初是在 PyTorch 生態(tài)系統(tǒng)下開(kāi)發(fā)的嚷节,更注重使用的靈活性。另一方面虎锚,Elegy 受 Keras 啟發(fā)硫痰。ObJAX 主要是為以研究為導(dǎo)向的目的而設(shè)計(jì)的,它更注重簡(jiǎn)單性和可理解性窜护。
參考鏈接:
- https://www.askpython.com/python-modules/tensorflow-vs-pytorch-vs-jax
- https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
- https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
- https://www.zhihu.com/question/306496943/answer/557876584
開(kāi)源前哨
日常分享熱門(mén)效斑、有趣和實(shí)用的開(kāi)源項(xiàng)目。參與維護(hù) 10萬(wàn)+ Star 的開(kāi)源技術(shù)資源庫(kù)柱徙,包括:Python缓屠、Java奇昙、C/C++、Go敌完、JS储耐、CSS、Node.js滨溉、PHP什湘、.NET 等。