1. 程式人生 > >tf.nn.rnn_cell.BasicLSTMCell函式用法

tf.nn.rnn_cell.BasicLSTMCell函式用法

tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True): n_hidden表示神經元的個數,forget_bias就是LSTM們的忘記係數,如果等於1,就是不會忘記任何資訊。如果等於0,就都忘記。state_is_tuple預設就是True,官方建議用True,就是表示返回的狀態用一個元祖表示。這個裡面存在一個狀態初始化函式,就是zero_state(batch_size,dtype)兩個引數。batch_size就是輸入樣本批次的數目,dtype就是資料型別。

例如:

import tensorflow as tf

batch_size = 4 
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True) #time_major如果是True,就表示RNN的steps用第一個維度表示,建議用這個,執行速度快一點。
#如果是False,那麼輸入的第二個維度就是steps。
#如果是True,output的維度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和輸入是一樣的
#final_state就是整個LSTM輸出的最終的狀態,包含c和h。c和h的維度都是[batch_size, n_hidden]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(output))
    print(sess.run(final_state))
輸出:

[[[-0.17837059  0.01385643  0.11524696 -0.04611184  0.05751593 -0.02275656
    0.10593235 -0.07636188  0.12855089  0.00768109]
  [ 0.07553699 -0.23295973 -0.00144508  0.09547552 -0.05839045 -0.06769165
   -0.41666976  0.3499622  -0.01430317 -0.02479473]
  [ 0.08574327 -0.05990489  0.06817424  0.03434218  0.10152793 -0.10594042
   -0.25310516  0.07232092  0.064815    0.0659876 ]
  [ 0.15607212 -0.31474397 -0.06477047 -0.06982201 -0.05489461  0.0188695
   -0.30281037  0.39494631 -0.05267519 -0.03253869]]


 [[-0.03209484 -0.06323308 -0.25410452 -0.10886975  0.00253956 -0.08053195
    0.18729064 -0.0788438   0.14781287 -0.20489833]
  [ 0.3164973  -0.10971865 -0.35004857 -0.00576114 -0.08092841  0.00883496
   -0.17579219  0.19092172 -0.0237403  -0.43207553]
  [ 0.2409949  -0.17808972 -0.1486263   0.02179234 -0.21656732  0.0522153
   -0.21345614  0.18841118 -0.0094095  -0.34072629]
  [ 0.12034108 -0.23767222  0.03664704  0.13274716 -0.04165298 -0.04095407
   -0.31182185  0.36334303 -0.01146755  0.05028744]]


 [[-0.12453001 -0.1567502  -0.16580626 -0.03544752  0.06869993  0.09097657
   -0.02214662 -0.18668351  0.06159507 -0.35843855]
  [ 0.2010586   0.03222289 -0.31237942  0.01898964 -0.08158109 -0.02510365
    0.02967031  0.12587228 -0.22250202 -0.08734316]
  [ 0.14316584  0.02029586 -0.1062321   0.02968353 -0.02318866  0.07653226
   -0.13600637 -0.00440343  0.07305693 -0.26385978]
  [ 0.23669831 -0.13415271 -0.10488234  0.03128149 -0.11343875 -0.05327768
   -0.22888957  0.17797095 -0.02945257 -0.18901967]]]
LSTMStateTuple(c=array([[-0.72714508,  0.32974839,  0.67756736,  0.11421457,  0.39167076,
         0.31247479,  0.0755761 , -0.62171376,  0.58582318, -0.19749212],
       [ 0.44815305,  0.06901363, -0.88840145,  0.22841501,  0.04539755,
         0.17472507, -0.50547051,  0.46637267, -0.07522876, -0.80750966],
       [-0.19392423, -0.16717091, -0.19510591, -0.48713976, -0.18430954,
         0.1046299 ,  0.30127296, -0.03556332, -0.37671563, -0.1388765 ],
       [-0.47982571,  0.2172934 ,  0.56419176,  0.15874679,  0.29927608,
         0.16362543,  0.11525643, -0.47210076,  0.56833684, -0.18866351]], dtype=float32), h=array([[-0.36339632,  0.17585619,  0.29174498,  0.03471305,  0.2237694 ,
         0.13323013,  0.03002708, -0.26190156,  0.28289214, -0.12495621],
       [ 0.1543802 ,  0.04264591, -0.27087522,  0.084597  ,  0.01555507,
         0.10631134, -0.23696639,  0.2758382 , -0.03724022, -0.4389703 ],
       [-0.14088678, -0.10961234, -0.10831701, -0.19923639, -0.10324109,
         0.04290821,  0.10720341, -0.01477169, -0.14518294, -0.04280116],
       [-0.34502122,  0.10841226,  0.32169446,  0.03053316,  0.20867576,
         0.04689977,  0.03286072, -0.11068864,  0.37977526, -0.12110116]], dtype=float32))