加速python運(yùn)行-numba
numba是一個用于編譯Python數(shù)組和數(shù)值計(jì)算函數(shù)的編譯器,這個編譯器能夠大幅提高直接使用Python編寫的函數(shù)的運(yùn)算速度污桦。
numba使用LLVM編譯器架構(gòu)將純Python代碼生成優(yōu)化過的機(jī)器碼琴庵,通過一些添加簡單的注解,將面向數(shù)組和使用大量數(shù)學(xué)的python代碼優(yōu)化到與c,c++和Fortran類似的性能轰豆,而無需改變Python的解釋器。
Numba的主要特性:
- 動態(tài)代碼生成 (在用戶偏愛的導(dǎo)入期和運(yùn)行期)
- 為CPU(默認(rèn))和GPU硬件生成原生的代碼
- 集成Python的科學(xué)軟件棧(Numpy)
下面是使用Numba優(yōu)化的函數(shù)方法齿诞,將Numpy數(shù)組作為參數(shù):
import numba
@numba.jit
def sum2d(arr):
M, N = arr.shape
result = 0.0
for i in range(M):
for j in range(N):
result += arr[i,j]
return result
如果你對此不是太感興趣酸休,或者對于其他的加速方案已經(jīng)很熟悉,可以到此為止祷杈,只需要了解加上jit裝飾器就可以實(shí)現(xiàn)了斑司。
使用jit
使用jit的好處就在于讓numba來決定什么時候以及怎么做優(yōu)化。
from numba import jit
@jit
def f(x, y):
# A somewhat trivial example
return x + y
比如這段代碼但汞,計(jì)算將延期到第一次函數(shù)執(zhí)行宿刮,numba將在調(diào)用期間推斷參數(shù)類型互站,然后基于這個信息生成優(yōu)化后的代碼。numba也能夠基于輸入的類型編譯生成特定的代碼僵缺。例如胡桃,對于上面的代碼,傳入整數(shù)和復(fù)數(shù)作為參數(shù)將會生成不同的代碼:
>>>f(1,2)
3
>>>f(1j,2)
(2+1j)
我們也可以加上所期望的函數(shù)簽名:
from numba import jit, int32
@jit(int32(int32, int32))
def f(x, y):
# A somewhat trivial example
return x + y
int32(int32, int32) 是函數(shù)簽名磕潮,這樣翠胰,相應(yīng)的特性將會被@jit裝飾器編譯,然后自脯,編譯器將控制類型選擇之景,并不允許其他特性(即其他類型的參數(shù)輸入,如float)
Numba編譯的函數(shù)可以調(diào)用其他編譯函數(shù)膏潮。 函數(shù)調(diào)用甚至可以在本機(jī)代碼中內(nèi)聯(lián)锻狗,具體取決于優(yōu)化器的啟發(fā)式。 例如:
@jit
def square(x):
return x ** 2
@jit
def hypot(x, y):
return math.sqrt(square(x) + square(y))
@jit裝飾器必須添加到任何庫函數(shù)戏罢,否則numba可能生成速度更慢的代碼屋谭。
簽名規(guī)范
Explicit @jit signatures can use a number of types. Here are some common ones:
void is the return type of functions returning nothing (which actually return None when called from Python)
intp and uintp are pointer-sized integers (signed and unsigned, respectively)
intc and uintc are equivalent to C int and unsigned int integer types
int8, uint8, int16, uint16, int32, uint32, int64, uint64 are fixed-width integers of the corresponding bit width (signed and unsigned)
float32 and float64 are single- and double-precision floating-point numbers, respectively
complex64 and complex128 are single- and double-precision complex numbers, respectively
array types can be specified by indexing any numeric type, e.g. float32[:] for a one-dimensional single-precision array or int8[:,:] for a two-dimensional array of 8-bit integers.
編譯選項(xiàng)
numba有兩種編譯模式:nopython模式和object模式。前者能夠生成更快的代碼龟糕,但是有一些限制可能迫使numba退為后者桐磁。想要避免退為后者,而且拋出異常讲岁,可以傳遞nopython=True.
@jit(nopython=True)
def f(x, y):
return x + y
當(dāng)Numba不需要保持全局線程鎖時我擂,如果用戶設(shè)定nogil=True,當(dāng)進(jìn)入這類編譯好的函數(shù)時缓艳,Numba將會釋放全局線程鎖校摩。
@jit(nogil=True)
def f(x, y):
return x + y
這樣可以利用多核系統(tǒng),但不能使用的函數(shù)是在object模式下編譯阶淘。
想要避免你調(diào)用python程序的編譯時間衙吩,可以這頂numba保存函數(shù)編譯結(jié)果到一個基于文件的緩存中∠希可以通過傳遞cache=True實(shí)現(xiàn)坤塞。
@jit(cache=True)
def f(x, y):
return x + y
開啟一個實(shí)驗(yàn)性質(zhì)的特性將函數(shù)中的這些操作自動并行化。這一特性可以通過傳遞parallel=True打開澈蚌,然后必須也要和nopython=True配合起來一起使用摹芙。編譯器將編譯一個版本,并行運(yùn)行多個原生的線程(沒有GIL)
@jit(nopython=True, parallel=True)
def f(x, y):
return x + y
generated_jit
有時候想要編寫一個函數(shù)宛瞄,基于輸入的類型實(shí)現(xiàn)不同的實(shí)現(xiàn)浮禾,generated_jit()裝飾器允許用戶在編譯期控制不同的特性的選擇。假定想要編寫一個函數(shù),基于某些需求盈电,返回所給定的值是否缺失的類型蝴簇,具體定義如下:
- 對于浮點(diǎn)數(shù),缺失的值為NaN挣轨。
- 對于Numpy的datetime64和timedelta64參數(shù)军熏,缺失值為NaT
- 其他類型沒有定義的缺失值
import numpy as np
from numba import generated_jit, types
@generated_jit(nopython=True)
def is_missing(x):
"""
Return True if the value is missing, False otherwise.
"""
if isinstance(x, types.Float):
return lambda x: np.isnan(x)
elif isinstance(x, (types.NPDatetime, types.NPTimedelta)):
# The corresponding Not-a-Time value
missing = x('NaT')
return lambda x: x == missing
else:
return lambda x: False
有以下幾點(diǎn)需要注意:
- 調(diào)用裝飾器函數(shù)是使用Numba的類型作為參數(shù),而不是他們的值卷扮。
- 裝飾器函數(shù)并不真的計(jì)算結(jié)果荡澎,而是返回一個對于給定類型,可調(diào)用的實(shí)際定義的函數(shù)執(zhí)行晤锹。
- 可以在編譯期預(yù)先計(jì)算一些數(shù)據(jù)摩幔,使其在編譯后執(zhí)行過程中重用。
- 函數(shù)定義使用和裝飾器函數(shù)中相同名字的參數(shù)鞭铆,這將確保通過名字傳遞參數(shù)能夠如期望的工作或衡。
使用@vectorize 裝飾器創(chuàng)建Numpy的 universal 函數(shù)
Numba的vectorize允許Python函數(shù)將標(biāo)量輸入?yún)?shù)作為Numpy的ufunc使用,將純Python函數(shù)編譯成ufunc车遂,使之速度與使用c編寫的傳統(tǒng)的ufunc函數(shù)一樣封断。
vectorize()有兩種操作模型:
- 主動,或者裝飾期間編譯:如果傳遞一個或者多個類型簽名給裝飾器舶担,就將構(gòu)建Numpy的universal function坡疼。后面將介紹使用裝飾期間編譯ufunc。
- 被動(惰性)衣陶,或者調(diào)用期間編譯:當(dāng)沒有提供任何簽名柄瑰,裝飾器將提供一個Numba動態(tài)universal function(DUFunc),當(dāng)一個未支持的新類型調(diào)用時剪况,就動態(tài)編譯一個新的內(nèi)核教沾,后面的“動態(tài) universal functions”將詳細(xì)介紹
如上所描述,如果傳遞一個簽名給vectorizer()裝飾器译断,函數(shù)將編譯成一個numpy 的ufunc:
from numba import vectorize, float64
@vectorize([float64(float64, float64)])
def f(x, y):
return x + y
如果想傳遞多個簽名授翻,注意順序,精度低的在前孙咪,高的在后藏姐,否則就會出奇怪的問題。例如int32就只能在int64之前该贾。
@vectorize([int32(int32, int32),
int64(int64, int64),
float32(float32, float32),
float64(float64, float64)])
def f(x, y):
return x + y
如果給定的類型正確:
>>> a = np.arange(6)
>>> f(a, a)
array([ 0, 2, 4, 6, 8, 10])
>>> a = np.linspace(0, 1, 6)
>>> f(a, a)
array([ 0. , 0.4, 0.8, 1.2, 1.6, 2. ])
如果提供了不支持的類型:
>>> a = np.linspace(0, 1+1j, 6)
>>> f(a, a)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: ufunc 'ufunc' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
vectorizer與jit裝飾器的差別:numpy的ufunc自動加載其他特性,例如:reduction, accumulation or broadcasting:
>>> a = np.arange(12).reshape(3, 4)
>>> a
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> f.reduce(a, axis=0)
array([12, 15, 18, 21])
>>> f.reduce(a, axis=1)
array([ 6, 22, 38])
>>> f.accumulate(a)
array([[ 0, 1, 2, 3],
[ 4, 6, 8, 10],
[12, 15, 18, 21]])
>>> f.accumulate(a, axis=1)
array([[ 0, 1, 3, 6],
[ 4, 9, 15, 22],
[ 8, 17, 27, 38]])
vectorize() 裝飾器支持多個ufunc 目標(biāo):
Target | Description |
---|---|
cpu | Single-threaded CPU |
parallel | Multi-core CPU |
cuda | CUDA GPU |
guvectorize裝飾器只用了進(jìn)一步的概念捌臊,允許用戶編寫ufuncs操作輸入數(shù)組中的任意數(shù)量的元素杨蛋,返回不同緯度的數(shù)組。典型的應(yīng)用是運(yùn)行求均值或者卷積濾波。
Numba支持通過jitclass裝飾器實(shí)現(xiàn)對于類的代碼生成逞力∈锕眩可以使用這個裝飾器來標(biāo)注優(yōu)化,類中的所有方法都被編譯成nopython function寇荧。
import numpy as np
from numba import jitclass # import the decorator
from numba import int32, float32 # import the types
spec = [
('value', int32), # a simple scalar field
('array', float32[:]), # an array field
]
@jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
self.array = np.zeros(value, dtype=np.float32)
@property
def size(self):
return self.array.size
def increment(self, val):
for i in range(self.size):
self.array[i] = val
return self.array
性能建議
對于Numba提供的最靈活的jit裝飾器举庶,首先將嘗試使用no python模式編譯,如果失敗了揩抡,就再嘗試使用object模式編譯户侥,盡管使用object模式可以提高性能,但將函數(shù)在no python模式下編譯才是提升性能的關(guān)鍵峦嗤。想要直接使用nopython模式蕊唐,可以直接使用裝飾器@njit,這個裝飾器與@jit(nopython=True)等價烁设。
@njit
def ident_np(x):
return np.cos(x) ** 2 + np.sin(x) ** 2
@njit
def ident_loops(x):
r = np.empty_like(x)
n = len(x)
for i in range(n):
r[i] = np.cos(x[i]) ** 2 + np.sin(x[i]) ** 2
return r
Function Name | @njit | Execution time |
---|---|---|
ident_np | No | 0.581s |
ident_np | Yes | 0.659s |
ident_loops | No | 25.2s |
ident_loops | Yes | 0.670s |
有時候不那么嚴(yán)格的規(guī)定數(shù)據(jù)將會帶來性能的提升替梨,此時,惡意使用fastmath關(guān)鍵字參數(shù):
@njit(fastmath=False)
def do_sum(A):
acc = 0.
# without fastmath, this loop must accumulate in strict order
for x in A:
acc += np.sqrt(x)
return acc
@njit(fastmath=True)
def do_sum_fast(A):
acc = 0.
# with fastmath, the reduction can be vectorized as floating point
# reassociation is permitted.
for x in A:
acc += np.sqrt(x)
return acc
Function Name | Execution time |
---|---|
do_sum | 35.2 ms |
do_sum_fast | 17.8 ms |
Trubleshooting and tips
想要編譯什么装黑?
通常建議是編譯代碼中耗時最長的關(guān)鍵路徑副瀑,如果有一部分代碼耗時很長,但在一些高階的代碼之中恋谭,可能就需要重構(gòu)這些對于性能有更高要求的代碼到一個單獨(dú)的函數(shù)中糠睡,讓numba專注于這些對于性能敏感的代碼有以下好處:
- 避免遇見不支持的特性
- 減少編譯時間
- 在需要編譯的函數(shù)外,高階的代碼會更簡單
不想要編譯什么箕别?
numba編譯失敗的原因很多铜幽,最常見的一個原因就是你寫的代碼依賴于不支持的Python特性,尤其是nopython模式串稀,可以查看支持的python特性
在numba編譯代碼之前色瘩,先要確定所有使用的變量的類型锄奢,這樣就能生成你的代碼的特定類型的機(jī)器碼。一個常見的編譯失敗原因(尤其是nopython模式)就是類型推導(dǎo)失敗,numba不能確定代碼中所有變量的類型档痪。
例如:參考這個函數(shù):
@jit(nopython=True)
def f(x, y):
return x + y
如果使用兩個數(shù)字作為參數(shù):
>>> f(1,2)
3
如果傳入一個元組和一個數(shù)字,numba不能得到數(shù)字和元組求和的結(jié)果训堆,就會觸發(fā)編譯報錯:
>>> f(1, (2,))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path>/numba/numba/dispatcher.py", line 339, in _compile_for_args
reraise(type(e), e, None)
File "<path>/numba/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (int64, tuple(int64 x 1))
Known signatures:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
* (uint16,) -> uint64
* (uint8,) -> uint64
* (uint64,) -> uint64
* (uint32,) -> uint64
* (int16,) -> int64
* (int64,) -> int64
* (int8,) -> int64
* (int32,) -> int64
* (float32,) -> float32
* (float64,) -> float64
* (complex64,) -> complex64
* (complex128,) -> complex128
* parameterized
[1] During: typing of intrinsic-call at <stdin> (3)
File "<stdin>", line 3:
錯誤信息“Invalid usage of + with parameters (int64, tuple(int64 x 1))”可以解釋為numba解釋器遇到了一個整數(shù)和元組中的整數(shù)求和做葵,
類型統(tǒng)一問題
另一個編譯失敗的常見原因是:不能靜態(tài)的決定返回的類型;返回值的類型僅僅依賴于運(yùn)行期华烟。這樣的事情也是僅僅發(fā)生在nopython 模式下翩迈。類型統(tǒng)一的概念僅僅只是嘗試找到一個類型,兩個變量能夠使用該類型安全的顯示盔夜;例如一個64位的浮點(diǎn)數(shù)和一個64位的復(fù)數(shù)可以同時使用128位的復(fù)數(shù)表示负饲。
以下是一個類型統(tǒng)一錯誤堤魁,這個函數(shù)的返回類型是基于x的值在運(yùn)行期決定的:
In [1]: from numba import jit
In [2]: @jit(nopython=True)
...: def f(x):
...: if x > 10:
...: return (1,)
...: else:
...: return 1
...:
嘗試執(zhí)行這個函數(shù),就會得到以下的錯誤:
In [3]: f(10)
TypingError: Failed at nopython (nopython frontend)
Can't unify return type from the following types: tuple(int64 x 1), int64
Return of: IR name '$8.2', type '(int64 x 1)', location:
File "<ipython-input-2-51ef1cc64bea>", line 4:
def f(x):
<source elided>
if x > 10:
return (1,)
^
Return of: IR name '$12.2', type 'int64', location:
File "<ipython-input-2-51ef1cc64bea>", line 6:
def f(x):
<source elided>
else:
return 1
錯誤信息: “Can’t unify return type from the following types: tuple(int64 x 1), int64” 可以理解為: “Numba cannot find a type that can safely represent a 1-tuple of integer and an integer”.
編譯的太慢
最常見的編譯速度很慢的原因是:nopython模式編譯失敗返十,然后嘗試使用object模式編譯妥泉。object模式當(dāng)前幾乎沒有提供加速特性,只是提供了一種叫做loop-lifting的優(yōu)化洞坑,這個優(yōu)化將允許使用nopython模式在內(nèi)聯(lián)迭代下編譯盲链。
可以在編譯好的函數(shù)上使用inspect_types()方法來查看函數(shù)的類型推導(dǎo)是否成功。例如迟杂,對于以下函數(shù):
@jit
def f(a, b):
s = a + float(b)
return s
當(dāng)使用numbers調(diào)用時刽沾,該函數(shù)將和numba一樣快速的將數(shù)字轉(zhuǎn)換為浮點(diǎn)數(shù):
>>> f(1, 2)
3.0
>>> f.inspect_types()
f (int64, int64)
--------------------------------------------------------------------------------
# --- LINE 7 ---
@jit
# --- LINE 8 ---
def f(a, b):
# --- LINE 9 ---
# label 0
# a.1 = a :: int64
# del a
# b.1 = b :: int64
# del b
# $0.2 = global(float: <class 'float'>) :: Function(<class 'float'>)
# $0.4 = call $0.2(b.1, ) :: (int64,) -> float64
# del b.1
# del $0.2
# $0.5 = a.1 + $0.4 :: float64
# del a.1
# del $0.4
# s = $0.5 :: float64
# del $0.5
s = a + float(b)
# --- LINE 10 ---
# $0.7 = cast(value=s) :: float64
# del s
# return $0.7
return s
關(guān)閉jit編譯
設(shè)定NUMBA_DISABLE_JIT 環(huán)境變量為 1.
FAQ
Q:能否傳遞一個函數(shù)作為參數(shù)?
A:不能逢慌,但可以使用閉包來模擬實(shí)現(xiàn)悠轩,例如:
@jit(nopython=True)
def f(g, x):
return g(x) + g(-x)
result = f(my_g_function, 1)
可以使用一個工廠函數(shù)重構(gòu):
def make_f(g):
# Note: a new f() is compiled each time make_f() is called!
@jit(nopython=True)
def f(x):
return g(x) + g(-x)
return f
f = make_f(my_g_function)
result = f(1)
Q:對于全局變量修改的問題
A:非常不建議使用全局變量,否則只能使用recompile()函數(shù)重新編譯攻泼,這樣還不如重構(gòu)代碼火架,不使用全局變量。
Q:如何調(diào)試jit的函數(shù)忙菠?
A:可以調(diào)用pdb何鸡,也可以臨時關(guān)閉編譯環(huán)境變量:NUMBA_DISABLE_JIT。
Q:如何增加整數(shù)的位寬
A:默認(rèn)情況下牛欢,numba為整形變量生成機(jī)器整形位寬骡男。我們可以使用np.int64為相關(guān)變量初始化(例如:np.int64(0)而不是0)。
Q:如何知道parallel=True已經(jīng)工作了傍睹?
A:如果parallel=True隔盛,設(shè)定環(huán)境變量NUMBA_WARNING為非0,所裝飾的函數(shù)轉(zhuǎn)換失敗拾稳,就顯示報警吮炕;同樣,環(huán)境變量:NUMBA_DEBUG_ARRAY_OPT_STAT將展示一些統(tǒng)計(jì)結(jié)果访得。