torch學習筆記3.2:實現自定義模組(cpu)
阿新 • • 發佈:2019-01-07
在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新增,會比較混亂,並且不利於本地git倉庫統一管理,這個時候,我們可以自己實現一個像nn一樣的模組,在程式碼中使用時 require即可。
我們來實現一個名為nxn的自定義模組,以及它的cuda版本cunxn模組,其中包含一個自定義的Hello類(lua實現),ReLU類(分別用CPU和GPU實現)。
由於篇幅原因,這裡把torch自定義模組的lua實現,cpu實現,gpu實現分別寫一篇文章,本文介紹cpu實現的ReLU類。
3 檔案說明
ReLU.lua
local ReLU = torch.class('nxn.ReLU')
function ReLU:__init(gpucompatible)
self.gpucompatible=gpucompatible
if self.gpucompatible then
self.gradInput=torch.CudaTensor()
self.output=torch.CudaTensor()
else
self.gradInput=torch.Tensor()
self.output=torch.Tensor()
end
self.outputSave=self.output
self.gradInputSave=self.gradInput
end
function ReLU:updateOutput(input)
-- 呼叫cpp實現的ReLU函式
return input.nxn.ReLU_updateOutput(self, input)
end
function ReLU:updateGradInput(input, gradOutput)
-- 呼叫cpp實現的ReLU函式
return input.nxn.ReLU_updateGradInput(self, input, gradOutput)
end
function ReLU:getDisposableTensors()
return {self.output, self.gradInput, self.gradInputSave, self.outputSave}
end
ReLU.c
內容如下:
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/ReLU.c"
#else
static int nxn_(ReLU_updateOutput)(lua_State *L)
{
printf("CPU version of ReLU updateOutput function\n");
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
THTensor_(resizeAs)(output, input);
TH_TENSOR_APPLY2(real, output, real, input, \
*output_data = *input_data > 0 ? *input_data : 0;)
return 1;
}
static int nxn_(ReLU_updateGradInput)(lua_State *L)
{
printf("CPU version of ReLU updateGradInput function\n");
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
THTensor_(resizeAs)(gradInput, output);
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \
*gradInput_data = *gradOutput_data * (*output_data > 0 ? 1 : 0););
return 1;
}
static const struct luaL_Reg nxn_(ReLU__) [] = {
{"ReLU_updateOutput", nxn_(ReLU_updateOutput)},
{"ReLU_updateGradInput", nxn_(ReLU_updateGradInput)},
{NULL, NULL}
};
static void nxn_(ReLU_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nxn_(ReLU__), "nxn");
lua_pop(L,1);
}
#endif
init.c
在編譯安裝模組時CMakeLists.txt根據init.c找類檔案:
#include "TH.h"
#include "luaT.h"
#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
#define torch_Tensor TH_CONCAT_STRING_3(torch.,Real,Tensor)
#define nxn_(NAME) TH_CONCAT_3(nxn_, Real, NAME)
#include "generic/ReLU.c"
#include "THGenerateFloatTypes.h"
LUA_EXTERNC DLL_EXPORT int luaopen_libnxn(lua_State *L);
// 把cpp實現編譯到libnxn
int luaopen_libnxn(lua_State *L)
{
lua_newtable(L);
lua_pushvalue(L, -1);
lua_setfield(L, LUA_GLOBALSINDEX, "nxn");
nxn_FloatReLU_init(L);
nxn_DoubleReLU_init(L);
return 1;
}