1. 程式人生 > 實用技巧 >Federated Learning: 問題與優化演算法

Federated Learning: 問題與優化演算法

工作原因,聽到和使用Federated Learning框架很多,但是對框架內的演算法和架構瞭解不夠細緻,特讀論文以記之。

這個系列計劃要寫的文章包括:

Overall

Federated Learning, 中文翻譯是聯合學習,或者很多人翻譯成聯邦學習,是一種在移動端訓練模型的框架。

不知為何,翻譯成聯邦學習總讓我有點笑場,就像one-hot編碼被翻譯成獨熱編碼一樣。難不成以後還有帝國學習,共和國學習? 下面只說聯合學習。

正常的機器學習/深度學習模型都是在伺服器端直接訪問資料進行訓練,但在實際的場景中,有很多情況下資料是不在伺服器端的:

  • 隱私內容: 比如商業資料,比如使用者在輸入法中直接輸入的資料。
  • 資料量大: 不太適合把所有資料都log到伺服器上。

聯合學習就是為了應對這種場景而生的。

聯合學習

聯合學習把資料和演算法解耦合。在模型的訓練中,首先把伺服器把模型當前狀態傳送給移動端,移動端利用當前的模型狀態和本地資料去進行計算,然後把梯度傳送給伺服器端,伺服器端再去匯合不同裝置上傳回的梯度去進行模型的更新。

這樣的訓練看著很直觀,但是相對於資料直接在伺服器端來說,有如下問題:

  • 資料並非獨立同分布的。如果資料在伺服器端,那麼可以通過shuffle來讓資料分佈均勻,但是每一臺device上,資料是有很強的bias的。
  • 資料不均衡。有的裝置上資料量很大,有的則很少。
  • 大規模分散式。參加訓練的裝置相對於裝置上的平均樣本數來說要大的多。
  • 有限通訊。頻寬很寶貴,因此訓練過程中要儘可能的減少伺服器和裝置交流的次數。

除了這些之外,還有一些問題不在本文的討論之中,但確也是非常實際的:

  • 客戶端資料在隨時發生變化。
  • device的可達性和資料的分佈有一種複雜的相關關係,比如,時區的原因,美式英語的使用者和英式英語的使用者在不同的時間上線參與訓練。
  • device不返回梯度或者返回損壞的梯度。

為了解決上述的問題,聯合學習採用的是可控環境下的同步式訓練:

  • 假設一共有K個客戶端參與聯合學習
  • 每次選擇C%的線上客戶端。
    • 做這個選擇是為了提高效率和減少錯誤率。
  • 伺服器端傳送模型當前狀態給選中的客戶端。
  • 客戶端進行本地計算,參與訓練的資料量為B(local_batch_size),得到梯度。
  • 客戶端傳送梯度更新給伺服器。
  • 伺服器進行聚合和更新全域性模型。

聚合梯度的公式如下,即不同client返回的梯度按照client上樣本數目進行加權。這裡假設資料是獨立同分布的,當然,因為這個條件不成立,所以這只是一個近似。

FederatedAveraging演算法

而聯合學習的訓練過程中,通訊將會是瓶頸,因為網路傳輸的頻寬比較小,聯合學習一般設定最多佔有1M/s的頻寬。而由於很多device上資料較少或者有高階核心(很多裝置都有GPU),所以算力反而不是問題。

而為了減少通訊次數,有兩種辦法:

  • 增大並行程度,即增大C,在每一輪訓練中增加參與計算的裝置。
    • 但這就面臨裝置出錯率變高的問題。
  • 增大每個裝置上單輪的計算,即在每一輪訓練中,每臺裝置上可能要計算多輪累積的梯度。
    • 這會遇到梯度更新不精確的問題。
    • 但後面會講到,這個問題在實驗中並不存在。

因而,在論文中,比較了兩種方法:

  • FedSGD: 就是SGD的聯合學習版本,每次訓練都使用device上的所有資料作為一個batch。進行屬於增大並行程度的方法,當C=1的時候,可以認為是Full-Batch訓練。
  • FederatedAveraging: 基於FedSGD,但是在device上可以訓練多步累積梯度,屬於增大每個裝置上單輪的運算。
    • 除了上面提到的K、C、B三個引數外,增加一個引數E,代表在device上每輪訓練執行的計算的次數。所以當B=全部,E=1的時候,FederatedAveraging與FedSGD等價。

演算法流程如下圖所示:

模型混合

經過FederatedAveraging學到的模型,有點類似於模型混合。因為模型在每個device上經過多步訓練之後可能會變得很不一樣。

而在通用的模型混合問題中,最基本的要求就是模型的初始化要一致。如下圖所示,不同方式初始化的模型做平均會得到差的結果(左圖),而相同的則是得到好的結果(右圖)。

# 實驗

增大客戶端數目

首先使用MNIST做了一個模擬實驗,實驗分為IID和NON-IID資料集+不同的E/B引數。

MNIST一共十個類別,IID資料集是將資料集混排後隨即分到100個客戶端上,而NON-IID則是在每個客戶端上只有2類的資料集,資料集都是均衡分佈在各個客戶端上的。

下圖中,2NN是2層全連線神經網路,CNN是一個2層的卷積網路,每層卷積之後都有一個pooling,最後是一個512的全連線層。表格中的數字代表的是達到某個準確率需要的通訊次數。其中2NN部分是達到97%準確率,CNN部分是達到99%準確率。

調整C,結果從下圖可以得到:

  • 參與的客戶端越多,速度越快。
  • B=全部的時候,增多客戶端,帶來的提升比較小,而在B=10的時候,增多客戶端,能帶來顯著的速度提升。

增大客戶端上的計算量

保持C=0.1,增大每輪訓練在device上的計算梯度的次數,即增大E,得到的實驗結果如下。 其中u代表的是每輪實驗梯度被計算的次數。可以看到,在IID資料上提速很大,在NON-IID上提速小,但是也能有將近三倍的提升。

同時,還做了一個LSTM語言模型上的實驗,這個實驗的設定跟MNIST很像,也分為IID和NON-IID,其中NON-IID是按照人物角色來分的。同時,IID是均衡資料集,NON-IID是不均衡資料集。

可以看到,在不均衡的NON-IID資料集上,FEDAVG卻能帶來95.3倍的提升,反而比IID均衡資料集要快。

但是需要注意的是,一直增大E,結果反而會適得其反,因為會導致模型在各個客戶端上發散。因為會導致模型發散。如下圖所示。

所以對於一些模型,比較好的方法是讓E隨著訓練步數的增加而遞減。這樣有利於收斂。

Cifar10實驗

在Cifar10上也進行了實驗,這次是均衡的IID資料,結果如下圖,可以看到,相對於普通的SGD,達到相同的準確率,FedSGD和FedAvg都有更少的通訊次數。

大規模LSTM Next Word Prediction實驗

將10M個某社交網站文件分到50k個裝置上,同一個作者的會被分到同一個裝置上,同時每個裝置限制嘴都5000個詞語。LSTM詞表是10k。LSTM是單層256節點。embedding是192,LSTM輸入的序列長度是10。

結果如下圖, FedAvg在35輪的時候就能達到SGD在伺服器端的效果。同時比FedAvg快23倍。

總結與思考

作為聯合學習實用化的開山之作,論文提出的FedAvg優化演算法,做了很多的對比實驗,實驗在不同的資料集上得到的略有不同的結論。但證明了在裝置端做mini-batch的是完全可行的,同時,裝置端還可以多做幾輪計算來積累梯度也有助於減少通訊次數。

與其他的演算法不同,聯合學習考慮的不再是算力問題,而是通訊問題,減少通訊次數成了最高優先順序,這點是個全新的思考方向。

勤思考, 多提問是Engineer的良好品德。

提問:

  • 如果裝置端只返回梯度,那麼有沒有可能通過梯度反推資料呢?如何避免這個問題?
  • 因為手機端記憶體有限,所以無法訓練大的模型,有沒有方法可以繞過這個限制得到大模型?

回答後續公佈,歡迎關注公眾號【雨石記】.

在這裡插入圖片描述

參考論文

  • [1]. McMahan, Brendan, et al. “Communication-efficient learning of deep networks from decentralized data.” Artificial Intelligence and Statistics. 2017.