torch.nn.Embedding(num_embeddings,embedding_dim)實現文字轉換詞向量,並完成文字情感分類任務
阿新 • • 發佈:2021-10-25
1、處理資料集
1 import torch 2 import os 3 import re 4 from torch.utils.data import Dataset, DataLoader 5 6 7 dataset_path = r'C:\Users\ci21615\Downloads\aclImdb_v1\aclImdb' 8 9 10 def tokenize(text): 11 """ 12 分詞,處理原始文字 13 :param text: 14 :return: 15 """ 16 fileters = ['View Code!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>', '\?', '@' 17 , '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ] 18 text = re.sub("<.*?>", " ", text, flags=re.S) 19 text = re.sub("|".join(fileters), " ", text, flags=re.S) 20 return [i.strip() for i in text.split()] 21 22 23 class ImdbDataset(Dataset): 24 """ 25 準備資料集 26 """ 27 def __init__(self, mode): 28 super(ImdbDataset, self).__init__() 29 if mode == 'train': 30 text_path = [os.path.join(dataset_path, i) fori in ['train/neg', 'train/pos']] 31 else: 32 text_path = [os.path.join(dataset_path, i) for i in ['test/neg', 'test/pos']] 33 self.total_file_path_list = [] 34 for i in text_path: 35 self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)]) 36 37 def __getitem__(self, item): 38 cur_path = self.total_file_path_list[item] 39 cur_filename = os.path.basename(cur_path) 40 # 獲取標籤 41 label = int(cur_filename.split('_')[-1].split('.')[0]) - 1 42 text = tokenize(open(cur_path).read().strip()) 43 return label, text 44 45 def __len__(self): 46 return len(self.total_file_path_list) 47 48 49 if __name__ == '__main__': 50 imdb_dataset = ImdbDataset('train') 51 print(imdb_dataset[0])
當前資料集處理後樣式:
2、自定義dataloader中的collate_fn
1 def collate_fn(batch): 2 """ 3 batch是list,其中是一個一個元組,每個元組是dataset中__getitem__的結果 4 :param batch: 5 :return: 6 """ 7 batch = list(zip(*batch)) 8 labels = torch.tensor(batch[0], dtype=torch.int32) 9 texts = batch[1] 10 del batch 11 return labels, texts 12 13 14 dataset = ImdbDataset('train') 15 dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn) 16 17 18 if __name__ == '__main__': 19 for index, (label, text) in enumerate(dataloader): 20 print(index) 21 print(label) 22 print(text) 23 breakView Code
當前結果:
3、文字序列化
每個詞都需要先給定一個初始的數字,再對該數字轉換成向量
1 class Word2Sequence(): 2 """ 3 文字序列化 4 思路分析: 5 1、對所有句子進行分詞 6 2、詞語存入字典,根據次數對詞語進行過濾,並統計次數 7 3、實現文字轉數字序列的方法 8 4、實現數字序列轉文字方法 9 """ 10 UNK_TAG = 'UNK' 11 PAD_TAG = 'PAD' 12 UNK = 0 13 PAD = 1 14 15 def __init__(self): 16 self.dict = { 17 self.UNK_TAG: self.UNK, 18 self.PAD_TAG: self.PAD 19 } 20 self.fited = False 21 22 def to_index(self, word): 23 """ 24 文字轉換成數字 25 :param word: 26 :return: 27 """ 28 assert self.fited == True 29 return self.dict.get(word, self.UNK) 30 31 def to_word(self, index): 32 """ 33 數字轉文字 34 :param index: 35 :return: 36 """ 37 assert self.fited 38 if index in self.inversed_dict: 39 return self.inversed_dict[index] 40 return self.UNK_TAG 41 42 def __len__(self): 43 return len(self.dict) 44 45 def fit(self, sentences, min_count=1, max_count=None, max_feature=None): 46 """ 47 :param sentences:[[word1,word2,word3],[word1,word3,wordn..],...] 48 :param min_count: 最小出現的次數 49 :param max_count: 最大出現的次數 50 :param max_feature: 總詞語的最大數量 51 :return: 52 """ 53 count = {} 54 # 單詞出現的次數 55 for sentence in sentences: 56 for a in sentence: 57 if a not in count: 58 count[a] = 0 59 count[a] += 1 60 # 根據單詞數量進行處理,即可以過濾頻率小的單詞 61 if min_count is not None: 62 count = {k:v for k, v in count.items() if v >= min_count} 63 if max_count is not None: 64 count = {k:v for k, v in count.items() if v <= max_count} 65 # 限制最大的數量 66 # 每個數字對應的初始值就是加入dict時dict的大小 67 if isinstance(max_feature, int): 68 count = sorted(list(count.items()), key=lambda x: x[1]) 69 if max_feature is not None and len(count) > max_feature: 70 count = count[-int(max_feature):] 71 for w, _ in count: 72 self.dict[w] = len(self.dict) 73 else: 74 for w in sorted(count.keys()): 75 self.dict[w] = len(self.dict) 76 self.fited = True 77 self.inversed_dict = dict(zip(self.dict.values(), self.dict.keys())) 78 79 def transform(self, sentence, max_len=None): 80 """ 81 實現吧句子轉化為陣列(向量) 82 :param sentence: 83 :param max_len: 84 :return: 85 """ 86 assert self.fited 87 if max_len is not None: 88 r = [self.PAD] * max_len 89 else: 90 r = [self.PAD] * len(sentence) 91 if max_len is not None and len(sentence) > max_len: 92 sentence = sentence[:max_len] 93 for index, word in enumerate(sentence): 94 r[index] = self.to_index(word) 95 return np.array(r, dtype=np.int64) 96 97 def inverse_transform(self, indices): 98 """ 99 實現從陣列 轉化為文字 100 :param indices: [1,2,3....] 101 :return:[word1,word2.....] 102 """ 103 sentence = [] 104 for i in indices: 105 word = self.to_word(i) 106 sentence.append(word) 107 return sentence 108 109 110 111 if __name__ == '__main__': 112 w2s = Word2Sequence() 113 w2s.fit([ 114 ['這', '是', '什', '麼'], 115 ['那', '是', '神', '麼'] 116 ]) 117 print(w2s.dict) 118 print(w2s.fited) 119 print(w2s.transform(['神', '麼', '這'])) 120 print(w2s.transform(['神麼這'], max_len=10))View Code
結果:
4、對Imdb資料構建字典,每個詞對應一個數字
1 # 實現對IMDB資料的處理和儲存 2 def fit_save_word_sequence(): 3 """ 4 從資料集構建字典 5 :return: 6 """ 7 ws = Word2Sequence() 8 train_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']] 9 total_file_path_list = [] 10 for i in train_path: 11 total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)]) 12 for cur_path in tqdm(total_file_path_list, desc='fitting'): 13 sentence = open(cur_path, encoding='utf-8').read().strip() 14 res = tokenize(sentence) 15 ws.fit([res]) 16 # 對wordSequesnce進行儲存 17 print(ws.dict) 18 print(len(ws)) 19 pickle.dump(ws, open('./model/ws.pkl', 'wb')) 20 21 22 if __name__ == '__main__': 23 fit_save_word_sequence()View Code
5、對每一段文字轉換成向量,可指定max_len維度
1 def get_dataloader(mode='train'): 2 """ 3 獲取資料集,轉換成詞向量後的資料集 4 :param mode: 5 :return: 6 """ 7 # 匯入詞典 8 ws = pickle.load(open('./model/ws.pkl', 'rb')) 9 print(len(ws)) 10 # 自定義collate_fn函式 11 def collate_fn(batch): 12 """ 13 batch是list,其中是一個一個元組,每個元組是dataset中__getitem__的結果 14 :param batch: 15 :return: 16 """ 17 max_len = 500 18 batch = list(zip(*batch)) 19 labels = torch.tensor(batch[0], dtype=torch.int32) 20 texts = batch[1] 21 # 獲取每個文字的長度 22 lengths = [len(i) if len(i) < max_len else max_len for i in texts] 23 # 每一段文字句子都轉換成了max_len維度的向量,即500維的向量 24 temp = [ws.transform(i, max_len) for i in texts] 25 texts = torch.tensor(temp) 26 27 del batch 28 return labels, texts, lengths 29 dataset = ImdbDataset(mode) 30 dataloader = DataLoader(dataset=dataset, batch_size=20, shuffle=True, collate_fn=collate_fn) 31 return dataloader 32 33 34 if __name__ == '__main__': 35 for index, (label, texts, length) in enumerate(get_dataloader()): 36 print(index) 37 print(label) 38 print(texts) 39 print(length)View Code
報錯問題:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
說白了就是num_embeddings(詞典的詞個數)不夠大,為什麼不夠呢
按道理說,我們詞嵌入的時候字典從0,1,…………n,對映我們所有的詞(或者字)
num_embeddings = n,是夠用的,但是我們考慮pad,pad預設一般是0,所以我們會重新處理一下對映字典1,2…………n+1
這時候 num_embeddings = n+1才夠對映
所以+1就夠了
然後就不會報錯了
參考:
https://blog.csdn.net/weixin_36488653/article/details/118485063