1. 程式人生 > >python讀取mnist資料

python讀取mnist資料

這裡用python實現了對mnist手寫數字識別圖片的讀取
函式read_data_sets的輸入是mnist中那四份檔案所在的目錄名,輸出是如下資料結構:

{
	'test': {
		'images': [[...],[...]...],  # numpy陣列,其大小為[10000, 784],每一行即對應一張圖片
		'labels': [...]  # numpy陣列,若引數one-hot為True,則大小為[10000,10],否則為[10000,]
	},
	'train': {
		'images': [[...],[...],...],  # numpy陣列,其大小為[60000, 784]
		'labels'
: [...] # numpy陣列,若引數one-hot為True,則大小為[60000,10],否則為[60000,] }, 'img_shape': (img_h,img_w,img_c) # 表示影象的大小,依次含義為“高"、“寬”、“通道”。在mnist中,實質為(28,28,1) }

函式引數含義:

  • dir: mnist的四個資料檔案所在的檔案目錄名
  • one_hot: 若為True,則其標籤是 one-hot 編碼,否則就是其對應的數字

注意,這裡對images進行了歸一化和 flatten 處理,若要顯示影象,應先將之線性縮放到區間[0,255],同還原圖片的大小。具體可參考程式碼裡的 if main

# -*- coding: utf-8 -*-
import numpy as np
import struct
import os
from collections import defaultdict

def read_data_sets(dir, one_hot=True):
	files = {
		'test': ['t10k-images.idx3-ubyte', 't10k-labels.idx1-ubyte'],
		'train': ['train-images.idx3-ubyte', 'train-labels.idx1-ubyte']
	}
	data_set = defaultdict(
dict) for key,value in files.items(): for i,fn in enumerate(value): file = open(os.path.join(dir, fn), 'rb') f = file.read() file.close() if not i: img_index = struct.calcsize('>IIII') _,size,row,column = struct.unpack('>IIII', f[:img_index]) imgs = struct.unpack_from(str(size*row*column) + 'B', f, img_index) data_set['img_shape'] = (row, column, 1) imgs = np.reshape(imgs, (size, row*column)).astype(np.float32) imgs = (imgs - np.min(imgs)) / (np.max(imgs) - np.min(imgs)) data_set[key]['images'] = imgs else: label_index = struct.calcsize('>II') _,size = struct.unpack('>II', f[:label_index]) labels = struct.unpack_from(str(size) + 'B', f, label_index) labels = np.reshape(labels, (size,)) if one_hot: tmp = np.zeros((size, np.max(labels)+1)) tmp[np.arange(size),labels] = 1 labels = tmp data_set[key]['labels'] = labels return data_set if __name__ == '__main__': import matplotlib.pyplot as plt data_set = read_data_sets('data') imgs = data_set['train']['images'] * 255 labels = data_set['train']['labels'] img_shape = data_set['img_shape'] * 255 # print(imgs[0]) plt.figure() plt.subplot(221); plt.imshow(imgs[0].reshape(img_shape[:2])); print(np.argmax(labels[0])) plt.subplot(222); plt.imshow(imgs[10].reshape(img_shape[:2])); print(np.argmax(labels[10])) plt.subplot(223); plt.imshow(imgs[100].reshape(img_shape[:2])); print(np.argmax(labels[100])) plt.subplot(224); plt.imshow(imgs[1000].reshape(img_shape[:2])); print(np.argmax(labels[1000])) plt.show()