TensorFlow詳解貓狗識別(一)--讀取自己的資料集
資料集下載 連結: https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 密碼: hpn4
資料集分別有12500張cat,12500張dog
讀取資料集 資料集的讀取,查閱了那麼多文件,大致瞭解到,資料集的讀取方法大概會分為兩種
1、先生成圖片list,和標籤list,把圖片名稱和標籤對應起來,再讀取製作迭代器(個人認為此方法一般用在,圖片名稱上可以明確的知道label的)
2、直接生成TFRecord檔案,用tf.TFRecordReader()來讀取,個人認為,當圖片量很大的時候(如:ImageNet)很使用,儲存了TFRecord檔案後,一勞永逸,省去了生成list的過程
下面貼出程式碼,簡單介紹兩種讀取資料集的方式。
方法一: import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np import cv2 # os模組包含作業系統相關的功能, # 可以處理檔案和目錄這些我們日常手動需要做的操作。因為我們需要獲取test目錄下的檔案,所以要匯入os模組。 # # 資料構成,在訓練資料中,There are 12500 cat,There are 12500 dogs,共25000張 # 獲取檔案路徑和標籤 def get_files(file_dir): # file_dir: 資料夾路徑 # return: 亂序後的圖片和標籤 cats = [] label_cats = [] dogs = [] label_dogs = [] # 載入資料路徑並寫入標籤值 for file in os.listdir(file_dir): name = file.split(sep='.') # name的形式為['dog', '9981', 'jpg'] # os.listdir將名字轉換為列表表達 if name[0] == 'cat': cats.append(file_dir + file) # 注意檔案路徑和名字之間要加分隔符,不然後面查詢圖片會提示找不到圖片 # 或者在後面傳路徑的時候末尾加兩// 'D:/Python/neural network/Cats_vs_Dogs/data/train//' label_cats.append(0) else: dogs.append(file_dir + file) label_dogs.append(1) # 貓為0,狗為1 print("There are %d cats\nThere are %d dogs" % (len(cats), len(dogs))) # 打亂檔案順序 image_list = np.hstack((cats, dogs)) label_list = np.hstack((label_cats, label_dogs)) # np.hstack()方法將貓和狗圖片和標籤整合到一起,標籤也整合到一起 temp = np.array([image_list, label_list]) # 這裡的陣列出來的是2行10列,第一行是image_list的資料,第二行是label_list的資料 temp = temp.transpose() # 轉置 # 將其轉換為10行2列,第一列是image_list的資料,第二列是label_list的資料 np.random.shuffle(temp) # 對應的打亂順序 image_list = list(temp[:, 0]) # 取所有行的第0列資料 label_list = list(temp[:, 1]) # 取所有行的第1列資料,並轉換為int label_list = [int(i) for i in label_list] return image_list, label_list # 生成相同大小的批次 def get_batch(image, label, image_W, image_H, batch_size, capacity): # image, label: 要生成batch的影象和標籤list # image_W, image_H: 圖片的寬高 # batch_size: 每個batch有多少張圖片 # capacity: 佇列容量 # return: 影象和標籤的batch # 將原來的python.list型別轉換成tf能夠識別的格式 image = tf.cast(image, tf.string)#強制型別轉換 label = tf.cast(label, tf.int32) # 生成佇列。我們使用slice_input_producer()來建立一個佇列,將image和label放入一個list中當做引數傳給該函式 input_queue = tf.train.slice_input_producer([image, label]) image_contents = tf.read_file(input_queue[0]) # 按佇列讀資料和標籤 label = input_queue[1] image = tf.image.decode_jpeg(image_contents, channels=3) # 要按照圖片格式進行解碼。本例程中訓練資料是jpg格式的,所以使用decode_jpeg()解碼器, # 如果是其他格式,就要用其他geshi具體可以從官方API中查詢。 # 注意decode出來的資料型別是uint8,之後模型卷積層裡面conv2d()要求輸入資料為float32型別 # 統一圖片大小 # 通過裁剪統一,包括裁剪和擴充 # image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) # 我的方法,通過縮小圖片,採用NEAREST_NEIGHBOR插值方法 image = tf.image.resize_images(image, [image_H, image_W], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=False) image = tf.cast(image, tf.float32) # 因為沒有標準化,所以需要轉換型別 # image = tf.image.per_image_standardization(image) # 標準化資料 image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, num_threads=64, # 執行緒 capacity=capacity) # image_batch是一個4D的tensor,[batch, width, height, channels], # label_batch是一個1D的tensor,[batch]。 # 這行多餘? label_batch = tf.reshape(label_batch, [batch_size]) return image_batch, label_batch ''' 下面程式碼為檢視圖片效果,主要用於觀察圖片是否打亂,你會可能會發現,圖片顯示出來的是一堆亂點,不用擔心,這是因為你對圖片的每一個畫素進行了強制型別轉化為了tf.float32,使畫素值介於-1~1之間,若想看原圖,可使用tf.uint8,畫素介於0~255 ''' # print("yes") # image_list,label_list = get_files("E:\\Pycharm\\tf-01\\Bigwork\\train\\") # image_batch,label_batch = train_batch,train_label_batch = get_batch(image_list,label_list,208,208,4,256) # print("ok") # # for i in range(4): # with tf.Session() as sess: # i = 0 # coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(coord=coord) # try: # while not coord.should_stop() and i < 1: # # just plot one batch size # image, label = sess.run([image_batch, label_batch]) # for j in np.arange(4): # print('label: %d' % label[j]) # plt.imshow(image[j, :, :, :]) # plt.show() # i += 1 # except tf.errors.OutOfRangeError: # print('done!') # finally: # coord.request_stop() # coord.join(threads) # for i in range(4): # sess = tf.Session() # image,label = sess.run([image_batch,label_batch]) # for j in range(4): # print('label:%d' % label[j]) # plt.imshow(image[j, :, :, :]) # plt.show() # sess.close() 方法二 import os import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np import cv2 cwd = "E:\\Pycharm\\tf-01\\Bigwork\\test\\" classes = {'cat', 'dog'} # 預先自己定義的類別 writer = tf.python_io.TFRecordWriter('test.tfrecords') # 輸出成tfrecord檔案 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) for index, name in enumerate(classes): class_path = cwd + name + '\\' print(class_path) for img_name in os.listdir(class_path): img_path = class_path + img_name # 每個圖片的地址 img = Image.open(img_path) img = img.resize((208, 208)) img_raw = img.tobytes() # 將圖片轉化為二進位制格式 example = tf.train.Example(features=tf.train.Features(feature={ "label": _int64_feature(index), "img_raw": _bytes_feature(img_raw), })) writer.write(example.SerializeToString()) # 序列化為字串 writer.close() print("writed OK") #生成tfrecord檔案後,下次可以不用再執行這段程式碼!!! def read_and_decode(filename,batch_size): # read train.tfrecords filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.float32) img = tf.reshape(img, [128, 128, 3]) # reshape image to 208*208*3 #據說下面這行多餘 #img = tf.cast(img,tf.float32)*(1./255)-0.5 label = tf.cast(features['label'], tf.int64) img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=batch_size, num_threads = 8, capacity = 100, min_after_dequeue = 60,) return img_batch, tf.reshape(label_batch, [batch_size]) filename = './/train.tfrecords'
image_batch、label_batch = read_and_decode(filename,batch_size)
--------------------- 作者:hush_yang 來源:CSDN 原文:https://blog.csdn.net/qq_41004007/article/details/81987631 版權宣告:本文為博主原創文章,轉載請附上博文連結!