1. 程式人生 > >Caffe提取任意層特徵並進行視覺化

Caffe提取任意層特徵並進行視覺化

原圖

conv1層視覺化結果 (96個filter得到的結果)

資料模型與準備

安裝好Caffe後,在examples/images資料夾下有兩張示例影象,本文即在這兩張影象上,用Caffe提供的預訓練模型,進行特徵提取,並進行視覺化。

1. 進入caffe根目錄,建立臨時資料夾,用於存放所需要的臨時檔案

mkdir examples/_temp

2. 根據examples/images資料夾中的圖片,建立包含影象列表的txt檔案,並新增標籤(0)

find `pwd`/examples/images -type f -exec echo {} \; > examples/_temp/temp.txt
sed "s/$/ 0/" examples/_temp/temp.txt > examples/_temp/file_list.txt

3. 執行下列指令碼,下載imagenet12影象均值檔案,在後面的網路結構定義prototxt檔案中,需要用到該檔案 (data/ilsvrc212/imagenet_mean.binaryproto)

data/ilsvrc12/get_ilsvrc_aux.sh

4. 將網路定義prototxt檔案複製到_temp資料夾下

cp examples/feature_extraction/imagenet_val.prototxt examples/_temp

提取特徵

 1. 建立 src/youname/ 資料夾, 存放我們自己的指令碼

mkdir src/yourname

2. caffe的  extract_features 將提取出的影象特徵存為leveldb格式, 為了方便觀察特徵,我們將利用下列兩個python指令碼將影象轉化為matlab的.mat格式 (請先安裝caffe的python依賴庫)

feat_helper_pb2.py

複製程式碼
# Generated by the protocol buffer compiler.  DO NOT EDIT!

from google.protobuf import descriptor
from
google.protobuf import message from google.protobuf import reflection from google.protobuf import descriptor_pb2 # @@protoc_insertion_point(imports) DESCRIPTOR = descriptor.FileDescriptor( name='datum.proto', package='feat_extract', serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02') _DATUM = descriptor.Descriptor( name='Datum', full_name='feat_extract.Datum', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='channels', full_name='feat_extract.Datum.channels', index=0, number=1, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='height', full_name='feat_extract.Datum.height', index=1, number=2, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='width', full_name='feat_extract.Datum.width', index=2, number=3, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='data', full_name='feat_extract.Datum.data', index=3, number=4, type=12, cpp_type=9, label=1, has_default_value=False, default_value="", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='label', full_name='feat_extract.Datum.label', index=4, number=5, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='float_data', full_name='feat_extract.Datum.float_data', index=5, number=6, type=2, cpp_type=6, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, extension_ranges=[], serialized_start=29, serialized_end=134, ) DESCRIPTOR.message_types_by_name['Datum'] = _DATUM class Datum(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _DATUM # @@protoc_insertion_point(class_scope:feat_extract.Datum) # @@protoc_insertion_point(module_scope)
複製程式碼


leveldb2mat.py

複製程式碼
import leveldb
import feat_helper_pb2
import numpy as np
import scipy.io as sio
import time

def main(argv):
    leveldb_name = sys.argv[1]
    print "%s" % sys.argv[1]
    batch_num = int(sys.argv[2]);
    batch_size = int(sys.argv[3]);
    window_num = batch_num*batch_size;

    start = time.time()
    if 'db' not in locals().keys():
        db = leveldb.LevelDB(leveldb_name)
        datum = feat_helper_pb2.Datum()

    ft = np.zeros((window_num, int(sys.argv[4])))
    for im_idx in range(window_num):
        datum.ParseFromString(db.Get('%d' %(im_idx)))
        ft[im_idx, :] = datum.float_data

    print 'time 1: %f' %(time.time() - start)
    sio.savemat(sys.argv[5], {'feats':ft})
    print 'time 2: %f' %(time.time() - start)
    print 'done!'

    #leveldb.DestroyDB(leveldb_name)

if __name__ == '__main__':
    import sys
    main(sys.argv)
複製程式碼

3. 建立指令碼檔案extract_feature.sh, 並執行,將在examples/_temp資料夾下得到leveldb檔案(features_conv1)和.mat檔案(features.mat)

複製程式碼
#!/usr/bin/env sh
# args for EXTRACT_FEATURE
TOOL=../../build/tools
MODEL=../../examples/imagenet/caffe_reference_imagenet_model #下載得到的caffe model
PROTOTXT=../../examples/_temp/imagenet_val.prototxt # 網路定義
LAYER=conv1 # 提取層的名字,如提取fc7等
LEVELDB=../../examples/_temp/features_conv1 # 儲存的leveldb路徑
BATCHSIZE=10

# args for LEVELDB to MAT
DIM=290400 # 需要手工計算feature長度
OUT=../../examples/_temp/features.mat #.mat檔案儲存路徑
BATCHNUM=1 # 有多少哥batch, 本例只有兩張圖, 所以只有一個batch

$TOOL/extract_features.bin  $MODEL $PROTOTXT $LAYER $LEVELDB $BATCHSIZE
python leveldb2mat.py $LEVELDB $BATCHNUM  $BATCHSIZE $DIM $OUT 
複製程式碼

4. 得到.mat檔案後,需要對其進行視覺化,這裡用了UFLDL裡的display_network函式,由於可視化出來結果進行了翻轉,因此對原始碼的67, 69, 83, 85行進行了修改

display_network.m 存放在 src/yourname資料夾下

複製程式碼
function [h, array] = display_network(A, opt_normalize, opt_graycolor, cols, opt_colmajor)
% This function visualizes filters in matrix A. Each column of A is a
% filter. We will reshape each column into a square image and visualizes
% on each cell of the visualization panel. 
% All other parameters are optional, usually you do not need to worry
% about it.
% opt_normalize: whether we need to normalize the filter so that all of
% them can have similar contrast. Default value is true.
% opt_graycolor: whether we use gray as the heat map. Default is true.
% cols: how many columns are there in the display. Default value is the
% squareroot of the number of columns in A.
% opt_colmajor: you can switch convention to row major for A. In that
% case, each row of A is a filter. Default value is false.
warning off all

if ~exist('opt_normalize', 'var') || isempty(opt_normalize)
    opt_normalize= true;
end

if ~exist('opt_graycolor', 'var') || isempty(opt_graycolor)
    opt_graycolor= true;
end

if ~exist('opt_colmajor', 'var') || isempty(opt_colmajor)
    opt_colmajor = false;
end

% rescale
A = A - mean(A(:));

if opt_graycolor, colormap(gray); end

% compute rows, cols
[L M]=size(A);
sz=sqrt(L);
buf=1;
if ~exist('cols', 'var')
    if floor(sqrt(M))^2 ~= M
        n=ceil(sqrt(M));
        while mod(M, n)~=0 && n<1.2*sqrt(M), n=n+1; end
        m=ceil(M/n);
    else
        n=sqrt(M);
        m=n;
    end
else
    n = cols;
    m = ceil(M/n);
end

array=-ones(buf+m*(sz+buf),buf+n*(sz+buf));

if ~opt_graycolor
    array = 0.1.* array;
end


if ~opt_colmajor
    k=1;
    for i=1:m
        for j=1:n
            if k>M, 
                continue; 
            end
            clim=max(abs(A(:,k)));
            if opt_normalize
                array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
            else
                array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/max(abs(A(:)));
            end
            k=k+1;
        end
    end
else
    k=1;
    for j=1:n
        for i=1:m
            if k>M, 
                continue; 
            end
            clim=max(abs(A(:,k)));
            if opt_normalize
                array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
            else
                array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)';
            end
            k=k+1;
        end
    end
end

if opt_graycolor
    h=imagesc(array,'EraseMode','none',[-1 1]);
else
    h=imagesc(array,'EraseMode','none',[-1 1]);
end
axis image off

drawnow;

warning on all
複製程式碼

5. 呼叫display_network 以及提取到的feature進行視覺化:

在 examples/_temp/ 下建立如下matlab指令碼, 並執行

複製程式碼
addpath(genpath('../../src/wyang'));

nsample     = 3;
num_output  = 96;

load features.mat
width = size(feats, 2);
nmap  = width / num_output;

for i = 1:nsample
    feat = feats(i, :);
    feat = reshape(feat, [nmap num_output]);
    figure('name', sprintf('image #%d', i));
    display_network(feat);
end
複製程式碼

 下圖是在MNIST上用lenet進行conv1層卷積後得到的結果