【PyTorch】PyTorch進階教程三
阿新 • • 發佈:2019-01-22
前面介紹了使用PyTorch構造CNN網路,這一節介紹點高階的東西LSTM。
關於LSTM的理論介紹請參考兩篇有名的部落格:
以及我之前的一篇中文翻譯部落格:
LSTM
class torch.nn.LSTM(*args, **kwargs)
Parameters
- input_size 輸入特徵維數
- hidden_size 隱層狀態的維數
- num_layers RNN層的個數
- bias 隱層狀態是否帶bias,預設為true
- batch_first 是否輸入輸出的第一維為batchsize
- dropout 是否在除最後一個RNN層外的RNN層後面加dropout層
- bidirectional 是否是雙向RNN,預設為false
Inputs: input, (h_0, c_0)
- input (seq_len, batch, input_size) 包含特徵的輸入序列,如果設定了batch_first,則batch為第一維
- (h_0, c_0) 隱層狀態
Outputs: output, (h_n, c_n)
- output (seq_len, batch, hidden_size * num_directions) 包含每一個時刻的輸出特徵,如果設定了batch_first,則batch為第一維
- (h_n, c_n) 隱層狀態
Model
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
batch_first=True )
self.fc = nn.Linear(hidden_size, num_classes) # 2 for bidirection
def forward(self, x):
# Forward propagate RNN
out, _ = self.lstm(x)
# Decode hidden state of last time step
out = self.fc(out[:, -1, :])
return out
rnn = RNN(input_size, hidden_size, num_layers, num_classes)
rnn.cuda()
PyTorch中實現LSTM是十分方便的,只需要定義輸入維度,隱層維度,RNN個數,以及分類個數就可以了。lstm的輸入狀態如果為空的話,表示預設初始化為0。在MNIST上,只需要2個epoch就可以達到97%的正確率。