tensorflow遷移學習:VGG16花朵分類
阿新 • • 發佈:2018-12-10
其實上面的文章已經寫的很詳細了,但是還有一點小小的問題,通過參考其他的程式碼,將其進行補充,這樣完整的程式就可以運行了。
下面我就主要說說進行補充的地方:
補充1:
如果按照原文章的步驟一步步進行,在進行到這步的時候,會有報錯,會提示找不到labels和codes檔案,原因是我們在執行上一步的程式碼沒有對其進行儲存。
# read codes and labels from file import csv with open('labels') as f: reader = csv.reader(f, delimiter='\n') labels = np.array([each for each in reader if len(each) > 0]).squeeze() with open('codes') as f: codes = np.fromfile(f, dtype=np.float32) codes = codes.reshape((len(labels), -1))
我們只需要在這段上面的程式碼處,對輸出的labels和codes檔案進行儲存處理即可。如下所示:
#將影象批量batches通過VGG模型,將輸出作為新的輸入: # Set the batch size higher if you can fit in in your GPU memory batch_size = 10 codes_list = [] labels = [] batch = [] codes = None with tf.Session() as sess: vgg = vgg16.Vgg16() input_ = tf.placeholder(tf.float32, [None, 224, 224, 3]) with tf.name_scope("content_vgg"): vgg.build(input_) for each in classes: print("Starting {} images".format(each)) class_path = data_dir + each files = os.listdir(class_path) for ii, file in enumerate(files, 1): # Add images to the current batch # utils.load_image crops the input images for us, from the center img = utils.load_image(os.path.join(class_path, file)) batch.append(img.reshape((1, 224, 224, 3))) labels.append(each) # Running the batch through the network to get the codes if ii % batch_size == 0 or ii == len(files): images = np.concatenate(batch) feed_dict = {input_: images} codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict) # Here I'm building an array of the codes if codes is None: codes = codes_batch else: codes = np.concatenate((codes, codes_batch)) # Reset to start building the next batch batch = [] print('{} images processed'.format(ii)) #這裡就是新增的儲存的程式碼 #這樣我們就可以得到一個 codes 陣列,和一個 labels 陣列,分別儲存了所有花朵的特徵值和類別。 with open('codes', 'w') as f: codes.tofile(f) import csv with open('labels', 'w') as f: writer = csv.writer(f, delimiter='\n') writer.writerow(labels)
接下來就會在檔案所在的目錄內自動生成labels和codes檔案,然後繼續原文章的步驟進行就可以實現最後的結果。
補充2:
文章結尾處,作者是以柱狀圖的形式來展示預測結果的,現在我們以只顯示概率和品種的形式來展示結果。jin
#將陣列轉換為list
predic_list = prediction.tolist()
print(type(predic_list))
index = predic_list.index(max(predic_list))
print(lb.classes_[index]+":"+str(max(predic_list)))
這樣就可以滿足部分同學對概率顯示預測結果的需求啦。
今天第一次寫,以後會繼續堅持,加油。