python讀取mnist資料
阿新 • • 發佈:2019-01-05
這裡用python實現了對mnist手寫數字識別圖片的讀取
函式read_data_sets
的輸入是mnist中那四份檔案所在的目錄名,輸出是如下資料結構:
{
'test': {
'images': [[...],[...]...], # numpy陣列,其大小為[10000, 784],每一行即對應一張圖片
'labels': [...] # numpy陣列,若引數one-hot為True,則大小為[10000,10],否則為[10000,]
},
'train': {
'images': [[...],[...],...], # numpy陣列,其大小為[60000, 784]
'labels' : [...] # numpy陣列,若引數one-hot為True,則大小為[60000,10],否則為[60000,]
},
'img_shape': (img_h,img_w,img_c) # 表示影象的大小,依次含義為“高"、“寬”、“通道”。在mnist中,實質為(28,28,1)
}
函式引數含義:
dir
: mnist的四個資料檔案所在的檔案目錄名one_hot
: 若為True
,則其標籤是 one-hot 編碼,否則就是其對應的數字
注意,這裡對images
進行了歸一化和 flatten 處理,若要顯示影象,應先將之線性縮放到區間[0,255]
,同還原圖片的大小。具體可參考程式碼裡的 if main
# -*- coding: utf-8 -*-
import numpy as np
import struct
import os
from collections import defaultdict
def read_data_sets(dir, one_hot=True):
files = {
'test': ['t10k-images.idx3-ubyte', 't10k-labels.idx1-ubyte'],
'train': ['train-images.idx3-ubyte', 'train-labels.idx1-ubyte']
}
data_set = defaultdict( dict)
for key,value in files.items():
for i,fn in enumerate(value):
file = open(os.path.join(dir, fn), 'rb')
f = file.read()
file.close()
if not i:
img_index = struct.calcsize('>IIII')
_,size,row,column = struct.unpack('>IIII', f[:img_index])
imgs = struct.unpack_from(str(size*row*column) + 'B', f, img_index)
data_set['img_shape'] = (row, column, 1)
imgs = np.reshape(imgs, (size, row*column)).astype(np.float32)
imgs = (imgs - np.min(imgs)) / (np.max(imgs) - np.min(imgs))
data_set[key]['images'] = imgs
else:
label_index = struct.calcsize('>II')
_,size = struct.unpack('>II', f[:label_index])
labels = struct.unpack_from(str(size) + 'B', f, label_index)
labels = np.reshape(labels, (size,))
if one_hot:
tmp = np.zeros((size, np.max(labels)+1))
tmp[np.arange(size),labels] = 1
labels = tmp
data_set[key]['labels'] = labels
return data_set
if __name__ == '__main__':
import matplotlib.pyplot as plt
data_set = read_data_sets('data')
imgs = data_set['train']['images'] * 255
labels = data_set['train']['labels']
img_shape = data_set['img_shape'] * 255
# print(imgs[0])
plt.figure()
plt.subplot(221); plt.imshow(imgs[0].reshape(img_shape[:2])); print(np.argmax(labels[0]))
plt.subplot(222); plt.imshow(imgs[10].reshape(img_shape[:2])); print(np.argmax(labels[10]))
plt.subplot(223); plt.imshow(imgs[100].reshape(img_shape[:2])); print(np.argmax(labels[100]))
plt.subplot(224); plt.imshow(imgs[1000].reshape(img_shape[:2])); print(np.argmax(labels[1000]))
plt.show()