使用C++解析MNIST數據庫
阿新 • • 發佈:2018-07-17
可視化 win 大坑 fop 存儲 ring ios num 出了
遇到的兩個個大坑
1.官方主頁給出了每個文件的字節數是個玄幻數據,training set images (9912422 bytes) ,這個字節數是解壓前的,解壓後字節數應該為47,040,016,這個數等於4 + 4 + 4 + 4 + 60000 * 28 * 28。
2.windows下的fgetc是個玄幻函數,以文本方式"r"讀取時會錯誤判斷EOF標誌,改成"rb",以字節流方式讀取即可。
#include <cstdio> #include <vector> #include <cstring> #include <iostream> #include <algorithm> #include <opencv2/opencv.hpp> using namespace std; using namespace cv; const int MnistTrainNumber = 6000; const int MnistTestNumber = 1000; //存儲像素信息 struct Image { cv::Mat pixs; Image() { pixs.create(Size(28, 28), CV_8U); } }; struct MnistImage { //檢驗值 int magicNumber; //圖片數量 int number; //圖片行數 int rows; //圖片列數 int cols; //圖片數組 vector<Image> images; }; struct MnistLabel { //檢驗值 int magicNumber; //標簽數量 int number; //標簽數組 vector<int> labels; }; //訓練集 struct MnistTrainSet { MnistImage trainImages; MnistLabel trainLabels; }; //測試集 struct MnistTestSet { MnistImage trainImages; MnistLabel trainLabels; }; //從file當前指針開始,連續讀取length個字節,返回讀取到的整數 int readData(FILE *file, int length) { int ans = 0; for (int i = 0; i < length; i++) { ans = ans * 256 + fgetc(file); } return ans; } //解析圖片字節流文件 int parseMnistImage(const char *fileName, MnistImage &mnistImage, int imagesNumber) { FILE *out = fopen(fileName, "rb"); if (out == NULL) return -1; mnistImage.magicNumber = readData(out, 4); mnistImage.number = readData(out, 4); mnistImage.rows = readData(out, 4); mnistImage.cols = readData(out, 4); for (int k = 0; k < imagesNumber; k++) { Image image; for (int i = 0; i < 28; i++) { for (int j = 0; j < 28; j++) { int x = fgetc(out); image.pixs.at<uchar>(i, j) = x; } } mnistImage.images.push_back(image); } fclose(out); return mnistImage.magicNumber; } //解析標簽字節流文件 int parseMnistLabel(const char *fileName, MnistLabel &mnistLabel, int labelNumber) { FILE *out = fopen(fileName, "rb"); if (out == NULL) return -1; mnistLabel.magicNumber = readData(out, 4); mnistLabel.number = readData(out, 4); for (int i = 0; i < labelNumber; i++) { int x = fgetc(out); mnistLabel.labels.push_back(x); } fclose(out); return mnistLabel.magicNumber; } void virtualizeData(Mat &mat) { imshow("virtualizeData", mat); waitKey(); } int main() { MnistTrainSet mnistTrainSet; MnistTestSet mnistTestSet; //magic number分別為2051,2049,2051,2049,與官方提供的檢驗值比對以確定解析程序是否有誤 cout << parseMnistImage("train-images.idx3-ubyte", mnistTrainSet.trainImages, MnistTrainNumber) << endl; cout << parseMnistLabel("train-labels.idx1-ubyte", mnistTrainSet.trainLabels, MnistTrainNumber) << endl; cout << parseMnistImage("t10k-images.idx3-ubyte", mnistTestSet.trainImages, MnistTestNumber) << endl; cout << parseMnistLabel("t10k-labels.idx1-ubyte", mnistTestSet.trainLabels, MnistTestNumber) << endl; //可視化訓練集中第k張圖片 int k = 5; virtualizeData(mnistTrainSet.trainImages.images[k].pixs); return 0; }
使用C++解析MNIST數據庫