1. 程式人生 > >【8】caffe的python介面學習:caffemodel中的引數及特徵的抽取

【8】caffe的python介面學習:caffemodel中的引數及特徵的抽取

如果用公式  y=f(wx+b)

來表示整個運算過程的話,那麼w和b就是我們需要訓練的東西,w稱為權值,在cnn中也可以叫做卷積核(filter),b是偏置項。f是啟用函式,有sigmoid、relu等。x就是輸入的資料。

資料訓練完成後,儲存的caffemodel裡面,實際上就是各層的w和b值。

我們執行程式碼:

deploy=root + 'mnist/deploy.prototxt'    #deploy檔案
caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #訓練好的 caffemodel
net = caffe.Net(net_file,caffe_model,caffe.TEST)   #載入model和network

就把所有的引數和資料都載入到一個net變數裡面了,但是net是一個很複雜的object, 想直接顯示出來看是不行的。其中:

net.params: 儲存各層的引數值(w和b)

net.blobs: 儲存各層的資料值

可用命令:

[(k,v[0].data) for k,v in net.params.items()]

檢視各層的引數值,其中k表示層的名稱,v[0].data就是各層的W值,而v[1].data是各層的b值。注意:並不是所有的層都有引數,只有卷積層和全連線層才有。

也可以不檢視具體值,只想看一下shape,可用命令

[(k,v[0].data.shape) for k,v in net.params.items()]

假設我們知道其中第一個卷積層的名字叫'Convolution1', 則我們可以提取這個層的引數:

w1=net.params['Convolution1'][0].data
b1=net.params['Convolution1'][1].data

輸入這些程式碼,實際檢視一下,對你理解network非常有幫助。

同理,除了檢視引數,我們還可以檢視資料,但是要注意的是,net裡面剛開始是沒有資料的,需要執行:

net.forward()

之後才會有資料。我們可以用程式碼:

[(k,v.data.shape) for k,v in net.blobs.items()]

[(k,v.data) for k,v in net.blobs.items()]

來檢視各層的資料。注意和上面檢視引數的區別,一個是net.params, 一個是net.blobs.

實際上資料剛輸入的時候,我們叫圖片資料,卷積之後我們就叫特徵了。

如果要抽取第一個全連線層的特徵,則可用命令:

fea=net.blobs['InnerProduct1'].data

只要知道某個層的名稱,就可以抽取這個層的特徵。

推薦大家執行一下上面的所有程式碼,深入理解模型各層。

最後,總結一個程式碼:

#!/usr/bin/env python
# encoding: utf-8
'''
@author: lele Ye
@contact: [email protected]
@software: pycharm 2018.2
@file: draw.py
@time: 2018/10/22 20:01
@desc:
'''
import os,sys
caffe_root = '/home/yeler082/caffe/'
sys.path.insert(0,caffe_root+'python')
import caffe
deploy=root + 'mnist/deploy.prototxt'    #deploy檔案
caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #訓練好的 caffemodel
net = caffe.Net(deploy,caffe_model,caffe.TEST)   #載入model和network
[(k,v[0].data.shape) for k,v in net.params.items()]  #檢視各層引數規模
w1=net.params['Convolution1'][0].data  #提取引數w
b1=net.params['Convolution1'][1].data  #提取引數b
net.forward()   #執行測試

[(k,v.data.shape) for k,v in net.blobs.items()]  #檢視各層資料規模
fea=net.blobs['InnerProduct1'].data   #提取某層資料(特徵)