1. 程式人生 > >keras搭建神經網路分類新聞主題

keras搭建神經網路分類新聞主題

from keras.datasets import reuters
import numpy as np
from keras import models
from keras import layers
from keras.optimizers import RMSprop
from keras.losses import categorical_crossentropy
'''
資料集介紹:
用路透社資料集,它包含許多短新聞及其對應的主題,由路透社在 1986 年釋出。它
是一個簡單的、廣泛使用的文字分類資料集。它包括 46 個不同的主題:某些主題的樣本更多,
但訓練集中每個主題都有至少 10 個樣本。
'''

#  載入資料(引數 num_words=10000 將資料限定為前 10 000 個最常出現的單詞)
(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)

#  先來看看資料集大小(訓練集8982個,測試集2246個)
print(train_data.shape)
print(train_data[0])
print(train_labels)
print(test_data.shape)
print(test_data[0])

# 使用one-hot編碼,使列表轉化為大小一致的張量
def vectorize_sequences(sequences, dimension=10000):
    results = np.zeros((len(sequences), dimension))  #  建立一個sequences行,10000列的0向量
    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1  #  將單詞出現指定處置為1
    return results

#  對資料編碼
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

#  進行One-hot編碼
def to_one_hot(labels, dimension=46):
    results = np.zeros((len(labels), dimension))
    for i, label in enumerate(labels):
        results[i, label] = 1
    return results

#  對標籤進行One-hot編碼
one_hot_train_labels = to_one_hot(train_labels)
one_hot_test_labels = to_one_hot(test_labels)


#  準備好了資料,接下來就可以構建網路啦
model = models.Sequential()
model.add(layers.Dense(units=64, activation='relu', input_shape=(10000, )))
model.add(layers.Dense(units=64, activation='relu'))
model.add(layers.Dense(units=46, activation='softmax'))  # 一共46類

model.compile(optimizer=RMSprop(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])


#  開始訓練網路
model.fit(x_train, one_hot_train_labels, epochs=4, batch_size=128, validation_data=[x_test, one_hot_test_labels])

# 來在測試集上測試一下模型的效能吧
test_loss, test_accuracy = model.evaluate(x_test, one_hot_test_labels)
print("test_loss:", test_loss, "    test_accuracy:", test_accuracy)

#  儲存模型
model.save('retuters.json')

幾點說明:

  • 網路的最後一層是大小為 46 的 Dense 層。這意味著,對於每個輸入樣本,網路都會輸出一個 46 維向量。這個向量的每個元素(即每個維度)代表不同的輸出類別。
  • 最後一層使用了 softmax 啟用。網路將輸出在 46個不同輸出類別上的概率分佈——對於每一個輸入樣本,網路都會輸出一個 46 維向量,其中 output[i] 是樣本屬於第 i 個類別的概率。46 個概率的總和為 1。對於單標籤、多分類問題,網路的最後一層應該使用 softmax 啟用,這樣可以輸出在 N個輸出類別上的概率分佈。
  • 最終輸出是 46 維的,因此中間層的隱藏單元個數不應該比 46 小太多。
  • 這種問題的損失函式幾乎總是應該使用分類交叉熵。它將網路輸出的概率分佈與目標的真實分佈之間的距離最小化。

結果如下:

在這裡插入圖片描述