1. 程式人生 > 其它 >python 3.6 生成器

python 3.6 生成器

最近在做bert文字分類,有一個生成器,記錄一下使用,跟我網上查到的不太一樣,主要在.iter()這個地方,很多程式碼都是沒有這個,不知道是不是版本原因

datalist, labellist = get_data_from_excel(r'data/test.xlsx')
data = data_generator(datalist).__iter__() # 注意這個.__iter__()
# 獲取一批資料
print(next(data))
# 或者
for x in data:
   print(x)
點選檢視程式碼
class data_generator:
    """
    data_generator只是一種為了節約記憶體的資料方式
    """
    def __init__(self, data, batch_size=Batch_size, shuffle=True):
        """
        :param data: 訓練的文字列表
        :param batch_size:  每次訓練的個數
        :param shuffle: 文字是否打亂
        """
        self.data = data
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.steps = len(self.data) // self.batch_size
        if len(self.data) % self.batch_size != 0:
            self.steps += 1

    def __len__(self):
        return self.steps

    def __iter__(self):
        while True:
            idxs = list(range(len(self.data)))  # 生成一個序列
            if self.shuffle:
                np.random.shuffle(idxs)  # 打亂序列
            X1, X2, Y = [], [], []
            for i in idxs:
                d = self.data[i]
                text = d[0][:maxlen]
                x1, x2 = tokenizer.encode(first=text)  # 新增[CLS]和[SEP]
                y = d[1]
                X1.append(x1)
                X2.append(x2)
                Y.append([y])
                if len(X1) == self.batch_size or i == idxs[-1]:
                    # 對一批資料(最後一批不滿batch_size)進行padding
                    X1 = seq_padding(X1)  # 內部轉為了np.array
                    X2 = seq_padding(X2)
                    Y = seq_padding(Y)
                    yield [X1, X2], Y[:, 0, :]
                    [X1, X2, Y] = [], [], []