讀取MNIST資料
阿新 • • 發佈:2021-12-20
Python依賴庫:numpy matplotlib
資料集下載地址:http://yann.lecun.com/exdb/mnist/
資料集的資料格式:
原始碼實現:import struct
import numpy as np import matplotlib.pyplot as plt def unpack_mnist(filepath): # 讀取mnist資料 f = open(filepath, 'rb') buf = f.read() # mnist資料型別識別 index = 0 magicNum = struct.unpack_from('>I', buf, index) index += struct.calcsize('>I') # 標籤資料解包 if magicNum[0] == 2049: labels = [] (labelNum, ) = struct.unpack_from('>I', buf, index) index += struct.calcsize('>I') for i in range(labelNum): label = struct.unpack_from('>B' ,buf, index) index += struct.calcsize('>B') # 將資料新增到陣列儲存 labels.append(label[0]) f.close() return labelNum, labels # 影象資料解包 elif magicNum[0] == 2051: imgs = [] imgNum, rows, columns = struct.unpack_from('>III', buf, index) index+= struct.calcsize('>III') for i in range(imgNum): img = struct.unpack_from('>784B' ,buf, index) index += struct.calcsize('>784B') # 將資料新增到陣列儲存 imgs.append(img) f.close() return imgNum, rows, columns, imgs else: print("input file error!") f.close() # 讀取訓練資料 imgNum, rows, columns, imgs = unpack_mnist("../data/train-images.idx3-ubyte") labelNum, labels = unpack_mnist("../data/train-labels.idx1-ubyte") # 讀取測試資料 imgNum, rows, columns, imgs = unpack_mnist("../data/t10k-images.idx3-ubyte") labelNum, labels = unpack_mnist("../data/t10k-labels.idx1-ubyte") # 測試讀取的圖片索引 index = 5 img = np.array(imgs[index]) img = img.reshape(28,28)
fig = plt.figure() plotwindow = fig.add_subplot(111) plt.title(str(labels[index])) plt.imshow(img ,cmap='gray') plt.show()