PyTorch(八)——pyTorch-To-Caffe
PyTorch的學習和使用(八)
一、pyTorch to Caffe 動態圖到靜態圖的轉換
靜態圖: 網路在輸入資料前就預先將網路定義好,與資料無關,將所有的操作過程定義好,執行時填入資料。比如caffe,tensorFlow等框架。
動態圖: 資料在網路傳輸中動態的構建網路。比如pyTorch等。
兩者各有優點,靜態圖由於提前將網路結構確定了,部署十分方便,但是由於資料在網路中的傳遞過程往往不可知,因此除錯較為困難;動態圖是根據資料的流動動態的構建網路圖,因此資料在網路中的狀態都是已知的,除錯十分便捷。目前,兩者都在吸取對方的優點,tensorFlow也在也如動態圖的機制,pyTorch與Caffe2結合,在配合ONNX實現高效的部署。
pyTorch模型轉換到caffe模型可以看為動態圖到靜態圖之間的轉換,主要需要進動態圖到靜態圖之間的轉換,即構建出動態圖然後將其對映到靜態圖,並且將網路引數也進行轉換。
二、python trace
機制捕獲動態圖
pyTroch框架採用python構建,通過使用python的trace
機制可以獲取到網路在傳遞過程中所經過的結構,從而對映到靜態圖。主要步驟如下:
- 啟動python的
trace
功能,並定義其回撥函式。 - 在回撥函式中捕獲網路所呼叫的原子操作。
- 將對應的操作使用caffe的python介面進行對映。
- 將相應pyTorch的網路引數對映到caffe模型。
- 儲存caffe模型,關閉python的
trace
程式碼框架如下:
import sys import torch from caffe import layers as L, params as P, to_proto def tracea_fun(frame, event, arg): //通過當前的frame棧得到每次呼叫的函式,並將其轉換為相應的caffe呼叫 def main(model, input): sys.settrace(trace_fun) output = model(input) sys.settrace(None) if __name__ == "__main__": input = DataLoder() model = Net() main(model, input)
2.1 sys.settrace
操作捕獲
Set the system’s trace function, which allows you to implement a Python source code debugger in Python.
Trace functions should have three arguments: frame, event, and arg. frame is the current stack frame. event is a string: ‘call’, ‘line’, ‘return’ or ‘exception’. arg depends on the event type.
因此通過sys.settrace
的回撥函式中的frame棧可以捕獲當前的操作,其中frame
為frame objects,定義見The standard type hierarchy,常用屬性有:
- f_code: The code object being executed in this frame
- co_name: Function name
- co_varnames: A tuple containing the names of the local variables
- f_locals: The dictionary used to look up local variables
- f_back: The previous stack frame
則通過frame.f_code.co_name
和frmae.f_locals
可以獲得網路傳遞過程中的函式名和引數。
2.2 pyTorch原子操作捕獲
實現該方法的難點在於如何找到網路中資料的流向,比如進行的view
操作和resNet網路中何時進行add
操作,這些操作在pyTorch0.2中都封裝成了相應的原子操作,只需要找到對應的呼叫函式即可(但是在pyToch0.3以上中直接呼叫C的介面,暫時不會怎麼使用settrace
進行捕捉)。
以卷基層為例,在trace_fun
中的conv2d程式碼如下:
def trace_fun(frame, event, arg):
if frame.f_code.co_name == "conv2d":
groups = frame.f_locals["groups"]
pad_h = frame.f_locals["padding"][0]
pad_w = frame.f_locals["padding"][1]
stride_h = frame.f_locals["stride"][0]
stride_w = frame.f_locals["stride"][1]
dilation = frame.f_locals["dilation"]
weight = frame.f_locals["weight"]
bias = frame.f_locals["bias"]
bottom = getBottom()
name = "conv1"
top = L.Convolution(bottom, name=name,
kernel_h=kernel_h, kernel_w = kernel_w,
num_output=num_output, groups=groups,
stride_h=stride_h, stride_w=stride_w,
pad_h=pad_h, pad_w=pad_w,
dilation=dilation)
其中,getBottom()
為獲取當前層的前一層,通過維護一個容器,在容器中以每層的實體地址作為該層的唯一索引進行檢索,即使用id(feature)
來確定其前一層。
注意,該方法只用於第一個pyTorch0.2之前的版本,在0.3之後的版本通過直接呼叫C介面的方式,目前還不會將其操作棧剝離出來。