1. 程式人生 > 實用技巧 >二、LSTM處理不定長句子

二、LSTM處理不定長句子

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data

x1 = [
           torch.tensor([[6,6], [6,6],[6,6]]).float(),
           torch.tensor([[7,7]]).float()
]
y = [
    torch.tensor([1]),
    torch.tensor([0])
]



class MyData(data.Dataset): def __init__(self, data_seq, y): self.data_seq = data_seq self.y = y def __len__(self): return len(self.data_seq) def __getitem__(self, idx): tuple_ = (self.data_seq[idx], self.y[idx]) return tuple_ def collate_fn(data_tuple): data_tuple.sort(key
=lambda x: len(x[0]), reverse=True) data = [sq[0] for sq in data_tuple] label = [sq[1] for sq in data_tuple] data_length = [len(q) for q in data] data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0.0) label = rnn_utils.pad_sequence(label, batch_first=True, padding_value=0.0)
return data, label,data_length if __name__=='__main__': learning_rate = 0.001 data = MyData(x1, y) data_loader = DataLoader(data, batch_size=2, shuffle=True, collate_fn=collate_fn) batch_x, y, batch_x_len = iter(data_loader).next() print(batch_x) print(batch_x.shape) print(batch_x_len) print(y) print(y.shape) batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True) net = nn.LSTM(input_size=2, hidden_size=10, num_layers=4, batch_first=True) criteria = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) print(batch_x_pack) out, (h1, c1) = net(batch_x_pack)