1.3 萬(wàn) Star!迅猛發(fā)展的 JAX 對(duì)比 TensorFlow、PyTorch

JAX 是機(jī)器學(xué)習(xí) (ML) 領(lǐng)域的新生力量被芳,它有望使 ML 編程更加直觀缰贝、結(jié)構(gòu)化和簡(jiǎn)潔。

在機(jī)器學(xué)習(xí)領(lǐng)域畔濒,大家可能對(duì) TensorFlow 和 PyTorch 已經(jīng)耳熟能詳剩晴,但除了這兩個(gè)框架,一些新生力量也不容小覷侵状,它就是谷歌推出的 JAX赞弥。很對(duì)研究者對(duì)其寄予厚望,希望它可以取代 TensorFlow 等眾多機(jī)器學(xué)習(xí)框架趣兄。

JAX 最初由谷歌大腦團(tuán)隊(duì)的 Matt Johnson绽左、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人發(fā)起诽俯。

目前妇菱,JAX 在 GitHub 上已累積 13.7K 星。

項(xiàng)目地址:https://github.com/google/jax

迅速發(fā)展的 JAX

JAX 的前身是 Autograd暴区,其借助 Autograd 的更新版本闯团,并且結(jié)合了 XLA,可對(duì) Python 程序與 NumPy 運(yùn)算執(zhí)行自動(dòng)微分仙粱,支持循環(huán)房交、分支、遞歸伐割、閉包函數(shù)求導(dǎo)候味,也可以求三階導(dǎo)數(shù);依賴(lài)于 XLA隔心,JAX 可以在 GPU 和 TPU 上編譯和運(yùn)行 NumPy 程序白群;通過(guò) grad,可以支持自動(dòng)模式反向傳播和正向傳播硬霍,且二者可以任意組合成任何順序帜慢。

開(kāi)發(fā) JAX 的出發(fā)點(diǎn)是什么?說(shuō)到這唯卖,就不得不提 NumPy粱玲。NumPy 是 Python 中的一個(gè)基礎(chǔ)數(shù)值運(yùn)算庫(kù),被廣泛使用拜轨。但是 numpy 不支持 GPU 或其他硬件加速器抽减,也沒(méi)有對(duì)反向傳播的內(nèi)置支持,此外橄碾,Python 本身的速度限制阻礙了 NumPy 使用卵沉,所以少有研究者在生產(chǎn)環(huán)境下直接用 numpy 訓(xùn)練或部署深度學(xué)習(xí)模型颠锉。

在此情況下,出現(xiàn)了眾多的深度學(xué)習(xí)框架偎箫,如 PyTorch木柬、TensorFlow 等。但是 numpy 具有靈活淹办、調(diào)試方便眉枕、API 穩(wěn)定等獨(dú)特的優(yōu)勢(shì)。而 JAX 的主要出發(fā)點(diǎn)就是將 numpy 的以上優(yōu)勢(shì)與硬件加速結(jié)合怜森。

目前速挑,基于 JAX 已有很多優(yōu)秀的開(kāi)源項(xiàng)目,如谷歌的神經(jīng)網(wǎng)絡(luò)庫(kù)團(tuán)隊(duì)開(kāi)發(fā)了 Haiku副硅,這是一個(gè)面向 Jax 的深度學(xué)習(xí)代碼庫(kù)姥宝,通過(guò) Haiku,用戶(hù)可以在 Jax 上進(jìn)行面向?qū)ο箝_(kāi)發(fā)恐疲;又比如 RLax腊满,這是一個(gè)基于 Jax 的強(qiáng)化學(xué)習(xí)庫(kù),用戶(hù)使用 RLax 就能進(jìn)行 Q-learning 模型的搭建和訓(xùn)練培己;此外還包括基于 JAX 的深度學(xué)習(xí)庫(kù) JAXnet碳蛋,該庫(kù)一行代碼就能定義計(jì)算圖、可進(jìn)行 GPU 加速省咨∷嗟埽可以說(shuō),在過(guò)去幾年中零蓉,JAX 掀起了深度學(xué)習(xí)研究的風(fēng)暴笤受,推動(dòng)了科學(xué)研究迅速發(fā)展。

JAX 的安裝

如何使用 JAX 呢敌蜂?首先你需要在 Python 環(huán)境或 Google colab 中安裝 JAX箩兽,使用 pip 進(jìn)行安裝:

$ pip install --upgrade jax jaxlib

注意,上述安裝方式只是支持在 CPU 上運(yùn)行章喉,如果你想在 GPU 執(zhí)行程序比肄,首先你需要有 CUDA、cuDNN 囊陡,然后運(yùn)行以下命令(確保將 jaxlib 版本映射到 CUDA 版本):

$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

現(xiàn)在將 JAX 與 Numpy 一起導(dǎo)入:

import jax
import jax.numpy as jnp
import numpy as np

JAX 的一些特性

使用 grad() 函數(shù)自動(dòng)微分:這對(duì)深度學(xué)習(xí)應(yīng)用非常有用,這樣就可以很容易地運(yùn)行反向傳播掀亥,下面為一個(gè)簡(jiǎn)單的二次函數(shù)并在點(diǎn) 1.0 上求導(dǎo)的示例:

from jax import grad
def f(x):
  return 3*x**2 + 2*x + 5
def f_prime(x):
  return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0

jit(Just in time) :為了利用 XLA 的強(qiáng)大功能撞反,必須將代碼編譯到 XLA 內(nèi)核中。這就是 jit 發(fā)揮作用的地方搪花。要使用 XLA 和 jit遏片,用戶(hù)可以使用 jit() 函數(shù)或 @jit 注釋嘹害。

from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)
  return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 μs per loop

pmap:自動(dòng)將計(jì)算分配到所有當(dāng)前設(shè)備,并處理它們之間的所有通信吮便。JAX 通過(guò) pmap 轉(zhuǎn)換支持大規(guī)模的數(shù)據(jù)并行笔呀,從而將單個(gè)處理器無(wú)法處理的大數(shù)據(jù)進(jìn)行處理。要檢查可用設(shè)備髓需,可以運(yùn)行 jax.devices():

from jax import pmap
def f(x):
  return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0.       , 1.841471 , 4.9092975, 9.14112  ], dtype=float32)

vmap:是一種函數(shù)轉(zhuǎn)換许师,JAX 通過(guò) vmap 變換提供了自動(dòng)矢量化算法,大大簡(jiǎn)化了這種類(lèi)型的計(jì)算僚匆,這使得研究人員在處理新算法時(shí)無(wú)需再去處理批量化的問(wèn)題微渠。示例如下:

from jax import vmap
def f(x):
  return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([ 0,  1,  4,  9, 16, 25, 36, 49, 64, 81], dtype=int32)

TensorFlow vs PyTorch vs Jax

在深度學(xué)習(xí)領(lǐng)域有幾家巨頭公司,他們所提出的框架被廣大研究者使用咧擂。比如谷歌的 TensorFlow逞盆、Facebook 的 PyTorch、微軟的 CNTK松申、亞馬遜 AWS 的 MXnet 等云芦。

每種框架都有其優(yōu)缺點(diǎn),選擇的時(shí)候需要根據(jù)自身需求進(jìn)行選擇贸桶。

我們以 Python 中的 3 個(gè)主要深度學(xué)習(xí)框架——TensorFlow舅逸、PyTorch 和 Jax 為例進(jìn)行比較。這些框架雖然不同刨啸,但有兩個(gè)共同點(diǎn):

  • 它們是開(kāi)源的堡赔。這意味著如果庫(kù)中存在錯(cuò)誤,使用者可以在 GitHub 中發(fā)布問(wèn)題(并修復(fù))设联,此外你也可以在庫(kù)中添加自己的功能善已;
  • 由于全局解釋器鎖,Python 在內(nèi)部運(yùn)行緩慢离例。所以這些框架使用 C/C++ 作為后端來(lái)處理所有的計(jì)算和并行過(guò)程换团。

那么它們的不同體現(xiàn)在哪些方面呢?如下表所示宫蛆,為 TensorFlow艘包、PyTorch、JAX 三個(gè)框架的比較耀盗。

TensorFlow

TensorFlow 由谷歌開(kāi)發(fā)想虎,最初版本可追溯到 2015 年開(kāi)源的 TensorFlow0.1,之后發(fā)展穩(wěn)定叛拷,擁有強(qiáng)大的用戶(hù)群體舌厨,成為最受歡迎的深度學(xué)習(xí)框架。但是用戶(hù)在使用時(shí)忿薇,也暴露了 TensorFlow 缺點(diǎn)裙椭,例如 API 穩(wěn)定性不足躏哩、靜態(tài)計(jì)算圖編程復(fù)雜等缺陷。因此在 TensorFlow2.0 版本揉燃,谷歌將 Keras 納入進(jìn)來(lái)扫尺,成為 tf.keras。

目前 TensorFlow 主要特點(diǎn)包括以下:

  • 這是一個(gè)非常友好的框架炊汤,高級(jí) API-Keras 的可用性使得模型層定義正驻、損失函數(shù)和模型創(chuàng)建變得非常容易直撤;
  • TensorFlow2.0 帶有 Eager Execution(動(dòng)態(tài)圖機(jī)制)愈污,這使得該庫(kù)更加用戶(hù)友好,并且是對(duì)以前版本的重大升級(jí)服猪;
  • Keras 這種高級(jí)接口有一定的缺點(diǎn)氓栈,由于 TensorFlow 抽象了許多底層機(jī)制(只是為了方便最終用戶(hù))渣磷,這讓研究人員在處理模型方面的自由度更小授瘦;
  • Tensorflow 提供了 TensorBoard醋界,它實(shí)際上是 Tensorflow 可視化工具包。它允許研究者可視化損失函數(shù)提完、模型圖形纺、模型分析等。

PyTorch

PyTorch(Python-Torch) 是來(lái)自 Facebook 的機(jī)器學(xué)習(xí)庫(kù)徒欣。用 TensorFlow 還是 PyTorch逐样?在一年前,這個(gè)問(wèn)題毫無(wú)爭(zhēng)議打肝,研究者大部分會(huì)選擇 TensorFlow脂新。但現(xiàn)在的情況大不一樣了,使用 PyTorch 的研究者越來(lái)越多粗梭。PyTorch 的一些最重要的特性包括:

  • 與 TensorFlow 不同争便,PyTorch 使用動(dòng)態(tài)類(lèi)型圖,這意味著執(zhí)行圖是在運(yùn)行中創(chuàng)建的断医。它允許我們隨時(shí)修改和檢查圖的內(nèi)部結(jié)構(gòu)滞乙;
  • 除了用戶(hù)友好的高級(jí) API 之外,PyTorch 還包括精心構(gòu)建的低級(jí) API鉴嗤,允許對(duì)機(jī)器學(xué)習(xí)模型進(jìn)行越來(lái)越多的控制斩启。我們可以在訓(xùn)練期間對(duì)模型的前向和后向傳遞進(jìn)行檢查和修改輸出。這被證明對(duì)于梯度裁剪和神經(jīng)風(fēng)格遷移非常有效醉锅;
  • PyTorch 允許用戶(hù)擴(kuò)展代碼兔簇,可以輕松添加新的損失函數(shù)和用戶(hù)定義的層。PyTorch 的 Autograd 模塊實(shí)現(xiàn)了深度學(xué)習(xí)算法中的反向傳播求導(dǎo)數(shù),在 Tensor 類(lèi)上的所有操作男韧, Autograd 都能自動(dòng)提供微分,簡(jiǎn)化了手動(dòng)計(jì)算導(dǎo)數(shù)的復(fù)雜過(guò)程默垄;
  • PyTorch 對(duì)數(shù)據(jù)并行和 GPU 的使用具有廣泛的支持此虑;
  • PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常適合 Python 生態(tài)系統(tǒng)口锭,它允許使用 Python 類(lèi)調(diào)試器工具來(lái)調(diào)試 PyTorch 代碼朦前。

JAX

JAX 是來(lái)自 Google 的一個(gè)相對(duì)較新的機(jī)器學(xué)習(xí)庫(kù)。它更像是一個(gè) autograd 庫(kù)鹃操,可以區(qū)分原生的 python 和 NumPy 代碼韭寸。JAX 的一些特性主要包括:

  • 正如官方網(wǎng)站所描述的那樣,JAX 能夠執(zhí)行 Python+NumPy 程序的可組合轉(zhuǎn)換:向量化荆隘、JIT 到 GPU/TPU 等等恩伺;
  • 與 PyTorch 相比,JAX 最重要的方面是如何計(jì)算梯度椰拒。在 Torch 中晶渠,圖是在前向傳遞期間創(chuàng)建的,梯度在后向傳遞期間計(jì)算燃观, 另一方面褒脯,在 JAX 中,計(jì)算表示為函數(shù)缆毁。在函數(shù)上使用 grad() 返回一個(gè)梯度函數(shù)番川,該函數(shù)直接計(jì)算給定輸入的函數(shù)梯度;
  • JAX 是一個(gè) autograd 工具脊框,不建議單獨(dú)使用颁督。有各種基于 JAX 的機(jī)器學(xué)習(xí)庫(kù),其中值得注意的是 ObJax缚陷、Flax 和 Elegy适篙。由于它們都使用相同的核心并且接口只是 JAX 庫(kù)的 wrapper,因此可以將它們放在同一個(gè) bracket 下箫爷;
  • Flax 最初是在 PyTorch 生態(tài)系統(tǒng)下開(kāi)發(fā)的嚷节,更注重使用的靈活性。另一方面虎锚,Elegy 受 Keras 啟發(fā)硫痰。ObJAX 主要是為以研究為導(dǎo)向的目的而設(shè)計(jì)的,它更注重簡(jiǎn)單性和可理解性窜护。

參考鏈接:

開(kāi)源前哨 日常分享熱門(mén)效斑、有趣和實(shí)用的開(kāi)源項(xiàng)目。參與維護(hù) 10萬(wàn)+ Star 的開(kāi)源技術(shù)資源庫(kù)柱徙,包括:Python缓屠、Java奇昙、C/C++、Go敌完、JS储耐、CSS、Node.js滨溉、PHP什湘、.NET 等。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末晦攒,一起剝皮案震驚了整個(gè)濱河市闽撤,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌脯颜,老刑警劉巖哟旗,帶你破解...
    沈念sama閱讀 217,277評(píng)論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異伐脖,居然都是意外死亡热幔,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評(píng)論 3 393
  • 文/潘曉璐 我一進(jìn)店門(mén)讼庇,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)绎巨,“玉大人,你說(shuō)我怎么就攤上這事蠕啄〕∏冢” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 163,624評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵歼跟,是天一觀的道長(zhǎng)和媳。 經(jīng)常有香客問(wèn)我,道長(zhǎng)哈街,這世上最難降的妖魔是什么留瞳? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,356評(píng)論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮骚秦,結(jié)果婚禮上她倘,老公的妹妹穿的比我還像新娘。我一直安慰自己作箍,他們只是感情好硬梁,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,402評(píng)論 6 392
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著胞得,像睡著了一般荧止。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 51,292評(píng)論 1 301
  • 那天跃巡,我揣著相機(jī)與錄音危号,去河邊找鬼。 笑死素邪,一個(gè)胖子當(dāng)著我的面吹牛葱色,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播娘香,決...
    沈念sama閱讀 40,135評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼办龄!你這毒婦竟也來(lái)了烘绽?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 38,992評(píng)論 0 275
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤俐填,失蹤者是張志新(化名)和其女友劉穎安接,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體英融,經(jīng)...
    沈念sama閱讀 45,429評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡盏檐,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,636評(píng)論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了驶悟。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片胡野。...
    茶點(diǎn)故事閱讀 39,785評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖痕鳍,靈堂內(nèi)的尸體忽然破棺而出硫豆,到底是詐尸還是另有隱情,我是刑警寧澤笼呆,帶...
    沈念sama閱讀 35,492評(píng)論 5 345
  • 正文 年R本政府宣布熊响,位于F島的核電站,受9級(jí)特大地震影響诗赌,放射性物質(zhì)發(fā)生泄漏汗茄。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,092評(píng)論 3 328
  • 文/蒙蒙 一铭若、第九天 我趴在偏房一處隱蔽的房頂上張望洪碳。 院中可真熱鬧,春花似錦奥喻、人聲如沸偶宫。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,723評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)纯趋。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間吵冒,已是汗流浹背纯命。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,858評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留痹栖,地道東北人亿汞。 一個(gè)月前我還...
    沈念sama閱讀 47,891評(píng)論 2 370
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像揪阿,于是被迫代替她去往敵國(guó)和親疗我。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,713評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容