1. 程式人生 > >Torch 中的引用、深拷貝 以及 getParameters 獲取引數的探討

Torch 中的引用、深拷貝 以及 getParameters 獲取引數的探討

Preface

這段時間一直在苦練 Torch,我是把 Torch 當作深度學習裡面的 Matlab 來用了。但最近碰到個兩個坑,把我坑的蠻慘。

一個是關於 Torch 中賦值引用、深拷貝的問題,另一個是關於 getParameters() 獲取引數引發的問題。

第一個坑: Torch 中的引用賦值 以及 深拷貝

第一個坑是,Torch 中的輸入 Tensor,經過一個網路層之後,輸出的結果,其 Size 會隨著之後輸入資料 Size 的變化而變化。這個問題,我在知乎上提問了:https://www.zhihu.com/question/48986099中科院自動化所的博士@feanfrog

給我做了解答。在這裡,我再總結一下。

輸入一個數據,比如說隨機生成:

x1 = torch.randn(3, 128, 128))

經過一網路 convNet,如卷積層:

convNet = nn.SpatialConvolution(3,64, 3,3, 1,1, 1,1))

進行 forward 之後,其輸出結果為 y1,其 Size64×128×128 .

這裡寫圖片描述

但輸出的這個 Size 大小,會隨著之後的輸入資料的 Size 的變化而變化!這是詭異的地方……

如又輸入:

x2 = torch.randn(5, 3, 128, 128) 

這個 x2 經過 convNet:forward

>y2 = convNet:forward(x2)

>y2:size()
  5
 64
128
128
[torch.LongStorage of size 4]

這個 y2size5×64×128×128 ,這是應該的。
但是請看 y1:size()

這裡寫圖片描述

y1size 也變成了 5×64×128×128 !

我百思不得其解,想了兩天,都沒找出願意(太渣了……),實在沒有辦法,就到知乎上提問了,結果真有大神 @beanfrog 給我解答了:

這裡寫圖片描述

原來 Torch 中為了提高速度,model:forward() 操作之後賦予的變數是不給這個變數開盤新的儲存空間的,而是 引用

。就相當於 起了個別名

不光這裡,torch裡面向量或是矩陣的賦值是指向同一記憶體的,這種策略不同於 Matlab。如果想不想引用,可以用 clone() 進行 深拷貝,如下的例子:

這裡寫圖片描述

當改變變數 v 的第一個元素的值時,變數 t 也隨之變化。

第二個坑: Torch 中 getParameters 獲取引數引起的疑惑

當有如下的程式碼:

require 'nn'

local convNet = nn.Sequential()
convNet:add(nn.Linear(2, 3))
convNet:add(nn.Tanh())

local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')
local params, gradParams = convNet:getParameters()

params:fill(0)
print(convNet2:get(1).weight)

此時的輸出為:

這裡寫圖片描述

感覺輸出結果很顯然的樣子。但當將上述的程式碼做一下微調:
require 'nn'

local convNet = nn.Sequential()
convNet:add(nn.Linear(2, 3))
convNet:add(nn.Tanh())

local params, gradParams = convNet:getParameters()
local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')

params:fill(0)
print(convNet2:get(1).weight)

輸出的結果為:

這裡寫圖片描述

僅僅將程式碼中下面的兩行做了對調:

local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')
local params, gradParams = convNet:getParameters()

結果就是兩種不同的結果。


看一下別人的解釋:

This is not really a bug though. getParameters() is a bit subtle, and should be documented properly.

It gets and flattens all the parameters of any given module, and insures that the set of parameters, as well as all the sharing in place within that module, remains consistent.

In the example you show, you’re grabbing the parameters of ‘convNet’, but getParameters() doesn’t know about the external convNet2. So sharing will be lost.



我自己的理解是
在第一段程式碼中,順序是:

local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')
local params, gradParams = convNet:getParameters()

是先進行的拷貝 clone(),是 深拷貝convNet2convNet 並不是同一個儲存。之後再 getParameters(要所有的引數 拉平) 的時候,已經不關 convNet2 的事了。
這時候再通過 params:fill(0) 賦值的時候(因為 getParameters() 得到的只是引數的引用,與原先引數指向的同一塊記憶體,所以可以通過 params:fill(0) 這種方式給 convNet 網路賦值),對 convNet2 已經沒有影響了。所以 convNet2 保持原先的值。

而第二段程式碼,順序是:

local params, gradParams = convNet:getParameters()
local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')

注意這時候,convNet 中的儲存結構,已經被 getParameters() 函式給 拉平 了,相當於是 convNet結構已經被破壞了
因為在官方的 Module 文件中的 getParameters() 函式這塊,有這麼一句話

This function will go over all the weights and gradWeights and make them view into a single tensor (one for weights and one for gradWeights). Since the storage of every weight and gradWeight is changed, this function should be called only once on a given network.

下面 convNet2clone('weight',...) 這樣拷貝,已經失效了。所以,實際上,這時候 convNet2 進行的所謂的 深拷貝,並不是真正的 深拷貝,而是 失效的深拷貝

下面我們可以通過加一句話驗證一下,上面的深拷貝是失效的:

require 'nn'

local convNet = nn.Sequential()
convNet:add(nn.Linear(2, 3))
convNet:add(nn.Tanh())

local params, gradParams = convNet:getParameters()
local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')

-- 加上下面這一句, 這一句的拷貝不指定拷貝的引數, 如 weight, bias 這些
-- 而是預設的進行深拷貝
local convNet3 = convNet:clone()

params:fill(0)
print(convNet2:get(1).weight)

print('------------------------')
print(convNet3:get(1).weight)

我們在中間加了一句:local convNet3 = convNet:clone(),自動的進行深拷貝,而不是指定引數。看輸出結果:

這裡寫圖片描述

看到了嗎?!
convNet3 的引數與 convNet 不是同一塊儲存地址,深拷貝成功。而 convNet2 的深拷貝失效,所以當 params:fill(0) 的時候,convNet2 的引數也變了。但 convNet3 的深拷貝成功!

總結一下:
實驗證明,我的猜想是成功的。由於 getParameters() 獲取引數使得 convNet 的網路引數被 拉平 了,所以 convNet2 的深拷貝方式就已經失效了,convNet2 本質上跟 convNet 還是共用的一塊記憶體地址:

local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')

反而不指定引數的 convNet3 的深拷貝方式,反而保持有效:

local convNet3 = convNet:clone()

Reference