1. 程式人生 > >torch學習筆記3.2:實現自定義模組(cpu)

torch學習筆記3.2:實現自定義模組(cpu)

在使用torch時,如果想自己實現一個層,則可以按照《torch學習筆記1:實現自定義層》 中的方法來實現。但是如果想要實現一個比較複雜的網路,往往需要自己實現多個層(或類),並且有時可能需要重寫其他模組中已有的函式來達到自己的目的,如果還是在nn模組中新增,會比較混亂,並且不利於本地git倉庫統一管理,這個時候,我們可以自己實現一個像nn一樣的模組,在程式碼中使用時 require即可。

我們來實現一個名為nxn的自定義模組,以及它的cuda版本cunxn模組,其中包含一個自定義的Hello類(lua實現),ReLU類(分別用CPU和GPU實現)。

由於篇幅原因,這裡把torch自定義模組的lua實現,cpu實現,gpu實現分別寫一篇文章,本文介紹cpu實現的ReLU類。

3 檔案說明

ReLU.lua

local ReLU = torch.class('nxn.ReLU')

function ReLU:__init(gpucompatible)

   self.gpucompatible=gpucompatible

   if self.gpucompatible then
      self.gradInput=torch.CudaTensor()
      self.output=torch.CudaTensor()
   else
      self.gradInput=torch.Tensor()
      self.output=torch.Tensor()
   end
self.outputSave=self.output self.gradInputSave=self.gradInput end function ReLU:updateOutput(input) -- 呼叫cpp實現的ReLU函式 return input.nxn.ReLU_updateOutput(self, input) end function ReLU:updateGradInput(input, gradOutput) -- 呼叫cpp實現的ReLU函式 return input.nxn.ReLU_updateGradInput(self, input, gradOutput) end
function ReLU:getDisposableTensors() return {self.output, self.gradInput, self.gradInputSave, self.outputSave} end

ReLU.c

內容如下:

#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/ReLU.c"
#else

static int nxn_(ReLU_updateOutput)(lua_State *L)
{
  printf("CPU version of ReLU updateOutput function\n");
  THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);

  THTensor_(resizeAs)(output, input);

  TH_TENSOR_APPLY2(real, output, real, input,         \
                   *output_data = *input_data > 0 ? *input_data : 0;)

  return 1;
}

static int nxn_(ReLU_updateGradInput)(lua_State *L)
{
  printf("CPU version of ReLU updateGradInput function\n");
  THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
  THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
  THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);

  THTensor_(resizeAs)(gradInput, output);
  TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,     \
                   *gradInput_data = *gradOutput_data * (*output_data > 0 ? 1 : 0););
  return 1;
}

static const struct luaL_Reg nxn_(ReLU__) [] = {
  {"ReLU_updateOutput", nxn_(ReLU_updateOutput)},
  {"ReLU_updateGradInput", nxn_(ReLU_updateGradInput)},
  {NULL, NULL}
};

static void nxn_(ReLU_init)(lua_State *L)
{
  luaT_pushmetatable(L, torch_Tensor);
  luaT_registeratname(L, nxn_(ReLU__), "nxn");
  lua_pop(L,1);
}

#endif

init.c

在編譯安裝模組時CMakeLists.txt根據init.c找類檔案:

#include "TH.h"
#include "luaT.h"

#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
#define torch_Tensor TH_CONCAT_STRING_3(torch.,Real,Tensor)
#define nxn_(NAME) TH_CONCAT_3(nxn_, Real, NAME)

#include "generic/ReLU.c"
#include "THGenerateFloatTypes.h"


LUA_EXTERNC DLL_EXPORT int luaopen_libnxn(lua_State *L);
// 把cpp實現編譯到libnxn
int luaopen_libnxn(lua_State *L)
{
  lua_newtable(L);
  lua_pushvalue(L, -1);
  lua_setfield(L, LUA_GLOBALSINDEX, "nxn");

  nxn_FloatReLU_init(L);
  nxn_DoubleReLU_init(L);

  return 1;
}