1. 程式人生 > >使用C++解析MNIST數據庫

使用C++解析MNIST數據庫

可視化 win 大坑 fop 存儲 ring ios num 出了

遇到的兩個個大坑
1.官方主頁給出了每個文件的字節數是個玄幻數據,training set images (9912422 bytes) ,這個字節數是解壓前的,解壓後字節數應該為47,040,016,這個數等於4 + 4 + 4 + 4 + 60000 * 28 * 28。
2.windows下的fgetc是個玄幻函數,以文本方式"r"讀取時會錯誤判斷EOF標誌,改成"rb",以字節流方式讀取即可。

#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;

const int MnistTrainNumber = 6000;
const int MnistTestNumber = 1000;

//存儲像素信息
struct Image
{
    cv::Mat pixs;
    Image()
    {
        pixs.create(Size(28, 28), CV_8U);
    }
};
struct MnistImage
{
    //檢驗值
    int magicNumber;
    //圖片數量
    int number;
    //圖片行數
    int rows;
    //圖片列數
    int cols;
    //圖片數組
    vector<Image> images;
};
struct MnistLabel
{
    //檢驗值
    int magicNumber;
    //標簽數量
    int number;
    //標簽數組
    vector<int> labels;
};
//訓練集
struct MnistTrainSet
{
    MnistImage trainImages;
    MnistLabel trainLabels;
};
//測試集
struct MnistTestSet
{
    MnistImage trainImages;
    MnistLabel trainLabels;
};
//從file當前指針開始,連續讀取length個字節,返回讀取到的整數
int readData(FILE *file, int length)
{
    int ans = 0;
    for (int i = 0; i < length; i++)
    {
        ans = ans * 256 + fgetc(file);
    }
    return ans;
}
//解析圖片字節流文件
int parseMnistImage(const char *fileName, MnistImage &mnistImage, int imagesNumber)
{
    FILE *out = fopen(fileName, "rb");
    if (out == NULL) return -1;
    mnistImage.magicNumber = readData(out, 4);
    mnistImage.number = readData(out, 4);
    mnistImage.rows = readData(out, 4);
    mnistImage.cols = readData(out, 4);
    for (int k = 0; k < imagesNumber; k++)
    {
        Image image;
        for (int i = 0; i < 28; i++)
        {
            for (int j = 0; j < 28; j++)
            {
                int x = fgetc(out);
                image.pixs.at<uchar>(i, j) = x;
            }
        }
        mnistImage.images.push_back(image);
    }
    fclose(out);
    return mnistImage.magicNumber;
}
//解析標簽字節流文件
int parseMnistLabel(const char *fileName, MnistLabel &mnistLabel, int labelNumber)
{
    FILE *out = fopen(fileName, "rb");
    if (out == NULL) return -1;
    mnistLabel.magicNumber = readData(out, 4);
    mnistLabel.number = readData(out, 4);
    for (int i = 0; i < labelNumber; i++)
    {
        int x = fgetc(out);
        mnistLabel.labels.push_back(x);
    }
    fclose(out);
    return mnistLabel.magicNumber;
}
void virtualizeData(Mat &mat)
{
    imshow("virtualizeData", mat);
    waitKey();
}
int main()
{
    MnistTrainSet mnistTrainSet;
    MnistTestSet mnistTestSet;
    //magic number分別為2051,2049,2051,2049,與官方提供的檢驗值比對以確定解析程序是否有誤
    cout << parseMnistImage("train-images.idx3-ubyte", mnistTrainSet.trainImages, MnistTrainNumber) << endl;
    cout << parseMnistLabel("train-labels.idx1-ubyte", mnistTrainSet.trainLabels, MnistTrainNumber) << endl;
    cout << parseMnistImage("t10k-images.idx3-ubyte", mnistTestSet.trainImages, MnistTestNumber) << endl;
    cout << parseMnistLabel("t10k-labels.idx1-ubyte", mnistTestSet.trainLabels, MnistTestNumber) << endl;
    //可視化訓練集中第k張圖片
    int k = 5;
    virtualizeData(mnistTrainSet.trainImages.images[k].pixs);
    return 0;
}

使用C++解析MNIST數據庫