ML.NET 2- 預測出租車價格
阿新 • • 發佈:2018-12-02
1. 預備測試資料
2. 載入模型
3. 訓練
4. 預測
實現:
TaxiFarePrediction.cs: using System; using System.Collections.Generic; using System.IO; using System.Text; using System.Threading.Tasks; using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Models; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; namespace _01_TaxiFare { public class TaxiFarePrediction { static readonly string _datapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-train.csv"); static readonly string _testdatapath = Path.Combine(Environment.CurrentDirectory, "taxi-fare-test.csv"); static readonly string _modelpath = Path.Combine(Environment.CurrentDirectory, "Model.zip"); public static async Task<TaxiTripFarePrediction> Predict(TaxiTrip tt) { var model = await Train(); Evaluate(model); return model.Predict(tt); } private static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train() { var pipeline = new LearningPipeline { new TextLoader(_datapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','), new ColumnCopier(("FareAmount", "Label")), new CategoricalOneHotVectorizer( "VendorId", "RateCode", "PaymentType"), new ColumnConcatenator( "Features", "VendorId", "RateCode", "PassengerCount", "TripDistance", "PaymentType"), new FastTreeRegressor() }; PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>(); await model.WriteAsync(_modelpath); return model; } private static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model) { var testData = new TextLoader(_testdatapath).CreateFrom<TaxiTrip>(useHeader: true, separator: ','); var evaluator = new RegressionEvaluator(); RegressionMetrics metrics = evaluator.Evaluate(model, testData); Console.WriteLine($"Rms = {metrics.Rms}"); Console.WriteLine($"RSquared = {metrics.RSquared}"); } } }
TaxiTrip.cs:
using System; using System.Collections.Generic; using System.Text; using Microsoft.ML.Runtime.Api; namespace _01_TaxiFare { public class TaxiTrip { [Column("0")] public string VendorId; [Column("1")] public string RateCode; [Column("2")] public float PassengerCount; [Column("3")] public float TripTime; [Column("4")] public float TripDistance; [Column("5")] public string PaymentType; [Column("6")] public float FareAmount; } public class TaxiTripFarePrediction { [ColumnName("Score")] public float FareAmount; } }
呼叫:
using System; namespace _01_TaxiFare { class Program { static void Main(string[] args) { var prediction = TaxiFarePrediction.Predict(new TaxiTrip { VendorId = "VTS", RateCode = "1", PassengerCount = 1, TripDistance = 10.33f, PaymentType = "CSH", FareAmount = 0 // predict it. actual = 29.5 }).Result; Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.FareAmount); Console.ReadLine(); } } }