Tensorflow常見模型實現之一(LSTM/BiLSTM)
阿新 • • 發佈:2018-12-15
1. LSTM
import tensorflow as tf import tensorflow.contrib as contrib from tensorflow.python.ops import array_ops class lstm(object): def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'): self.in_data = in_data self.hidden_dim = hidden_dim self.batch_seqlen = batch_seqlen self.flag = flag lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim) out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32) if flag=='all_ht': self.out = out if flag = 'first_ht': self.out = out[:,0,:] if flag = 'last_ht': self.out = out[:,-1,:] if flag = 'concat': self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)
2. Bi-LSTM
import tensorflow as tf import tensorflow.contrib as contrib from tensorflow.python.ops import array_ops from tensorflow.python.framework import dtypes class bilstm(object): def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'): self.in_data = in_data self.hidden_dim = hidden_dim self.batch_seqlen = batch_seqlen self.flag = flag lstm_cell_fw = contrib.rnn.LSTMCell(self.hidden_dim) lstm_cell_bw = contrib.rnn.LSTMCell(self.hidden_dim) out, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fw,cell_bw=lstm_cell_bw,inputs=self.in_data, sequence_lenth=self.batch_seqlen,dtype=tf.float32) bi_out = tf.concat(out, 2) if flag=='all_ht': self.out = bi_out if flag=='first_ht': self.out = bi_out[:,0,:] if flag=='last_ht': self.out = tf.concat([state[0].h,state[1].h], 1) if flag=='concat': self.out = tf.concat([bi_out[:,0,:],tf.concat([state[0].h,state[1].h], 1)],1)