1. 程式人生 > >論文:accurate ,large minibatch SGD:Training ImageNet in 1 Hour

論文:accurate ,large minibatch SGD:Training ImageNet in 1 Hour

Abstract:

這篇論文發現,在 ImageNet dataset 上使用 large minibatch 會導致優化困難,但是當這個問題解決了,模型具有更好的泛化能力,並且沒有精度上的損失

為達到這個目的,我們提出了 hyper-parameter-free linear scaling rule,用來調整學習率,學習率是有關於 minibatch size 的一個函式,還提出了一個 warmup scheme 用來克服訓練早期的優化問題

1 Introduction:

本文目的是介紹 分散式同步 SGD 完成 large-scale training,我們可以將 ResNet-50 從 minibatch size 256 時間 29 hours 縮短到 minibatch size 8192 in 1 hour,獲得的精度一樣的,如下圖

為了解決 large minibatch size,我們提出了一個簡單的 hyper-parameter-free linear scaling rule 來調整學習率,為了成功應用此 rule,我們提出了一個新的 warmup strategy. 這個 strategy 在訓練的初期使用低的學習率來克服優化困難


我們之後的實驗說明了優化困難最主要的問題是 large minibatch 而不是 poor generalization ( 至少在 ImageNet 上是),而且我們說明了 linear scaling rule 和 warmup strategy 可以推廣到更復雜的任務,比如 detection 和 instance segmentation.

雖然這個 strategy 很簡單,但是它的應用需要比較好的理解,SGD裡面很小的改變有時候會得到很難發現的錯誤的結果,之後我們會描述這些常見的陷阱和解決的細節,我們的策略還需要非平凡的通訊演算法

在工業界,我們可以釋放模型訓練大量資料的潛能,在學術界我們可以簡化從單 GPU 到多 GPU 的遷移而不需要超引數搜尋

2 Large Minibatch SGD:

首先回顧基本的隨機梯度下降方法


w 是 weight , x 是有標籤的訓練資料 l(x,w)是計算的 loss ,通常 loss 是 classification loss (cross-entropy)和 regularization loss on w 的和

Minibatch SGD 在最近的文獻中被簡稱為 SGD,它的更新函式如下:


其中 B 是一個minibatch 的sample,n 是 minibatch size , η 是學習率,我們使用的是 momentum SGD ,在之後的第3部分進行討論

2.1 learning rates for large minibatches

large minibatch 在分散式學習中可以利用資料並行性使用多個 work 工作,並且不會減少每一個 work 的工作量也不會犧牲模型的精度

Liner Scaling Rule: When the minibatch size is multiplied by K ,multiply the learning rate by K

這個 rule 在 broad range of minibatch size 裡都很有效果,其他的 hyper-parameters(weight decay 等)都保持不變,在第 5 部分,我們將會展示 linear sacling rule 不僅可以在 small 和 large minibath 中 math accuracy ,還可以 match training curves

我們比較了 k minibatch ,每一個batch size 為 n ,學習率為 η  和 一個 minibatch ,size 為 kn, 學習率為 

第一種的更新函式為???

第二種的更新函式為???

在一個很強的假設,即 l(x,wt) 和 l(x,w(t+j)) 的梯度相等的條件的,設定 ,可以獲得

但是這個假設在兩種情況下不存在,一種是訓練初期,網路變化的很快,第二種是 minibatch size 不可以無限的縮放,雖然結果在很大的 size 時也會保持很高的精度,但是在超過某個點後會迅速的下降

2.2 warmup

上面的第一種情況可以使用 warmup 來解決

Constant warmup:在訓練的 first few epochs 使用 low constant learning rate. 這個 strategy 在目標檢測和語義分割上fine pre-trained layers together with newly initialized layers 很有效,在 ImageNet kn minibatch size的實驗中,先使用小學習率 η 學習 first 5 epoch ,之後使用,學習。然而當 k 比較大的時候,constant warmup 策略對收斂並不充分,並有可能使訓練誤差增大,所以提出下面的方法

Gradual warmup:逐漸將學習率從小到大增大,可以避免學習率的突然增大,保證訓練初期的健康收斂。在 kn 的minibatch size 下,一開始使用 η 學習率,然後在 5 epoch 後逐漸增大至 ,warmup 後,回到正常的 learning rate schedule.

2.3 batch normalization with large minibatches

3 Subtleties and Pitfalls of Distributed SGD

在分散式計算中,許多 common implementation errors 會改變超引數的定義,模型雖然可以訓練但誤差會比較大

  • weight decay:

l2 regularization on the weights

如果沒有 weight decay , 就會有很多種方法來縮放學習率,例如縮放loss 的梯度項,但是我們發現縮放 loss 和縮放學習率並不等價

  • mometum correction:


m 是 momentum 的 decay factor , u 是 update tensor.

還有一種流行的將學習率加到 update tensor 項中


對於 fixed 的學習率,這兩個是等價的,但是我們可以發現,u 和學習率是無關的,v 和學習率是有關的,如果學習率改變了,為了使第二個式子和第一個等價 ,v 應該變為我們將  factor 當做 momentum correction,這一項對訓練 stabilize 很重要,尤其是在 t+1 的學習率遠大於 t 的學習率時,否則的話,history term 就會變得很小使得訓練不穩定

 

  • gradient aggregation

對於 k 個 worker,每一個 worker 的 minibatch size 為 n,梯度更新的時候除以 kn ,而 loss layer 通常會將每一個 worker 的平均梯度加起來


  • data shuffling


4 Communication

5 Main Results and Analysis

我們的主要結果是使用256 workers 一小時內在 ImageNet 上訓練 ResNet-50 網路,獲得了和 small minibatch size 同樣的精度。使用 linear scaling rule 和 warmup 策略允許我們不用調整超引數和影響精度的情況下縮放 batch size 

  • minibath size vs error 


minibath sizes 從64 到 65536(64k),所有的模型都使用 linear scaling rule ,在 kn > 256 時,使用 gradual warmup 策略,從上圖可以發現,在 8k 之後驗證誤差就會變大

  • warmup 


  • Training curves for various minibatch size


比較了不同 minibatch size 的 training curves 和 256 minibatch baseline 

  • Alternative learning rate rules

對於小批次 256 ,學習率取 0.1 獲得最小的 error,但是大的或者小的學習率也可以獲得比較好的結果,當在 8k images 上使用 linear scaling rule 時,學習率在 0.1*32 獲得最好的結果


當改變學習率時,會改變整個 trianing curves,即使最後的誤差是相同的。而線性縮放規則可以在誤差和training curves 都相同。

5.4 generalization to detection and segmentation

為了確定 large minibatch 和 small 學到的特徵是否一樣好,在 COCO  detection 和 instance segmentation 上使用 ImageNet pre-training

為了驗證 large minibatch pre-training 對 Mask R-CNN 的影響,使用 ResNet-50 訓練 ImageNet-1k,minibatch 從 256到 16k,之後使用這個model 初始化 Mask R-CNN 

 

只要 ImageNet validation error 很低,直到 up 8k,detection 的 AP 與之匹配,當資料集切換和任務切換時,用 large minibatch 並不沒有什麼問題

同樣,linear scaling rule 在 Mask R-CNN 上也適用