1. 程式人生 > 實用技巧 >利用RNN進行中文文字分類(資料集是復旦中文語料)

利用RNN進行中文文字分類(資料集是復旦中文語料)

利用TfidfVectorizer進行中文文字分類(資料集是復旦中文語料)

1、訓練詞向量

資料預處理參考利用TfidfVectorizer進行中文文字分類(資料集是復旦中文語料),現在我們有了分詞後的train_jieba.txt和test_jieba.txt,看一下部分內容:

fenci_path = '/content/drive/My Drive/NLP/dataset/Fudan/train_jieba.txt'
with open(fenci_path,'r',encoding='utf-8') as fp:
    i = 0
    lines = fp.readlines()
    for
line in lines: print(line) i += 1 if i == 10: break

每一篇文章的分詞結果和標籤都是一行,標籤之間使用'\t'隔開。

由於之前只是簡略的進行分詞,沒有過濾掉一些停用詞,接下來還需要進行一些處理,我們之前已經建立了停用詞文字:stopwords.txt,現在我們要使用它。

def clean():
  label_list = []
  content_list = []
  with open('/content/drive/My Drive/NLP/dataset/Fudan/train_jieba.txt
','r',encoding='utf-8') as fp: lines = fp.readlines() for line in lines: tmp = line.strip().split("\t") content,label = tmp[0],tmp[1] label_list.append(label) out_list = [] for word in content.strip().split(' '): if word not in stopwords_list and word != ''
: out_list.append(word) else: continue content_list.append(" ".join(out_list)) return content_list,label_list content_list,label_list = clean() i = 0 for content,label in zip(content_list,label_list): print(content,label) i += 1 if i == 10: break

確實是過濾掉了一些停用詞,如果效果不好可以根據當前任務情況繼續擴充停用詞 ,這裡就暫時到這了。

對訓練集和測試集進行同樣的清理後儲存:

def save(content_list,label_list):
  path = '/content/drive/My Drive/NLP/dataset/Fudan/train_clean_jieba.txt'
  fp = open(path,'w',encoding='utf-8')
  for content,label in zip(content_list,label_list):
    fp.write(content+str(label)+'\n')
  fp.close()
save(content_list,label_list)

對測試集進行相同的操作時這一句content,label = tmp[0],tmp[1]出現了:list out of range

只需要多加一句:if len(tmp) == 2:過濾以下即可。

def clean():
  label_list = []
  content_list = []
  with open('/content/drive/My Drive/NLP/dataset/Fudan/test_jieba.txt','r',encoding='utf-8') as fp:
    lines = fp.readlines()
    for line in lines:
      tmp = line.strip().split("\t")
      if len(tmp) == 2:
        content,label = tmp[0],tmp[1]
      label_list.append(label)
      out_list = []
      for word in content.strip().split(' '):
        if word not in stopwords_list and word != '':
          out_list.append(word)
        else:
          continue
      content_list.append(" ".join(out_list))
  return content_list,label_list
content_list,label_list = clean()    
def save(content_list,label_list):
  path = '/content/drive/My Drive/NLP/dataset/Fudan/test_clean_jieba.txt'
  fp = open(path,'w',encoding='utf-8')
  for content,label in zip(content_list,label_list):
    fp.write(content+'\t'+str(label)+'\n')
  fp.close()
save(content_list,label_list)

2、訓練word2vec,構建詞向量

我們新建一個data資料夾,然後將train_clean_jieba.txt和test_clean_jieba.txt放進去。這裡word2vec的用法就不具體介紹了。

from gensim.models import Word2Vec
from gensim.models.word2vec import PathLineSentences
import multiprocessing
import os
import sys
import logging

# 日誌資訊輸出
program = os.path.basename(sys.argv[0])
logger = logging.getLogger(program)
logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')
logging.root.setLevel(level=logging.INFO)
logger.info("running %s" % ' '.join(sys.argv))

# check and process input arguments
# if len(sys.argv) < 4:
#     print(globals()['__doc__'] % locals())
#     sys.exit(1)
# input_dir, outp1, outp2 = sys.argv[1:4]

# 訓練模型 
# 輸入語料目錄:PathLineSentences(input_dir)
# embedding size:100 共現視窗大小:5 去除出現次數5以下的詞,多執行緒執行,迭代5次
model = Word2Vec(PathLineSentences('/content/drive/My Drive/NLP/dataset/Fudan/data/'),
                     size=100, window=5, min_count=5,
                     workers=multiprocessing.cpu_count(), iter=5)
model.save('/content/drive/My Drive/NLP/dataset/Fudan/Word2vec.w2v')

執行之後是這個樣子:

2020-10-16 13:57:28,601: INFO: running /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py -f /root/.local/share/jupyter/runtime/kernel-52776eb8-5141-458e-8f04-3d3a0f11d46f.json
2020-10-16 13:57:28,606: INFO: reading directory /content/drive/My Drive/NLP/dataset/Fudan/data/
2020-10-16 13:57:28,608: INFO: files read into PathLineSentences:/content/drive/My Drive/NLP/dataset/Fudan/data/test_clean_jieba.txt
/content/drive/My Drive/NLP/dataset/Fudan/data/train_clean_jieba.txt
2020-10-16 13:57:28,610: INFO: collecting all words and their counts
2020-10-16 13:57:28,612: INFO: reading file /content/drive/My Drive/NLP/dataset/Fudan/data/test_clean_jieba.txt
/usr/local/lib/python3.6/dist-packages/smart_open/smart_open_lib.py:252: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function
  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL
2020-10-16 13:57:28,627: INFO: PROGRESS: at sentence #0, processed 0 words, keeping 0 word types
2020-10-16 13:57:33,897: INFO: reading file /content/drive/My Drive/NLP/dataset/Fudan/data/train_clean_jieba.txt
2020-10-16 13:57:34,040: INFO: PROGRESS: at sentence #10000, processed 18311769 words, keeping 440372 word types
2020-10-16 13:57:39,060: INFO: collected 584112 word types from a corpus of 35545042 raw words and 19641 sentences
2020-10-16 13:57:39,062: INFO: Loading a fresh vocabulary
2020-10-16 13:57:39,768: INFO: effective_min_count=5 retains 183664 unique words (31% of original 584112, drops 400448)
2020-10-16 13:57:39,769: INFO: effective_min_count=5 leaves 34810846 word corpus (97% of original 35545042, drops 734196)
2020-10-16 13:57:40,320: INFO: deleting the raw counts dictionary of 584112 items
2020-10-16 13:57:40,345: INFO: sample=0.001 downsamples 19 most-common words
2020-10-16 13:57:40,345: INFO: downsampling leaves estimated 33210825 word corpus (95.4% of prior 34810846)
2020-10-16 13:57:40,951: INFO: estimated required memory for 183664 words and 100 dimensions: 238763200 bytes
2020-10-16 13:57:40,952: INFO: resetting layer weights
2020-10-16 13:58:15,170: INFO: training model with 2 workers on 183664 vocabulary and 100 features, using sg=0 hs=0 sample=0.001 negative=5 window=5
2020-10-16 13:58:15,174: INFO: reading file /content/drive/My Drive/NLP/dataset/Fudan/data/test_clean_jieba.txt
2020-10-16 13:58:16,183: INFO: EPOCH 1 - PROGRESS: at 1.11% examples, 481769 words/s, in_qsize 3, out_qsize 0

最後會生成:

接下來我們要使用模型,然後檢視詞以及其對應的詞向量:

from gensim.models import Word2Vec
model = Word2Vec.load('/content/drive/My Drive/NLP/dataset/Fudan/Word2vec.w2v')
#所有詞的數目
print(len(model.wv.index2word))
word_vector_dict = {}
for word in model.wv.index2word:
  word_vector_dict[word] = list(model[word])
i = 0
for k,v in word_vector_dict.items():
  print(k,v)
  i += 1
  if i == 5:
    break

結果:

/usr/local/lib/python3.6/dist-packages/smart_open/smart_open_lib.py:252: UserWarning: This function is deprecated, use smart_open.open instead. See the migration notes for details: https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst#migrating-to-the-new-open-function
  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL
183664
/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:7: DeprecationWarning: Call to deprecated `__getitem__` (Method will be removed in 4.0.0, use self.wv.__getitem__() instead).
  import sys
. [-2.8709345, -0.47548708, 0.86331373, 1.2737428, 2.3575406, 2.0570302, -0.53931403, 1.2613002, 0.5172711, -1.6461672, 1.3732913, 0.86122376, -0.21252058, 2.0552237, 0.9418685, 0.3278085, 0.588585, -0.7969468, -1.8978101, -0.43336996, -0.4861237, -0.25338736, -0.5043334, 0.6816521, 4.776381, 1.3428804, 1.9577577, 0.2862259, -1.3767976, 1.2107555, -0.21500991, 2.584977, -3.157238, -0.08438093, -1.4721884, -0.5101056, 0.39259034, 0.74332994, -0.6534138, 0.04722414, 2.2819524, 1.9146276, -0.13876201, -1.3124858, -1.2666191, 0.1447281, -0.5460836, 1.7340208, 0.5979215, -4.0311975, 0.11542667, -0.6193901, -0.058931056, 1.9952455, -0.8310607, -0.9370241, 0.2416995, -1.4236349, -0.41856983, -0.5497827, 1.2359228, 0.01779593, 0.9849501, 1.2311344, 1.8523129, 2.363041, 1.0974075, -1.2220355, 0.110876285, 0.17010106, -0.9745132, 1.1252304, 0.20266196, 1.6555228, -0.69005895, -0.15593, -2.6057267, 0.59146214, -0.29261357, 0.83551484, -2.1035368, 1.1904488, -1.0554912, -0.641594, 1.2142769, -1.4514563, 0.9756896, 0.52437824, -0.8486732, -3.358046, -0.69511414, 1.8128188, 0.45924014, -1.1814638, -0.48232678, -0.12257868, 0.23399891, -3.303544, -0.6949516, 0.5121446]
, [-2.618333, -1.8558567, 1.8535767, -0.21151228, 1.7623954, 4.3192573, 0.09128157, 1.5980599, 0.7076833, -1.7116284, 1.0046017, -0.15326972, 0.4059908, 0.9488417, 2.2387662, 0.20677945, -0.7107643, -2.758641, -0.3840812, 0.16083181, -2.1107125, 0.24038436, -1.2403657, 2.7272208, 1.9277251, 0.1489557, 2.1110923, 0.5919174, -2.1878436, 0.36604762, 0.31739056, 5.550043, -3.364542, 0.70963943, 0.13099277, -2.2344782, -0.39852622, -0.24567917, -1.3379095, -0.27352497, 1.3079535, -0.3681397, 1.2069534, -0.7798161, -0.18939576, -0.373316, -1.1903548, 1.2864754, -0.61407185, -3.171876, -1.2982743, 1.7416263, 0.73636365, 0.9905826, -0.3719811, 0.05626492, -2.6127703, 0.83886856, 0.66923296, 1.2502893, 0.9262052, 0.42174354, -1.484305, -0.17558077, 1.9593159, 4.8938365, 0.61336166, -1.0788211, -1.0862421, -0.5105872, -2.6575727, 2.091327, -0.23270625, 2.284086, -0.98763543, 0.28696263, -2.2600112, -3.2595506, 0.025764514, 1.3404137, -0.71168816, 2.2680438, 0.48311472, 0.36931905, 0.938186, -1.6107051, -0.15926446, 1.3209386, -0.801876, -2.303902, -0.436481, 0.8073558, 0.38733667, -0.26957598, -1.4267699, -0.8020603, 0.414129, -3.3372293, 0.6402213, -0.19667119]
) [-0.80750054, -0.6121455, -1.0710338, -2.9930687, 2.0432, 4.141169, -0.15709901, 0.81717527, -1.5162835, -3.1241925, -0.10446141, 1.010525, -3.1002233, 1.6662389, 0.9942944, 0.85855705, 2.0851238, -1.6842883, -2.9477723, -0.2876924, -0.6282387, -0.28349137, -3.1225855, 2.2486699, 1.2903367, 2.2274559, 0.27433106, 0.57094145, -1.1607213, -0.4642481, -1.0572903, 3.2884996, -1.2198547, -1.6459501, 0.67363816, -2.5827177, -0.25848988, -1.1222432, 0.21818976, 1.8232889, 2.8271437, -0.617807, -1.4015028, 1.2166779, -0.8353678, 0.34809938, -0.46445072, -0.084388316, 0.7031371, -4.1085744, -0.50515014, -3.1198754, 0.72745895, 1.4460654, 0.9307348, -2.758027, 0.018058121, -0.8535555, 0.6409112, 0.1882723, -1.1798013, 1.3632597, -0.1337653, 0.51510906, -0.5415601, 4.006427, -0.91912925, -3.4697065, -2.7071013, -0.6627828, -2.9176655, 1.0004271, 0.8123536, 2.1355457, -0.013824586, -0.10087594, 0.115427904, -0.46978354, 2.071482, 1.8447496, 0.99563545, 2.845259, 1.1902128, 0.02504066, 2.6136658, -0.6704431, -0.47580847, 1.1602222, 1.2428118, -2.3880181, -1.6264966, 0.74079543, -0.54774994, 1.0163826, -0.736786, -1.8922712, 0.5381837, -1.1004277, 0.33553576, 0.40247878]
( [-2.4204996, -1.0095057, 0.36723495, -1.9701287, 1.5028982, 1.0829349, -0.72509646, 1.0087173, -0.8471445, 0.21284652, -0.4341774, -0.9700405, -1.300372, 0.9491097, 3.350109, 1.4735373, 2.9339328, -0.3343834, -3.6445296, -0.41197056, -1.338803, 0.28331625, 0.10618747, -1.3739557, 1.1008664, 0.17741367, 0.45283958, 1.5100185, -1.7710751, 1.0186597, 0.7735381, 2.491264, 0.07328774, -1.1831408, -3.2152338, -2.5108373, -0.34185433, 0.34209073, -0.14207332, -2.194724, 1.0734048, -1.1285906, 1.9627889, -1.5373456, -1.9735036, 2.2119362, -0.21241511, 1.8747587, -0.67907304, -4.566279, -2.0092149, -1.3107775, 0.3573235, 0.9350223, 0.4996264, 1.6724535, -0.79917055, -0.14005652, 2.7869322, 0.80775166, 0.13976693, 0.5046433, -0.34996128, 0.3425343, 3.6427495, 2.3169396, -1.0229387, -4.0736656, 0.09746367, 0.79698503, -3.6760647, 0.53965265, -2.018294, 2.074562, -0.5203732, 0.06932237, -1.1419374, -1.2626162, 1.5128584, 1.1419917, -2.4901378, 3.0212705, 3.0879154, -1.0666283, 1.4316878, 0.25575432, 1.0118675, -0.210056, 1.5728005, -3.074708, -2.050965, 2.177831, -1.4306773, 0.5591415, -1.6649296, -2.479498, 0.27199566, -0.7439327, 1.065499, -1.7122517]
中 [-1.4137642, 0.07996469, -0.84706545, 0.9269082, -0.5876861, 0.9406654, -2.7666419, 0.013692471, 0.7948517, -3.7575817, -3.0255227, -0.1290994, 0.15024899, 1.7057111, -1.783816, 1.2594382, -0.80985075, 1.2856516, -1.1239803, 0.33939472, 1.7681189, 0.5220787, -3.093301, -0.72288835, -0.27703923, 0.6913874, -0.62614673, 0.16310164, 1.6016583, -0.9558958, -0.65395266, -0.81403816, -0.35800782, -1.6817136, 0.0038451876, 0.924515, 0.7525097, -0.55127585, -2.7082217, -0.5226547, 0.65330553, -0.13418457, -0.11833907, -4.0032573, -0.56922513, -1.323926, 0.097095534, 1.0593758, 0.48968402, -0.6643793, 1.4596446, -2.0395942, 2.7365487, -1.0603454, -0.54655385, -2.8474076, 0.3412293, 0.96139586, 0.9478409, 0.7041088, 4.2240176, -0.5293954, -3.0038583, -3.1062794, 0.55948454, 0.37824842, 0.13522537, 0.00925424, -1.3225565, 0.4190299, 0.57395566, -1.2779645, -0.6505884, 3.8218825, -1.2415665, -0.06736558, -1.7298794, 1.6446227, -1.0105107, -1.0007042, -0.7136034, 1.7795436, -0.8232877, 0.3342558, -1.9837192, -0.043689013, 0.4572051, 0.5139073, 1.9465048, 1.3884708, -1.18057, 3.5671742, -2.4114704, 1.324688, -0.14609453, -0.724388, 0.6249127, 0.600731, -2.1366022, 2.421635]

還是有一些標點符號沒有去除,可以補充下停止詞文字,這裡就暫時這樣了。

之後我們將詞語儲存在一個檔案中,將對應的詞向量儲存在另一個檔案中。

vocabulary_path = '/content/drive/My Drive/NLP/dataset/Fudan/vocabulary.txt'
vector_path = '/content/drive/My Drive/NLP/dataset/Fudan/vector.txt'
fp1 = open(vocabulary_path,'w',encoding='utf-8')
fp2 = open(vector_path,'w',encoding='utf-8')
for word in model.wv.index2word:
  fp1.write(word+'\n')
  vector_list = model[word]
  vector_str_list = [str(num) for num in vector_list]
  fp2.write(" ".join(vector_str_list)+"\n")
fp1.close()
fp2.close()

接著我們還需要進行一系列的轉換操作:

import keras
#將詞彙表中的單詞對映成id
def word2id():
  vocabulary_path = '/content/drive/My Drive/NLP/dataset/Fudan/vocabulary.txt'
  fp1 = open(vocabulary_path,'r',encoding='utf-8')
  word2id_dict = {}
  for i,line in enumerate(fp1.readlines()):
    word2id_dict[line.strip()] = i
  print(word2id_dict)
  fp1.close()
  return word2id_dict
#得到文字內容及對應的標籤
def get_content_label():
  data = '/content/drive/My Drive/NLP/dataset/Fudan/data/train_clean_jieba.txt'
  fp = open(data,'r',encoding='utf-8')
  content_list = []
  label_list = []
  for line in fp.readlines():
    line = line.strip().split('\t')
    if len(line) == 2:
      content_list.append(line[0])
      label_list.append(line[1])
  print(content_list[:5])
  print(label_list[:5])
  fp.close()
  return content_list,label_list
#得到標籤對應的id
def get_label_id():
  label = '/content/drive/My Drive/NLP/dataset/Fudan/label.txt'
  label2id_dict = {}
  fp = open(label,'r',encoding='utf-8')
  for line in fp.readlines():
    line = line.strip().split('\t')
    label2id_dict[line[0]] = line[1]
  #print(label2id_dict)
  return label2id_dict
#將文字內容中的詞替換成詞對應的id,並設定文字的最大長度
#對標籤進行one-hot編碼
def process():
  max_length = 600
  data = '/content/drive/My Drive/NLP/dataset/Fudan/data/train_clean_jieba.txt'
  contents,labels = get_content_label()
  word_to_id = word2id()
  cat_to_id = get_label_id()
  data_id = []
  label_id = []
  for i in range(len(contents)):
    data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
    label_id.append(cat_to_id[labels[i]])

  # 使用keras提供的pad_sequences來將文字pad為固定長度
  x_pad = keras.preprocessing.sequence.pad_sequences(data_id, max_length)
  y_pad = keras.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 將標籤轉換為one-hot表示
  return x_pad,y_pad
x_pad,y_pad = process()
print(x_pad[0])
print(y_pad[0])
print(len(x_pad),len(y_pad))

結果:

[  3464   2264   1227   1015   1844  34754   3464   2264   5781   2933
   1214   1499    519   2558    603  68784  50747   2706   1499   2127
   2558   3388   2912   1128   4617   1499   2127   3464   2264      4
   1499   2127   1244   5645  22020  55754   3464   2264   4419   5781
   2933   3464   2264   2558    603   1538     80   1104   1844      4
   1363   2821   5602   3464   2264   1244   5645   5308   2558    603
   1244   5645   1844  34754   3464   2264    238   1499   2558    603
   5602   5308   2127   2558    603    538    762   4437   2127   2558
    603   3388   2264   1024   1139    538   1818   1024   1139   1851
   1851   2327    139    929   1548    314    160   2602    482  10087
  13030   1730  40786   4754    139    562    366   6089      4    562
    160   2602     85   2433   5781     80    466   1139   1503   4453
   4617   1244   5645   3560   6058   3459      4    562    160   2602
   2558    603   3829   2517    410   4585   2558    603   3464   2264
   3848    423  11739   5645   3560   6058    431   3950   2127   1499
   2127     35    423  11739   5645    319   2558    603   1499   2127
   3773   4383      4   1503   1499   2558    603   1994   4419   1257
   1553    603    926   6065   1257   1553    603   1376    431   1538
     80   1090   2646   6506   7261    519   2558    603   1994   4419
   2456   2127   2558    603  20160   1553    603   1182   1090  16160
   4414   1137   1503   1844  34754      4    864  22754   1844  34754
   1730   3464   2264   2558    603  68784   3464   2264   2558    603
   5658  16754   6608   2558    603   3468   1776   4780  11201   5634
    429   1994   4419  38671   1730   3464   2264    755   2332  25839
    828   2558    603   3464   2264    429   3174    144   2840    429
   3174   1305   1164   2094  41825  33950      7      4    562   3464
   2264   3773   4383   7131    787   2264   3773   4383   3773   4383
   5326      8   1336  22020   2181   3464   2264   2558    603    915
    429  19614  11857   1844  34754    905   5372    429   3140   1116
   1371    780    858    780  22020  55754   3464   2264   2558    603
   4526   1032   1227   1015   1104   1844  17286   5308   2456   1104
   2193    429   3464   2264   2558    603   1336   3464   2264    755
   2558    603    755    888   2127   2558    603   1182   1090    139
   1499   2193    429   3464   2264   2558    603    220    201    144
   1844  34754   5223   3355    296   1321      0   1844   2602   5368
   4815    319    144    160   2602    915    429   2332   1996   1227
   1015   2114    384   2691  25814   2261    160   2602   1844  12894
   1996  20370  15958   1844  34754   4711   3994   1996      0   1844
  34754   1866   3241   6754    201   1305   2181   6754    201   2558
    603   2558    603   2193    429   2127   1090   4617   4982   2706
   1025   3119  10028   3464   2264   2558    603   1116    160   1182
   1090    950    384   1215  26769 116663    160   2602   1996    864
   2578   1864   5223    431  19429   3355    296   2578   1864   1851
   1851   2327   5223      0   1844  34754    238   2433   3464   2264
    458  39604    787    395   8527  30953    519   1090   4617   1321
    201   3119   2710   1321    201    519   1321    201   2558    603
   1321    201   1844  10087      0   1844  34754   1540    431    861
    562    787   1844    864     10   1411    787   2264   9301    519
  58253  13086   8527   3560   5648   3464   2264  10478   2181   1844
  34754      4      0   1844  34754     85   1077   2578   1864   1548
   8068   2578   1864      4    562    787   2264   1692   1938   2924
   1692   3837   2181   3683   7285     35   1844  34754    864    238
   1499    139    519   2806   1321    562   2236    301    395  50747
   2706   2574    429     35    254   2806   1321   1227    176   2574
    429    562    731   2281    139   1127   4668   3459    716   1548
   8068   2578   1864   2927   1636   2400   1851    139  14986   3773
  12279     80   3275   8128   2033   1723   7131    867   3468   2790
   1938  22337   2895  32268   2790   1723   1938  22337   2067   4914
   1723   1938  22337      7   3812   8246   4899   4178   8553   8595
   5487   1553    731   9237  45100    482    429   2684   1221      8]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
9803 9803

最後我們再定義一個將資料製作成batch的操作:

def batch_iter(x, y, batch_size=64):
    """生成批次資料"""
    data_len = len(x)
    num_batch = int((data_len - 1) / batch_size) + 1

    indices = np.random.permutation(np.arange(data_len))
    x_shuffle = x[indices]
    y_shuffle = y[indices]

    for i in range(num_batch):
        start_id = i * batch_size
        end_id = min((i + 1) * batch_size, data_len)
        yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

3、ternsorflow中的RNN

RNN在tensorflow中有靜態RNN,動態RNN之分。兩者差異挺大,我們在使用tensorflow進行RNN實踐時,主要注意以下幾點:

  1. 靜態RNN一般需要將所有句子padding成等長處理,這點與TextCNN一樣的,但動態rnn稍顯靈活一點,動態RNN中,只要一個batch中的所有句子等長就可以;
  2. 靜態RNN的輸入與輸出是list或二維張量;動態RNN中輸入輸出的是三維張量,相對與TextCNN,少了一維;
  3. 靜態RNN生成過程所需的時間更長,網路所佔記憶體會更大,但模型中會帶有每個序列的中間資訊,利於除錯;動態RNN生成過程所需時間相對少,所佔記憶體相對更小,但模型中只有最後的狀態。

本文介紹使用動態RNN進行文字分類。

(1)我們首先要定義模型

class TRNNConfig(object):
    """RNN配置引數"""

    # 模型引數
    embedding_dim = 100      # 詞向量維度
    seq_length = 600        # 序列長度
    num_classes = 20        # 類別數
    vocab_size = 183664       # 詞彙總數

    num_layers= 2           # 隱藏層層數
    hidden_dim = 128        # 隱藏層神經元
    rnn = 'gru'             # lstm 或 gru

    dropout_keep_prob = 0.8 # dropout保留比例
    learning_rate = 1e-3    # 學習率

    batch_size = 128         # 每批訓練大小
    num_epochs = 10         # 總迭代輪次

    print_per_batch = 20    # 每多少輪輸出一次結果
    save_per_batch = 10      # 每多少輪存入tensorboard


class TextRNN(object):
    """文字分類,RNN模型"""
    def __init__(self, config):
        self.config = config

        # 三個待輸入的資料
        self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
        self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

        self.rnn()

    def rnn(self):
        """rnn模型"""

        def lstm_cell():   # lstm核
            return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)

        def gru_cell():  # gru核
            return tf.contrib.rnn.GRUCell(self.config.hidden_dim)

        def dropout(): # 為每一個rnn核後面加一個dropout層
            if (self.config.rnn == 'lstm'):
                cell = lstm_cell()
            else:
                cell = gru_cell()
            return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)

        # 詞向量對映
        with tf.device('/cpu:0'):
            embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
            embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)

        with tf.name_scope("rnn"):
            # 多層rnn網路
            cells = [dropout() for _ in range(self.config.num_layers)]
            rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)

            _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
            last = _outputs[:, -1, :]  # 取最後一個時序輸出作為結果

        with tf.name_scope("score"):
            # 全連線層,後面接dropout以及relu啟用
            fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
            fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            fc = tf.nn.relu(fc)

            # 分類器
            self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 預測類別

        with tf.name_scope("optimize"):
            # 損失函式,交叉熵
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
            self.loss = tf.reduce_mean(cross_entropy)
            # 優化器
            self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

        with tf.name_scope("accuracy"):
            # 準確率
            correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
            self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

模型大致結構如下:

(2)定義一些輔助函式

def evaluate(sess, x_, y_):
    """評估在某一資料上的準確率和損失"""
    data_len = len(x_)
    batch_eval = batch_iter(x_, y_, 128)
    total_loss = 0.0
    total_acc = 0.0
    for x_batch, y_batch in batch_eval:
        batch_len = len(x_batch)
        feed_dict = feed_data(x_batch, y_batch, 1.0)
        loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
        total_loss += loss * batch_len
        total_acc += acc * batch_len

    return total_loss / data_len, total_acc / data_len

def get_time_dif(start_time):
    """獲取已使用時間"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))


def feed_data(x_batch, y_batch, keep_prob):
    feed_dict = {
        model.input_x: x_batch,
        model.input_y: y_batch,
        model.keep_prob: keep_prob
    }
    return feed_dict

(3)定義訓練主函式

def train():
    print("Configuring TensorBoard and Saver...")
    # 配置 Tensorboard,重新訓練時,請將tensorboard資料夾刪除,不然圖會覆蓋
    tensorboard_dir = 'tensorboard/textrnn'
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    tf.summary.scalar("loss", model.loss)
    tf.summary.scalar("accuracy", model.acc)
    merged_summary = tf.summary.merge_all()
    writer = tf.summary.FileWriter(tensorboard_dir)
    
    save_dir = 'checkpoints/textrnn'
    save_path = os.path.join(save_dir, 'best_validation')  # 最佳驗證結果儲存路徑
    # 配置 Saver
    saver = tf.train.Saver()
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print("Loading training and validation data...")
    # 載入訓練集與驗證集
    start_time = time.time()
    train_dir = '/content/drive/My Drive/NLP/dataset/Fudan/data/train_clean_jieba.txt'
    val_dir = '/content/drive/My Drive/NLP/dataset/Fudan/data/test_clean_jieba.txt'
    x_train, y_train = process(train_dir, config.seq_length)
    x_val, y_val = process(val_dir, config.seq_length)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # 建立session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    writer.add_graph(session.graph)

    print('Training and evaluating...')
    start_time = time.time()
    total_batch = 0  # 總批次
    best_acc_val = 0.0  # 最佳驗證集準確率
    last_improved = 0  # 記錄上一次提升批次
    require_improvement = 1000  # 如果超過1000輪未提升,提前結束訓練

    flag = False
    for epoch in range(config.num_epochs):
        print('Epoch:', epoch + 1)
        batch_train = batch_iter(x_train, y_train, config.batch_size)
        for x_batch, y_batch in batch_train:
            feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)

            if total_batch % config.save_per_batch == 0:
                # 每多少輪次將訓練結果寫入tensorboard scalar
                s = session.run(merged_summary, feed_dict=feed_dict)
                writer.add_summary(s, total_batch)

            if total_batch % config.print_per_batch == 0:
                # 每多少輪次輸出在訓練集和驗證集上的效能
                feed_dict[model.keep_prob] = 1.0
                loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                loss_val, acc_val = evaluate(session, x_val, y_val)  # todo

                if acc_val > best_acc_val:
                    # 儲存最好結果
                    best_acc_val = acc_val
                    last_improved = total_batch
                    saver.save(sess=session, save_path=save_path)
                    improved_str = '*'
                else:
                    improved_str = ''

                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                      + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
            
            feed_dict[model.keep_prob] = config.dropout_keep_prob
            session.run(model.optim, feed_dict=feed_dict)  # 執行優化
            total_batch += 1

            if total_batch - last_improved > require_improvement:
                # 驗證集正確率長期不提升,提前結束訓練
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break  # 跳出迴圈
        if flag:  # 同上
            break
if __name__ == '__main__':
  print('Configuring RNN model...')
  config = TRNNConfig()
  model = TextRNN(config)
  train()

執行部分結果:

Epoch: 8
Iter:    540, Train Loss:   0.25, Train Acc:  92.19%, Val Loss:   0.62, Val Acc:  83.12%, Time: 0:22:00 
Iter:    560, Train Loss:   0.28, Train Acc:  91.41%, Val Loss:   0.61, Val Acc:  84.18%, Time: 0:22:48 
Iter:    580, Train Loss:   0.25, Train Acc:  91.41%, Val Loss:   0.59, Val Acc:  84.61%, Time: 0:23:36 *
Iter:    600, Train Loss:   0.39, Train Acc:  89.06%, Val Loss:   0.62, Val Acc:  83.94%, Time: 0:24:24 
Epoch: 9
Iter:    620, Train Loss:   0.17, Train Acc:  95.31%, Val Loss:   0.59, Val Acc:  84.75%, Time: 0:25:12 *
Iter:    640, Train Loss:   0.24, Train Acc:  92.97%, Val Loss:   0.57, Val Acc:  85.21%, Time: 0:26:00 *
Iter:    660, Train Loss:   0.23, Train Acc:  94.53%, Val Loss:   0.61, Val Acc:  83.84%, Time: 0:26:47 
Iter:    680, Train Loss:   0.33, Train Acc:  90.62%, Val Loss:    0.6, Val Acc:  85.02%, Time: 0:27:35 
Epoch: 10
Iter:    700, Train Loss:   0.23, Train Acc:  92.97%, Val Loss:   0.63, Val Acc:  83.92%, Time: 0:28:22 
Iter:    720, Train Loss:   0.29, Train Acc:  92.97%, Val Loss:   0.59, Val Acc:  85.37%, Time: 0:29:10 *
Iter:    740, Train Loss:   0.13, Train Acc:  96.09%, Val Loss:   0.59, Val Acc:  84.92%, Time: 0:29:57 
Iter:    760, Train Loss:   0.32, Train Acc:  91.41%, Val Loss:   0.62, Val Acc:  84.72%, Time: 0:30:44 

在tensorboard視覺化結果:

同時會生成儲存的檔案:

進行測試,這裡我們的測試集和驗證集是同樣的:

def test():
  print("Loading test data...")
  start_time = time.time()
  test_dir = '/content/drive/My Drive/NLP/dataset/Fudan/data/test_clean_jieba.txt'
  x_test, y_test = process(test_dir, config.seq_length)
  save_path = 'checkpoint/textrnn/best_validation'
  session = tf.Session()
  session.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.restore(sess=session, save_path=save_path)  # 讀取儲存的模型

  print('Testing...')
  loss_test, acc_test = evaluate(session, x_test, y_test)
  msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
  print(msg.format(loss_test, acc_test))

  batch_size = 128
  data_len = len(x_test)
  num_batch = int((data_len - 1) / batch_size) + 1

  y_test_cls = np.argmax(y_test, 1)
  y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 儲存預測結果
  for i in range(num_batch):  # 逐批次處理
      start_id = i * batch_size
      end_id = min((i + 1) * batch_size, data_len)
      feed_dict = {
          model.input_x: x_test[start_id:end_id],
          model.keep_prob: 1.0
      }
      y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)

  # 評估
  print("Precision, Recall and F1-Score...")
  categories = get_label_id().values()
  print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))

  # 混淆矩陣
  print("Confusion Matrix...")
  cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
  print(cm)

  time_dif = get_time_dif(start_time)
  print("Time usage:", time_dif)
if __name__ == '__main__':
  print('Configuring RNN model...')
  config = TRNNConfig()
  model = TextRNN(config)
  test()

結果:這裡9833是因為最後面多出了一行空行

Test Loss:   0.61, Test Acc:  84.53%
Precision, Recall and F1-Score...
/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        61
           1       0.87      0.90      0.88      1022
           2       0.28      0.32      0.30        59
           3       0.87      0.91      0.89      1254
           4       0.60      0.40      0.48        52
           5       0.74      0.88      0.80      1026
           6       0.95      0.94      0.94      1358
           7       0.50      0.02      0.04        45
           8       0.40      0.24      0.30        76
           9       0.84      0.88      0.86       742
          10       0.60      0.09      0.15        34
          11       0.00      0.00      0.00        28
          12       0.91      0.92      0.92      1218
          13       0.85      0.85      0.85       642
          14       0.36      0.12      0.18        33
          15       0.44      0.15      0.22        27
          16       0.88      0.88      0.88      1601
          17       0.27      0.45      0.34        53
          18       0.33      0.12      0.17        34
          19       0.65      0.52      0.58       468

    accuracy                           0.85      9833
   macro avg       0.57      0.48      0.49      9833
weighted avg       0.83      0.85      0.84      9833

Confusion Matrix...
[[   0    3    2   43    0    3    0    0    1    1    0    0    0    1
     0    0    2    0    0    5]
 [   0  916    0   13    0    6    0    0    0    1    0    0   21    0
     0    0   49    8    2    6]
 [   0    2   19    2    1    1    3    0    1    0    0    0    5    5
     2    2    1   13    1    1]
 [   0    8    1 1147    0   45    1    0    2    7    0    0    4    5
     0    0   12    3    1   18]
 [   0    2    1    5   21    4    2    0    1    3    0    0    2    1
     0    0    6    2    0    2]
 [   0    4    0   23    1  898    0    0    3   13    0    0    0    0
     0    0   67    0    1   16]
 [   0    0    1    9    0    1 1278    0    0    8    1    0    6   46
     0    0    7    1    0    0]
 [   0    0    1    9    0   16    1    1    0   11    0    0    0    0
     0    1    2    0    0    3]
 [   0    1    3    7    0   23    1    0   18    2    0    0    0    2
     1    0    1    3    0   14]
 [   0    0    0    2    2   29    2    0    1  651    1    0    0    0
     0    0    3    1    0   50]
 [   0    0    0    1    0    4    0    1    2   15    3    0    0    0
     0    0    2    1    0    5]
 [   0    0    0    3    0    1    4    0    0    0    0    0    5    6
     0    0    6    3    0    0]
 [   0   32    5    5    3    0   15    0    0    0    0    0 1117   13
     1    1   21    3    2    0]
 [   0    6   15    8    3    0   33    0    4    1    0    0   18  546
     0    0    0    8    0    0]
 [   0    2    2    0    1    2    0    0    0    1    0    0   11    6
     4    0    3    0    0    1]
 [   0    0    0    2    0    1    8    0    2    0    0    0    2    6
     0    4    1    0    0    1]
 [   0   59    3   21    1   55    3    0    3    2    0    0   25    0
     2    0 1416    5    1    5]
 [   0    7    9    4    0    1    0    0    3    0    0    0    0    0
     0    0    2   24    0    3]
 [   0    4    5    0    1    2    0    0    1    0    0    0    5    0
     1    0    2    8    4    1]
 [   0    4    1   15    1  118    0    0    3   61    0    0    0    2
     0    1   10    7    0  245]]
Time usage: 0:01:01

上面的模型是沒有加入到我們預先訓練好的詞向量的,接下來,我們要將自己的詞向量匯入到模型中,再進行訓練。

4、將詞向量加入到網路中

首先我們需要對詞向量進行處理:生成一個詞嵌入,然後將詞向量賦值給對應的位置

import numpy as np
def export_word2vec_vectors():
  word2vec_dir = '/content/drive/My Drive/NLP/dataset/Fudan/vector.txt'
  trimmed_filename = '/content/drive/My Drive/NLP/dataset/Fudan/vector_word.npz'
  file_r = open(word2vec_dir, 'r', encoding='utf-8')
  #(183664,100)
  lines = file_r.readlines()
  embeddings = np.zeros([183664, 100])
  for i,vec in enumerate(lines):
    vec = vec.strip().split(" ")
    vec = np.asarray(vec,dtype='float32')
    embeddings[i] = vec
  np.savez_compressed(trimmed_filename, embeddings=embeddings)
export_word2vec_vectors()

之後用這種方式進行讀取:

def get_training_word2vec_vectors(filename):
  with np.load(filename) as data:
      return data["embeddings"]

接下來看看我們需要修改的地方:

在模型配置檔案中加入:

    pre_trianing = None
    vector_word_npz = '/content/drive/My Drive/NLP/dataset/Fudan/vector_word.npz'

在模型中修改:

#embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
            embedding = tf.get_variable("embeddings", shape=[self.config.vocab_size, self.config.embedding_dim],
                                             initializer=tf.constant_initializer(self.config.pre_trianing))
            embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)

在main中修改:

if __name__ == '__main__':
  print('Configuring RNN model...')
  config = TRNNConfig()
  config.pre_trianing = get_training_word2vec_vectors(config.vector_word_npz)
  model = TextRNN(config)
  train()

然後我們執行:

Epoch: 8
Iter:    540, Train Loss:   0.17, Train Acc:  92.97%, Val Loss:   0.44, Val Acc:  87.80%, Time: 0:22:14 
Iter:    560, Train Loss:   0.17, Train Acc:  96.09%, Val Loss:   0.39, Val Acc:  89.10%, Time: 0:23:04 *
Iter:    580, Train Loss:   0.14, Train Acc:  94.53%, Val Loss:    0.4, Val Acc:  88.71%, Time: 0:23:51 
Iter:    600, Train Loss:   0.16, Train Acc:  92.97%, Val Loss:   0.39, Val Acc:  89.10%, Time: 0:24:37 
Epoch: 9
Iter:    620, Train Loss:   0.14, Train Acc:  93.75%, Val Loss:    0.4, Val Acc:  88.78%, Time: 0:25:25 
Iter:    640, Train Loss:   0.16, Train Acc:  96.09%, Val Loss:   0.42, Val Acc:  88.67%, Time: 0:26:13 
Iter:    660, Train Loss:   0.13, Train Acc:  96.09%, Val Loss:   0.42, Val Acc:  88.95%, Time: 0:26:59 
Iter:    680, Train Loss:   0.18, Train Acc:  94.53%, Val Loss:    0.4, Val Acc:  89.17%, Time: 0:27:47 *
Epoch: 10
Iter:    700, Train Loss:   0.19, Train Acc:  94.53%, Val Loss:   0.43, Val Acc:  89.06%, Time: 0:28:35 
Iter:    720, Train Loss:  0.046, Train Acc:  98.44%, Val Loss:    0.4, Val Acc:  89.72%, Time: 0:29:22 *
Iter:    740, Train Loss:   0.11, Train Acc:  96.09%, Val Loss:   0.44, Val Acc:  88.86%, Time: 0:30:10 
Iter:    760, Train Loss:  0.059, Train Acc:  97.66%, Val Loss:   0.39, Val Acc:  89.47%, Time: 0:30:57 

再進行測試:

Test Loss:    0.4, Test Acc:  89.72%
Precision, Recall and F1-Score...
              precision    recall  f1-score   support

           0       0.48      0.38      0.42        61
           1       0.93      0.91      0.92      1022
           2       0.58      0.51      0.54        59
           3       0.95      0.93      0.94      1254
           4       0.75      0.40      0.53        52
           5       0.87      0.91      0.89      1026
           6       0.93      0.98      0.96      1358
           7       0.41      0.31      0.35        45
           8       0.64      0.57      0.60        76
           9       0.89      0.91      0.90       742
          10       0.57      0.12      0.20        34
          11       0.36      0.18      0.24        28
          12       0.94      0.95      0.95      1218
          13       0.93      0.92      0.92       642
          14       0.42      0.15      0.22        33
          15       0.33      0.07      0.12        27
          16       0.90      0.94      0.92      1601
          17       0.56      0.60      0.58        53
          18       0.36      0.15      0.21        34
          19       0.75      0.74      0.75       468

    accuracy                           0.90      9833
   macro avg       0.68      0.58      0.61      9833
weighted avg       0.89      0.90      0.89      9833

Confusion Matrix...
[[  23    0    0   17    0    2    1    1    0    5    0    0    2    1
     0    0    3    6    0    0]
 [   0  926    0    0    0    3    0    0    0    0    0    0    7    1
     0    0   72    1    0   12]
 [   0    1   30    0    1    0   13    0    0    0    0    1    0    5
     0    1    6    1    0    0]
 [   8    6    0 1165    0   21    4    0    1   14    0    0    8    3
     0    0    8    3    0   13]
 [   0    0    4    0   21    5    4    0    3    0    0    1    4    0
     0    1    9    0    0    0]
 [   3    5    0   12    2  932    0    6   11    4    0    0    3    0
     0    0   28    1    0   19]
 [   0    0    1    1    0    0 1336    0    0    0    0    3    3   12
     0    0    2    0    0    0]
 [   3    0    0   10    0    8    0   14    0    6    0    0    0    1
     0    0    1    0    0    2]
 [   1    1    2    0    0   15    2    0   43    0    0    0    0    3
     0    0    0    8    0    1]
 [   0    0    1    2    1    0    2    5    1  675    3    0    0    0
     0    0    1    0    0   51]
 [   0    0    0    2    0    2    0    4    2   10    4    0    0    0
     0    0    1    0    0    9]
 [   0    0    1    1    0    0    9    0    0    0    0    5    0    6
     0    1    4    1    0    0]
 [   1   14    0    0    0    2   13    0    2    0    0    0 1161    5
     0    0   17    0    3    0]
 [   0    6    1    3    0    0   28    0    0    1    0    0   12  589
     0    0    1    1    0    0]
 [   0    1    2    0    0    1    0    0    0    0    0    1   14    2
     5    0    4    0    3    0]
 [   0    0    6    0    0    1   12    0    1    0    0    1    0    2
     0    2    2    0    0    0]
 [   1   27    3    4    2   32    3    3    0    0    0    0    4    0
     1    1 1509    3    3    5]
 [   8    2    0    3    1    1    0    0    0    0    0    1    2    0
     1    0    2   32    0    0]
 [   0    1    1    0    0    0    1    0    0    0    0    1   12    2
     5    0    6    0    5    0]
 [   0    4    0    5    0   48    4    1    3   46    0    0    0    4
     0    0    8    0    0  345]]
Time usage: 0:01:02

使用了我們預先訓練的詞向量之後,發現比隨機生成的詞向量相比,確實能夠提升網路的效能。

最後做個總結:

使用RNN進行文字分類的過程如下:

  • 獲取資料;
  • 無論資料是什麼格式的,我們需要對其進行分詞(去掉停用詞)可以根據頻率進行選擇前N個詞(可選);
  • 我們需要所有詞,並對它們進行編號;
  • 訓練詞向量(可選),要將訓練好的向量和詞編號進行對應;
  • 將資料集中的句子中的每個詞用編號代替,對標籤也進行編號,讓標籤和標籤編號對應;
  • 文字可使用keras限制它的最大長度,標籤進行onehot編碼;
  • 讀取資料集(文字和標籤),然後構建batchsize
  • 搭建模型並進行訓練和測試;

至此從資料的處理到文字分類的整個流程就已經全部完成了,接下來還是對該資料集,使用CNN進行訓練和測試。歡迎關注我的微信公眾號-西西嘛呦,它不橡部落格園發表那樣長篇大論的文章,只希望能夠帶給你有用的知識。

參考:

https://www.jianshu.com/p/cd9563a3f6c9

https://github.com/cjymz886/text-cnn

https://github.com/gaussic/text-classification-cnn-rnn/