CUDA9.0中GEMM接口不支持?jǐn)?shù)據(jù)按列存儲(chǔ)(即橫向排布)平匈,當(dāng)只有C橫向排布時(shí)會(huì)報(bào)第15個(gè)值錯(cuò)誤,當(dāng)A牍氛、B橫向排布時(shí)晨继,雖然不會(huì)報(bào)錯(cuò),但后續(xù)操作會(huì)訪(fǎng)存越界搬俊。
接口如下
/** cublasStatus_t cublasSgemmStridedBatched(cublasHandle_t handle,
* cublasOperation_t transa, cublasOperation_t transb,
* int m, int n, int k,
* const float *alpha,
* const float *A, int lda, long long int strideA,
* const float *B, int ldb, long long int strideB,
* const float *beta,
* float *C, int ldc, long long int strideC,
* int batchCount)
*/
其中紊扬,A,B,C內(nèi)部進(jìn)行運(yùn)算的小矩陣分別是MxK, KxN, MxN大小,TRANSA, TRANSB表示是否使用對(duì)應(yīng)矩陣的轉(zhuǎn)置唉擂,ALPHA, BETA為對(duì)應(yīng)的系數(shù)餐屎。而LDA, LDB, LDC表示對(duì)應(yīng)矩陣的leading dimension,即第一維度的大小玩祟。LDA表示一個(gè)batch中一個(gè)矩陣行的長(zhǎng)度, 因?yàn)榫仃囋趦?nèi)存中是連續(xù)存放的腹缩,而這個(gè)leading dimension的量用來(lái)定義元素?fù)Q行后的位置,即A[i, j] = A + i*lda + j
空扎。
C++矩陣默認(rèn)行優(yōu)先存儲(chǔ)藏鹊,BLAS庫(kù)默認(rèn)列優(yōu)先存儲(chǔ),所以?xún)蓚€(gè)矩陣要反過(guò)來(lái)輸入(無(wú)需指定trans=CUBLAS_OP_T)转锈,transa就是對(duì)應(yīng)于第一個(gè)矩陣是否轉(zhuǎn)置盘寡。不過(guò)如果反著輸入,且兩個(gè)trans都為CUBLAS_OP_N, 則我們令MxK為第一個(gè)輸入中小矩陣形狀, KxN為第二個(gè)輸入中小矩陣形狀, 然后將N,M,K對(duì)應(yīng)于接口的m,n,k進(jìn)行輸入撮慨。如果其中trans為T(mén)竿痰,則m,n,k用來(lái)指定其小矩陣轉(zhuǎn)置后的形狀,表示實(shí)際參與運(yùn)算的矩陣大小砌溺,對(duì)應(yīng)于邏輯上的AxB(而不是反著輸入的B和A)影涉,但trans的順序是對(duì)應(yīng)反著輸入的兩個(gè)矩陣的。
// 以BERT中Q*K'為例
// 對(duì)于K, 大矩陣是lda * m(ldQKV * S), 小矩陣轉(zhuǎn)置后是k * m(headSize * S), 轉(zhuǎn)置只是轉(zhuǎn)置小矩陣
// 對(duì)于Q, 大矩陣是ldb * n(ldQKV * S), 小矩陣轉(zhuǎn)置后是k * n(headSize x S)
CUBLAS_CHECK(cublasGemmStridedBatched<T>(cublas, CUBLAS_OP_T, CUBLAS_OP_N, S, S, headSize, 1.f, kptr, ldQKV, strideQKV,
qptr, ldQKV, strideQKV, 0.f, qkptr, S, omatSize, numMats));
參考https://blog.csdn.net/feng__shuai/article/details/107091684