cublasGemmEx函式應用-探究8bit矩陣乘
介紹
cublasGemmEx 是CUDA8.0中cuBLAS新出的函式,是cublasgemm()類函式的擴充套件,也是目前來看功能最強大的矩陣乘函數了。該函式另一強大之處在於支援多種計算模式(compute type),其中就包括CUDA 8.0新出的FP16和INT8。但是該函式的文件並不太健全,最近在使用這個函式實現INT8矩陣乘的時候就碰見坑了,照著文件用就是報錯,找NVIDIA的工程師才給解決。下面總結一下使用經驗,把坑填上,以防大家再踩。
函式原型
cublasStatus_t cublasGemmEx(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType_t Atype,
int lda,
const void *B,
cudaDataType_t Btype,
int ldb,
const void *beta,
void *C,
cudaDataType_t Ctype,
int ldc,
cudaDataType_t computeType,
cublasGemmAlgo_t algo)
跟cublasSgemm長的比較像,但是多了這麼幾個引數,Atype,Btype,Ctype,computeType和algo。
這個函式的核心就是計算模式(computeType),computeType支援以下型別:
computeType | 解釋 |
---|---|
CUDA_R_16F | FP16計算模式,輸入輸出都是FP16 |
CUDA_R_32F | FP32計算模式,這個比較強大,輸入可以是FP16、INT8和FP32 |
CUDA_R_32I | INT8計算模式,也是本文著重要講的模式 |
CUDA_R_64F | FP64計算模式 |
CUDA_C_32F | |
CUDA_C_64F |
每個computeType支援的輸入型別和輸出型別在cublasGemmEx文件中寫的非常清楚,照著用就行了。但是,有一個隱含的坑就在CUDA_R_32I計算模式裡。
正常按照 char *A, char *B, int *C是會報錯CUBLAS_STATUS_NOT_SUPPORTED,這個錯誤官方的解釋是“the combination of the parameters Atype, Btype and Ctype and the algorithm type, algo is not supported”,大概意思就是Atype,Btype,Ctype,和algo不匹配。但是明明是按文件上寫的啊,因為錯誤根本不在這裡。
解決辦法
錯誤的原因是,如果要使用CUDA_R_32I計算模式,那麼alpha和beta這兩個引數也必須是int型別且必須是0或者1……神坑啊。
PS:CUDA_R_32I計算模式下,cublasGemmAlgo_t 引數好像也只支援前7種,這個在文件裡也沒說。
CUDA_R_32I與CUDA_R_32F計算對比結果
這裡多說一點INT8矩陣乘計算模式吧,CUDA_R_32I計算模式裡呼叫CUDA 8.0新出的INT8計算介面-dp4a,按照官方的理論,dp4a這個函式會將四個char組合成一個int進行乘法運算,將4次乘法和3次加法減少為一次高階指令,從而提高效能。
我的實驗結果表明,CUDA_R_32I模式與CUDA_R_32F模式相比,最快能提高3.2倍(與矩陣的大小有關),同時能將資料壓縮75%,這是一個非常可觀的收益了。
但是FP32(float)量化成INT8(char)肯定是會有精度損失的,對INT8有興趣的可以關注NVIDIA新出的TensorRT2.0,該庫能夠在一些情況下保持較高的精度實現INT8加速。
TensorRT給的資料也比較少,坑也特別多,因此我開了一個TensorRT_Tutorial,歡迎志同道合者一起參與。