caffe權值及featureMap視覺化
阿新 • • 發佈:2018-12-11
1、權值視覺化
主函式 conv1_weights_vis.m,放在caffe根目錄,需要matcaffe
clear; clc; close all; addpath('matlab') caffe.set_mode_cpu(); fprintf(['Caffe Version = ', caffe.version(), '\n']); net = caffe.Net('models/bvlc_reference_caffenet/deploy.prototxt', 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel', 'test'); fprintf('Load net done. Net layers : '); net.layer_names fprintf('Net blobs : '); net.blob_names % Conv1 Weight Visualization conv1_layer = net.layer_vec(2); blob1 = conv1_layer.params(1); w = blob1.get_data(); fprintf('Conv1 Weight shape: '); size(w) visualize_weights(w, 1); % Conv2 Weight Visualization conv2_layer = net.layer_vec(6); blob2 = conv2_layer.params(1); w2 = blob2.get_data(); fprintf('Conv2 Weight shape: '); size(w2) visualize_weights(w2, 1); % Conv3 Weight Visualization conv3_layer = net.layer_vec(10); blob3 = conv3_layer.params(1); w3 = blob3.get_data(); fprintf('Conv3 Weight shape: '); size(w3) visualize_weights(w3, 1); % Conv4 Weight Visualization conv4_layer = net.layer_vec(12); blob4 = conv4_layer.params(1); w4 = blob4.get_data(); fprintf('Conv4 Weight shape: '); size(w4) visualize_weights(w4, 1); % Conv5 Weight Visualization conv5_layer = net.layer_vec(14); blob5 = conv5_layer.params(1); w5 = blob5.get_data(); fprintf('Conv5 Weight shape: '); size(w5) visualize_weights(w5, 1);
visualize_weights.m
function [] = visualize_weights(w, s) rr=size(w,1); cc=size(w,2); h = max(rr, cc); % Kernel size g = h + s; % Grid size, larger than Kernel size for better visual effects. % Normalization for gray scale w = w - min(min(min(min(w)))); w = w / max(max(max(max(w)))) * 255; w = uint8(w); W = zeros(g * size(w, 3), g * size(w, 4)); for u = 1:size(w, 3) for v = 1:size(w, 4) W(g * (u - 1) + (1:cc), g * (v -1) + (1:rr)) = w(:,:,u,v)'; % figure,imshow(uint8(w(:,:,u,v))); end end W = uint8(W); figure;imshow(W);
2、featureMap視覺化
clear; clc; close all; addpath('matlab') caffe.set_mode_cpu(); fprintf(['Caffe Version = ', caffe.version(), '\n']); net = caffe.Net('models/bvlc_reference_caffenet/deploy.prototxt', 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel', 'test'); fprintf('Load net done. Net layers : '); net.layer_names fprintf('Net blobs : '); net.blob_names im=imread('examples/images/cat.jpg'); figure,imshow(im);title('Original Image'); d=load('matlab/+caffe/imagenet/ilsvrc_2012_mean.mat'); mean_data=d.mean_data; IMAGE_DIM=256; CROPPED_DIM=227; im_data=im(:,:,[3,2,1]); %matlab影象通道是RGB,轉換為opencv格式BGR im_data=permute(im_data,[2,1,3]); %matlab內部是列優先儲存,轉化為opencv格式的行優先儲存 im_data=single(im_data); %將uint8格式,轉化為single型別。 im_data=imresize(im_data,[IMAGE_DIM IMAGE_DIM],'bilinear'); im_data=im_data-mean_data; im=imresize(im_data,[CROPPED_DIM CROPPED_DIM],'bilinear'); km=cat(4,im,im,im,im,im); % 227*227*3*5 pm=cat(4,km,km); % 227*227*3*10 因為輸入要求為 input_param { shape: { dim: 10 dim: 3 dim: 227 dim: 227 } },注意順序反了 input_data={pm}; scores=net.forward(input_data); scores=scores{1}; scores=mean(scores,2); [~,maxlabel]=max(scores); maxlabel figure;plot(scores); fm_data=net.blob_vec(1); d1=fm_data.get_data(); fprintf('Data size=') size(d1) visualize_feature_maps(d1,1); % 卷積層1 fm_conv1=net.blob_vec(2); f1=fm_conv1.get_data(); fprintf('Feature map conv1 size=') size(f1) visualize_feature_maps(f1,1); % 卷積層2 fm_conv2=net.blob_vec(5); f2=fm_conv2.get_data(); fprintf('Feature map conv2 size=') size(f2) visualize_feature_maps(f2,1); % 卷積層3 fm_conv3=net.blob_vec(8); f3=fm_conv3.get_data(); fprintf('Feature map conv3 size=') size(f3) visualize_feature_maps(f3,1); % 卷積層4 fm_conv4=net.blob_vec(9); f4=fm_conv4.get_data(); fprintf('Feature map conv4 size=') size(f4) visualize_feature_maps(f4,1); % 卷積層5 fm_conv5=net.blob_vec(10); f5=fm_conv5.get_data(); fprintf('Feature map conv5 size=') size(f5) visualize_feature_maps(f5,1);
function []=visualize_feature_maps(w,s)
h=max(size(w,1),size(w,2));
g=h+s;
c=size(w,3);
cv=ceil(sqrt(c));
W=zeros(g*cv,g*cv);
%%% 缺少最後一個通道10
for u=1:cv
for v=1:cv
tw=zeros(h,h);
if(((u-1)*cv+v)<=c)
tw=w(:,:,(u-1)*cv+v,1)';
tw=tw-min(min(tw));
tw=tw/max(max(tw))*255;
end
W(g*(u-1)+(1:h),g*(v-1)+(1:h))=tw;
end
end
W=uint8(W);
figure,imshow(W);