pytorch + visdom RNN分類手寫數字(MNIST)
阿新 • • 發佈:2019-02-15
環境
系統:win10
cpu:i7-6700HQ
gpu:gtx965m
python : 3.6
pytorch :0.3
RNN
簡單來講,rnn與傳統的神經網路相比,最大的優勢是隱藏層的神經單元互通,隱藏層會對之前的輸出做整合,也就是做能記錄之前的幾個狀態,於是具有了聯絡上下文的能力,但是也僅僅只能 記錄幾個狀態而已,因此有一定的應用侷限,LSTM解決這個問題。
LSTM
我們來用LSTM 對 MNIST進行分類,我們知道圖片其實也有由數字組成,每一個畫素點就是由數字組成,黑白圖片是一層,彩色圖片是rgb三層,組成影象的數字從上到下也是有順序的,屬於連續的序列,因此根據序列的連續上下內容我們也可以根據這戲資訊對資料做分類。
資料
MNIST 手寫數字資料集
train_dataset = datasets.MNIST('./mnist', True, transforms.ToTensor(), download=False)
test_dataset = datasets.MNIST('./mnist', False, transforms.ToTensor())
train_loader = DataLoader(train_dataset, BATCH_SIZE, True)
# 縮短測試時間,只取2000個數據
test_data = test_dataset.test_data[:2000]
test_label =test_dataset.test_labels[:2000 ]
# 資料視覺化
viz.images(torch.unsqueeze(test_data[:25], 1), nrow=5)
資料是28x28的大小,我們可以將每一個圖片理解為28個序列,每個序列28維度。
def __init__(self, in_dim, n_class):
super(RNN, self).__init__()
self.rnn = nn.LSTM(
# 輸入28 ,輸出 64
input_size=in_dim,
hidden_size=64,
# 神經層 2 層
num_layers=2 ,
# out(batch, 序列長度, 維度)
batch_first=True)
# n_class為輸出的分類數
self.cf = nn.Linear(64, n_class)
def forward(self, x):
# RNN(LSTM)的輸出是output 和 hidden 這裡只取 output
out = self.rnn(x)[0]
# 由於LSTM是根據具有對序列的記憶能力,我們只輸出序列的最後一位,來判斷
out = out[:, -1, :]
out = self.cf(out)
return out
執行20個epoch看看,,結果:
epoch: [20/20] | Loss: 0.0228 | TR_acc: 0.9910 | TS_acc: 0.9835 | Time: 377.4
epoch: [20/20] | Loss: 0.0221 | TR_acc: 0.9926 | TS_acc: 0.9860 | Time: 378.1
epoch: [20/20] | Loss: 0.0203 | TR_acc: 0.9934 | TS_acc: 0.9820 | Time: 378.8
準確率還可以,由於CNN(LSTM)的這種處理時間序列上根據有優勢,於是多由於處理自然語言方面。