python 3.6 生成器
阿新 • • 發佈:2022-06-05
最近在做bert文字分類,有一個生成器,記錄一下使用,跟我網上查到的不太一樣,主要在.iter()這個地方,很多程式碼都是沒有這個,不知道是不是版本原因
datalist, labellist = get_data_from_excel(r'data/test.xlsx')
data = data_generator(datalist).__iter__() # 注意這個.__iter__()
# 獲取一批資料
print(next(data))
# 或者
for x in data:
print(x)
點選檢視程式碼
class data_generator: """ data_generator只是一種為了節約記憶體的資料方式 """ def __init__(self, data, batch_size=Batch_size, shuffle=True): """ :param data: 訓練的文字列表 :param batch_size: 每次訓練的個數 :param shuffle: 文字是否打亂 """ self.data = data self.batch_size = batch_size self.shuffle = shuffle self.steps = len(self.data) // self.batch_size if len(self.data) % self.batch_size != 0: self.steps += 1 def __len__(self): return self.steps def __iter__(self): while True: idxs = list(range(len(self.data))) # 生成一個序列 if self.shuffle: np.random.shuffle(idxs) # 打亂序列 X1, X2, Y = [], [], [] for i in idxs: d = self.data[i] text = d[0][:maxlen] x1, x2 = tokenizer.encode(first=text) # 新增[CLS]和[SEP] y = d[1] X1.append(x1) X2.append(x2) Y.append([y]) if len(X1) == self.batch_size or i == idxs[-1]: # 對一批資料(最後一批不滿batch_size)進行padding X1 = seq_padding(X1) # 內部轉為了np.array X2 = seq_padding(X2) Y = seq_padding(Y) yield [X1, X2], Y[:, 0, :] [X1, X2, Y] = [], [], []