1. 程式人生 > 其它 >torch.nn.Embedding(num_embeddings,embedding_dim)實現文字轉換詞向量,並完成文字情感分類任務

torch.nn.Embedding(num_embeddings,embedding_dim)實現文字轉換詞向量,並完成文字情感分類任務

1、處理資料集

 1 import torch
 2 import os
 3 import re
 4 from torch.utils.data import Dataset, DataLoader
 5 
 6 
 7 dataset_path = r'C:\Users\ci21615\Downloads\aclImdb_v1\aclImdb'
 8 
 9 
10 def tokenize(text):
11     """
12     分詞,處理原始文字
13     :param text:
14     :return:
15     """
16     fileters = ['
!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>', '\?', '@' 17 , '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '', '', ] 18 text = re.sub("<.*?>", " ", text, flags=re.S) 19 text = re.sub("
|".join(fileters), " ", text, flags=re.S) 20 return [i.strip() for i in text.split()] 21 22 23 class ImdbDataset(Dataset): 24 """ 25 準備資料集 26 """ 27 def __init__(self, mode): 28 super(ImdbDataset, self).__init__() 29 if mode == 'train': 30 text_path = [os.path.join(dataset_path, i) for
i in ['train/neg', 'train/pos']] 31 else: 32 text_path = [os.path.join(dataset_path, i) for i in ['test/neg', 'test/pos']] 33 self.total_file_path_list = [] 34 for i in text_path: 35 self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)]) 36 37 def __getitem__(self, item): 38 cur_path = self.total_file_path_list[item] 39 cur_filename = os.path.basename(cur_path) 40 # 獲取標籤 41 label = int(cur_filename.split('_')[-1].split('.')[0]) - 1 42 text = tokenize(open(cur_path).read().strip()) 43 return label, text 44 45 def __len__(self): 46 return len(self.total_file_path_list) 47 48 49 if __name__ == '__main__': 50 imdb_dataset = ImdbDataset('train') 51 print(imdb_dataset[0])
View Code

當前資料集處理後樣式:

2、自定義dataloader中的collate_fn

 1 def collate_fn(batch):
 2     """
 3     batch是list,其中是一個一個元組,每個元組是dataset中__getitem__的結果
 4     :param batch:
 5     :return:
 6     """
 7     batch = list(zip(*batch))
 8     labels = torch.tensor(batch[0], dtype=torch.int32)
 9     texts = batch[1]
10     del batch
11     return labels, texts
12 
13 
14 dataset = ImdbDataset('train')
15 dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
16 
17 
18 if __name__ == '__main__':
19     for index, (label, text) in enumerate(dataloader):
20         print(index)
21         print(label)
22         print(text)
23         break
View Code

當前結果:

3、文字序列化

每個詞都需要先給定一個初始的數字,再對該數字轉換成向量

  1 class Word2Sequence():
  2     """
  3     文字序列化
  4     思路分析:
  5     1、對所有句子進行分詞
  6     2、詞語存入字典,根據次數對詞語進行過濾,並統計次數
  7     3、實現文字轉數字序列的方法
  8     4、實現數字序列轉文字方法
  9     """
 10     UNK_TAG = 'UNK'
 11     PAD_TAG = 'PAD'
 12     UNK = 0
 13     PAD = 1
 14 
 15     def __init__(self):
 16         self.dict = {
 17             self.UNK_TAG: self.UNK,
 18             self.PAD_TAG: self.PAD
 19         }
 20         self.fited = False
 21 
 22     def to_index(self, word):
 23         """
 24         文字轉換成數字
 25         :param word:
 26         :return:
 27         """
 28         assert self.fited == True
 29         return self.dict.get(word, self.UNK)
 30 
 31     def to_word(self, index):
 32         """
 33         數字轉文字
 34         :param index:
 35         :return:
 36         """
 37         assert self.fited
 38         if index in self.inversed_dict:
 39             return self.inversed_dict[index]
 40         return self.UNK_TAG
 41 
 42     def __len__(self):
 43         return len(self.dict)
 44 
 45     def fit(self, sentences, min_count=1, max_count=None, max_feature=None):
 46         """
 47         :param sentences:[[word1,word2,word3],[word1,word3,wordn..],...]
 48         :param min_count: 最小出現的次數
 49         :param max_count: 最大出現的次數
 50         :param max_feature: 總詞語的最大數量
 51         :return:
 52         """
 53         count = {}
 54         # 單詞出現的次數
 55         for sentence in sentences:
 56             for a in sentence:
 57                 if a not in count:
 58                     count[a] = 0
 59                 count[a] += 1
 60         # 根據單詞數量進行處理,即可以過濾頻率小的單詞
 61         if min_count is not None:
 62             count = {k:v for k, v in count.items() if v >= min_count}
 63         if max_count is not None:
 64             count = {k:v for k, v in count.items() if v <= max_count}
 65         # 限制最大的數量
 66         # 每個數字對應的初始值就是加入dict時dict的大小
 67         if isinstance(max_feature, int):
 68             count = sorted(list(count.items()), key=lambda x: x[1])
 69             if max_feature is not None and len(count) > max_feature:
 70                 count = count[-int(max_feature):]
 71             for w, _ in count:
 72                 self.dict[w] = len(self.dict)
 73         else:
 74             for w in sorted(count.keys()):
 75                 self.dict[w] = len(self.dict)
 76         self.fited = True
 77         self.inversed_dict = dict(zip(self.dict.values(), self.dict.keys()))
 78 
 79     def transform(self, sentence, max_len=None):
 80         """
 81         實現吧句子轉化為陣列(向量)
 82         :param sentence:
 83         :param max_len:
 84         :return:
 85         """
 86         assert self.fited
 87         if max_len is not None:
 88             r = [self.PAD] * max_len
 89         else:
 90             r = [self.PAD] * len(sentence)
 91         if max_len is not None and len(sentence) > max_len:
 92             sentence = sentence[:max_len]
 93         for index, word in enumerate(sentence):
 94             r[index] = self.to_index(word)
 95         return np.array(r, dtype=np.int64)
 96 
 97     def inverse_transform(self, indices):
 98         """
 99         實現從陣列 轉化為文字
100         :param indices: [1,2,3....]
101         :return:[word1,word2.....]
102         """
103         sentence = []
104         for i in indices:
105             word = self.to_word(i)
106             sentence.append(word)
107         return sentence
108 
109 
110 
111 if __name__ == '__main__':
112     w2s = Word2Sequence()
113     w2s.fit([
114         ['', '', '', ''],
115         ['', '', '', '']
116     ])
117     print(w2s.dict)
118     print(w2s.fited)
119     print(w2s.transform(['', '', '']))
120     print(w2s.transform(['神麼這'], max_len=10))
View Code

結果:

4、對Imdb資料構建字典,每個詞對應一個數字

 1 # 實現對IMDB資料的處理和儲存
 2 def fit_save_word_sequence():
 3     """
 4     從資料集構建字典
 5     :return:
 6     """
 7     ws = Word2Sequence()
 8     train_path = [os.path.join(dataset_path, i) for i in ['train/neg', 'train/pos']]
 9     total_file_path_list = []
10     for i in train_path:
11         total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])
12     for cur_path in tqdm(total_file_path_list, desc='fitting'):
13         sentence = open(cur_path, encoding='utf-8').read().strip()
14         res = tokenize(sentence)
15         ws.fit([res])
16     # 對wordSequesnce進行儲存
17     print(ws.dict)
18     print(len(ws))
19     pickle.dump(ws, open('./model/ws.pkl', 'wb'))
20 
21 
22 if __name__ == '__main__':
23     fit_save_word_sequence()
View Code

5、對每一段文字轉換成向量,可指定max_len維度

 1 def get_dataloader(mode='train'):
 2     """
 3     獲取資料集,轉換成詞向量後的資料集
 4     :param mode:
 5     :return:
 6     """
 7     # 匯入詞典
 8     ws = pickle.load(open('./model/ws.pkl', 'rb'))
 9     print(len(ws))
10     # 自定義collate_fn函式
11     def collate_fn(batch):
12         """
13         batch是list,其中是一個一個元組,每個元組是dataset中__getitem__的結果
14         :param batch:
15         :return:
16         """
17         max_len = 500
18         batch = list(zip(*batch))
19         labels = torch.tensor(batch[0], dtype=torch.int32)
20         texts = batch[1]
21         # 獲取每個文字的長度
22         lengths = [len(i) if len(i) < max_len else max_len for i in texts]
23         # 每一段文字句子都轉換成了max_len維度的向量,即500維的向量
24         temp = [ws.transform(i, max_len) for i in texts]
25         texts = torch.tensor(temp)
26 
27         del batch
28         return labels, texts, lengths
29     dataset = ImdbDataset(mode)
30     dataloader = DataLoader(dataset=dataset, batch_size=20, shuffle=True, collate_fn=collate_fn)
31     return dataloader
32 
33 
34 if __name__ == '__main__':
35     for index, (label, texts, length) in enumerate(get_dataloader()):
36         print(index)
37         print(label)
38         print(texts)
39         print(length)
View Code

報錯問題:

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

說白了就是num_embeddings(詞典的詞個數)不夠大,為什麼不夠呢

按道理說,我們詞嵌入的時候字典從0,1,…………n,對映我們所有的詞(或者字)

num_embeddings = n,是夠用的,但是我們考慮pad,pad預設一般是0,所以我們會重新處理一下對映字典1,2…………n+1

這時候 num_embeddings = n+1才夠對映

所以+1就夠了

然後就不會報錯了

參考:

https://blog.csdn.net/weixin_36488653/article/details/118485063