1. 程式人生 > 實用技巧 >pytorch+huggingface實現基於bert模型的文字分類(附程式碼)

pytorch+huggingface實現基於bert模型的文字分類(附程式碼)

從RNN到BERT

一年前的這個時候,我逃課了一個星期,從澳洲飛去上海觀看電競比賽,也順便在上海的一個公司聯絡了面試。當時,面試官問我對RNN的瞭解程度,我回答“沒有了解”。但我把這個問題帶回了學校,從此接觸了RNN,以及它的加強版-LSTM。

時隔一年,LSTM好像已經可以退出歷史舞臺。BERT站在了舞臺中間,它可以更快且更好的解決NLP問題。我打算以邊學習邊分享的方式,用BERT(GTP-2)過一遍常見的NLP問題。這一篇部落格是文字分類的baseline system。

BERT

如果你熟悉transformer,相信理解bert對你來說沒有任何難度。bert就是encoder的堆疊。

如果你不熟悉transformer,這篇文章是我見過的最棒的transformer圖解,可以幫助你理解:http://jalammar.github.io/illustrated-transformer/

當然這個作者也做出了很棒的bert圖解,連結在此:http://jalammar.github.io/illustrated-bert/

BERT做文字分類

bert是encoder的堆疊。當我們向bert輸入一句話,它會對這句話裡的每一個詞(嚴格說是token,有時也被稱為word piece)進行並列處理,併為每個詞輸出對應的向量。我們給輸入文字的句首新增一個[CLS] token(CLS為classification的縮寫),然後我們只考慮這個CLS token的對應輸出,用它來做classifier的輸入,最終輸出具體的分類。

使用Huggingface

Huggingface可以幫助我們輕易的完成文字分類任務。

通過它,我們可以輕鬆的讀取預訓練語言模型,以及使用它自帶的文字分類bert模型-BertForSequenceClassification

正式開始解決問題

資料介紹

資料來自Kaggle的competition:Real or Not? NLP with Disaster Tweets 連結:https://www.kaggle.com/c/nlp-getting-started

這是推特的資料集,資料的格式如下:

id location keyword text target
1 聖地亞哥 大火 聖地亞哥國家公園出現嚴重森林大火 1
2 矽谷 沙灘 今天在矽谷的沙灘晒太陽真開心 0

我們需要做的,就是根據推文的location、keyword 以及 text 來判斷這篇推文是否和災難有關。

它的現實意義在於,如果我們能夠根據推文來第一時間發現災難,有關部門就可以快速做出反應,將災難的損失降低到最小。就像前段時間溫嶺油罐車爆炸,群眾第一時間就把資訊、視訊上傳到了微博,消防部門可以通過微博獲取資訊。

探索式資料分析(EDA)與資料清理

在拿到資料後,我們需要進行探索式資料分析。由於這不是本篇部落格最重要的部分,這裡我只給出大體輪廓和結論。在我的kaggle notebook上有詳細的程式碼及plot。https://www.kaggle.com/jianweitang/nlp-with-disaster-tweets-eda

我們保留keyword這一列,摒棄location這一列。

有標籤的訓練資料有7613條,無標籤的測試資料有3263條

Training Set Shape: (7613, 5)
Test Set Shape: (3263, 4)

對於location這一列,它具有較多的缺失值,並且有非常多的unique values,暫且認為很難將他與災難直接聯絡到一起,我們直接把location這一列摒棄。

Number of unique values in keyword = 222 (Training) - 222 (Test)
Number of unique values in location = 3342 (Training) - 1603 (Test)

而對於keyword這一列,它的缺失值很少,unique values有222個。同時它與label之間有可見的相關性,有些詞只在災難推文中出現,有些詞只在非災難推文中出現。如下圖:

標籤的分佈是均勻的,這意味著我們可以直接把它拿來訓練模型

文字清潔

  • 去除特殊符號
  • 把縮寫及網路用語展開,例如把he's 展開為 he islmao 展開為laughing my ass off
  • 把hashtags和usernames展開
  • 糾正錯誤拼寫

推文錯誤標記

在資料中我們發現了重複的text被標記成了不同的標籤,大概有十幾個樣本。這些樣本可能是有爭議,也可能是單純的標記錯誤,在這裡我們直接刪掉這些樣本。

BERT預處理

import random
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda')

我們先讀取預訓練的 bert-base-uncased 模型,用來進行分詞,以及詞向量轉化

# Get text values and labels
text_values = train.final_text.values
labels = train.target.values

# Load the pretrained Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)  

來用這個tokenizer切分資料裡的第一條推文試試看

print('Original Text : ', text_values[1])
print('Tokenized Text: ', tokenizer.tokenize(text_values[1]))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text_values[1])))