1. 程式人生 > >ValueError: Variable rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists, disallowed.

ValueError: Variable rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists, disallowed.

報錯:
ValueError: Variable rnnlm/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists, disallowed.

原因:模型重用

解決方法:在構建圖的程式碼塊上加上with tf.Graph().as_default():

graph = tf.Graph()
with graph.as_default():

    #定義輸入,輸出
    x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size]
,name="inputx") y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") #定義權值 weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) bias=tf.Variable(tf.zeros(shape=[n_classes])) # 定義RNN網路 def RNN(x,weights,bias): '''返回[batch_size,n_classes]'''
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) # RNN/LSTM/GRU在此處選擇BasicRNNCell/BasicLSTMCell/GRUCell。該網路中包含一個深度RNN網路,這個RNN包含hidden_num個隱層單元/RNN cell rnn_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(hidden_num)
for _ in range(3)]) # 構建多層RNN/LSTM/GRU網路,3表示3層(這裡都是用MultiRNNCell,沒有MultiGRUCell等) output,states=tf.nn.dynamic_rnn(rnn_cell,x,dtype=tf.float32) return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # 計算預計輸出 predy=RNN(x,weights,bias) # 定義損失函式和優化演算法 cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y)) train=tf.train.AdamOptimizer(train_rate).minimize(cost) # 計算accuracy correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1)) accuracy=tf.reduce_mean(tf.to_float(correct_pred)) ## 開始訓練 with tf.Session(graph=graph) as sess: print('step','accuracy','loss') sess.run(tf.initialize_all_variables()) step=1 testx,testy=mnist.test.next_batch(batch_size) while step<train_step: batch_x,batch_y=mnist.train.next_batch(batch_size) # batch_x=tf.reshape(batch_x,shape=[batch_size,sequence_length,frame_size]) _loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y}) if step % display_step ==0: acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy}) print(step,acc,loss) step+=1

這樣就不會報錯了。