1. 程式人生 > >PyTorch(八)——pyTorch-To-Caffe

PyTorch(八)——pyTorch-To-Caffe

PyTorch的學習和使用(八)

Mon 22Mon 29Mon 05settrace grad_fn pyTorch1.0 pyTorchToCaffe完成進度表

一、pyTorch to Caffe \Rightarrow 動態圖到靜態圖的轉換

靜態圖: 網路在輸入資料前就預先將網路定義好,與資料無關,將所有的操作過程定義好,執行時填入資料。比如caffe,tensorFlow等框架。
動態圖: 資料在網路傳輸中動態的構建網路。比如pyTorch等。

兩者各有優點,靜態圖由於提前將網路結構確定了,部署十分方便,但是由於資料在網路中的傳遞過程往往不可知,因此除錯較為困難;動態圖是根據資料的流動動態的構建網路圖,因此資料在網路中的狀態都是已知的,除錯十分便捷。目前,兩者都在吸取對方的優點,tensorFlow也在也如動態圖的機制,pyTorch與Caffe2結合,在配合ONNX實現高效的部署。

pyTorch模型轉換到caffe模型可以看為動態圖到靜態圖之間的轉換,主要需要進動態圖到靜態圖之間的轉換,即構建出動態圖然後將其對映到靜態圖,並且將網路引數也進行轉換。

二、python trace 機制捕獲動態圖

pyTroch框架採用python構建,通過使用python的trace機制可以獲取到網路在傳遞過程中所經過的結構,從而對映到靜態圖。主要步驟如下:

  1. 啟動python的trace功能,並定義其回撥函式。
  2. 在回撥函式中捕獲網路所呼叫的原子操作。
  3. 將對應的操作使用caffe的python介面進行對映。
  4. 將相應pyTorch的網路引數對映到caffe模型。
  5. 儲存caffe模型,關閉python的trace
    功能。

程式碼框架如下:

import sys
import torch
from caffe import layers as L, params as P, to_proto

def tracea_fun(frame, event, arg):
    //通過當前的frame棧得到每次呼叫的函式,並將其轉換為相應的caffe呼叫
    
def main(model, input):
    sys.settrace(trace_fun)
    output = model(input)
    sys.settrace(None)

if __name__ == "__main__":
    input = DataLoder()
    model = Net()
    
    main(model, input)

2.1 sys.settrace 操作捕獲

Set the system’s trace function, which allows you to implement a Python source code debugger in Python.
Trace functions should have three arguments: frame, event, and arg. frame is the current stack frame. event is a string: ‘call’, ‘line’, ‘return’ or ‘exception’. arg depends on the event type.

因此通過sys.settrace的回撥函式中的frame棧可以捕獲當前的操作,其中frameframe objects,定義見The standard type hierarchy,常用屬性有:

  • f_code: The code object being executed in this frame
    • co_name: Function name
    • co_varnames: A tuple containing the names of the local variables
  • f_locals: The dictionary used to look up local variables
  • f_back: The previous stack frame

則通過frame.f_code.co_namefrmae.f_locals可以獲得網路傳遞過程中的函式名和引數。

2.2 pyTorch原子操作捕獲

實現該方法的難點在於如何找到網路中資料的流向,比如進行的view操作和resNet網路中何時進行add操作,這些操作在pyTorch0.2中都封裝成了相應的原子操作,只需要找到對應的呼叫函式即可(但是在pyToch0.3以上中直接呼叫C的介面,暫時不會怎麼使用settrace進行捕捉)。

以卷基層為例,在trace_fun中的conv2d程式碼如下:

def trace_fun(frame, event, arg):
    if frame.f_code.co_name == "conv2d":
        groups = frame.f_locals["groups"]
        pad_h = frame.f_locals["padding"][0]
        pad_w = frame.f_locals["padding"][1]
        stride_h = frame.f_locals["stride"][0]
        stride_w = frame.f_locals["stride"][1]
        dilation = frame.f_locals["dilation"]
        weight = frame.f_locals["weight"]
        bias = frame.f_locals["bias"]
        bottom = getBottom()
        name = "conv1"
        
        top = L.Convolution(bottom, name=name,
                            kernel_h=kernel_h, kernel_w = kernel_w,
                            num_output=num_output, groups=groups,
                            stride_h=stride_h, stride_w=stride_w,
                            pad_h=pad_h, pad_w=pad_w,
                            dilation=dilation)

其中,getBottom()為獲取當前層的前一層,通過維護一個容器,在容器中以每層的實體地址作為該層的唯一索引進行檢索,即使用id(feature)來確定其前一層。

注意,該方法只用於第一個pyTorch0.2之前的版本,在0.3之後的版本通過直接呼叫C介面的方式,目前還不會將其操作棧剝離出來。

三、pyTorch grad_fn網路拓撲圖構建

四、pyTorch1.0 ONNX和caffe2之間的使用