1. 程式人生 > >Tensorflow - 統計batch中資料的長度.

Tensorflow - 統計batch中資料的長度.

就是記錄一下,這個用於統計batch中資料的長度,用於RNN或者LSTM的輸入,由於使用了0進行資料填充,因此統計實際資料長度,輸入到RNN的length中,其中輸入為一個batch,輸出為(batch_size,)為一個batch中所有資料的長度。

def compute_mask(batch_index):
    c_mask = tf.cast(batch_index, tf.bool)
    index_len = tf.reshape(tf.reduce_sum(tf.cast(c_mask, tf.int32), axis=-1),[-1])
    return index_len