Pytorch的mean和std調查例項
阿新 • • 發佈:2020-01-09
如下所示:
# coding: utf-8 from __future__ import print_function import copy import click import cv2 import numpy as np import torch from torch.autograd import Variable from torchvision import models,transforms import matplotlib.pyplot as plt import load_caffemodel import scipy.io as sio # if model has LSTM # torch.backends.cudnn.enabled = False imgpath = 'D:/ck/files_detected_face224/' imgname = 'S055_002_00000025.png' # anger image_path = imgpath + imgname mean_file = [0.485,0.456,0.406] std_file = [0.229,0.224,0.225] raw_image = cv2.imread(image_path)[...,::-1] print(raw_image.shape) raw_image = cv2.resize(raw_image,(224,) * 2) image = transforms.Compose([ transforms.ToTensor(),transforms.Normalize( mean=mean_file,std =std_file,#mean = mean_file,#std = std_file,) ])(raw_image).unsqueeze(0) print(image.shape) convert_image1 = image.numpy() convert_image1 = np.squeeze(convert_image1) # 3* 224 *224,C * H * W convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,1)) convert_image1 = np.transpose(convert_image1,(1,2,0)) # H * W * C print(convert_image1.shape) convert_image1 = convert_image1 * 255 diff = raw_image - convert_image1 err = np.max(diff) print(err) plt.imshow(np.uint8(convert_image1)) plt.show()
結論:
input_image = (raw_image / 255 - mean) ./ std
下面調查均值檔案和方差檔案是如何生成的:
mean_file = [0.485,0.225]
# coding: utf-8 import matplotlib.pyplot as plt import argparse import os import numpy as np import torchvision import torchvision.transforms as transforms dataset_names = ('cifar10','cifar100','mnist') parser = argparse.ArgumentParser(description='PyTorchLab') parser.add_argument('-d','--dataset',metavar='DATA',default='cifar10',choices=dataset_names,help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)') args = parser.parse_args() data_dir = os.path.join('.',args.dataset) print(args.dataset) args.dataset = 'cifar10' if args.dataset == "cifar10": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR10(root=data_dir,train=True,download=True,transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(train_set.train_data.mean(axis=(0,2))/255) print(train_set.train_data.std(axis=(0,2))/255) # imshow image train_data = train_set.train_data ind = 100 img0 = train_data[ind,...] ## test channel number,in total,the correct channel is : RGB,not like BGR in caffe # error produce #b,g,r=cv2.split(img0) #img0=cv2.merge([r,b]) print(img0.shape) print(type(img0)) plt.imshow(img0) plt.show() # in ship in sea #img0 = cv2.resize(img0,224)) #cv2.imshow('img0',img0) #cv2.waitKey() elif args.dataset == "cifar100": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR100(root=data_dir,transform=train_transform) #print(vars(train_set)) print(train_set.train_data.shape) print(np.mean(train_set.train_data,axis=(0,2))/255) print(np.std(train_set.train_data,2))/255) elif args.dataset == "mnist": train_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.MNIST(root=data_dir,transform=train_transform) #print(vars(train_set)) print(list(train_set.train_data.size())) print(train_set.train_data.float().mean()/255) print(train_set.train_data.float().std()/255)
結果:
cifar10 Files already downloaded and verified (50000,32,3) [ 0.49139968 0.48215841 0.44653091] [ 0.24703223 0.24348513 0.26158784] (32,3) <class 'numpy.ndarray'>
使用matlab檢測是如何計算mean_file和std_file的:
% load cifar10 dataset data = load('cifar10_train_data.mat'); train_data = data.train_data; disp(size(train_data)); temp = mean(train_data,1); disp(size(temp)); train_data = double(train_data); % compute mean_file mean_val = mean(mean(mean(train_data,1),2),3)/255; % compute std_file temp1 = train_data(:,:,1); std_val1 = std(temp1(:))/255; temp2 = train_data(:,2); std_val2 = std(temp2(:))/255; temp3 = train_data(:,3); std_val3 = std(temp3(:))/255; mean_val = squeeze(mean_val); std_val = [std_val1,std_val2,std_val3]; disp(mean_val); disp(std_val); % result: mean_val: [0.4914,0.4822,0.4465] % std_val: [0.2470,0.2435,0.2616]
均值計算的過程也可以遵循標準差的計算過程。為 了簡單,例如對於一個矩陣,所有元素的均值,等於兩個方向上先後均值。所以會直接採用如下的形式:
mean_val = mean(mean(mean(train_data,3)/255;
標準差的計算是每一個通道的對所有樣本的求標準差。然後再除以255。
以上這篇Pytorch的mean和std調查例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。