tf.contrib.rnn.BasicLSTMCell, tf.contrib.rnn.MultiRNNCell深度解析
阿新 • • 發佈:2019-02-04
tf.contrib.rnn.BasicRnnCell
首先來看看BasicRNNCell的原始碼
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
self._reuse = reuse
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
with _checked_scope(self, scope or "basic_rnn_cell" , reuse=self._reuse):
output = self._activation(
_linear([inputs, state], self._num_units, True))
return output, output
BasicRNNCell是最基本的RNN cell單元。
輸入引數:num_units:RNN層神經元的個數
input_size(該引數已被棄用)
activation: 內部狀態之間的啟用函式
reuse: Python布林值, 描述是否重用現有作用域中的變數
從原始碼中可以看出通過BasicRnnCell定義的例項物件Cell,其中兩個屬性Cell.state_size和Cell.output_size返回的都是num_units. 通過_call_將例項A變成一個可呼叫的物件,當傳入輸入input和狀態state後,根據公式output = new_state = act(W * input + U * state + B) 可以得到相應的輸出並返回,
tf.contrib.rnn.BasicLSTMCell
原始碼如下
class BasicLSTMCell(RNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full LSTMCell that follows.
"""
def __init__(self, num_units, forget_bias=1.0, input_size=None,
state_is_tuple=True, activation=tanh, reuse=None):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
self._reuse = reuse
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
concat = _linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
關於LSTMStateTuple的原始碼如下
class LSTMStateTuple(_LSTMStateTuple):
"""Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
Stores two elements: `(c, h)`, in that order.
Only used when `state_is_tuple=True`.
"""
__slots__ = ()
@property
def dtype(self):
(c, h) = self
if not c.dtype == h.dtype:
raise TypeError("Inconsistent internal state: %s vs %s" %
(str(c.dtype), str(h.dtype)))
return c.dtype
BasicLSTMCell類是最基本的LSTM迴圈神經網路單元。
輸入引數和BasicRNNCell差不多
num_units: LSTM cell層中的單元數
forget_bias: forget gates中的偏置
state_is_tuple: 還是設定為True吧, 返回 (c_state , m_state)的二元組
activation: 狀態之間轉移的啟用函式
reuse: Python布林值, 描述是否重用現有作用域中的變數
- state_size屬性:如果state_is_tuple為true的話,返回的是二元狀態元祖。
- output_size屬性:返回LSTM中的num_units, 也就是LSTM Cell中的單元數,在初始化是輸入的num_units引數
- _call_()將類例項轉化為一個可呼叫的物件,傳入輸入input和狀態state,根據LSTM的計算公式, 返回new_h, 和新的狀態new_state. 其中new_state = (new_c, new_h)關於具體的理論詳細見這篇論文https://arxiv.org/pdf/1409.2329.pdf