This commit is contained in:
Jonathan Tims 2022-06-22 07:54:27 +01:00
Родитель 5156776f4c
Коммит 2f58f64afb
50 изменённых файлов: 12 добавлений и 3393 удалений

Просмотреть файл

@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.29020.237
# Visual Studio Version 17
VisualStudioVersion = 17.2.32408.312
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{A181C943-2E01-454D-9008-2E3C53AA09CC}"
ProjectSection(SolutionItems) = preProject
@ -56,8 +56,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Learners", "Learners", "{29
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Runners", "Runners", "{3DB795A6-5FE8-447C-89B9-9E608285C6F8}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CommandLine", "src\Learners\Runners\CommandLine\CommandLine.csproj", "{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Common", "src\Learners\Runners\Common\Common.csproj", "{25D28099-E338-4543-B1DE-261439654CA6}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Evaluator", "src\Learners\Runners\Evaluator\Evaluator.csproj", "{040FA938-BE24-4391-86BA-D04B331A787A}"
@ -321,18 +319,6 @@ Global
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Debug|Any CPU.Build.0 = Debug|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Release|Any CPU.ActiveCfg = Release|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Release|Any CPU.Build.0 = Release|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.Debug|Any CPU.Build.0 = Debug|Any CPU
{25D28099-E338-4543-B1DE-261439654CA6}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
@ -570,7 +556,6 @@ Global
{87D09BD4-119E-49C1-B0B4-86DF962A00EE} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
{3DB795A6-5FE8-447C-89B9-9E608285C6F8} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
{25D28099-E338-4543-B1DE-261439654CA6} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
{040FA938-BE24-4391-86BA-D04B331A787A} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
{07E9E91D-6593-4FF9-A266-270ED5241C98} = {2964BB90-4E6D-49ED-AA35-645D94337C76}

Просмотреть файл

@ -4,7 +4,7 @@
# Nightly build using .NET Core
name: 0.4.$(Date:yyMM).$(Date:dd)$(Rev:rr)
name: 0.5.$(Date:yyMM).$(Date:dd)$(Rev:rr)
resources:
- repo: self

Просмотреть файл

@ -4,7 +4,7 @@
# Nightly build for Windows. Tests on x86 and x64. Produces NuGet packages
name: 0.4.$(Date:yyMM).$(Date:dd)$(Rev:rr)
name: 0.5.$(Date:yyMM).$(Date:dd)$(Rev:rr)
resources:
- repo: self

Просмотреть файл

@ -4,7 +4,7 @@
# Official signed build
name: 0.4.$(Date:yyMM).$(Date:dd)$(Rev:rr)
name: 0.5.$(Date:yyMM).$(Date:dd)$(Rev:rr)
resources:
- repo: self

Просмотреть файл

@ -109,26 +109,6 @@ namespace Microsoft.ML.Probabilistic.Learners
#region .Net binary deserialization
/// <summary>
/// Deserializes a Bayes point machine classifier from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <typeparam name="TPredictionSettings">The type of the settings for prediction.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>
Load<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>(string fileName)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
where TPredictionSettings : IBayesPointMachineClassifierPredictionSettings<TLabel>
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>>(fileName);
}
/// <summary>
/// Deserializes a Bayes point machine classifier from a stream and formatter.
/// </summary>
@ -150,24 +130,6 @@ namespace Microsoft.ML.Probabilistic.Learners
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>>(stream, formatter);
}
/// <summary>
/// Deserializes a binary Bayes point machine classifier from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings>(string fileName)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
}
/// <summary>
/// Deserializes a binary Bayes point machine classifier from a stream and formatter.
/// </summary>
@ -187,24 +149,6 @@ namespace Microsoft.ML.Probabilistic.Learners
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
}
/// <summary>
/// Deserializes a multi-class Bayes point machine classifier from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings>(string fileName)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
}
/// <summary>
/// Deserializes a multi-class Bayes point machine classifier from a stream and a formatter.
/// </summary>
@ -224,22 +168,6 @@ namespace Microsoft.ML.Probabilistic.Learners
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
}
/// <summary>
/// Deserializes a binary Bayes point machine classifier from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
}
/// <summary>
/// Deserializes a binary Bayes point machine classifier from a stream and format.
/// </summary>
@ -257,22 +185,6 @@ namespace Microsoft.ML.Probabilistic.Learners
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
}
/// <summary>
/// Deserializes a multi-class Bayes point machine classifier from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
}
/// <summary>
/// Deserializes a multi-class Bayes point machine classifier from a stream and formatter.
/// </summary>
@ -592,23 +504,6 @@ namespace Microsoft.ML.Probabilistic.Learners
#region Internal .Net binary deserialization
/// <summary>
/// Deserializes a binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over factorized weights from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
}
/// <summary>
/// Deserializes a binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over factorized weights from a stream and formatter.
@ -627,23 +522,6 @@ namespace Microsoft.ML.Probabilistic.Learners
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
}
/// <summary>
/// Deserializes a multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over factorized weights from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
}
/// <summary>
/// Deserializes a multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over factorized weights from a stream and formatter.

Просмотреть файл

@ -23,30 +23,6 @@ namespace Microsoft.ML.Probabilistic.Learners
#region Save
/// <summary>
/// Persists a learner to a file.
/// </summary>
/// <param name="learner">The learner to serialize.</param>
/// <param name="fileName">The name of the file.</param>
public static void Save(this ILearner learner, string fileName)
{
if (learner == null)
{
throw new ArgumentNullException(nameof(learner));
}
if (fileName == null)
{
throw new ArgumentNullException(nameof(fileName));
}
using (Stream stream = File.Open(fileName, FileMode.Create))
{
var formatter = new BinaryFormatter();
learner.Save(stream, formatter);
}
}
/// <summary>
/// Serializes a learner to a given stream using a given formatter.
/// </summary>
@ -181,26 +157,6 @@ namespace Microsoft.ML.Probabilistic.Learners
return (TLearner)formatter.Deserialize(stream);
}
/// <summary>
/// Deserializes a learner from a file.
/// </summary>
/// <typeparam name="TLearner">The type of a learner.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized learner object.</returns>
public static TLearner Load<TLearner>(string fileName)
{
if (fileName == null)
{
throw new ArgumentNullException(nameof(fileName));
}
using (Stream stream = File.Open(fileName, FileMode.Open))
{
var formatter = new BinaryFormatter();
return Load<TLearner>(stream, formatter);
}
}
#endregion
#endregion

Просмотреть файл

@ -70,22 +70,6 @@ namespace Microsoft.ML.Probabilistic.Learners
#region .NET binary deserialization
/// <summary>
/// Deserializes a Matchbox recommender from a file.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TUser">The type of a user.</typeparam>
/// <typeparam name="TItem">The type of an item.</typeparam>
/// <typeparam name="TRatingDistribution">The type of a distribution over ratings.</typeparam>
/// <typeparam name="TFeatureSource">The type of a feature source.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized recommender object.</returns>
public static IMatchboxRecommender<TInstanceSource, TUser, TItem, TRatingDistribution, TFeatureSource>
Load<TInstanceSource, TUser, TItem, TRatingDistribution, TFeatureSource>(string fileName)
{
return Utilities.Load<IMatchboxRecommender<TInstanceSource, TUser, TItem, TRatingDistribution, TFeatureSource>>(fileName);
}
/// <summary>
/// Deserializes a recommender from a given stream and formatter.
/// </summary>

Просмотреть файл

@ -1,311 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Utilities;
/// <summary>
/// Utilities for the Bayes point machine classifier modules.
/// </summary>
internal static class BayesPointMachineClassifierModuleUtilities
{
/// <summary>
/// Diagnoses the Bayes point machine classifier on the specified data set.
/// </summary>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <param name="classifier">The Bayes point machine classifier.</param>
/// <param name="trainingSet">The dataset.</param>
/// <param name="maxParameterChangesFileName">The name of the file to store the maximum parameter differences.</param>
/// <param name="modelFileName">The name of the file to store the trained Bayes point machine model.</param>
public static void DiagnoseClassifier<TTrainingSettings>(
IBayesPointMachineClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<string>> classifier,
IList<LabeledFeatureValues> trainingSet,
string maxParameterChangesFileName,
string modelFileName)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
// Create prior distributions over weights
int classCount = trainingSet[0].LabelDistribution.LabelSet.Count;
int featureCount = trainingSet[0].GetDenseFeatureVector().Count;
var priorWeightDistributions = Util.ArrayInit(classCount, c => Util.ArrayInit(featureCount, f => new Gaussian(0.0, 1.0)));
// Create IterationChanged handler
var watch = new Stopwatch();
classifier.IterationChanged += (sender, eventArgs) =>
{
watch.Stop();
double maxParameterChange = MaxDiff(eventArgs.WeightPosteriorDistributions, priorWeightDistributions);
if (!string.IsNullOrEmpty(maxParameterChangesFileName))
{
SaveMaximumParameterDifference(
maxParameterChangesFileName,
eventArgs.CompletedIterationCount,
maxParameterChange,
watch.ElapsedMilliseconds);
}
Console.WriteLine(
"[{0}] Iteration {1,-4} dp = {2,-20} dt = {3,5}ms",
DateTime.Now.ToLongTimeString(),
eventArgs.CompletedIterationCount,
maxParameterChange,
watch.ElapsedMilliseconds);
// Copy weight marginals
for (int c = 0; c < eventArgs.WeightPosteriorDistributions.Count; c++)
{
for (int f = 0; f < eventArgs.WeightPosteriorDistributions[c].Count; f++)
{
priorWeightDistributions[c][f] = eventArgs.WeightPosteriorDistributions[c][f];
}
}
watch.Restart();
};
// Write file header
if (!string.IsNullOrEmpty(maxParameterChangesFileName))
{
using (var writer = new StreamWriter(maxParameterChangesFileName))
{
writer.WriteLine("# time, # iteration, # maximum absolute parameter difference, # iteration time in milliseconds");
}
}
// Train the Bayes point machine classifier
Console.WriteLine("[{0}] Starting training...", DateTime.Now.ToLongTimeString());
watch.Start();
classifier.Train(trainingSet);
// Compute evidence
if (classifier.Settings.Training.ComputeModelEvidence)
{
Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
}
// Save trained model
if (!string.IsNullOrEmpty(modelFileName))
{
classifier.Save(modelFileName);
}
}
/// <summary>
/// Samples weights from the learned weight distribution of the Bayes point machine classifier.
/// </summary>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <param name="classifier">The Bayes point machine used to sample weights from.</param>
/// <param name="samplesFile">The name of the file to which the weights will be written.</param>
public static void SampleWeights<TTrainingSettings>(
IBayesPointMachineClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<string>> classifier,
string samplesFile)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
// Sample weights
var samples = SampleWeights(classifier);
// Write samples to file
if (!string.IsNullOrEmpty(samplesFile))
{
ClassifierPersistenceUtils.SaveVectors(samplesFile, samples);
}
}
/// <summary>
/// Saves the performance metrics to a file with the specified name.
/// </summary>
/// <param name="fileName">The name of the file to save the metrics to.</param>
/// <param name="accuracy">The accuracy.</param>
/// <param name="negativeLogProbability">The mean negative log probability.</param>
/// <param name="auc">The AUC.</param>
/// <param name="evidence">The model's log evidence.</param>
/// <param name="iterationCount">The number of training iterations.</param>
/// <param name="trainingTime">The training time in milliseconds.</param>
public static void SavePerformanceMetrics(
string fileName,
ICollection<double> accuracy,
IEnumerable<double> negativeLogProbability,
IEnumerable<double> auc,
IEnumerable<double> evidence,
IEnumerable<double> iterationCount,
IEnumerable<double> trainingTime)
{
using (var writer = new StreamWriter(fileName))
{
// Write header
for (int fold = 0; fold < accuracy.Count; fold++)
{
if (fold == 0)
{
writer.Write("# ");
}
writer.Write("Fold {0}, ", fold + 1);
}
writer.WriteLine("Mean, Standard deviation");
writer.WriteLine();
// Write metrics
SaveSinglePerformanceMetric(writer, "Accuracy", accuracy);
SaveSinglePerformanceMetric(writer, "Mean negative log probability", negativeLogProbability);
SaveSinglePerformanceMetric(writer, "AUC", auc);
SaveSinglePerformanceMetric(writer, "Log evidence", evidence);
SaveSinglePerformanceMetric(writer, "Training time", trainingTime);
SaveSinglePerformanceMetric(writer, "Iteration count", iterationCount);
}
}
/// <summary>
/// Converts elapsed time in milliseconds into a human readable format.
/// </summary>
/// <param name="elapsedMilliseconds">The elapsed time in milliseconds.</param>
/// <returns>A human readable string of specified time.</returns>
public static string FormatElapsedTime(double elapsedMilliseconds)
{
TimeSpan time = TimeSpan.FromMilliseconds(elapsedMilliseconds);
string formattedTime = time.Hours > 0 ? string.Format("{0}:", time.Hours) : string.Empty;
formattedTime += time.Hours > 0 ? string.Format("{0:D2}:", time.Minutes) : time.Minutes > 0 ? string.Format("{0}:", time.Minutes) : string.Empty;
formattedTime += time.Hours > 0 || time.Minutes > 0 ? string.Format("{0:D2}.{1:D3}", time.Seconds, time.Milliseconds) : string.Format("{0}.{1:D3} seconds", time.Seconds, time.Milliseconds);
return formattedTime;
}
/// <summary>
/// Writes key statistics of the specified data set to console.
/// </summary>
/// <param name="dataSet">The data set.</param>
public static void WriteDataSetInfo(IList<LabeledFeatureValues> dataSet)
{
Console.WriteLine(
"Data set contains {0} instances, {1} classes and {2} features.",
dataSet.Count,
dataSet.Count > 0 ? dataSet.First().LabelDistribution.LabelSet.Count : 0,
dataSet.Count > 0 ? dataSet.First().FeatureSet.Count : 0);
}
#region Helper methods
/// <summary>
/// Writes a single performance metric to the specified writer.
/// </summary>
/// <param name="writer">The writer to write the metrics to.</param>
/// <param name="description">The metric description.</param>
/// <param name="metric">The metric.</param>
private static void SaveSinglePerformanceMetric(TextWriter writer, string description, IEnumerable<double> metric)
{
// Write description
writer.WriteLine("# " + description);
// Write metric
var mva = new MeanVarianceAccumulator();
foreach (double value in metric)
{
writer.Write("{0}, ", value);
mva.Add(value);
}
writer.WriteLine("{0}, {1}", mva.Mean, Math.Sqrt(mva.Variance));
writer.WriteLine();
}
/// <summary>
/// Saves the maximum absolute difference between two given weight distributions to a file with the specified name.
/// </summary>
/// <param name="fileName">The name of the file to save the maximum absolute difference between weight distributions to.</param>
/// <param name="iteration">The inference algorithm iteration.</param>
/// <param name="maxParameterChange">The maximum absolute difference in any parameter of two weight distributions.</param>
/// <param name="elapsedMilliseconds">The elapsed milliseconds.</param>
private static void SaveMaximumParameterDifference(string fileName, int iteration, double maxParameterChange, long elapsedMilliseconds)
{
using (var writer = new StreamWriter(fileName, true))
{
writer.WriteLine("{0}, {1}, {2}, {3}", DateTime.Now.ToLongTimeString(), iteration, maxParameterChange, elapsedMilliseconds);
}
}
/// <summary>
/// Computes the maximum difference in any parameter of two Gaussian distributions.
/// </summary>
/// <param name="first">The first Gaussian.</param>
/// <param name="second">The second Gaussian.</param>
/// <returns>The maximum absolute difference in any parameter.</returns>
/// <remarks>This difference computation is based on mean and variance instead of mean*precision and precision.</remarks>
private static double MaxDiff(IReadOnlyList<IReadOnlyList<Gaussian>> first, Gaussian[][] second)
{
int classCount = first.Count;
int featureCount = first[0].Count;
double maxDiff = double.NegativeInfinity;
for (int c = 0; c < classCount; c++)
{
for (int f = 0; f < featureCount; f++)
{
double firstMean, firstVariance, secondMean, secondVariance;
first[c][f].GetMeanAndVariance(out firstMean, out firstVariance);
second[c][f].GetMeanAndVariance(out secondMean, out secondVariance);
double meanDifference = Math.Abs(firstMean - secondMean);
double varianceDifference = Math.Abs(firstVariance - secondVariance);
if (meanDifference > maxDiff)
{
maxDiff = Math.Abs(meanDifference);
}
if (Math.Abs(varianceDifference) > maxDiff)
{
maxDiff = Math.Abs(varianceDifference);
}
}
}
return maxDiff;
}
/// <summary>
/// Samples weights from the learned weight distribution of the Bayes point machine classifier.
/// </summary>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <param name="classifier">The Bayes point machine used to sample weights from.</param>
/// <returns>The samples from the weight distribution of the Bayes point machine classifier.</returns>
private static IEnumerable<Vector> SampleWeights<TTrainingSettings>(
IBayesPointMachineClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<string>> classifier)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
Debug.Assert(classifier != null, "The classifier must not be null.");
IReadOnlyList<IReadOnlyList<Gaussian>> weightPosteriorDistributions = classifier.WeightPosteriorDistributions;
int classCount = weightPosteriorDistributions.Count < 2 ? 2 : weightPosteriorDistributions.Count;
int featureCount = weightPosteriorDistributions[0].Count;
var samples = new Vector[classCount - 1];
for (int c = 0; c < classCount - 1; c++)
{
var sample = Vector.Zero(featureCount);
for (int f = 0; f < featureCount; f++)
{
sample[f] = weightPosteriorDistributions[c][f].Sample();
}
samples[c] = sample;
}
return samples;
}
#endregion
}
}

Просмотреть файл

@ -1,153 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to cross-validate a binary Bayes point machine classifier on given data.
/// </summary>
internal class BinaryBayesPointMachineClassifierCrossValidationModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string dataSetFile = string.Empty;
string resultsFile = string.Empty;
int crossValidationFoldCount = 5;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--data-set", "FILE", "File with training data", v => dataSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--results", "FILE", "File with cross-validation results", v => resultsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--folds", "NUM", "Number of cross-validation folds (defaults to " + crossValidationFoldCount + ")", v => crossValidationFoldCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
// Load and shuffle data
var dataSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(dataSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(dataSet);
Rand.Restart(562);
Rand.Shuffle(dataSet);
// Create evaluator
var evaluatorMapping = Mappings.Classifier.ForEvaluation();
var evaluator = new ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string>(evaluatorMapping);
// Create performance metrics
var accuracy = new List<double>();
var negativeLogProbability = new List<double>();
var auc = new List<double>();
var evidence = new List<double>();
var iterationCounts = new List<double>();
var trainingTime = new List<double>();
// Run cross-validation
int validationSetSize = dataSet.Count / crossValidationFoldCount;
Console.WriteLine("Running {0}-fold cross-validation on {1}", crossValidationFoldCount, dataSetFile);
// TODO: Use chained mapping to implement cross-validation
for (int fold = 0; fold < crossValidationFoldCount; fold++)
{
// Construct training and validation sets for fold
int validationSetStart = fold * validationSetSize;
int validationSetEnd = (fold + 1 == crossValidationFoldCount)
? dataSet.Count
: (fold + 1) * validationSetSize;
var trainingSet = new List<LabeledFeatureValues>();
var validationSet = new List<LabeledFeatureValues>();
for (int instance = 0; instance < dataSet.Count; instance++)
{
if (validationSetStart <= instance && instance < validationSetEnd)
{
validationSet.Add(dataSet[instance]);
}
else
{
trainingSet.Add(dataSet[instance]);
}
}
// Print info
Console.WriteLine(" Fold {0} [validation set instances {1} - {2}]", fold + 1, validationSetStart, validationSetEnd - 1);
// Create classifier
var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(Mappings.Classifier);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
int currentIterationCount = 0;
classifier.IterationChanged += (sender, eventArgs) => { currentIterationCount = eventArgs.CompletedIterationCount; };
// Train classifier
var stopWatch = new Stopwatch();
stopWatch.Start();
classifier.Train(trainingSet);
stopWatch.Stop();
// Produce predictions
var predictions = classifier.PredictDistribution(validationSet).ToList();
var predictedLabels = predictions.Select(
prediction => prediction.Aggregate((aggregate, next) => next.Value > aggregate.Value ? next : aggregate).Key).ToList();
// Iteration count
iterationCounts.Add(currentIterationCount);
// Training time
trainingTime.Add(stopWatch.ElapsedMilliseconds);
// Compute accuracy
accuracy.Add(1 - (evaluator.Evaluate(validationSet, predictedLabels, Metrics.ZeroOneError) / predictions.Count));
// Compute mean negative log probability
negativeLogProbability.Add(evaluator.Evaluate(validationSet, predictions, Metrics.NegativeLogProbability) / predictions.Count);
// Compute M-measure (averaged pairwise AUC)
auc.Add(evaluator.AreaUnderRocCurve(validationSet, predictions));
// Compute log evidence if desired
evidence.Add(computeModelEvidence ? classifier.LogModelEvidence : double.NaN);
// Persist performance metrics
Console.WriteLine(
" Accuracy = {0,5:0.0000} NegLogProb = {1,5:0.0000} AUC = {2,5:0.0000}{3} Iterations = {4} Training time = {5}",
accuracy[fold],
negativeLogProbability[fold],
auc[fold],
computeModelEvidence ? string.Format(" Log evidence = {0,5:0.0000}", evidence[fold]) : string.Empty,
iterationCounts[fold],
BayesPointMachineClassifierModuleUtilities.FormatElapsedTime(trainingTime[fold]));
BayesPointMachineClassifierModuleUtilities.SavePerformanceMetrics(
resultsFile, accuracy, negativeLogProbability, auc, evidence, iterationCounts, trainingTime);
}
return true;
}
}
}

Просмотреть файл

@ -1,54 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
/// <summary>
/// A command-line module to incrementally train a binary Bayes point machine classifier on some given data.
/// </summary>
internal class BinaryBayesPointMachineClassifierIncrementalTrainingModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingSetFile = string.Empty;
string inputModelFile = string.Empty;
string outputModelFile = string.Empty;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--input-model", "FILE", "File with the trained binary Bayes point machine model", v => inputModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File to store the incrementally trained binary Bayes point machine model", v => outputModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
var classifier = BayesPointMachineClassifier.LoadBinaryClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(inputModelFile);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
classifier.TrainIncremental(trainingSet);
classifier.Save(outputModelFile);
return true;
}
}
}

Просмотреть файл

@ -1,50 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
/// <summary>
/// A command-line module to predict labels given a trained binary Bayes point machine classifier model and a test set.
/// </summary>
internal class BinaryBayesPointMachineClassifierPredictionModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string testSetFile = string.Empty;
string modelFile = string.Empty;
string predictionsFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--test-set", "FILE", "File with test data", v => testSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File with a trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File to store predictions for the test data", v => predictionsFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var testSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(testSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(testSet);
var classifier =
BayesPointMachineClassifier.LoadBinaryClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
// Predict labels
var predictions = classifier.PredictDistribution(testSet);
// Write labels to file
ClassifierPersistenceUtils.SaveLabelDistributions(predictionsFile, predictions);
return true;
}
}
}

Просмотреть файл

@ -1,41 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
/// <summary>
/// A command-line module to sample weights from a trained binary Bayes point machine classifier model.
/// </summary>
internal class BinaryBayesPointMachineClassifierSampleWeightsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string modelFile = string.Empty;
string samplesFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--model", "FILE", "File with a trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--samples", "FILE", "File to store samples of the weights", v => samplesFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var classifier =
BayesPointMachineClassifier.LoadBinaryClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
BayesPointMachineClassifierModuleUtilities.SampleWeights(classifier, samplesFile);
return true;
}
}
}

Просмотреть файл

@ -1,50 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
/// <summary>
/// A command-line module to diagnose training of a binary Bayes point machine classifier on given data.
/// </summary>
internal class BinaryBayesPointMachineClassifierTrainingDiagnosticsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingSetFile = string.Empty;
string maxParameterChangesFile = string.Empty;
string modelFile = string.Empty;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--results", "FILE", "File to store the maximum parameter differences", v => maxParameterChangesFile = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(Mappings.Classifier);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
BayesPointMachineClassifierModuleUtilities.DiagnoseClassifier(classifier, trainingSet, maxParameterChangesFile, modelFile);
return true;
}
}
}

Просмотреть файл

@ -1,63 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Linq;
/// <summary>
/// A command-line module to train a binary Bayes point machine classifier on some given data.
/// </summary>
internal class BinaryBayesPointMachineClassifierTrainingModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingSetFile = string.Empty;
string modelFile = string.Empty;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
var featureSet = trainingSet.Count > 0 ? trainingSet.First().FeatureSet : null;
var mapping = new ClassifierMapping(featureSet);
var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(mapping);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
classifier.Train(trainingSet);
if (classifier.Settings.Training.ComputeModelEvidence)
{
Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
}
classifier.Save(modelFile);
return true;
}
}
}

Просмотреть файл

@ -1,674 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using Microsoft.ML.Probabilistic.Learners.Mappings;
/// <summary>
/// A command-line module to evaluate the label predictions of classifiers.
/// </summary>
internal class ClassifierEvaluationModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string groundTruthFileName = string.Empty;
string predictionsFileName = string.Empty;
string reportFileName = string.Empty;
string calibrationCurveFileName = string.Empty;
string rocCurveFileName = string.Empty;
string precisionRecallCurveFileName = string.Empty;
string positiveClassLabel = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--ground-truth", "FILE", "File with ground truth labels", v => groundTruthFileName = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File with label predictions", v => predictionsFileName = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--report", "FILE", "File to store the evaluation report", v => reportFileName = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--calibration-curve", "FILE", "File to store the empirical calibration curve", v => calibrationCurveFileName = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--roc-curve", "FILE", "File to store the receiver operating characteristic curve", v => rocCurveFileName = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--precision-recall-curve", "FILE", "File to store the precision-recall curve", v => precisionRecallCurveFileName = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--positive-class", "STRING", "Label of the positive class to use in curves", v => positiveClassLabel = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
// Read ground truth
var groundTruth = ClassifierPersistenceUtils.LoadLabeledFeatureValues(groundTruthFileName);
// Read predictions using ground truth label dictionary
var predictions = ClassifierPersistenceUtils.LoadLabelDistributions(predictionsFileName, groundTruth.First().LabelDistribution.LabelSet);
// Check that there are at least two distinct class labels
if (predictions.First().LabelSet.Count < 2)
{
throw new InvalidFileFormatException("Ground truth and predictions must contain at least two distinct class labels.");
}
// Distill distributions and point estimates
var predictiveDistributions = predictions.Select(i => i.ToDictionary()).ToList();
var predictivePointEstimates = predictions.Select(i => i.GetMode()).ToList();
// Create evaluator
var evaluatorMapping = Mappings.Classifier.ForEvaluation();
var evaluator = new ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string>(evaluatorMapping);
// Write evaluation report
if (!string.IsNullOrEmpty(reportFileName))
{
using (var writer = new StreamWriter(reportFileName))
{
this.WriteReportHeader(writer, groundTruthFileName, predictionsFileName);
this.WriteReport(writer, evaluator, groundTruth, predictiveDistributions, predictivePointEstimates);
}
}
// Compute and write the empirical probability calibration curve
positiveClassLabel = this.CheckPositiveClassLabel(groundTruth, positiveClassLabel);
if (!string.IsNullOrEmpty(calibrationCurveFileName))
{
this.WriteCalibrationCurve(calibrationCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
}
// Compute and write the precision-recall curve
if (!string.IsNullOrEmpty(precisionRecallCurveFileName))
{
this.WritePrecisionRecallCurve(precisionRecallCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
}
// Compute and write the receiver operating characteristic curve
if (!string.IsNullOrEmpty(rocCurveFileName))
{
this.WriteRocCurve(rocCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
}
return true;
}
#region Helper methods
/// <summary>
/// Writes the evaluation results to a file with the specified name.
/// </summary>
/// <param name="writer">The name of the file to write the report to.</param>
/// <param name="evaluator">The classifier evaluator.</param>
/// <param name="groundTruth">The ground truth.</param>
/// <param name="predictiveDistributions">The predictive distributions.</param>
/// <param name="predictedLabels">The predicted labels.</param>
private void WriteReport(
StreamWriter writer,
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
IList<LabeledFeatureValues> groundTruth,
ICollection<IDictionary<string, double>> predictiveDistributions,
IEnumerable<string> predictedLabels)
{
// Compute confusion matrix
var confusionMatrix = evaluator.ConfusionMatrix(groundTruth, predictedLabels);
// Compute mean negative log probability
double meanNegativeLogProbability =
evaluator.Evaluate(groundTruth, predictiveDistributions, Metrics.NegativeLogProbability) / predictiveDistributions.Count;
// Compute M-measure (averaged pairwise AUC)
IDictionary<string, IDictionary<string, double>> aucMatrix;
double auc = evaluator.AreaUnderRocCurve(groundTruth, predictiveDistributions, out aucMatrix);
// Compute per-label AUC as well as micro- and macro-averaged AUC
double microAuc;
double macroAuc;
int macroAucClassLabelCount;
var labelAuc = this.ComputeLabelAuc(
confusionMatrix,
evaluator,
groundTruth,
predictiveDistributions,
out microAuc,
out macroAuc,
out macroAucClassLabelCount);
// Instance-averaged performance
this.WriteInstanceAveragedPerformance(writer, confusionMatrix, meanNegativeLogProbability, microAuc);
// Class-averaged performance
this.WriteClassAveragedPerformance(writer, confusionMatrix, auc, macroAuc, macroAucClassLabelCount);
// Performance on individual classes
this.WriteIndividualClassPerformance(writer, confusionMatrix, labelAuc);
// Confusion matrix
this.WriteConfusionMatrix(writer, confusionMatrix);
// Pairwise AUC
this.WriteAucMatrix(writer, aucMatrix);
}
/// <summary>
/// Computes all per-label AUCs as well as the micro- and macro-averaged AUCs.
/// </summary>
/// <param name="confusionMatrix">The confusion matrix.</param>
/// <param name="evaluator">The classifier evaluator.</param>
/// <param name="groundTruth">The ground truth.</param>
/// <param name="predictiveDistributions">The predictive distributions.</param>
/// <param name="microAuc">The micro-averaged area under the receiver operating characteristic curve.</param>
/// <param name="macroAuc">The macro-averaged area under the receiver operating characteristic curve.</param>
/// <param name="macroAucClassLabelCount">The number of class labels for which the AUC if defined.</param>
/// <returns>The area under the receiver operating characteristic curve for each class label.</returns>
private IDictionary<string, double> ComputeLabelAuc(
ConfusionMatrix<string> confusionMatrix,
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
IList<LabeledFeatureValues> groundTruth,
ICollection<IDictionary<string, double>> predictiveDistributions,
out double microAuc,
out double macroAuc,
out int macroAucClassLabelCount)
{
int instanceCount = predictiveDistributions.Count;
var classLabels = confusionMatrix.ClassLabelSet.Elements.ToArray();
int classLabelCount = classLabels.Length;
var labelAuc = new Dictionary<string, double>();
// Compute per-label AUC
macroAucClassLabelCount = classLabelCount;
foreach (var classLabel in classLabels)
{
// One versus rest
double auc;
try
{
auc = evaluator.AreaUnderRocCurve(classLabel, groundTruth, predictiveDistributions);
}
catch (ArgumentException)
{
auc = double.NaN;
macroAucClassLabelCount--;
}
labelAuc.Add(classLabel, auc);
}
// Compute micro- and macro-averaged AUC
microAuc = 0;
macroAuc = 0;
foreach (var label in classLabels)
{
if (double.IsNaN(labelAuc[label]))
{
continue;
}
microAuc += confusionMatrix.TrueLabelCount(label) * labelAuc[label] / instanceCount;
macroAuc += labelAuc[label] / macroAucClassLabelCount;
}
return labelAuc;
}
/// <summary>
/// Writes the header of the evaluation report to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="groundTruthFileName">The name of the file containing the ground truth.</param>
/// <param name="predictionsFileName">The name of the file containing the predictions.</param>
private void WriteReportHeader(StreamWriter writer, string groundTruthFileName, string predictionsFileName)
{
writer.WriteLine();
writer.WriteLine(" Classifier evaluation report ");
writer.WriteLine("******************************");
writer.WriteLine();
writer.WriteLine(" Date: {0}", DateTime.Now);
writer.WriteLine(" Ground truth: {0}", groundTruthFileName);
writer.WriteLine(" Predictions: {0}", predictionsFileName);
}
/// <summary>
/// Writes instance-averaged performance results to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="confusionMatrix">The confusion matrix.</param>
/// <param name="negativeLogProbability">The negative log-probability.</param>
/// <param name="microAuc">The micro-averaged AUC.</param>
private void WriteInstanceAveragedPerformance(
StreamWriter writer,
ConfusionMatrix<string> confusionMatrix,
double negativeLogProbability,
double microAuc)
{
long instanceCount = 0;
long correctInstanceCount = 0;
foreach (var classLabelIndex in confusionMatrix.ClassLabelSet.Indexes)
{
string classLabel = confusionMatrix.ClassLabelSet.GetElementByIndex(classLabelIndex);
instanceCount += confusionMatrix.TrueLabelCount(classLabel);
correctInstanceCount += confusionMatrix[classLabel, classLabel];
}
writer.WriteLine();
writer.WriteLine(" Instance-averaged performance (micro-averages)");
writer.WriteLine("================================================");
writer.WriteLine();
writer.WriteLine(" Precision = {0,10:0.0000}", confusionMatrix.MicroPrecision);
writer.WriteLine(" Recall = {0,10:0.0000}", confusionMatrix.MicroRecall);
writer.WriteLine(" F1 = {0,10:0.0000}", confusionMatrix.MicroF1);
writer.WriteLine();
writer.WriteLine(" #Correct = {0,10}", correctInstanceCount);
writer.WriteLine(" #Total = {0,10}", instanceCount);
writer.WriteLine(" Accuracy = {0,10:0.0000}", confusionMatrix.MicroAccuracy);
writer.WriteLine(" Error = {0,10:0.0000}", 1 - confusionMatrix.MicroAccuracy);
writer.WriteLine();
writer.WriteLine(" AUC = {0,10:0.0000}", microAuc);
writer.WriteLine();
writer.WriteLine(" Log-loss = {0,10:0.0000}", negativeLogProbability);
}
/// <summary>
/// Writes class-averaged performance results to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="confusionMatrix">The confusion matrix.</param>
/// <param name="auc">The AUC.</param>
/// <param name="macroAuc">The macro-averaged AUC.</param>
/// <param name="macroAucClassLabelCount">The number of distinct class labels used to compute macro-averaged AUC.</param>
private void WriteClassAveragedPerformance(
StreamWriter writer,
ConfusionMatrix<string> confusionMatrix,
double auc,
double macroAuc,
int macroAucClassLabelCount)
{
int classLabelCount = confusionMatrix.ClassLabelSet.Count;
writer.WriteLine();
writer.WriteLine(" Class-averaged performance (macro-averages)");
writer.WriteLine("=============================================");
writer.WriteLine();
if (confusionMatrix.MacroPrecisionClassLabelCount < classLabelCount)
{
writer.WriteLine(
" Precision = {0,10:0.0000} {1,10}",
confusionMatrix.MacroPrecision,
"[only " + confusionMatrix.MacroPrecisionClassLabelCount + "/" + classLabelCount + " classes defined]");
}
else
{
writer.WriteLine(" Precision = {0,10:0.0000}", confusionMatrix.MacroPrecision);
}
if (confusionMatrix.MacroRecallClassLabelCount < classLabelCount)
{
writer.WriteLine(
" Recall = {0,10:0.0000} {1,10}",
confusionMatrix.MacroRecall,
"[only " + confusionMatrix.MacroRecallClassLabelCount + "/" + classLabelCount + " classes defined]");
}
else
{
writer.WriteLine(" Recall = {0,10:0.0000}", confusionMatrix.MacroRecall);
}
if (confusionMatrix.MacroF1ClassLabelCount < classLabelCount)
{
writer.WriteLine(
" F1 = {0,10:0.0000} {1,10}",
confusionMatrix.MacroF1,
"[only " + confusionMatrix.MacroF1ClassLabelCount + "/" + classLabelCount + " classes defined]");
}
else
{
writer.WriteLine(" F1 = {0,10:0.0000}", confusionMatrix.MacroF1);
}
writer.WriteLine();
if (confusionMatrix.MacroF1ClassLabelCount < classLabelCount)
{
writer.WriteLine(
" Accuracy = {0,10:0.0000} {1,10}",
confusionMatrix.MacroAccuracy,
"[only " + confusionMatrix.MacroAccuracyClassLabelCount + "/" + classLabelCount + " classes defined]");
writer.WriteLine(
" Error = {0,10:0.0000} {1,10}",
1 - confusionMatrix.MacroAccuracy,
"[only " + confusionMatrix.MacroAccuracyClassLabelCount + "/" + classLabelCount + " classes defined]");
}
else
{
writer.WriteLine(" Accuracy = {0,10:0.0000}", confusionMatrix.MacroAccuracy);
writer.WriteLine(" Error = {0,10:0.0000}", 1 - confusionMatrix.MacroAccuracy);
}
writer.WriteLine();
if (macroAucClassLabelCount < classLabelCount)
{
writer.WriteLine(
" AUC = {0,10:0.0000} {1,10}",
macroAuc,
"[only " + macroAucClassLabelCount + "/" + classLabelCount + " classes defined]");
}
else
{
writer.WriteLine(" AUC = {0,10:0.0000}", macroAuc);
}
writer.WriteLine();
writer.WriteLine(" M (pairwise AUC) = {0,10:0.0000}", auc);
}
/// <summary>
/// Writes performance results for individual classes to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="confusionMatrix">The confusion matrix.</param>
/// <param name="auc">The per-class AUC.</param>
private void WriteIndividualClassPerformance(
StreamWriter writer,
ConfusionMatrix<string> confusionMatrix,
IDictionary<string, double> auc)
{
writer.WriteLine();
writer.WriteLine(" Performance on individual classes");
writer.WriteLine("===================================");
writer.WriteLine();
writer.WriteLine(
" {0,5} {1,15} {2,10} {3,11} {4,9} {5,10} {6,10} {7,10} {8,10}",
"Index",
"Label",
"#Truth",
"#Predicted",
"#Correct",
"Precision",
"Recall",
"F1",
"AUC");
writer.WriteLine("----------------------------------------------------------------------------------------------------");
foreach (var classLabelIndex in confusionMatrix.ClassLabelSet.Indexes)
{
string classLabel = confusionMatrix.ClassLabelSet.GetElementByIndex(classLabelIndex);
writer.WriteLine(
" {0,5} {1,15} {2,10} {3,11} {4,9} {5,10:0.0000} {6,10:0.0000} {7,10:0.0000} {8,10:0.0000}",
classLabelIndex + 1,
classLabel,
confusionMatrix.TrueLabelCount(classLabel),
confusionMatrix.PredictedLabelCount(classLabel),
confusionMatrix[classLabel, classLabel],
confusionMatrix.Precision(classLabel),
confusionMatrix.Recall(classLabel),
confusionMatrix.F1(classLabel),
auc[classLabel]);
}
}
/// <summary>
/// Writes the confusion matrix to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="confusionMatrix">The confusion matrix.</param>
private void WriteConfusionMatrix(StreamWriter writer, ConfusionMatrix<string> confusionMatrix)
{
writer.WriteLine();
writer.WriteLine(" Confusion matrix");
writer.WriteLine("==================");
writer.WriteLine();
writer.WriteLine(confusionMatrix);
}
/// <summary>
/// Writes the matrix of pairwise AUC metrics to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="aucMatrix">The matrix containing the pairwise AUC metrics.</param>
private void WriteAucMatrix(StreamWriter writer, IDictionary<string, IDictionary<string, double>> aucMatrix)
{
writer.WriteLine();
writer.WriteLine(" Pairwise AUC matrix");
writer.WriteLine("=====================");
writer.WriteLine();
const int MaxLabelWidth = 20;
const int MaxValueWidth = 6;
// Widths of the columns
string[] labels = aucMatrix.Keys.ToArray();
int classLabelCount = aucMatrix.Count;
var columnWidths = new int[classLabelCount + 1];
// For each column of the confusion matrix...
for (int c = 0; c < classLabelCount; c++)
{
// ...find the longest string among counts and label
int labelWidth = Math.Min(labels[c].Length, MaxLabelWidth);
columnWidths[c + 1] = Math.Max(MaxValueWidth, labelWidth);
if (labelWidth > columnWidths[0])
{
columnWidths[0] = labelWidth;
}
}
// Print title row
string format = string.Format("{{0,{0}}} \\ Prediction ->", columnWidths[0]);
writer.WriteLine(format, "Truth");
// Print column labels
this.WriteLabel(writer, string.Empty, columnWidths[0]);
for (int c = 0; c < classLabelCount; c++)
{
this.WriteLabel(writer, labels[c], columnWidths[c + 1]);
}
writer.WriteLine();
// For each row (true labels) in confusion matrix...
for (int r = 0; r < classLabelCount; r++)
{
// Print row label
this.WriteLabel(writer, labels[r], columnWidths[0]);
// For each column (predicted labels) in the confusion matrix...
for (int c = 0; c < classLabelCount; c++)
{
// Print count
this.WriteAucValue(writer, labels[r].Equals(labels[c]) ? -1 : aucMatrix[labels[r]][labels[c]], columnWidths[c + 1]);
}
writer.WriteLine();
}
}
/// <summary>
/// Writes the probability calibration plot to the file with the specified name.
/// </summary>
/// <param name="fileName">The name of the file to write the calibration plot to.</param>
/// <param name="evaluator">The classifier evaluator.</param>
/// <param name="groundTruth">The ground truth.</param>
/// <param name="predictiveDistributions">The predictive distributions.</param>
/// <param name="positiveClassLabel">The label of the positive class.</param>
private void WriteCalibrationCurve(
string fileName,
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
IList<LabeledFeatureValues> groundTruth,
IList<IDictionary<string, double>> predictiveDistributions,
string positiveClassLabel)
{
Debug.Assert(predictiveDistributions != null, "The predictive distributions must not be null.");
Debug.Assert(predictiveDistributions.Count > 0, "The predictive distributions must not be empty.");
Debug.Assert(positiveClassLabel != null, "The label of the positive class must not be null.");
var calibrationCurve = evaluator.CalibrationCurve(positiveClassLabel, groundTruth, predictiveDistributions);
double calibrationError = calibrationCurve.Select(i => Metrics.AbsoluteError(i.EmpiricalProbability, i.PredictedProbability)).Average();
using (var writer = new StreamWriter(fileName))
{
writer.WriteLine("# Empirical probability calibration plot");
writer.WriteLine("#");
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
writer.WriteLine("# Calibration error = {0} (mean absolute error)", calibrationError);
writer.WriteLine("#");
writer.WriteLine("# Predicted probability, empirical probability");
foreach (var point in calibrationCurve)
{
writer.WriteLine(point);
}
}
}
/// <summary>
/// Writes the precision-recall curve to the file with the specified name.
/// </summary>
/// <param name="fileName">The name of the file to write the precision-recall curve to.</param>
/// <param name="evaluator">The classifier evaluator.</param>
/// <param name="groundTruth">The ground truth.</param>
/// <param name="predictiveDistributions">The predictive distributions.</param>
/// <param name="positiveClassLabel">The label of the positive class.</param>
private void WritePrecisionRecallCurve(
string fileName,
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
IList<LabeledFeatureValues> groundTruth,
IList<IDictionary<string, double>> predictiveDistributions,
string positiveClassLabel)
{
Debug.Assert(predictiveDistributions != null, "The predictive distributions must not be null.");
Debug.Assert(predictiveDistributions.Count > 0, "The predictive distributions must not be empty.");
Debug.Assert(positiveClassLabel != null, "The label of the positive class must not be null.");
var precisionRecallCurve = evaluator.PrecisionRecallCurve(positiveClassLabel, groundTruth, predictiveDistributions);
using (var writer = new StreamWriter(fileName))
{
writer.WriteLine("# Precision-recall curve");
writer.WriteLine("#");
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
writer.WriteLine("#");
writer.WriteLine("# precision (P), Recall (R)");
foreach (var point in precisionRecallCurve)
{
writer.WriteLine(point);
}
}
}
/// <summary>
/// Writes the receiver operating characteristic curve to the file with the specified name.
/// </summary>
/// <param name="fileName">The name of the file to write the receiver operating characteristic curve to.</param>
/// <param name="evaluator">The classifier evaluator.</param>
/// <param name="groundTruth">The ground truth.</param>
/// <param name="predictiveDistributions">The predictive distributions.</param>
/// <param name="positiveClassLabel">The label of the positive class.</param>
private void WriteRocCurve(
string fileName,
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
IList<LabeledFeatureValues> groundTruth,
IList<IDictionary<string, double>> predictiveDistributions,
string positiveClassLabel)
{
Debug.Assert(predictiveDistributions != null, "The predictive distributions must not be null.");
Debug.Assert(predictiveDistributions.Count > 0, "The predictive distributions must not be empty.");
Debug.Assert(positiveClassLabel != null, "The label of the positive class must not be null.");
var rocCurve = evaluator.ReceiverOperatingCharacteristicCurve(positiveClassLabel, groundTruth, predictiveDistributions);
using (var writer = new StreamWriter(fileName))
{
writer.WriteLine("# Receiver operating characteristic (ROC) curve");
writer.WriteLine("#");
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
writer.WriteLine("#");
writer.WriteLine("# False positive rate (FPR), True positive rate (TPR)");
foreach (var point in rocCurve)
{
writer.WriteLine(point);
}
}
}
/// <summary>
/// Writes a count to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="auc">The count.</param>
/// <param name="width">The width in characters used to print the count.</param>
private void WriteAucValue(StreamWriter writer, double auc, int width)
{
string paddedCount;
if (auc > 0)
{
paddedCount = auc.ToString(CultureInfo.InvariantCulture);
}
else
{
if (auc < 0)
{
paddedCount = ".";
}
else
{
paddedCount = double.IsNaN(auc) ? "NaN" : "0";
}
}
paddedCount = paddedCount.Length > width ? paddedCount.Substring(0, width) : paddedCount;
paddedCount = paddedCount.PadLeft(width + 2);
writer.Write(paddedCount);
}
/// <summary>
/// Writes a label to a specified stream writer.
/// </summary>
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
/// <param name="label">The label.</param>
/// <param name="width">The width in characters used to print the label.</param>
private void WriteLabel(StreamWriter writer, string label, int width)
{
string paddedLabel = label.Length > width ? label.Substring(0, width) : label;
paddedLabel = paddedLabel.PadLeft(width + 2);
writer.Write(paddedLabel);
}
/// <summary>
/// Checks the positive class label.
/// </summary>
/// <param name="groundTruth">The ground truth.</param>
/// <param name="positiveClassLabel">An optional positive class label provided by the user. Defaults to the first class label.</param>
/// <returns>The actually used positive class label.</returns>
private string CheckPositiveClassLabel(IList<LabeledFeatureValues> groundTruth, string positiveClassLabel = null)
{
Debug.Assert(groundTruth != null, "The ground truth labels must not be null.");
Debug.Assert(groundTruth.Count > 0, "There must be at least one ground truth label.");
if (string.IsNullOrEmpty(positiveClassLabel))
{
positiveClassLabel = groundTruth[0].LabelDistribution.LabelSet.Elements.First();
}
else
{
if (!groundTruth.First().LabelDistribution.LabelSet.Contains(positiveClassLabel))
{
throw new ArgumentException(
"The label '" + positiveClassLabel + "' of the positive class is not a valid class label.");
}
}
return positiveClassLabel;
}
#endregion
}
}

Просмотреть файл

@ -1,153 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to cross-validate a multi-class Bayes point machine classifier on given data.
/// </summary>
internal class MulticlassBayesPointMachineClassifierCrossValidationModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string dataSetFile = string.Empty;
string resultsFile = string.Empty;
int crossValidationFoldCount = 5;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--data-set", "FILE", "File with training data", v => dataSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--results", "FILE", "File with cross-validation results", v => resultsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--folds", "NUM", "Number of cross-validation folds (defaults to " + crossValidationFoldCount + ")", v => crossValidationFoldCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
// Load and shuffle data
var dataSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(dataSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(dataSet);
Rand.Restart(562);
Rand.Shuffle(dataSet);
// Create evaluator
var evaluatorMapping = Mappings.Classifier.ForEvaluation();
var evaluator = new ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string>(evaluatorMapping);
// Create performance metrics
var accuracy = new List<double>();
var negativeLogProbability = new List<double>();
var auc = new List<double>();
var evidence = new List<double>();
var iterationCounts = new List<double>();
var trainingTime = new List<double>();
// Run cross-validation
int validationSetSize = dataSet.Count / crossValidationFoldCount;
Console.WriteLine("Running {0}-fold cross-validation on {1}", crossValidationFoldCount, dataSetFile);
// TODO: Use chained mapping to implement cross-validation
for (int fold = 0; fold < crossValidationFoldCount; fold++)
{
// Construct training and validation sets for fold
int validationSetStart = fold * validationSetSize;
int validationSetEnd = (fold + 1 == crossValidationFoldCount)
? dataSet.Count
: (fold + 1) * validationSetSize;
var trainingSet = new List<LabeledFeatureValues>();
var validationSet = new List<LabeledFeatureValues>();
for (int instance = 0; instance < dataSet.Count; instance++)
{
if (validationSetStart <= instance && instance < validationSetEnd)
{
validationSet.Add(dataSet[instance]);
}
else
{
trainingSet.Add(dataSet[instance]);
}
}
// Print info
Console.WriteLine(" Fold {0} [validation set instances {1} - {2}]", fold + 1, validationSetStart, validationSetEnd - 1);
// Create classifier
var classifier = BayesPointMachineClassifier.CreateMulticlassClassifier(Mappings.Classifier);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
int currentIterationCount = 0;
classifier.IterationChanged += (sender, eventArgs) => { currentIterationCount = eventArgs.CompletedIterationCount; };
// Train classifier
var stopWatch = new Stopwatch();
stopWatch.Start();
classifier.Train(trainingSet);
stopWatch.Stop();
// Produce predictions
var predictions = classifier.PredictDistribution(validationSet).ToList();
var predictedLabels = predictions.Select(
prediction => prediction.Aggregate((aggregate, next) => next.Value > aggregate.Value ? next : aggregate).Key).ToList();
// Iteration count
iterationCounts.Add(currentIterationCount);
// Training time
trainingTime.Add(stopWatch.ElapsedMilliseconds);
// Compute accuracy
accuracy.Add(1 - (evaluator.Evaluate(validationSet, predictedLabels, Metrics.ZeroOneError) / predictions.Count));
// Compute mean negative log probability
negativeLogProbability.Add(evaluator.Evaluate(validationSet, predictions, Metrics.NegativeLogProbability) / predictions.Count);
// Compute M-measure (averaged pairwise AUC)
auc.Add(evaluator.AreaUnderRocCurve(validationSet, predictions));
// Compute log evidence if desired
evidence.Add(computeModelEvidence ? classifier.LogModelEvidence : double.NaN);
// Persist performance metrics
Console.WriteLine(
" Accuracy = {0,5:0.0000} NegLogProb = {1,5:0.0000} AUC = {2,5:0.0000}{3} Iterations = {4} Training time = {5}",
accuracy[fold],
negativeLogProbability[fold],
auc[fold],
computeModelEvidence ? string.Format(" Log evidence = {0,5:0.0000}", evidence[fold]) : string.Empty,
iterationCounts[fold],
BayesPointMachineClassifierModuleUtilities.FormatElapsedTime(trainingTime[fold]));
BayesPointMachineClassifierModuleUtilities.SavePerformanceMetrics(
resultsFile, accuracy, negativeLogProbability, auc, evidence, iterationCounts, trainingTime);
}
return true;
}
}
}

Просмотреть файл

@ -1,54 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
/// <summary>
/// A command-line module to incrementally train a multi-class Bayes point machine classifier on some given data.
/// </summary>
internal class MulticlassBayesPointMachineClassifierIncrementalTrainingModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingSetFile = string.Empty;
string inputModelFile = string.Empty;
string outputModelFile = string.Empty;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--input-model", "FILE", "File with the trained multi-class Bayes point machine model", v => inputModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File to store the incrementally trained multi-class Bayes point machine model", v => outputModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
var classifier = BayesPointMachineClassifier.LoadMulticlassClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(inputModelFile);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
classifier.TrainIncremental(trainingSet);
classifier.Save(outputModelFile);
return true;
}
}
}

Просмотреть файл

@ -1,50 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
/// <summary>
/// A command-line module to predict labels given a trained multi-class Bayes point machine classifier model and a test set.
/// </summary>
internal class MulticlassBayesPointMachineClassifierPredictionModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string testSetFile = string.Empty;
string modelFile = string.Empty;
string predictionsFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--test-set", "FILE", "File with test data", v => testSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File with a trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File to store predictions for the test data", v => predictionsFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var testSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(testSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(testSet);
var classifier =
BayesPointMachineClassifier.LoadMulticlassClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
// Predict labels
var predictions = classifier.PredictDistribution(testSet);
// Write labels to file
ClassifierPersistenceUtils.SaveLabelDistributions(predictionsFile, predictions);
return true;
}
}
}

Просмотреть файл

@ -1,41 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
/// <summary>
/// A command-line module to sample weights from a trained multi-class Bayes point machine classifier model.
/// </summary>
internal class MulticlassBayesPointMachineClassifierSampleWeightsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string modelFile = string.Empty;
string samplesFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--model", "FILE", "File with a trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--samples", "FILE", "File to store samples of the weights", v => samplesFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var classifier =
BayesPointMachineClassifier.LoadMulticlassClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
BayesPointMachineClassifierModuleUtilities.SampleWeights(classifier, samplesFile);
return true;
}
}
}

Просмотреть файл

@ -1,50 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
/// <summary>
/// A command-line module to diagnose training of a multi-class Bayes point machine classifier on given data.
/// </summary>
internal class MulticlassBayesPointMachineClassifierTrainingDiagnosticsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingSetFile = string.Empty;
string maxParameterChangesFile = string.Empty;
string modelFile = string.Empty;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--results", "FILE", "File to store the maximum parameter differences", v => maxParameterChangesFile = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
var classifier = BayesPointMachineClassifier.CreateMulticlassClassifier(Mappings.Classifier);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
BayesPointMachineClassifierModuleUtilities.DiagnoseClassifier(classifier, trainingSet, maxParameterChangesFile, modelFile);
return true;
}
}
}

Просмотреть файл

@ -1,63 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Linq;
/// <summary>
/// A command-line module to train a multi-class Bayes point machine classifier on some given data.
/// </summary>
internal class MulticlassBayesPointMachineClassifierTrainingModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingSetFile = string.Empty;
string modelFile = string.Empty;
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
var featureSet = trainingSet.Count > 0 ? trainingSet.First().FeatureSet : null;
var mapping = new ClassifierMapping(featureSet);
var classifier = BayesPointMachineClassifier.CreateMulticlassClassifier(mapping);
classifier.Settings.Training.IterationCount = iterationCount;
classifier.Settings.Training.BatchCount = batchCount;
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
classifier.Train(trainingSet);
if (classifier.Settings.Training.ComputeModelEvidence)
{
Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
}
classifier.Save(modelFile);
return true;
}
}
}

Просмотреть файл

@ -1,50 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="$(MSBuildThisFileDirectory)..\..\..\..\build\common.props" />
<PropertyGroup>
<OutputType>Exe</OutputType>
<AssemblyName>Learner</AssemblyName>
<ErrorReport>prompt</ErrorReport>
<Prefer32Bit>false</Prefer32Bit>
<DefineConstants>TRACE</DefineConstants>
<RootNamespace>Microsoft.ML.Probabilistic.Learners.Runners</RootNamespace>
<Configurations>Debug;Release;DebugFull;DebugCore;ReleaseFull;ReleaseCore</Configurations>
</PropertyGroup>
<Choose>
<When Condition="'$(Configuration)'=='DebugFull' OR '$(Configuration)'=='ReleaseFull'">
<PropertyGroup>
<TargetFramework>net461</TargetFramework>
</PropertyGroup>
</When>
<When Condition="'$(Configuration)'=='DebugCore' OR '$(Configuration)'=='ReleaseCore'">
<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
</PropertyGroup>
</When>
<Otherwise>
<PropertyGroup>
<TargetFrameworks>net5.0;net461</TargetFrameworks>
</PropertyGroup>
</Otherwise>
</Choose>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugFull|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugCore|AnyCPU'">
<DebugSymbols>true</DebugSymbols>
<DebugType>full</DebugType>
<Optimize>false</Optimize>
<DefineConstants>$(DefineConstants);DEBUG</DefineConstants>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\..\Runtime\Runtime.csproj" />
<ProjectReference Include="..\..\Classifier\Classifier.csproj" />
<ProjectReference Include="..\..\Core\Core.csproj" />
<ProjectReference Include="..\..\Recommender\Recommender.csproj" />
<ProjectReference Include="..\Common\Common.csproj" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\..\Shared\SharedAssemblyFileVersion.cs" />
<Compile Include="..\..\..\Shared\SharedAssemblyInfo.cs" />
</ItemGroup>
</Project>

Просмотреть файл

@ -1,74 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
/// <summary>
/// The program.
/// </summary>
public static class Program
{
/// <summary>
/// The entry point for the program.
/// </summary>
/// <param name="args">The array of command-line arguments.</param>
public static void Main(string[] args)
{
// Matchbox recommender
var recommenderModuleSelector = new CommandLineModuleSelector();
recommenderModuleSelector.RegisterModule("SplitData", new RecommenderSplitDataModule());
recommenderModuleSelector.RegisterModule("GenerateNegativeData", new RecommenderGenerateNegativeDataModule());
recommenderModuleSelector.RegisterModule("Train", new RecommenderTrainModule());
recommenderModuleSelector.RegisterModule("PredictRatings", new RecommenderPredictRatingsModule());
recommenderModuleSelector.RegisterModule("RecommendItems", new RecommenderRecommendItemsModule());
recommenderModuleSelector.RegisterModule("FindRelatedUsers", new RecommenderFindRelatedUsersModule());
recommenderModuleSelector.RegisterModule("FindRelatedItems", new RecommenderFindRelatedItemsModule());
recommenderModuleSelector.RegisterModule("EvaluateRatingPrediction", new RecommenderEvaluateRatingPredictionModule());
recommenderModuleSelector.RegisterModule("EvaluateItemRecommendation", new RecommenderEvaluateItemRecommendationModule());
recommenderModuleSelector.RegisterModule("EvaluateFindRelatedUsers", new RecommenderEvaluateFindRelatedUsersModule());
recommenderModuleSelector.RegisterModule("EvaluateFindRelatedItems", new RecommenderEvaluateFindRelatedItemsModule());
// Binary Bayes point machine classifier
var binaryBayesPointMachineModuleSelector = new CommandLineModuleSelector();
binaryBayesPointMachineModuleSelector.RegisterModule("Train", new BinaryBayesPointMachineClassifierTrainingModule());
binaryBayesPointMachineModuleSelector.RegisterModule("TrainIncremental", new BinaryBayesPointMachineClassifierIncrementalTrainingModule());
binaryBayesPointMachineModuleSelector.RegisterModule("Predict", new BinaryBayesPointMachineClassifierPredictionModule());
binaryBayesPointMachineModuleSelector.RegisterModule("CrossValidate", new BinaryBayesPointMachineClassifierCrossValidationModule());
binaryBayesPointMachineModuleSelector.RegisterModule("SampleWeights", new BinaryBayesPointMachineClassifierSampleWeightsModule());
binaryBayesPointMachineModuleSelector.RegisterModule("DiagnoseTrain", new BinaryBayesPointMachineClassifierTrainingDiagnosticsModule());
// Multi-class Bayes point machine classifier
var multiclassBayesPointMachineModuleSelector = new CommandLineModuleSelector();
multiclassBayesPointMachineModuleSelector.RegisterModule("Train", new MulticlassBayesPointMachineClassifierTrainingModule());
multiclassBayesPointMachineModuleSelector.RegisterModule("TrainIncremental", new MulticlassBayesPointMachineClassifierIncrementalTrainingModule());
multiclassBayesPointMachineModuleSelector.RegisterModule("Predict", new MulticlassBayesPointMachineClassifierPredictionModule());
multiclassBayesPointMachineModuleSelector.RegisterModule("CrossValidate", new MulticlassBayesPointMachineClassifierCrossValidationModule());
multiclassBayesPointMachineModuleSelector.RegisterModule("SampleWeights", new MulticlassBayesPointMachineClassifierSampleWeightsModule());
multiclassBayesPointMachineModuleSelector.RegisterModule("DiagnoseTrain", new MulticlassBayesPointMachineClassifierTrainingDiagnosticsModule());
// Classifier
var classifierModuleSelector = new CommandLineModuleSelector();
classifierModuleSelector.RegisterModule("Evaluate", new ClassifierEvaluationModule());
classifierModuleSelector.RegisterModule("BinaryBayesPointMachine", binaryBayesPointMachineModuleSelector);
classifierModuleSelector.RegisterModule("MulticlassBayesPointMachine", multiclassBayesPointMachineModuleSelector);
// Modules
var moduleSelector = new CommandLineModuleSelector();
moduleSelector.RegisterModule("Recommender", recommenderModuleSelector);
moduleSelector.RegisterModule("Classifier", classifierModuleSelector);
try
{
bool success = moduleSelector.Run(args, Environment.GetCommandLineArgs()[0]);
Environment.Exit(success ? 0 : 1);
}
catch (Exception e)
{
Console.WriteLine("An error has occured in one of the modules. {0}", e.Message);
}
}
}
}

Просмотреть файл

@ -1,23 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("CommandLine")]
[assembly: AssemblyDescription("Runs the Infer.NET learner modules from the command line")]
// Setting ComVisible to false makes the types in this assembly not visible
// to COM components. If you need to access a type in this assembly from
// COM, set the ComVisible attribute to true on that type.
[assembly: ComVisible(false)]
// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("9d6aa264-e13a-4314-a39f-d485e21c3103")]
[assembly: InternalsVisibleTo("ML.Probabilistic.Learners.Tests,PublicKey=0024000004800000940000000602000000240000525341310004000001000100551f07a755a3e3f2901fa321ab631d13d6192b4e6ac9c87279500f49d6635cde6902587752eff20402f46f6ea9c3d80e827580a799840aaab9a49b1d2597e4c1798ee93c5cb66851e9d22f4d6e8110571f4a2e59f1d760f7be04fb10e7dc43ee7ed2831907731427b9815c5fe7f4888f9933ee7a1ad5d1f293fd8ab834fac1be")]

Просмотреть файл

@ -1,59 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to evaluate item recommendation.
/// </summary>
internal class RecommenderEvaluateFindRelatedItemsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string testDatasetFile = string.Empty;
string predictionsFile = string.Empty;
string reportFile = string.Empty;
int minCommonRatingCount = 5;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of users that the query item and the related item should have been rated by in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
IDictionary<Item, IEnumerable<Item>> relatedItems = RecommenderPersistenceUtils.LoadRelatedItems(predictionsFile);
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
using (var writer = new StreamWriter(reportFile))
{
writer.WriteLine(
"L1 Sim NDCG: {0:0.000}",
evaluator.RelatedItemsMetric(testDataset, relatedItems, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedManhattanSimilarity));
writer.WriteLine(
"L2 Sim NDCG: {0:0.000}",
evaluator.RelatedItemsMetric(testDataset, relatedItems, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedEuclideanSimilarity));
}
return true;
}
}
}

Просмотреть файл

@ -1,59 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to evaluate item recommendation.
/// </summary>
internal class RecommenderEvaluateFindRelatedUsersModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string testDatasetFile = string.Empty;
string predictionsFile = string.Empty;
string reportFile = string.Empty;
int minCommonRatingCount = 5;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of items that the query user and the related user should have rated in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
IDictionary<User, IEnumerable<User>> relatedUsers = RecommenderPersistenceUtils.LoadRelatedUsers(predictionsFile);
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
using (var writer = new StreamWriter(reportFile))
{
writer.WriteLine(
"L1 Sim NDCG: {0:0.000}",
evaluator.RelatedUsersMetric(testDataset, relatedUsers, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedManhattanSimilarity));
writer.WriteLine(
"L2 Sim NDCG: {0:0.000}",
evaluator.RelatedUsersMetric(testDataset, relatedUsers, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedEuclideanSimilarity));
}
return true;
}
}
}

Просмотреть файл

@ -1,61 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to evaluate item recommendation.
/// </summary>
internal class RecommenderEvaluateItemRecommendationModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string testDatasetFile = string.Empty;
string predictionsFile = string.Empty;
string reportFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
int minRating = Mappings.StarRatingRecommender.GetRatingInfo(testDataset).MinStarRating;
IDictionary<User, IEnumerable<Item>> recommendedItems = RecommenderPersistenceUtils.LoadRecommendedItems(predictionsFile);
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
using (var writer = new StreamWriter(reportFile))
{
writer.WriteLine(
"NDCG: {0:0.000}",
evaluator.ItemRecommendationMetric(
testDataset,
recommendedItems,
Metrics.Ndcg,
rating => Convert.ToDouble(rating) - minRating + 1));
}
return true;
}
}
}

Просмотреть файл

@ -1,58 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to evaluate rating prediction.
/// </summary>
internal class RecommenderEvaluateRatingPredictionModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string testDatasetFile = string.Empty;
string predictionsFile = string.Empty;
string reportFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
IDictionary<User, IDictionary<Item, int>> ratingPredictions = RecommenderPersistenceUtils.LoadPredictedRatings(predictionsFile);
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
using (var writer = new StreamWriter(reportFile))
{
writer.WriteLine(
"Mean absolute error: {0:0.000}",
evaluator.RatingPredictionMetric(testDataset, ratingPredictions, Metrics.AbsoluteError));
writer.WriteLine(
"Root mean squared error: {0:0.000}",
Math.Sqrt(evaluator.RatingPredictionMetric(testDataset, ratingPredictions, Metrics.SquaredError)));
}
return true;
}
}
}

Просмотреть файл

@ -1,57 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
/// <summary>
/// A command-line module to find related items given a trained recommender model and a dataset.
/// </summary>
internal class RecommenderFindRelatedItemsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string datasetFile = string.Empty;
string trainedModelFile = string.Empty;
string predictionsFile = string.Empty;
int maxRelatedItemCount = 5;
int minCommonRatingCount = 5;
int minRelatedItemPoolSize = 5;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--max-items", "NUM", "Maximum number of related items for a single item; defaults to 5", v => maxRelatedItemCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--min-common-users", "NUM", "Minimum number of users that the query item and the related item should have been rated by in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the related item pool for a single item; defaults to 5", v => minRelatedItemPoolSize = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
var evaluator = new RecommenderEvaluator<RecommenderDataset, User, Item, int, int, RatingDistribution>(
Mappings.StarRatingRecommender.ForEvaluation());
IDictionary<Item, IEnumerable<Item>> relatedItems = evaluator.FindRelatedItemsRatedBySameUsers(
trainedModel, testDataset, maxRelatedItemCount, minCommonRatingCount, minRelatedItemPoolSize);
RecommenderPersistenceUtils.SaveRelatedItems(predictionsFile, relatedItems);
return true;
}
}
}

Просмотреть файл

@ -1,57 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
/// <summary>
/// A command-line module to find related users given a trained recommender model and a dataset.
/// </summary>
internal class RecommenderFindRelatedUsersModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string datasetFile = string.Empty;
string trainedModelFile = string.Empty;
string predictionsFile = string.Empty;
int maxRelatedUserCount = 5;
int minCommonRatingCount = 5;
int minRelatedUserPoolSize = 5;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--max-users", "NUM", "Maximum number of related users for a single user; defaults to 5", v => maxRelatedUserCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of items that the query user and the related user should have rated in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the related user pool for a single user; defaults to 5", v => minRelatedUserPoolSize = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
var evaluator = new RecommenderEvaluator<RecommenderDataset, User, Item, int, int, RatingDistribution>(
Mappings.StarRatingRecommender.ForEvaluation());
IDictionary<User, IEnumerable<User>> relatedUsers = evaluator.FindRelatedUsersWhoRatedSameItems(
trainedModel, testDataset, maxRelatedUserCount, minCommonRatingCount, minRelatedUserPoolSize);
RecommenderPersistenceUtils.SaveRelatedUsers(predictionsFile, relatedUsers);
return true;
}
}
}

Просмотреть файл

@ -1,48 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Linq;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to generate negative data for a positive-only recommender dataset.
/// </summary>
internal class RecommenderGenerateNegativeDataModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string inputDatasetFile = string.Empty;
string outputDatasetFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--input-data", "FILE", "Input dataset, treated as if all the ratings are positive", v => inputDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--output-data", "FILE", "Output dataset with both posisitve and negative data", v => outputDatasetFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var generatorMapping = Mappings.StarRatingRecommender.WithGeneratedNegativeData();
var inputDataset = RecommenderDataset.Load(inputDatasetFile);
var outputDataset = new RecommenderDataset(
generatorMapping.GetInstances(inputDataset).Select(i => new RatedUserItem(i.User, i.Item, i.Rating)),
generatorMapping.GetRatingInfo(inputDataset));
outputDataset.Save(outputDatasetFile);
return true;
}
}
}

Просмотреть файл

@ -1,46 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
/// <summary>
/// A command-line module to predict ratings given a trained recommender model and a dataset.
/// </summary>
internal class RecommenderPredictRatingsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string datasetFile = string.Empty;
string trainedModelFile = string.Empty;
string predictionsFile = string.Empty;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
IDictionary<User, IDictionary<Item, int>> predictions = trainedModel.Predict(testDataset);
RecommenderPersistenceUtils.SavePredictedRatings(predictionsFile, predictions);
return true;
}
}
}

Просмотреть файл

@ -1,55 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Learners.Mappings;
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
/// <summary>
/// A command-line module to predict ratings given a trained recommender model and a dataset.
/// </summary>
internal class RecommenderRecommendItemsModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string datasetFile = string.Empty;
string trainedModelFile = string.Empty;
string predictionsFile = string.Empty;
int maxRecommendedItemCount = 5;
int minRecommendationPoolSize = 5;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--max-items", "NUM", "Maximum number of items to recommend; defaults to 5", v => maxRecommendedItemCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the recommendation pool for a single user; defaults to 5", v => minRecommendationPoolSize = v, CommandLineParameterType.Optional);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
var evaluator = new RecommenderEvaluator<RecommenderDataset, User, Item, int, int, RatingDistribution>(
Mappings.StarRatingRecommender.ForEvaluation());
IDictionary<User, IEnumerable<Item>> itemRecommendations = evaluator.RecommendRatedItems(
trainedModel, testDataset, maxRecommendedItemCount, minRecommendationPoolSize);
RecommenderPersistenceUtils.SaveRecommendedItems(predictionsFile, itemRecommendations);
return true;
}
}
}

Просмотреть файл

@ -1,73 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
using Microsoft.ML.Probabilistic.Learners.Mappings;
using Microsoft.ML.Probabilistic.Math;
/// <summary>
/// A command-line module to split a given recommendation dataset into training and test parts.
/// </summary>
internal class RecommenderSplitDataModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string inputDatasetFile = string.Empty;
string outputTrainingDatasetFile = string.Empty;
string outputTestDatasetFile = string.Empty;
double trainingOnlyUserFraction = 0.5;
double testUserRatingTrainingFraction = 0.25;
double coldUserFraction = 0;
double coldItemFraction = 0;
double ignoredUserFraction = 0;
double ignoredItemFraction = 0;
bool removeOccasionalColdItems = false;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--input-data", "FILE", "Dataset to split", v => inputDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--output-data-train", "FILE", "Training part of the split dataset", v => outputTrainingDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--output-data-test", "FILE", "Test part of the split dataset", v => outputTestDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--training-users", "NUM", "Fraction of training-only users; defaults to 0.5", (double v) => trainingOnlyUserFraction = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--test-user-training-ratings", "NUM", "Fraction of test user ratings for training; defaults to 0.25", (double v) => testUserRatingTrainingFraction = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--cold-users", "NUM", "Fraction of cold (test-only) users; defaults to 0", (double v) => coldUserFraction = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--cold-items", "NUM", "Fraction of cold (test-only) items; defaults to 0", (double v) => coldItemFraction = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--ignored-users", "NUM", "Fraction of ignored users; defaults to 0", (double v) => ignoredUserFraction = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--ignored-items", "NUM", "Fraction of ignored items; defaults to 0", (double v) => ignoredItemFraction = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--remove-occasional-cold-items", "Remove occasionally produced cold items", () => removeOccasionalColdItems = true);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
var splittingMapping = Mappings.StarRatingRecommender.SplitToTrainTest(
trainingOnlyUserFraction,
testUserRatingTrainingFraction,
coldUserFraction,
coldItemFraction,
ignoredUserFraction,
ignoredItemFraction,
removeOccasionalColdItems);
var inputDataset = RecommenderDataset.Load(inputDatasetFile);
var outputTrainingDataset = new RecommenderDataset(
splittingMapping.GetInstances(SplitInstanceSource.Training(inputDataset)),
inputDataset.StarRatingInfo);
outputTrainingDataset.Save(outputTrainingDatasetFile);
var outputTestDataset = new RecommenderDataset(
splittingMapping.GetInstances(SplitInstanceSource.Test(inputDataset)),
inputDataset.StarRatingInfo);
outputTestDataset.Save(outputTestDatasetFile);
return true;
}
}
}

Просмотреть файл

@ -1,56 +0,0 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Learners.Runners
{
/// <summary>
/// A command-line module to train a recommendation model on a given dataset.
/// </summary>
internal class RecommenderTrainModule : CommandLineModule
{
/// <summary>
/// Runs the module.
/// </summary>
/// <param name="args">The command line arguments for the module.</param>
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
/// <returns>True if the run was successful, false otherwise.</returns>
public override bool Run(string[] args, string usagePrefix)
{
string trainingDatasetFile = string.Empty;
string trainedModelFile = string.Empty;
int traitCount = MatchboxRecommenderTrainingSettings.TraitCountDefault;
int iterationCount = MatchboxRecommenderTrainingSettings.IterationCountDefault;
int batchCount = MatchboxRecommenderTrainingSettings.BatchCountDefault;
bool useUserFeatures = MatchboxRecommenderTrainingSettings.UseUserFeaturesDefault;
bool useItemFeatures = MatchboxRecommenderTrainingSettings.UseItemFeaturesDefault;
var parser = new CommandLineParser();
parser.RegisterParameterHandler("--training-data", "FILE", "Training dataset", v => trainingDatasetFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--trained-model", "FILE", "Trained model file", v => trainedModelFile = v, CommandLineParameterType.Required);
parser.RegisterParameterHandler("--traits", "NUM", "Number of traits (defaults to " + traitCount + ")", v => traitCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--iterations", "NUM", "Number of inference iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
parser.RegisterParameterHandler("--use-user-features", "Use user features in the model (defaults to " + useUserFeatures + ")", () => useUserFeatures = true);
parser.RegisterParameterHandler("--use-item-features", "Use item features in the model (defaults to " + useItemFeatures + ")", () => useItemFeatures = true);
if (!parser.TryParse(args, usagePrefix))
{
return false;
}
RecommenderDataset trainingDataset = RecommenderDataset.Load(trainingDatasetFile);
var recommender = MatchboxRecommender.Create(Mappings.StarRatingRecommender);
recommender.Settings.Training.TraitCount = traitCount;
recommender.Settings.Training.IterationCount = iterationCount;
recommender.Settings.Training.BatchCount = batchCount;
recommender.Settings.Training.UseUserFeatures = useUserFeatures;
recommender.Settings.Training.UseItemFeatures = useItemFeatures;
recommender.Train(trainingDataset);
recommender.Save(trainedModelFile);
return true;
}
}
}

Просмотреть файл

@ -147,7 +147,7 @@ namespace Microsoft.ML.Probabilistic.Collections
FileStats.AddRead();
string path = prefix + index.ToString(CultureInfo.InvariantCulture) + ".bin";
if (!File.Exists(path)) return default(T);
IFormatter formatter = new BinaryFormatter();
IFormatter formatter = new BinaryFormatter(); // This is only used within the runtime for caching temporary files -- it is not used to persist data and no untrusted data is read.
using (var stream = new FileStream(prefix + index.ToString(CultureInfo.InvariantCulture) + ".bin", FileMode.Open, FileAccess.Read, FileShare.Read))
{
return (T) formatter.Deserialize(stream);
@ -163,7 +163,7 @@ namespace Microsoft.ML.Probabilistic.Collections
if (object.ReferenceEquals(value, null) || value.Equals(default(T))) File.Delete(path);
else
{
IFormatter formatter = new BinaryFormatter();
IFormatter formatter = new BinaryFormatter(); // This is only used within the runtime for caching temporary files -- it is not used to persist data and no untrusted data is read.
using (var stream = new FileStream(path, FileMode.Create, FileAccess.Write, FileShare.None))
{
formatter.Serialize(stream, value);

Просмотреть файл

@ -90,21 +90,6 @@ namespace Microsoft.ML.Probabilistic.Serialization
return deserializedVersion;
}
/// <summary>
/// Reads an object from the reader of a binary stream.
/// </summary>
/// <param name="reader">The reader of the binary stream.</param>
/// <returns>The object constructed from the binary stream.</returns>
public static T ReadObject<T>(this BinaryReader reader)
{
if (reader == null)
{
throw new ArgumentNullException(nameof(reader));
}
return ByteArrayToObject<T>(reader.ReadByteArray());
}
/// <summary>
/// Reads a <see cref="Guid"/> from the reader of a binary stream.
/// </summary>
@ -260,27 +245,5 @@ namespace Microsoft.ML.Probabilistic.Serialization
return new GammaArray(length, i => ReadGamma(reader));
}
#region Helper methods
/// <summary>
/// Converts a byte array to an object.
/// </summary>
/// <param name="array">The byte array to convert.</param>
/// <returns>The object for the specified byte array.</returns>
private static T ByteArrayToObject<T>(byte[] array)
{
Debug.Assert(array != null, "The array must not be null.");
using (var memoryStream = new MemoryStream())
{
var binaryFormatter = new BinaryFormatter();
memoryStream.Write(array, 0, array.Length);
memoryStream.Seek(0, SeekOrigin.Begin);
return (T)binaryFormatter.Deserialize(memoryStream);
}
}
#endregion
}
}

Просмотреть файл

@ -19,28 +19,6 @@ namespace Microsoft.ML.Probabilistic.Serialization
/// </summary>
public static class BinaryWriterExtensions
{
/// <summary>
/// Writes the specified object in binary to the stream.
/// </summary>
/// <param name="writer">The binary writer.</param>
/// <param name="obj">The object to write to the stream.</param>
public static void WriteObject(this BinaryWriter writer, object obj)
{
if (writer == null)
{
throw new ArgumentNullException(nameof(writer));
}
if (obj == null)
{
throw new ArgumentNullException(nameof(obj));
}
byte[] array = ObjectToByteArray(obj);
writer.Write(array.Length);
writer.Write(array);
}
/// <summary>
/// Writes the specified <see cref="Guid"/> in binary to the stream.
/// </summary>
@ -198,26 +176,5 @@ namespace Microsoft.ML.Probabilistic.Serialization
writer.Write(element);
}
}
#region Helper methods
/// <summary>
/// Converts an object to a byte array.
/// </summary>
/// <param name="obj">The object to convert.</param>
/// <returns>The byte array of the specified object.</returns>
private static byte[] ObjectToByteArray(object obj)
{
Debug.Assert(obj != null, "The object must not be null.");
using (var memoryStream = new MemoryStream())
{
var binaryFormatter = new BinaryFormatter();
binaryFormatter.Serialize(memoryStream, obj);
return memoryStream.ToArray();
}
}
#endregion
}
}

Просмотреть файл

@ -41,7 +41,7 @@ namespace Microsoft.ML.Probabilistic.Serialization
public T ReadObject<T>()
{
return binaryReader.ReadObject<T>();
throw new NotImplementedException("This reader cannot read objects");
}
public string ReadString()

Просмотреть файл

@ -48,7 +48,7 @@ namespace Microsoft.ML.Probabilistic.Serialization
public void WriteObject(object value)
{
binaryWriter.WriteObject(value);
throw new NotImplementedException("This writer cannot write objects.");
}
}
}

Просмотреть файл

@ -179,7 +179,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
#region Serialization
/// <summary>
/// Constructor used by Json and BinaryFormatter serializers. Informally needed to be
/// Constructor used by Json serializer. Informally needed to be
/// implemented for <see cref="ISerializable"/> interface.
/// </summary>
internal DataContainer(SerializationInfo info, StreamingContext context)

Просмотреть файл

@ -1800,7 +1800,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
new ImmutableDiscreteChar(Storage.Read(readInt32, readDouble));
/// <summary>
/// Constructor used during deserialization by Newtonsoft.Json and BinaryFormatter.
/// Constructor used during deserialization by Newtonsoft.Json.
/// </summary>
private ImmutableDiscreteChar(SerializationInfo info, StreamingContext context) =>
this.data_ = Storage.FromSerializationInfo(info);

Просмотреть файл

@ -520,22 +520,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in native format and
/// features in a dense representation.
/// </summary>
[Fact]
public void DenseBinaryNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryNativeMapping),
this.denseNativeTrainingData,
this.denseNativePredictionData,
this.expectedPredictiveBernoulliDistributions,
this.expectedIncrementalPredictiveBernoulliDistributions,
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier for data in native format and
/// features in a dense representation.
@ -820,22 +804,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in native format and
/// features in a sparse representation.
/// </summary>
[Fact]
public void SparseBinaryNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryNativeMapping),
this.sparseNativeTrainingData,
this.sparseNativePredictionData,
this.expectedPredictiveBernoulliDistributions,
this.expectedIncrementalPredictiveBernoulliDistributions,
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier for data in native format and
/// features in a sparse representation.
@ -1079,22 +1047,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in native format and
/// features in a dense representation.
/// </summary>
[Fact]
public void DenseMulticlassNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassNativeMapping),
this.denseNativeTrainingData,
this.denseNativePredictionData,
this.expectedPredictiveDiscreteDistributions,
this.expectedIncrementalPredictiveDiscreteDistributions,
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in native format and
/// features in a dense representation.
@ -1361,22 +1313,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in native format and
/// features in a sparse representation.
/// </summary>
[Fact]
public void SparseMulticlassNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassNativeMapping),
this.sparseNativeTrainingData,
this.sparseNativePredictionData,
this.expectedPredictiveDiscreteDistributions,
this.expectedIncrementalPredictiveDiscreteDistributions,
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in native format and
/// features in a sparse representation.
@ -1619,22 +1555,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionSimpleTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in standard format and
/// features in a dense representation.
/// </summary>
[Fact]
public void DenseBinaryStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryStandardMapping),
this.denseStandardTrainingData,
this.denseStandardPredictionData,
this.expectedPredictiveBernoulliStandardDistributions,
this.expectedIncrementalPredictiveBernoulliStandardDistributions,
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier for data in standard format and
/// features in a dense representation.
@ -1851,22 +1771,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in standard format and
/// features in a sparse representation.
/// </summary>
[Fact]
public void SparseBinaryStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryStandardMapping),
this.sparseStandardTrainingData,
this.sparseStandardPredictionData,
this.expectedPredictiveBernoulliStandardDistributions,
this.expectedIncrementalPredictiveBernoulliStandardDistributions,
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier for data in standard format and
/// features in a sparse representation.
@ -2074,22 +1978,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in standard format and
/// features in a dense representation.
/// </summary>
[Fact]
public void DenseMulticlassStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassStandardMapping),
this.denseStandardTrainingData,
this.denseStandardPredictionData,
this.expectedPredictiveDiscreteStandardDistributions,
this.expectedIncrementalPredictiveDiscreteStandardDistributions,
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in standard format and
/// features in a dense representation.
@ -2316,22 +2204,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in standard format and
/// features in a sparse representation.
/// </summary>
[Fact]
public void SparseMulticlassStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassStandardMapping),
this.sparseStandardTrainingData,
this.sparseStandardPredictionData,
this.expectedPredictiveDiscreteStandardDistributions,
this.expectedIncrementalPredictiveDiscreteStandardDistributions,
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in standard format and
/// features in a sparse representation.
@ -2569,22 +2441,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a dense representation.
/// </summary>
[Fact]
public void GaussianDenseBinaryNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryNativeMapping),
this.denseNativeTrainingData,
this.denseNativePredictionData,
this.gaussianPriorExpectedPredictiveBernoulliDistributions,
this.gaussianPriorExpectedIncrementalPredictiveBernoulliDistributions,
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a dense representation.
@ -2839,22 +2695,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a sparse representation.
/// </summary>
[Fact]
public void GaussianSparseBinaryNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryNativeMapping),
this.sparseNativeTrainingData,
this.sparseNativePredictionData,
this.gaussianPriorExpectedPredictiveBernoulliDistributions,
this.gaussianPriorExpectedIncrementalPredictiveBernoulliDistributions,
CheckPredictedBernoulliDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a sparse representation.
@ -3099,22 +2939,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a dense representation.
/// </summary>
[Fact]
public void GaussianDenseMulticlassNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassNativeMapping),
this.denseNativeTrainingData,
this.denseNativePredictionData,
this.gaussianPriorExpectedPredictiveDiscreteDistributions,
this.gaussianPriorExpectedIncrementalPredictiveDiscreteDistributions,
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a dense representation.
@ -3382,22 +3206,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a sparse representation.
/// </summary>
[Fact]
public void GaussianSparseMulticlassNativeSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassNativeMapping),
this.sparseNativeTrainingData,
this.sparseNativePredictionData,
this.gaussianPriorExpectedPredictiveDiscreteDistributions,
this.gaussianPriorExpectedIncrementalPredictiveDiscreteDistributions,
CheckPredictedDiscreteDistributionNativeTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in native format and features in a sparse representation.
@ -3593,22 +3401,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a dense representation.
/// </summary>
[Fact]
public void GaussianDenseBinaryStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryStandardMapping),
this.denseStandardTrainingData,
this.denseStandardPredictionData,
this.gaussianPriorExpectedPredictiveBernoulliStandardDistributions,
this.gaussianPriorExpectedIncrementalPredictiveBernoulliStandardDistributions,
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a dense representation.
@ -3826,22 +3618,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a sparse representation.
/// </summary>
[Fact]
public void GaussianSparseBinaryStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryStandardMapping),
this.sparseStandardTrainingData,
this.sparseStandardPredictionData,
this.gaussianPriorExpectedPredictiveBernoulliStandardDistributions,
this.gaussianPriorExpectedIncrementalPredictiveBernoulliStandardDistributions,
CheckPredictedBernoulliDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a sparse representation.
@ -4050,22 +3826,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a dense representation.
/// </summary>
[Fact]
public void GaussianDenseMulticlassStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassStandardMapping),
this.denseStandardTrainingData,
this.denseStandardPredictionData,
this.gaussianPriorExpectedPredictiveDiscreteStandardDistributions,
this.gaussianPriorExpectedIncrementalPredictiveDiscreteStandardDistributions,
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a dense representation.
@ -4293,22 +4053,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a sparse representation.
/// </summary>
[Fact]
public void GaussianSparseMulticlassStandardSerializationRegressionTest()
{
TestRegressionSerialization(
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassStandardMapping),
this.sparseStandardTrainingData,
this.sparseStandardPredictionData,
this.gaussianPriorExpectedPredictiveDiscreteStandardDistributions,
this.gaussianPriorExpectedIncrementalPredictiveDiscreteStandardDistributions,
CheckPredictedDiscreteDistributionStandardTestingDataset);
}
/// <summary>
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
/// over weights for data in standard format and features in a sparse representation.
@ -4988,57 +4732,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
});
}
/// <summary>
/// Tests .NET serialization and deserialization of the Bayes point machine classifier.
/// </summary>
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
/// <typeparam name="TInstance">The type of an instance.</typeparam>
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
/// <param name="classifier">The Bayes point machine classifier.</param>
/// <param name="trainingData">The training data.</param>
/// <param name="testData">The prediction data.</param>
/// <param name="expectedLabelDistributions">The expected label distributions.</param>
/// <param name="expectedIncrementalLabelDistributions">The expected label distributions for incremental training.</param>
/// <param name="checkPrediction">A method which asserts the equality of expected and predicted distributions.</param>
private static void TestRegressionSerialization<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings>(
IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<TLabel>> classifier,
TInstanceSource trainingData,
TInstanceSource testData,
IEnumerable<TLabelDistribution> expectedLabelDistributions,
IEnumerable<TLabelDistribution> expectedIncrementalLabelDistributions,
Action<IEnumerable<TLabelDistribution>, IEnumerable<TLabelDistribution>, double> checkPrediction)
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
const string TrainedFileName = "trainedClassifier.bin";
const string UntrainedFileName = "untrainedClassifier.bin";
// Train and serialize
classifier.Settings.Training.IterationCount = IterationCount;
classifier.Save(UntrainedFileName);
classifier.Train(trainingData);
classifier.Save(TrainedFileName);
// Deserialize and test
var trainedClassifier = BayesPointMachineClassifier.Load<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<TLabel>>(TrainedFileName);
var untrainedClassifier = BayesPointMachineClassifier.Load<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<TLabel>>(UntrainedFileName);
untrainedClassifier.Train(trainingData);
checkPrediction(expectedLabelDistributions, trainedClassifier.PredictDistribution(testData), Tolerance);
checkPrediction(expectedLabelDistributions, untrainedClassifier.PredictDistribution(testData), Tolerance);
// Incremental training
trainedClassifier.TrainIncremental(trainingData);
untrainedClassifier.TrainIncremental(trainingData);
checkPrediction(expectedIncrementalLabelDistributions, trainedClassifier.PredictDistribution(testData), Tolerance);
checkPrediction(expectedIncrementalLabelDistributions, untrainedClassifier.PredictDistribution(testData), Tolerance);
}
#region Custom binary serialization
/// <summary>

Просмотреть файл

@ -60,7 +60,6 @@
<ProjectReference Include="..\..\..\src\Learners\Classifier\Classifier.csproj" />
<ProjectReference Include="..\..\..\src\Learners\Core\Core.csproj" />
<ProjectReference Include="..\..\..\src\Learners\Recommender\Recommender.csproj" />
<ProjectReference Include="..\..\..\src\Learners\Runners\CommandLine\CommandLine.csproj" />
<ProjectReference Include="..\..\..\src\Learners\Runners\Common\Common.csproj" />
</ItemGroup>
<ItemGroup>

Просмотреть файл

@ -642,75 +642,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
this.TestDataConsistency(data, true, false); // Valid item features restored, must not throw
}
/// <summary>
/// Tests serialization/deserialization of the recommender operating on the native Matchbox data format.
/// </summary>
[Fact]
public void NativeDataFormatSerializationTest()
{
const string TrainedFileName = "trainedNativeRecommender.bin";
const string NotTrainedFileName = "notTrainedNativeRecommender.bin";
// Train and serialize
{
var recommender = this.CreateNativeDataFormatMatchboxRecommender();
recommender.Save(NotTrainedFileName);
recommender.Train(this.nativeTrainingData, this.nativeTrainingData);
recommender.Save(TrainedFileName);
}
// Deserialize and test
{
var trainedRecommender = MatchboxRecommender.Load<NativeDataset, int, int, Discrete, NativeDataset>(TrainedFileName);
var notTrainedRecommender = MatchboxRecommender.Load<NativeDataset, int, int, Discrete, NativeDataset>(NotTrainedFileName);
notTrainedRecommender.Train(this.nativeTrainingData, this.nativeTrainingData);
this.VerifyNativeRatingDistributionOfUserOneAndItemThree(trainedRecommender.PredictDistribution(1, 3, this.nativeTrainingData));
Assert.Equal(MaxStarRating - MinStarRating, trainedRecommender.Predict(1, 3, this.nativeTrainingData));
this.VerifyNativeRatingDistributionOfUserOneAndItemThree(notTrainedRecommender.PredictDistribution(1, 3, this.nativeTrainingData));
Assert.Equal(MaxStarRating - MinStarRating, notTrainedRecommender.Predict(1, 3, this.nativeTrainingData));
}
}
/// <summary>
/// Tests serialization/deserialization of the recommender operating on the standard data format.
/// </summary>
[Fact]
public void StandardDataFormatSerializationTest()
{
const int BatchCount = 2;
const string TrainedFileName = "trainedStandardRecommender.bin";
const string NotTrainedFileName = "notTrainedStandardRecommender.bin";
// Add features for cold test set user and item
this.standardTrainingDataFeatures.UserFeatures.Add(User.WithId("u2"), Vector.FromArray(4, 2, 1.3, 1.2, 2));
this.standardTrainingDataFeatures.ItemFeatures.Add(Item.WithId("i4"), Vector.FromArray(6.3, 0.5));
// Train and serialize
{
var recommender = this.CreateStandardDataFormatMatchboxRecommender();
recommender.Settings.Training.BatchCount = BatchCount;
recommender.Save(NotTrainedFileName);
recommender.Train(this.standardTrainingData, this.standardTrainingDataFeatures);
recommender.Save(TrainedFileName);
CheckStandardRatingPrediction(recommender, this.standardTrainingDataFeatures);
}
// Deserialize and test
{
var trainedRecommender = MatchboxRecommender.Load<StandardDataset, User, Item, RatingDistribution, FeatureProvider>(TrainedFileName);
var notTrainedRecommender = MatchboxRecommender.Load<StandardDataset, User, Item, RatingDistribution, FeatureProvider>(NotTrainedFileName);
notTrainedRecommender.Train(this.standardTrainingData, this.standardTrainingDataFeatures);
CheckStandardRatingPrediction(trainedRecommender, this.standardTrainingDataFeatures);
CheckStandardRatingPrediction(notTrainedRecommender, this.standardTrainingDataFeatures);
}
}
/// <summary>
/// Tests binary custom serialization/deserialization of the recommender operating on the native data format.
/// </summary>

Просмотреть файл

@ -136,21 +136,6 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
writer.Dispose();
#if NETFRAMEWORK
// In the .NET 5.0 BinaryFormatter is obsolete
// and would produce errors. This test code should be migrated.
// See https://aka.ms/binaryformatter
if (true)
{
BinaryFormatter serializer = new BinaryFormatter();
using (Stream stream = File.Create(Path.Combine(dataFolder, "weights.bin")))
{
serializer.Serialize(stream, train.wPost);
serializer.Serialize(stream, train.biasPost);
}
}
#endif
}
// (0.5,0.5):
@ -170,37 +155,6 @@ namespace Microsoft.ML.Probabilistic.Tests
#pragma warning restore 162
#endif
#if NETFRAMEWORK
public static void Rcv1Test2()
{
GaussianArray wPost;
Gaussian biasPost;
BinaryFormatter serializer = new BinaryFormatter();
using (Stream stream = File.OpenRead(Path.Combine(dataFolder, "weights.bin")))
{
wPost = (GaussianArray)serializer.Deserialize(stream);
biasPost = (Gaussian)serializer.Deserialize(stream);
}
if (true)
{
GaussianEstimator est = new GaussianEstimator();
foreach (Gaussian item in wPost) est.Add(item.GetMean());
Console.WriteLine("weight distribution = {0}", est.GetDistribution(new Gaussian()));
}
var predict = new BpmPredict2();
predict.SetPriors(wPost, biasPost);
int count = 0;
int errors = 0;
foreach (Instance instance in new VwReader(Path.Combine(dataFolder, "rcv1.test.vw.gz")))
{
bool yPred = predict.Predict(instance);
if (yPred != instance.label) errors++;
count++;
}
Console.WriteLine("error rate = {0} = {1}/{2}", (double) errors/count, errors, count);
}
#endif
public static void Rcv1Test3()
{
int nf = 47152;

Просмотреть файл

@ -32,6 +32,7 @@ namespace Microsoft.ML.Probabilistic.Tests
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Algorithms;
using Microsoft.ML.Probabilistic.Models.Attributes;
using System.Runtime.Serialization;
public class DistributedTests
{

Просмотреть файл

@ -82,18 +82,6 @@ namespace Microsoft.ML.Probabilistic.Tests
mc.AssertEqualTo(mc2);
}
#if NETFRAMEWORK
[Fact]
public void BinaryFormatterTest()
{
var mc = new MyClass();
mc.Initialize(skipStringDistributions: true);
var mc2 = CloneBinaryFormatter(mc);
mc.AssertEqualTo(mc2);
}
#endif
[Fact]
public void JsonNetSerializerTest()
{
@ -104,46 +92,6 @@ namespace Microsoft.ML.Probabilistic.Tests
mc.AssertEqualTo(mc2);
}
#if NETFRAMEWORK
[Fact]
public void VectorSerializeTests()
{
Sparsity approxSparsity = Sparsity.ApproximateWithTolerance(0.001);
double[] fromArray = new double[] {1.2, 2.3, 3.4, 1.2, 1.2, 2.3};
Vector vdense = Vector.FromArray(fromArray);
Vector vsparse = Vector.FromArray(fromArray, Sparsity.Sparse);
Vector vapprox = Vector.FromArray(fromArray, approxSparsity);
MemoryStream stream = new MemoryStream();
BinaryFormatter serializer = new BinaryFormatter();
serializer.Serialize(stream, vdense);
serializer.Serialize(stream, vsparse);
serializer.Serialize(stream, vapprox);
stream.Position = 0;
Vector vdense2 = (Vector) serializer.Deserialize(stream);
SparseVector vsparse2 = (SparseVector) serializer.Deserialize(stream);
ApproximateSparseVector vapprox2 = (ApproximateSparseVector) serializer.Deserialize(stream);
Assert.Equal(6, vdense2.Count);
for (int i = 0; i < fromArray.Length; i++) Assert.Equal(fromArray[i], vdense2[i]);
Assert.Equal(vdense2.Sparsity, Sparsity.Dense);
Assert.Equal(6, vsparse2.Count);
for (int i = 0; i < fromArray.Length; i++) Assert.Equal(vsparse2[i], fromArray[i]);
Assert.Equal(vsparse2.Sparsity, Sparsity.Sparse);
Assert.Equal(1.2, vsparse2.CommonValue);
Assert.Equal(3, vsparse2.SparseValues.Count);
Assert.True(vsparse2.HasCommonElements);
Assert.Equal(6, vapprox2.Count);
for (int i = 0; i < fromArray.Length; i++) Assert.Equal(vapprox2[i], fromArray[i]);
Assert.Equal(vapprox2.Sparsity, approxSparsity);
Assert.Equal(1.2, vapprox2.CommonValue);
Assert.Equal(3, vapprox2.SparseValues.Count);
Assert.True(vapprox2.HasCommonElements);
}
#endif
[DataContract]
[Serializable]
public class MyClass
@ -306,19 +254,6 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
#if NETFRAMEWORK
private static T CloneBinaryFormatter<T>(T obj)
{
var bf = new BinaryFormatter();
using (var ms = new MemoryStream())
{
bf.Serialize(ms, obj);
ms.Position = 0;
return (T)bf.Deserialize(ms);
}
}
#endif
private static T CloneJsonNet<T>(T obj)
{
var serializerSettings = new JsonSerializerSettings

Просмотреть файл

@ -1452,25 +1452,6 @@ namespace Microsoft.ML.Probabilistic.Tests
Console.WriteLine(engine.Infer(mean));
}
#if NETFRAMEWORK
internal void BinarySerializationExample()
{
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);
BinaryFormatter serializer = new BinaryFormatter();
// write to disk
using (FileStream stream = new FileStream("temp.bin", FileMode.Create))
{
serializer.Serialize(stream, d);
}
// read from disk
using (FileStream stream = new FileStream("temp.bin", FileMode.Open))
{
Dirichlet d2 = (Dirichlet)serializer.Deserialize(stream);
Console.WriteLine(d2);
}
}
#endif
internal void XmlSerializationExample()
{
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);

Просмотреть файл

@ -679,21 +679,6 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
#if NETFRAMEWORK
if (false)
{
BinaryFormatter serializer = new BinaryFormatter();
using (var writer = new FileStream("userTraits.bin", FileMode.Create))
{
serializer.Serialize(writer, engine.Infer(userTraits));
}
using (var writer = new FileStream("itemTraits.bin", FileMode.Create))
{
serializer.Serialize(writer, engine.Infer(itemTraits));
}
}
#endif
// test resetting inference
if (engine.Compiler.ReturnCopies)
{