1. 程式人生 > 其它 >[原始碼解析] 深度學習流水線並行 PipeDream(6)--- 1F1B策略

[原始碼解析] 深度學習流水線並行 PipeDream(6)--- 1F1B策略

在前文中,我們介紹了PipeDream的總體架構,Profile階段,計算分割槽階段,模型轉換階段,執行時引擎和通訊模組,本文是 PipeDream 系列最後一篇,介紹 1F1B 策略,這是 PipeDream 最大的貢獻。

[原始碼解析] 深度學習流水線並行 PipeDream(6)--- 1F1B策略

目錄

0x00 摘要

在前文中,我們介紹了PipeDream的總體架構,Profile階段,計算分割槽階段,模型轉換階段,執行時引擎和通訊模組,本文是 PipeDream 系列最後一篇,介紹 1F1B 策略,這是 PipeDream 最大的貢獻。

流水線並行其他文章連結如下:

[原始碼解析] 深度學習流水線並行Gpipe(1)---流水線基本實現

[原始碼解析] 深度學習流水線並行GPipe (2) ----- 梯度累積

[原始碼解析] 深度學習流水線並行 GPipe(3) ----重計算

[原始碼解析] 深度學習流水線並行之PipeDream(1)--- Profile階段

[原始碼解析] 深度學習流水線並行 PipeDream(2)--- 計算分割槽

[原始碼解析] 深度學習流水線並行 PipeDream(3)--- 轉換模型

[原始碼解析] 深度學習流水線並行 PipeDream(4)--- 執行時引擎

[原始碼解析] 深度學習流水線並行 PipeDream(5)--- 通訊模組

0x01 流水線比較

首先,我們比較一下目前分析過的各個流水線。

1.1 普通流水線

DNN模型組成的基本單位是層。PipeDream將DNN的這些層劃分為多個階段——每個階段(stage)由模型中的一組連續層組成。PipeDream把模型的不同的階段部署在不同的機器上,每個階段可能有不同的replication。該階段對本階段中所有層執行向前和向後傳遞。PipeDream將包含輸入層的階段稱為輸入階段,將包含輸出層的階段稱為輸出階段。

在最簡單的情況下,和傳統的模型並行訓練中一樣,系統中只有一個minibatch是活動的。上圖就顯示了一個計算時間線,該示例有四臺機器和一個管道,可以認為是一個最普通的流水線

  • 在正向階段,每個階段對本階段中的層的minibatch執行正向傳遞,並將結果傳送到下一階段。輸出級在完成前向傳遞後,計算minibatch的損失。
  • 在後向階段,每個階段形成後向通道,逐一將損失傳播到前一階段。

1.2 Gpipe流水線

因為PipeDream是基於Gpipe進行改進,所以我們也要基於 Gpipe 看看其問題所在。

Gpipe 的流水線並行訓練圖如下:

  • 將被訓練的這些層劃分為多個階段,每個階段包含模型之中一組連續的層。
  • 把輸入資料minibatch進行分片,分成 m 個microbatches,像 allreduce 一樣,計算完一些就傳給下個節點,最後同步更新引數。
  • GPipe使用現有的技術,如梯度累積來優化記憶體效率,通過丟棄前向傳播和後向傳播之間的activation儲存來交換記憶體,在後向傳遞需要activation時再重新計算它們。

Gpipe的流水線有幾個問題:

  • 過多流水線重新整理導致空閒時間的增加。
  • 如果m很小,Gpipe可能會由於重新計算開銷和頻繁的管道重新整理而降低硬體效率,所以 m 一般都設定的較大。
  • 於是需要快取 m 份 activation導致記憶體增加。原因是每個microbatch前向計算的中間結果activation都要被其後向計算所使用,所以需要在記憶體中快取。

1.3 1F1B流水線

PipeDream 的 1F1B(One Forward pass followed by One Backward pass)策略就可以解決快取 activation 的份數問題,使得 activation 的快取數量只跟階段(stage)數相關,從而進一步節省視訊記憶體。

Pipeline的並行方式是把模型的不同層放到不同機器(節點)上,順序地進行前向計算和反向計算。

PipeDream的目標是:以最小化總體訓練時間的方式將流水線並行,模型並行性和資料並行性結合起來。然而,要使這種方法對大型DNN模型有效,獲得流水線並行化訓練的潛在收益,PipeDream 必須克服三個主要挑戰:

  1. 如何跨可用計算資源自動劃分工作(模型的層)。
  2. 在確保訓練任務向前推進的同時,如何排程計算以最大化吞吐量。
  3. 面對流水線帶來的非同步性,如何確保訓練有效。

其中 1F1B 就對應了後面兩個挑戰。

1.3.1 思路

我們剖析下1F1B策略的思路。

終極目的是:減少activation 的快取數量,降低視訊記憶體佔用,從而可以訓練更大的模型。

目前困境是:即便使用了Checkpointing 技術,前向計算的 activation 也需要等到對應的後向計算完成之後才能釋放。

解決思路是:努力減少每個 activation 的儲存時間,即這就需要每個 micro-batch 資料儘可能早的完成後向計算讓,從而讓每個 activation 儘可能早釋放。

注意:PipeDream中,最後使用的是minibatch這個單詞,所以我們可以認為PipeDream的minibatch就是 Gpipe的 micro-batch,從這裡開始,都使用 minibatch。

解決方案是:

  • 讓最後一個 stage(下圖中的 Machine 4) 在做完一次 minibatch 的前向傳播之後,就立即做本minibatch 的後向傳播,那麼就可以讓其他 stage 儘可能早的開始後向傳播計算,這就是 1F1B 策略。有點類似於把整體同步變成了眾多小資料塊上的非同步,而且眾多小資料塊都是大家獨立更新。
  • 在 1F1B 的穩定狀態下,會在每臺機器上嚴格交替的進行前向計算/後向計算,這樣使得每個GPU上都會有一個minibatch資料正在處理,從而保證資源的高利用率(整個pipeline比較均衡,可忽略的流水線暫停,沒有流水線 flush,能確保以固定週期執行每個階段上的引數更新)
  • 面對流水線帶來的非同步性,1F1B 使用不同版本的權重來確保訓練的有效性。
  • PipeDream 又擴充套件了 1F1B,對於使用資料並行的stage,採用 round-robin的排程模式將任務分配在同一個stage的各個裝置上,保證了一個batch的資料的前向傳播計算和後向傳播計算髮生在同一臺機器上,這就是 1F1B-RR(one-forward-noe-backward-round-robin)。

實際上,1F1B策略就是把一個batch的同步變為了眾多小資料(minibatch)上的非同步,計算完一個minibatch就立刻反向,一個minibatch的反向結束之後就更新對應worker的梯度。所有worker一起跑起來。可以理解為從 BSP 執行變成了 ASP 執行。

1.3.2 圖示

下圖是實施了 1F1B 的流水線。

  • 把一個 batch 分成多個mini batches,比如把一個 batch 分成 1,2,3,4 這4個mini batches。
  • 把多個 mini batches 逐一插入到流水線。
  • Machine 1 先計算 藍色 1 的前向傳播,然後把藍色 1 傳送給 Machine 2 繼續計算。
  • Machine 2 接著計算 藍色 2 的前向傳播,然後把藍色 1 發給 Machine 2 繼續計算。
  • 當藍色 1 由上至下遍歷了 Machine 1 ~ 4,則完成了全部前向傳播,於是開始進行反向傳播,對應了第一個綠色 1,然後逆向傳遞到 Machine 3 ~ 1。
  • 當資料 1 完成了全部反向傳播,即綠色 1 來到了 Machine 1。
  • 每個機器在完成自己 mini batch 的反向傳播之後,會在本地進行梯度更新。
  • Machine 和 Machine 之間只傳送模型的一個子集,這樣計算和通訊可以並行。

需要注意,下圖給出了初始階段和穩定階段,我們後續講解中會提到這兩個階段。

0x02 PipeDream 實現

首先給出一個包含4個GPU的示例圖,圖內也給出了其中一個GPU(Mach. 3)的時間流示例。這裡計算和梯度/啟用通訊是有部分重疊的。

2.1 總體邏輯

我們以一次訓練為例,結合下圖來說明。

需要介紹一個名詞 NOAM,活動小批次數目。

NUM_OPT_ACTIVE_MINIBATCHES (NOAM) = ⌈ (# machines) / (# machines in the input stage) ⌉

其意義是:基於我們的演算法生成的分割槽,為了在穩定狀態下保持流水線滿負荷,每個輸入級副本所允許的最小批處理數

上圖顯示了管道的相應計算時間線,每個流水線有4個階段在不同機器上執行,所以此配置的NOAM為 4。

我們具體再分析下執行步驟。

  • 在訓練開始的啟動階段(圖上的Startup State),輸入的stage的先讀入足夠多minibatch的資料(就是NOAM個),以保證pipeline在穩定階段時,各個裝置上都有相應的工作在處理。對於上圖,就是輸入階段傳送四個小批次傳播到輸出階段。
  • 一旦輸出階段完成第一個小批次的前向傳播(就是Machine 4 第一個藍色1),它就對同一個小批次執行後向傳播(就是Machine 4 的第一個綠色 1)。
  • 然後開始交替執行後續小批次的前向傳播和後向傳播(就是 Machine 4 的 2前,2後,3前,3後.....)。
  • 當反向傳播過程開始傳播到管道中的早期階段時(就是Work 3 ~ Work 1),每個階段開始在不同小批次的正向和反向過程之間交替進行。
  • 在穩定狀態下,每臺機器都忙著對一個小批次進行正向傳播或反向傳播。

2.2 權重問題

Pipeline的訓練模式會引入兩種引數不一致性,因為實際是ASP計算,沒有協調會越幹越亂:

  • 在一個原生的PipeDream流水線中,每個階段的前向傳播都是使用某一個版本的引數來執行,而其後向傳播則是使用不同版本的引數來執行的,即同一個minibatch的前向傳播和後向傳播使用的引數不一致。例如上圖所示:
    • 當 minibatch 5 進入到 worker 1 時,它的前向傳播邏輯在 minibatch 1 的後向傳播計算之後執行,即它前向傳播計算時候使用的引數是 minibatch 1 後向傳播計算之後更新的引數。
    • 但是 minibatch 5 後向傳播邏輯是在 "minibatch 2, minibatch 3, minibatch 4" 執行完後才開始計算,即此時使用的引數是"minibatch 1, minibatch 2, minibatch 3, minibatch 4" 後向傳播計算之後更新的引數。
    • 這就導致 minibatch 5 的前向計算和後向計算時候,使用的引數不一致。即,第一行 Machine 1,藍色 5 號 和 綠色 5 號 計算時候,必須都使用 綠色 1 號之後更新的引數。
  • 同一個minibatch在不同stage做同樣操作(同樣做前向操作,或者同樣做後向傳播)使用的引數版本不一致。同樣如上圖所示:
    • 對於 minibatch 5 在 worker 1 上的前向計算部分(藍色5),他的前向邏輯在 minibatch 1 的後向計算以後執行。
    • 但是 minibatch 5 在 worker 2 上的前向計算部分(藍色5),是在 "minibatch 1, minibatch 2" 的後向計算結束後才執行。
    • 這就導致了 minibatch 5 在兩個stage上前向計算使用的引數版本不一致。

為解決這兩個問題,PipeDream 分別採用了 weight stashing 和 Vertical Sync 兩種技術

  • Weight stashing : 為權重維護多個版本,每個active minibatch都有一個版本。每個stage 都用最新版本的權重進行前向計算,處理輸入的minibatch。計算前向傳播之後,會將這份引數儲存下來用於同一個minibatch的後向計算。Weight stashing確保在一個階段內,相同版本的模型引數被用於給定小批量的向前和向後傳播,但是不能保證跨階段間,一個給定的小批次使用模型引數的一致性
  • Vertical Sync : 每個minibatch進入pipeline時都使用輸入stage最新版本的引數,並且引數的版本號會伴隨該minibatch資料整個生命週期,在各個階段都是用同一個版本的引數(而不是Weight stashing那樣都使用最新版本的引數),從而實現了stage間的引數一致性

2.3 Weight Stashing

我們以下圖為例:

Worker 1, work 2 ... 各自有自己的權重, 記為 \(W_1\)\(W_2\) .... 即,圖上的 \(W_i^{(j)}\),下標 i 表示 第 i 個 worker,上標 ( j ) 表示第 j 個minibatch。

在一個階段(每一個 worker)中:

  • 每次向後傳播都會導致權重更新,下一次向前傳使用最新版本的可用權重。就是說,每個 worker 的權重,在出現一個新的綠色後向傳播之後會被更新。接下來的新操作應該基於這個新權重。
  • 計算前向傳播之後,會將這份前向傳播使用的權重儲存下來用於同一個 minibatch 的後向計算。
  • Weight stashing確保在一個階段內,相同版本的模型引數被用於給定小批量的向前和向後傳播。

我們以上圖為例:

Worker 1 第一行的藍色 5 依賴於 它前面同一行的綠色 1。Worker 1 所在行的第一個綠色 1 結束時,代表了 minibatch 1 完成了本次流水線的 4 次前向傳播,4次後向傳播。所以是一個新版本的 weight 1,就是\(W_1^{(1)}\)。因此,Work 1 的兩個 minibatch 5(藍色前向和綠色後向)都應該基於新版本 \(W_1^{(1)}\) 計算。因此需要記錄下來 新版本 \(W_1^{(1)}\)

Worker 2 第二行的藍色 5 依賴於它前面同一行的綠色 2。同理,Worker 1 的第一個綠色 2 結束時,代表了 minibatch 2 完成了本次流水線的 4 次前向傳播,4次後向傳播。所以是一個新版本的 weight 2。此時的 minibatch 6 的前向和圖上未標出的綠色後向都應該基於 新版本的 weight 2 計算,因此需要記錄下來 新版本 \(W_2^{(2)}\)

對於 worker 3,從它的角度看,它本身的權重應該執行兩次前向,兩次後向(worker 4一次,然後 worker 3 第二次)。當執行 minibatch 5 的前向傳播時候,\(W_3^{(3)}\)已經更新(被minibatch 3 的綠色更新),所以需要記錄下來 \(W_3^{(3)}\),為了以後 minibatch 5 的後向更新使用。

依次類推,worker 1 需要記錄 \(W_1^{(1)}\), \(W_1^{(2)}\)\(W_1^{(3)}\)\(W_1^{(4)}\),... 的每一個新版本。就是 worker 1 對應 minibatch 1,2,3,4 的各個權重。

2.4 Vertical Sync

目前問題是:worker 1 上 minibath 5 的前向計算用的是 1 後向傳播之後的引數,但worker 2 上計算 minibath 5 是用 2 後向傳播之後的引數,最後彙總的時候不就又亂了?

Vertical Sync機制是:每個進入管道的 minibatch(\(b_i\))都與其進入流水線輸入階段時候的最新權重版本\(w^{(i-x)}\)相聯絡。當小批次在流水線前向傳播階段前進時候,這個版本資訊隨著啟用值和梯度一起流動。在所有階段中,\(b_i\) 的前向傳播使用儲存的\(w^{(i-x)}\)來計算,而不是Weight stashing那樣都使用最新版本的引數。在使用儲存的 \(w^{(i-x)}\)來計算後向傳播之後,每個階段獨立應用權重更新,建立最新權重\(w^{(i)}\),然後刪除\(w^{(i-x)}\)

用下面圖來說明:

強制所有worker在計算 minibatch 5 的時候都用本worker做 minibatch 1 反向傳播之後的引數,具體來說就是:

對於 worker 2,使用本階段綠色1(1反向傳播之後,更新的本階段權重)來做 5 的前向傳播。

同理,對於 worker 3,使用本階段綠色1(1反向傳播之後,更新的本階段權重)來做 5 的前向傳播。對於 worker 4,使用本階段綠色1(1反向傳播之後,更新的本階段權重)來做 5 的前向傳播。

但是,這樣同步會導致很多計算浪費無用。比如5更新時用的1的權重,但2/3/4後向傳播的權重都白白計算了,所以預設不使用Vertical Sync。這樣雖然每層不完全一致,但是由於weight stashing的存在,所有的引數都是有效的。

2.5 緩衝區

這裡對緩衝區的處理再做一下說明。

引數狀態。對於每個階段,PipeDream主要維護與GPU記憶體中直接分配給該階段的層相關的所有引數。每個層的引數分別儲存,每個層分配一個唯一的ID。如果沒有複製該階段,PipeDream將更新應用到儲存在GPU記憶體中的引數資料的最新版本,當所提供的GPU緩衝區中的權重更新可用時。如果複製了stage,則將權重更新複製到主機記憶體,然後傳送到引數伺服器。當新版本的引數可用時,作為權重儲存方案的一部分,不會立即丟棄以前的版本。引數資料只有在使用較新引數的向後傳遞被格式化後才會被丟棄。

中間狀態。每個層的中間資料也被分配了一個唯一的blob ID。當從前一級(或在輸入級的情況下從磁碟)接收中間資料時,PipeDream將中間資料複製到GPU記憶體,並在工作佇列中放置一個指向相關緩衝區的指標。在關聯的minibatch完成該階段的向後傳遞之前,forward傳遞的中間資料不會被丟棄。當ML工作人員完成使用後,以及如果需要,在將其傳送到下一階段之後,來自向後傳遞的中間資料就被釋放。由於向前和向後傳遞中對中間資料的要求不同,PipeDream中的stage通常管理來自向前傳遞的多個版本的中間資料,而只管理來自當前執行的向後傳遞的單個版本的中間資料。

0x03 程式碼

3.1 總體程式碼

我們用 runtime/translation/main_with_runtime.py 來分析。

下面省略部分次要程式碼。

使用 runtime 的總體邏輯可以如下檔案為例 :runtime/translation/main_with_runtime.py。主要邏輯是:

  • 解析輸入引數。
  • 載入,生成模型。
  • 依據模組來構建模型。
  • 依據引數進行配置比如輸入大小,batch size等。
  • 遍歷模型的每個層(跳過最後loss層)。
    • 遍歷每層的輸入,構建輸入張量。
    • 通過呼叫stage對應的forward函式,構建出輸出。
    • 遍歷每層的輸出,設定其型別和形狀 。
  • 構建輸出值張量型別。
  • 載入配置檔案。
  • 構建一個 StageRuntime。
  • 建立 optimizer,這裡 optimizer,使用了AdamWithWeightStashing 或者 SGDWithWeightStashing,所以就是使用了 weight stashing。
  • 載入 dataset。
  • 進行訓練,儲存checkpoint。

總體程式碼如下:

def main():
    # 解析輸入引數
    global args, best_prec1
    args = parser.parse_args()

    # Special case handling for GNMT model
    l2_promote()

    torch.cuda.set_device(args.local_rank)

    # build tokenizer
    tokenizer = Tokenizer(os.path.join(args.data_dir, config.VOCAB_FNAME))

    # define loss function
    criterion = build_gnmt_criterion(
        vocab_size=tokenizer.vocab_size, padding_idx=config.PAD, smoothing=0.1)

    # create stages of the model
    # 載入,生成模型
    module = importlib.import_module(args.module)
    args.arch = module.arch()
    # 依據模組來構建模型
    model = module.model(criterion)

    # 依據引數進行配置比如輸入大小,batch size等
    input_size = [args.max_length_train, args.batch_size]
    training_tensor_shapes = {"input0": input_size, "input1": [args.batch_size],
                              "input2": input_size, "target": [args.max_length_train * args.batch_size],
                              "target_length": [args.batch_size]}
    dtypes = {"input0": torch.int64, "input1": torch.int64, "input2": torch.int64,
              "target": torch.int64, "target_length": torch.int32}
    inputs_module_destinations = {"input0": 0, "input1": 0, "input2": 0}
    target_tensor_names = {"target", "target_length"}
    
    # 遍歷模型的每個層(跳過最後loss層)
    for module_id, (stage, inputs, outputs) in enumerate(model[:-1]):  # Skip last layer (loss).
        input_tensors = []
        # 遍歷每層的輸入,構建輸入張量
        for module_input in inputs:
            if module_input in inputs_module_destinations:
                inputs_module_destinations[module_input] = module_id

            input_tensor = torch.ones(tuple(training_tensor_shapes[module_input]),
                                      dtype=dtypes[module_input])#.cuda()
            input_tensors.append(input_tensor)
        #stage.cuda()
        # PyTorch should not maintain metadata for a backward pass on
        # synthetic inputs. Without the following line, the runtime is
        # as much as 1.5x slower in a full DP configuration.
        with torch.no_grad():
            # 通過呼叫stage對應的forward函式,構建出輸出
            output_tensors = stage(*tuple(input_tensors))
        if not type(output_tensors) is tuple:
            output_tensors = [output_tensors]
        # 遍歷每層的輸出,設定其型別和形狀    
        for output, output_tensor in zip(outputs,
                                         list(output_tensors)):
            # output 是 ['out2', 'out1']
            training_tensor_shapes[output] = list(output_tensor.size())
            dtypes[output] = output_tensor.dtype

    # 構建輸出值張量型別           
    eval_tensor_shapes = {}
    for key in training_tensor_shapes:
        eval_tensor_shapes[key] = tuple(
            training_tensor_shapes[key])
        training_tensor_shapes[key] = tuple(
            training_tensor_shapes[key])

    # 載入配置檔案
    configuration_maps = {
        'module_to_stage_map': None,
        'stage_to_rank_map': None,
        'stage_to_depth_map': None
    }
    if args.config_path is not None:
        json_config_file = json.load(open(args.config_path, 'r'))
        configuration_maps['module_to_stage_map'] = json_config_file.get("module_to_stage_map", None)
        configuration_maps['stage_to_rank_map'] = json_config_file.get("stage_to_rank_map", None)
        configuration_maps['stage_to_rank_map'] = {
            int(k): v for (k, v) in configuration_maps['stage_to_rank_map'].items()}
        configuration_maps['stage_to_depth_map'] = json_config_file.get("stage_to_depth_map", None)

    # 構建一個 StageRuntime
    r = runtime.StageRuntime(
        model=model, distributed_backend=args.distributed_backend,
        fp16=args.fp16, loss_scale=args.loss_scale,
        training_tensor_shapes=training_tensor_shapes,
        eval_tensor_shapes=eval_tensor_shapes,
        training_tensor_dtypes=dtypes,
        inputs_module_destinations=inputs_module_destinations,
        target_tensor_names=target_tensor_names,
        configuration_maps=configuration_maps,
        master_addr=args.master_addr,
        rank=args.rank, local_rank=args.local_rank,
        num_ranks_in_server=args.num_ranks_in_server,
        verbose_freq=args.verbose_frequency,
        model_type=runtime.TRANSLATION,
        enable_recompute=args.recompute)

    # stage needed to determine if current stage is the first stage
    # num_stages needed to determine if current stage is the last stage
    # num_ranks needed to determine number of warmup_minibatches in case of pipelining
    args.stage = r.stage
    args.num_stages = r.num_stages
    args.num_ranks = r.num_ranks
    if not is_first_stage():
        args.synthetic_data = True

    # define optimizer
    if args.no_input_pipelining:
        num_versions = 1
    else:
        # number of versions is the total number of machines following the current
        # stage, shared amongst all replicas in this stage
        num_versions = r.num_warmup_minibatches + 1

    # if specified, resume from checkpoint
    if args.resume:
        checkpoint_file_path = "%s.%d.pth.tar" % (args.resume, r.stage)
        assert os.path.isfile(checkpoint_file_path)
        print("=> loading checkpoint '{}'".format(checkpoint_file_path))
        checkpoint = torch.load(checkpoint_file_path)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        r.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(checkpoint_file_path, checkpoint['epoch']))

    # TODO: make this configurable by args
    # 建立 optimizer,使用了AdamWithWeightStashing 或者 SGDWithWeightStashing
    use_adam_optimizer = True
    if use_adam_optimizer:
        optimizer = adam.AdamWithWeightStashing(
            modules=r.modules(), master_parameters=r.master_parameters,
            model_parameters=r.model_parameters, loss_scale=args.loss_scale,
            num_versions=num_versions, lr=args.lr, betas=(0.9,0.999),
            weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency,
            macrobatch=args.macrobatch)
    else:
        optimizer = sgd.SGDWithWeightStashing(
            modules=r.modules(), master_parameters=r.master_parameters,
            model_parameters=r.model_parameters, loss_scale=args.loss_scale,
            num_versions=num_versions, lr=args.lr, momentum=args.momentum,
            weight_decay=args.weight_decay, verbose_freq=args.verbose_frequency)

    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    # 載入 dataset
    train_dataset = LazyParallelDataset(
        src_fname=os.path.join(args.data_dir, config.SRC_TRAIN_FNAME),
        tgt_fname=os.path.join(args.data_dir, config.TGT_TRAIN_FNAME),
        tokenizer=tokenizer,
        min_len=args.min_length_train,
        max_len=args.max_length_train,
        sort=False,
        max_size=None)

    val_dataset = ParallelDataset(
        src_fname=os.path.join(args.data_dir, config.SRC_VAL_FNAME),
        tgt_fname=os.path.join(args.data_dir, config.TGT_VAL_FNAME),
        tokenizer=tokenizer,
        min_len=args.min_length_train,
        max_len=args.max_length_train,
        sort=True)

    distributed_sampler = False
    if configuration_maps['stage_to_rank_map'] is not None:
        num_ranks_in_first_stage = len(configuration_maps['stage_to_rank_map'][0])
        if num_ranks_in_first_stage > 1:
            distributed_sampler = True

    # TODO: fix random seeds
    train_loader = train_dataset.get_loader(
        batch_size=args.batch_size, seeds=range(args.epochs),
        batch_first=False, shuffle=True,
        bucketing=not args.no_bucketing, num_workers=args.workers,
        world_size=r.num_ranks_in_first_stage,
        rank=r.rank_in_stage if r.stage == 0 else 0
    )

    val_loader = val_dataset.get_loader(
        batch_size=args.batch_size, batch_first=False,
        shuffle=True, num_workers=args.workers,
        world_size=r.num_ranks_in_first_stage,
        seeds=range(args.epochs),
        rank=r.rank_in_stage if r.stage == 0 else 0
    )

    # if checkpoint is loaded, start by running validation
    if args.resume:
        assert args.start_epoch > 0
        validate(val_loader, r, args.start_epoch-1)

    # 進行訓練,儲存checkpoint
    for epoch in range(args.start_epoch, args.epochs):
        if distributed_sampler:
            train_loader.sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args.epochs, r, args.lr_policy)

        # train or run forward pass only for one epoch
        if args.forward_only:
            validate(val_loader, r, epoch)
        else:
            train(train_loader, r, optimizer, epoch)

            # evaluate on validation set
            prec1 = validate(val_loader, r, epoch)
            if r.stage != r.num_stages: prec1 = 0

            # remember best prec@1 and save checkpoint
            best_prec1 = max(prec1, best_prec1)

            should_save_checkpoint = args.checkpoint_dir_not_nfs or r.rank_in_stage == 0
            if args.checkpoint_dir and should_save_checkpoint:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': r.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer' : optimizer.state_dict(),
                    'tokenizer': tokenizer.get_state()
                }, args.checkpoint_dir, r.stage, epoch)

3.2 訓練函式

我們下面看看訓練函式 train 程式碼

  • 首先進入啟動熱身階段,需要一直執行到 輸出完成第一個小批次的前向傳播,對應上圖的 Startup State。
  • 然後開始交替執行後續小批次的前向傳播和後向傳播,從此時開始,進入到上圖的 Steady State,在每個階段之中,對於每一個小批次:
    • 實施前向傳播,目的是把minibatch推送到下游worker。這就是 1F
    • 如果是最後階段,則更新損失。
    • 梯度清零。
    • 載入儲存的權重。
    • 後向傳播。這就是 1B
    • 恢復最新權重。目前在本step內,就完成了 1F1B。
    • 進行下一次step。
  • 最後是剩餘的後向傳播,對應著熱身階段的前向傳播。
def train(train_loader, r, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    n = r.num_iterations(loader_size=len(train_loader))
    if args.num_minibatches is not None:
        n = min(n, args.num_minibatches)
    r.train(n)
    if not is_first_stage(): train_loader = None
    r.set_loader(train_loader)

    end = time.time()
    epoch_start_time = time.time()

    if args.no_input_pipelining:
        num_warmup_minibatches = 0
    else:
        num_warmup_minibatches = r.num_warmup_minibatches

    # start num_warmup_minibatches forward passes
    # 啟動熱身階段,需要一直執行到 輸出完成第一個小批次的前向傳播,對應上圖的Start State。
    for i in range(num_warmup_minibatches):
        r.run_forward() # 前向傳播,就是1F

    # 開始交替執行後續小批次的前向傳播和後向傳播,從此時開始,進入到上圖的 Steady State。
    for i in range(n - num_warmup_minibatches):
        # perform forward pass
        r.run_forward() #前向傳播,就是1F

        if is_last_stage(): # 最後階段
            # measure accuracy and record loss
            output, target, loss, num_tokens = r.output, r.target, r.loss.item(), r.num_tokens()
            losses.update(loss, num_tokens) # 更新損失

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            epoch_time = (end - epoch_start_time) / 3600.0
            full_epoch_time = (epoch_time / float(i+1)) * float(n)
        else:
            # print log,省略

        # perform backward pass
        if args.fp16:
            r.zero_grad() # 梯度清零
        else:
            optimizer.zero_grad() # 梯度清零
            
        optimizer.load_old_params() # 載入 stash weight

        r.run_backward() # 後向傳播,就是1B
        
        optimizer.load_new_params() # 恢復新的weight
        
        optimizer.step() # 下一次訓練,同時更新引數

    # finish remaining backward passes
    # 最後剩餘的後向傳播,對應著熱身階段的前向傳播
    for i in range(num_warmup_minibatches):
        optimizer.zero_grad()
        optimizer.load_old_params() # 載入 stash weight
        r.run_backward() # 後向傳播,就是1B
        optimizer.load_new_params() # 恢復新的weight
        optimizer.step() # 下一次訓練

    # wait for all helper threads to complete
    r.wait()

上面引數的 r 是 StageRuntime 型別,所以我們看看其中的run_forward和run_backward。

3.3 前向傳播

以下是 StageRuntime 類的 run_forward 方法 和 _run_forward 方法,這兩個方法完成了前向傳播。

   def run_forward(self, recompute_step=False):
        """Run forward pass.
        """
        # Receive tensors from previous worker.
        self.receive_tensors_forward() # 接受上一階段的張量
        tensors = self.tensors[-1]

        # Run forward pass.
        self._run_forward(tensors) # 進行本階段前向傳播計算

        # Send tensors forward.
        self.send_tensors_forward() # 傳送給下一階段
        self.forward_stats.reset_stats()
        self.forward_minibatch_id += 1

    def _run_forward(self, tensors):
        # Perform forward pass through model (self.modules_with_dependencies already
        # has modules in topological order).
        
        # 得到module和對應的輸入,輸出
        modules = self.modules_with_dependencies.modules()
        all_input_names = self.modules_with_dependencies.all_input_names()
        all_output_names = self.modules_with_dependencies.all_output_names()
        
        # 遍歷模組
        for i, (module, input_names, output_names) in \
                enumerate(zip(modules, all_input_names, all_output_names)):
            if i == (len(modules) - 1) and self.is_criterion: 
                # 如果是計算損失
                # If layer is criterion (loss).
                if self.model_type == SPEECH_TO_TEXT:
                    output = tensors["output"].transpose(0, 1).float()
                    output_sizes = tensors["output_sizes"].cpu()
                    target = tensors["target"].cpu()
                    target_sizes = tensors["target_length"].cpu()
                    input0_size = tensors["input0_size"].cpu()
                    module_outputs = [module(output, target, output_sizes, target_sizes) / input0_size[0]]
                else:
                    module_outputs = [module(tensors[input_name],
                                             tensors["target"])
                                      for input_name in input_names]
                    module_outputs = [sum(module_outputs)]
            else:
                # 中間層
                # If layer is non-criterion.
                module_outputs = module(*[tensors[input_name]
                                          for input_name in input_names])
                if not isinstance(module_outputs, tuple):
                    module_outputs = (module_outputs,)
                module_outputs = list(module_outputs)

            # 把計算結果放入tensors之中,這樣後續就知道如何傳送    
            for (output_name, module_output) in zip(output_names, module_outputs):
                tensors[output_name] = module_output

        self.output = tensors[input_names[0]]
        # 如果是最後階段,則做處理
        if self.is_criterion and self.model_type == TRANSLATION:
            loss_per_batch = tensors[output_names[0]] * tensors[self.criterion_input_name].size(1)
            loss_per_token = loss_per_batch / tensors["target_length"][0].item()
            self.loss = loss_per_token
        elif self.is_criterion:
            self.loss = tensors[output_names[0]]
        else:
            self.loss = 1

3.4 反向傳播

執行引擎的 run_backward 完成了後向計算。

    def run_backward(self):
        # Receive input gradients needed for backward pass.
        self.receive_tensors_backward() # 從反向計算圖上一層接受梯度
        
        # Backward pass through modules in reverse order.
        inputs = {}
        outputs = {}
        input_gradients = {}
        output_gradients = {}

        # Get input and output names spanning all modules in this stage.
        all_input_names_set = set()
        all_output_names_set = set()

        # 得到module和對應的輸入,輸出
        modules = self.modules_with_dependencies.modules()
        all_input_names = self.modules_with_dependencies.all_input_names()
        all_output_names = self.modules_with_dependencies.all_output_names()

        for (input_names, output_names) in zip(all_input_names, all_output_names):
            for input_name in input_names:
                all_input_names_set.add(input_name)
            for output_name in output_names:
                all_output_names_set.add(output_name)

        tensors = self.tensors.pop(0)
        # Set inputs, outputs, and output_gradients.
        # Only set outputs/output_gradients for tensors that are not inputs of
        # other modules in this stage.
        # Similarly, only set inputs for tensors that are not outputs of other
        # modules in this stage.
        for (module, input_names, output_names) in \
            zip(reversed(modules), reversed(all_input_names), reversed(all_output_names)):
            for output_name in output_names:
                if output_name not in all_input_names_set:
                    if output_name not in self.gradients:
                        output_gradients[output_name] = None
                    else: 
                        # 計算梯度記錄在這裡
                        output_gradients[output_name] = self.gradients[output_name]
                    if tensors[output_name].requires_grad:
                        outputs[output_name] = tensors[output_name]
            for input_name in input_names:
                if input_name not in all_output_names_set:
                    inputs[input_name] = tensors[input_name]

        # Hook to record input gradients.
        def hook_wrapper(input_name):
            def hook(input_gradient):
                input_gradients[input_name] = input_gradient
            return hook

        for input_name in inputs:
            if input_name != "input0" and input_name != "input1" and input_name != "input2" \
                    and inputs[input_name].requires_grad:
                inputs[input_name].register_hook(hook_wrapper(input_name))

        if "loss" in outputs:
            outputs["loss"] *= self.loss_scale

        # Perform backward pass.
        # 進行反向傳播,output_gradients 
        # outputs 就是要計算梯度的張量,output_gradients就是計算出來的梯度
        torch.autograd.backward(tuple([outputs[output_name] for output_name in outputs]),
                                grad_tensors=tuple([output_gradients[output_name]
                                                    for output_name in outputs]))

        # Input tensors don't need gradients.
        for input_name in inputs:
            if not inputs[input_name].requires_grad:
                self.gradients[input_name] = inputs[input_name]
                continue

            if input_name != "input0" and input_name != "input1" and input_name != "input2" and input_name != "input":
                self.gradients[input_name] = input_gradients[input_name]

        # Send output gradients.
        self.send_tensors_backward() # 傳送梯度(self.gradients)給反向圖的下一層
        
        if self.verbose_freq > 0 and self.backward_minibatch_id % self.verbose_freq == 0:
            self.backward_stats.print_stats()
        self.backward_stats.reset_stats()
        self.backward_minibatch_id += 1

我們藉助前文的圖,再加深一下印象。

傳送邏輯:

 StageRuntime            CommunicationHandler              send_helper_thread

      +                           +                                 +
      |                           |                                 |
      | 1                         |                                 |
      v                           |                                 |
 run_backward                     |                                 |
      |                           |                                 |
      | 2                         |                                 |
      |                           |                    wait on backward_send_queues
      v                  3        v                                 |
send_tensors_backward +--------> send                               |
                                  |                                 |
                                  |                                 |
                                  |  4                              |
                                  v               5                 v
               backward_send_queues.add(tensor) +----> tensor = queue.remove()
                                                notify              |
                                                                    |
                                                                    | 6
                                                                    v
                                                                  _send
                                                                    |
                                                                    | 7
                                                                    |
                                                                    v
                                                                 dist.send

接受邏輯:

    StageRuntime             CommunicationHandler           recv_helper_thread
          +                            +                            +
          |                            |                            |
          | 1                          |                            |
          |                            |                            | 4
          v                            |                            v
    run_backward                       |                         _recv
          |                            |                            |
          |                            |                            |
          |                            |                            | 5
          |                            |                            |
          | 2                          |                            v
          |                            |                  dist.recv / dist.broadcast
          |                            |                            |
          v                  3         v                            |
receive_tensors_backward +--------->  recv                          |
          +                            |                            |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |                            v                            |
          |                 backward_receive_queues.remove()        |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |               wait on backward_receive_queues           |
          |                            |                            |
          |                            |                            |
          |                            |                            |
          |                            |                 6          v
          |                  backward_receive_queues <-------+ queue.add(tensor)
          |                            |               notify
          |                            |  7
          v                  3 return  |
gradients[output_name] <---------------+

3.5 Weight Stashing

Weight Stashing 是由OptimizerWithWeightStashing實現的。

下面省略了很多次要程式碼,訓練時候呼叫了 load_old_params 和 load_new_params。

class OptimizerWithWeightStashing(torch.optim.Optimizer):
    """Wrapper class that adds weight stashing to a vanilla torch.optim.Optimizer.

    Arguments:
        - optim_name: the name of optimizer, required to create the corresponding
                      base_optimizer (torch.optim.{optim_name}).
        - optimizer_args: the keyword arguments passed to base_optimizer.
    """

    def __init__(self, optim_name, modules, master_parameters, model_parameters,
                 loss_scale, num_versions, verbose_freq=0, macrobatch=False,
                 **optimizer_args):
        self.modules = modules
        self.master_parameters = master_parameters
        self.model_parameters = model_parameters  # model_parameters is None if not fp16.
        self.loss_scale = loss_scale

        # Only need at most 2 versions if using macrobatching.
        if macrobatch:
            num_versions = min(2, num_versions) 
        self.num_versions = num_versions
        self.base_optimizer = getattr(torch.optim, optim_name)(
            master_parameters, **optimizer_args)
        self.latest_version = Version()
        self.current_version = Version()
        self.initialize_queue()
        self.verbose_freq = verbose_freq
        self.batch_counter = 0

        # If macrobatching, push and pop versions at the right rate.
        if macrobatch:
            self.update_interval = self.num_versions
        else:
            self.update_interval = 1

    def initialize_queue(self):
        self.queue = deque(maxlen=self.num_versions)
        for i in range(self.num_versions):
            self.queue.append(self.get_params(clone=True))
        self.buffered_state_dicts = self.queue[0][0] # stash weght變數

    def load_old_params(self):
        if self.num_versions > 1:
            self.set_params(*self.queue[0]) #找到最初的舊weight

    def load_new_params(self):
        if self.num_versions > 1:
            self.set_params(*self.queue[-1]) # 載入最新的weight

    def zero_grad(self): # 用來reset
        if self.batch_counter % self.update_interval == 0:
            self.base_optimizer.zero_grad()

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                                          and returns the loss.
        """
        # 每 update_interval個 steps更新一次梯度
        if self.batch_counter % self.update_interval != self.update_interval - 1:
            self.batch_counter += 1
            return None
        
        # 省略程式碼
        
        self.latest_version = self.latest_version.incr() # 因為多訓練了一步,所以增加版本號
        if self.num_versions > 1:
            self.buffered_state_dicts = self.queue[0][0] 
            self.queue.append(self.get_params(clone=False)) # 把新的變數存進去

        self.batch_counter += 1
        return loss

0xFF 參考

lingvo框架走讀筆記

Tensorflow實現先累加多個minibatch計算的梯度,再反向傳播

用tensorflow2實現梯度累積

十倍模型計算時間僅增20%:OpenAI開源梯度替換外掛

PipeDream: Fast and Efficient Pipeline Parallel DNN Training

論文解讀系列第五篇:微軟斯坦福等PipeDream快速訓練大規模神經網路

https://cs231n.github.io/neural-networks-3/#gradcheck

https://www.cnblogs.com/geekfx/p/14182048.html

訓練時視訊記憶體優化技術——OP合併與gradient checkpoint

Pytorch筆記04-自定義torch.autograd.Function

PyTorch教程之Autograd

pytorch的自定義拓展之(三)——torch.autograd.Function的簡單定義與案例

pytorch的自定義拓展之(二)——torch.autograd.Function完成自定義層

PyTorch 原始碼解讀之 torch.autograd:梯度計算詳解

再談反向傳播(Back Propagation)

CS231n課程筆記翻譯:反向傳播筆記

Pytorch 分散式訓練

torch.distributed

GPT-3模型為何難以復現?這也許是分散式AI框架的最優設計

蘇起冬 - pipedream