GRU訓練情感分類器(程式碼)
阿新 • • 發佈:2022-05-07
import re import time import os import pandas as pd import torch from torch import nn from torch.utils.data import DataLoader, Dataset, random_split from torchtext.vocab import build_vocab_from_iterator class Imdb_Datasets(Dataset): """ when u want to init this instance, the file path should be father/train(test)/datafile """ def __init__(self, data_path: str, train=True): super(Imdb_Datasets, self).__init__() self._train_data_path = os.path.join(data_path, "train") self._test_data_path = os.path.join(data_path, "test") self._temp_data_path = self._train_data_path if train else self._test_data_path self.temp_data_path = [os.path.join(self._temp_data_path, 'pos'), os.path.join(self._temp_data_path, 'neg')] self.total_data_path_list = [] for path in self.temp_data_path: self.total_data_path_list.extend([os.path.join(path, j) for j in os.listdir(path) if j.endswith('.txt')]) def __len__(self): return self.total_data_path_list.__len__() def __getitem__(self, index): path = self.total_data_path_list[index] label_str = path.split('\\')[-2] label = [1, 0] if label_str == 'neg' else [0, 1] content = pd.read_csv(path, sep='\t') return content.columns[0], label # todo: define a NLP network to address sentiment classify problem class Imdb_Sentiment_classify(nn.Module): def __init__(self, vocab_size, embed_dim): super(Imdb_Sentiment_classify, self).__init__() self.hidden_size = 64 self.dropout = 0.5 self.num_layer = 2 self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.gru = nn.GRU(input_size=embed_dim, hidden_size=self.hidden_size, num_layers=self.num_layer, dropout=self.dropout) self.fc = nn.Sequential(nn.Linear(self.hidden_size, 128), nn.ReLU(), nn.Linear(128, 2), nn.Softmax(dim=1) ) self.init_weight() def init_weight(self): initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) def forward(self, text, offsets): """ 注意:在embedding後,資料的維度是[batch_size, embed_size], 需要變成[batch_size, sequence_length, input_size],以此來增加以滿足訓練的要求 參考: https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU, https://pytorch.org/docs/stable/generated/torch.unsqueeze.html forward 這樣寫在Debug時,可以更加直觀的看到每一層的輸出 """ x = self.embedding(text, offsets) x = torch.unsqueeze(x, dim=1) out_, H_n = self.gru(x, None) output_ = self.fc(out_) output = torch.squeeze(output_, dim=1) return output def tokenize(text): fileters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>', '\?', '@', '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ] text = re.sub("<.*?>", " ", text) text = re.sub("|".join(fileters), " ", text) return [i.strip().lower() for i in text.split()] def yield_tokens(data_iter): """ To processing texts2tokens """ for text, label in data_iter: yield tokenize(text) def collate_batch(batch): """ This will be use by DataLoader, which used to processing a batch size of datas :rtype: object """ label_list, text_list, offsets = [], [], [0] for (_text, _label) in batch: label_list.append(label_pipeline(_label)) processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64) text_list.append(processed_text) offsets.append(processed_text.size(0)) label_list = torch.tensor(label_list, dtype=torch.int64) offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) text_list = torch.cat(text_list) return label_list.to(device), text_list.to(device), offsets.to(device) def train(dataloader, epo): model.train() total_acc, total_count = 0, 0 log_interval = 500 start_time = time.time() for idx, (label, text, offsets) in enumerate(dataloader): optimizer.zero_grad() predicted_label = model(text, offsets) loss = criterion(predicted_label, label.to(torch.float32)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item() total_count += label.size(0) if idx % log_interval == 0 and idx > 0: elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches ' '| accuracy {:8.3f}'.format(epo, idx, len(dataloader), total_acc / total_count)) total_acc, total_count = 0, 0 start_time = time.time() def evaluate(dataloader): model.eval() total_acc, total_count = 0, 0 with torch.no_grad(): for idx, (label, text, offsets) in enumerate(dataloader): predicted_label = model(text, offsets) loss = criterion(predicted_label, label.to(torch.float32)) total_acc += (predicted_label.argmax(1) == label.argmax(1)).sum().item() total_count += label.size(0) return total_acc / total_count device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_iter = iter(Imdb_Datasets(r"imdb")) vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"]) vocab.set_default_index(vocab["<unk>"]) text_pipeline = lambda x: vocab(tokenize(x)) label_pipeline = lambda x: x PATH = "imdb" # todo: Hyper-parameter to net vb_size = len(vocab) emsize = 128 LR = 5 EPOCH = 30 BATCH_SIZE = 64 train_imdb_Dataset = Imdb_Datasets(PATH, train=True) test_imdb_Dataset = Imdb_Datasets(PATH, train=False) num_train = int(len(test_imdb_Dataset) * 0.95) split_train_, split_valid_ = random_split(test_imdb_Dataset, [num_train, len(test_imdb_Dataset) - num_train]) train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch) test_dataloader = DataLoader(test_imdb_Dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch) val_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch) model = Imdb_Sentiment_classify(vocab_size=vb_size, embed_dim=emsize).to(device) print(model) for i, j, k in train_dataloader: print(f"label:{i.shape}\ntext:{j.shape}\noffsets:{k.shape}") output = model(j, k) print(f"output shape: {output.shape}") print("-" * 10 + "show some detail" + "-" * 10) break criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=LR) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) total_accu = None for epoch in range(1, EPOCH + 1): epoch_start_time = time.time() train(train_dataloader, epo=epoch) accu_val = evaluate(val_dataloader) if total_accu is not None and total_accu > accu_val: scheduler.step() else: total_accu = accu_val print('-' * 59) print('| end of epoch {:3d} | time: {:5.2f}s | ' 'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val)) print('-' * 59) # 19:02 start
output:
D:\Python\python.exe D:/1PythonProject/RNN/imdb_sentiment_classify.py Imdb_Sentiment_classify( (embedding): EmbeddingBag(87928, 128, mode=mean) (gru): GRU(128, 64, num_layers=2, dropout=0.5) (fc): Sequential( (0): Linear(in_features=64, out_features=128, bias=True) (1): ReLU() (2): Linear(in_features=128, out_features=2, bias=True) (3): Softmax(dim=None) ) ) label:torch.Size([64, 2]) text:torch.Size([16989]) offsets:torch.Size([64]) D:\Python\lib\site-packages\torch\nn\modules\container.py:141: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. input = module(input) output shape: torch.Size([64, 2]) ----------show some detail---------- ----------------------------------------------------------- | end of epoch 1 | time: 394.94s | valid accuracy 0.530 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 2 | time: 94.51s | valid accuracy 0.544 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 3 | time: 57.92s | valid accuracy 0.552 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 4 | time: 61.43s | valid accuracy 0.546 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 5 | time: 58.15s | valid accuracy 0.544 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 6 | time: 57.27s | valid accuracy 0.543 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 7 | time: 57.30s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 8 | time: 57.17s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 9 | time: 57.27s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 10 | time: 57.17s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 11 | time: 57.14s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 12 | time: 57.12s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 13 | time: 57.16s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 14 | time: 57.12s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 15 | time: 57.12s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 16 | time: 57.25s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 17 | time: 57.24s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 18 | time: 57.23s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 19 | time: 57.11s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 20 | time: 57.13s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 21 | time: 57.14s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 22 | time: 57.49s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 23 | time: 57.08s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 24 | time: 59.55s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 25 | time: 59.04s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 26 | time: 59.13s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 27 | time: 59.06s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 28 | time: 59.17s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 29 | time: 59.00s | valid accuracy 0.542 ----------------------------------------------------------- ----------------------------------------------------------- | end of epoch 30 | time: 59.05s | valid accuracy 0.542 ----------------------------------------------------------- Process finished with exit code 0