1. 程式人生 > >如何使用tensor2tensor自定義資料並訓練模型(很全面)

如何使用tensor2tensor自定義資料並訓練模型(很全面)

如何使用tensor2tensor自定義資料訓練模型

由於tensor2tensor高度的封裝,內部新增和一些資料集,和一些常見的問題,所以在直接用起來比較方便。但是如果想要用不同的資料訓練模型,或者是用模型解決一個其他的問題,就要費一番功夫了。

這裡主要是解決了用自己的資料集,使用tensor2tensor訓練一個英中翻譯模型,當然訓練中英,只需要加上`_rev`即可。

如果要使用自己的資料集,根據github文件可知,它是沒有告訴你怎麼做滴,那麼怎麼辦呢,就不做了麼? 當然不可能。文件中提到了可以定義自己的問題,然後在裡面可以定義一些內容,例如單詞表大小,資料集的位置,分詞方式等。嗯,看到了,是有資料集的位置的,那麼直接定義一個問題不就可以麼,然後在裡面指定相應的資料集位置,那麼看下程式碼:

完整程式碼在後面:

首先是要有一個自定義的使用者目錄,也就是引數‘--usr_dir ’ 的值。

接下來,建立一個 problem_name.py 檔案,並且裡面有__init__.py 這個檔案,並且在init.py 中把problem_name 匯入,這樣才能夠被`t2t-datagen`和`t2t-trainer`識別,並註冊到t2t裡面。就像下面這樣。

在建立完檔案之後就要對檔案的內容進行編寫了。

一些匯入檔案的程式碼略過(篇幅有限)

然後兩個資料集:

_NC_TRAIN_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.tgz",
    ["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]

_NC_TEST_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
    ("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]

上面程式碼:重要的也就是這兩個資料集了:其中一個是訓練集, 一個是測試集,開發集程式內部會進行分割,這裡就不考慮。

首先是列表內容元素的第一個連結指的是元素的位置,也就是網路位置,由於我們要是用的是本地的檔案,這裡就是一個殭屍檔案,也就是一個虛擬地址+殭屍壓縮檔案。主要作用是避免內部生成單詞表和資料的時候進行資料的下載。

後面一個"raw-train.zh-en.en", "raw-train.zh-en.zh" 也就是平行語料,也就是自己的資料集檔案,這裡面的檔案只要是處理乾淨就行了,關於分詞的話,谷歌內部的新的分詞方式subword基本能滿足使用,某些論文中甚至要優於bpe分詞方式。

def create_dummy_tar(tmp_dir, dummy_file_name):
    dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
    if not os.path.exists(dummy_file_path):
        tf.logging.info("Generating dummy file: %s", dummy_file_path)
        tar_dummy = tarfile.open(dummy_file_path, "w:gz")
        tar_dummy.close()
    tf.logging.info("File %s is already exists or created", dummy_file_name)

上面函式主要是為了防止t2t的資料生成工具進行下載,而建立殭屍壓縮檔案。對於每一個數據集都會進行檢查。

def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
        train = dataset_split == problem.DatasetSplit.TRAIN
        train_dataset = self.get_training_dataset(tmp_dir)
        datasets = train_dataset if train else _NC_TEST_DATASETS
        for item in datasets:
            dummy_file_name = item[0].split("/")[-1]
            create_dummy_tar(tmp_dir, dummy_file_name)
            s_file, t_file = item[1][0], item[1][1]
            if not os.path.exists(os.path.join(tmp_dir, s_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
            if not os.path.exists(os.path.join(tmp_dir, t_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)

        source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
        target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]

    ...

    return text_problems.text2text_generate_encoded(
            text_problems.text2text_txt_iterator(data_path + ".lang1",
                                                 data_path + ".lang2"),
            source_vocab, target_vocab)

上面函式主要是生成樣本資料,也就是在data資料夾下面的一些資料,同樣如果在data目錄下面沒有單詞表檔案的話,會根據資料集生成單詞表檔案。

至此,基本已經完成了所有操作,只需要用 t2t-datagen 和 t2t-trainer 生成資料並進行訓練即可!

另外,提一下,自定義的類名應該是駝峰法命名,定義的問題對應根據駝峰規則用橫線隔開,例如這裡我定義的是:translate_enzh_sub32k,對應類名 TranslateEnzhSub32k。

~ ~

完成程式碼:

# coding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tarfile
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import translate
from tensor2tensor.data_generators import tokenizer
from tensor2tensor.utils import registry
import tensorflow as tf

from collections import defaultdict

_NC_TRAIN_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.tgz",
    ["raw-train.zh-en.en", "raw-train.zh-en.zh"]
]]

_NC_TEST_DATASETS = [[
    "http://data.actnned.com/ai/machine_learning/dummy.dev.tgz",
    ("raw-dev.zh-en.en", "raw-dev.zh-en.zh")
]]

def create_dummy_tar(tmp_dir, dummy_file_name):
    dummy_file_path = os.path.join(tmp_dir, dummy_file_name)
    if not os.path.exists(dummy_file_path):
        tf.logging.info("Generating dummy file: %s", dummy_file_path)
        tar_dummy = tarfile.open(dummy_file_path, "w:gz")
        tar_dummy.close()
    tf.logging.info("File %s is already exists or created", dummy_file_name)


def get_filename(dataset):
    return dataset[0][0].split("/")[-1]


@registry.register_problem
class TranslateEnzhSub32k(translate.TranslateProblem):
    """Problem spec for WMT En-De translation, BPE version."""

    # 設定單詞表生成大小
    @property
    def vocab_size(self):
        return 32000

    # 使用 bpe 進行分詞
    # @property
    # def vocab_type(self):
    #    return text_problems.VocabType.TOKEN

    # 超過單詞表之後的詞的表示,None 表示用元字元替換
    @property
    def oov_token(self):
        """Out of vocabulary token. Only for VocabType.TOKEN."""
        return None

    @property
    def approx_vocab_size(self):
        return 32000

    @property
    def source_vocab_name(self):
        return "vocab.enzh-sub-en.%d" % self.approx_vocab_size

    @property
    def target_vocab_name(self):
        return "vocab.enzh-sub-zh.%d" % self.approx_vocab_size

    def get_training_dataset(self, tmp_dir):
        full_dataset = _NC_TRAIN_DATASETS
        # 可以新增一些其他的資料集在這裡
        return full_dataset

    def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
        train = dataset_split == problem.DatasetSplit.TRAIN
        train_dataset = self.get_training_dataset(tmp_dir)
        datasets = train_dataset if train else _NC_TEST_DATASETS
        for item in datasets:
            dummy_file_name = item[0].split("/")[-1]
            create_dummy_tar(tmp_dir, dummy_file_name)
            s_file, t_file = item[1][0], item[1][1]
            if not os.path.exists(os.path.join(tmp_dir, s_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % s_file)
            if not os.path.exists(os.path.join(tmp_dir, t_file)):
                raise Exception("Be sure file '%s' is exists in tmp dir" % t_file)

        source_datasets = [[item[0], [item[1][0]]] for item in train_dataset]
        target_datasets = [[item[0], [item[1][1]]] for item in train_dataset]
        source_vocab = generator_utils.get_or_generate_vocab(
            data_dir,
            tmp_dir,
            self.source_vocab_name,
            self.approx_vocab_size,
            source_datasets,
            file_byte_budget=1e8)
        target_vocab = generator_utils.get_or_generate_vocab(
            data_dir,
            tmp_dir,
            self.target_vocab_name,
            self.approx_vocab_size,
            target_datasets,
            file_byte_budget=1e8)
        tag = "train" if train else "dev"
        filename_base = "wmt_enzh_%sk_sub_%s" % (self.approx_vocab_size, tag)
        data_path = translate.compile_data(tmp_dir, datasets, filename_base)
        return text_problems.text2text_generate_encoded(
            text_problems.text2text_txt_iterator(data_path + ".lang1",
                                                 data_path + ".lang2"),
            source_vocab, target_vocab)


    def feature_encoders(self, data_dir):
        source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
        target_vocab_filename = os.path.join(data_dir, self.target_vocab_name)
        source_token = text_encoder.SubwordTextEncoder(source_vocab_filename)
        target_token = text_encoder.SubwordTextEncoder(target_vocab_filename)
        return {
            "inputs": source_token,
            "targets": target_token,
        }