1. 程式人生 > 其它 >pytorch之C++實現自定義運算元

pytorch之C++實現自定義運算元

技術標籤:深度學習c++pythonjavalinux程式語言

自定義運算元

對於輸入 x,其輸出為

利用C++實現以上運算元,總共只要實現兩個檔案:

setup.py

利用python中提供的setuptools模組完成事先編譯流程,將寫有運算元的C++檔案,編譯成為一個動態連結庫(在Linux平臺是一個.so字尾檔案),可以讓python呼叫其中實現的函式功能。需要setup.py編寫如下

from setuptools import setup
from torch.utils import cpp_extension

setup(
    name='ncrelu_cpp',						# 編譯後的連結庫名稱
    ext_modules=[
        cpp_extension.CppExtension(
            'ncrelu_cpp', ['ncrelu.cpp']		       # 待編譯檔案,及編譯函式
        )
    ],
    cmdclass={						       # 執行編譯命令設定
        'build_ext': cpp_extension.BuildExtension
    }
)

這裡Pytorch提供了一個封裝cpp_extension,方便編譯過程中所需要的設定選項,以及所需包含的標頭檔案位置路徑設定等等。

ncrelu.cpp

接下來便是C++實現的函式ncrelu.cpp。首先上程式碼

#include <torch/extension.h>					// 標頭檔案引用部分 

torch::Tensor ncrelu_forward(torch::Tensor input) {
    auto pos = input.clamp_min(0);				       // 具體實現部分
    auto neg = input.clamp_max(0);
    return torch::cat({pos, neg}, 1);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {	// 繫結部分
    m.def("forward", &ncrelu_forward, "NCReLU forward");
}

以上程式碼包含了三個部分。分別是

標頭檔案引用部分:這裡包含了torch/extension.h標頭檔案,是編寫Pytorch的C++擴充套件時必須包含的一個檔案。它基本上囊括了實現中所需要的所有依賴,包含了ATen庫pybind11和二者之間的互動。其中ATen是Pytorch底層張量運算庫,負責實現具體張量操作運算;pybind11是實現C++程式碼到python的繫結(binding),可以在python裡呼叫C++函式。

具體實現部分:函式返回型別和傳遞引數型別均是torch::Tensor類,這種物件不僅包含了資料,還附屬了諸多運算操作。因此我們可以看到在下面實現方式類似於python中使用Pytorch張量運算操作一樣,可以直接呼叫如擷取操作clamp和拼接操作cat等

,非常簡潔已讀且方便。

繫結部分:只需要在m.def中傳入引數,分別是繫結到python的函式名稱,需繫結的C++函式引用,以及一個簡短的函式說明字串,用來新增到python函式中的__doc__成員名稱中。

將以上兩個檔案放在同一資料夾下,然後進行編譯。使用python setup.py build_ext --inplace命令,如果一切正常將會在資料夾下產生類似ncrelu_cpp.cpython-35m-x86_64-linux-gnu.so動態連結檔案。然後我們可以啟動python檢測是否可以匯入其中的函式

>> import torch
>> import ncrelu_cpp
>> a = torch.randn(4, 3)
>> a
tensor([[ 0.5921,  0.3207,  0.7690],
        [ 1.4514, -0.8942,  0.9039],
        [-0.3262, -0.1610,  0.6137],
        [ 0.7824, -1.8527,  0.2844]])
>> b = ncrelu_cpp.forward(a)
>> b
tensor([[ 0.5921,  0.3207,  0.7690,  0.0000,  0.0000,  0.0000],
        [ 1.4514,  0.0000,  0.9039,  0.0000, -0.8942,  0.0000],
        [ 0.0000,  0.0000,  0.6137, -0.3262, -0.1610,  0.0000],
        [ 0.7824,  0.0000,  0.2844,  0.0000, -1.8527,  0.0000]])