用tensorflow訓練自己的資料_3、訓練模型
阿新 • • 發佈:2019-02-15
訓練模型的時候,維數一定要匹配,同時要了解你自己的資料的格式,和讀取的型別,一個one_hot編碼用的函式和非one_hot用的函式完全不一樣,這也是我當時一直出現問題的原因。
#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Thu Jan 25 11:32:40 2018 @author: huangxudong """ import dr_alexnet import tensorflow as tf import read_data2 #定義網路超引數 learning_rate=0.01 train_iters=2000 batch_size=5 capacity=256 display_step=10 #讀取資料 tra_list,tra_labels,val_list,val_labels=read_data2.get_files('/home/bigvision/Desktop/DR_model',0.2) tra_list_batch,tra_label_batch=read_data2.get_batch(tra_list,tra_labels,512,512,batch_size,capacity) val_list_batch,val_label_batch=read_data2.get_batch(val_list,val_labels,512,512,batch_size,capacity) #定義網路引數 n_class=6 #標記維度 dropout=0.75 skip=[] #輸入佔位符 x=tf.placeholder(tf.float32,[None,786432]) #2800*2100*3,512*512*3 y=tf.placeholder(tf.int32,[None]) #print(y.shape) keep_prob=tf.placeholder(tf.float32) #dropout ''''構建模型,定義損失函式和優化器''''' pred=dr_alexnet.alexNet(x,dropout,n_class,skip) #定義損失函式和優化器 cost=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=pred.fc3)) optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) #評估函式,優化函式 correct_pred=tf.nn.in_top_k(pred.fc3,y,1) #1表示列上去最大,0是行,這個地方如果是one_hot就是tf.argmax accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32)) #改型別 '''訓練模型''' init=tf.global_variables_initializer() #初始化所有變數 with tf.Session() as sess: sess.run(init) coord=tf.train.Coordinator() threads= tf.train.start_queue_runners(coord=coord) step=1 #開始訓練,達到最大訓練次數 while step*batch_size<train_iters: batch_x,batch_y=tra_list_batch.eval(session=sess),tra_label_batch.eval(session=sess) batch_x=batch_x.reshape((batch_size,786432)) batch_y=batch_y.T sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout}) if step%display_step==2: #計算損失值和準確度,輸出 loss,acc=sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob:1.}) print("Iter"+str(step*batch_size)+",Minibatch Loss="+ "{:.6f}".format(loss)+", Training Acc"+ "{:.5f}".format(acc)) step+=1 print("Optimization Finished!") coord.request_stop() coord.join(threads) #多執行緒進行batch送入
feed_dict字典讀取資料的時候不能是tensor型別,必須是list,numpy型別(還有一個忘了),所以在送入batch資料的時候加入了.eval(session.sess),當初這塊也是磨了很久。希望以後不在犯錯
本人新人,對大家有幫助的話就點贊哦