從一句程式碼看tf.scan
阿新 • • 發佈:2019-01-14
在讀這篇文章的時候遇到了以下程式碼:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True) cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True) init_state = cell.zero_state(batch_size, tf.float32) rnn_outputs, final_states = \ tf.scan(lambda a, x: cell(x, a[1]), tf.transpose(rnn_inputs, [1,0,2]), initializer=(tf.zeros([batch_size, state_size]), init_state))
這裡來解釋一下:
首先,tf.scan 第一個輸入是函式,也就是:
tf.scan(lambda a, x: cell(x, a[1])
等價於(未驗證,僅作illustration用):
def func(a, x):
return cell(x, a[1])
x是輸入,a是上一步函式func的輸出。為什麼輸入cell的是a[1]呢?這是因為,根據官方文件,MultiRNNCell的輸出是:
Returns: A pair containing: Output: A 2-D tensor with shape [batch_size, self.output_size]. New state: Either a single 2-D tensor, or a tuple of tensors matching the arity and shapes of state.
換句話說,就是: (output,New state),也就是a。
那麼:
a = (output, new_state)
a[0] = output
a[1] = new_state
所以,cell的輸入,其一是x,也就是每一個time step的輸入,其二是a[1],也就是上一個time step 輸出的hidden state。
然後,tf.scan 的第二個輸入是input,這個沒什麼好說的,需要注意資料的形狀要從[batch_size,num_steps, state_size] 調整為[num_steps, batch_size, state_size]。tf.scan 會一步一步的把input輸入cell,每次的形狀是:[batch_size, state_size]
tf.scan第三個引數是a的初始化,那麼水到渠成,它分別初始化了output和new_state:
initializer=(tf.zeros([batch_size, state_size]), init_state))
至此這句程式碼就分析完畢了。有不懂的同學還請細細鑽研,弄懂了就不難。另附例子如下:
def testScan_SingleInputMultiOutput(self):
with self.test_session() as sess:
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
initializer = (np.array(1.0), np.array(-1.0))
r = tf.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, initializer)
r_value = sess.run(r)
self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])