Model Builder Multiclass Classification sample for Docs tutorial (#697)
* Initial commit * Initial Commit * Updated solution name and README * Removing irrelevant files * Renaming top-level folder * Working * Adding condensed dataset * Updated to use GH hosted dataset * Updated README and console output * Fixed project names in README * Updated data source * Back to working state with updates to use DB * Updated README to emphasize CSharp * Added C# * Updated based n feedback
This commit is contained in:
Родитель
f857f0c323
Коммит
cbacd6022a
|
@ -0,0 +1,40 @@
|
|||
# Restaurant Violation Inspections
|
||||
|
||||
|
||||
| ML.NET version | Status | App Type | Data type | Scenario | ML Task | Algorithms |
|
||||
|----------------|-------------------------------|-------------|-----------|---------------------|---------------------------|-----------------------------|
|
||||
| v1.3.1 | Up-to-date | Console App | Single data sample | Issue Classification | Multiclass classification | Linear Classification |
|
||||
|
||||
## Goal
|
||||
|
||||
Create a C# .NET Core Console application that uses an ML.NET multiclass classification machine learning model trained using Model Builder to categorize the risk level of restaurant violations found during health inspections.
|
||||
|
||||
![](./images/console.PNG)
|
||||
|
||||
## Application
|
||||
|
||||
- RestaurantViolations: A C# .NET Core Console application that uses a multiclass classification model to assign risk to violations encountered during restaurant inspections.
|
||||
- RestaurantViolationsML.ConsoleApp: A .NET Core Console application that contains the model training and test prediction code.
|
||||
- RestaurantViolationsML.Model: A .NET Standard class library containing the data models that define the schema of input and output model data as well as the persisted version of the best performing model during training.
|
||||
|
||||
## The data
|
||||
|
||||
> The data set used to train and evaluate the machine learning model is originally from the [San Francisco Department of Public Health Restaurant Safety Scores](https://www.sfdph.org/dph/EH/Food/score/default.asp). For convenience, the dataset has been condensed to only include the columns relevant to train the model and make predictions. Visit the following website to learn more about the [dataset](https://data.sfgov.org/Health-and-Social-Services/Restaurant-Scores-LIVES-Standard/pyih-qa8i?row_index=0).
|
||||
|
||||
| inspection_type | violation_description | risk_category |
|
||||
| --- | --- | --- |
|
||||
| Routine - Unscheduled | Inadequately cleaned or sanitized food contact surfaces | Moderate Risk |
|
||||
| New Ownership | High risk vermin infestation | High Risk |
|
||||
| Routine - Unscheduled | Wiping cloths not clean or properly stored or inadequate sanitizer | Low Risk |
|
||||
|
||||
## The model
|
||||
|
||||
The goal of the application is to predict whether an inspection violation belongs to one of several categories (low/moderate/high risk). The Machine Learning Task to use in this scenario is multiclass classification. The model in this application was trained using Model Builder.
|
||||
|
||||
[Model Builder](https://marketplace.visualstudio.com/items?itemName=MLNET.07) is an intuitive graphical Visual Studio extension to build, train, and deploy custom machine learning models.
|
||||
|
||||
Model Builder uses automated machine learning (AutoML) to explore different machine learning algorithms and settings to help you find the one that best suits your scenario.
|
||||
|
||||
You don't need machine learning expertise to use Model Builder. All you need is some data, and a problem to solve. Model Builder generates the code to add the model to your .NET application.
|
||||
|
||||
In this solution, both the `RestaurantViolationsML.ConsoleApp` and `RestaurantViolationsML.Model` projects are autogenerated by Model Builder.
|
Двоичные данные
samples/modelbuilder/MulticlassClassification_RestaurantViolations/RestaurantScores.zip
Normal file
Двоичные данные
samples/modelbuilder/MulticlassClassification_RestaurantViolations/RestaurantScores.zip
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1,37 @@
|
|||
|
||||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio Version 16
|
||||
VisualStudioVersion = 16.0.29403.142
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RestaurantViolations", "RestaurantViolations\RestaurantViolations.csproj", "{FD0B76FD-773C-4DE0-BD65-EC60ECAA7503}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RestaurantViolationsML.Model", "RestaurantViolationsML.Model\RestaurantViolationsML.Model.csproj", "{016977A3-1164-42E2-9EFC-D793E3B05EAB}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RestaurantViolationsML.ConsoleApp", "RestaurantViolationsML.ConsoleApp\RestaurantViolationsML.ConsoleApp.csproj", "{27E67B4C-1941-4AD4-8CAC-EAC8DB6048B0}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
Release|Any CPU = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(ProjectConfigurationPlatforms) = postSolution
|
||||
{FD0B76FD-773C-4DE0-BD65-EC60ECAA7503}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{FD0B76FD-773C-4DE0-BD65-EC60ECAA7503}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{FD0B76FD-773C-4DE0-BD65-EC60ECAA7503}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{FD0B76FD-773C-4DE0-BD65-EC60ECAA7503}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{016977A3-1164-42E2-9EFC-D793E3B05EAB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{016977A3-1164-42E2-9EFC-D793E3B05EAB}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{016977A3-1164-42E2-9EFC-D793E3B05EAB}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{016977A3-1164-42E2-9EFC-D793E3B05EAB}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{27E67B4C-1941-4AD4-8CAC-EAC8DB6048B0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{27E67B4C-1941-4AD4-8CAC-EAC8DB6048B0}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{27E67B4C-1941-4AD4-8CAC-EAC8DB6048B0}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{27E67B4C-1941-4AD4-8CAC-EAC8DB6048B0}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
SolutionGuid = {E3421DCB-2CD0-457B-A7BE-E03D4A1DE96A}
|
||||
EndGlobalSection
|
||||
EndGlobal
|
|
@ -0,0 +1,27 @@
|
|||
using System;
|
||||
using RestaurantViolationsML.Model;
|
||||
|
||||
namespace RestaurantViolations
|
||||
{
|
||||
class Program
|
||||
{
|
||||
static void Main(string[] args)
|
||||
{
|
||||
// Create sample data
|
||||
ModelInput input = new ModelInput
|
||||
{
|
||||
InspectionType = "Complaint",
|
||||
ViolationDescription = "Inadequate sewage or wastewater disposal"
|
||||
};
|
||||
|
||||
// Make prediction
|
||||
ModelOutput result = ConsumeModel.Predict(input);
|
||||
|
||||
// Print Prediction
|
||||
Console.WriteLine($"Inspection type: {input.InspectionType}");
|
||||
Console.WriteLine($"Violation description: {input.ViolationDescription}");
|
||||
Console.WriteLine($"Predicted risk category: {result.Prediction}");
|
||||
Console.ReadKey();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>netcoreapp2.1</TargetFramework>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\RestaurantViolationsML.Model\RestaurantViolationsML.Model.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
|
@ -0,0 +1,190 @@
|
|||
// This file was auto-generated by ML.NET Model Builder.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Compression;
|
||||
using System.Linq;
|
||||
using System.Net.Http;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.ML;
|
||||
using Microsoft.ML.Data;
|
||||
using RestaurantViolationsML.Model;
|
||||
|
||||
namespace RestaurantViolationsML.ConsoleApp
|
||||
{
|
||||
public static class ModelBuilder
|
||||
{
|
||||
private static string TRAIN_DATA_FILEPATH = GetAbsolutePath(@"RestaurantScores.tsv");
|
||||
private static string MODEL_FILEPATH = @"../../../../RestaurantViolationsML.Model/MLModel.zip";
|
||||
|
||||
// Create MLContext to be shared across the model creation workflow objects
|
||||
// Set a random seed for repeatable/deterministic results across multiple trainings.
|
||||
private static MLContext mlContext = new MLContext(seed: 1);
|
||||
|
||||
public static async Task CreateModel()
|
||||
{
|
||||
// Download Data
|
||||
if(!File.Exists(TRAIN_DATA_FILEPATH))
|
||||
{
|
||||
await DownloadData();
|
||||
}
|
||||
|
||||
// Load Data
|
||||
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
|
||||
path: TRAIN_DATA_FILEPATH,
|
||||
hasHeader: true,
|
||||
separatorChar: '\t',
|
||||
allowQuoting: true,
|
||||
allowSparse: false);
|
||||
|
||||
// Build training pipeline
|
||||
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
|
||||
|
||||
// Evaluate quality of Model
|
||||
Evaluate(mlContext, trainingDataView, trainingPipeline);
|
||||
|
||||
// Train Model
|
||||
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
|
||||
|
||||
// Save model
|
||||
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
|
||||
}
|
||||
|
||||
private static async Task DownloadData()
|
||||
{
|
||||
using (var client = new HttpClient())
|
||||
{
|
||||
var response = await client.GetStreamAsync("https://github.com/luisquintanilla/machinelearning-samples/raw/AB1608219/samples/modelbuilder/MulticlassClassification_RestaurantViolations/RestaurantScores.zip");
|
||||
|
||||
using (var archive = new ZipArchive(response))
|
||||
{
|
||||
foreach(var file in archive.Entries)
|
||||
{
|
||||
if (Path.GetExtension(file.FullName) == ".tsv")
|
||||
{
|
||||
ZipFileExtensions.ExtractToFile(file, GetAbsolutePath(TRAIN_DATA_FILEPATH));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
|
||||
{
|
||||
// Data process configuration with pipeline data transformations
|
||||
var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("RiskCategory", "RiskCategory")
|
||||
.Append(mlContext.Transforms.Categorical.OneHotEncoding(new[] { new InputOutputColumnPair("InspectionType", "InspectionType"), new InputOutputColumnPair("ViolationDescription", "ViolationDescription") }))
|
||||
.Append(mlContext.Transforms.Concatenate("Features", new[] { "InspectionType", "ViolationDescription" }))
|
||||
.Append(mlContext.Transforms.NormalizeMinMax("Features", "Features"))
|
||||
.AppendCacheCheckpoint(mlContext);
|
||||
|
||||
// Set the training algorithm
|
||||
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.AveragedPerceptron(labelColumnName: "RiskCategory", numberOfIterations: 10, featureColumnName: "Features"), labelColumnName: "RiskCategory")
|
||||
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
|
||||
var trainingPipeline = dataProcessPipeline.Append(trainer);
|
||||
|
||||
return trainingPipeline;
|
||||
}
|
||||
|
||||
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
|
||||
{
|
||||
Console.WriteLine("=============== Training model ===============");
|
||||
|
||||
ITransformer model = trainingPipeline.Fit(trainingDataView);
|
||||
|
||||
Console.WriteLine("=============== End of training process ===============");
|
||||
return model;
|
||||
}
|
||||
|
||||
private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
|
||||
{
|
||||
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
|
||||
// in order to evaluate and get the model's accuracy metrics
|
||||
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
|
||||
var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "RiskCategory");
|
||||
PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
|
||||
}
|
||||
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
|
||||
{
|
||||
// Save/persist the trained model to a .ZIP file
|
||||
Console.WriteLine($"=============== Saving the model ===============");
|
||||
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
|
||||
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
|
||||
}
|
||||
|
||||
public static string GetAbsolutePath(string relativePath)
|
||||
{
|
||||
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
|
||||
string assemblyFolderPath = _dataRoot.Directory.FullName;
|
||||
|
||||
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
|
||||
|
||||
return fullPath;
|
||||
}
|
||||
|
||||
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
|
||||
{
|
||||
Console.WriteLine($"************************************************************");
|
||||
Console.WriteLine($"* Metrics for multi-class classification model ");
|
||||
Console.WriteLine($"*-----------------------------------------------------------");
|
||||
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
|
||||
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
|
||||
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
|
||||
for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
|
||||
{
|
||||
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
|
||||
}
|
||||
Console.WriteLine($"************************************************************");
|
||||
}
|
||||
|
||||
public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
|
||||
{
|
||||
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
|
||||
|
||||
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
|
||||
var microAccuracyAverage = microAccuracyValues.Average();
|
||||
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
|
||||
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
|
||||
|
||||
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
|
||||
var macroAccuracyAverage = macroAccuracyValues.Average();
|
||||
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
|
||||
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
|
||||
|
||||
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
|
||||
var logLossAverage = logLossValues.Average();
|
||||
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
|
||||
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
|
||||
|
||||
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
|
||||
var logLossReductionAverage = logLossReductionValues.Average();
|
||||
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
|
||||
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
|
||||
|
||||
Console.WriteLine($"*************************************************************************************************************");
|
||||
Console.WriteLine($"* Metrics for Multi-class Classification model ");
|
||||
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
|
||||
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
|
||||
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
|
||||
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
|
||||
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
|
||||
Console.WriteLine($"*************************************************************************************************************");
|
||||
|
||||
}
|
||||
|
||||
public static double CalculateStandardDeviation(IEnumerable<double> values)
|
||||
{
|
||||
double average = values.Average();
|
||||
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
|
||||
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
|
||||
return standardDeviation;
|
||||
}
|
||||
|
||||
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
|
||||
{
|
||||
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
|
||||
return confidenceInterval95;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
// This file was auto-generated by ML.NET Model Builder.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.ML;
|
||||
using RestaurantViolationsML.Model;
|
||||
|
||||
namespace RestaurantViolationsML.ConsoleApp
|
||||
{
|
||||
class Program
|
||||
{
|
||||
static async Task Main(string[] args)
|
||||
{
|
||||
//await ModelBuilder.CreateModel();
|
||||
|
||||
// Create single instance of sample data from first line of dataset for model input
|
||||
ModelInput sampleData = CreateSingleDataSample();
|
||||
|
||||
// Make a single prediction on the sample data and print results
|
||||
ModelOutput predictionResult = ConsumeModel.Predict(sampleData);
|
||||
|
||||
Console.WriteLine("Using model to make single prediction -- Comparing actual RiskCategory with predicted RiskCategory from sample data...\n\n");
|
||||
Console.WriteLine($"InspectionType: {sampleData.InspectionType}");
|
||||
Console.WriteLine($"ViolationDescription: {sampleData.ViolationDescription}");
|
||||
Console.WriteLine($"\n\nActual RiskCategory: {sampleData.RiskCategory} \nPredicted RiskCategory value {predictionResult.Prediction} \nPredicted RiskCategory scores: [{String.Join(",", predictionResult.Score)}]\n\n");
|
||||
Console.WriteLine("=============== End of process, hit any key to finish ===============");
|
||||
Console.ReadKey();
|
||||
}
|
||||
|
||||
// Change this code to create your own sample data
|
||||
#region CreateSingleDataSample
|
||||
// Method to load single data sample to try a single prediction
|
||||
private static ModelInput CreateSingleDataSample()
|
||||
{
|
||||
|
||||
// Use new test data
|
||||
ModelInput sampleForPrediction = new ModelInput
|
||||
{
|
||||
InspectionType = "Complaint",
|
||||
ViolationDescription = "Inadequate sewage or wastewater disposal"
|
||||
};
|
||||
|
||||
return sampleForPrediction;
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<TargetFramework>netcoreapp2.1</TargetFramework>
|
||||
<LangVersion>latest</LangVersion>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.ML" Version="1.3.1" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\RestaurantViolationsML.Model\RestaurantViolationsML.Model.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -0,0 +1,32 @@
|
|||
// This file was auto-generated by ML.NET Model Builder.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using Microsoft.ML;
|
||||
using RestaurantViolationsML.Model;
|
||||
|
||||
namespace RestaurantViolationsML.Model
|
||||
{
|
||||
public class ConsumeModel
|
||||
{
|
||||
// For more info on consuming ML.NET models, visit https://aka.ms/model-builder-consume
|
||||
// Method for consuming model in your app
|
||||
public static ModelOutput Predict(ModelInput input)
|
||||
{
|
||||
|
||||
// Create new MLContext
|
||||
MLContext mlContext = new MLContext();
|
||||
|
||||
// Load model & create prediction engine
|
||||
string modelPath = AppDomain.CurrentDomain.BaseDirectory + "MLModel.zip";
|
||||
ITransformer mlModel = mlContext.Model.Load(modelPath, out var modelInputSchema);
|
||||
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);
|
||||
|
||||
// Use model to make prediction on input data
|
||||
ModelOutput result = predEngine.Predict(input);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
Двоичные данные
samples/modelbuilder/MulticlassClassification_RestaurantViolations/RestaurantViolationsML.Model/MLModel.zip
Normal file
Двоичные данные
samples/modelbuilder/MulticlassClassification_RestaurantViolations/RestaurantViolationsML.Model/MLModel.zip
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1,22 @@
|
|||
// This file was auto-generated by ML.NET Model Builder.
|
||||
|
||||
using Microsoft.ML.Data;
|
||||
|
||||
namespace RestaurantViolationsML.Model
|
||||
{
|
||||
public class ModelInput
|
||||
{
|
||||
[ColumnName("InspectionType"), LoadColumn(0)]
|
||||
public string InspectionType { get; set; }
|
||||
|
||||
|
||||
[ColumnName("ViolationDescription"), LoadColumn(1)]
|
||||
public string ViolationDescription { get; set; }
|
||||
|
||||
|
||||
[ColumnName("RiskCategory"), LoadColumn(2)]
|
||||
public string RiskCategory { get; set; }
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// This file was auto-generated by ML.NET Model Builder.
|
||||
|
||||
using System;
|
||||
using Microsoft.ML.Data;
|
||||
|
||||
namespace RestaurantViolationsML.Model
|
||||
{
|
||||
public class ModelOutput
|
||||
{
|
||||
// ColumnName attribute is used to change the column name from
|
||||
// its default value, which is the name of the field.
|
||||
[ColumnName("PredictedLabel")]
|
||||
public String Prediction { get; set; }
|
||||
public float[] Score { get; set; }
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netstandard2.0</TargetFramework>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.ML" Version="1.3.1" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<None Update="MLModel.zip">
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
Двоичные данные
samples/modelbuilder/MulticlassClassification_RestaurantViolations/images/console.png
Normal file
Двоичные данные
samples/modelbuilder/MulticlassClassification_RestaurantViolations/images/console.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 27 KiB |
Загрузка…
Ссылка в новой задаче