1. 程式人生 > >用spark訓練深度神經網路

用spark訓練深度神經網路

SparkNet: Training Deep Network in Spark

這篇論文是 Berkeley 大學 Michael I. Jordan 組的 ICLR2016(under review) 的最新論文,有興趣可以看看原文和原始碼:papergithub .

訓練深度神經網路是一個非常耗時的過程,比如用卷積神經網路去訓練一個目標識別任務需要好幾天來訓練。因此,充分利用叢集的資源,加快訓練速度成了一個非常重要的領域。不過,當前非常熱門的批處理計算架構(例如:MapReduce 和 Spark)都不是設計用來專門支援非同步計算和現有的一些通訊密集型的深度學習系統。

SparkNet 是基於Spark的深度神經網路架構,

  1. 它提供了便捷的介面能夠去訪問Spark RDDs;
  2. 同時提供Scala介面去呼叫caffe;
  3. 還擁有一個輕量級的tensor 庫;
  4. 使用了一個簡單的並行機制來實現SGD的並行化,使得SparkNet能夠很好的適應叢集的大小並且能夠容忍極高的通訊延時;
  5. 它易於部署,並且不需要對引數進行調整;
  6. 它還能很好的相容現有的caffe模型;

下面這張圖是SparkNet的架構:

此處輸入圖片的描述

從上圖可以看出,Master 向每個worker 分發任務之後,各個worker都單獨的使用Caffe(利用GPU)來進行訓練。每個worker完成任務之後,把引數傳回Master。論文用了5個節點的EC2叢集,broadcast 和 collect 引數(每個worker幾百M),耗時20秒,而一個minibatch的計算時間是2秒。

Implementation

SparkNet 是建立在Apache Spark和Caffe深度學習庫的基礎之上的。SparkNet 用Java來訪問Caffe的資料,用Scala來訪問Caffe的引數,用ScalaBuff來使得Caffe網路在執行時保持動態結構。SparkNet能夠相容Caffe的一些模型定義檔案,並且支援Caffe模型引數的載入。

下面簡單貼一下SparkNet的api和模型定義、模型訓練程式碼。
此處輸入圖片的描述
此處輸入圖片的描述

並行化的SGD

為了讓模型能夠在頻寬受限的環境下也能執行得很好,論文提出了一種SGD的並行化機制使得最大幅度減小通訊,這也是全文最大了亮點。這個方法也不是隻針對SGD,實際上對Caffe的各種優化求解方法都有效。

在將SparkNet的並行化機制之前,先介紹一種Naive的並行機制。

Naive SGD Parallelization

Spark擁有一個master節點和一些worker節點。資料分散在各個worker中的。
在每一次的迭代中,Spark master節點都會通過broadcast(廣播)的方式,把模型引數傳到各個worker節點中。
各個worker節點在自己分到的部分資料,在同一個模型上跑一個minibatch 的SGD。
完成之後,各個worker把訓練的模型引數再發送回master,master將這些引數進行一個平均操作,作為新的(下一次迭代)的模型引數。

這是很多人都會採用的方法,看上去很對,不過它有一些缺陷。

Naive 並行化的缺陷

這個缺陷就是需要消耗太多的通訊頻寬,因為每一次minibatch訓練都要broadcast 和 collect 一次,而這個過程特別消耗時間(20秒左右)。

Na(b) 表示,在batch-size為 b 的情況下,到達準確率 a 所需要的迭代次數。
C(b) 表示,在batch-size 為 b 的情況下,SGD訓練一個batch的訓練時間(約2秒)。
顯然,使用SGD達到準確率為a所需要的時間消耗是:

Na(b)C(b)

假設有K個機器,通訊(broadcast 和 collect)的時間為 S,那麼Naive 並行 SGD
的時間消耗就是:

Na(b)(C(b)/K+S)

SparkNet 的並行化機制

基本上過程和Naive 並行化差不多。唯一的區別在於,各個worker節點每次不再只跑一個迭代,而是在自己分到的一個minibatch資料集上,迭代多次,迭代次數是一個固定值τ

SparkNet的並行機制是分好幾個rounds來跑的。在每一個round中,每個機器都在batch size為b 的資料集上跑 τ 次迭代。沒一個round結束,再把引數彙總到master進行平均等處理。

我們用Ma(b,K,τ) 表示達到準確率 a 所需要的 round 次數。
因此,SparkNet需要的時間消耗就是:

Ma(b,K,τ)(τC(b)+S)

下面這張圖,很直觀的對比了Naive 並行機制跟 SparkNet 並行機制的區別:
Naive 並行機制:

此處輸入圖片的描述

SparkNet 並行機制:

此處輸入圖片的描述

論文還做了各種對比實驗,包括時間,準確率等。實驗模型採用AlexNet,資料集是ImageNet的子集(100類,每類1000張)。

假設S=0,那麼τMa(b,K,τ)/Na(b) 就是SparkNet的加速倍數。論文通過改變τK 得出了下面的表格(使準確率達到20%的耗時情況):

此處輸入圖片的描述

上面的表格還是體現了一些趨勢的:
(1). 看第一行,當K=1,因為只有一個worker節點,所以非同步計算的τ這時並沒有起到什麼作用,可以看到第一行基本的值基本都是接近1.
(2). 看最右邊這列,當τ=1,這其實就相當於是Naive 並行機制,只不過,Naive的batch是b/K,這裡是b. 這一列基本上是跟K成正比。
(3). 注意到每一行的值並不是從左到右一直遞增的。

S!=0 的時候,naive 跟 SparkNet 的耗時情況又是怎麼樣的呢?作者又做了一些實驗。

此處輸入圖片的描述

可以看到,當S接近與0的時候(頻寬高),Naive會比SparkNet速度更快,但是,當S 變大(頻寬受限),SparkNet的效能將超過Naive,並且可以看出,Naive受S變化劇烈, 而SparkNet相對平穩。

而作者實驗用EC2環境,S大概是20秒,所以,顯然,SparkNet會比Naive好很多。

論文還做了一些事情,比如:
τ=50,分別測試K=1、3、5、10時,準確率與時間的關係;
K=5,分別測試τ=20、50、100、150時,準確率與時間的關係。

此處輸入圖片的描述

總結一下,這篇論文其實沒有太多複雜的創新(除了SGD並行化時候的一點小改進),不過我很期待後續的工作,同時也希望這個SparkNet能夠維護的越來越好。有時間的話,還是很想試試這個SparkNet的