機器學習(2) - KNN識別MNIST
代碼
https://github.com/s055523/MNISTTensorFlowSharp
數據的獲得
數據可以由http://yann.lecun.com/exdb/mnist/下載。之後,儲存在trainDir中,下次就不需要下載了。
/// <summary> /// 如果文件不存在就去下載 /// </summary> /// <param name="urlBase">下載地址</param> /// <param name="trainDir">文件目錄地址</param>View Code/// <param name="file">文件名</param> /// <returns></returns> public static Stream MaybeDownload(string urlBase, string trainDir, string file) { if (!Directory.Exists(trainDir)) { Directory.CreateDirectory(trainDir); }var target = Path.Combine(trainDir, file); if (!File.Exists(target)) { var wc = new WebClient(); wc.DownloadFile(urlBase + file, target); } return File.OpenRead(target); }
數據格式處理
下載下來的文件共有四個,都是擴展名為gz的壓縮包。
train-images-idx3-ubyte.gz 55000張訓練圖片和5000張驗證圖片
train-labels-idx1-ubyte.gz 訓練圖片對應的數字標簽(即答案)
t10k-images-idx3-ubyte.gz 10000張測試圖片
t10k-labels-idx1-ubyte.gz 測試圖片對應的數字標簽(即答案)
處理圖片數據壓縮包
每個壓縮包的格式為:
偏移量 |
類型 |
值 |
意義 |
0 |
Int32 |
2051或2049 |
一個定死的魔術數。用來驗證該壓縮包是訓練集(2051)或測試集(2049) |
4 |
Int32 |
60000或10000 |
壓縮包的圖片數 |
8 |
Int32 |
28 |
每個圖片的行數 |
12 |
Int32 |
28 |
每個圖片的列數 |
16 |
Unsigned byte |
0 - 255 |
第一張圖片的第一個像素 |
17 |
Unsigned byte |
0 - 255 |
第一張圖片的第二個像素 |
… |
… |
… |
… |
因此,我們可以使用一個統一的方式將數據處理。我們只需要那些圖片像素。
/// <summary> /// 從數據流中讀取下一個int32 /// </summary> /// <param name="s"></param> /// <returns></returns> int Read32(Stream s) { var x = new byte[4]; s.Read(x, 0, 4); return DataConverter.BigEndian.GetInt32(x, 0); } /// <summary> /// 處理圖片數據 /// </summary> /// <param name="input"></param> /// <param name="file"></param> /// <returns></returns> MnistImage[] ExtractImages(Stream input, string file) { //文件是gz格式的 using (var gz = new GZipStream(input, CompressionMode.Decompress)) { //不是2051說明下載的文件不對 if (Read32(gz) != 2051) { throw new Exception("不是2051說明下載的文件不對: " + file); } //圖片數 var count = Read32(gz); //行數 var rows = Read32(gz); //列數 var cols = Read32(gz); Console.WriteLine($"準備讀取{count}張圖片。"); var result = new MnistImage[count]; for (int i = 0; i < count; i++) { //圖片的大小(每個像素占一個bit) var size = rows * cols; var data = new byte[size]; //從數據流中讀取這麽大的一塊內容 gz.Read(data, 0, size); //將讀取到的內容轉換為MnistImage類型 result[i] = new MnistImage(cols, rows, data); } return result; } }View Code
準備一個MnistImage類型:
/// <summary> /// 圖片類型 /// </summary> public struct MnistImage { public int Cols, Rows; public byte[] Data; public float[] DataFloat; public MnistImage(int cols, int rows, byte[] data) { Cols = cols; Rows = rows; Data = data; DataFloat = new float[data.Length]; for (int i = 0; i < data.Length; i++) { //數據歸一化(這裏將0-255除255變成了0-1之間的小數) //也可以歸一為-0.5到0.5之間 DataFloat[i] = Data[i] / 255f; } } }View Code
這樣一來,圖片數據就處理完成了。
處理數字標簽數據壓縮包
數字標簽數據壓縮包和圖片數據壓縮包的格式類似。
偏移量 |
類型 |
值 |
意義 |
0 |
Int32 |
2051或2049 |
一個定死的魔術數。用來驗證該壓縮包是訓練集(2051)或測試集(2049) |
4 |
Int32 |
60000或10000 |
壓縮包的數字標簽數 |
5 |
Unsigned byte |
0 - 9 |
第一張圖片對應的數字 |
6 |
Unsigned byte |
0 - 9 |
第二張圖片對應的數字 |
… |
… |
… |
… |
它的處理更加簡單。
/// <summary> /// 處理標簽數據 /// </summary> /// <param name="input"></param> /// <param name="file"></param> /// <returns></returns> byte[] ExtractLabels(Stream input, string file) { using (var gz = new GZipStream(input, CompressionMode.Decompress)) { //不是2049說明下載的文件不對 if (Read32(gz) != 2049) { throw new Exception("不是2049說明下載的文件不對:" + file); } var count = Read32(gz); var labels = new byte[count]; gz.Read(labels, 0, count); return labels; } }View Code
將數字標簽轉化為二維數組:one-hot編碼
由於我們的數字為0-9,所以,可以視為有十個class。此時,為了後續的處理方便,我們將數字標簽轉化為數組。因此,一組標簽就轉換為了一個二維數組。
例如,標簽0變成[1,0,0,0,0,0,0,0,0,0]