1. 程式人生 > >Pytorch MNIST資料集標準化為什麼是transforms.Normalize((0.1307,), (0.3081,))

Pytorch MNIST資料集標準化為什麼是transforms.Normalize((0.1307,), (0.3081,))

Pytorch已經提供了MNIST資料集,只要呼叫datasets.MNIST()下載即可,這裡要注意的是標準化(Normalization):

transforms.Normalize((0.1307,), (0.3081,))

標準化(Normalization)

和基於決策樹的機器學習模型,如RF、xgboost等不同的是,神經網路特別鍾愛經過標準化處理後的資料。標準化處理指的是,data減去它的均值,再除以它的標準差,最終data將呈現均值為0方差為1的資料分佈。決策樹模型在哪裡split特徵是由特徵序列決定的,跟具體數值無關,所以並不要求資料做標準化處理,至於詳細原因以後有機會寫機器學習博文時再詳述。

神經網路模型偏愛標準化資料,原因是均值為0方差為1的資料在sigmoid、tanh經過啟用函式後求導得到的導數很大,反之原始資料不僅分佈不均(噪聲大)而且數值通常都很大(本例中數值範圍是0~255),啟用函式後求導得到的導數則接近與0,這也被稱為梯度消失。前文已經分析,神經網路是根據函式對權值求導的導數來調整權值,導數越大,調整幅度越大,越快逼近目標函式,反之,導數越小,調整幅度越小,所以說,資料的標準化有利於加快神經網路的訓練。

除此之外,還需要保持train_set、val_set和test_set標準化係數的一致性。標準化係數就是計算要用到的均值和標準差,在本例中是((0.1307,), (0.3081,)),均值是0.1307,標準差是0.3081,這些係數都是資料集提供方計算好的資料

不同資料集就有不同的標準化係數,例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是Imagenet dataset的標準化係數(RGB三個通道對應三組係數),當需要將imagenet預訓練的引數遷移到另一神經網路時,被遷移的神經網路就需要使用imagenet的係數,否則預訓練不僅無法起到應有的作用甚至還會幫倒忙,

例如,我們想要用神經網路來識別夜空中的星星,因為黑色是夜空的主旋律,從畫素上看黑色就是資料集的均值,標準化操作時,所有影象會減去均值(黑色),如此Imagenet預訓練的神經網路很難識別出這些資料是夜空影象。