pytorch 影象中的資料預處理和批標準化例項
阿新 • • 發佈:2020-01-16
目前資料預處理最常見的方法就是中心化和標準化。
中心化相當於修正資料的中心位置,實現方法非常簡單,就是在每個特徵維度上減去對應的均值,最後得到 0 均值的特徵。
標準化也非常簡單,在資料變成 0 均值之後,為了使得不同的特徵維度有著相同的規模,可以除以標準差近似為一個標準正態分佈,也可以依據最大值和最小值將其轉化為 -1 ~ 1 之間
批標準化:BN
在資料預處理的時候,我們儘量輸入特徵不相關且滿足一個標準的正態分佈,這樣模型的表現一般也較好。但是對於很深的網路結構,網路的非線性層會使得輸出的結果變得相關,且不再滿足一個標準的 N(0,1) 的分佈,甚至輸出的中心已經發生了偏移,這對於模型的訓練,特別是深層的模型訓練非常的困難。
所以在 2015 年一篇論文提出了這個方法,批標準化,簡而言之,就是對於每一層網路的輸出,對其做一個歸一化,使其服從標準的正態分佈,這樣後一層網路的輸入也是一個標準的正態分佈,所以能夠比較好的進行訓練,加快收斂速度。
batch normalization 的實現非常簡單,接下來寫一下對應的python程式碼:
import sys sys.path.append('..') import torch def simple_batch_norm_1d(x,gamma,beta): eps = 1e-5 x_mean = torch.mean(x,dim=0,keepdim=True) # 保留維度進行 broadcast x_var = torch.mean((x - x_mean) ** 2,keepdim=True) x_hat = (x - x_mean) / torch.sqrt(x_var + eps) return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean) x = torch.arange(15).view(5,3) gamma = torch.ones(x.shape[1]) beta = torch.zeros(x.shape[1]) print('before bn: ') print(x) y = simple_batch_norm_1d(x,beta) print('after bn: ') print(y)
測試的時候該使用批標準化嗎?
答案是肯定的,因為訓練的時候使用了,而測試的時候不使用肯定會導致結果出現偏差,但是測試的時候如果只有一個數據集,那麼均值不就是這個值,方差為 0 嗎?這顯然是隨機的,所以測試的時候不能用測試的資料集去算均值和方差,而是用訓練的時候算出的移動平均均值和方差去代替
下面我們實現以下能夠區分訓練狀態和測試狀態的批標準化方法
def batch_norm_1d(x,beta,is_training,moving_mean,moving_var,moving_momentum=0.1): eps = 1e-5 x_mean = torch.mean(x,keepdim=True) if is_training: x_hat = (x - x_mean) / torch.sqrt(x_var + eps) moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var else: x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps) return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
下面我們在卷積網路下試用一下批標準化看看效果
def data_tf(x): x = np.array(x,dtype='float32') / 255 x = (x - 0.5) / 0.5 # 資料預處理,標準化 x = torch.from_numpy(x) x = x.unsqueeze(0) return x train_set = mnist.MNIST('./data',train=True,transform=data_tf,download=True) # 重新載入資料集,申明定義的資料變換 test_set = mnist.MNIST('./data',train=False,download=True) train_data = DataLoader(train_set,batch_size=64,shuffle=True) test_data = DataLoader(test_set,batch_size=128,shuffle=False) # 使用批標準化 class conv_bn_net(nn.Module): def __init__(self): super(conv_bn_net,self).__init__() self.stage1 = nn.Sequential( nn.Conv2d(1,6,3,padding=1),nn.BatchNorm2d(6),nn.ReLU(True),nn.MaxPool2d(2,2),nn.Conv2d(6,16,5),nn.BatchNorm2d(16),2) ) self.classfy = nn.Linear(400,10) def forward(self,x): x = self.stage1(x) x = x.view(x.shape[0],-1) x = self.classfy(x) return x net = conv_bn_net() optimizer = torch.optim.SGD(net.parameters(),1e-1) # 使用隨機梯度下降,學習率 0.1 train(net,train_data,test_data,5,optimizer,criterion)
以上這篇pytorch 影象中的資料預處理和批標準化例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。