解析train-images-idx3-ubyte與train-labels-idx1-ubyte(mnist資料集)
阿新 • • 發佈:2018-12-28
import numpy as np import struct def decode_idx3_ubyte(idx3_ubyte_file): """ 解析idx3檔案的通用函式 :param idx3_ubyte_file: idx3檔案路徑 :return: 資料集 """ # 讀取二進位制資料 bin_data = open(idx3_ubyte_file, 'rb').read() # 解析檔案頭資訊,依次為魔數、圖片數量、每張圖片高、每張圖片寬 offset = 0 fmt_header = '>iiii' magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset) print ('魔數:%d, 圖片數量: %d張, 圖片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols)) # 解析資料集 image_size = num_rows * num_cols offset += struct.calcsize(fmt_header) fmt_image = '>' + str(image_size) + 'B' images = np.empty((num_images, num_rows, num_cols)) for i in range(num_images): if (i + 1) % 10000 == 0: print ('已解析 %d' % (i + 1) + '張') images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols)) offset += struct.calcsize(fmt_image) return images def decode_idx1_ubyte(idx1_ubyte_file): """ 解析idx1檔案的通用函式 :param idx1_ubyte_file: idx1檔案路徑 :return: 資料集 """ # 讀取二進位制資料 bin_data = open(idx1_ubyte_file, 'rb').read() # 解析檔案頭資訊,依次為魔數和標籤數 offset = 0 fmt_header = '>ii' magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset) print ('魔數:%d, 圖片數量: %d張' % (magic_number, num_images)) # 解析資料集 offset += struct.calcsize(fmt_header) fmt_image = '>B' labels = np.empty(num_images) for i in range(num_images): if (i + 1) % 10000 == 0: print ('已解析 %d' % (i + 1) + '張') labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0] offset += struct.calcsize(fmt_image) return labels