tensorflow學習之MultiRNNCell詳解
阿新 • • 發佈:2018-12-11
tf.contrib.rnn.MultiRNNCell
Aliases:
- Class tf.contrib.rnn.MultiRNNCell
- Class tf.nn.rnn_cell.MultiRNNCell
由多個簡單的cells組成的RNN cell。用於構建多層迴圈神經網路。
__init__( cells, state_is_tuple=True ) |
引數:
- cells:RNNCells的list。
- state_is_tuple:如果為True,接受和返回的states是n-tuples,其中n=len(cells)。如果為False,states是concatenated沿著列軸.後者即將棄用。
程式碼例項:
import tensorflow as tf batch_size=10 depth=128 inputs=tf.Variable(tf.random_normal([batch_size,depth])) previous_state0=(tf.random_normal([batch_size,100]),tf.random_normal([batch_size,100])) previous_state1=(tf.random_normal([batch_size,200]),tf.random_normal([batch_size,200])) previous_state2=(tf.random_normal([batch_size,300]),tf.random_normal([batch_size,300])) num_units=[100,200,300] print(inputs) cells=[tf.nn.rnn_cell.BasicLSTMCell(num_unit) for num_unit in num_units] mul_cells=tf.nn.rnn_cell.MultiRNNCell(cells) outputs,states=mul_cells(inputs,(previous_state0,previous_state1,previous_state2)) print(outputs.shape) #(10, 300) print(states[0]) #第一層LSTM print(states[1]) #第二層LSTM print(states[2]) ##第三層LSTM print(states[0].h.shape) #第一層LSTM的h狀態,(10, 100) print(states[0].c.shape) #第一層LSTM的c狀態,(10, 100) print(states[1].h.shape) #第二層LSTM的h狀態,(10, 200)