機器學習框架ML.NET學習筆記【4】多元分類之手寫數字識別
一、問題與解決方案
通過多元分類演算法進行手寫數字識別,手寫數字的圖片解析度為8*8的灰度圖片、已經預先進行過處理,讀取了各畫素點的灰度值,並進行了標記。
其中第0列是序號(不參與運算)、1-64列是畫素值、65列是結果。
我們以64位畫素值為特徵進行多元分類,演算法採用SDCA最大熵分類演算法。
二、原始碼
先貼出全部程式碼:
namespace MulticlassClassification_Mnist { class Program { static readonly string TrainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "optdigits-full.csv"); static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip"); static void Main(string[] args) { MLContext mlContext = new MLContext(seed: 1); TrainAndSaveModel(mlContext); TestSomePredictions(mlContext); Console.WriteLine("Hit any key to finish the app"); Console.ReadKey(); } public static void TrainAndSaveModel(MLContext mlContext) { // STEP 1: 準備資料 var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath, columns: new[] { new TextLoader.Column("Serial", DataKind.Single, 0), new TextLoader.Column("PixelValues", DataKind.Single, 1, 64), new TextLoader.Column("Number", DataKind.Single, 65) }, hasHeader: true, separatorChar: ',' ); var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.2); var trainData = trainTestData.TrainSet; var testData = trainTestData.TestSet; // STEP 2: 配置資料處理管道 var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue); // STEP 3: 配置訓練演算法 var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues"); var trainingPipeline = dataProcessPipeline.Append(trainer) .Append(mlContext.Transforms.Conversion.MapKeyToValue("Number", "Label")); // STEP 4: 訓練模型使其與資料集擬合 Console.WriteLine("=============== Train the model fitting to the DataSet ==============="); ITransformer trainedModel = trainingPipeline.Fit(trainData); // STEP 5:評估模型的準確性 Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); var predictions = trainedModel.Transform(testData); var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Number", scoreColumnName: "Score"); PrintMultiClassClassificationMetrics(trainer.ToString(), metrics); // STEP 6:儲存模型 mlContext.ComponentCatalog.RegisterAssembly(typeof(DebugConversion).Assembly); mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath); Console.WriteLine("The model is saved to {0}", ModelPath); } private static void TestSomePredictions(MLContext mlContext) { // Load Model ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema); // Create prediction engine var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel); //num 1 InputData MNIST1 = new InputData() { PixelValues = new float[] { 0, 0, 0, 0, 14, 13, 1, 0, 0, 0, 0, 5, 16, 16, 2, 0, 0, 0, 0, 14, 16, 12, 0, 0, 0, 1, 10, 16, 16, 12, 0, 0, 0, 3, 12, 14, 16, 9, 0, 0, 0, 0, 0, 5, 16, 15, 0, 0, 0, 0, 0, 4, 16, 14, 0, 0, 0, 0, 0, 1, 13, 16, 1, 0 } }; var resultprediction1 = predEngine.Predict(MNIST1); resultprediction1.PrintToConsole(); } } class InputData { public float Serial; [VectorType(64)] public float[] PixelValues; public float Number; } class OutPutData : InputData { public float[] Score; } }
三、分析
整體流程和二元分類沒有什麼區別,下面解釋一下有差異的兩個地方。
1、載入資料
// STEP 1: 準備資料 var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath, columns: new[] { new TextLoader.Column("Serial", DataKind.Single, 0), new TextLoader.Column("PixelValues", DataKind.Single, 1, 64), new TextLoader.Column("Number", DataKind.Single, 65) }, hasHeader: true, separatorChar: ',' );
這次我們不是通過實體物件來載入資料,而是通過列資訊來進行載入,其中PixelValues是特徵值,Number是標籤值。
2、訓練通道
// STEP 2: 配置資料處理管道 var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)
// STEP 3: 配置訓練演算法 var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues");
var trainingPipeline = dataProcessPipeline.Append(trainer)
.Append(mlContext.Transforms.Conversion.MapKeyToValue("Number", "Label"));
// STEP 4: 訓練模型使其與資料集擬合
ITransformer trainedModel = trainingPipeline.Fit(trainData);
首先通過MapValueToKey方法將Number值轉換為Key型別,多元分類演算法要求標籤值必須是這種型別(類似列舉型別,二元分類要求標籤為BOOL型別)。關於這個轉換的原因及編碼方式,下面詳細介紹。
四、鍵值型別編碼與獨熱編碼
MapValueToKey功能是將(字串)值型別轉換為KeyTpye型別。
有時候某些輸入欄位用來表示型別(類別特徵),但本身並沒有特別的含義,比如編號、電話號碼、行政區域名稱或編碼等,這裡需要把這些型別轉換為1到一個整數如1-300來進行重新編號。
舉個簡單的例子,我們進行圖片識別的時候,目標結果可能是“貓咪”、“小狗”、“人物”這些分類,需要把這些分類轉換為1、2、3這樣的整數。但本文的標籤值本身就是1、2、3,為什麼還要轉換呢?因為我們這裡的一二三其實不是數學意義上的數字,而是一種標誌,可以理解為壹、貳、叄,所以要進行編碼。
MapKeyToValue和MapValueToKey相反,它把將鍵型別轉換回其原始值(字串)。就是說標籤是文字格式,在運算前已經被轉換為數字列舉型別了,此時預測結果為數字,通過MapKeyToValue將其結果轉換為對應文字。
MapValueToKey一般是對標籤值進行編碼,一般不用於特徵值,如果是特徵值為字串型別的,建議採用獨熱編碼。獨熱編碼即 One-Hot 編碼,又稱一位有效編碼,其方法是使用N位狀態暫存器來對N個狀態進行編碼,每個狀態都由他獨立的暫存器位,並且在任意時候,其中只有一位有效。例如:
自然狀態碼為:0,1,2,3,4,5
獨熱編碼為:000001,000010,000100,001000,010000,100000
怎麼理解這個事情呢?舉個例子,假如我們要進行人的身材的分析,但我們希望加入地域特徵,比如:“黑龍江”、“山東”、“湖南”、“廣東”這種特徵,但這種字串機器學習是不認識的,必須轉換為浮點數,剛才提到MapKeyToValue可以把字串轉換為數字,為什麼這裡要採用獨熱編碼呢?簡單來說,假設把地域名稱轉換為1到10幾個數字,在歐氏幾何中1到3的尤拉距離和1到9的尤拉距離是不等的,但經過獨熱編碼後,任意兩點間的尤拉距離都是相等的,而我們這裡的地域特徵僅僅是想表達分類關係,彼此之間沒有其他邏輯關係,所以應該採用獨熱編碼。
五、進度除錯
一般機器演算法的資料擬合過程時間都比較長,有時程式跑了兩個小時還沒結束,也不知道還需要多長時間,著實讓人著急,所以及時瞭解學習進度,是很有必要的。
由於機器學習演算法一般都有“遞迴直到收斂”這種操作,所以我們是沒有辦法預先知道最終運算次數的,能做到的只能列印一些過程資訊,看到程式在動,心裡也有點底,當系統跑過一次之後,基本就大致知道需要多少次擬合了,後面再除錯就可以大致瞭解進度了。補充一句,可不可以在測試階段先減少樣本資料進行快速除錯,除錯通過後再切換到全樣本進行訓練?其實不行,有時候樣本數量小,可能會引起指標震盪,時間反而長了。
之前在Githube上看到有人通過MLContext.LOG事件來列印除錯資訊,我試了一下,發現沒法控制篩選內容,不太方便,後來想到一個方法,就是新增一個自定義資料處理通道,這個通道不做具體事情,就列印除錯資訊。
類定義:
namespace MulticlassClassification_Mnist { public class DebugConversionInput { public float Serial { get; set; } } public class DebugConversionOutput { public float DebugFeature { get; set; } } [CustomMappingFactoryAttribute("DebugConversionAction")] public class DebugConversion : CustomMappingFactory<DebugConversionInput, DebugConversionOutput> { static long TotalCount = 0; public void CustomAction(DebugConversionInput input, DebugConversionOutput output) { output.DebugFeature = 1.0f;
TotalCount++; Console.WriteLine($"DebugConversion.CustomAction's debug info.TotalCount={TotalCount} "); } public override Action<DebugConversionInput, DebugConversionOutput> GetMapping() => CustomAction; } }
使用方法:
var dataProcessPipeline = mlContext.Transforms.CustomMapping(new DebugConversion().GetMapping(), contractName: "DebugConversionAction") .Append(...) .Append(mlContext.Transforms.Concatenate("Features", new string[] { "RealFeatures", "DebugFeature" }));
通過CustomMapping載入我們自定義的資料處理通道,由於資料集是懶載入(Lazy)的,所以必須把我們自定義資料處理通道的輸出加入為特徵值,才能參與運算,然後演算法在操作每一條資料時都會呼叫到CustomAction方法,這樣就可以列印進度資訊了。為了不影響運算結果,我們把這個資料處理通道的輸出值固定為1.0f 。
六、資源獲取
原始碼下載地址:https://github.com/seabluescn/Study_ML.NET
工程名稱:MulticlassClassification_Mnist
點選檢視機器學習框架ML.NET學習筆記系列文章