keras搭建神經網路分類新聞主題
阿新 • • 發佈:2018-11-07
from keras.datasets import reuters import numpy as np from keras import models from keras import layers from keras.optimizers import RMSprop from keras.losses import categorical_crossentropy ''' 資料集介紹: 用路透社資料集,它包含許多短新聞及其對應的主題,由路透社在 1986 年釋出。它 是一個簡單的、廣泛使用的文字分類資料集。它包括 46 個不同的主題:某些主題的樣本更多, 但訓練集中每個主題都有至少 10 個樣本。 ''' # 載入資料(引數 num_words=10000 將資料限定為前 10 000 個最常出現的單詞) (train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000) # 先來看看資料集大小(訓練集8982個,測試集2246個) print(train_data.shape) print(train_data[0]) print(train_labels) print(test_data.shape) print(test_data[0]) # 使用one-hot編碼,使列表轉化為大小一致的張量 def vectorize_sequences(sequences, dimension=10000): results = np.zeros((len(sequences), dimension)) # 建立一個sequences行,10000列的0向量 for i, sequence in enumerate(sequences): results[i, sequence] = 1 # 將單詞出現指定處置為1 return results # 對資料編碼 x_train = vectorize_sequences(train_data) x_test = vectorize_sequences(test_data) # 進行One-hot編碼 def to_one_hot(labels, dimension=46): results = np.zeros((len(labels), dimension)) for i, label in enumerate(labels): results[i, label] = 1 return results # 對標籤進行One-hot編碼 one_hot_train_labels = to_one_hot(train_labels) one_hot_test_labels = to_one_hot(test_labels) # 準備好了資料,接下來就可以構建網路啦 model = models.Sequential() model.add(layers.Dense(units=64, activation='relu', input_shape=(10000, ))) model.add(layers.Dense(units=64, activation='relu')) model.add(layers.Dense(units=46, activation='softmax')) # 一共46類 model.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy']) # 開始訓練網路 model.fit(x_train, one_hot_train_labels, epochs=4, batch_size=128, validation_data=[x_test, one_hot_test_labels]) # 來在測試集上測試一下模型的效能吧 test_loss, test_accuracy = model.evaluate(x_test, one_hot_test_labels) print("test_loss:", test_loss, " test_accuracy:", test_accuracy) # 儲存模型 model.save('retuters.json')
幾點說明:
- 網路的最後一層是大小為 46 的 Dense 層。這意味著,對於每個輸入樣本,網路都會輸出一個 46 維向量。這個向量的每個元素(即每個維度)代表不同的輸出類別。
- 最後一層使用了 softmax 啟用。網路將輸出在 46個不同輸出類別上的概率分佈——對於每一個輸入樣本,網路都會輸出一個 46 維向量,其中 output[i] 是樣本屬於第 i 個類別的概率。46 個概率的總和為 1。對於單標籤、多分類問題,網路的最後一層應該使用 softmax 啟用,這樣可以輸出在 N個輸出類別上的概率分佈。
- 最終輸出是 46 維的,因此中間層的隱藏單元個數不應該比 46 小太多。
- 這種問題的損失函式幾乎總是應該使用分類交叉熵。它將網路輸出的概率分佈與目標的真實分佈之間的距離最小化。
結果如下: