面臨的深度學習技術問題以及基於TensorFlow的開發實踐
一、用於Bot的深度學習技術及其面臨的難題
聊天機器人(chatbot),也被稱為會話代理或對話系統,現已成為了一個熱門話題。微軟在聊天機器人上押上了重注,Facebook(M)、蘋果(Siri)、谷歌、微信和 Slack 等公司也是如此。通過開發 Operator 或 x.ai 這樣的消費者應用、 Chatfuel 這樣的 bot 平臺和 Howdy's Botkit 這樣的 bot 庫,新一波創業者們正在嘗試改變消費者與服務的互動方式。微軟最近釋出了他們自己的 bot 開發者框架。
許多公司希望開發出能進行無法與人類相區分的自然會話的 bot,並且許多公司聲稱要使用自然語言處理(NLP)和深度學習技術來使其成為可能。但是,在眾多關於人工智慧的大肆宣傳中,有時很難將現實和幻想區分開。
在這個系列文章中,我想回顧一些被用來設計對話代理的深度學習技術,我會從解釋我們現在所處的發展現狀、什麼是可能的和什麼在短時間內幾乎是不可能的開始。這篇文章作為導言,我們會在接下來的文章中談論實現細節。
模型分類
基於檢索式模型 vs 生成式模型
基於檢索式模型(更簡單)使用了預定義回覆庫和某種啟發式方法來根據輸入和語境做出合適的回覆。這種啟發式方法可以像基於規則的表示式匹配一樣簡單,也可以像機器學習分類器集一樣複雜。這些系統不會產生任何新文字,他們只是從固定的集合中挑選一種回覆而已。
生成式模型(更困難)不依賴於預定義回覆庫。他們從零開始生成新回覆。生成式模型通常基於機器翻譯技術,但區別於語言翻譯,我們把一個輸入「翻譯」成一個輸出「回覆」。
兩種方式都有明顯的優勢和劣勢。由於採用人工製作的回覆庫,基於檢索式方法不會犯語法錯誤。然而它們可能無法處理沒見過的情況,因為它們沒有合適的預定義回覆。同樣,這些模型不能重新提到上下文中的實體資訊,如先前對話中提到過的名字。生成式模型更「聰明」。它們可以重新提及輸入中的實體並帶給你一種正和人類對話的感覺。然而,這些模型很難訓練,很可能會犯語法錯誤(特別是長句),而且通常要求大量的訓練資料。
基於檢索式模型或生成式模型都可以應用深度學習技術,但是相關研究似乎正轉向生成式方向。像序列到序列(Sequence to Sequence)這樣的深度學習架構是唯一可以適用於產生文字的,並且研究者希望在這個領域取得快速進步。然而,我們仍處於構建工作良好的生成式模型的早期階段。現在的生產系統更可能是基於檢索式的。
長對話 vs 短對話
對話越長,就越難使它自動化。一方面,短文字對話(更簡單)的目標是單獨回覆一個簡單的輸入。例如,你可能收到一個使用者的特定問題並回複合適的答案。而長對話(更困難)要求你經歷多個轉折並需要記錄說過什麼話。客戶支援類對話通常是包含多個問題的長對話流。
開域(open domain) vs 閉域(closeddomain)
在開域(更困難)環境中,使用者可以進行任何對話。不需要明確定義的目標或意圖。像Twitter 和 Reddit 這種社交媒體網站上的對話通常是開域的——它們可以是任何主題。話題的無限數量和用於產生合理回覆的一定量的知識使開域成為了一個艱難的問題。
在閉域(更簡單)設定中,因為系統試圖達成一個非常明確的目標,可能輸入和輸出的空間會有所限制。例如客戶技術支援或購物助手就屬於閉域的範疇。這些系統不需要能談論政治,它們只需要儘可能高效地完成它們特定的任務。當然,使用者仍然可以進行任何他們想要的對話,但是這樣的系統不需要能處理所有情況,並且使用者也不期望它能處理。
普遍難題
在構建大部分屬於活躍研究領域的會話代理方面存在著許多明顯和不明顯的難題。
1. 整合語境
為了生成明智的回覆,系統可能需要整合語言語境(linguistic context)和物理語境(physical context)。在長對話中,人們記錄已經被說過的話和已經交換過的資訊。這是結合語言語境的例子。最普遍的方法是將對話嵌入一個向量中,但在長對話上進行這樣的操作是很有挑戰性的。「使用生成式分層神經網路模型構建端到端對話系統」和「神經網路對話模型的注意與意圖」兩個實驗中都選擇了這個研究方向。此外還可能需要整合其它型別的語境資料,例如日期/時間、位置或使用者資訊。
2. 一致人格
當生成回覆時,對於語義相同的輸入,代理應該生成相同的回答。例如,你想在「你多大了?」和「你的年齡是多少?」上得到同樣的回答。這聽起來很簡單,但是將固定的知識或者「人格」整合進模型是非常困難的研究難題。許多系統學習如何生成語義合理的回覆,但是它們沒有被訓練如何生成語義上一致的回覆。這一般是因為它們是基於多個不同使用者的資料訓練的。「基於個人的神經對話模型」這樣的模型是明確的對人格建模的方向上的第一步。
3. 模型評估
評估一個對話代理的理想方式是衡量它是否完成了它的任務,例如,在給定對話中解決客戶支援問題。但是獲取這樣的標籤成本高昂,因為它們要求人類的判斷和評估。某些時候並不存在明確定義的目標,比如開域模型中的情況。通常像 BLEU 這樣被用於機器翻譯且是基於文字匹配的標準並不能勝任,因為智慧的回覆可以包括完全不同的單詞或短語。實際上,在 How NOT To Evaluate Your Dialogue System: An Empirical Study of UnsupervisedEvaluation Metrics for Dialogue Response Generation 中,研究者發現沒有一個通用的度量能真正與人類判斷一一對應。
4. 意圖和多樣性
生成式系統的普遍問題是它們往往能生成像「太好了!」或「我不知道」這樣的能適用於許多輸入情況的普遍回覆。谷歌的智慧回覆(Smart Reply )早期版本常常用「我愛你」回覆一切。一定程度上這是系統根據資料和實際訓練目標/演算法訓練的結果。然而,人類通常使用針對輸入的回覆並帶有意圖。因為生成系統(特別是開域系統)是不被訓練成有特定意圖的,所以它們缺乏這種多樣性。
實際工作情況
縱觀現在所有最前沿的研究,我們發展到哪裡了?這些系統的實際工作情況如何?讓我們再看看我們的分類法。一個基於檢索式開域系統顯然是不可能實現的,因為你不能人工製作出足夠的回覆來覆蓋所有情況。生成式開域系統幾乎是人工通用智慧(AGI: Artificial General Intelligence),因為它需要處理所有可能的場景。我們離 AGI 還非常遙遠(但是這個領域有許多研究正在進行)。
這就讓我們的問題進入了生成式和基於檢索式方法都適用的受限的領域。對話越長,語境就越重要,問題就變得越困難。
在最近的採訪中,現任百度首席科學家吳恩達說得很好:
當今深度學習的價值在你可以獲得許多資料的狹窄領域內。有一件事它做不到:進行有意義的對話。存在一些演示,並且如果你仔細挑選這些對話,看起來就像它正在進行有意義的對話,但是如果你親自嘗試,它就會快速偏離軌道。
許多公司從外包他們的對話業務給人類工作者開始,並承諾一旦他們收集到了足夠的資料,他們就會使其「自動化」。只有當他們在一個相當狹窄的領域中這樣操作時,這才有可能發生——比如呼叫 Uber 的聊天介面。任何稍微多點開域的事(像銷售郵件)就超出了我們現在的能力範圍。然而,我們也可以利用這些系統建議和改正回覆來輔助人類工作者。這就更符合實際了。
生產系統的語法錯誤成本很高並會趕走使用者。所以,大多數系統可能最好還是使用不會有語法錯誤和不禮貌回答的基於檢索式方法。如果公司能想辦法得到大量的資料,那麼生成式模型就將是可行的——但是,必須需要其它技術的輔助來防止它們像微軟的 Tay 一樣脫軌。
總結和推薦閱讀
背靠科技巨頭的資源或創業公司的熱情的聊天機器人們正在努力擠進我們生活的方方面面。在一般的話題上,它們還遠不能實現如人類水平的溝通和語境理解能力;但在購物助手等狹窄領域的的應用中,訓練有素的聊天機器人完全可以取代人類來協助和提升消費者的購物體驗。如果你對相關研究感興趣,那麼,下列論文可以讓你開個好頭:
-
Neural Responding Machine for Short-Text Conversation (2015-03)
-
A Neural Conversational Model (2015-06)
-
A Neural Network Approach to Context-Sensitive Generation of Conversational Responses (2015-06)
-
The Ubuntu Dialogue Corpus: A Large Dataset for Research in Unstructured Multi-Turn Dialogue Systems (2015-06)
-
Building End-To-End Dialogue Systems Using Generative Hierarchical Neural Network Models (2015-07)
-
A Diversity-Promoting Objective Function for Neural Conversation Models (2015-10)
-
Attention with Intention for a Neural Network Conversation Model (2015-10)
-
Improved Deep Learning Baselines for Ubuntu Corpus Dialogs (2015-10)
-
A Survey of Available Corpora for Building Data-Driven Dialogue Systems (2015-12)
-
Incorporating Copying Mechanism in Sequence-to-Sequence Learning (2016-03)
-
A Persona-Based Neural Conversation Model (2016-03)
-
How NOT To Evaluate Your Dialogue System: An Empirical Study of Unsupervised Evaluation Metrics for Dialogue Response Generation (2016-03)
二、如何在TensorFlow上部署一個基於檢索的Bot模型
我們將在這篇文章中探討基於檢索的 Bot 的實現。基於檢索的模型使用了一個預定義響應庫(repository of pre-defined responses),這一點不同於生成模型——它們會生成它們從未見過的響應。將一個語境(截至於該點的對話)和一個潛在的響應輸入檢索模型後,模型會輸出這個響應的得分。為了找到好的響應,我們可以計算多個響應的得分,然後選擇其中得分最高的響應。
但是如果你能構建生成模型,那為什麼又要建立一個檢索模型呢?生成模型看上去更靈活因為它們不需要這種預定義的響應庫,對嗎?
問題是,實踐中的檢索模型表現得並不好。至少現在還不夠好。因為它們在響應方式上的自由度太大了。生成模型往往會產生語法錯誤,而且會生成不相關的、通用或者不一致的響應。它們還需要巨量的訓練資料,而且難以優化。當下絕大部分的產品系統都是以檢索為基礎的,或者是基於檢索和生成方法的結合。谷歌的 Smart Replay 是很好的例子。生成模型的研究領域很活躍,但我們還沒達到實用的程度。今天如果你想構建一個聊天代理,最好還是選擇基於檢索的模型。
Ubuntu 對話語料庫
本文中,我們使用了 Ubuntu 對話預料庫(Ubuntu Dialog Corpus,UDC)(https://github.com/rkadlec/ubuntu-ranking-dataset-creator),UDC 是最大的公共對話資料庫之一。它以一個公共 IRC 網路上的 Ubuntu 頻道為基礎 。有篇論文詳細描述了創造這個語料庫的過程,所以在此不再贅述。但是,理解我們使用的資料的型別是非常重要的,這裡先做一些探索。
這個訓練資料由 1,000,000 個樣本構成,其中 50% 是積極的 (標籤為 1),50% 是消極的(標籤為 0)。每個樣本都由一個語境(context)(也就是截至於該點的對話)、一個話語(utterance)和一個語境的響應(response)構成。積極標籤表示這個話語是對這個語境的真實響應,消極標籤代表的則不是真實的響應——而是從語料庫中隨機抽取的。下面是一些樣本資料。
注意,資料庫生成指令碼已經為我們做好了一堆預處理——使用 NLTK 工具給輸出內容打上代號封存起來,然後把輸出的內容按照異體形式進行歸類(lemmatize)。這個指令碼也代替了名稱、位置、組織、URL 和帶有特殊代號的系統路徑等實體。這個預處理過程不是完全必要的,但是它可能能將整個效能表現提升幾個百分點。每個語境平均有 86 個單詞,而每個話語平均包含 17 個單詞。
這個資料集帶有測試和驗證集,它們的格式不同於訓練資料的格式。測試/驗證集中的每一個記錄都由一個語境、一個真實話語(真實響應)和 9 個稱為干擾項( distractor)的不正確話語構成。這個模型的目標會將最高分分配給這個真實的話語,給錯誤的話語分配較低分數。
評估模型有很多種方法。常用的標準是 [email protected]。[email protected] 意思是我們讓這個模型從 10 個可能的響應中挑出 k 個最好的(1 個真實和 9 個干擾項)響應。如果這個正確的響應在這些選擇出的響應中,我們就將該測試樣本標記為正確。所以,一個更大的 k 意味著這個任務會更簡單。如果我們設 k = 10, 那麼我們 100% 需要重新呼叫,因為我們總共只有 10 個可選的響應。如果我們設 k=1, 那麼這個模型挑出正確的迴應的機會只有 1 次。
在這裡你也許疑問這 9 個干擾項是怎麼選出來的。這個資料集中的 9 個干擾項是隨機選出來的。但是,在真實世界,你也許有幾百萬個可能的響應,而且你不知道哪一個是正確的。你不可能能夠評估一百萬個潛在的響應,然後選出得分最高的一個——這個代價太大了。谷歌的 Smart Replay 使用聚類技術得出一組可能的響應來一個一個挑選。或者,如果你總共只有一百個潛在的響應,你也可以全部評估它們。
基線
在開始使用神奇的神經網路模型之前,我們先建立一些簡單的基線模型(baseline model)來幫助理解我們預計會得到什麼型別的效能。我們會使用下面的函式來估算我們的 [email protected] 標準:
def evaluate_recall(y, y_test, k=1):
num_examples = float(len(y))
num_correct = 0
for predictions, label in zip(y, y_test):
if label in predictions[:k]:
num_correct += 1
return num_correct/num_examples
其中,y 是我們按得分遞減順序列出的列表,y_test 是真實標籤。例如,一個 [0,3,1,2,5,6,4,7,8,9] 構成的 y 中可能是 0 號話語得分最高、9號話語得分最低。記住每一個測試樣本有 10 個話語,第一個(索引0)總是正確的那一個,因為這個話語列產生在另外 9 個干擾項之前。
直觀來看,一個完全隨機的預測得分應該是:[email protected] 為 10%、[email protected] 為20%,等等。我們來看看是不是這種情況。
# Random Predictor
def predict_random(context, utterances):
return np.random.choice(len(utterances), 10, replace=False)
# Evaluate Random predictor
y_random = [predict_random(test_df.Context[x], test_df.iloc[x,1:].values) for x in range(len(test_df))]
y_test = np.zeros(len(y_random))
for n in [1, 2, 5, 10]:
print("Recall @ ({}, 10): {:g}".format(n, evaluate_recall(y_random, y_test, n)))
Recall @ (1, 10): 0.0937632
Recall @ (2, 10): 0.194503
Recall @ (5, 10): 0.49297
Recall @ (10, 10): 1
很棒,看起來有用。我們當然不僅僅想要一個隨機預測。最初的那篇論文中討論的另一個基線是 tf-idf 預測。 tf-idf 意思是「術語頻率-逆文件頻率(term frequency – inverse document frequency)」,它測量文件中的一個單詞與整個語料相關的重要性。這裡不過多深入細節(網上有許多關於 tf-idf 的教程),具有相似內容的文件將會相似的 tf-idf 向量。直觀上看,如果一個語境和一個響應有相似的單詞,它們就更可能是正確的一對。至少比隨機出來的一對更有可能。現有的許多資料庫(如 scikit-learn)都內建了 tf-idf 函式,所以非常容易使用。我們建一個 tf-idf 看看它表現如何。
class TFIDFPredictor:
def __init__(self):
self.vectorizer = TfidfVectorizer()
def train(self, data):
self.vectorizer.fit(np.append(data.Context.values,data.Utterance.values))
def predict(self, context, utterances):
# Convert context and utterances into tfidf vector
vector_context = self.vectorizer.transform([context])
vector_doc = self.vectorizer.transform(utterances)
# The dot product measures the similarity of the resulting vectors
result = np.dot(vector_doc, vector_context.T).todense()
result = np.asarray(result).flatten()
# Sort by top results and return the indices in descending order
return np.argsort(result, axis=0)[::-1]
# Evaluate TFIDF predictor
pred = TFIDFPredictor()
pred.train(train_df)
y = [pred.predict(test_df.Context[x], test_df.iloc[x,1:].values) for x in range(len(test_df))]
for n in [1, 2, 5, 10]:
print("Recall @ ({}, 10): {:g}".format(n, evaluate_recall(y, y_test, n)))
Recall @ (1, 10): 0.495032
Recall @ (2, 10): 0.596882
Recall @ (5, 10): 0.766121
Recall @ (10, 10): 1
我們能看到 tf-idf 模型表現十分出色,遠比隨機模型要好。不過還遠不夠完美,我們也沒有那麼好的假設。首先,一個正確的響應沒必要和語境相似。第二, tf-idf 忽略了單詞順序,而這可能是重要的訊號。有了神經網路模型我們能做得更好。
雙重編碼器 LSTM
本文中我們將要建立一個叫雙重編碼器 LSTM(Dual Encoder LSTM)網路的深度學習模型。這種網路只是我們能應用在這個問題上的網路型別之一,而且它不一定是最好的一個。你可以試試所有還沒被試過的深度學習架構——這個領域的研究非常活躍。例如,常用在機器翻譯中的 seq2seq 模型也可能可以很好地完成這個任務。我們選擇這個雙重編碼器的原因是它已經被多次報道過在這個資料集上的表現非常好。這就意味著我們知道預計能得到什麼,也能明確我們的實施是正確的。應用其他模型解決這個問題也會非常有趣。
我們將要建立的雙重編碼器 LSTM 是這樣的(出自論文:The Ubuntu Dialogue Corpus: A Large Dataset for Research in Unstructured Multi-Turn Dialogue Systems):
它大概做以下工作:
1. 會被單詞分裂,而且每個單詞都嵌入到一個向量中。詞嵌入(word embeddings )由斯坦福的 GloVe 向量初始化,然後在訓練中進行精細的調整(邊注:這是可選的,且未在圖中展示。我發現用 GloVe 初始化詞向量不會在模型效能上有很大區別。)
2. 嵌入的語境和響應被一個單詞一個單詞地饋送到同一個迴圈神經網路。這個迴圈神經網路生成一個不太嚴格的向量表徵,該表徵獲取了語境及其響應(圖中的 c 和 r )的「意義」。我們可以選擇這些向量的大小,而我們選擇的尺寸是 256 維。
3. 我們用 c 乘以矩陣 M 來「預測」響應 r'。如果 c 是一個 256 維的向量,那麼 M 就是一個 256×256 維的矩陣。結果是另一個 256 維的向量,我們可以將它解釋為一個生成的響應。矩陣 M 是在訓練中被學習到的。
4. 我們通過計算這個兩個向量的點積測量預測響應 r' 與真實響應 r 的相似度。大點積表示這些向量是相似的,則該響應應該得到高分。之後我們應用一個 sigmoid 函式把這個分數轉換成概率。注意第 3 步和第 4 步被融合在一個圖中。
訓練這個網路,我們還需要一個損失(成本)函式。我們使用的是分類問題中常用的二元交叉熵損失( binary cross-entropy loss)函式。讓我們將用於語境-響應對的真實標籤稱為 y,其要麼是 1(實際響應),要麼是 0(不正確響應)。讓我們將上面第 4 步中的預測概率稱為 y'. 那麼,該交叉熵損失可通過 L= −y * ln(y') − (1 − y) * ln(1−y) 計算得到。這個公式之後的直觀知識很簡單。如果 y=1,則只剩下 L = -ln(y'),會使預測處於遠離 1 的不利位置;而如果 y=0,則只剩下 L= −ln(1−y),會使預測處於遠離 0 的不利位置.
在實施中,我們會結合使用 numpy 、pandas Tensorflow 和 TF Learn( Tensorflow 高層便利函式( high-level convenience functions )的綜合體)。
資料預處理
資料庫最初以 CSV 的格式出現。我們能直接使用 CSV ,但更好的方法是將我們的資料轉換成 Tensorflow 的優先的 Example 格式 。(邊注:也有 tf.SequenceExample)但是它看上去還並不被 tf.learn 支援。)這種格式的主要好處是它允許我們直接從輸入檔案中載入張量(tensor),並且讓 Tensorflow 處理輸入所有的混排(shuffling)、分批(batching)和列隊(queuing)。
每一個 Example 都包含以下域(field):
-
context:一個表示語境文字的單詞標識序列,比如 [231, 2190, 737, 0, 912]
-
context_len: 語境長度,比如以上例子為 5
-
utterance:一個表示話語(響應)的單詞標識序列
-
utterance_len:話語的長度
-
label: 只在訓練資料中。0 或 1。
-
distractor_[N]: 只在測試/驗證資料中。N 的範圍從 0 到 8。一個表示干擾項話語的單詞標識序列。
-
distractor_[N]_len :只在測試/驗證資料中。N 的範圍從 0 到 8。話語的長度。
由 Python 指令碼 prepare_data.py完成這個預處理,它會生成 3 個檔案: train.tfrecords、validation.tfrecords 和 test.tfrecords.你可以自己執行這個指令碼,或者在 https://drive.google.com/open?id=0B_bZck-ksdkpVEtVc1R6Y01HMWM 下載資料。
建立一個輸入函式
為了使用 Tensorflow 中對訓練和評估的內建支援,我們需要建立一個輸入函式——該函式能返回輸入資料的批次。事實上,因為我們的訓練和測試資料有不同的格式,我們需要為它們提供不同輸入函式。輸入函式應該返回一批特徵和標籤(如果可以的話)。用程式碼表示為:
def input_fn():
# TODO Load and preprocess data here
return batched_features, labels
因為在訓練和評估過程中我們需要不同的輸入函式,而且因為我們厭惡程式碼重複,所以我們創造了一個名叫 create_input_fn 包裝函式,其能建立合適模式的輸入函式。它也需要一些其它引數。我們使用的定義如下:
def create_input_fn(mode, input_files, batch_size, num_epochs=None):
def input_fn():
# TODO Load and preprocess data here
return batched_features, labels
return input_fn
完整程式碼可見於檔案 udc_inputs.py。在一個較高的層面上,該函式實現了以下功能:
-
在我們的 Example 檔案中創造了一個描述域(field)的特徵定義
-
使用 tf.TFRecordReader 從 input_files 中讀取記錄
-
根據特徵定義解析記錄
-
提取訓練標籤
-
批處理多個樣本和訓練標籤
-
返回批處理過的樣本的訓練標籤
定義評估標準
我們已經提到了我們想使用 [email protected] 標準來評估我們的模型。幸運的是,TensorFlow 已經具備了許多我們可以使用的評估標準,其中就包括 [email protected]。為了使用這些標準,我們需要創造一個能夠將標準名(metric name)對映到做出預測的函式和作為引數的標籤的詞典。
def create_evaluation_metrics():
eval_metrics = {}
for k in [1, 2, 5, 10]:
eval_metrics["recall_at_%d" % k] = functools.partial(
tf.contrib.metrics.streaming_sparse_recall_at_k,
k=k)
return eval_metrics
上面,我們使用了 functools.partial 將帶有三個引數的函式轉換成只帶有 2 個引數的函式。不要對 streaming_sparse_recall_at_k 的名字感到困惑。其中的「streaming(流)」只是意味著該標準是在多批次處理上的積累,而「sparse(稀疏)」則是指我們的標籤的格式。
這就將我們帶到了一個非常重要的點上:評估過程中我們的預測的格式究竟是什麼?在訓練過程中,我們預測了樣本正確的概率。但在評估過程中,我們的目標是對錶達和 9 個干擾項進行評分,然後選出最好的一個——我們並不是簡單地預測正確或不正確。這意味著:在評估過程中,每個樣本都會得到一個包含 10 個分數的向量,如[0.34, 0.11, 0.22, 0.45, 0.01, 0.02, 0.03, 0.08, 0.33, 0.11],其中這些分數分別對應於真實迴應和 9 個干擾項。其中對每一個表達的評分都是獨立的,所以這些概率之和不需要等於 1。 因為真實迴應總是陣列中的元素 0,所以每個樣本的標籤是 0。上面的樣本可以根據 [email protected] 而被不正確地分類,因為第 3 個干擾項獲得了 0.45 的概率,而真實迴應僅有 0.34。 但其會被 [email protected] 評分為正確。
樣板檔案訓練程式碼
在編寫實際的神經網路程式碼之前,我喜歡先編寫訓練以及評估模型的樣板檔案程式碼。這是因為,一旦你遵守正確的介面,就非常容易置換出使用的網路的型別。假設我們有一個模型函式 model_fn,把我們的批量特徵、標記和模式(訓練或者評估),然後返回預測。接下來,我們就能像下面這樣編寫訓練模型的通用程式碼:
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=MODEL_DIR,
config=tf.contrib.learn.RunConfig())
input_fn_train = udc_inputs.create_input_fn(
mode=tf.contrib.learn.ModeKeys.TRAIN,
input_files=[TRAIN_FILE],
batch_size=hparams.batch_size)
input_fn_eval = udc_inputs.create_input_fn(
mode=tf.contrib.learn.ModeKeys.EVAL,
input_files=[VALIDATION_FILE],
batch_size=hparams.eval_batch_size,
num_epochs=1)
eval_metrics = udc_metrics.create_evaluation_metrics()
# We need to subclass theis manually for now. The next TF version will
# have support ValidationMonitors with metrics built-in.
# It's already on the master branch.
class EvaluationMonitor(tf.contrib.learn.monitors.EveryN):
def every_n_step_end(self, step, outputs):
self._estimator.evaluate(
input_fn=input_fn_eval,
metrics=eval_metrics,
steps=None)
eval_monitor = EvaluationMonitor(every_n_steps=FLAGS.eval_every)
estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
在這裡,我們為 model_fn 創造了一個評估器,兩個訓練以及評估資料的輸入函式,以及我們的評估引數資料。我們也定義了一個監控器,能在訓練過程中評估模型的每一個 FLAGS.eval_every 步驟。最終,我們訓練了這個模型。訓練的執行是無限期的,但 Tensorflow 自動儲存了檢查點檔案到 MODEL_DIR,所以你能隨時停止訓練。一個更精緻的技術是使用早停止,這意味著你可以在一個驗證集引數停止改進(即,開始過擬合)時自動停止訓練。你可以在 udc_train.py(https://github.com/dennybritz/chatbot-retrieval/blob/master/udc_train.py)看到完整程式碼。
我想提及的兩件事是 FLAGS 的使用。這是一種將命令列引數給予到程式設計的方式(類似於 Python 的 argparse)。 hparams 是我們在 hparams.py 創造的一個自定義物件,持有模型中我們能調整的超引數、nobs。這一 hparams 物件在我們具現化的時候會被交給模型。
建立模型
如今,我們已經建立了輸入、解析、評估和訓練的樣本檔案程式碼,是時候為雙重 LSTM 神經網路編寫程式碼了。因為我們有不同格式的訓練和評估資料,我寫了一個 create_model_fn 包裝器,它可以為我們帶來正確格式的資料。這個包裝器採用 amodel_impl 引數(argument),這是一個真正可以做預測的函式。在我們的案例中,我們上面描述的是 雙重編碼器 LSTM,但我們能輕鬆的將其替換為其他神經網路。讓我們看一下它是什麼樣的:
def dual_encoder_model(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
targets):
# Initialize embedidngs randomly or with pre-trained vectors if available
embeddings_W = get_embeddings(hparams)
# Embed the context and the utterance
context_embedded = tf.nn.embedding_lookup(
embeddings_W, context, name="embed_context")
utterance_embedded = tf.nn.embedding_lookup(
embeddings_W, utterance, name="embed_utterance")
# Build the RNN
with tf.variable_scope("rnn") as vs:
# We use an LSTM Cell
cell = tf.nn.rnn_cell.LSTMCell(
hparams.rnn_dim,
forget_bias=2.0,
use_peepholes=True,
state_is_tuple=True)
# Run the utterance and context through the RNN
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(
cell,
tf.concat(0, [context_embedded, utterance_embedded]),
sequence_length=tf.concat(0, [context_len, utterance_len]),
dtype=tf.float32)
encoding_context, encoding_utterance = tf.split(0, 2, rnn_states.h)
with tf.vari