keras實現attention based sequence to sequence model(首稿)
阿新 • • 發佈:2019-02-03
class AttentionGRU(GRU):
def __init__(self, atten_states, states_len, L2Strength, **kwargs):
'''
:param atten_states: previous states for attention
:param states_len: length of state
:param L2Strength: for regularization
:param kwargs: for GRU
'''
self.p_states = atten_states
self.states_len = states_len
self.size = kwargs['units' ]
self.L2Strength = L2Strength
super(AttentionGRU, self).__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
input_length = input_shape[1]
self.W1 = self.add_weight(shape = (self.units + input_dim, 1),
initializer = 'random_uniform' ,
regularizer=l2(self.L2Strength),
trainable = True)
self.b1 = self.add_weight(shape=(1,),
initializer = 'zero',
regularizer=l2(self.L2Strength),
trainable= True )
self.W2 = self.add_weight(shape=(self.units + input_dim, self.units),
initializer='random_uniform',
regularizer=l2(self.L2Strength),
trainable=True)
self.b2 = self.add_weight(shape=(self.units,),
initializer='zero',
regularizer=l2(self.L2Strength),
trainable=True)
super(AttentionGRU, self).build(input_shape)
def step(self, inputs, states):
h, _ = super(AttentionGRU, self).step(inputs, states)
alfa = K.repeat(h, self.states_len) # alfa = [batch_size, states_len, units]
alfa = K.concatenate([self.p_states, alfa], axis = 2) # alfa = [batch_size, states_len, 2*units]
scores = K.tanh(K.dot(alfa, self.W1) + self.b1) # scores = [batch_size, states_len, 1]
scores = K.softmax(scores)
scores = K.reshape(scores, (-1, 1, self.states_len)) # scores = [batch_size, 1, states_len]
attn = K.batch_dot(scores, self.p_states) # attn = [batch_size, 1, units]
attn = K.reshape(attn, (-1, self.units)) # attn = [batch_size, units]
h = keras.layers.concatenate([h, attn]) # h = [batch_size, 2*units]
h = K.dot(h, self.W2) + self.b2 # h = [batch_size, units]
return h, [h]
def compute_output_shape(self, input_shape):
return input_shape[0], self.units
需要把encoder的states傳給引數atten_states,然後就當Keras裡標準的GRU用就好了。因為是GRU不是LSTM,所以step裡計算方式和論文裡有點不一樣。units是hidden size,這裡假設encoder和decoder的hidden size一樣。