如何使用tensor2tensor自定義資料並訓練模型(很全面)
如何使用tensor2tensor自定義資料訓練模型
由於tensor2tensor高度的封裝,內部新增和一些資料集,和一些常見的問題,所以在直接用起來比較方便。但是如果想要用不同的資料訓練模型,或者是用模型解決一個其他的問題,就要費一番功夫了。
這裡主要是解決了用自己的資料集,使用tensor2tensor訓練一個英中翻譯模型,當然訓練中英,只需要加上`_rev`即可。
如果要使用自己的資料集,根據github文件可知,它是沒有告訴你怎麼做滴,那麼怎麼辦呢,就不做了麼? 當然不可能。文件中提到了可以定義自己的問題,然後在裡面可以定義一些內容,例如單詞表大小,資料集的位置,分詞方式等。嗯,看到了,是有資料集的位置的,那麼直接定義一個問題不就可以麼,然後在裡面指定相應的資料集位置,那麼看下程式碼:
完整程式碼在後面:
首先是要有一個自定義的使用者目錄,也就是引數‘--usr_dir ’ 的值。
接下來,建立一個 problem_name.py 檔案,並且裡面有__init__.py 這個檔案,並且在init.py 中把problem_name 匯入,這樣才能夠被`t2t-datagen`和`t2t-trainer`識別,並註冊到t2t裡面。就像下面這樣。
在建立完檔案之後就要對檔案的內容進行編寫了。
一些匯入檔案的程式碼略過(篇幅有限)
然後兩個資料集:
_NC_TRAIN_DATASETS = [[ "http://data.actnned.com/ai/machine_learning/dummy.tgz", ["raw-train.zh-en.en", "raw-train.zh-en.zh"] ]] _NC_TEST_DATASETS = [[ "http://data.actnned.com/ai/machine_learning/dummy.dev.tgz", ("raw-dev.zh-en.en", "raw-dev.zh-en.zh") ]]
上面程式碼:重要的也就是這兩個資料集了:其中一個是訓練集, 一個是測試集,開發集程式內部會進行分割,這裡就不考慮。
首先是列表內容元素的第一個連結指的是元素的位置,也就是網路位置,由於我們要是用的是本地的檔案,這裡就是一個殭屍檔案,也就是一個虛擬地址+殭屍壓縮檔案。主要作用是避免內部生成單詞表和資料的時候進行資料的下載。
後面一個"raw-train.zh-en.en", "raw-train.zh-en.zh" 也就是平行語料,也就是自己的資料集檔案,這裡面的檔案只要是處理乾淨就行了,關於分詞的話,谷歌內部的新的分詞方式subword基本能滿足使用,某些論文中甚至要優於bpe分詞方式。
def create_dummy_tar(tmp_dir, dummy_file_name):
dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
if not os.path.exists(dummy_file_path):
tf.logging.info("Generating dummy file: %s", dummy_file_path)
tar_dummy = tarfile.open(dummy_file_path, "w:gz")
tar_dummy.close()
tf.logging.info("File %s is already exists or created", dummy_file_name)
上面函式主要是為了防止t2t的資料生成工具進行下載,而建立殭屍壓縮檔案。對於每一個數據集都會進行檢查。
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
train = dataset_split == problem.DatasetSplit.TRAIN
train_dataset = self.get_training_dataset(tmp_dir)
datasets = train_dataset if train else _NC_TEST_DATASETS
for item in datasets:
dummy_file_name = item[0].split("/")[-1]
create_dummy_tar(tmp_dir, dummy_file_name)
s_file, t_file = item[1][0], item[1][1]
if not os.path.exists(os.path.join(tmp_dir, s_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
if not os.path.exists(os.path.join(tmp_dir, t_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)
source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
...
return text_problems.text2text_generate_encoded(
text_problems.text2text_txt_iterator(data_path + ".lang1",
data_path + ".lang2"),
source_vocab, target_vocab)
上面函式主要是生成樣本資料,也就是在data資料夾下面的一些資料,同樣如果在data目錄下面沒有單詞表檔案的話,會根據資料集生成單詞表檔案。
至此,基本已經完成了所有操作,只需要用 t2t-datagen 和 t2t-trainer 生成資料並進行訓練即可!
另外,提一下,自定義的類名應該是駝峰法命名,定義的問題對應根據駝峰規則用橫線隔開,例如這裡我定義的是:translate_enzh_sub32k,對應類名 TranslateEnzhSub32k。
~ ~
完成程式碼:
# coding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import tokenizer
from tensor2tensor.utils import registry
import tensorflow as tf
from collections import defaultdict
_NC_TRAIN_DATASETS = [[
"http://data.actnned.com/ai/machine_learning/dummy.tgz",
["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]
_NC_TEST_DATASETS = [[
"http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]
def create_dummy_tar(tmp_dir, dummy_file_name):
dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
if not os.path.exists(dummy_file_path):
tf.logging.info("Generating dummy file: %s", dummy_file_path)
tar_dummy = tarfile.open(dummy_file_path, "w:gz")
tar_dummy.close()
tf.logging.info("File %s is already exists or created", dummy_file_name)
def get_filename(dataset):
return dataset[0][0].split("/")[-1]
@registry.register_problem
class TranslateEnzhSub32k(translate.TranslateProblem):
"""Problem spec for WMT En-De translation, BPE version."""
# 設定單詞表生成大小
@property
def vocab_size(self):
return 32000
# 使用 bpe 進行分詞
# @property
# def vocab_type(self):
# return text_problems.VocabType.TOKEN
# 超過單詞表之後的詞的表示,None 表示用元字元替換
@property
def oov_token(self):
"""Out of vocabulary token. Only for VocabType.TOKEN."""
return None
@property
def approx_vocab_size(self):
return 32000
@property
def source_vocab_name(self):
return "vocab.enzh-sub-en.%d" % self.approx_vocab_size
@property
def target_vocab_name(self):
return "vocab.enzh-sub-zh.%d" % self.approx_vocab_size
def get_training_dataset(self, tmp_dir):
full_dataset = _NC_TRAIN_DATASETS
# 可以新增一些其他的資料集在這裡
return full_dataset
def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
train = dataset_split == problem.DatasetSplit.TRAIN
train_dataset = self.get_training_dataset(tmp_dir)
datasets = train_dataset if train else _NC_TEST_DATASETS
for item in datasets:
dummy_file_name = item[0].split("/")[-1]
create_dummy_tar(tmp_dir, dummy_file_name)
s_file, t_file = item[1][0], item[1][1]
if not os.path.exists(os.path.join(tmp_dir, s_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
if not os.path.exists(os.path.join(tmp_dir, t_file)):
raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)
source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
source_vocab = generator_utils.get_or_generate_vocab(
data_dir,
tmp_dir,
self.source_vocab_name,
self.approx_vocab_size,
source_datasets,
file_byte_budget=1e8)
target_vocab = generator_utils.get_or_generate_vocab(
data_dir,
tmp_dir,
self.target_vocab_name,
self.approx_vocab_size,
target_datasets,
file_byte_budget=1e8)
tag = "train" if train else "dev"
filename_base = "wmt_enzh_%sk_sub_%s" % (self.approx_vocab_size, tag)
data_path = translate.compile_data(tmp_dir, datasets, filename_base)
return text_problems.text2text_generate_encoded(
text_problems.text2text_txt_iterator(data_path + ".lang1",
data_path + ".lang2"),
source_vocab, target_vocab)
def feature_encoders(self, data_dir):
source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
return {
"inputs": source_token,
"targets": target_token,
}