1. 程式人生 > 其它 >讀取MNIST資料

讀取MNIST資料

Python依賴庫:numpy matplotlib

資料集下載地址:http://yann.lecun.com/exdb/mnist/

資料集的資料格式:

原始碼實現:import struct

import numpy as np
import matplotlib.pyplot as plt

def unpack_mnist(filepath):
    # 讀取mnist資料
    f = open(filepath, 'rb')
    buf = f.read()
    # mnist資料型別識別
    index = 0
    magicNum = struct.unpack_from('
>I', buf, index) index += struct.calcsize('>I') # 標籤資料解包 if magicNum[0] == 2049: labels = [] (labelNum, ) = struct.unpack_from('>I', buf, index) index += struct.calcsize('>I') for i in range(labelNum): label = struct.unpack_from('>B
' ,buf, index) index += struct.calcsize('>B') # 將資料新增到陣列儲存 labels.append(label[0]) f.close() return labelNum, labels # 影象資料解包 elif magicNum[0] == 2051: imgs = [] imgNum, rows, columns = struct.unpack_from('>III', buf, index) index
+= struct.calcsize('>III') for i in range(imgNum): img = struct.unpack_from('>784B' ,buf, index) index += struct.calcsize('>784B') # 將資料新增到陣列儲存 imgs.append(img) f.close() return imgNum, rows, columns, imgs else: print("input file error!") f.close() # 讀取訓練資料 imgNum, rows, columns, imgs = unpack_mnist("../data/train-images.idx3-ubyte") labelNum, labels = unpack_mnist("../data/train-labels.idx1-ubyte") # 讀取測試資料 imgNum, rows, columns, imgs = unpack_mnist("../data/t10k-images.idx3-ubyte") labelNum, labels = unpack_mnist("../data/t10k-labels.idx1-ubyte") # 測試讀取的圖片索引 index = 5 img = np.array(imgs[index]) img = img.reshape(28,28)
fig
= plt.figure() plotwindow = fig.add_subplot(111) plt.title(str(labels[index])) plt.imshow(img ,cmap='gray') plt.show()