python讀取cifar10資料集
阿新 • • 發佈:2018-11-29
最近學習卷積網路用到cifar10資料集,自己寫了一個工具類,用來讀取已經下載到本地的cifar10資料集。
程式碼寫的不算好,但是自己用起來還可以。所以放到網上,有需要的可以拿去用。程式碼比較少,所以沒有寫註釋。下面介紹一下實現的功能。完整的程式碼可以在github上下載。地址:https://github.com/NewQJX/DeepLearning/tree/master/Cifar10
檔名為:input_data.py
建立了一個類Cifar10():用於讀取本地資料集,對資料集進行操作
__init__(self, path, one_hot = True): 引數path為本地資料集儲存路徑。one_hot:決定是否對類別獨熱編碼
_load_data():用於載入資料集
next_batch(batch_size, shuffle = True): 該方法返回指定batch_size大小的訓練集, shuffle:決定是否打亂順序
下面是使用該類的方法:
import input_data import numpy as np path = r"E:\pythonCode\TensorFlow\cifar10\cifar-10-batches-py" cifar10 = input_data.load_cifar10(path, one_hot = True) images = cifar10.images print("訓練集圖片:" + str(images.shape)) labels = cifar10.labels print("訓練集類別:" + str(labels.shape)) test_images = cifar10.test.images print("測試集圖片:"+ str(test_images.shape)) test_labels = cifar10.test.labels print("測試集類別:"+ str(test_labels.shape)) batch_xs, batch_ys = cifar10.next_batch(batch_size = 500, shuffle = True) print("batch_xs shape is:" + str(batch_xs.shape)) print("batch_ys shape is:" + str(batch_ys.shape))