LSTM中tf.nn.dynamic_rnn處理過程詳解
在唐宇迪之tensorflow學習筆記專案實戰(LSTM情感分析)一文中,連結地址如下https://blog.csdn.net/liushao123456789/article/details/78991581。對於tf.nn.dynamic_rnn處理過程的程式碼如下,但是每一步缺少細緻的解釋,本部落格旨在幫助小夥伴們詳細瞭解每一的步驟以及為什麼要這樣做。
lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits) lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75) value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)
lstmUnits為神經元的個數,前兩行程式碼比較好理解,第三行程式碼生成的value和_令我百思不得其解。接著又出現另外幾行程式碼更讓我雲裡霧裡。
weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses])) bias = tf.Variable(tf.constant(0.1, shape=[numClasses])) value = tf.transpose(value, [1, 0, 2]) #取最終的結果值 last = tf.gather(value, int(value.get_shape()[0]) - 1) prediction = (tf.matmul(last, weight) + bias)
看到這裡不禁會發問,為什麼要對value進行value = tf.transpose(value, [1, 0, 2])這部分操作,然後last = tf.gather(value, int(value.get_shape()[0]) - 1)這一步又有什麼作用?帶著這些疑問,我通過不停地百度,參考https://blog.csdn.net/junjun150013652/article/details/81331448這篇文章終於得出解答。
首先tf.nn.dynamic_rnn的輸出包括outputs和states兩部分。參照連結中的文章:輸入X是一個 [batch_size,step,input_size] = [3,2,3] 的tensor,其中step表示長度,input_size表示向量的維度
outputs_val.shape:
Tensor("rnn/transpose_1:0", shape=(?, 2, 5), dtype=float32)
states_val.shape:
(<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 5) dtype=float32>,
<tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 5) dtype=float32>,
<tf.Tensor 'rnn/while/Exit_5:0' shape=(?, 5) dtype=float32>)
outputs_val:
[[[0. 0. 0. 0. 0. ]
[0. 0.18740742 0. 0.2997518 0. ]]
[[0. 0.07222144 0. 0.11551574 0. ]
[0. 0. 0. 0. 0. ]]
[[0. 0.13463384 0. 0.21534224 0. ]
[0.03702604 0.18443246 0. 0.34539366 0. ]]
[[0. 0.54511094 0. 0.8718864 0. ]
[0.5382122 0. 0.04396425 0.4040263 0. ]]]
states_val:
(array([[0. , 0.83723307, 0. , 0. , 2.8518028 ],
[0. , 0.1996038 , 0. , 0. , 1.5456247 ],
[0. , 1.1372368 , 0. , 0. , 0.832613 ],
[0. , 0.7904129 , 2.4675028 , 0. , 0.36980057]],
dtype=float32),
array([[0.6524607 , 0. , 0. , 0. , 0. ],
[0.25143963, 0. , 0. , 0. , 0. ],
[0.5010576 , 0. , 0. , 0. , 0. ],
[0. , 0.3166597 , 0.4545995 , 0. , 0. ]],
dtype=float32),
array([[0. , 0.18740742, 0. , 0.2997518 , 0. ],
[0. , 0.07222144, 0. , 0.11551574, 0. ],
[0.03702604, 0.18443246, 0. , 0.34539366, 0. ],
[0.5382122 , 0. , 0.04396425, 0.4040263 , 0. ]],
dtype=float32))
在唐宇迪例子中value相當於outputs,我們需要找outputs的最後一個step的輸出。對value進行value = tf.transpose(value, [1, 0, 2])操作後得到的shape為[step,batch_size,lstmUnits].而後last = tf.gather(value, int(value.get_shape()[0]) - 1),其中value.get_shape()[0]) - 1找到value經過transpose後的最後一個分片,last = tf.gather(value, int(value.get_shape()[0]) - 1)表示最後一個[batch_size,lstmUnits],也就是lstm最後的輸出,這時候weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))的shape為[lstmUnits,numClasses],last的shape為[batch_size,lstmUnits],兩者相乘的維度為[batch_size,numClasses],再與偏置向量相加即可得到。