本文第一部分將講解如何在計(jì)算機(jī)上實(shí)現(xiàn)通用的矩陣乘法(General matrix multiply, GEMM)丹禀,第二部分講解神經(jīng)網(wǎng)絡(luò)加速包NNPACK基于NEON指令實(shí)現(xiàn)的矩陣乘法。
這是文章的第一部分筋粗。閱讀后讀者應(yīng)能了解計(jì)算機(jī)算矩陣乘法與我們自己筆算有何不同,如何根據(jù)這些不同來(lái)設(shè)計(jì)最基本的矩陣乘法算法嫉父,并擴(kuò)展成具有標(biāo)準(zhǔn)接口的函數(shù)壶冒,以及設(shè)計(jì)算法時(shí)值得注意之處。錯(cuò)漏之處歡迎指正禁炒。
1. 在計(jì)算機(jī)上實(shí)現(xiàn)矩陣乘法
首先回憶一下我們?cè)趺垂P算兩個(gè)矩陣相乘而咆。假設(shè)我們有一個(gè)8x12的A矩陣,一個(gè)12x16的矩陣B幕袱,他倆相乘暴备,得到8x16的矩陣C。我們會(huì)遍歷C矩陣的每一個(gè)位置凹蜂,比如當(dāng)我們想求C(3,4)這個(gè)位置的值馍驯,如上圖所示,應(yīng)該取A矩陣的第3行和B矩陣的第4列玛痊,求這兩個(gè)向量的內(nèi)積汰瘫,也就是把他倆各自的12個(gè)元素兩兩相乘然后相加:
C[2][3] = 0;
for (int k = 0; k < 12; k++) {
C[2][3] += A[2][k] * B[k][3]
}
整個(gè)矩陣相乘即是:
memset(C, 0, 8 * 16 * sizeof(float));
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 16; j++) {
for (int k = 0; k < 12; k++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
這就是我們熟知的最普通但也最萬(wàn)能的公式了。如果這活交給計(jì)算機(jī)來(lái)做擂煞,跟我們筆算有什么不一樣混弥?
1.1 SIMD指令
現(xiàn)在不少CPU采用了單指令多數(shù)據(jù)技術(shù)(SIMD),一次可以對(duì)128位二進(jìn)制數(shù)據(jù)做一個(gè)相同操作对省。這就是說(shuō)蝗拿,過(guò)去我們的代碼,比如C[i][j] += A[i][k] * B[k][j];
蒿涎,每次運(yùn)算只操作一個(gè)32位的數(shù)據(jù)(float)哀托;但現(xiàn)在用上SIMD技術(shù),比如ARM芯片的NEON指令:
float32x4_t v1 = (float32x4_t) { 0.0f, 1.0f, 2.0f, 3.0f};
float32x4_t v2 = (float32x4_t) {-0.0f, -1.0f, -2.0f, -3.0f};
float32x4_t v3 = vaddq_f32(v1, v2); // v3 = { 0.0f, 0.0f, 0.0f, 0.0f}
float32x4_t v4 = vmulq_f32(v1, v2); // v4 = { 0.0f, -1.0f, -4.0f, -9.0f}
float32x4_t
就是一個(gè)由4個(gè)32位的float組成的數(shù)據(jù)類(lèi)型劳秋,對(duì)它做一次操作仓手,4個(gè)float都被用到胖齐。vaddq_f32
函數(shù)讓CPU只需要1次運(yùn)算,就能算出v1
和v2
的4個(gè)對(duì)應(yīng)元素相加的結(jié)果嗽冒,然后存到v3
里呀伙;vmulq_f32
函數(shù)同樣只需要1次運(yùn)算,就能得到v1
和v2
對(duì)應(yīng)元素相乘的結(jié)果添坊。
1.2 多線(xiàn)程并行計(jì)算
矩陣乘法有一個(gè)特點(diǎn):對(duì)于8x16的C矩陣剿另,假設(shè)我們有8x16個(gè)人,他們每個(gè)人負(fù)責(zé)算C矩陣一個(gè)元素的值贬蛙,那么他們的任務(wù)將是相互獨(dú)立雨女、互不影響的,因?yàn)樗麄冎恍枰谕粔K內(nèi)存上取數(shù)據(jù)速客,然后各自算各自的戚篙,算完了再寫(xiě)到不同位置上去。有些人算得快溺职,有些算得慢岔擂;有些馬上開(kāi)始算,有些睡了一天才開(kāi)始浪耘。但這些都不會(huì)影響最終結(jié)果的正確性乱灵,畢竟有獨(dú)立性。
現(xiàn)在把人換成CPU的核七冲。假設(shè)它有8x16個(gè)核痛倚,每個(gè)核各跑1個(gè)線(xiàn)程,就可以讓每個(gè)線(xiàn)程負(fù)責(zé)算C矩陣的一個(gè)元素澜躺;假設(shè)它只有2個(gè)核蝉稳、2個(gè)線(xiàn)程,那么每個(gè)線(xiàn)程負(fù)責(zé)算4x16個(gè)元素掘鄙,或者讓一個(gè)線(xiàn)程只算1個(gè)元素耘戚、另一個(gè)線(xiàn)程算8x16-1個(gè)元素,最后的結(jié)果都是對(duì)的操漠。至于算得快不快收津,就看線(xiàn)程池任務(wù)調(diào)度合理不合理了。
總之浊伙,計(jì)算機(jī)可以并行地算矩陣乘法撞秋。于CPU而言,可以在它的每一個(gè)核上創(chuàng)建一個(gè)線(xiàn)程嚣鄙,哪個(gè)線(xiàn)程閑著就給它派個(gè)獨(dú)立的小任務(wù)吻贿,所有小任務(wù)做完了矩陣乘法也就算好了。如果是GPU哑子,它可能有成百上千個(gè)核舅列,那更得把任務(wù)拆散了派發(fā)下去奉芦。
1.3 算法怎么實(shí)現(xiàn)
請(qǐng)牢記,當(dāng)我們?cè)谟?jì)算機(jī)上做矩陣乘法的時(shí)候剧蹂,一是可以用SIMD指令(比如ARM芯片的NEON),在同樣的時(shí)間內(nèi)多算幾個(gè)數(shù)烦却;二是可以在多核心的CPU上用多個(gè)線(xiàn)程并行計(jì)算宠叼,當(dāng)然用GPU就更棒了。接下來(lái)就看算法怎么寫(xiě)其爵。
因?yàn)镹EON指令集其他函數(shù)沒(méi)有那么顧名思義冒冬,后文中我們將沿用其數(shù)據(jù)類(lèi)型float32x4_t,但不再直接用其函數(shù)名∧γ欤現(xiàn)在定義以下顧名思義的函數(shù):
float32x4_t vget(float *src);
float32x4_t vdup(float num);
void write(float *dst, float32x4_t vec);
float32x4_t vadd(float32x4_t v1, float32x4_t v2);
float32x4_t vmul(float32x4_t v1, float32x4_t v2);
void svv_mul_add(float32x4_t v0, float32x4_t v1, float32x4_t v2, float s1);
void vvv_mul_add(float32x4_t v0, float32x4_t v1, float32x4_t v2, float32x4_t v3);
-
vget
函數(shù):從地址src
那里取4個(gè)float简烤,組成一個(gè)float32x4_t并返回 -
vdup
函數(shù):直接輸入一個(gè)float,把它復(fù)制粘貼4次摇幻,組成一個(gè)float32x4_t并返回 -
write
函數(shù):把一個(gè)float32x4_t寫(xiě)到地址dst
去横侦,相當(dāng)于一次寫(xiě)入4個(gè)float -
vadd
和vmul
函數(shù):兩個(gè)函數(shù)分別返回v1
、v2
對(duì)應(yīng)元素相加绰姻、相乘的結(jié)果 -
svv_mul_add
函數(shù):取float32x4_t型的v1
枉侧、v2
和float型的s1
,然后讓v1
的每一個(gè)元素都乘上s1
狂芋,將其結(jié)果與v2
對(duì)應(yīng)位置的元素相加榨馁,寫(xiě)到同為float32x4_t 型的v0
-
vvv_mul_add
函數(shù):取float32x4_t型的v1
、v2
和v3
帜矾,然后讓v1
和v2
每一個(gè)對(duì)應(yīng)元素相乘翼虫,再與v3
的每一個(gè)對(duì)應(yīng)元素相加,寫(xiě)到同為float32x4_t 型的v0
(以上函數(shù)對(duì)應(yīng)的NEON指令分別是vld1q_f32
屡萤、vdupq_n_f32
珍剑、vst1q_f32
、vaddq_f32
灭衷、vmulq_f32
次慢、vfmaq_lane_f32
和vfmaq_f32
;在其他指令集中應(yīng)該也有對(duì)應(yīng)的函數(shù))
如上圖所示翔曲,如果我們要求C(1,5)到C(1,8)這4個(gè)點(diǎn)的值迫像,就不再需要4x12個(gè)循環(huán),而只需要12個(gè)瞳遍。第一種寫(xiě)法如下:
float32x4_t ret = vdup(0.0f);
for (int k = 0; k < 12; k++) {
svv_mul_add(ret, A[0][k], vget(&B[k][4]), ret);
}
write(ret, &C[0][4]);
同樣也可以用vvv_mul_add
函數(shù):
float32x4_t ret = vdup(0.0f);
for (int k = 0; k < 12; k++) {
vvv_mul_add(ret, vdup(A[0][k]), vget(&B[k][4]), ret);
}
write(ret, &C[0][4]);
這樣我們需要for循環(huán)執(zhí)行的次數(shù)就變成原來(lái)的1/4闻妓。不過(guò),試想接下來(lái)如果我們要求C(2,5)到C(2,8)這4個(gè)點(diǎn)的值掠械,就又需要一個(gè)for循環(huán)由缆,重新取一遍B矩陣第5到第8列的所有值注祖,與A矩陣第二列相乘。這個(gè)取值也是有時(shí)間成本的均唉,應(yīng)當(dāng)盡量避免是晨。那我們不妨這樣:
float32x4_t vc0 = vdup(0.0f);
float32x4_t vc1 = vdup(0.0f);
float32x4_t vc2 = vdup(0.0f);
float32x4_t vc3 = vdup(0.0f);
for (int k = 0; k < 12; k++) {
float32x4_t vb = vget(&B[k][4]);
vvv_mul_add(vc0, vdup(A[0][k]), vb, vc0);
vvv_mul_add(vc1, vdup(A[1][k]), vb, vc1);
vvv_mul_add(vc2, vdup(A[2][k]), vb, vc2);
vvv_mul_add(vc3, vdup(A[3][k]), vb, vc3);
}
write(vc0, &C[0][4]);
write(vc1, &C[1][4]);
write(vc2, &C[2][4]);
write(vc3, &C[3][4]);
改寫(xiě)后的代碼,for循環(huán)會(huì)取遍A矩陣第1到第4行舔箭、B矩陣第5到第8列的所有值罩缴,算出C矩陣紅色區(qū)域內(nèi)的16的元素。后文中我會(huì)把這樣的情況叫做每次算出C矩陣一個(gè)4x4的塊(block)层扶。這樣改寫(xiě)并不會(huì)減少乘法和加法的計(jì)算次數(shù)箫章,但能把對(duì)B矩陣取值的次數(shù)減少到原來(lái)的1/4,因?yàn)槊看稳〕鰜?lái)的值都被用了4次镜会。
是不是取A矩陣取得越多列越好呢檬寂?如果每次取8列,對(duì)B矩陣的取值次數(shù)不就只有原來(lái)的1/8了嗎戳表?每次取10000列不就……同樣地桶至,如果B矩陣每次取8列,不就可以把對(duì)A矩陣取值的次數(shù)減到原來(lái)的1/2了嗎扒袖?每次取10000……
這樣想確實(shí)沒(méi)什么大毛病塞茅。我也試過(guò),每次算一個(gè)8x8的塊確實(shí)比算4x4更快季率。不過(guò)我們現(xiàn)在舉的例子都是比較簡(jiǎn)單的情況野瘦,即A矩陣的行數(shù)、B矩陣的列數(shù)都是4或者8的整數(shù)倍飒泻,如果是更一般的情況鞭光,即不是整數(shù)倍、存在余數(shù)泞遗,或者干脆小于4或8惰许,這些部分處理起來(lái)是很麻煩的,需要大量的判斷語(yǔ)句(if...else
, switch...case
)史辙,這也是會(huì)耗時(shí)間的汹买,可能得不償失。
如果塊取得太大聊倔,比如取到了16x16晦毙,那么A矩陣的行數(shù)、B矩陣的列數(shù)就各有15/16的幾率不是16的整數(shù)倍耙蔑。如果計(jì)算兩個(gè)17x17的方陣相乘见妒,C矩陣將被劃分成4個(gè)塊(尺寸分別是16x16,16x1甸陌,1x16和1x1)须揣;只有其中1個(gè)塊滿(mǎn)16x16盐股,可以用類(lèi)似上面的很簡(jiǎn)潔的代碼算出來(lái);計(jì)算另外3個(gè)塊(占75%)都需要大量的判斷語(yǔ)句耻卡,確保取值和賦值不會(huì)過(guò)界疯汁,這就造成大量時(shí)間浪費(fèi)。但如果取的是4x4的塊卵酪,C矩陣被劃分成25個(gè)塊涛目,只有其中9個(gè)塊不滿(mǎn)4x4(占36%)需要判斷語(yǔ)句。
還有一種處理方式就是補(bǔ)0凛澎。如果使用了16x16的塊,就用0來(lái)把A矩陣的行數(shù)估蹄、B矩陣的列數(shù)補(bǔ)成16的整數(shù)倍塑煎。最后算出來(lái)的C矩陣周?chē)邪肴Φ?,保證其行數(shù)臭蚁、列數(shù)都是16的倍數(shù)最铁,于是還需要去掉這些0。這樣倒是不需要大量的判斷語(yǔ)句了垮兑,但這來(lái)回來(lái)去的倒騰也是很耗時(shí)間的冷尉。
所以最合適的塊的大小究竟是多少,應(yīng)該通過(guò)測(cè)試來(lái)找系枪,還要參考實(shí)際的業(yè)務(wù)需求雀哨。
1.4 擴(kuò)展成標(biāo)準(zhǔn)接口
至此,算法的輪廓已經(jīng)隱約可見(jiàn)了:
- 確定每次算C矩陣一個(gè)多大的塊
- 把計(jì)算每一塊作為一個(gè)小任務(wù)私爷,通過(guò)線(xiàn)程池分發(fā)任務(wù)
- 等待所有小任務(wù)執(zhí)行完畢即可
需要注意的一個(gè)是選擇多大的塊雾棺,一個(gè)是處理邊緣上那些不滿(mǎn)的塊,取值衬浑、賦值的時(shí)候都要判斷是否超出范圍捌浩。這樣就可以完成一個(gè)最基本的C = A * B
的算法。
很多矩陣運(yùn)算庫(kù)定義的矩陣乘法是這樣的:
gemm(const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const float alpha,
const float *A,
const float *B,
const float beta,
float *C);
它們計(jì)算的是這樣的式子:
C = alpha * op(A) * op(B) + beta * C
op
的意思是相乘之前可以要求先對(duì)這個(gè)矩陣轉(zhuǎn)置工秩,也就是調(diào)用gemm
函數(shù)時(shí)前兩個(gè)參數(shù)可以是trans
或者noTrans
尸饺;alpha
和beta
是兩個(gè)常數(shù),也就是要求矩陣的每個(gè)元素都要乘上一個(gè)常數(shù)助币。
再看我們的算法浪听,如果要求考慮A、B矩陣事先轉(zhuǎn)置的情況奠支,就得修改取值的代碼馋辈。比如原來(lái)對(duì)B矩陣的取值是連續(xù)取4個(gè)值:
float32x4_t vb = vget(&B[k][4]);
當(dāng)要求B矩陣轉(zhuǎn)置的時(shí)候,就得這樣:
float32x4_t vb = (float32x4_t) {B[k+0][4], B[k+1][4], B[k+2][4], B[k+3][4]};
不連續(xù)取值可能會(huì)降低效率倍谜,或許在某些情況下還不如用別的代碼迈螟,比如iOS可以用vDsp_mTrans
叉抡,先把B矩陣轉(zhuǎn)置一下,再像從前一樣連續(xù)取值答毫。
另外我們?cè)瓉?lái)的賦值語(yǔ)句是這樣寫(xiě)的:
write(vc0, &C[0][4]);
write(vc1, &C[1][4]);
write(vc2, &C[2][4]);
write(vc3, &C[3][4]);
考慮alpha
和beta
時(shí)褥民,需要改寫(xiě)成:
float32x4_t valpha = vdup(alpha);
float32x4_t vbeta = vdup(beta);
write(vadd(vmul(vc0, valpha), vmul(vget(&C[0][4]), vbeta)), &C[0][4]);
write(vadd(vmul(vc1, valpha), vmul(vget(&C[1][4]), vbeta)), &C[1][4]);
write(vadd(vmul(vc2, valpha), vmul(vget(&C[2][4]), vbeta)), &C[2][4]);
write(vadd(vmul(vc3, valpha), vmul(vget(&C[3][4]), vbeta)), &C[3][4]);
經(jīng)過(guò)這些修改,即可獲得一個(gè)具有標(biāo)準(zhǔn)接口的gemm
函數(shù)洗搂。