1. 程式人生 > >使用ML.NET實現猜動畫片臺詞

使用ML.NET實現猜動畫片臺詞

conf num 分數 ict 版本 source style ogl post

前面幾篇主要內容出自微軟官方,經我特意修改的案例的文章:

使用ML.NET實現情感分析[新手篇]

使用ML.NET預測紐約出租車費

.NET Core玩轉機器學習

使用ML.NET實現情感分析[新手篇]後補

相信看過後大家對ML.NET有了一定的了解了,由於目前還是0.1的版本,也沒有更多官方示例放出來,大家普遍覺得提供的特性還不夠強大,所以處在觀望狀態也是能理解的。

本文結合Azure提供的語音識別服務,向大家展示另一種ML.NET有趣的玩法——猜動畫片臺詞。

這個場景特別容易想像,是一種你說我猜的遊戲,我會事先用ML.NET對若幹動畫片的臺詞進行分類學習,然後使用麥克風,讓使用者隨便說一句動畫片的臺詞(當然得是數據集中已存在的,沒有的不要搞事情呀!),然後來預測出自哪一部。跟隨我動手做做看。

準備工作


這次需要使用Azure的認知服務中一項API——Speaker Recognition,目前還處於免費試用階段,打開https://azure.microsoft.com/zh-cn/try/cognitive-services/?api=speaker-recognition,能看到如下頁面:

技術分享圖片

點擊獲取API密鑰,用自己的Azure賬號登錄,然後就能看到自己的密鑰了,類似如下圖:

技術分享圖片

創建項目


這一次請註意,我們要創建一個.NET Framework 4.6.1或以上版本的控制臺應用程序,通過NuGet分別引用三個類庫:Microsoft.ML,JiebaNet.Analyser,Microsoft.CognitiveServices.Speech。

然後把編譯平臺修改成x64,而不是Any CPU。(這一點非常重要)

代碼分解


在Main函數部分,我們只需要關心幾個主要步驟,先切詞,然後訓練模型,最後在一個循環中等待使用者說話,用模型進行預測。

static void Main(string[] args)
{
    Segment(_dataPath, _dataTrainPath);
    var model = Train();
    Evaluate(model);
    ConsoleKeyInfo x;
    do
    {
        var speech = Recognize();
        speech.Wait();
        Predict(model, speech.Result);
        Console.WriteLine(
"\nRecognition done. Your Choice (0: Stop Any key to continue): "); x = Console.ReadKey(true); } while (x.Key != ConsoleKey.D0); }

初始化的變量主要就是訓練數據,Azure語音識別密鑰等。註意YourServiceRegion的值是“westus”,而不是網址。

const string SubscriptionKey = "你的密鑰";
const string YourServiceRegion = "westus";
const string _dataPath = @".\data\dubs.txt";
const string _dataTrainPath = @".\data\dubs_result.txt";

定義數據結構和預測結構和我之前的文章一樣,沒有什麽特別之處。

public class DubbingData
{
    [Column(ordinal: "0")]
    public string DubbingText;
    [Column(ordinal: "1", name: "Label")]
    public string Label;
}

public class DubbingPrediction
{
    [ColumnName("PredictedLabel")]
    public string PredictedLabel;
}

切記部分註意對分隔符的過濾。

public static void Segment(string source, string result)
{
    var segmenter = new JiebaSegmenter();
    using (var reader = new StreamReader(source))
    {
        using (var writer = new StreamWriter(result))
        {
            while (true)
            {
                var line = reader.ReadLine();
                if (string.IsNullOrWhiteSpace(line))
                    break;
                var parts = line.Split(new[] { \t }, StringSplitOptions.RemoveEmptyEntries);
                if (parts.Length != 2) continue;
                var segments = segmenter.Cut(parts[0]);
                writer.WriteLine("{0}\t{1}", string.Join(" ", segments), parts[1]);
            }
        }
    }
}

訓練部分依然使用熟悉的多分類訓練器StochasticDualCoordinateAscentClassifier。TextFeaturizer用於對文本內容向量化處理。

public static PredictionModel<DubbingData, DubbingPrediction> Train()
{
    var pipeline = new LearningPipeline();
    pipeline.Add(new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab"));
    pipeline.Add(new TextFeaturizer("Features", "DubbingText"));
    pipeline.Add(new Dictionarizer("Label"));
    pipeline.Add(new StochasticDualCoordinateAscentClassifier());
    pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" });
    var model = pipeline.Train<DubbingData, DubbingPrediction>();
    return model;
}

驗證部分這次重點是看損失程度分數。

public static void Evaluate(PredictionModel<DubbingData, DubbingPrediction> model)
{
    var testData = new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab");
    var evaluator = new ClassificationEvaluator();
    var metrics = evaluator.Evaluate(model, testData);
    Console.WriteLine();
    Console.WriteLine("PredictionModel quality metrics evaluation");
    Console.WriteLine("------------------------------------------");
    //Console.WriteLine($"TopKAccuracy: {metrics.TopKAccuracy:P2}");
    Console.WriteLine($"LogLoss: {metrics.LogLoss:P2}");
}

預測部分沒有什麽大變化,就是對中文交互進行了友好展示。

public static void Predict(PredictionModel<DubbingData, DubbingPrediction> model, string sentence)
{
    IEnumerable<DubbingData> sentences = new[]
    {
        new DubbingData
        {
            DubbingText = sentence
        }
    };

    var segmenter = new JiebaSegmenter();
    foreach (var item in sentences)
    {
        item.DubbingText = string.Join(" ", segmenter.Cut(item.DubbingText));
    }

    IEnumerable<DubbingPrediction> predictions = model.Predict(sentences);
    Console.WriteLine();
    Console.WriteLine("Category Predictions");
    Console.WriteLine("---------------------");

    var sentencesAndPredictions = sentences.Zip(predictions, (sentiment, prediction) => (sentiment, prediction));
    foreach (var item in sentencesAndPredictions)
    {
        Console.WriteLine($"臺詞: {item.sentiment.DubbingText.Replace(" ", string.Empty)} | 來自動畫片: {item.prediction.PredictedLabel}");
    }
    Console.WriteLine();
}

Azure語音識別的調用如下。

static async Task<string> Recognize()
{
    var factory = SpeechFactory.FromSubscription(SubscriptionKey, YourServiceRegion);
    var lang = "zh-cn";

    using (var recognizer = factory.CreateSpeechRecognizer(lang))
    {
        Console.WriteLine("Say something...");

        var result = await recognizer.RecognizeAsync().ConfigureAwait(false);

        if (result.RecognitionStatus != RecognitionStatus.Recognized)
        {
            Console.WriteLine($"There was an error. Status:{result.RecognitionStatus.ToString()}, Reason:{result.RecognitionFailureReason}");
            return null;
        }
        else
        {
            Console.WriteLine($"We recognized: {result.RecognizedText}");
            return result.RecognizedText;
        }
    }
}

運行過程如下:

技術分享圖片

雖然這看上去有點幼稚,不過一樣讓你開心一笑了,不是麽?請期待更多有趣的案例。

本文使用的數據集:下載

完整的代碼如下:

using System;
using Microsoft.ML.Models;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using JiebaNet.Segmenter;
using System.IO;
using Microsoft.CognitiveServices.Speech;
using System.Threading.Tasks;

namespace DubbingRecognition
{
    class Program
    {
        public class DubbingData
        {
            [Column(ordinal: "0")]
            public string DubbingText;
            [Column(ordinal: "1", name: "Label")]
            public string Label;
        }

        public class DubbingPrediction
        {
            [ColumnName("PredictedLabel")]
            public string PredictedLabel;
        }

        const string SubscriptionKey = "你的密鑰";
        const string YourServiceRegion = "westus";
        const string _dataPath = @".\data\dubs.txt";
        const string _dataTrainPath = @".\data\dubs_result.txt";


        static void Main(string[] args)
        {
            Segment(_dataPath, _dataTrainPath);
            var model = Train();
            Evaluate(model);
            ConsoleKeyInfo x;
            do
            {
                var speech = Recognize();
                speech.Wait();
                Predict(model, speech.Result);
                Console.WriteLine("\nRecognition done. Your Choice (0: Stop Any key to continue): ");
                x = Console.ReadKey(true);
            } while (x.Key != ConsoleKey.D0);
        }

        public static void Segment(string source, string result)
        {
            var segmenter = new JiebaSegmenter();
            using (var reader = new StreamReader(source))
            {
                using (var writer = new StreamWriter(result))
                {
                    while (true)
                    {
                        var line = reader.ReadLine();
                        if (string.IsNullOrWhiteSpace(line))
                            break;
                        var parts = line.Split(new[] { \t }, StringSplitOptions.RemoveEmptyEntries);
                        if (parts.Length != 2) continue;
                        var segments = segmenter.Cut(parts[0]);
                        writer.WriteLine("{0}\t{1}", string.Join(" ", segments), parts[1]);
                    }
                }
            }
        }

        public static PredictionModel<DubbingData, DubbingPrediction> Train()
        {
            var pipeline = new LearningPipeline();
            pipeline.Add(new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab"));

            //pipeline.Add(new ColumnConcatenator("Features", "DubbingText"));

            pipeline.Add(new TextFeaturizer("Features", "DubbingText"));
            //pipeline.Add(new TextFeaturizer("Label", "Category"));
            pipeline.Add(new Dictionarizer("Label"));
            pipeline.Add(new StochasticDualCoordinateAscentClassifier());
            pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" });
            var model = pipeline.Train<DubbingData, DubbingPrediction>();
            return model;
        }

        public static void Evaluate(PredictionModel<DubbingData, DubbingPrediction> model)
        {
            var testData = new TextLoader<DubbingData>(_dataTrainPath, useHeader: false, separator: "tab");
            var evaluator = new ClassificationEvaluator();
            var metrics = evaluator.Evaluate(model, testData);
            Console.WriteLine();
            Console.WriteLine("PredictionModel quality metrics evaluation");
            Console.WriteLine("------------------------------------------");
            //Console.WriteLine($"TopKAccuracy: {metrics.TopKAccuracy:P2}");
            Console.WriteLine($"LogLoss: {metrics.LogLoss:P2}");
        }

        public static void Predict(PredictionModel<DubbingData, DubbingPrediction> model, string sentence)
        {
            IEnumerable<DubbingData> sentences = new[]
            {
                new DubbingData
                {
                    DubbingText = sentence
                }
            };

            var segmenter = new JiebaSegmenter();
            foreach (var item in sentences)
            {
                item.DubbingText = string.Join(" ", segmenter.Cut(item.DubbingText));
            }

            IEnumerable<DubbingPrediction> predictions = model.Predict(sentences);
            Console.WriteLine();
            Console.WriteLine("Category Predictions");
            Console.WriteLine("---------------------");

            var sentencesAndPredictions = sentences.Zip(predictions, (sentiment, prediction) => (sentiment, prediction));
            foreach (var item in sentencesAndPredictions)
            {
                Console.WriteLine($"臺詞: {item.sentiment.DubbingText.Replace(" ", string.Empty)} | 來自動畫片: {item.prediction.PredictedLabel}");
            }
            Console.WriteLine();
        }
        static async Task<string> Recognize()
        {
            var factory = SpeechFactory.FromSubscription(SubscriptionKey, YourServiceRegion);
            var lang = "zh-cn";

            using (var recognizer = factory.CreateSpeechRecognizer(lang))
            {
                Console.WriteLine("Say something...");

                var result = await recognizer.RecognizeAsync().ConfigureAwait(false);

                if (result.RecognitionStatus != RecognitionStatus.Recognized)
                {
                    Console.WriteLine($"There was an error. Status:{result.RecognitionStatus.ToString()}, Reason:{result.RecognitionFailureReason}");
                    return null;
                }
                else
                {
                    Console.WriteLine($"We recognized: {result.RecognizedText}");
                    return result.RecognizedText;
                }
            }
        }
    }
}

使用ML.NET實現猜動畫片臺詞