1. 程式人生 > >Pytorch的mean和std調查

Pytorch的mean和std調查


# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False imgpath = 'D:/ck/files_detected_face224/' imgname = 'S055_002_00000025.png' # anger image_path = imgpath + imgname mean_file = [0.485, 0.456, 0.406] std_file = [0.229, 0.224, 0.225] raw_image = cv2.imread(image_path)[..., ::-1] print(raw_image.shape) raw_image = cv2.resize(raw_image, (224
, ) * 2) image = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=mean_file, std =std_file, #mean = mean_file, #std = std_file, ) ])(raw_image).unsqueeze(0) print(image.shape) convert_image1 = image.numpy() convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1)) convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C print(convert_image1.shape) convert_image1 = convert_image1 * 255 diff = raw_image - convert_image1 err = np.max(diff) print(err) plt.imshow(np.uint8(convert_image1)) plt.show()

結論:

input_image = (raw_image / 255 - mean) ./ std 

下面調查均值檔案和方差檔案是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file  = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
                    help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
    train_transform = transforms.Compose([transforms.ToTensor()])
    train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
    #print(vars(train_set))
    print(train_set.train_data.shape)
    print(train_set.train_data.mean(axis=(0,1,2))/255)
    print(train_set.train_data.std(axis=(0,1,2))/255)

    # imshow image
    train_data = train_set.train_data
    ind = 100
    img0 = train_data[ind,...]
    ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
    # error produce
    #b,g,r=cv2.split(img0)
    #img0=cv2.merge([r,g,b])

    print(img0.shape)
    print(type(img0))
    plt.imshow(img0)
    plt.show() # in ship in sea

    #img0 = cv2.resize(img0,(224,224))
    #cv2.imshow('img0',img0)
    #cv2.waitKey()

elif args.dataset == "cifar100":
    train_transform = transforms.Compose([transforms.ToTensor()])
    train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
    #print(vars(train_set))
    print(train_set.train_data.shape)
    print(np.mean(train_set.train_data, axis=(0,1,2))/255)
    print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
    train_transform = transforms.Compose([transforms.ToTensor()])
    train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
    #print(vars(train_set))
    print(list(train_set.train_data.size()))
    print(train_set.train_data.float().mean()/255)
    print(train_set.train_data.float().std()/255)

結果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968  0.48215841  0.44653091]
[ 0.24703223  0.24348513  0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab檢測是如何計算mean_file和std_file的:


% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1  = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val  = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%          std_val: [0.2470, 0.2435, 0.2616]

均值計算的過程也可以遵循標準差的計算過程。為 了簡單,例如對於一個矩陣,所有元素的均值,等於兩個方向上先後均值。所以會直接採用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

標準差的計算是每一個通道的對所有樣本的求標準差。然後再除以255。