pytorch之C++實現自定義運算元
技術標籤:深度學習c++pythonjavalinux程式語言
自定義運算元
對於輸入 x,其輸出為
利用C++實現以上運算元,總共只要實現兩個檔案:
setup.py
利用python中提供的setuptools模組完成事先編譯流程,將寫有運算元的C++檔案,編譯成為一個動態連結庫(在Linux平臺是一個.so字尾檔案),可以讓python呼叫其中實現的函式功能。需要setup.py
編寫如下
from setuptools import setup from torch.utils import cpp_extension setup( name='ncrelu_cpp', # 編譯後的連結庫名稱 ext_modules=[ cpp_extension.CppExtension( 'ncrelu_cpp', ['ncrelu.cpp'] # 待編譯檔案,及編譯函式 ) ], cmdclass={ # 執行編譯命令設定 'build_ext': cpp_extension.BuildExtension } )
這裡Pytorch提供了一個封裝cpp_extension
,方便編譯過程中所需要的設定選項,以及所需包含的標頭檔案位置路徑設定等等。
ncrelu.cpp
接下來便是C++實現的函式ncrelu.cpp
。首先上程式碼
#include <torch/extension.h> // 標頭檔案引用部分 torch::Tensor ncrelu_forward(torch::Tensor input) { auto pos = input.clamp_min(0); // 具體實現部分 auto neg = input.clamp_max(0); return torch::cat({pos, neg}, 1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // 繫結部分 m.def("forward", &ncrelu_forward, "NCReLU forward"); }
以上程式碼包含了三個部分。分別是
標頭檔案引用部分:這裡包含了torch/extension.h
標頭檔案,是編寫Pytorch的C++擴充套件時必須包含的一個檔案。它基本上囊括了實現中所需要的所有依賴,包含了ATen庫,pybind11和二者之間的互動。其中ATen是Pytorch底層張量運算庫,負責實現具體張量操作運算;pybind11是實現C++程式碼到python的繫結(binding),可以在python裡呼叫C++函式。
具體實現部分:函式返回型別和傳遞引數型別均是torch::Tensor
類,這種物件不僅包含了資料,還附屬了諸多運算操作。因此我們可以看到在下面實現方式類似於python中使用Pytorch張量運算操作一樣,可以直接呼叫如擷取操作clamp和拼接操作cat等
繫結部分:只需要在m.def
中傳入引數,分別是繫結到python的函式名稱,需繫結的C++函式引用,以及一個簡短的函式說明字串,用來新增到python函式中的__doc__
成員名稱中。
將以上兩個檔案放在同一資料夾下,然後進行編譯。使用python setup.py build_ext --inplace
命令,如果一切正常將會在資料夾下產生類似ncrelu_cpp.cpython-35m-x86_64-linux-gnu.so
動態連結檔案。然後我們可以啟動python檢測是否可以匯入其中的函式
>> import torch
>> import ncrelu_cpp
>> a = torch.randn(4, 3)
>> a
tensor([[ 0.5921, 0.3207, 0.7690],
[ 1.4514, -0.8942, 0.9039],
[-0.3262, -0.1610, 0.6137],
[ 0.7824, -1.8527, 0.2844]])
>> b = ncrelu_cpp.forward(a)
>> b
tensor([[ 0.5921, 0.3207, 0.7690, 0.0000, 0.0000, 0.0000],
[ 1.4514, 0.0000, 0.9039, 0.0000, -0.8942, 0.0000],
[ 0.0000, 0.0000, 0.6137, -0.3262, -0.1610, 0.0000],
[ 0.7824, 0.0000, 0.2844, 0.0000, -1.8527, 0.0000]])