1. 程式人生 > 程式設計 >PyTorch預訓練Bert模型的示例

PyTorch預訓練Bert模型的示例

本文介紹以下內容:
1. 使用transformers框架做預訓練的bert-base模型;
2. 開發平臺使用Google的Colab平臺,白嫖GPU加速;
3. 使用datasets模組下載IMDB影評資料作為訓練資料。

transformers模組簡介

transformers框架為Huggingface開源的深度學習框架,支援幾乎所有的Transformer架構的預訓練模型。使用非常的方便,本文基於此框架,嘗試一下預訓練模型的使用,簡單易用。

本來打算預訓練bert-large模型,發現colab上GPU視訊記憶體不夠用,只能使用base版本了。開啟colab,並且設定好GPU加速,接下來開始介紹程式碼。

程式碼實現

首先安裝資料下載模組和transformers包。

pip install datasets
pip install transformers

使用datasets下載IMDB資料,返回DatasetDict型別的資料.返回的資料是文字型別,需要進行編碼。下面會使用tokenizer進行編碼。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 列印前3條訓練資料

接下來載入tokenizer和模型.從transformers匯入AutoModelForSequenceClassification, AutoTokenizer,建立模型和tokenizer。

from transformers import AutoModelForSequenceClassification,AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=2)

對原始資料進行編碼,並且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'],padding=True,truncation=True,max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func,batched=True,batch_size=batch_size)

上面得到編碼資料,每個批次設定為16.接下來需要指定訓練的引數,訓練引數的指定使用transformers給出的介面類TrainingArguments,模型的訓練可以使用Trainer。

from transformers import Trainer,TrainingArguments

args = TrainingArguments(
  'out',per_device_train_batch_size=batch_size,per_device_eval_batch_size=batch_size,learning_rate=5e-5,evaluation_strategy='epoch',num_train_epochs=10,load_best_model_at_end=True,)

trainer = Trainer(
  model,args=args,train_dataset=encoded_data['train'],eval_dataset=encoded_data['test'],tokenizer=tokenizer
)

訓練模型使用trainer物件的train方法

trainer.train()

PyTorch預訓練Bert模型的示例

評估模型使用trainer物件的evaluate方法

trainer.evaluate()

總結

本文介紹了基於transformers框架實現的bert預訓練模型,此框架提供了非常友好的介面,可以方便讀者嘗試各種預訓練模型。同時datasets也提供了很多資料集,便於學習NLP的各種問題。加上Google提供的colab環境,資料下載和預訓練模型下載都非常快,建議讀者自行去煉丹。本文完整的案例下載

以上就是PyTorch預訓練Bert模型的示例的詳細內容,更多關於PyTorch預訓練Bert模型的資料請關注我們其它相關文章!