1. 程式人生 > 程式設計 >Pytorch轉keras的有效方法,以FlowNet為例講解

Pytorch轉keras的有效方法,以FlowNet為例講解

Pytorch憑藉動態圖機制,獲得了廣泛的使用,大有超越tensorflow的趨勢,不過在工程應用上,TF仍然佔據優勢。有的時候我們會遇到這種情況,需要把模型應用到工業中,運用到實際專案上,TF支援的PB檔案和TF的C++介面就成為了有效的工具。今天就給大家講解一下Pytorch轉成Keras的方法,進而我們也可以獲得Pb檔案,因為Keras是支援tensorflow的,我將會在下一篇部落格講解獲得Pb檔案,並使用Pb檔案的方法。

Pytorch To Keras

首先,我們必須有清楚的認識,網上以及github上一些所謂的pytorch轉換Keras或者Keras轉換成Pytorch的工具程式碼幾乎不能執行或者有使用的侷限性(比如僅僅能轉換某一些模型),但是我們是可以用這些轉換程式碼中看出一些端倪來,比如二者的引數的尺寸(shape)的形式、channel的排序(first or last)是否一樣,掌握到差異性,就能根據這些差異自己編寫轉換程式碼,沒錯,自己編寫轉換程式碼,是最穩妥的辦法。整個過程也就分為兩個部分。筆者將會以Nvidia開源的FlowNet為例,將開源的Pytorch程式碼轉化為Keras模型。

按照Pytorch中模型的結構,編寫對應的Keras程式碼,用keras的函式式API,構建起來會非常方便。

把Pytorch的模型引數,按照層的名稱依次賦值給Keras的模型

以上兩步雖然看上去簡單,但實際我也走了不少彎路。這裡一個關鍵的地方,就是引數的shape在兩個框架中是否統一,那當然是不統一的。下面我以FlowNet為例。

Pytorch中的FlowNet程式碼

我們僅僅展示層名稱和層引數,就不把整個結構貼出來了,否則會佔很多的空間,形成水文。

先看用Keras搭建的flowNet模型,直接用model.summary()輸出模型資訊

__________________________________________________________________________________________________
Layer (type)   Output Shape  Param # Connected to   
==================================================================================================
input_1 (InputLayer)  (None,6,512,512) 0      
__________________________________________________________________________________________________
conv0 (Conv2D)   (None,64,512) 3520 input_1[0][0]   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None,512) 0  conv0[0][0]   
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None,514,514) 0  leaky_re_lu_1[0][0]  
__________________________________________________________________________________________________
conv1 (Conv2D)   (None,256,256) 36928 zero_padding2d_1[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None,256) 0  conv1[0][0]   
__________________________________________________________________________________________________
conv1_1 (Conv2D)  (None,128,256 73856 leaky_re_lu_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None,256 0  conv1_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None,258,258 0  leaky_re_lu_3[0][0]  
__________________________________________________________________________________________________
conv2 (Conv2D)   (None,128 147584 zero_padding2d_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None,128 0  conv2[0][0]   
__________________________________________________________________________________________________
conv2_1 (Conv2D)  (None,128 147584 leaky_re_lu_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None,128 0  conv2_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None,130,130 0  leaky_re_lu_5[0][0]  
__________________________________________________________________________________________________
conv3 (Conv2D)   (None,64) 295168 zero_padding2d_3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None,64) 0  conv3[0][0]   
__________________________________________________________________________________________________
conv3_1 (Conv2D)  (None,64) 590080 leaky_re_lu_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None,64) 0  conv3_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None,66,66) 0  leaky_re_lu_7[0][0]  
__________________________________________________________________________________________________
conv4 (Conv2D)   (None,32,32) 1180160 zero_padding2d_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None,32) 0  conv4[0][0]   
__________________________________________________________________________________________________
conv4_1 (Conv2D)  (None,32) 2359808 leaky_re_lu_8[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None,32) 0  conv4_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None,34,34) 0  leaky_re_lu_9[0][0]  
__________________________________________________________________________________________________
conv5 (Conv2D)   (None,16,16) 2359808 zero_padding2d_5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None,16) 0  conv5[0][0]   
__________________________________________________________________________________________________
conv5_1 (Conv2D)  (None,16) 2359808 leaky_re_lu_10[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None,16) 0  conv5_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None,18,18) 0  leaky_re_lu_11[0][0]  
__________________________________________________________________________________________________
conv6 (Conv2D)   (None,1024,8,8) 4719616 zero_padding2d_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None,8) 0  conv6[0][0]   
__________________________________________________________________________________________________
conv6_1 (Conv2D)  (None,8) 9438208 leaky_re_lu_12[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None,8) 0  conv6_1[0][0]   
__________________________________________________________________________________________________
deconv5 (Conv2DTranspose) (None,16) 8389120 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
predict_flow6 (Conv2D)  (None,2,8) 18434 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None,16) 0  deconv5[0][0]   
__________________________________________________________________________________________________
upsampled_flow6_to_5 (Conv2DTra (None,16) 66  predict_flow6[0][0]  
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None,1026,16) 0  leaky_re_lu_11[0][0]  
         leaky_re_lu_14[0][0]  
         upsampled_flow6_to_5[0][0] 
__________________________________________________________________________________________________
inter_conv5 (Conv2D)  (None,16) 4728320 concatenate_1[0][0]  
__________________________________________________________________________________________________
deconv4 (Conv2DTranspose) (None,32) 4202752 concatenate_1[0][0]  
__________________________________________________________________________________________________
predict_flow5 (Conv2D)  (None,16) 9218 inter_conv5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU) (None,32) 0  deconv4[0][0]   
__________________________________________________________________________________________________
upsampled_flow5_to4 (Conv2DTran (None,32) 66  predict_flow5[0][0]  
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None,770,32) 0  leaky_re_lu_9[0][0]  
         leaky_re_lu_15[0][0]  
         upsampled_flow5_to4[0][0] 
__________________________________________________________________________________________________
inter_conv4 (Conv2D)  (None,32) 1774336 concatenate_2[0][0]  
__________________________________________________________________________________________________
deconv3 (Conv2DTranspose) (None,64) 1577088 concatenate_2[0][0]  
__________________________________________________________________________________________________
predict_flow4 (Conv2D)  (None,32) 4610 inter_conv4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU) (None,64) 0  deconv3[0][0]   
__________________________________________________________________________________________________
upsampled_flow4_to3 (Conv2DTran (None,64) 66  predict_flow4[0][0]  
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None,386,64) 0  leaky_re_lu_7[0][0]  
         leaky_re_lu_16[0][0]  
         upsampled_flow4_to3[0][0] 
__________________________________________________________________________________________________
inter_conv3 (Conv2D)  (None,64) 444800 concatenate_3[0][0]  
__________________________________________________________________________________________________
deconv2 (Conv2DTranspose) (None,128) 395328 concatenate_3[0][0]  
__________________________________________________________________________________________________
predict_flow3 (Conv2D)  (None,64) 2306 inter_conv3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU) (None,128) 0  deconv2[0][0]   
__________________________________________________________________________________________________
upsampled_flow3_to2 (Conv2DTran (None,128) 66  predict_flow3[0][0]  
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None,194,128 0  leaky_re_lu_5[0][0]  
         leaky_re_lu_17[0][0]  
         upsampled_flow3_to2[0][0] 
__________________________________________________________________________________________________
inter_conv2 (Conv2D)  (None,128) 111808 concatenate_4[0][0]  
__________________________________________________________________________________________________
predict_flow2 (Conv2D)  (None,128) 1154 inter_conv2[0][0]  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None,512) 0  predict_flow2[0][0] 

再看看Pytorch搭建的flownet模型

 (conv0): Sequential(
 (0): Conv2d(6,kernel_size=(3,3),stride=(1,1),padding=(1,1))
 (1): LeakyReLU(negative_slope=0.1,inplace)
 )
 (conv1): Sequential(
 (0): Conv2d(64,stride=(2,2),inplace)
 )
 (conv1_1): Sequential(
 (0): Conv2d(64,inplace)
 )
 (conv2): Sequential(
 (0): Conv2d(128,inplace)
 )
 (conv2_1): Sequential(
 (0): Conv2d(128,inplace)
 )
 (conv3): Sequential(
 (0): Conv2d(128,inplace)
 )
 (conv3_1): Sequential(
 (0): Conv2d(256,inplace)
 )
 (conv4): Sequential(
 (0): Conv2d(256,inplace)
 )
 (conv4_1): Sequential(
 (0): Conv2d(512,inplace)
 )
 (conv5): Sequential(
 (0): Conv2d(512,inplace)
 )
 (conv5_1): Sequential(
 (0): Conv2d(512,inplace)
 )
 (conv6): Sequential(
 (0): Conv2d(512,inplace)
 )
 (conv6_1): Sequential(
 (0): Conv2d(1024,inplace)
 )
 (deconv5): Sequential(
 (0): ConvTranspose2d(1024,kernel_size=(4,4),inplace)
 )
 (deconv4): Sequential(
 (0): ConvTranspose2d(1026,inplace)
 )
 (deconv3): Sequential(
 (0): ConvTranspose2d(770,inplace)
 )
 (deconv2): Sequential(
 (0): ConvTranspose2d(386,inplace)
 )
 (inter_conv5): Sequential(
 (0): Conv2d(1026,1))
 )
 (inter_conv4): Sequential(
 (0): Conv2d(770,1))
 )
 (inter_conv3): Sequential(
 (0): Conv2d(386,1))
 )
 (inter_conv2): Sequential(
 (0): Conv2d(194,1))
 )
 (predict_flow6): Conv2d(1024,1))
 (predict_flow5): Conv2d(512,1))
 (predict_flow4): Conv2d(256,1))
 (predict_flow3): Conv2d(128,1))
 (predict_flow2): Conv2d(64,1))
 (upsampled_flow6_to_5): ConvTranspose2d(2,1))
 (upsampled_flow5_to_4): ConvTranspose2d(2,1))
 (upsampled_flow4_to_3): ConvTranspose2d(2,1))
 (upsampled_flow3_to_2): ConvTranspose2d(2,1))
 (upsample1): Upsample(scale_factor=4.0,mode=bilinear)
)
conv0 Sequential(
 (0): Conv2d(6,inplace)
)
conv0.0 Conv2d(6,1))
conv0.1 LeakyReLU(negative_slope=0.1,inplace)
conv1 Sequential(
 (0): Conv2d(64,inplace)
)
conv1.0 Conv2d(64,1))
conv1.1 LeakyReLU(negative_slope=0.1,inplace)
conv1_1 Sequential(
 (0): Conv2d(64,inplace)
)
conv1_1.0 Conv2d(64,1))
conv1_1.1 LeakyReLU(negative_slope=0.1,inplace)
conv2 Sequential(
 (0): Conv2d(128,inplace)
)
conv2.0 Conv2d(128,1))
conv2.1 LeakyReLU(negative_slope=0.1,inplace)
conv2_1 Sequential(
 (0): Conv2d(128,inplace)
)
conv2_1.0 Conv2d(128,1))
conv2_1.1 LeakyReLU(negative_slope=0.1,inplace)
conv3 Sequential(
 (0): Conv2d(128,inplace)
)
conv3.0 Conv2d(128,1))
conv3.1 LeakyReLU(negative_slope=0.1,inplace)
conv3_1 Sequential(
 (0): Conv2d(256,inplace)
)
conv3_1.0 Conv2d(256,1))
conv3_1.1 LeakyReLU(negative_slope=0.1,inplace)
conv4 Sequential(
 (0): Conv2d(256,inplace)
)
conv4.0 Conv2d(256,1))
conv4.1 LeakyReLU(negative_slope=0.1,inplace)
conv4_1 Sequential(
 (0): Conv2d(512,inplace)
)
conv4_1.0 Conv2d(512,1))
conv4_1.1 LeakyReLU(negative_slope=0.1,inplace)
conv5 Sequential(
 (0): Conv2d(512,inplace)
)
conv5.0 Conv2d(512,1))
conv5.1 LeakyReLU(negative_slope=0.1,inplace)
conv5_1 Sequential(
 (0): Conv2d(512,inplace)
)
conv5_1.0 Conv2d(512,1))
conv5_1.1 LeakyReLU(negative_slope=0.1,inplace)
conv6 Sequential(
 (0): Conv2d(512,inplace)
)
conv6.0 Conv2d(512,1))
conv6.1 LeakyReLU(negative_slope=0.1,inplace)
conv6_1 Sequential(
 (0): Conv2d(1024,inplace)
)
conv6_1.0 Conv2d(1024,1))
conv6_1.1 LeakyReLU(negative_slope=0.1,inplace)
deconv5 Sequential(
 (0): ConvTranspose2d(1024,inplace)
)
deconv5.0 ConvTranspose2d(1024,1))
deconv5.1 LeakyReLU(negative_slope=0.1,inplace)
deconv4 Sequential(
 (0): ConvTranspose2d(1026,inplace)
)
deconv4.0 ConvTranspose2d(1026,1))
deconv4.1 LeakyReLU(negative_slope=0.1,inplace)
deconv3 Sequential(
 (0): ConvTranspose2d(770,inplace)
)
deconv3.0 ConvTranspose2d(770,1))
deconv3.1 LeakyReLU(negative_slope=0.1,inplace)
deconv2 Sequential(
 (0): ConvTranspose2d(386,inplace)
)
deconv2.0 ConvTranspose2d(386,1))
deconv2.1 LeakyReLU(negative_slope=0.1,inplace)
inter_conv5 Sequential(
 (0): Conv2d(1026,1))
)
inter_conv5.0 Conv2d(1026,1))
inter_conv4 Sequential(
 (0): Conv2d(770,1))
)
inter_conv4.0 Conv2d(770,1))
inter_conv3 Sequential(
 (0): Conv2d(386,1))
)
inter_conv3.0 Conv2d(386,1))
inter_conv2 Sequential(
 (0): Conv2d(194,1))
)

因為Pytorch模型用name_modules()輸出不是按順序的,動態圖機制決定了只有在有資料流動之後才知道走過的路徑。所以上面的順序也是亂的。但我想表明的是,我用Keras搭建的模型確實是根據官方開源的Pytorch模型搭建的。

模型搭建完畢之後,就到了關鍵的步驟:給Keras模型賦值。

給Keras模型賦值

這個步驟其實注意三個點

Pytorch是channels_first的,Keras預設是channels_last,在程式碼開頭加上這兩句:

K.set_image_data_format(‘channels_first')
K.set_learning_phase(0)

眾所周知,卷積層的權重是一個4維張量,那麼,在Pytorch和keras中,卷積核的權重的形式是否一致的,那自然是不一致的,要不然我為啥還要寫這一點。那麼就涉及到Pytorch權重的變形。

既然卷積層權重形式在兩個框架是不一致的,轉置卷積自然也是不一致的。

我們先看看卷積層在兩個框架中的形式

keras的卷積層權重形式

我們用以下程式碼看keras卷積層權重形式

 for l in model.layers:
  print(l.name)
  for i,w in enumerate(l.get_weights()):
   print('%d'%i,w.shape)

第一個卷積層輸出如下 0之後是卷積權重的shape,1之後的是偏置項

conv0
0 (3,3,64)
1 (64,)

所以Keras的卷積層權重形式是[ height,width,input_channels,out_channels]

Pytorch的卷積層權重形式

 net = FlowNet2SD()
 for n,m in net.named_parameters():
  print(n)
  print(m.data.size())

conv0.0.weight
torch.Size([64,3])
conv0.0.bias
torch.Size([64])

用上面的程式碼得到所有層的引數的shape,同樣找到第一個卷積層的引數,檢視shape。

通過對比我們可以發現,Pytorch的卷積層shape是[ out_channels,height,width]的形式。

那麼我們在取出Pytorch權重之後,需要用np.transpose改變一下權重的排序,才能送到Keras模型對應的層上。

Keras中轉置卷積權重形式

deconv4
0 (4,4,1026)
1 (256,)

程式碼仍然和上面一樣,找到轉置卷積的對應的位置,檢視一下

可以看出在Keras中,轉置卷積形式是 [ height,out_channels,input_channels]

Pytorch中轉置卷積權重形式

deconv4.0.weight
torch.Size([1026,4])
deconv4.0.bias
torch.Size([256])

程式碼仍然和上面一樣,找到轉置卷積的對應的位置,檢視一下

可以看出在Pytorch中,轉置卷積形式是 [ input_channels,width]

小結

對於卷積層來說,Pytorch的權重需要使用

np.transpose(weight.data.numpy(),[2,1,0])

才能賦值給keras模型對應的層的權重。

對於轉置卷積來說,通過對比其實也是一樣的。不信你去試試嘛。O(∩_∩)O哈哈~

對於偏置項,兩種模組都是一維的向量,不需要處理。

有的情況還可能需要通道顛倒一下,但是很少需要這樣做。

weights[::-1,::-1,:,:]

賦值

結束了預處理之後,我們就進入第二步,開始賦值了。

先看預處理的程式碼:

for k,v in weights_from_torch.items():
 if 'bias' not in k:
  weights_from_torch[k] = v.data.numpy().transpose(2,0)

賦值程式碼我只截了一部分供大家參考:

k_model = k_model()
for layer in k_model.layers:
 current_layer_name = layer.name
 if current_layer_name=='conv0':
  weights = [weights_from_torch['conv0.0.weight'],weights_from_torch['conv0.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1':
  weights = [weights_from_torch['conv1.0.weight'],weights_from_torch['conv1.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1_1':
  weights = [weights_from_torch['conv1_1.0.weight'],weights_from_torch['conv1_1.0.bias']]
  layer.set_weights(weights)

首先就是定義Keras模型,用layers獲得所有層的迭代器。

遍歷迭代器,對一個層賦予相應的值。

賦值需要用save_weights,其引數需要是一個列表,形式和get_weights的返回結果一致,即 [ conv_weights,bias_weights]

最後祝願大家能實現自己模型的遷移。工程開源在了個人Github,有詳細的使用介紹,並且包含使用資料,大家可以直接執行。

以上這篇Pytorch轉keras的有效方法,以FlowNet為例講解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。