前面幾章一直在講JAX轉(zhuǎn)換函數(shù),如jax.jit徒探、jax.grad葬荷、jax.vmap等以及它們的組合使用可以編寫簡潔、執(zhí)行高效的代碼岖寞。本章介紹如何通過自定義Jaxpr解釋器來自定義函數(shù)轉(zhuǎn)換抡四。
Jaxpr Tracer跟蹤器
JAX為數(shù)值計(jì)算提供了一套類似于NumPy的API,幾乎可以按NumPy原樣使用jax.numpy仗谆,當(dāng)JAX真正的功能來自于可組合的函數(shù)轉(zhuǎn)換指巡。下面以jax.jit函數(shù)轉(zhuǎn)換為例,該函數(shù)接受一個(gè)函數(shù)并返回一個(gè)語義相同的函數(shù)隶垮,之后再用XLA加速器編譯函數(shù)藻雪。
import jax
def function(x):
return 2 * x ** 2 + 3 * x
def test():
function_jit = jax.jit(function)
result = function_jit(10)
print("result = ", result)
def main():
test()
if __name__ == "__main__":
main()
上面例子里,當(dāng)調(diào)用funciton_jit時(shí)狸吞,JAX講跟蹤函數(shù)并構(gòu)造XLA計(jì)算圖勉耀,然后對(duì)圖形進(jìn)行JIT編譯和執(zhí)行。其他函數(shù)轉(zhuǎn)換方式類似蹋偏,即首先跟蹤函數(shù)并以某種方式處理輸出跟蹤便斥。
JAX中一個(gè)特別重要的跟蹤器就是Jaxpr,它講OP記錄到的Jaxpr(JAX表達(dá)式)中威始。Jaxpr是一種數(shù)據(jù)結(jié)構(gòu)枢纠,可以像函數(shù)式編程語言那樣進(jìn)行計(jì)算,因此Jaxpr是函數(shù)轉(zhuǎn)換中有用的中間表示形式黎棠。
可以使用make_jaxpr對(duì)函數(shù)進(jìn)行jaxpr轉(zhuǎn)換晋渺,它將一個(gè)函數(shù)轉(zhuǎn)換成給定的示例參數(shù),生成計(jì)Jaxpr算表達(dá)式脓斩。雖然通常不能直接使用它生成的jaxpr語句木西,但是這對(duì)于調(diào)試和查看JAX函數(shù)很有用。
下面通過幾段代碼來理解jaxpr的運(yùn)行機(jī)制随静。
import jax
def function(x):
return 2 * x ** 2 + 3 * x
def test():
expr = jax.make_jaxpr(function)
result = expr(2.0)
print(result)
def main():
test()
if __name__ == "__main__":
main()
運(yùn)行結(jié)果打印輸出如下户魏,
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
更詳細(xì)的函數(shù)來來對(duì)make_jaxpr進(jìn)行解析予跌,代碼如下葛躏,
import jax
def function(x):
return 2 * x ** 2 + 3 * x
def print_jaxpr(closed_expr):
jaxpr = closed_expr.jaxpr
print("invars: ", jaxpr.invars)
print("outvars: ", jaxpr.outvars)
print("constvars: ", jaxpr.constvars)
for equation in jaxpr.eqns:
print("Equation: ", equation.invars, equation.primitive, equation.outvars, equation.params)
print("jaxpr: ", jaxpr)
def test():
expr = jax.make_jaxpr(function)
result = expr(2.0)
print(result)
print("--------------------------")
print_jaxpr(result)
def main():
test()
if __name__ == "__main__":
main()
運(yùn)行結(jié)果打印輸出如下烤礁,
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
--------------------------
invars: [a]
outvars: [e]
constvars: []
Equation: [a] integer_pow [b] {'y': 2}
Equation: [2.0, b] mul [c] {}
Equation: [3.0, a] mul [d] {}
Equation: [c, d] add [e] {}
jaxpr: { lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
詳細(xì)解析前陆赋,先了解一下相關(guān)參數(shù)的意義,
- jaxpr.invars鸠信,輸入變量列表纵寝,類似于函數(shù)的形參。
- jaxpr.outvars星立,輸出(返回)變量列表爽茴。
- Jaxpr.constvars,變量列表绰垂,也是jaxpr的輸入變量室奏,但對(duì)應(yīng)追蹤中的常量。
- Jaxpr.eqns劲装,一系列內(nèi)部計(jì)算的等式(或函數(shù))列表胧沫,這個(gè)列表中的每一個(gè)等式(或函數(shù))都有一個(gè)輸入和輸出,用于計(jì)算這個(gè)函數(shù)產(chǎn)生的輸出結(jié)果占业。
根據(jù)參數(shù)說明绒怨,可以嘗試去解析一下上面運(yùn)行結(jié)果,
invars: [a]
outvars: [e]
constvars: []
Equation: [a] integer_pow [b] {'y': 2}
Equation: [2.0, b] mul [c] {}
Equation: [3.0, a] mul [d] {}
Equation: [c, d] add [e] {}
- invars: [a]谦疾,輸入?yún)?shù)變量為唯一元素a組成的數(shù)組或列表南蹂。
- outvars: [e],輸出或返回值變量為唯一元素e組成的數(shù)組或列表念恍。
- constvars: []六剥,輸入可追蹤常數(shù)“變量”無。
- Equation: [a] integer_pow [b] {'y': 2}峰伙,等式(或函數(shù))疗疟,計(jì)算輸入?yún)?shù)變量a([a表示參數(shù)輸入值,[b]表示等式或函數(shù)的輸出或返回值])的2次冪的等式词爬。
- Equation: [2.0, b] mul [c] {},等式(或函數(shù))权均,計(jì)算輸入?yún)?shù)變量b(上面等式的輸出值)和常數(shù)2.0相乘顿膨、返回值或者輸出值為c的等式。
- Equation: [3.0, a] mul [d] {}叽赊,等式(或函數(shù))恋沃,計(jì)算輸入?yún)?shù)變量a和常數(shù)3.0相乘、返回值或者輸出值為d的等式必指。
- Equation: [c, d] add [e] {}囊咏,等式(或函數(shù)),計(jì)算輸入?yún)?shù)變量c和d相加、返回值或者輸出值為e的等式梅割。
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
- { lambda ; a:f32[]. let霜第,定義lambda表達(dá)式,float32類型的數(shù)組輸入?yún)?shù)a户辞,let函數(shù)體開始泌类。
- b:f32[] = integer_pow[y=2] a,定義float32類型的數(shù)組變量b底燎,用于接受由2次指數(shù)函數(shù)integer_pow和輸入?yún)?shù)a計(jì)算后的結(jié)果刃榨。
- c:f32[] = mul 2.0 b,定義float32類型的數(shù)組變量c双仍,用于接受由常數(shù)2.0與上面結(jié)果b相乘后的結(jié)果枢希。
- d:f32[] = mul 3.0 a,定義float32類型的數(shù)組變量d朱沃,用于接受由常數(shù)3.0與輸入?yún)?shù)a計(jì)算后的結(jié)果苞轿。
- e:f32[] = add c d,定義float32類型的數(shù)組變量e为流,用于接受由上面結(jié)果c和結(jié)果d相加的后的結(jié)果呕屎。
- in (e,) },定義返回值為由e組成的元組敬察。
由上面解析過程來看秀睛,Jaxpr表達(dá)式是易于轉(zhuǎn)換的簡單程序表示形式,類似于某些語言的中間語言莲祸。由于JAX允許從Python函數(shù)中直接轉(zhuǎn)譯Jaxpr蹂安,所以,它提供了一套為Python數(shù)值計(jì)算函數(shù)進(jìn)行轉(zhuǎn)換的方法锐帜。
對(duì)于函數(shù)的追蹤則有些復(fù)雜田盈,不能直接使用make_jaxpr,因?yàn)樾枰崛≡谧粉欉^程中創(chuàng)建的常量以傳遞到j(luò)axpr缴阎。但是允瞧,可以編寫一個(gè)類似于make_jaxpr的函數(shù),代碼如下蛮拔,
def print_literals():
function_jaxpr = jax.make_jaxpr(function)
closed_jaxpr = function_jaxpr(2.0)
print(closed_jaxpr)
print("-----------------------------------------")
print(closed_jaxpr.literals)
運(yùn)行結(jié)果打印輸出如下述暂,
{ lambda ; a:f32[]. let
b:f32[] = integer_pow[y=2] a
c:f32[] = mul 2.0 b
d:f32[] = mul 3.0 a
e:f32[] = add c d
in (e,) }
-----------------------------------------
[]
此時(shí)輸出結(jié)果就是以序列的方式對(duì)函數(shù)內(nèi)部參數(shù)進(jìn)行追蹤的Jaxpr代碼。
定義可被Jaxpr追蹤函數(shù)
對(duì)于解釋器的使用建炫,需要先將其注冊(cè)之后再遵循JAX原語的規(guī)則來使用畦韭。下面例子演示使用Jaxpr進(jìn)行包裝的函數(shù)。代碼如下所示肛跌,
import jax
def inverse_iterate_jaxpr(inverse_registry, jaxpr, consts, *args):
configurations = {}
def read(var):
if type(var) is jax.core.Literal:
return var.val
return configurations[var]
def write(var, value):
configurations[var] = value
jax.util.safe_map(write, jaxpr.outvars, args)
jax.util.safe_map(write, jaxpr.constvars, consts)
# Backwards iteration
for equation in jaxpr.eqns[:: -1]:
in_values = jax.util.safe_map(read, equation.outvars)
if equation.primitive not in inverse_registry:
raise NotImplementedError("{} does not registered inverse.".format(equation.primitive))
out_values = inverse_registry[equation.primitive](*in_values)
jax.util.safe_map(write, equation.invars, [out_values])
return jax.util.safe_map(read, jaxpr.invars)
def inverse(functionPointer, inverse_registry):
@jax.util.wraps(functionPointer)
def wrapped_function(*args, **kwargs):
function_jaxpr = jax.make_jaxpr(functionPointer)
closed_jaxpr = function_jaxpr(*args, **kwargs)
output = inverse_iterate_jaxpr(inverse_registry, closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return output[0]
return wrapped_function
def function(x):
tan = jax.numpy.tanh(x)
exp = jax.numpy.exp(tan)
return exp
def test():
function_jaxpr = jax.make_jaxpr(function)
jaxpr = function_jaxpr(2.)
運(yùn)行結(jié)果打印輸出如下艺配,
jaxpr = { lambda ; a:f32[]. let b:f32[] = tanh a; c:f32[] = exp b in (c,) }
---------------------------
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
result = { lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
---------------------------
可以看到察郁,自定義函數(shù)被前向和后向轉(zhuǎn)換后的結(jié)果,
- XLA是JAX使用的編譯器转唉,它使得JAX可以用于TPU皮钠,并迅速應(yīng)用于所有設(shè)備的編譯器,因此值得研究酝掩。但是鳞芙,直接使用原始C++接口處理 XLA計(jì)算并不容易。JAX通過Python包裝器公開底層的XLA計(jì)算生成器API期虾,并使得與XLA計(jì)算模型的交互訪問原朝,以便進(jìn)行融合。
- XLA計(jì)算在被編譯以計(jì)算圖的形式生成镶苞,然后降低到特定設(shè)備中喳坠,比如CPU、GPU和TPU茂蚓。
維度命名
之前在進(jìn)行矩陣計(jì)算時(shí)壕鹉,特別是在VGG訓(xùn)練時(shí),沒有使用維度名稱聋涨,而是根據(jù)位置約定來匹配batch_size晾浴、channels、height牍白、width等維度脊凰。而JAX一個(gè)特性是可以給維度進(jìn)行命名。對(duì)維度命名很有用茂腥,能夠幫助編程者如何使用命名軸來編寫文檔化函數(shù)狸涌,以更加直觀的方式來操控矩陣運(yùn)算。
以前面章節(jié)實(shí)現(xiàn)的全連接層完成MNIST數(shù)據(jù)集分類任務(wù)為例最岗,說明維度命名帕胆。代碼如下所示,
import jax
def forward(weight1, weight2, images):
dot = jax.numpy.dot(images, weight1)
hidden1 = jax.nn.relu(dot)
hidden2 = jax.numpy.dot(hidden1, weight2)
logtis = jax.nn.softmax(hidden2)
return logtis
def loss_function(weight1, weight2, images, labels):
predictions = forward(weight1 = weight1, weight2 = weight2, images = images)
targets = jax.nn.one_hot(labels, predictions.shape[-1])
losses = jax.numpy.sum(targets * predictions, axis = 1)
return -jax.numpy.mean(losses, axis = 0)
def train():
weight1 = jax.numpy.zeros((784, 512))
weight2 = jax.numpy.zeros((512, 10))
images = jax.numpy.zeros((128, 784))
labels = jax.numpy.zeros(128, dtype = jax.numpy.int32)
losses = loss_function(weight1, weight2, images, labels)
print("losses = ", losses)
def main():
train()
if __name__ == "__main__":
main()
上述代碼僅僅是簡單地實(shí)現(xiàn)了前向預(yù)測(cè)部分與損失函數(shù)的計(jì)算般渡。下面通過使用命名空間對(duì)這部分代碼進(jìn)行改寫懒豹,代碼如下,
axes = [
["inputs", "hidden"],
["hidden", "classes"],
["batch", "inputs"],
["batch", ...]
]
這里根據(jù)輸入的數(shù)據(jù)建立了對(duì)應(yīng)的維度名稱驯用,其中每個(gè)維度都被人為設(shè)定了特定的名稱脸秽。通過以下方式使用,
import jax
import numpy
from jax.experimental import maps
def predict(weight1, weight2, images):
dots = jax.numpy.dot(images, weight1)
hiddens = jax.nn.relu(dots)
logtis = jax.numpy.dot(hiddens, weight2)
return logtis - jax.nn.logsumexp(logtis, axis = 1, keepdims = True)
def loss_function(weight1, weight2, images, labels):
predictions = predict(weight1 = weight1, weight2 = weight2, images = images)
targets = jax.nn.one_hot(labels, predictions.shape[-1])
losses = jax.numpy.sum(targets * predictions, axis = 1)
return -jax.numpy.mean(losses, axis = 0)
# Named dimensions will be used to compute the data
def named_predict(weight1, weight2, images):
pdot = jax.lax.pdot(images, weight1, "inputs")
hidden = jax.nn.relu(pdot)
logtis = jax.lax.pdot(hidden, weight2, "hidden")
return logtis - jax.nn.logsumexp(logtis, "classes")
def named_loss_function(weight1, weight2, images, labels):
predictions = named_predict(weight1, weight2, images)
# jax.lax.psum(): Compute an all-reduce sum on x over the pmapped axis axis_name
number_classes = jax.lax.psum(1, "classes")
targets = jax.nn.one_hot(labels, number_classes, axis = "classes")
losses = jax.lax.psum(targets * predictions, "classes")
return -jax.lax.pmean(losses, "batch")
def train():
weight1 = jax.numpy.zeros((784, 512))
weight2 = jax.numpy.zeros((512, 10))
images = jax.numpy.zeros((128, 784))
labels = jax.numpy.zeros(128, dtype = jax.numpy.int32)
losses = loss_function(weight1, weight2, images, labels)
print("losses = ", losses)
in_axes = [
["inputs", "hidden"],
["hidden", "classes"],
["batch", "inputs"],
["batch", ...]
]
# Register the names for the dimensions
loss_function_xmap = maps.xmap(named_loss_function, in_axes = in_axes, out_axes = [...], axis_resources = {"batch": "x"})
devices = numpy.array(jax.local_devices())
with jax.sharding.Mesh(devices, ("x",)):
losses = loss_function_xmap(weight1, weight2, images, labels)
print("losses = ", losses)
def main():
train()
if __name__ == "__main__":
main()
運(yùn)行結(jié)果打印輸出如下晨汹,
losses = 2.3025854
losses = 2.3025854
通過給維度命名豹储,可以很好地對(duì)神經(jīng)網(wǎng)絡(luò)的維度進(jìn)行設(shè)定贷盲,而不至于在訓(xùn)練時(shí)因弄錯(cuò)維度而造成計(jì)算錯(cuò)誤淘这。畢竟一個(gè)有意義的名稱剥扣,讓讓人望文生義,明顯好于單純以數(shù)字標(biāo)識(shí)的維度位置铝穷。
自定義JAX中的向量Tensor
Python本身的NumPy(不是jax.numpy)中的編程模型是基于N維數(shù)組钠怯,而不是每一個(gè)N維數(shù)組數(shù)值包含2個(gè)部分,
- 數(shù)組中的數(shù)據(jù)類型曙聂。
- 數(shù)組的維度晦炊。
在JAX中,這兩個(gè)維度被同一成一個(gè)類型——dtype[shape_tuple]宁脊。舉例來說断国,一個(gè)float32的維度大小為[3, 17, 21]的數(shù)據(jù)被定義成f32[(3, 17, 21)]。下面通過一個(gè)小示例來掩飾形狀如何通過簡單的NumPy程序進(jìn)行傳播榆苞。
import numpy
import etils
class ArrayType:
def __getitem__(self, idx):
return Any
f32 = ArrayType()
def test():
array = numpy.ones(shape = (3, 17, 21))
print(array.shape)
array = numpy.arange(1071).reshape(3, 17, 21)
print(array.shape)
x: etils.array_types.f32[(2, 3)] = numpy.ones(shape = (2, 3), dtype = numpy.float32)
y: etils.array_types.f32[(3, 5)] = numpy.ones(shape = (3, 5), dtype = numpy.float32)
z: etils.array_types.f32[(2, 5)] = x.dot(y)
w: etils.array_types.f32[(7, 1, 5)] = numpy.ones((7, 1, 5), dtype = numpy.float32)
q: etils.array_types.f32[(7, 2, 5)] = z + w
print(f"x.shape = {x.shape}, y.shape = {y.shape}, z.shape = {z.shape}, w.shape = {w.shape}, q.shape = {q.shape}")
x: f32[(2, 3)] = numpy.ones(shape = (2, 3), dtype = numpy.float32)
y: f32[(3, 5)] = numpy.ones(shape = (3, 5), dtype = numpy.float32)
z: f32[(2, 5)] = x.dot(y)
w: f32[(7, 1, 5)] = numpy.ones((7, 1, 5), dtype = numpy.float32)
q: f32[(7, 2, 5)] = z + w
print(f"x.shape = {x.shape}, y.shape = {y.shape}, z.shape = {z.shape}, w.shape = {w.shape}, q.shape = {q.shape}")
def main():
test()
if __name__ == "__main__":
main()
運(yùn)行結(jié)果打印輸出如下稳衬,
(3, 17, 21)
(3, 17, 21)
x.shape = (2, 3), y.shape = (3, 5), z.shape = (2, 5), w.shape = (7, 1, 5), q.shape = (7, 2, 5)
x.shape = (2, 3), y.shape = (3, 5), z.shape = (2, 5), w.shape = (7, 1, 5), q.shape = (7, 2, 5)
關(guān)于f32,從過代碼可知坐漏,有兩種來源etils.array_types.f32和自定義的類薄疚,
class ArrayType:
def __getitem__(self, idx):
return Any
f32 = ArrayType()
實(shí)際上,在自定義類里赊琳,f32是定義的能夠接受和返回任何數(shù)據(jù)類型的自定義類街夭。此時(shí)這樣被自定義的類可以和正常的數(shù)組一樣被打印,并提供了一個(gè)對(duì)應(yīng)的shape大小躏筏。
結(jié)論
本章探討了jaxpr解釋器板丽,從組合函數(shù)的轉(zhuǎn)換、追蹤器寸士,以及自定義可被jaxpr追蹤的函數(shù)檐什,較為底層。同時(shí)弱卡,也從工程實(shí)踐角度通過命名維度來改善深度學(xué)習(xí)里對(duì)矩陣的管理乃正。
內(nèi)容較多,量力而行婶博。