1. 程式人生 > 資訊 >秒秒鐘揪出張量形狀錯誤,這個工具能防止 ML 模型訓練白忙一場

秒秒鐘揪出張量形狀錯誤,這個工具能防止 ML 模型訓練白忙一場

模型吭哧吭哧訓練了半天,結果發現張量形狀定義錯了,這一定沒少讓你抓狂吧。那麼針對這種情況,是否存在較好的解決方法呢?

這不最近,韓國首爾大學的研究者就開發出了一款“利器”—— PyTea。

據研究人員介紹,它在訓練模型前,能幾秒內幫助你靜態分析潛在的張量形狀錯誤

那麼 PyTea 是如何做到的,到底靠不靠譜,讓我們一探究竟吧。

PyTea 的出場方式

為什麼張量形狀錯誤這麼重要?

神經網路涉及到一系列的矩陣計算,前面矩陣的列數必需匹配後面矩陣的行數,如果維度不匹配,那後面的運算就都無法運行了。

上圖程式碼就是一個典型的張量形狀錯誤,[B x 120] * [80 x 10] 無法進行矩陣運算。

無論是 PyTorch,TensorFlow 還是 Keras 在進行神經網路的訓練時,大多都遵循圖上的流程。

首先定義一系列神經網路層(也就是矩陣),然後合成神經網路模組……

那麼為什麼需要 PyTea 呢?

以往我們都是在模型讀取大量資料,開始訓練,程式碼執行到錯誤張量處,才可以發現張量形狀定義錯誤。

由於模型可能十分複雜,訓練資料非常龐大,所以發現錯誤的時間成本會很高,有時候程式碼放在後臺訓練,出了問題都不知道……

PyTea 就可以有效幫我們避免這個問題,因為它能在執行模型程式碼之前,就幫我們分析出形狀錯誤。

網友們已經在熱烈討論了。

PyTea 是如何運作的,它能否有效地檢查出錯誤呢?

受各種約束條件的影響,程式碼可能的執行路徑有很多,不同的資料會走向不同的路徑。

所以 PyTea 需要靜態掃描所有可能的執行路徑,跟蹤張量變化,推斷出每個張量形狀精確而保守的範圍。

上圖就是 PyTea 的整體架構,一共分為翻譯語言,收集約束條件,求解器判斷和給出反饋四步。

首先 PyTea 將原始的 Python 程式碼翻譯成一種核心語言。PyTea 內部表示法(PyTea IR)。

接著 PyTea 追蹤 PyTea IR 每個可能的執行路徑,並收集有關張量形狀的約束條件。

判斷約束條件是否被滿足,分為線上分析和離線分析兩步

  • 線上分析 node.js(TypeScript / JavaScript):查詢張量形狀數值上的不匹配和誤用 API 函式的情況。如果 PyTea 發現問題,就會停止在當前位置,然後給使用者報錯。

  • 離線分析 Z3 / Python:如果線上分析沒有問題,PyTea 將收集到的約束條件傳給 SMT(Satisfiability Modulo Theories)求解器 Z3,求解器負責檢視每條路徑的約束條件是否都能被滿足,如果不能,返回給使用者第一條出錯路徑的約束條件。

如果求解器過久沒有反應,PyTea 會返回不知道是否存在問題。

然而追蹤所有可能的路徑是指數級別的任務,對於複雜的神經網路來說,一定會發生路徑爆炸這個問題。

比如說在這個例子中,網路的最終結構是由 24 個相同模組塊構成的(第 17 行),那麼可能的路徑就有 16M 之多。

所以路徑爆炸是一定要處理的,PyTea 是怎麼做的?

PyTea 選擇保守的地對路徑剪枝和超時判斷來處理這種路徑爆炸。

什麼樣的路徑可以被剪枝?

PyTea 給出的答案是,如果該前饋函式不改變全域性值,並且它的輸出值不受分支條件影響,對於每條路徑都是相等的,我們就可以忽略許多完全一致的路徑,來節約計算資源。

如果路徑剪枝還是不行,那麼就只能按超時處理了。

原理就介紹這麼多了,感覺還是值得一試的,現在程式碼已經在 GitHub 上面開源了,快去看看吧!

使用方法

依賴庫:

安裝方法:

執行命令:

參考連結

[1]https://github.com/ropas/pytea

[2]https://arxiv.org/abs/2112.09037