1. 程式人生 > >TensorFlow(二)之分散式

TensorFlow(二)之分散式

本博文參考TensorFlow技術解析與實戰(李嘉璇),僅用於學習。

一、原理

分散式TensorFlow是由高效能的gRPC庫作為底層技術來支援的。gRPC是Google開源的RPC框架(遠端過程呼叫協議),相當於提供一個介面,使用者將引數從本地傳遞到遠端伺服器,在伺服器上實現計算,客戶端最後獲得傳回的結果。

TensorFlow部署分為單機多卡(左)和多機多卡(分散式,右)。單機多卡就是一臺伺服器上有多個GPU,多機多卡(分散式)是指訓練在多個工作節點(worker)上。

            

TensorFlow叢集(cluster)可以劃分為一個到多個工作(job),每個工作可以劃分為一個到多個任務(task)。

二、架構

1、服務端:運行了tf.train.Server例項的程序,是叢集(cluster)的一部分,分為主節點服務和工作點服務

2、客戶端:包含tf.Session()。一個客戶端可以與多個服務端相連,一個服務端可以與多個客戶端相連

3、主節點:實現了tensorflow::Session介面。通過RPC服務程式來遠端連線工作節點,一般是task_index為0的作業(job)。

4、工作點:實現worker_service.proto介面,是對部分圖的計算。

三、模式

訓練一個模型的過程中,主要有資料並行和模型並行。

1、資料並行

CPU負責梯度平均和引數平均,GPU負責訓練模型副本。

(1)、非同步更新:每個工作節點上的任務獨立計算區域性梯度,非同步更新到模型引數中,不要要協調和等待。如非同步隨機梯度下降法(Sync-SGD)。資料量大,各節點計算能力參差不齊的情況下,推薦使用非同步更新。


(2)、同步更新:CPU等待所有GPU完成計算,統一進行引數更新,再講引數送到GPU中。如同步隨機梯度下降法(Sync-SGD)。資料量小,節點計算能力比較均衡的情況下,推薦使用同步更新。


2、模型並行:對模型進行切分,讓模型的不同部分執行在不同裝置上。如LSTM模型。


四、分散式API

建立叢集:建立一個tf.train.ClusterSpec,用於對叢集中所有任務進行描述,該描述內容對所有任務都一樣。

建立服務:建立一個tf.train.Server

常用API:

(1)、  tf.train.ClusterSpec({"ps":ps_hosts,"worker":worker_host})

eg:    tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})

 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2223"],

"ps":["ps0.example.com:2222","ps1.example.com:2222"]})

(2)、  tf.train.Server(cluster,job_name,task_index)

eg:     #在任務0:

 cluster=tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})

  server=tf.train.Server(cluster,job_name="local",task_index=0)

(3)、 tf.device(device_name_or_function),指定程式碼執行在CPU或者GPU上

eg:    with tf.device("/job:ps/task:0"):

五、分散式訓練程式碼框架

分散式程式碼具有固定的結構: 1、命令列引數解析,獲取叢集的訊息ps_hosts和worker_hosts,以及當前節點的角色資訊job_name和task_index
tf.app.flags.DEFINE_string("ps_hosts","","Comma-separated of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts","","Comma-separated of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name","","one of 'ps','worker'")
tf.app.flags.DEFINE_integer("task_index",0,'Index of task within the job')
FLAGS=tf.app.flags.FLAGS
ps_hosts=FLAGS.ps_hosts.split(",")
work_hosts=FLAGS.worker_hosts(",")
2、建立當前節點的伺服器
cluster=tf.train.ClusterSpec({"ps":ps_hosts,"worker":worker_hosts})
server=tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
3、如果當前節點是引數伺服器,則呼叫server.join()無休止等待;如果是工作節點,則執行第四步。
if FLAGS.job_name=="ps":
    server.join()
4、構建要訓練的模型,構建計算圖
elif FLAGS.job_name=="worker":
    #build tensorflow graph model
5、建立tf.train.Supervisor來管理模型的訓練過程
sv=tf.train.Supervisor(is_chief=(FLAGS.task_index==0),logdir="/tmp/train_logs")
sess=sv.prepare_or_wait_for_session(server,target)
while not sv.should_stop()