基於Bert和通用句子編碼的Spark-NLP文字分類
作者|Veysel Kocaman
編譯|VK
來源|Towards Data Science
自然語言處理(NLP)是許多資料科學系統中必須理解或推理文字的關鍵組成部分。常見的用例包括文字分類、問答、釋義或總結、情感分析、自然語言BI、語言建模和消歧。
NLP在越來越多的人工智慧應用中是越來越重要。如果你正在構建聊天機器人、搜尋專利資料庫、將患者與臨床試驗相匹配、對客戶服務或銷售電話進行分級、從財務報告中提取摘要,你必須從文字中提取準確的資訊。
文字分類是現代自然語言處理的主要任務之一,它是為句子或文件指定一個合適的類別的任務。類別取決於所選的資料集,並且可以從主題開始。
每一個文字分類問題都遵循相似的步驟,並用不同的演算法來解決。更不用說經典和流行的機器學習分類器,如隨機森林或Logistic迴歸,有150多個深度學習框架提出了各種文字分類問題。
文字分類問題中使用了幾個基準資料集,可以在nlpprogress.com上跟蹤最新的基準。以下是關於這些資料集的基本統計資料。
簡單的文字分類應用程式通常遵循以下步驟:
- 文字預處理和清理
- 特徵工程(手動從文字建立特徵)
- 特徵向量化(TfIDF、頻數、編碼)或嵌入(word2vec、doc2vec、Bert、Elmo、句子嵌入等)
- 用ML和DL演算法訓練模型。
Spark-NLP中的文字分類
在本文中,我們將使用通用句子嵌入(Universal Sentence Embeddings)在Spark NLP中建立一個文字分類模型。然後我們將與其他ML和DL方法以及文字向量化方法進行比較。
Spark NLP中有幾個文字分類選項:
- Spark-NLP中的文字預處理及基於Spark-ML的ML演算法
- Spark-NLP和ML演算法中的文字預處理和單詞嵌入(Glove,Bert,Elmo)
- Spark-NLP和ML演算法中的文字預處理和句子嵌入(
Universal Sentence Encoders
) - Spark-NLP中的文字預處理和ClassifierDL模組(基於TensorFlow)
正如我們在關於Spark NLP的重要文章中所深入討論的,在ClassifierDL之前的所有這些文字處理步驟都可以在指定的管道序列中實現,並且每個階段都是一個轉換器或估計器。這些階段按順序執行,輸入資料幀在通過每個階段時進行轉換。也就是說,資料按順序通過各個管道。每個階段的transform()
Universal Sentence Encoders
在自然語言處理(NLP)中,在建立任何深度學習模型之前,文字嵌入起著重要的作用。文字嵌入將文字(單詞或句子)轉換為向量。
基本上,文字嵌入方法在固定長度的向量中對單詞和句子進行編碼,以極大地改進文字資料的處理。這個想法很簡單:出現在相同上下文中的單詞往往有相似的含義。
像Word2vec和Glove這樣的技術是通過將一個單詞轉換成向量來實現的。因此,對應的向量“貓”比“鷹”更接近“狗”。但是,當嵌入一個句子時,整個句子的上下文需要被捕獲到這個向量中。這就是“Universal Sentence Encoders
”的功能了。
Universal Sentence Encoders
將文字編碼成高維向量,可用於文字分類、語義相似性、聚類和其他自然語言任務。在Tensorflow hub中可以公開使用預訓練的Universal Sentence Encoders
。它有兩種變體,一種是用Transformer編碼器訓練的,另一種是用深度平均網路(DAN)訓練的。
Spark NLP使用Tensorflow hub版本,該版本以一種在Spark環境中執行的方式包裝。也就是說,你只需在Spark NLP中插入並播放此嵌入,然後以分散式方式訓練模型。
為句子生成嵌入,無需進一步計算,因為我們並不是平均句子中每個單詞的單詞嵌入來獲得句子嵌入。
Spark-NLP中ClassifierDL和USE在文字分類的應用
在本文中,我們將使用AGNews資料集(文字分類任務中的基準資料集之一)在Spark NLP中使用USE和ClassifierDL構建文字分類器,後者是Spark NLP 2.4.4版中新增的最新模組。
ClassifierDL
是Spark NLP中第一個多類文字分類器,它使用各種文字嵌入作為文字分類的輸入。ClassifierDLAnnotator使用了一個在TensorFlow內部構建的深度學習模型(DNN),它最多支援50個類。
也就是說,你可以用這個classifirdl在Spark NLP中用Bert
、Elmo
、Glove
和Universal Sentence Encoders
構建一個文字分類器。
我們開始寫程式碼吧!
宣告載入必要的包並啟動一個Spark會話。
import sparknlp
spark = sparknlp.start()
# sparknlp.start(gpu=True) >> 在GPU上訓練
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
import pandas as pd
print("Spark NLP version", sparknlp.version())
print("Apache Spark version:", spark.version)
>> Spark NLP version 2.4.5
>> Apache Spark version: 2.4.4
然後我們可以從Github repo下載AGNews資料集(https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/tutorials/Certification_Trainings/Public)。
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_train.csv
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_test.csv
trainDataset = spark.read \
.option("header", True) \
.csv("news_category_train.csv")
trainDataset.show(10, truncate=50)
>>
+--------+--------------------------------------------------+
|category| description|
+--------+--------------------------------------------------+
|Business| Short sellers, Wall Street's dwindling band of...|
|Business| Private investment firm Carlyle Group, which h...|
|Business| Soaring crude prices plus worries about the ec...|
|Business| Authorities have halted oil export flows from ...|
|Business| Tearaway world oil prices, toppling records an...|
|Business| Stocks ended slightly higher on Friday but sta...|
|Business| Assets of the nation's retail money market mut...|
|Business| Retail sales bounced back a bit in July, and n...|
|Business|" After earning a PH.D. in Sociology, Danny Baz...|
|Business| Short sellers, Wall Street's dwindling band o...|
+--------+--------------------------------------------------+
only showing top 10 rows
AGNews資料集有4個類:World、Sci/Tech、Sports、Business
from pyspark.sql.functions import col
trainDataset.groupBy("category") \
.count() \
.orderBy(col("count").desc()) \
.show()
>>
+--------+-----+
|category|count|
+--------+-----+
| World|30000|
|Sci/Tech|30000|
| Sports|30000|
|Business|30000|
+--------+-----+
testDataset = spark.read \
.option("header", True) \
.csv("news_category_test.csv")
testDataset.groupBy("category") \
.count() \
.orderBy(col("count").desc()) \
.show()
>>
+--------+-----+
|category|count|
+--------+-----+
|Sci/Tech| 1900|
| Sports| 1900|
| World| 1900|
|Business| 1900|
+--------+-----+
現在,我們可以將這個資料提供給Spark NLP DocumentAssembler,它是任何Spark datagram的Spark NLP的入口點。
# 實際內容在description列
document = DocumentAssembler()\
.setInputCol("description")\
.setOutputCol("document")
#我們可以下載預先訓練好的嵌入
use = UniversalSentenceEncoder.pretrained()\
.setInputCols(["document"])\
.setOutputCol("sentence_embeddings")
# classes/labels/categories 在category列
classsifierdl = ClassifierDLApproach()\
.setInputCols(["sentence_embeddings"])\
.setOutputCol("class")\
.setLabelColumn("category")\
.setMaxEpochs(5)\
.setEnableOutputLogs(True)
use_clf_pipeline = Pipeline(
stages = [
document,
use,
classsifierdl
])
以上,我們獲取資料集,輸入,然後從使用中獲取句子嵌入,然後在ClassifierDL中進行訓練
現在我們開始訓練。我們將使用ClassiferDL
中的.setMaxEpochs()
訓練5個epoch。在Colab環境下,這大約需要10分鐘才能完成。
use_pipelineModel = use_clf_pipeline.fit(trainDataset)
執行此命令時,Spark NLP會將訓練日誌寫入主目錄中的annotator_logs資料夾。下面是得到的日誌。
如你所見,我們在不到10分鐘的時間內就實現了90%以上的驗證精度,而無需進行文字預處理,這通常是任何NLP建模中最耗時、最費力的一步。
現在讓我們在最早的時候得到預測。我們將使用上面下載的測試集。
下面是通過sklearn庫中的classification_report
獲得測試結果。
我們達到了89.3%的測試集精度!看起來不錯!
基於Bert和globe嵌入的Spark-NLP文字預處理分類
與任何文字分類問題一樣,有很多有用的文字預處理技術,包括詞幹、詞幹分析、拼寫檢查和停用詞刪除,而且除了拼寫檢查之外,Python中幾乎所有的NLP庫都有應用這些技術的工具。目前,Spark NLP庫是唯一一個具備拼寫檢查功能的可用NLP庫。
讓我們在Spark NLP管道中應用這些步驟,然後使用glove嵌入來訓練文字分類器。我們將首先應用幾個文字預處理步驟(僅通過保留字母順序進行標準化,刪除停用詞字和詞幹化),然後獲取每個標記的單詞嵌入(標記的詞幹),然後平均每個句子中的單詞嵌入以獲得每行的句子嵌入。
關於Spark NLP中的所有這些文字預處理工具以及更多內容,你可以在這個Colab筆記本中找到詳細的說明和程式碼示例(https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/2.Text_Preprocessing_with_SparkNLP_Annotators_Transformers.ipynb)。
那我們就可以訓練了。
clf_pipelineModel = clf_pipeline.fit(trainDataset)
得到測試結果。
現在我們有88%的測試集精度!即使在所有這些文字清理步驟之後,我們仍然無法擊敗Universal Sentence Embeddings
+ClassifierDL
,這主要是因為USE
相對於資料清理後的版本,它在原始文字上執行得更好。
為了訓練與BERT相同的分類器,我們可以在上面構建的同一管道中用BERT_embedding替換glove_embeddings。
word_embeddings = BertEmbeddings\
.pretrained('bert_base_cased', 'en') \
.setInputCols(["document",'lemma'])\
.setOutputCol("embeddings")\
.setPoolingLayer(-2) # default 0
我們也可以使用Elmo嵌入。
word_embeddings = ElmoEmbeddings\
.pretrained('elmo', 'en')\
.setInputCols(["document",'lemma'])\
.setOutputCol("embeddings")
使用LightPipeline進行快速推理
正如我們在前面的一篇文章中深入討論的,LightPipelines是Spark NLP特有的管道,相當於Spark ML管道,但其目的是處理少量的資料。它們在處理小資料集、除錯結果或從服務一次性請求的API執行訓練或預測時非常有用。
Spark NLP LightPipelines
是Spark ML管道轉換成在單獨的機器上,變成多執行緒的任務,對於較小的資料量(較小的是相對的,但5萬個句子大致最大值)來說,速度快了10倍以上。要使用它們,我們只需插入一個經過訓練的管道,我們甚至不需要將輸入文字轉換為DataFrame,就可以將其輸入到一個管道中,該管道首先接受DataFrame作為輸入。當需要從經過訓練的ML模型中獲得幾行文字的預測時,這個功能將非常有用。
LightPipelines很容易建立,而且可以避免處理Spark資料集。它們的速度也非常快,當只在驅動節點上工作時,它們執行平行計算。讓我們看看它是如何適用於我們上面描述的案例的:
light_model = LightPipeline(clf_pipelineModel)
text="Euro 2020 and the Copa America have both been moved to the summer of 2021 due to the coronavirus outbreak."
light_model.annotate(text)['class'][0]
>> "Sports"
你還可以將這個經過訓練的模型儲存到磁碟中,然後稍後在另一個Spark管道中與ClassifierDLModel.load()
一起使用。
結論
本文在Spark-NLP中利用詞嵌入和Universal Sentence Encoders,
訓練了一個多類文字分類模型,在不到10min的訓練時間內獲得了較好的模型精度。整個程式碼都可以在這個Github repo中找到(Colab相容,https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/5.Text_Classification_with_ClassifierDL.ipynb)。我們還準備了另一個Notebook,幾乎涵蓋了Spark NLP和Spark ML中所有可能的文字分類組合(CV、TfIdf、Glove、Bert、Elmo、USE、LR、RF、ClassifierDL、DocClassifier):https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/5.1_Text_classification_examples_in_SparkML_SparkNLP.ipynb。
我們還開始為公共和企業(醫療)版本提供線上Spark NLP訓練。這裡是所有公共Colab Notebook的連結(https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/tutorials/Certification_Trainings/Public)
John Snow實驗室將組織虛擬Spark NLP訓練,以下是下一次訓練的連結:
https://events.johnsnowlabs.com/online-training-spark-nlp
以上程式碼截圖
歡迎關注磐創AI部落格站:
http://panchuang.net/
sklearn機器學習中文官方文件:
http://sklearn123.com/
歡迎關注磐創部落格資源彙總站:
http://docs.panchuang.net/