1. 程式人生 > >cublasGemmEx函式應用-探究8bit矩陣乘

cublasGemmEx函式應用-探究8bit矩陣乘

介紹

cublasGemmEx 是CUDA8.0中cuBLAS新出的函式,是cublasgemm()類函式的擴充套件,也是目前來看功能最強大的矩陣乘函數了。該函式另一強大之處在於支援多種計算模式(compute type),其中就包括CUDA 8.0新出的FP16和INT8。但是該函式的文件並不太健全,最近在使用這個函式實現INT8矩陣乘的時候就碰見坑了,照著文件用就是報錯,找NVIDIA的工程師才給解決。下面總結一下使用經驗,把坑填上,以防大家再踩。

函式原型

cublasStatus_t cublasGemmEx(cublasHandle_t handle, 
                            cublasOperation_t transa, 
                            cublasOperation_t transb, 
                            int
m, int n, int k, const void *alpha, const void *A, cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int
ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType, cublasGemmAlgo_t algo)

跟cublasSgemm長的比較像,但是多了這麼幾個引數,Atype,Btype,Ctype,computeType和algo。

這個函式的核心就是計算模式(computeType),computeType支援以下型別:

computeType 解釋
CUDA_R_16F FP16計算模式,輸入輸出都是FP16
CUDA_R_32F FP32計算模式,這個比較強大,輸入可以是FP16、INT8和FP32
CUDA_R_32I INT8計算模式,也是本文著重要講的模式
CUDA_R_64F FP64計算模式
CUDA_C_32F
CUDA_C_64F

每個computeType支援的輸入型別和輸出型別在cublasGemmEx文件中寫的非常清楚,照著用就行了。但是,有一個隱含的坑就在CUDA_R_32I計算模式裡。

正常按照 char *A, char *B, int *C是會報錯CUBLAS_STATUS_NOT_SUPPORTED,這個錯誤官方的解釋是“the combination of the parameters Atype, Btype and Ctype and the algorithm type, algo is not supported”,大概意思就是Atype,Btype,Ctype,和algo不匹配。但是明明是按文件上寫的啊,因為錯誤根本不在這裡。

解決辦法

錯誤的原因是,如果要使用CUDA_R_32I計算模式,那麼alpha和beta這兩個引數也必須是int型別且必須是0或者1……神坑啊。

PS:CUDA_R_32I計算模式下,cublasGemmAlgo_t 引數好像也只支援前7種,這個在文件裡也沒說。

CUDA_R_32I與CUDA_R_32F計算對比結果

這裡多說一點INT8矩陣乘計算模式吧,CUDA_R_32I計算模式裡呼叫CUDA 8.0新出的INT8計算介面-dp4a,按照官方的理論,dp4a這個函式會將四個char組合成一個int進行乘法運算,將4次乘法和3次加法減少為一次高階指令,從而提高效能。

我的實驗結果表明,CUDA_R_32I模式與CUDA_R_32F模式相比,最快能提高3.2倍(與矩陣的大小有關),同時能將資料壓縮75%,這是一個非常可觀的收益了。

但是FP32(float)量化成INT8(char)肯定是會有精度損失的,對INT8有興趣的可以關注NVIDIA新出的TensorRT2.0,該庫能夠在一些情況下保持較高的精度實現INT8加速。

TensorRT給的資料也比較少,坑也特別多,因此我開了一個TensorRT_Tutorial,歡迎志同道合者一起參與。

Figure 2: Illustration of the All-Reduce collective.