使用ML.NET實現猜動畫片臺詞
前面幾篇主要內容出自微軟官方,經我特意修改的案例的文章:
使用ML.NET實現情感分析[新手篇]
使用ML.NET預測紐約出租車費
.NET Core玩轉機器學習
使用ML.NET實現情感分析[新手篇]後補
相信看過後大家對ML.NET有了一定的了解了,由於目前還是0.1的版本,也沒有更多官方示例放出來,大家普遍覺得提供的特性還不夠強大,所以處在觀望狀態也是能理解的。
本文結合Azure提供的語音識別服務,向大家展示另一種ML.NET有趣的玩法——猜動畫片臺詞。
這個場景特別容易想像,是一種你說我猜的遊戲,我會事先用ML.NET對若幹動畫片的臺詞進行分類學習,然後使用麥克風,讓使用者隨便說一句動畫片的臺詞(當然得是數據集中已存在的,沒有的不要搞事情呀!),然後來預測出自哪一部。跟隨我動手做做看。
準備工作
這次需要使用Azure的認知服務中一項API——Speaker Recognition,目前還處於免費試用階段,打開https://azure.microsoft.com/zh-cn/try/cognitive-services/?api=speaker-recognition,能看到如下頁面:
點擊獲取API密鑰,用自己的Azure賬號登錄,然後就能看到自己的密鑰了,類似如下圖:
創建項目
這一次請註意,我們要創建一個.NET Framework 4.6.1或以上版本的控制臺應用程序,通過NuGet分別引用三個類庫:Microsoft.ML,JiebaNet.Analyser,Microsoft.CognitiveServices.Speech。
然後把編譯平臺修改成x64,而不是Any CPU。(這一點非常重要)
代碼分解
在Main函數部分,我們只需要關心幾個主要步驟,先切詞,然後訓練模型,最後在一個循環中等待使用者說話,用模型進行預測。
static void Main(string[] args) { Segment(_dataPath, _dataTrainPath); var model = Train(); Evaluate(model); ConsoleKeyInfo x; do { var speech = Recognize(); speech.Wait(); Predict(model, speech.Result); Console.WriteLine("\nRecognition done. Your Choice (0: Stop Any key to continue): "); x = Console.ReadKey(true); } while (x.Key != ConsoleKey.D0); }
初始化的變量主要就是訓練數據,Azure語音識別密鑰等。註意YourServiceRegion的值是“westus”,而不是網址。
const string SubscriptionKey = "你的密鑰"; const string YourServiceRegion = "westus"; const string _dataPath = @".\data\dubs.txt"; const string _dataTrainPath = @".\data\dubs_result.txt";
定義數據結構和預測結構和我之前的文章一樣,沒有什麽特別之處。
public class DubbingData { [Column(ordinal: "0")] public string DubbingText; [Column(ordinal: "1", name: "Label")] public string Label; } public class DubbingPrediction { [ColumnName("PredictedLabel")] public string PredictedLabel; }
切記部分註意對分隔符的過濾。
public static void Segment(string source, string result) { var segmenter = new JiebaSegmenter(); using (var reader = new StreamReader(source)) { using (var writer = new StreamWriter(result)) { while (true) { var line = reader.ReadLine(); if (string.IsNullOrWhiteSpace(line)) break; var parts = line.Split(new[] { ‘\t‘ }, StringSplitOptions.RemoveEmptyEntries); if (parts.Length != 2) continue; var segments = segmenter.Cut(parts[0]); writer.WriteLine("{0}\t{1}", string.Join(" ", segments), parts[1]); } } } }
訓練部分依然使用熟悉的多分類訓練器StochasticDualCoordinateAscentClassifier。TextFeaturizer用於對文本內容向量化處理。
public static PredictionModel<DubbingData, DubbingPrediction> Train() { var pipeline = new LearningPipeline(); pipeline.Add(new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab")); pipeline.Add(new TextFeaturizer("Features", "DubbingText")); pipeline.Add(new Dictionarizer("Label")); pipeline.Add(new StochasticDualCoordinateAscentClassifier()); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); var model = pipeline.Train<DubbingData, DubbingPrediction>(); return model; }
驗證部分這次重點是看損失程度分數。
public static void Evaluate(PredictionModel<DubbingData, DubbingPrediction> model) { var testData = new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab"); var evaluator = new ClassificationEvaluator(); var metrics = evaluator.Evaluate(model, testData); Console.WriteLine(); Console.WriteLine("PredictionModel quality metrics evaluation"); Console.WriteLine("------------------------------------------"); //Console.WriteLine($"TopKAccuracy: {metrics.TopKAccuracy:P2}"); Console.WriteLine($"LogLoss: {metrics.LogLoss:P2}"); }
預測部分沒有什麽大變化,就是對中文交互進行了友好展示。
public static void Predict(PredictionModel<DubbingData, DubbingPrediction> model, string sentence) { IEnumerable<DubbingData> sentences = new[] { new DubbingData { DubbingText = sentence } }; var segmenter = new JiebaSegmenter(); foreach (var item in sentences) { item.DubbingText = string.Join(" ", segmenter.Cut(item.DubbingText)); } IEnumerable<DubbingPrediction> predictions = model.Predict(sentences); Console.WriteLine(); Console.WriteLine("Category Predictions"); Console.WriteLine("---------------------"); var sentencesAndPredictions = sentences.Zip(predictions, (sentiment, prediction) => (sentiment, prediction)); foreach (var item in sentencesAndPredictions) { Console.WriteLine($"臺詞: {item.sentiment.DubbingText.Replace(" ", string.Empty)} | 來自動畫片: {item.prediction.PredictedLabel}"); } Console.WriteLine(); }
Azure語音識別的調用如下。
static async Task<string> Recognize() { var factory = SpeechFactory.FromSubscription(SubscriptionKey, YourServiceRegion); var lang = "zh-cn"; using (var recognizer = factory.CreateSpeechRecognizer(lang)) { Console.WriteLine("Say something..."); var result = await recognizer.RecognizeAsync().ConfigureAwait(false); if (result.RecognitionStatus != RecognitionStatus.Recognized) { Console.WriteLine($"There was an error. Status:{result.RecognitionStatus.ToString()}, Reason:{result.RecognitionFailureReason}"); return null; } else { Console.WriteLine($"We recognized: {result.RecognizedText}"); return result.RecognizedText; } } }
運行過程如下:
雖然這看上去有點幼稚,不過一樣讓你開心一笑了,不是麽?請期待更多有趣的案例。
本文使用的數據集:下載
完整的代碼如下:
using System; using Microsoft.ML.Models; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; using System.Collections.Generic; using System.Linq; using Microsoft.ML; using JiebaNet.Segmenter; using System.IO; using Microsoft.CognitiveServices.Speech; using System.Threading.Tasks; namespace DubbingRecognition { class Program { public class DubbingData { [Column(ordinal: "0")] public string DubbingText; [Column(ordinal: "1", name: "Label")] public string Label; } public class DubbingPrediction { [ColumnName("PredictedLabel")] public string PredictedLabel; } const string SubscriptionKey = "你的密鑰"; const string YourServiceRegion = "westus"; const string _dataPath = @".\data\dubs.txt"; const string _dataTrainPath = @".\data\dubs_result.txt"; static void Main(string[] args) { Segment(_dataPath, _dataTrainPath); var model = Train(); Evaluate(model); ConsoleKeyInfo x; do { var speech = Recognize(); speech.Wait(); Predict(model, speech.Result); Console.WriteLine("\nRecognition done. Your Choice (0: Stop Any key to continue): "); x = Console.ReadKey(true); } while (x.Key != ConsoleKey.D0); } public static void Segment(string source, string result) { var segmenter = new JiebaSegmenter(); using (var reader = new StreamReader(source)) { using (var writer = new StreamWriter(result)) { while (true) { var line = reader.ReadLine(); if (string.IsNullOrWhiteSpace(line)) break; var parts = line.Split(new[] { ‘\t‘ }, StringSplitOptions.RemoveEmptyEntries); if (parts.Length != 2) continue; var segments = segmenter.Cut(parts[0]); writer.WriteLine("{0}\t{1}", string.Join(" ", segments), parts[1]); } } } } public static PredictionModel<DubbingData, DubbingPrediction> Train() { var pipeline = new LearningPipeline(); pipeline.Add(new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab")); //pipeline.Add(new ColumnConcatenator("Features", "DubbingText")); pipeline.Add(new TextFeaturizer("Features", "DubbingText")); //pipeline.Add(new TextFeaturizer("Label", "Category")); pipeline.Add(new Dictionarizer("Label")); pipeline.Add(new StochasticDualCoordinateAscentClassifier()); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); var model = pipeline.Train<DubbingData, DubbingPrediction>(); return model; } public static void Evaluate(PredictionModel<DubbingData, DubbingPrediction> model) { var testData = new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab"); var evaluator = new ClassificationEvaluator(); var metrics = evaluator.Evaluate(model, testData); Console.WriteLine(); Console.WriteLine("PredictionModel quality metrics evaluation"); Console.WriteLine("------------------------------------------"); //Console.WriteLine($"TopKAccuracy: {metrics.TopKAccuracy:P2}"); Console.WriteLine($"LogLoss: {metrics.LogLoss:P2}"); } public static void Predict(PredictionModel<DubbingData, DubbingPrediction> model, string sentence) { IEnumerable<DubbingData> sentences = new[] { new DubbingData { DubbingText = sentence } }; var segmenter = new JiebaSegmenter(); foreach (var item in sentences) { item.DubbingText = string.Join(" ", segmenter.Cut(item.DubbingText)); } IEnumerable<DubbingPrediction> predictions = model.Predict(sentences); Console.WriteLine(); Console.WriteLine("Category Predictions"); Console.WriteLine("---------------------"); var sentencesAndPredictions = sentences.Zip(predictions, (sentiment, prediction) => (sentiment, prediction)); foreach (var item in sentencesAndPredictions) { Console.WriteLine($"臺詞: {item.sentiment.DubbingText.Replace(" ", string.Empty)} | 來自動畫片: {item.prediction.PredictedLabel}"); } Console.WriteLine(); } static async Task<string> Recognize() { var factory = SpeechFactory.FromSubscription(SubscriptionKey, YourServiceRegion); var lang = "zh-cn"; using (var recognizer = factory.CreateSpeechRecognizer(lang)) { Console.WriteLine("Say something..."); var result = await recognizer.RecognizeAsync().ConfigureAwait(false); if (result.RecognitionStatus != RecognitionStatus.Recognized) { Console.WriteLine($"There was an error. Status:{result.RecognitionStatus.ToString()}, Reason:{result.RecognitionFailureReason}"); return null; } else { Console.WriteLine($"We recognized: {result.RecognizedText}"); return result.RecognizedText; } } } } }
使用ML.NET實現猜動畫片臺詞