1. 程式人生 > >解析train-images-idx3-ubyte與train-labels-idx1-ubyte(mnist資料集)

解析train-images-idx3-ubyte與train-labels-idx1-ubyte(mnist資料集)

import numpy as np
import struct

def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3檔案的通用函式
    :param idx3_ubyte_file: idx3檔案路徑
    :return: 資料集
    """
    # 讀取二進位制資料
    bin_data = open(idx3_ubyte_file, 'rb').read()
 
    # 解析檔案頭資訊,依次為魔數、圖片數量、每張圖片高、每張圖片寬
    offset = 0
    fmt_header = '>iiii'
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print ('魔數:%d, 圖片數量: %d張, 圖片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))
 
    # 解析資料集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)
    fmt_image = '>' + str(image_size) + 'B'
    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print ('已解析 %d' % (i + 1) + '張')
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
        offset += struct.calcsize(fmt_image)
    return images
 
 
def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1檔案的通用函式
    :param idx1_ubyte_file: idx1檔案路徑
    :return: 資料集
    """
    # 讀取二進位制資料
    bin_data = open(idx1_ubyte_file, 'rb').read()
 
    # 解析檔案頭資訊,依次為魔數和標籤數
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print ('魔數:%d, 圖片數量: %d張' % (magic_number, num_images))
 
    # 解析資料集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print ('已解析 %d' % (i + 1) + '張')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels