1. 程式人生 > >Bi-LSTM-CRF(一)--tensorflow原始碼解析

Bi-LSTM-CRF(一)--tensorflow原始碼解析

1.1.核心程式碼:

cell_fw = tf.contrib.rnn.LSTMCell(num_units=100)

cell_bw = tf.contrib.rnn.LSTMCell(num_units=100)

(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=300)

1.2.LSTMCell

其實LSTM使用起來很簡單,就是輸入一排的向量,然後輸出一排的向量。構建時只要設定兩個超引數:num_units和sequence_length。

tf.contrib.rnn.LSTMCell(

    num_units,

    use_peepholes=False,

    cell_clip=None,

    initializer=None,

    num_proj=None,

    proj_clip=None,

    num_unit_shards=None,

    num_proj_shards=None,

    forget_bias=1.0,

    state_is_tuple=True,

    activation=None,

    reuse=None

)

上面的LSTM Cell只有一個超引數需要設定,num_units,即輸出向量的維度。

1.3.bidirectional_dynamic_rnn

這個函式唯一需要設定的超引數就是序列長度sequence_length。

(outputs, output_states) = tf.nn.bidirectional_dynamic_rnn(

    cell_fw,

    cell_bw,

    inputs,

    sequence_length=None,

    initial_state_fw=None,

    initial_state_bw=None,

    dtype=None,

    parallel_iterations=None,

    swap_memory=False,

    time_major=False,

    scope=None

)

輸入: inputs的shape通常是[batch_size, sequence_length, dim_embedding]。

輸出: outputs是一個(output_fw, output_bw)元組,output_fw和output_bw的shape都是[batch_size, sequence_length, num_units]

output_states是一個(output_state_fw, output_state_bw) 元組,分別是前向和後向最後一個Cell的Output,output_state_fw和output_state_bw的型別都是LSTMStateTuple,這個類有兩個屬性c和h,分別表示Memory Cell和Hidden State