зеркало из https://github.com/dotnet/infer.git
This commit is contained in:
Родитель
5156776f4c
Коммит
2f58f64afb
19
Infer.sln
19
Infer.sln
|
@ -1,6 +1,6 @@
|
|||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio Version 16
|
||||
VisualStudioVersion = 16.0.29020.237
|
||||
# Visual Studio Version 17
|
||||
VisualStudioVersion = 17.2.32408.312
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{A181C943-2E01-454D-9008-2E3C53AA09CC}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
|
@ -56,8 +56,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Learners", "Learners", "{29
|
|||
EndProject
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Runners", "Runners", "{3DB795A6-5FE8-447C-89B9-9E608285C6F8}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CommandLine", "src\Learners\Runners\CommandLine\CommandLine.csproj", "{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Common", "src\Learners\Runners\Common\Common.csproj", "{25D28099-E338-4543-B1DE-261439654CA6}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Evaluator", "src\Learners\Runners\Evaluator\Evaluator.csproj", "{040FA938-BE24-4391-86BA-D04B331A787A}"
|
||||
|
@ -321,18 +319,6 @@ Global
|
|||
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseCore|Any CPU.Build.0 = Release|Any CPU
|
||||
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseFull|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{5B669C82-B04C-4DD6-8CE6-47D025D98777}.ReleaseFull|Any CPU.Build.0 = Release|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugCore|Any CPU.ActiveCfg = DebugCore|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugCore|Any CPU.Build.0 = DebugCore|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugFull|Any CPU.ActiveCfg = DebugFull|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.DebugFull|Any CPU.Build.0 = DebugFull|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseCore|Any CPU.ActiveCfg = ReleaseCore|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseCore|Any CPU.Build.0 = ReleaseCore|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseFull|Any CPU.ActiveCfg = ReleaseFull|Any CPU
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96}.ReleaseFull|Any CPU.Build.0 = ReleaseFull|Any CPU
|
||||
{25D28099-E338-4543-B1DE-261439654CA6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{25D28099-E338-4543-B1DE-261439654CA6}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{25D28099-E338-4543-B1DE-261439654CA6}.DebugCore|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
|
@ -570,7 +556,6 @@ Global
|
|||
{87D09BD4-119E-49C1-B0B4-86DF962A00EE} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
|
||||
{6FF3E672-378C-4D61-B4CA-A5A5E01C2563} = {DC5F5BC4-CDB0-41F7-8B03-CD4C38C8DEB2}
|
||||
{3DB795A6-5FE8-447C-89B9-9E608285C6F8} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
|
||||
{10FD3E08-53E8-42B2-8E4F-A5C23DEE3B96} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
|
||||
{25D28099-E338-4543-B1DE-261439654CA6} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
|
||||
{040FA938-BE24-4391-86BA-D04B331A787A} = {3DB795A6-5FE8-447C-89B9-9E608285C6F8}
|
||||
{07E9E91D-6593-4FF9-A266-270ED5241C98} = {2964BB90-4E6D-49ED-AA35-645D94337C76}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
# Nightly build using .NET Core
|
||||
|
||||
name: 0.4.$(Date:yyMM).$(Date:dd)$(Rev:rr)
|
||||
name: 0.5.$(Date:yyMM).$(Date:dd)$(Rev:rr)
|
||||
|
||||
resources:
|
||||
- repo: self
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
# Nightly build for Windows. Tests on x86 and x64. Produces NuGet packages
|
||||
|
||||
name: 0.4.$(Date:yyMM).$(Date:dd)$(Rev:rr)
|
||||
name: 0.5.$(Date:yyMM).$(Date:dd)$(Rev:rr)
|
||||
|
||||
resources:
|
||||
- repo: self
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
# Official signed build
|
||||
|
||||
name: 0.4.$(Date:yyMM).$(Date:dd)$(Rev:rr)
|
||||
name: 0.5.$(Date:yyMM).$(Date:dd)$(Rev:rr)
|
||||
|
||||
resources:
|
||||
- repo: self
|
||||
|
|
|
@ -109,26 +109,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
|
||||
#region .Net binary deserialization
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a Bayes point machine classifier from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <typeparam name="TPredictionSettings">The type of the settings for prediction.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>
|
||||
Load<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>(string fileName)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
where TPredictionSettings : IBayesPointMachineClassifierPredictionSettings<TLabel>
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a Bayes point machine classifier from a stream and formatter.
|
||||
/// </summary>
|
||||
|
@ -150,24 +130,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, TPredictionSettings>>(stream, formatter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a binary Bayes point machine classifier from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings>(string fileName)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a binary Bayes point machine classifier from a stream and formatter.
|
||||
/// </summary>
|
||||
|
@ -187,24 +149,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a multi-class Bayes point machine classifier from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings>(string fileName)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a multi-class Bayes point machine classifier from a stream and a formatter.
|
||||
/// </summary>
|
||||
|
@ -224,22 +168,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a binary Bayes point machine classifier from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a binary Bayes point machine classifier from a stream and format.
|
||||
/// </summary>
|
||||
|
@ -257,22 +185,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a multi-class Bayes point machine classifier from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, BayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a multi-class Bayes point machine classifier from a stream and formatter.
|
||||
/// </summary>
|
||||
|
@ -592,23 +504,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
|
||||
#region Internal .Net binary deserialization
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over factorized weights from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over factorized weights from a stream and formatter.
|
||||
|
@ -627,23 +522,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over factorized weights from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over factorized weights from a stream and formatter.
|
||||
|
|
|
@ -23,30 +23,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
|
||||
#region Save
|
||||
|
||||
/// <summary>
|
||||
/// Persists a learner to a file.
|
||||
/// </summary>
|
||||
/// <param name="learner">The learner to serialize.</param>
|
||||
/// <param name="fileName">The name of the file.</param>
|
||||
public static void Save(this ILearner learner, string fileName)
|
||||
{
|
||||
if (learner == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(learner));
|
||||
}
|
||||
|
||||
if (fileName == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(fileName));
|
||||
}
|
||||
|
||||
using (Stream stream = File.Open(fileName, FileMode.Create))
|
||||
{
|
||||
var formatter = new BinaryFormatter();
|
||||
learner.Save(stream, formatter);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Serializes a learner to a given stream using a given formatter.
|
||||
/// </summary>
|
||||
|
@ -181,26 +157,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
return (TLearner)formatter.Deserialize(stream);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a learner from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TLearner">The type of a learner.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized learner object.</returns>
|
||||
public static TLearner Load<TLearner>(string fileName)
|
||||
{
|
||||
if (fileName == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(fileName));
|
||||
}
|
||||
|
||||
using (Stream stream = File.Open(fileName, FileMode.Open))
|
||||
{
|
||||
var formatter = new BinaryFormatter();
|
||||
return Load<TLearner>(stream, formatter);
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#endregion
|
||||
|
|
|
@ -70,22 +70,6 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
|
||||
#region .NET binary deserialization
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a Matchbox recommender from a file.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TUser">The type of a user.</typeparam>
|
||||
/// <typeparam name="TItem">The type of an item.</typeparam>
|
||||
/// <typeparam name="TRatingDistribution">The type of a distribution over ratings.</typeparam>
|
||||
/// <typeparam name="TFeatureSource">The type of a feature source.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized recommender object.</returns>
|
||||
public static IMatchboxRecommender<TInstanceSource, TUser, TItem, TRatingDistribution, TFeatureSource>
|
||||
Load<TInstanceSource, TUser, TItem, TRatingDistribution, TFeatureSource>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IMatchboxRecommender<TInstanceSource, TUser, TItem, TRatingDistribution, TFeatureSource>>(fileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Deserializes a recommender from a given stream and formatter.
|
||||
/// </summary>
|
||||
|
|
|
@ -1,311 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
|
||||
/// <summary>
|
||||
/// Utilities for the Bayes point machine classifier modules.
|
||||
/// </summary>
|
||||
internal static class BayesPointMachineClassifierModuleUtilities
|
||||
{
|
||||
/// <summary>
|
||||
/// Diagnoses the Bayes point machine classifier on the specified data set.
|
||||
/// </summary>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <param name="classifier">The Bayes point machine classifier.</param>
|
||||
/// <param name="trainingSet">The dataset.</param>
|
||||
/// <param name="maxParameterChangesFileName">The name of the file to store the maximum parameter differences.</param>
|
||||
/// <param name="modelFileName">The name of the file to store the trained Bayes point machine model.</param>
|
||||
public static void DiagnoseClassifier<TTrainingSettings>(
|
||||
IBayesPointMachineClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<string>> classifier,
|
||||
IList<LabeledFeatureValues> trainingSet,
|
||||
string maxParameterChangesFileName,
|
||||
string modelFileName)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
// Create prior distributions over weights
|
||||
int classCount = trainingSet[0].LabelDistribution.LabelSet.Count;
|
||||
int featureCount = trainingSet[0].GetDenseFeatureVector().Count;
|
||||
var priorWeightDistributions = Util.ArrayInit(classCount, c => Util.ArrayInit(featureCount, f => new Gaussian(0.0, 1.0)));
|
||||
|
||||
// Create IterationChanged handler
|
||||
var watch = new Stopwatch();
|
||||
classifier.IterationChanged += (sender, eventArgs) =>
|
||||
{
|
||||
watch.Stop();
|
||||
double maxParameterChange = MaxDiff(eventArgs.WeightPosteriorDistributions, priorWeightDistributions);
|
||||
|
||||
if (!string.IsNullOrEmpty(maxParameterChangesFileName))
|
||||
{
|
||||
SaveMaximumParameterDifference(
|
||||
maxParameterChangesFileName,
|
||||
eventArgs.CompletedIterationCount,
|
||||
maxParameterChange,
|
||||
watch.ElapsedMilliseconds);
|
||||
}
|
||||
|
||||
Console.WriteLine(
|
||||
"[{0}] Iteration {1,-4} dp = {2,-20} dt = {3,5}ms",
|
||||
DateTime.Now.ToLongTimeString(),
|
||||
eventArgs.CompletedIterationCount,
|
||||
maxParameterChange,
|
||||
watch.ElapsedMilliseconds);
|
||||
|
||||
// Copy weight marginals
|
||||
for (int c = 0; c < eventArgs.WeightPosteriorDistributions.Count; c++)
|
||||
{
|
||||
for (int f = 0; f < eventArgs.WeightPosteriorDistributions[c].Count; f++)
|
||||
{
|
||||
priorWeightDistributions[c][f] = eventArgs.WeightPosteriorDistributions[c][f];
|
||||
}
|
||||
}
|
||||
|
||||
watch.Restart();
|
||||
};
|
||||
|
||||
// Write file header
|
||||
if (!string.IsNullOrEmpty(maxParameterChangesFileName))
|
||||
{
|
||||
using (var writer = new StreamWriter(maxParameterChangesFileName))
|
||||
{
|
||||
writer.WriteLine("# time, # iteration, # maximum absolute parameter difference, # iteration time in milliseconds");
|
||||
}
|
||||
}
|
||||
|
||||
// Train the Bayes point machine classifier
|
||||
Console.WriteLine("[{0}] Starting training...", DateTime.Now.ToLongTimeString());
|
||||
watch.Start();
|
||||
|
||||
classifier.Train(trainingSet);
|
||||
|
||||
// Compute evidence
|
||||
if (classifier.Settings.Training.ComputeModelEvidence)
|
||||
{
|
||||
Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
|
||||
}
|
||||
|
||||
// Save trained model
|
||||
if (!string.IsNullOrEmpty(modelFileName))
|
||||
{
|
||||
classifier.Save(modelFileName);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Samples weights from the learned weight distribution of the Bayes point machine classifier.
|
||||
/// </summary>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <param name="classifier">The Bayes point machine used to sample weights from.</param>
|
||||
/// <param name="samplesFile">The name of the file to which the weights will be written.</param>
|
||||
public static void SampleWeights<TTrainingSettings>(
|
||||
IBayesPointMachineClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<string>> classifier,
|
||||
string samplesFile)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
// Sample weights
|
||||
var samples = SampleWeights(classifier);
|
||||
|
||||
// Write samples to file
|
||||
if (!string.IsNullOrEmpty(samplesFile))
|
||||
{
|
||||
ClassifierPersistenceUtils.SaveVectors(samplesFile, samples);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the performance metrics to a file with the specified name.
|
||||
/// </summary>
|
||||
/// <param name="fileName">The name of the file to save the metrics to.</param>
|
||||
/// <param name="accuracy">The accuracy.</param>
|
||||
/// <param name="negativeLogProbability">The mean negative log probability.</param>
|
||||
/// <param name="auc">The AUC.</param>
|
||||
/// <param name="evidence">The model's log evidence.</param>
|
||||
/// <param name="iterationCount">The number of training iterations.</param>
|
||||
/// <param name="trainingTime">The training time in milliseconds.</param>
|
||||
public static void SavePerformanceMetrics(
|
||||
string fileName,
|
||||
ICollection<double> accuracy,
|
||||
IEnumerable<double> negativeLogProbability,
|
||||
IEnumerable<double> auc,
|
||||
IEnumerable<double> evidence,
|
||||
IEnumerable<double> iterationCount,
|
||||
IEnumerable<double> trainingTime)
|
||||
{
|
||||
using (var writer = new StreamWriter(fileName))
|
||||
{
|
||||
// Write header
|
||||
for (int fold = 0; fold < accuracy.Count; fold++)
|
||||
{
|
||||
if (fold == 0)
|
||||
{
|
||||
writer.Write("# ");
|
||||
}
|
||||
|
||||
writer.Write("Fold {0}, ", fold + 1);
|
||||
}
|
||||
|
||||
writer.WriteLine("Mean, Standard deviation");
|
||||
writer.WriteLine();
|
||||
|
||||
// Write metrics
|
||||
SaveSinglePerformanceMetric(writer, "Accuracy", accuracy);
|
||||
SaveSinglePerformanceMetric(writer, "Mean negative log probability", negativeLogProbability);
|
||||
SaveSinglePerformanceMetric(writer, "AUC", auc);
|
||||
SaveSinglePerformanceMetric(writer, "Log evidence", evidence);
|
||||
SaveSinglePerformanceMetric(writer, "Training time", trainingTime);
|
||||
SaveSinglePerformanceMetric(writer, "Iteration count", iterationCount);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts elapsed time in milliseconds into a human readable format.
|
||||
/// </summary>
|
||||
/// <param name="elapsedMilliseconds">The elapsed time in milliseconds.</param>
|
||||
/// <returns>A human readable string of specified time.</returns>
|
||||
public static string FormatElapsedTime(double elapsedMilliseconds)
|
||||
{
|
||||
TimeSpan time = TimeSpan.FromMilliseconds(elapsedMilliseconds);
|
||||
|
||||
string formattedTime = time.Hours > 0 ? string.Format("{0}:", time.Hours) : string.Empty;
|
||||
formattedTime += time.Hours > 0 ? string.Format("{0:D2}:", time.Minutes) : time.Minutes > 0 ? string.Format("{0}:", time.Minutes) : string.Empty;
|
||||
formattedTime += time.Hours > 0 || time.Minutes > 0 ? string.Format("{0:D2}.{1:D3}", time.Seconds, time.Milliseconds) : string.Format("{0}.{1:D3} seconds", time.Seconds, time.Milliseconds);
|
||||
|
||||
return formattedTime;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes key statistics of the specified data set to console.
|
||||
/// </summary>
|
||||
/// <param name="dataSet">The data set.</param>
|
||||
public static void WriteDataSetInfo(IList<LabeledFeatureValues> dataSet)
|
||||
{
|
||||
Console.WriteLine(
|
||||
"Data set contains {0} instances, {1} classes and {2} features.",
|
||||
dataSet.Count,
|
||||
dataSet.Count > 0 ? dataSet.First().LabelDistribution.LabelSet.Count : 0,
|
||||
dataSet.Count > 0 ? dataSet.First().FeatureSet.Count : 0);
|
||||
}
|
||||
|
||||
#region Helper methods
|
||||
|
||||
/// <summary>
|
||||
/// Writes a single performance metric to the specified writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The writer to write the metrics to.</param>
|
||||
/// <param name="description">The metric description.</param>
|
||||
/// <param name="metric">The metric.</param>
|
||||
private static void SaveSinglePerformanceMetric(TextWriter writer, string description, IEnumerable<double> metric)
|
||||
{
|
||||
// Write description
|
||||
writer.WriteLine("# " + description);
|
||||
|
||||
// Write metric
|
||||
var mva = new MeanVarianceAccumulator();
|
||||
foreach (double value in metric)
|
||||
{
|
||||
writer.Write("{0}, ", value);
|
||||
mva.Add(value);
|
||||
}
|
||||
|
||||
writer.WriteLine("{0}, {1}", mva.Mean, Math.Sqrt(mva.Variance));
|
||||
writer.WriteLine();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves the maximum absolute difference between two given weight distributions to a file with the specified name.
|
||||
/// </summary>
|
||||
/// <param name="fileName">The name of the file to save the maximum absolute difference between weight distributions to.</param>
|
||||
/// <param name="iteration">The inference algorithm iteration.</param>
|
||||
/// <param name="maxParameterChange">The maximum absolute difference in any parameter of two weight distributions.</param>
|
||||
/// <param name="elapsedMilliseconds">The elapsed milliseconds.</param>
|
||||
private static void SaveMaximumParameterDifference(string fileName, int iteration, double maxParameterChange, long elapsedMilliseconds)
|
||||
{
|
||||
using (var writer = new StreamWriter(fileName, true))
|
||||
{
|
||||
writer.WriteLine("{0}, {1}, {2}, {3}", DateTime.Now.ToLongTimeString(), iteration, maxParameterChange, elapsedMilliseconds);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the maximum difference in any parameter of two Gaussian distributions.
|
||||
/// </summary>
|
||||
/// <param name="first">The first Gaussian.</param>
|
||||
/// <param name="second">The second Gaussian.</param>
|
||||
/// <returns>The maximum absolute difference in any parameter.</returns>
|
||||
/// <remarks>This difference computation is based on mean and variance instead of mean*precision and precision.</remarks>
|
||||
private static double MaxDiff(IReadOnlyList<IReadOnlyList<Gaussian>> first, Gaussian[][] second)
|
||||
{
|
||||
int classCount = first.Count;
|
||||
int featureCount = first[0].Count;
|
||||
|
||||
double maxDiff = double.NegativeInfinity;
|
||||
for (int c = 0; c < classCount; c++)
|
||||
{
|
||||
for (int f = 0; f < featureCount; f++)
|
||||
{
|
||||
double firstMean, firstVariance, secondMean, secondVariance;
|
||||
first[c][f].GetMeanAndVariance(out firstMean, out firstVariance);
|
||||
second[c][f].GetMeanAndVariance(out secondMean, out secondVariance);
|
||||
double meanDifference = Math.Abs(firstMean - secondMean);
|
||||
double varianceDifference = Math.Abs(firstVariance - secondVariance);
|
||||
|
||||
if (meanDifference > maxDiff)
|
||||
{
|
||||
maxDiff = Math.Abs(meanDifference);
|
||||
}
|
||||
|
||||
if (Math.Abs(varianceDifference) > maxDiff)
|
||||
{
|
||||
maxDiff = Math.Abs(varianceDifference);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxDiff;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Samples weights from the learned weight distribution of the Bayes point machine classifier.
|
||||
/// </summary>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <param name="classifier">The Bayes point machine used to sample weights from.</param>
|
||||
/// <returns>The samples from the weight distribution of the Bayes point machine classifier.</returns>
|
||||
private static IEnumerable<Vector> SampleWeights<TTrainingSettings>(
|
||||
IBayesPointMachineClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<string>> classifier)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
Debug.Assert(classifier != null, "The classifier must not be null.");
|
||||
|
||||
IReadOnlyList<IReadOnlyList<Gaussian>> weightPosteriorDistributions = classifier.WeightPosteriorDistributions;
|
||||
int classCount = weightPosteriorDistributions.Count < 2 ? 2 : weightPosteriorDistributions.Count;
|
||||
int featureCount = weightPosteriorDistributions[0].Count;
|
||||
|
||||
var samples = new Vector[classCount - 1];
|
||||
for (int c = 0; c < classCount - 1; c++)
|
||||
{
|
||||
var sample = Vector.Zero(featureCount);
|
||||
for (int f = 0; f < featureCount; f++)
|
||||
{
|
||||
sample[f] = weightPosteriorDistributions[c][f].Sample();
|
||||
}
|
||||
|
||||
samples[c] = sample;
|
||||
}
|
||||
|
||||
return samples;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to cross-validate a binary Bayes point machine classifier on given data.
|
||||
/// </summary>
|
||||
internal class BinaryBayesPointMachineClassifierCrossValidationModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string dataSetFile = string.Empty;
|
||||
string resultsFile = string.Empty;
|
||||
int crossValidationFoldCount = 5;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--data-set", "FILE", "File with training data", v => dataSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--results", "FILE", "File with cross-validation results", v => resultsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--folds", "NUM", "Number of cross-validation folds (defaults to " + crossValidationFoldCount + ")", v => crossValidationFoldCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load and shuffle data
|
||||
var dataSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(dataSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(dataSet);
|
||||
|
||||
Rand.Restart(562);
|
||||
Rand.Shuffle(dataSet);
|
||||
|
||||
// Create evaluator
|
||||
var evaluatorMapping = Mappings.Classifier.ForEvaluation();
|
||||
var evaluator = new ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string>(evaluatorMapping);
|
||||
|
||||
// Create performance metrics
|
||||
var accuracy = new List<double>();
|
||||
var negativeLogProbability = new List<double>();
|
||||
var auc = new List<double>();
|
||||
var evidence = new List<double>();
|
||||
var iterationCounts = new List<double>();
|
||||
var trainingTime = new List<double>();
|
||||
|
||||
// Run cross-validation
|
||||
int validationSetSize = dataSet.Count / crossValidationFoldCount;
|
||||
Console.WriteLine("Running {0}-fold cross-validation on {1}", crossValidationFoldCount, dataSetFile);
|
||||
|
||||
// TODO: Use chained mapping to implement cross-validation
|
||||
for (int fold = 0; fold < crossValidationFoldCount; fold++)
|
||||
{
|
||||
// Construct training and validation sets for fold
|
||||
int validationSetStart = fold * validationSetSize;
|
||||
int validationSetEnd = (fold + 1 == crossValidationFoldCount)
|
||||
? dataSet.Count
|
||||
: (fold + 1) * validationSetSize;
|
||||
|
||||
var trainingSet = new List<LabeledFeatureValues>();
|
||||
var validationSet = new List<LabeledFeatureValues>();
|
||||
|
||||
for (int instance = 0; instance < dataSet.Count; instance++)
|
||||
{
|
||||
if (validationSetStart <= instance && instance < validationSetEnd)
|
||||
{
|
||||
validationSet.Add(dataSet[instance]);
|
||||
}
|
||||
else
|
||||
{
|
||||
trainingSet.Add(dataSet[instance]);
|
||||
}
|
||||
}
|
||||
|
||||
// Print info
|
||||
Console.WriteLine(" Fold {0} [validation set instances {1} - {2}]", fold + 1, validationSetStart, validationSetEnd - 1);
|
||||
|
||||
// Create classifier
|
||||
var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(Mappings.Classifier);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
|
||||
|
||||
int currentIterationCount = 0;
|
||||
classifier.IterationChanged += (sender, eventArgs) => { currentIterationCount = eventArgs.CompletedIterationCount; };
|
||||
|
||||
// Train classifier
|
||||
var stopWatch = new Stopwatch();
|
||||
stopWatch.Start();
|
||||
classifier.Train(trainingSet);
|
||||
stopWatch.Stop();
|
||||
|
||||
// Produce predictions
|
||||
var predictions = classifier.PredictDistribution(validationSet).ToList();
|
||||
var predictedLabels = predictions.Select(
|
||||
prediction => prediction.Aggregate((aggregate, next) => next.Value > aggregate.Value ? next : aggregate).Key).ToList();
|
||||
|
||||
// Iteration count
|
||||
iterationCounts.Add(currentIterationCount);
|
||||
|
||||
// Training time
|
||||
trainingTime.Add(stopWatch.ElapsedMilliseconds);
|
||||
|
||||
// Compute accuracy
|
||||
accuracy.Add(1 - (evaluator.Evaluate(validationSet, predictedLabels, Metrics.ZeroOneError) / predictions.Count));
|
||||
|
||||
// Compute mean negative log probability
|
||||
negativeLogProbability.Add(evaluator.Evaluate(validationSet, predictions, Metrics.NegativeLogProbability) / predictions.Count);
|
||||
|
||||
// Compute M-measure (averaged pairwise AUC)
|
||||
auc.Add(evaluator.AreaUnderRocCurve(validationSet, predictions));
|
||||
|
||||
// Compute log evidence if desired
|
||||
evidence.Add(computeModelEvidence ? classifier.LogModelEvidence : double.NaN);
|
||||
|
||||
// Persist performance metrics
|
||||
Console.WriteLine(
|
||||
" Accuracy = {0,5:0.0000} NegLogProb = {1,5:0.0000} AUC = {2,5:0.0000}{3} Iterations = {4} Training time = {5}",
|
||||
accuracy[fold],
|
||||
negativeLogProbability[fold],
|
||||
auc[fold],
|
||||
computeModelEvidence ? string.Format(" Log evidence = {0,5:0.0000}", evidence[fold]) : string.Empty,
|
||||
iterationCounts[fold],
|
||||
BayesPointMachineClassifierModuleUtilities.FormatElapsedTime(trainingTime[fold]));
|
||||
|
||||
BayesPointMachineClassifierModuleUtilities.SavePerformanceMetrics(
|
||||
resultsFile, accuracy, negativeLogProbability, auc, evidence, iterationCounts, trainingTime);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to incrementally train a binary Bayes point machine classifier on some given data.
|
||||
/// </summary>
|
||||
internal class BinaryBayesPointMachineClassifierIncrementalTrainingModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingSetFile = string.Empty;
|
||||
string inputModelFile = string.Empty;
|
||||
string outputModelFile = string.Empty;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--input-model", "FILE", "File with the trained binary Bayes point machine model", v => inputModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File to store the incrementally trained binary Bayes point machine model", v => outputModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
|
||||
|
||||
var classifier = BayesPointMachineClassifier.LoadBinaryClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(inputModelFile);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
|
||||
classifier.TrainIncremental(trainingSet);
|
||||
|
||||
classifier.Save(outputModelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to predict labels given a trained binary Bayes point machine classifier model and a test set.
|
||||
/// </summary>
|
||||
internal class BinaryBayesPointMachineClassifierPredictionModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string testSetFile = string.Empty;
|
||||
string modelFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--test-set", "FILE", "File with test data", v => testSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with a trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File to store predictions for the test data", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var testSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(testSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(testSet);
|
||||
|
||||
var classifier =
|
||||
BayesPointMachineClassifier.LoadBinaryClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
|
||||
|
||||
// Predict labels
|
||||
var predictions = classifier.PredictDistribution(testSet);
|
||||
|
||||
// Write labels to file
|
||||
ClassifierPersistenceUtils.SaveLabelDistributions(predictionsFile, predictions);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to sample weights from a trained binary Bayes point machine classifier model.
|
||||
/// </summary>
|
||||
internal class BinaryBayesPointMachineClassifierSampleWeightsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string modelFile = string.Empty;
|
||||
string samplesFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with a trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--samples", "FILE", "File to store samples of the weights", v => samplesFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var classifier =
|
||||
BayesPointMachineClassifier.LoadBinaryClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
|
||||
|
||||
BayesPointMachineClassifierModuleUtilities.SampleWeights(classifier, samplesFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
/// <summary>
|
||||
/// A command-line module to diagnose training of a binary Bayes point machine classifier on given data.
|
||||
/// </summary>
|
||||
internal class BinaryBayesPointMachineClassifierTrainingDiagnosticsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingSetFile = string.Empty;
|
||||
string maxParameterChangesFile = string.Empty;
|
||||
string modelFile = string.Empty;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--results", "FILE", "File to store the maximum parameter differences", v => maxParameterChangesFile = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
|
||||
|
||||
var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(Mappings.Classifier);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
|
||||
BayesPointMachineClassifierModuleUtilities.DiagnoseClassifier(classifier, trainingSet, maxParameterChangesFile, modelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Linq;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to train a binary Bayes point machine classifier on some given data.
|
||||
/// </summary>
|
||||
internal class BinaryBayesPointMachineClassifierTrainingModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingSetFile = string.Empty;
|
||||
string modelFile = string.Empty;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained binary Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
|
||||
|
||||
var featureSet = trainingSet.Count > 0 ? trainingSet.First().FeatureSet : null;
|
||||
var mapping = new ClassifierMapping(featureSet);
|
||||
var classifier = BayesPointMachineClassifier.CreateBinaryClassifier(mapping);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
|
||||
|
||||
classifier.Train(trainingSet);
|
||||
|
||||
if (classifier.Settings.Training.ComputeModelEvidence)
|
||||
{
|
||||
Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
|
||||
}
|
||||
|
||||
classifier.Save(modelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,674 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to evaluate the label predictions of classifiers.
|
||||
/// </summary>
|
||||
internal class ClassifierEvaluationModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string groundTruthFileName = string.Empty;
|
||||
string predictionsFileName = string.Empty;
|
||||
string reportFileName = string.Empty;
|
||||
string calibrationCurveFileName = string.Empty;
|
||||
string rocCurveFileName = string.Empty;
|
||||
string precisionRecallCurveFileName = string.Empty;
|
||||
string positiveClassLabel = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--ground-truth", "FILE", "File with ground truth labels", v => groundTruthFileName = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File with label predictions", v => predictionsFileName = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--report", "FILE", "File to store the evaluation report", v => reportFileName = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--calibration-curve", "FILE", "File to store the empirical calibration curve", v => calibrationCurveFileName = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--roc-curve", "FILE", "File to store the receiver operating characteristic curve", v => rocCurveFileName = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--precision-recall-curve", "FILE", "File to store the precision-recall curve", v => precisionRecallCurveFileName = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--positive-class", "STRING", "Label of the positive class to use in curves", v => positiveClassLabel = v, CommandLineParameterType.Optional);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read ground truth
|
||||
var groundTruth = ClassifierPersistenceUtils.LoadLabeledFeatureValues(groundTruthFileName);
|
||||
|
||||
// Read predictions using ground truth label dictionary
|
||||
var predictions = ClassifierPersistenceUtils.LoadLabelDistributions(predictionsFileName, groundTruth.First().LabelDistribution.LabelSet);
|
||||
|
||||
// Check that there are at least two distinct class labels
|
||||
if (predictions.First().LabelSet.Count < 2)
|
||||
{
|
||||
throw new InvalidFileFormatException("Ground truth and predictions must contain at least two distinct class labels.");
|
||||
}
|
||||
|
||||
// Distill distributions and point estimates
|
||||
var predictiveDistributions = predictions.Select(i => i.ToDictionary()).ToList();
|
||||
var predictivePointEstimates = predictions.Select(i => i.GetMode()).ToList();
|
||||
|
||||
// Create evaluator
|
||||
var evaluatorMapping = Mappings.Classifier.ForEvaluation();
|
||||
var evaluator = new ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string>(evaluatorMapping);
|
||||
|
||||
// Write evaluation report
|
||||
if (!string.IsNullOrEmpty(reportFileName))
|
||||
{
|
||||
using (var writer = new StreamWriter(reportFileName))
|
||||
{
|
||||
this.WriteReportHeader(writer, groundTruthFileName, predictionsFileName);
|
||||
this.WriteReport(writer, evaluator, groundTruth, predictiveDistributions, predictivePointEstimates);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and write the empirical probability calibration curve
|
||||
positiveClassLabel = this.CheckPositiveClassLabel(groundTruth, positiveClassLabel);
|
||||
if (!string.IsNullOrEmpty(calibrationCurveFileName))
|
||||
{
|
||||
this.WriteCalibrationCurve(calibrationCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
|
||||
}
|
||||
|
||||
// Compute and write the precision-recall curve
|
||||
if (!string.IsNullOrEmpty(precisionRecallCurveFileName))
|
||||
{
|
||||
this.WritePrecisionRecallCurve(precisionRecallCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
|
||||
}
|
||||
|
||||
// Compute and write the receiver operating characteristic curve
|
||||
if (!string.IsNullOrEmpty(rocCurveFileName))
|
||||
{
|
||||
this.WriteRocCurve(rocCurveFileName, evaluator, groundTruth, predictiveDistributions, positiveClassLabel);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#region Helper methods
|
||||
|
||||
/// <summary>
|
||||
/// Writes the evaluation results to a file with the specified name.
|
||||
/// </summary>
|
||||
/// <param name="writer">The name of the file to write the report to.</param>
|
||||
/// <param name="evaluator">The classifier evaluator.</param>
|
||||
/// <param name="groundTruth">The ground truth.</param>
|
||||
/// <param name="predictiveDistributions">The predictive distributions.</param>
|
||||
/// <param name="predictedLabels">The predicted labels.</param>
|
||||
private void WriteReport(
|
||||
StreamWriter writer,
|
||||
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
|
||||
IList<LabeledFeatureValues> groundTruth,
|
||||
ICollection<IDictionary<string, double>> predictiveDistributions,
|
||||
IEnumerable<string> predictedLabels)
|
||||
{
|
||||
// Compute confusion matrix
|
||||
var confusionMatrix = evaluator.ConfusionMatrix(groundTruth, predictedLabels);
|
||||
|
||||
// Compute mean negative log probability
|
||||
double meanNegativeLogProbability =
|
||||
evaluator.Evaluate(groundTruth, predictiveDistributions, Metrics.NegativeLogProbability) / predictiveDistributions.Count;
|
||||
|
||||
// Compute M-measure (averaged pairwise AUC)
|
||||
IDictionary<string, IDictionary<string, double>> aucMatrix;
|
||||
double auc = evaluator.AreaUnderRocCurve(groundTruth, predictiveDistributions, out aucMatrix);
|
||||
|
||||
// Compute per-label AUC as well as micro- and macro-averaged AUC
|
||||
double microAuc;
|
||||
double macroAuc;
|
||||
int macroAucClassLabelCount;
|
||||
var labelAuc = this.ComputeLabelAuc(
|
||||
confusionMatrix,
|
||||
evaluator,
|
||||
groundTruth,
|
||||
predictiveDistributions,
|
||||
out microAuc,
|
||||
out macroAuc,
|
||||
out macroAucClassLabelCount);
|
||||
|
||||
// Instance-averaged performance
|
||||
this.WriteInstanceAveragedPerformance(writer, confusionMatrix, meanNegativeLogProbability, microAuc);
|
||||
|
||||
// Class-averaged performance
|
||||
this.WriteClassAveragedPerformance(writer, confusionMatrix, auc, macroAuc, macroAucClassLabelCount);
|
||||
|
||||
// Performance on individual classes
|
||||
this.WriteIndividualClassPerformance(writer, confusionMatrix, labelAuc);
|
||||
|
||||
// Confusion matrix
|
||||
this.WriteConfusionMatrix(writer, confusionMatrix);
|
||||
|
||||
// Pairwise AUC
|
||||
this.WriteAucMatrix(writer, aucMatrix);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes all per-label AUCs as well as the micro- and macro-averaged AUCs.
|
||||
/// </summary>
|
||||
/// <param name="confusionMatrix">The confusion matrix.</param>
|
||||
/// <param name="evaluator">The classifier evaluator.</param>
|
||||
/// <param name="groundTruth">The ground truth.</param>
|
||||
/// <param name="predictiveDistributions">The predictive distributions.</param>
|
||||
/// <param name="microAuc">The micro-averaged area under the receiver operating characteristic curve.</param>
|
||||
/// <param name="macroAuc">The macro-averaged area under the receiver operating characteristic curve.</param>
|
||||
/// <param name="macroAucClassLabelCount">The number of class labels for which the AUC if defined.</param>
|
||||
/// <returns>The area under the receiver operating characteristic curve for each class label.</returns>
|
||||
private IDictionary<string, double> ComputeLabelAuc(
|
||||
ConfusionMatrix<string> confusionMatrix,
|
||||
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
|
||||
IList<LabeledFeatureValues> groundTruth,
|
||||
ICollection<IDictionary<string, double>> predictiveDistributions,
|
||||
out double microAuc,
|
||||
out double macroAuc,
|
||||
out int macroAucClassLabelCount)
|
||||
{
|
||||
int instanceCount = predictiveDistributions.Count;
|
||||
var classLabels = confusionMatrix.ClassLabelSet.Elements.ToArray();
|
||||
int classLabelCount = classLabels.Length;
|
||||
var labelAuc = new Dictionary<string, double>();
|
||||
|
||||
// Compute per-label AUC
|
||||
macroAucClassLabelCount = classLabelCount;
|
||||
foreach (var classLabel in classLabels)
|
||||
{
|
||||
// One versus rest
|
||||
double auc;
|
||||
try
|
||||
{
|
||||
auc = evaluator.AreaUnderRocCurve(classLabel, groundTruth, predictiveDistributions);
|
||||
}
|
||||
catch (ArgumentException)
|
||||
{
|
||||
auc = double.NaN;
|
||||
macroAucClassLabelCount--;
|
||||
}
|
||||
|
||||
labelAuc.Add(classLabel, auc);
|
||||
}
|
||||
|
||||
// Compute micro- and macro-averaged AUC
|
||||
microAuc = 0;
|
||||
macroAuc = 0;
|
||||
foreach (var label in classLabels)
|
||||
{
|
||||
if (double.IsNaN(labelAuc[label]))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
microAuc += confusionMatrix.TrueLabelCount(label) * labelAuc[label] / instanceCount;
|
||||
macroAuc += labelAuc[label] / macroAucClassLabelCount;
|
||||
}
|
||||
|
||||
return labelAuc;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the header of the evaluation report to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="groundTruthFileName">The name of the file containing the ground truth.</param>
|
||||
/// <param name="predictionsFileName">The name of the file containing the predictions.</param>
|
||||
private void WriteReportHeader(StreamWriter writer, string groundTruthFileName, string predictionsFileName)
|
||||
{
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Classifier evaluation report ");
|
||||
writer.WriteLine("******************************");
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Date: {0}", DateTime.Now);
|
||||
writer.WriteLine(" Ground truth: {0}", groundTruthFileName);
|
||||
writer.WriteLine(" Predictions: {0}", predictionsFileName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes instance-averaged performance results to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="confusionMatrix">The confusion matrix.</param>
|
||||
/// <param name="negativeLogProbability">The negative log-probability.</param>
|
||||
/// <param name="microAuc">The micro-averaged AUC.</param>
|
||||
private void WriteInstanceAveragedPerformance(
|
||||
StreamWriter writer,
|
||||
ConfusionMatrix<string> confusionMatrix,
|
||||
double negativeLogProbability,
|
||||
double microAuc)
|
||||
{
|
||||
long instanceCount = 0;
|
||||
long correctInstanceCount = 0;
|
||||
foreach (var classLabelIndex in confusionMatrix.ClassLabelSet.Indexes)
|
||||
{
|
||||
string classLabel = confusionMatrix.ClassLabelSet.GetElementByIndex(classLabelIndex);
|
||||
instanceCount += confusionMatrix.TrueLabelCount(classLabel);
|
||||
correctInstanceCount += confusionMatrix[classLabel, classLabel];
|
||||
}
|
||||
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Instance-averaged performance (micro-averages)");
|
||||
writer.WriteLine("================================================");
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Precision = {0,10:0.0000}", confusionMatrix.MicroPrecision);
|
||||
writer.WriteLine(" Recall = {0,10:0.0000}", confusionMatrix.MicroRecall);
|
||||
writer.WriteLine(" F1 = {0,10:0.0000}", confusionMatrix.MicroF1);
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" #Correct = {0,10}", correctInstanceCount);
|
||||
writer.WriteLine(" #Total = {0,10}", instanceCount);
|
||||
writer.WriteLine(" Accuracy = {0,10:0.0000}", confusionMatrix.MicroAccuracy);
|
||||
writer.WriteLine(" Error = {0,10:0.0000}", 1 - confusionMatrix.MicroAccuracy);
|
||||
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" AUC = {0,10:0.0000}", microAuc);
|
||||
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Log-loss = {0,10:0.0000}", negativeLogProbability);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes class-averaged performance results to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="confusionMatrix">The confusion matrix.</param>
|
||||
/// <param name="auc">The AUC.</param>
|
||||
/// <param name="macroAuc">The macro-averaged AUC.</param>
|
||||
/// <param name="macroAucClassLabelCount">The number of distinct class labels used to compute macro-averaged AUC.</param>
|
||||
private void WriteClassAveragedPerformance(
|
||||
StreamWriter writer,
|
||||
ConfusionMatrix<string> confusionMatrix,
|
||||
double auc,
|
||||
double macroAuc,
|
||||
int macroAucClassLabelCount)
|
||||
{
|
||||
int classLabelCount = confusionMatrix.ClassLabelSet.Count;
|
||||
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Class-averaged performance (macro-averages)");
|
||||
writer.WriteLine("=============================================");
|
||||
writer.WriteLine();
|
||||
if (confusionMatrix.MacroPrecisionClassLabelCount < classLabelCount)
|
||||
{
|
||||
writer.WriteLine(
|
||||
" Precision = {0,10:0.0000} {1,10}",
|
||||
confusionMatrix.MacroPrecision,
|
||||
"[only " + confusionMatrix.MacroPrecisionClassLabelCount + "/" + classLabelCount + " classes defined]");
|
||||
}
|
||||
else
|
||||
{
|
||||
writer.WriteLine(" Precision = {0,10:0.0000}", confusionMatrix.MacroPrecision);
|
||||
}
|
||||
|
||||
if (confusionMatrix.MacroRecallClassLabelCount < classLabelCount)
|
||||
{
|
||||
writer.WriteLine(
|
||||
" Recall = {0,10:0.0000} {1,10}",
|
||||
confusionMatrix.MacroRecall,
|
||||
"[only " + confusionMatrix.MacroRecallClassLabelCount + "/" + classLabelCount + " classes defined]");
|
||||
}
|
||||
else
|
||||
{
|
||||
writer.WriteLine(" Recall = {0,10:0.0000}", confusionMatrix.MacroRecall);
|
||||
}
|
||||
|
||||
if (confusionMatrix.MacroF1ClassLabelCount < classLabelCount)
|
||||
{
|
||||
writer.WriteLine(
|
||||
" F1 = {0,10:0.0000} {1,10}",
|
||||
confusionMatrix.MacroF1,
|
||||
"[only " + confusionMatrix.MacroF1ClassLabelCount + "/" + classLabelCount + " classes defined]");
|
||||
}
|
||||
else
|
||||
{
|
||||
writer.WriteLine(" F1 = {0,10:0.0000}", confusionMatrix.MacroF1);
|
||||
}
|
||||
|
||||
writer.WriteLine();
|
||||
if (confusionMatrix.MacroF1ClassLabelCount < classLabelCount)
|
||||
{
|
||||
writer.WriteLine(
|
||||
" Accuracy = {0,10:0.0000} {1,10}",
|
||||
confusionMatrix.MacroAccuracy,
|
||||
"[only " + confusionMatrix.MacroAccuracyClassLabelCount + "/" + classLabelCount + " classes defined]");
|
||||
writer.WriteLine(
|
||||
" Error = {0,10:0.0000} {1,10}",
|
||||
1 - confusionMatrix.MacroAccuracy,
|
||||
"[only " + confusionMatrix.MacroAccuracyClassLabelCount + "/" + classLabelCount + " classes defined]");
|
||||
}
|
||||
else
|
||||
{
|
||||
writer.WriteLine(" Accuracy = {0,10:0.0000}", confusionMatrix.MacroAccuracy);
|
||||
writer.WriteLine(" Error = {0,10:0.0000}", 1 - confusionMatrix.MacroAccuracy);
|
||||
}
|
||||
|
||||
writer.WriteLine();
|
||||
if (macroAucClassLabelCount < classLabelCount)
|
||||
{
|
||||
writer.WriteLine(
|
||||
" AUC = {0,10:0.0000} {1,10}",
|
||||
macroAuc,
|
||||
"[only " + macroAucClassLabelCount + "/" + classLabelCount + " classes defined]");
|
||||
}
|
||||
else
|
||||
{
|
||||
writer.WriteLine(" AUC = {0,10:0.0000}", macroAuc);
|
||||
}
|
||||
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" M (pairwise AUC) = {0,10:0.0000}", auc);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes performance results for individual classes to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="confusionMatrix">The confusion matrix.</param>
|
||||
/// <param name="auc">The per-class AUC.</param>
|
||||
private void WriteIndividualClassPerformance(
|
||||
StreamWriter writer,
|
||||
ConfusionMatrix<string> confusionMatrix,
|
||||
IDictionary<string, double> auc)
|
||||
{
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Performance on individual classes");
|
||||
writer.WriteLine("===================================");
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(
|
||||
" {0,5} {1,15} {2,10} {3,11} {4,9} {5,10} {6,10} {7,10} {8,10}",
|
||||
"Index",
|
||||
"Label",
|
||||
"#Truth",
|
||||
"#Predicted",
|
||||
"#Correct",
|
||||
"Precision",
|
||||
"Recall",
|
||||
"F1",
|
||||
"AUC");
|
||||
|
||||
writer.WriteLine("----------------------------------------------------------------------------------------------------");
|
||||
|
||||
foreach (var classLabelIndex in confusionMatrix.ClassLabelSet.Indexes)
|
||||
{
|
||||
string classLabel = confusionMatrix.ClassLabelSet.GetElementByIndex(classLabelIndex);
|
||||
|
||||
writer.WriteLine(
|
||||
" {0,5} {1,15} {2,10} {3,11} {4,9} {5,10:0.0000} {6,10:0.0000} {7,10:0.0000} {8,10:0.0000}",
|
||||
classLabelIndex + 1,
|
||||
classLabel,
|
||||
confusionMatrix.TrueLabelCount(classLabel),
|
||||
confusionMatrix.PredictedLabelCount(classLabel),
|
||||
confusionMatrix[classLabel, classLabel],
|
||||
confusionMatrix.Precision(classLabel),
|
||||
confusionMatrix.Recall(classLabel),
|
||||
confusionMatrix.F1(classLabel),
|
||||
auc[classLabel]);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the confusion matrix to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="confusionMatrix">The confusion matrix.</param>
|
||||
private void WriteConfusionMatrix(StreamWriter writer, ConfusionMatrix<string> confusionMatrix)
|
||||
{
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Confusion matrix");
|
||||
writer.WriteLine("==================");
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(confusionMatrix);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the matrix of pairwise AUC metrics to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="aucMatrix">The matrix containing the pairwise AUC metrics.</param>
|
||||
private void WriteAucMatrix(StreamWriter writer, IDictionary<string, IDictionary<string, double>> aucMatrix)
|
||||
{
|
||||
writer.WriteLine();
|
||||
writer.WriteLine(" Pairwise AUC matrix");
|
||||
writer.WriteLine("=====================");
|
||||
writer.WriteLine();
|
||||
|
||||
const int MaxLabelWidth = 20;
|
||||
const int MaxValueWidth = 6;
|
||||
|
||||
// Widths of the columns
|
||||
string[] labels = aucMatrix.Keys.ToArray();
|
||||
int classLabelCount = aucMatrix.Count;
|
||||
var columnWidths = new int[classLabelCount + 1];
|
||||
|
||||
// For each column of the confusion matrix...
|
||||
for (int c = 0; c < classLabelCount; c++)
|
||||
{
|
||||
// ...find the longest string among counts and label
|
||||
int labelWidth = Math.Min(labels[c].Length, MaxLabelWidth);
|
||||
|
||||
columnWidths[c + 1] = Math.Max(MaxValueWidth, labelWidth);
|
||||
|
||||
if (labelWidth > columnWidths[0])
|
||||
{
|
||||
columnWidths[0] = labelWidth;
|
||||
}
|
||||
}
|
||||
|
||||
// Print title row
|
||||
string format = string.Format("{{0,{0}}} \\ Prediction ->", columnWidths[0]);
|
||||
writer.WriteLine(format, "Truth");
|
||||
|
||||
// Print column labels
|
||||
this.WriteLabel(writer, string.Empty, columnWidths[0]);
|
||||
for (int c = 0; c < classLabelCount; c++)
|
||||
{
|
||||
this.WriteLabel(writer, labels[c], columnWidths[c + 1]);
|
||||
}
|
||||
|
||||
writer.WriteLine();
|
||||
|
||||
// For each row (true labels) in confusion matrix...
|
||||
for (int r = 0; r < classLabelCount; r++)
|
||||
{
|
||||
// Print row label
|
||||
this.WriteLabel(writer, labels[r], columnWidths[0]);
|
||||
|
||||
// For each column (predicted labels) in the confusion matrix...
|
||||
for (int c = 0; c < classLabelCount; c++)
|
||||
{
|
||||
// Print count
|
||||
this.WriteAucValue(writer, labels[r].Equals(labels[c]) ? -1 : aucMatrix[labels[r]][labels[c]], columnWidths[c + 1]);
|
||||
}
|
||||
|
||||
writer.WriteLine();
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the probability calibration plot to the file with the specified name.
|
||||
/// </summary>
|
||||
/// <param name="fileName">The name of the file to write the calibration plot to.</param>
|
||||
/// <param name="evaluator">The classifier evaluator.</param>
|
||||
/// <param name="groundTruth">The ground truth.</param>
|
||||
/// <param name="predictiveDistributions">The predictive distributions.</param>
|
||||
/// <param name="positiveClassLabel">The label of the positive class.</param>
|
||||
private void WriteCalibrationCurve(
|
||||
string fileName,
|
||||
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
|
||||
IList<LabeledFeatureValues> groundTruth,
|
||||
IList<IDictionary<string, double>> predictiveDistributions,
|
||||
string positiveClassLabel)
|
||||
{
|
||||
Debug.Assert(predictiveDistributions != null, "The predictive distributions must not be null.");
|
||||
Debug.Assert(predictiveDistributions.Count > 0, "The predictive distributions must not be empty.");
|
||||
Debug.Assert(positiveClassLabel != null, "The label of the positive class must not be null.");
|
||||
|
||||
var calibrationCurve = evaluator.CalibrationCurve(positiveClassLabel, groundTruth, predictiveDistributions);
|
||||
double calibrationError = calibrationCurve.Select(i => Metrics.AbsoluteError(i.EmpiricalProbability, i.PredictedProbability)).Average();
|
||||
|
||||
using (var writer = new StreamWriter(fileName))
|
||||
{
|
||||
writer.WriteLine("# Empirical probability calibration plot");
|
||||
writer.WriteLine("#");
|
||||
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
|
||||
writer.WriteLine("# Calibration error = {0} (mean absolute error)", calibrationError);
|
||||
writer.WriteLine("#");
|
||||
writer.WriteLine("# Predicted probability, empirical probability");
|
||||
foreach (var point in calibrationCurve)
|
||||
{
|
||||
writer.WriteLine(point);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the precision-recall curve to the file with the specified name.
|
||||
/// </summary>
|
||||
/// <param name="fileName">The name of the file to write the precision-recall curve to.</param>
|
||||
/// <param name="evaluator">The classifier evaluator.</param>
|
||||
/// <param name="groundTruth">The ground truth.</param>
|
||||
/// <param name="predictiveDistributions">The predictive distributions.</param>
|
||||
/// <param name="positiveClassLabel">The label of the positive class.</param>
|
||||
private void WritePrecisionRecallCurve(
|
||||
string fileName,
|
||||
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
|
||||
IList<LabeledFeatureValues> groundTruth,
|
||||
IList<IDictionary<string, double>> predictiveDistributions,
|
||||
string positiveClassLabel)
|
||||
{
|
||||
Debug.Assert(predictiveDistributions != null, "The predictive distributions must not be null.");
|
||||
Debug.Assert(predictiveDistributions.Count > 0, "The predictive distributions must not be empty.");
|
||||
Debug.Assert(positiveClassLabel != null, "The label of the positive class must not be null.");
|
||||
|
||||
var precisionRecallCurve = evaluator.PrecisionRecallCurve(positiveClassLabel, groundTruth, predictiveDistributions);
|
||||
using (var writer = new StreamWriter(fileName))
|
||||
{
|
||||
writer.WriteLine("# Precision-recall curve");
|
||||
writer.WriteLine("#");
|
||||
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
|
||||
writer.WriteLine("#");
|
||||
writer.WriteLine("# precision (P), Recall (R)");
|
||||
foreach (var point in precisionRecallCurve)
|
||||
{
|
||||
writer.WriteLine(point);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the receiver operating characteristic curve to the file with the specified name.
|
||||
/// </summary>
|
||||
/// <param name="fileName">The name of the file to write the receiver operating characteristic curve to.</param>
|
||||
/// <param name="evaluator">The classifier evaluator.</param>
|
||||
/// <param name="groundTruth">The ground truth.</param>
|
||||
/// <param name="predictiveDistributions">The predictive distributions.</param>
|
||||
/// <param name="positiveClassLabel">The label of the positive class.</param>
|
||||
private void WriteRocCurve(
|
||||
string fileName,
|
||||
ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string> evaluator,
|
||||
IList<LabeledFeatureValues> groundTruth,
|
||||
IList<IDictionary<string, double>> predictiveDistributions,
|
||||
string positiveClassLabel)
|
||||
{
|
||||
Debug.Assert(predictiveDistributions != null, "The predictive distributions must not be null.");
|
||||
Debug.Assert(predictiveDistributions.Count > 0, "The predictive distributions must not be empty.");
|
||||
Debug.Assert(positiveClassLabel != null, "The label of the positive class must not be null.");
|
||||
|
||||
var rocCurve = evaluator.ReceiverOperatingCharacteristicCurve(positiveClassLabel, groundTruth, predictiveDistributions);
|
||||
|
||||
using (var writer = new StreamWriter(fileName))
|
||||
{
|
||||
writer.WriteLine("# Receiver operating characteristic (ROC) curve");
|
||||
writer.WriteLine("#");
|
||||
writer.WriteLine("# Class '" + positiveClassLabel + "' (versus the rest)");
|
||||
writer.WriteLine("#");
|
||||
writer.WriteLine("# False positive rate (FPR), True positive rate (TPR)");
|
||||
foreach (var point in rocCurve)
|
||||
{
|
||||
writer.WriteLine(point);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes a count to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="auc">The count.</param>
|
||||
/// <param name="width">The width in characters used to print the count.</param>
|
||||
private void WriteAucValue(StreamWriter writer, double auc, int width)
|
||||
{
|
||||
string paddedCount;
|
||||
if (auc > 0)
|
||||
{
|
||||
paddedCount = auc.ToString(CultureInfo.InvariantCulture);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (auc < 0)
|
||||
{
|
||||
paddedCount = ".";
|
||||
}
|
||||
else
|
||||
{
|
||||
paddedCount = double.IsNaN(auc) ? "NaN" : "0";
|
||||
}
|
||||
}
|
||||
|
||||
paddedCount = paddedCount.Length > width ? paddedCount.Substring(0, width) : paddedCount;
|
||||
paddedCount = paddedCount.PadLeft(width + 2);
|
||||
writer.Write(paddedCount);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes a label to a specified stream writer.
|
||||
/// </summary>
|
||||
/// <param name="writer">The <see cref="StreamWriter"/> to write to.</param>
|
||||
/// <param name="label">The label.</param>
|
||||
/// <param name="width">The width in characters used to print the label.</param>
|
||||
private void WriteLabel(StreamWriter writer, string label, int width)
|
||||
{
|
||||
string paddedLabel = label.Length > width ? label.Substring(0, width) : label;
|
||||
paddedLabel = paddedLabel.PadLeft(width + 2);
|
||||
writer.Write(paddedLabel);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks the positive class label.
|
||||
/// </summary>
|
||||
/// <param name="groundTruth">The ground truth.</param>
|
||||
/// <param name="positiveClassLabel">An optional positive class label provided by the user. Defaults to the first class label.</param>
|
||||
/// <returns>The actually used positive class label.</returns>
|
||||
private string CheckPositiveClassLabel(IList<LabeledFeatureValues> groundTruth, string positiveClassLabel = null)
|
||||
{
|
||||
Debug.Assert(groundTruth != null, "The ground truth labels must not be null.");
|
||||
Debug.Assert(groundTruth.Count > 0, "There must be at least one ground truth label.");
|
||||
|
||||
if (string.IsNullOrEmpty(positiveClassLabel))
|
||||
{
|
||||
positiveClassLabel = groundTruth[0].LabelDistribution.LabelSet.Elements.First();
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!groundTruth.First().LabelDistribution.LabelSet.Contains(positiveClassLabel))
|
||||
{
|
||||
throw new ArgumentException(
|
||||
"The label '" + positiveClassLabel + "' of the positive class is not a valid class label.");
|
||||
}
|
||||
}
|
||||
|
||||
return positiveClassLabel;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to cross-validate a multi-class Bayes point machine classifier on given data.
|
||||
/// </summary>
|
||||
internal class MulticlassBayesPointMachineClassifierCrossValidationModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string dataSetFile = string.Empty;
|
||||
string resultsFile = string.Empty;
|
||||
int crossValidationFoldCount = 5;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--data-set", "FILE", "File with training data", v => dataSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--results", "FILE", "File with cross-validation results", v => resultsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--folds", "NUM", "Number of cross-validation folds (defaults to " + crossValidationFoldCount + ")", v => crossValidationFoldCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load and shuffle data
|
||||
var dataSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(dataSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(dataSet);
|
||||
|
||||
Rand.Restart(562);
|
||||
Rand.Shuffle(dataSet);
|
||||
|
||||
// Create evaluator
|
||||
var evaluatorMapping = Mappings.Classifier.ForEvaluation();
|
||||
var evaluator = new ClassifierEvaluator<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string>(evaluatorMapping);
|
||||
|
||||
// Create performance metrics
|
||||
var accuracy = new List<double>();
|
||||
var negativeLogProbability = new List<double>();
|
||||
var auc = new List<double>();
|
||||
var evidence = new List<double>();
|
||||
var iterationCounts = new List<double>();
|
||||
var trainingTime = new List<double>();
|
||||
|
||||
// Run cross-validation
|
||||
int validationSetSize = dataSet.Count / crossValidationFoldCount;
|
||||
Console.WriteLine("Running {0}-fold cross-validation on {1}", crossValidationFoldCount, dataSetFile);
|
||||
|
||||
// TODO: Use chained mapping to implement cross-validation
|
||||
for (int fold = 0; fold < crossValidationFoldCount; fold++)
|
||||
{
|
||||
// Construct training and validation sets for fold
|
||||
int validationSetStart = fold * validationSetSize;
|
||||
int validationSetEnd = (fold + 1 == crossValidationFoldCount)
|
||||
? dataSet.Count
|
||||
: (fold + 1) * validationSetSize;
|
||||
|
||||
var trainingSet = new List<LabeledFeatureValues>();
|
||||
var validationSet = new List<LabeledFeatureValues>();
|
||||
|
||||
for (int instance = 0; instance < dataSet.Count; instance++)
|
||||
{
|
||||
if (validationSetStart <= instance && instance < validationSetEnd)
|
||||
{
|
||||
validationSet.Add(dataSet[instance]);
|
||||
}
|
||||
else
|
||||
{
|
||||
trainingSet.Add(dataSet[instance]);
|
||||
}
|
||||
}
|
||||
|
||||
// Print info
|
||||
Console.WriteLine(" Fold {0} [validation set instances {1} - {2}]", fold + 1, validationSetStart, validationSetEnd - 1);
|
||||
|
||||
// Create classifier
|
||||
var classifier = BayesPointMachineClassifier.CreateMulticlassClassifier(Mappings.Classifier);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
|
||||
|
||||
int currentIterationCount = 0;
|
||||
classifier.IterationChanged += (sender, eventArgs) => { currentIterationCount = eventArgs.CompletedIterationCount; };
|
||||
|
||||
// Train classifier
|
||||
var stopWatch = new Stopwatch();
|
||||
stopWatch.Start();
|
||||
classifier.Train(trainingSet);
|
||||
stopWatch.Stop();
|
||||
|
||||
// Produce predictions
|
||||
var predictions = classifier.PredictDistribution(validationSet).ToList();
|
||||
var predictedLabels = predictions.Select(
|
||||
prediction => prediction.Aggregate((aggregate, next) => next.Value > aggregate.Value ? next : aggregate).Key).ToList();
|
||||
|
||||
// Iteration count
|
||||
iterationCounts.Add(currentIterationCount);
|
||||
|
||||
// Training time
|
||||
trainingTime.Add(stopWatch.ElapsedMilliseconds);
|
||||
|
||||
// Compute accuracy
|
||||
accuracy.Add(1 - (evaluator.Evaluate(validationSet, predictedLabels, Metrics.ZeroOneError) / predictions.Count));
|
||||
|
||||
// Compute mean negative log probability
|
||||
negativeLogProbability.Add(evaluator.Evaluate(validationSet, predictions, Metrics.NegativeLogProbability) / predictions.Count);
|
||||
|
||||
// Compute M-measure (averaged pairwise AUC)
|
||||
auc.Add(evaluator.AreaUnderRocCurve(validationSet, predictions));
|
||||
|
||||
// Compute log evidence if desired
|
||||
evidence.Add(computeModelEvidence ? classifier.LogModelEvidence : double.NaN);
|
||||
|
||||
// Persist performance metrics
|
||||
Console.WriteLine(
|
||||
" Accuracy = {0,5:0.0000} NegLogProb = {1,5:0.0000} AUC = {2,5:0.0000}{3} Iterations = {4} Training time = {5}",
|
||||
accuracy[fold],
|
||||
negativeLogProbability[fold],
|
||||
auc[fold],
|
||||
computeModelEvidence ? string.Format(" Log evidence = {0,5:0.0000}", evidence[fold]) : string.Empty,
|
||||
iterationCounts[fold],
|
||||
BayesPointMachineClassifierModuleUtilities.FormatElapsedTime(trainingTime[fold]));
|
||||
|
||||
BayesPointMachineClassifierModuleUtilities.SavePerformanceMetrics(
|
||||
resultsFile, accuracy, negativeLogProbability, auc, evidence, iterationCounts, trainingTime);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to incrementally train a multi-class Bayes point machine classifier on some given data.
|
||||
/// </summary>
|
||||
internal class MulticlassBayesPointMachineClassifierIncrementalTrainingModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingSetFile = string.Empty;
|
||||
string inputModelFile = string.Empty;
|
||||
string outputModelFile = string.Empty;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--input-model", "FILE", "File with the trained multi-class Bayes point machine model", v => inputModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File to store the incrementally trained multi-class Bayes point machine model", v => outputModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
|
||||
|
||||
var classifier = BayesPointMachineClassifier.LoadMulticlassClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(inputModelFile);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
|
||||
classifier.TrainIncremental(trainingSet);
|
||||
|
||||
classifier.Save(outputModelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to predict labels given a trained multi-class Bayes point machine classifier model and a test set.
|
||||
/// </summary>
|
||||
internal class MulticlassBayesPointMachineClassifierPredictionModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string testSetFile = string.Empty;
|
||||
string modelFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--test-set", "FILE", "File with test data", v => testSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with a trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File to store predictions for the test data", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var testSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(testSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(testSet);
|
||||
|
||||
var classifier =
|
||||
BayesPointMachineClassifier.LoadMulticlassClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
|
||||
|
||||
// Predict labels
|
||||
var predictions = classifier.PredictDistribution(testSet);
|
||||
|
||||
// Write labels to file
|
||||
ClassifierPersistenceUtils.SaveLabelDistributions(predictionsFile, predictions);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to sample weights from a trained multi-class Bayes point machine classifier model.
|
||||
/// </summary>
|
||||
internal class MulticlassBayesPointMachineClassifierSampleWeightsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string modelFile = string.Empty;
|
||||
string samplesFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with a trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--samples", "FILE", "File to store samples of the weights", v => samplesFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var classifier =
|
||||
BayesPointMachineClassifier.LoadMulticlassClassifier<IList<LabeledFeatureValues>, LabeledFeatureValues, IList<LabelDistribution>, string, IDictionary<string, double>>(modelFile);
|
||||
|
||||
BayesPointMachineClassifierModuleUtilities.SampleWeights(classifier, samplesFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
/// <summary>
|
||||
/// A command-line module to diagnose training of a multi-class Bayes point machine classifier on given data.
|
||||
/// </summary>
|
||||
internal class MulticlassBayesPointMachineClassifierTrainingDiagnosticsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingSetFile = string.Empty;
|
||||
string maxParameterChangesFile = string.Empty;
|
||||
string modelFile = string.Empty;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--results", "FILE", "File to store the maximum parameter differences", v => maxParameterChangesFile = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
|
||||
|
||||
var classifier = BayesPointMachineClassifier.CreateMulticlassClassifier(Mappings.Classifier);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
|
||||
BayesPointMachineClassifierModuleUtilities.DiagnoseClassifier(classifier, trainingSet, maxParameterChangesFile, modelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Linq;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to train a multi-class Bayes point machine classifier on some given data.
|
||||
/// </summary>
|
||||
internal class MulticlassBayesPointMachineClassifierTrainingModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingSetFile = string.Empty;
|
||||
string modelFile = string.Empty;
|
||||
int iterationCount = BayesPointMachineClassifierTrainingSettings.IterationCountDefault;
|
||||
int batchCount = BayesPointMachineClassifierTrainingSettings.BatchCountDefault;
|
||||
bool computeModelEvidence = BayesPointMachineClassifierTrainingSettings.ComputeModelEvidenceDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-set", "FILE", "File with training data", v => trainingSetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File to store the trained multi-class Bayes point machine model", v => modelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of training algorithm iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--compute-evidence", "Compute model evidence (defaults to " + computeModelEvidence + ")", () => computeModelEvidence = true);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var trainingSet = ClassifierPersistenceUtils.LoadLabeledFeatureValues(trainingSetFile);
|
||||
BayesPointMachineClassifierModuleUtilities.WriteDataSetInfo(trainingSet);
|
||||
|
||||
var featureSet = trainingSet.Count > 0 ? trainingSet.First().FeatureSet : null;
|
||||
var mapping = new ClassifierMapping(featureSet);
|
||||
var classifier = BayesPointMachineClassifier.CreateMulticlassClassifier(mapping);
|
||||
classifier.Settings.Training.IterationCount = iterationCount;
|
||||
classifier.Settings.Training.BatchCount = batchCount;
|
||||
classifier.Settings.Training.ComputeModelEvidence = computeModelEvidence;
|
||||
|
||||
classifier.Train(trainingSet);
|
||||
|
||||
if (classifier.Settings.Training.ComputeModelEvidence)
|
||||
{
|
||||
Console.WriteLine("Log evidence = {0,10:0.0000}", classifier.LogModelEvidence);
|
||||
}
|
||||
|
||||
classifier.Save(modelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<Import Project="$(MSBuildThisFileDirectory)..\..\..\..\build\common.props" />
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
<AssemblyName>Learner</AssemblyName>
|
||||
<ErrorReport>prompt</ErrorReport>
|
||||
<Prefer32Bit>false</Prefer32Bit>
|
||||
<DefineConstants>TRACE</DefineConstants>
|
||||
<RootNamespace>Microsoft.ML.Probabilistic.Learners.Runners</RootNamespace>
|
||||
<Configurations>Debug;Release;DebugFull;DebugCore;ReleaseFull;ReleaseCore</Configurations>
|
||||
</PropertyGroup>
|
||||
<Choose>
|
||||
<When Condition="'$(Configuration)'=='DebugFull' OR '$(Configuration)'=='ReleaseFull'">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net461</TargetFramework>
|
||||
</PropertyGroup>
|
||||
</When>
|
||||
<When Condition="'$(Configuration)'=='DebugCore' OR '$(Configuration)'=='ReleaseCore'">
|
||||
<PropertyGroup>
|
||||
<TargetFramework>net5.0</TargetFramework>
|
||||
</PropertyGroup>
|
||||
</When>
|
||||
<Otherwise>
|
||||
<PropertyGroup>
|
||||
<TargetFrameworks>net5.0;net461</TargetFrameworks>
|
||||
</PropertyGroup>
|
||||
</Otherwise>
|
||||
</Choose>
|
||||
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugFull|AnyCPU' OR '$(Configuration)|$(Platform)'=='DebugCore|AnyCPU'">
|
||||
<DebugSymbols>true</DebugSymbols>
|
||||
<DebugType>full</DebugType>
|
||||
<Optimize>false</Optimize>
|
||||
<DefineConstants>$(DefineConstants);DEBUG</DefineConstants>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\..\Runtime\Runtime.csproj" />
|
||||
<ProjectReference Include="..\..\Classifier\Classifier.csproj" />
|
||||
<ProjectReference Include="..\..\Core\Core.csproj" />
|
||||
<ProjectReference Include="..\..\Recommender\Recommender.csproj" />
|
||||
<ProjectReference Include="..\Common\Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="..\..\..\Shared\SharedAssemblyFileVersion.cs" />
|
||||
<Compile Include="..\..\..\Shared\SharedAssemblyInfo.cs" />
|
||||
</ItemGroup>
|
||||
</Project>
|
|
@ -1,74 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
|
||||
/// <summary>
|
||||
/// The program.
|
||||
/// </summary>
|
||||
public static class Program
|
||||
{
|
||||
/// <summary>
|
||||
/// The entry point for the program.
|
||||
/// </summary>
|
||||
/// <param name="args">The array of command-line arguments.</param>
|
||||
public static void Main(string[] args)
|
||||
{
|
||||
// Matchbox recommender
|
||||
var recommenderModuleSelector = new CommandLineModuleSelector();
|
||||
recommenderModuleSelector.RegisterModule("SplitData", new RecommenderSplitDataModule());
|
||||
recommenderModuleSelector.RegisterModule("GenerateNegativeData", new RecommenderGenerateNegativeDataModule());
|
||||
recommenderModuleSelector.RegisterModule("Train", new RecommenderTrainModule());
|
||||
recommenderModuleSelector.RegisterModule("PredictRatings", new RecommenderPredictRatingsModule());
|
||||
recommenderModuleSelector.RegisterModule("RecommendItems", new RecommenderRecommendItemsModule());
|
||||
recommenderModuleSelector.RegisterModule("FindRelatedUsers", new RecommenderFindRelatedUsersModule());
|
||||
recommenderModuleSelector.RegisterModule("FindRelatedItems", new RecommenderFindRelatedItemsModule());
|
||||
recommenderModuleSelector.RegisterModule("EvaluateRatingPrediction", new RecommenderEvaluateRatingPredictionModule());
|
||||
recommenderModuleSelector.RegisterModule("EvaluateItemRecommendation", new RecommenderEvaluateItemRecommendationModule());
|
||||
recommenderModuleSelector.RegisterModule("EvaluateFindRelatedUsers", new RecommenderEvaluateFindRelatedUsersModule());
|
||||
recommenderModuleSelector.RegisterModule("EvaluateFindRelatedItems", new RecommenderEvaluateFindRelatedItemsModule());
|
||||
|
||||
// Binary Bayes point machine classifier
|
||||
var binaryBayesPointMachineModuleSelector = new CommandLineModuleSelector();
|
||||
binaryBayesPointMachineModuleSelector.RegisterModule("Train", new BinaryBayesPointMachineClassifierTrainingModule());
|
||||
binaryBayesPointMachineModuleSelector.RegisterModule("TrainIncremental", new BinaryBayesPointMachineClassifierIncrementalTrainingModule());
|
||||
binaryBayesPointMachineModuleSelector.RegisterModule("Predict", new BinaryBayesPointMachineClassifierPredictionModule());
|
||||
binaryBayesPointMachineModuleSelector.RegisterModule("CrossValidate", new BinaryBayesPointMachineClassifierCrossValidationModule());
|
||||
binaryBayesPointMachineModuleSelector.RegisterModule("SampleWeights", new BinaryBayesPointMachineClassifierSampleWeightsModule());
|
||||
binaryBayesPointMachineModuleSelector.RegisterModule("DiagnoseTrain", new BinaryBayesPointMachineClassifierTrainingDiagnosticsModule());
|
||||
|
||||
// Multi-class Bayes point machine classifier
|
||||
var multiclassBayesPointMachineModuleSelector = new CommandLineModuleSelector();
|
||||
multiclassBayesPointMachineModuleSelector.RegisterModule("Train", new MulticlassBayesPointMachineClassifierTrainingModule());
|
||||
multiclassBayesPointMachineModuleSelector.RegisterModule("TrainIncremental", new MulticlassBayesPointMachineClassifierIncrementalTrainingModule());
|
||||
multiclassBayesPointMachineModuleSelector.RegisterModule("Predict", new MulticlassBayesPointMachineClassifierPredictionModule());
|
||||
multiclassBayesPointMachineModuleSelector.RegisterModule("CrossValidate", new MulticlassBayesPointMachineClassifierCrossValidationModule());
|
||||
multiclassBayesPointMachineModuleSelector.RegisterModule("SampleWeights", new MulticlassBayesPointMachineClassifierSampleWeightsModule());
|
||||
multiclassBayesPointMachineModuleSelector.RegisterModule("DiagnoseTrain", new MulticlassBayesPointMachineClassifierTrainingDiagnosticsModule());
|
||||
|
||||
// Classifier
|
||||
var classifierModuleSelector = new CommandLineModuleSelector();
|
||||
classifierModuleSelector.RegisterModule("Evaluate", new ClassifierEvaluationModule());
|
||||
classifierModuleSelector.RegisterModule("BinaryBayesPointMachine", binaryBayesPointMachineModuleSelector);
|
||||
classifierModuleSelector.RegisterModule("MulticlassBayesPointMachine", multiclassBayesPointMachineModuleSelector);
|
||||
|
||||
// Modules
|
||||
var moduleSelector = new CommandLineModuleSelector();
|
||||
moduleSelector.RegisterModule("Recommender", recommenderModuleSelector);
|
||||
moduleSelector.RegisterModule("Classifier", classifierModuleSelector);
|
||||
|
||||
try
|
||||
{
|
||||
bool success = moduleSelector.Run(args, Environment.GetCommandLineArgs()[0]);
|
||||
Environment.Exit(success ? 0 : 1);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Console.WriteLine("An error has occured in one of the modules. {0}", e.Message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Reflection;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
|
||||
// General Information about an assembly is controlled through the following
|
||||
// set of attributes. Change these attribute values to modify the information
|
||||
// associated with an assembly.
|
||||
[assembly: AssemblyTitle("CommandLine")]
|
||||
[assembly: AssemblyDescription("Runs the Infer.NET learner modules from the command line")]
|
||||
|
||||
// Setting ComVisible to false makes the types in this assembly not visible
|
||||
// to COM components. If you need to access a type in this assembly from
|
||||
// COM, set the ComVisible attribute to true on that type.
|
||||
[assembly: ComVisible(false)]
|
||||
|
||||
// The following GUID is for the ID of the typelib if this project is exposed to COM
|
||||
[assembly: Guid("9d6aa264-e13a-4314-a39f-d485e21c3103")]
|
||||
|
||||
[assembly: InternalsVisibleTo("ML.Probabilistic.Learners.Tests,PublicKey=0024000004800000940000000602000000240000525341310004000001000100551f07a755a3e3f2901fa321ab631d13d6192b4e6ac9c87279500f49d6635cde6902587752eff20402f46f6ea9c3d80e827580a799840aaab9a49b1d2597e4c1798ee93c5cb66851e9d22f4d6e8110571f4a2e59f1d760f7be04fb10e7dc43ee7ed2831907731427b9815c5fe7f4888f9933ee7a1ad5d1f293fd8ab834fac1be")]
|
|
@ -1,59 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to evaluate item recommendation.
|
||||
/// </summary>
|
||||
internal class RecommenderEvaluateFindRelatedItemsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string testDatasetFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
string reportFile = string.Empty;
|
||||
int minCommonRatingCount = 5;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of users that the query item and the related item should have been rated by in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
|
||||
IDictionary<Item, IEnumerable<Item>> relatedItems = RecommenderPersistenceUtils.LoadRelatedItems(predictionsFile);
|
||||
|
||||
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
|
||||
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
|
||||
using (var writer = new StreamWriter(reportFile))
|
||||
{
|
||||
writer.WriteLine(
|
||||
"L1 Sim NDCG: {0:0.000}",
|
||||
evaluator.RelatedItemsMetric(testDataset, relatedItems, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedManhattanSimilarity));
|
||||
writer.WriteLine(
|
||||
"L2 Sim NDCG: {0:0.000}",
|
||||
evaluator.RelatedItemsMetric(testDataset, relatedItems, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedEuclideanSimilarity));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to evaluate item recommendation.
|
||||
/// </summary>
|
||||
internal class RecommenderEvaluateFindRelatedUsersModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string testDatasetFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
string reportFile = string.Empty;
|
||||
int minCommonRatingCount = 5;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of items that the query user and the related user should have rated in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
|
||||
IDictionary<User, IEnumerable<User>> relatedUsers = RecommenderPersistenceUtils.LoadRelatedUsers(predictionsFile);
|
||||
|
||||
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
|
||||
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
|
||||
using (var writer = new StreamWriter(reportFile))
|
||||
{
|
||||
writer.WriteLine(
|
||||
"L1 Sim NDCG: {0:0.000}",
|
||||
evaluator.RelatedUsersMetric(testDataset, relatedUsers, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedManhattanSimilarity));
|
||||
writer.WriteLine(
|
||||
"L2 Sim NDCG: {0:0.000}",
|
||||
evaluator.RelatedUsersMetric(testDataset, relatedUsers, minCommonRatingCount, Metrics.Ndcg, Metrics.NormalizedEuclideanSimilarity));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to evaluate item recommendation.
|
||||
/// </summary>
|
||||
internal class RecommenderEvaluateItemRecommendationModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string testDatasetFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
string reportFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
|
||||
int minRating = Mappings.StarRatingRecommender.GetRatingInfo(testDataset).MinStarRating;
|
||||
|
||||
IDictionary<User, IEnumerable<Item>> recommendedItems = RecommenderPersistenceUtils.LoadRecommendedItems(predictionsFile);
|
||||
|
||||
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
|
||||
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
|
||||
using (var writer = new StreamWriter(reportFile))
|
||||
{
|
||||
writer.WriteLine(
|
||||
"NDCG: {0:0.000}",
|
||||
evaluator.ItemRecommendationMetric(
|
||||
testDataset,
|
||||
recommendedItems,
|
||||
Metrics.Ndcg,
|
||||
rating => Convert.ToDouble(rating) - minRating + 1));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to evaluate rating prediction.
|
||||
/// </summary>
|
||||
internal class RecommenderEvaluateRatingPredictionModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string testDatasetFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
string reportFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--test-data", "FILE", "Test dataset used to obtain ground truth", v => testDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "Predictions to evaluate", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--report", "FILE", "Evaluation report file", v => reportFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(testDatasetFile);
|
||||
IDictionary<User, IDictionary<Item, int>> ratingPredictions = RecommenderPersistenceUtils.LoadPredictedRatings(predictionsFile);
|
||||
|
||||
var evaluatorMapping = Mappings.StarRatingRecommender.ForEvaluation();
|
||||
var evaluator = new StarRatingRecommenderEvaluator<RecommenderDataset, User, Item, int>(evaluatorMapping);
|
||||
using (var writer = new StreamWriter(reportFile))
|
||||
{
|
||||
writer.WriteLine(
|
||||
"Mean absolute error: {0:0.000}",
|
||||
evaluator.RatingPredictionMetric(testDataset, ratingPredictions, Metrics.AbsoluteError));
|
||||
writer.WriteLine(
|
||||
"Root mean squared error: {0:0.000}",
|
||||
Math.Sqrt(evaluator.RatingPredictionMetric(testDataset, ratingPredictions, Metrics.SquaredError)));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
|
||||
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to find related items given a trained recommender model and a dataset.
|
||||
/// </summary>
|
||||
internal class RecommenderFindRelatedItemsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string datasetFile = string.Empty;
|
||||
string trainedModelFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
int maxRelatedItemCount = 5;
|
||||
int minCommonRatingCount = 5;
|
||||
int minRelatedItemPoolSize = 5;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--max-items", "NUM", "Maximum number of related items for a single item; defaults to 5", v => maxRelatedItemCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--min-common-users", "NUM", "Minimum number of users that the query item and the related item should have been rated by in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the related item pool for a single item; defaults to 5", v => minRelatedItemPoolSize = v, CommandLineParameterType.Optional);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
|
||||
|
||||
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
|
||||
var evaluator = new RecommenderEvaluator<RecommenderDataset, User, Item, int, int, RatingDistribution>(
|
||||
Mappings.StarRatingRecommender.ForEvaluation());
|
||||
IDictionary<Item, IEnumerable<Item>> relatedItems = evaluator.FindRelatedItemsRatedBySameUsers(
|
||||
trainedModel, testDataset, maxRelatedItemCount, minCommonRatingCount, minRelatedItemPoolSize);
|
||||
RecommenderPersistenceUtils.SaveRelatedItems(predictionsFile, relatedItems);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
|
||||
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to find related users given a trained recommender model and a dataset.
|
||||
/// </summary>
|
||||
internal class RecommenderFindRelatedUsersModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string datasetFile = string.Empty;
|
||||
string trainedModelFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
int maxRelatedUserCount = 5;
|
||||
int minCommonRatingCount = 5;
|
||||
int minRelatedUserPoolSize = 5;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--max-users", "NUM", "Maximum number of related users for a single user; defaults to 5", v => maxRelatedUserCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--min-common-items", "NUM", "Minimum number of items that the query user and the related user should have rated in common; defaults to 5", v => minCommonRatingCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the related user pool for a single user; defaults to 5", v => minRelatedUserPoolSize = v, CommandLineParameterType.Optional);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
|
||||
|
||||
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
|
||||
var evaluator = new RecommenderEvaluator<RecommenderDataset, User, Item, int, int, RatingDistribution>(
|
||||
Mappings.StarRatingRecommender.ForEvaluation());
|
||||
IDictionary<User, IEnumerable<User>> relatedUsers = evaluator.FindRelatedUsersWhoRatedSameItems(
|
||||
trainedModel, testDataset, maxRelatedUserCount, minCommonRatingCount, minRelatedUserPoolSize);
|
||||
RecommenderPersistenceUtils.SaveRelatedUsers(predictionsFile, relatedUsers);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Linq;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to generate negative data for a positive-only recommender dataset.
|
||||
/// </summary>
|
||||
internal class RecommenderGenerateNegativeDataModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string inputDatasetFile = string.Empty;
|
||||
string outputDatasetFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--input-data", "FILE", "Input dataset, treated as if all the ratings are positive", v => inputDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--output-data", "FILE", "Output dataset with both posisitve and negative data", v => outputDatasetFile = v, CommandLineParameterType.Required);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var generatorMapping = Mappings.StarRatingRecommender.WithGeneratedNegativeData();
|
||||
|
||||
var inputDataset = RecommenderDataset.Load(inputDatasetFile);
|
||||
var outputDataset = new RecommenderDataset(
|
||||
generatorMapping.GetInstances(inputDataset).Select(i => new RatedUserItem(i.User, i.Item, i.Rating)),
|
||||
generatorMapping.GetRatingInfo(inputDataset));
|
||||
outputDataset.Save(outputDatasetFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to predict ratings given a trained recommender model and a dataset.
|
||||
/// </summary>
|
||||
internal class RecommenderPredictRatingsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string datasetFile = string.Empty;
|
||||
string trainedModelFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
|
||||
|
||||
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
|
||||
IDictionary<User, IDictionary<Item, int>> predictions = trainedModel.Predict(testDataset);
|
||||
RecommenderPersistenceUtils.SavePredictedRatings(predictionsFile, predictions);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,55 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using System.Collections.Generic;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
|
||||
using RatingDistribution = System.Collections.Generic.IDictionary<int, double>;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to predict ratings given a trained recommender model and a dataset.
|
||||
/// </summary>
|
||||
internal class RecommenderRecommendItemsModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string datasetFile = string.Empty;
|
||||
string trainedModelFile = string.Empty;
|
||||
string predictionsFile = string.Empty;
|
||||
int maxRecommendedItemCount = 5;
|
||||
int minRecommendationPoolSize = 5;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--data", "FILE", "Dataset to make predictions for", v => datasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--model", "FILE", "File with trained model", v => trainedModelFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--predictions", "FILE", "File with generated predictions", v => predictionsFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--max-items", "NUM", "Maximum number of items to recommend; defaults to 5", v => maxRecommendedItemCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--min-pool-size", "NUM", "Minimum size of the recommendation pool for a single user; defaults to 5", v => minRecommendationPoolSize = v, CommandLineParameterType.Optional);
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset testDataset = RecommenderDataset.Load(datasetFile);
|
||||
|
||||
var trainedModel = MatchboxRecommender.Load<RecommenderDataset, User, Item, RatingDistribution, DummyFeatureSource>(trainedModelFile);
|
||||
var evaluator = new RecommenderEvaluator<RecommenderDataset, User, Item, int, int, RatingDistribution>(
|
||||
Mappings.StarRatingRecommender.ForEvaluation());
|
||||
IDictionary<User, IEnumerable<Item>> itemRecommendations = evaluator.RecommendRatedItems(
|
||||
trainedModel, testDataset, maxRecommendedItemCount, minRecommendationPoolSize);
|
||||
RecommenderPersistenceUtils.SaveRecommendedItems(predictionsFile, itemRecommendations);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
using Microsoft.ML.Probabilistic.Learners.Mappings;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <summary>
|
||||
/// A command-line module to split a given recommendation dataset into training and test parts.
|
||||
/// </summary>
|
||||
internal class RecommenderSplitDataModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string inputDatasetFile = string.Empty;
|
||||
string outputTrainingDatasetFile = string.Empty;
|
||||
string outputTestDatasetFile = string.Empty;
|
||||
double trainingOnlyUserFraction = 0.5;
|
||||
double testUserRatingTrainingFraction = 0.25;
|
||||
double coldUserFraction = 0;
|
||||
double coldItemFraction = 0;
|
||||
double ignoredUserFraction = 0;
|
||||
double ignoredItemFraction = 0;
|
||||
bool removeOccasionalColdItems = false;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--input-data", "FILE", "Dataset to split", v => inputDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--output-data-train", "FILE", "Training part of the split dataset", v => outputTrainingDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--output-data-test", "FILE", "Test part of the split dataset", v => outputTestDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--training-users", "NUM", "Fraction of training-only users; defaults to 0.5", (double v) => trainingOnlyUserFraction = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--test-user-training-ratings", "NUM", "Fraction of test user ratings for training; defaults to 0.25", (double v) => testUserRatingTrainingFraction = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--cold-users", "NUM", "Fraction of cold (test-only) users; defaults to 0", (double v) => coldUserFraction = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--cold-items", "NUM", "Fraction of cold (test-only) items; defaults to 0", (double v) => coldItemFraction = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--ignored-users", "NUM", "Fraction of ignored users; defaults to 0", (double v) => ignoredUserFraction = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--ignored-items", "NUM", "Fraction of ignored items; defaults to 0", (double v) => ignoredItemFraction = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--remove-occasional-cold-items", "Remove occasionally produced cold items", () => removeOccasionalColdItems = true);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var splittingMapping = Mappings.StarRatingRecommender.SplitToTrainTest(
|
||||
trainingOnlyUserFraction,
|
||||
testUserRatingTrainingFraction,
|
||||
coldUserFraction,
|
||||
coldItemFraction,
|
||||
ignoredUserFraction,
|
||||
ignoredItemFraction,
|
||||
removeOccasionalColdItems);
|
||||
|
||||
var inputDataset = RecommenderDataset.Load(inputDatasetFile);
|
||||
var outputTrainingDataset = new RecommenderDataset(
|
||||
splittingMapping.GetInstances(SplitInstanceSource.Training(inputDataset)),
|
||||
inputDataset.StarRatingInfo);
|
||||
outputTrainingDataset.Save(outputTrainingDatasetFile);
|
||||
var outputTestDataset = new RecommenderDataset(
|
||||
splittingMapping.GetInstances(SplitInstanceSource.Test(inputDataset)),
|
||||
inputDataset.StarRatingInfo);
|
||||
outputTestDataset.Save(outputTestDatasetFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Learners.Runners
|
||||
{
|
||||
/// <summary>
|
||||
/// A command-line module to train a recommendation model on a given dataset.
|
||||
/// </summary>
|
||||
internal class RecommenderTrainModule : CommandLineModule
|
||||
{
|
||||
/// <summary>
|
||||
/// Runs the module.
|
||||
/// </summary>
|
||||
/// <param name="args">The command line arguments for the module.</param>
|
||||
/// <param name="usagePrefix">The prefix to print before the usage string.</param>
|
||||
/// <returns>True if the run was successful, false otherwise.</returns>
|
||||
public override bool Run(string[] args, string usagePrefix)
|
||||
{
|
||||
string trainingDatasetFile = string.Empty;
|
||||
string trainedModelFile = string.Empty;
|
||||
int traitCount = MatchboxRecommenderTrainingSettings.TraitCountDefault;
|
||||
int iterationCount = MatchboxRecommenderTrainingSettings.IterationCountDefault;
|
||||
int batchCount = MatchboxRecommenderTrainingSettings.BatchCountDefault;
|
||||
bool useUserFeatures = MatchboxRecommenderTrainingSettings.UseUserFeaturesDefault;
|
||||
bool useItemFeatures = MatchboxRecommenderTrainingSettings.UseItemFeaturesDefault;
|
||||
|
||||
var parser = new CommandLineParser();
|
||||
parser.RegisterParameterHandler("--training-data", "FILE", "Training dataset", v => trainingDatasetFile = v, CommandLineParameterType.Required);
|
||||
parser.RegisterParameterHandler("--trained-model", "FILE", "Trained model file", v => trainedModelFile = v, CommandLineParameterType.Required);
|
||||
|
||||
parser.RegisterParameterHandler("--traits", "NUM", "Number of traits (defaults to " + traitCount + ")", v => traitCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--iterations", "NUM", "Number of inference iterations (defaults to " + iterationCount + ")", v => iterationCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--batches", "NUM", "Number of batches to split the training data into (defaults to " + batchCount + ")", v => batchCount = v, CommandLineParameterType.Optional);
|
||||
parser.RegisterParameterHandler("--use-user-features", "Use user features in the model (defaults to " + useUserFeatures + ")", () => useUserFeatures = true);
|
||||
parser.RegisterParameterHandler("--use-item-features", "Use item features in the model (defaults to " + useItemFeatures + ")", () => useItemFeatures = true);
|
||||
|
||||
if (!parser.TryParse(args, usagePrefix))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
RecommenderDataset trainingDataset = RecommenderDataset.Load(trainingDatasetFile);
|
||||
var recommender = MatchboxRecommender.Create(Mappings.StarRatingRecommender);
|
||||
recommender.Settings.Training.TraitCount = traitCount;
|
||||
recommender.Settings.Training.IterationCount = iterationCount;
|
||||
recommender.Settings.Training.BatchCount = batchCount;
|
||||
recommender.Settings.Training.UseUserFeatures = useUserFeatures;
|
||||
recommender.Settings.Training.UseItemFeatures = useItemFeatures;
|
||||
recommender.Train(trainingDataset);
|
||||
recommender.Save(trainedModelFile);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -147,7 +147,7 @@ namespace Microsoft.ML.Probabilistic.Collections
|
|||
FileStats.AddRead();
|
||||
string path = prefix + index.ToString(CultureInfo.InvariantCulture) + ".bin";
|
||||
if (!File.Exists(path)) return default(T);
|
||||
IFormatter formatter = new BinaryFormatter();
|
||||
IFormatter formatter = new BinaryFormatter(); // This is only used within the runtime for caching temporary files -- it is not used to persist data and no untrusted data is read.
|
||||
using (var stream = new FileStream(prefix + index.ToString(CultureInfo.InvariantCulture) + ".bin", FileMode.Open, FileAccess.Read, FileShare.Read))
|
||||
{
|
||||
return (T) formatter.Deserialize(stream);
|
||||
|
@ -163,7 +163,7 @@ namespace Microsoft.ML.Probabilistic.Collections
|
|||
if (object.ReferenceEquals(value, null) || value.Equals(default(T))) File.Delete(path);
|
||||
else
|
||||
{
|
||||
IFormatter formatter = new BinaryFormatter();
|
||||
IFormatter formatter = new BinaryFormatter(); // This is only used within the runtime for caching temporary files -- it is not used to persist data and no untrusted data is read.
|
||||
using (var stream = new FileStream(path, FileMode.Create, FileAccess.Write, FileShare.None))
|
||||
{
|
||||
formatter.Serialize(stream, value);
|
||||
|
|
|
@ -90,21 +90,6 @@ namespace Microsoft.ML.Probabilistic.Serialization
|
|||
return deserializedVersion;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads an object from the reader of a binary stream.
|
||||
/// </summary>
|
||||
/// <param name="reader">The reader of the binary stream.</param>
|
||||
/// <returns>The object constructed from the binary stream.</returns>
|
||||
public static T ReadObject<T>(this BinaryReader reader)
|
||||
{
|
||||
if (reader == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(reader));
|
||||
}
|
||||
|
||||
return ByteArrayToObject<T>(reader.ReadByteArray());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Reads a <see cref="Guid"/> from the reader of a binary stream.
|
||||
/// </summary>
|
||||
|
@ -260,27 +245,5 @@ namespace Microsoft.ML.Probabilistic.Serialization
|
|||
|
||||
return new GammaArray(length, i => ReadGamma(reader));
|
||||
}
|
||||
|
||||
#region Helper methods
|
||||
|
||||
/// <summary>
|
||||
/// Converts a byte array to an object.
|
||||
/// </summary>
|
||||
/// <param name="array">The byte array to convert.</param>
|
||||
/// <returns>The object for the specified byte array.</returns>
|
||||
private static T ByteArrayToObject<T>(byte[] array)
|
||||
{
|
||||
Debug.Assert(array != null, "The array must not be null.");
|
||||
|
||||
using (var memoryStream = new MemoryStream())
|
||||
{
|
||||
var binaryFormatter = new BinaryFormatter();
|
||||
memoryStream.Write(array, 0, array.Length);
|
||||
memoryStream.Seek(0, SeekOrigin.Begin);
|
||||
return (T)binaryFormatter.Deserialize(memoryStream);
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,28 +19,6 @@ namespace Microsoft.ML.Probabilistic.Serialization
|
|||
/// </summary>
|
||||
public static class BinaryWriterExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Writes the specified object in binary to the stream.
|
||||
/// </summary>
|
||||
/// <param name="writer">The binary writer.</param>
|
||||
/// <param name="obj">The object to write to the stream.</param>
|
||||
public static void WriteObject(this BinaryWriter writer, object obj)
|
||||
{
|
||||
if (writer == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(writer));
|
||||
}
|
||||
|
||||
if (obj == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(obj));
|
||||
}
|
||||
|
||||
byte[] array = ObjectToByteArray(obj);
|
||||
writer.Write(array.Length);
|
||||
writer.Write(array);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes the specified <see cref="Guid"/> in binary to the stream.
|
||||
/// </summary>
|
||||
|
@ -198,26 +176,5 @@ namespace Microsoft.ML.Probabilistic.Serialization
|
|||
writer.Write(element);
|
||||
}
|
||||
}
|
||||
|
||||
#region Helper methods
|
||||
|
||||
/// <summary>
|
||||
/// Converts an object to a byte array.
|
||||
/// </summary>
|
||||
/// <param name="obj">The object to convert.</param>
|
||||
/// <returns>The byte array of the specified object.</returns>
|
||||
private static byte[] ObjectToByteArray(object obj)
|
||||
{
|
||||
Debug.Assert(obj != null, "The object must not be null.");
|
||||
|
||||
using (var memoryStream = new MemoryStream())
|
||||
{
|
||||
var binaryFormatter = new BinaryFormatter();
|
||||
binaryFormatter.Serialize(memoryStream, obj);
|
||||
return memoryStream.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace Microsoft.ML.Probabilistic.Serialization
|
|||
|
||||
public T ReadObject<T>()
|
||||
{
|
||||
return binaryReader.ReadObject<T>();
|
||||
throw new NotImplementedException("This reader cannot read objects");
|
||||
}
|
||||
|
||||
public string ReadString()
|
||||
|
|
|
@ -48,7 +48,7 @@ namespace Microsoft.ML.Probabilistic.Serialization
|
|||
|
||||
public void WriteObject(object value)
|
||||
{
|
||||
binaryWriter.WriteObject(value);
|
||||
throw new NotImplementedException("This writer cannot write objects.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
#region Serialization
|
||||
|
||||
/// <summary>
|
||||
/// Constructor used by Json and BinaryFormatter serializers. Informally needed to be
|
||||
/// Constructor used by Json serializer. Informally needed to be
|
||||
/// implemented for <see cref="ISerializable"/> interface.
|
||||
/// </summary>
|
||||
internal DataContainer(SerializationInfo info, StreamingContext context)
|
||||
|
|
|
@ -1800,7 +1800,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
new ImmutableDiscreteChar(Storage.Read(readInt32, readDouble));
|
||||
|
||||
/// <summary>
|
||||
/// Constructor used during deserialization by Newtonsoft.Json and BinaryFormatter.
|
||||
/// Constructor used during deserialization by Newtonsoft.Json.
|
||||
/// </summary>
|
||||
private ImmutableDiscreteChar(SerializationInfo info, StreamingContext context) =>
|
||||
this.data_ = Storage.FromSerializationInfo(info);
|
||||
|
|
|
@ -520,22 +520,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in native format and
|
||||
/// features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void DenseBinaryNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryNativeMapping),
|
||||
this.denseNativeTrainingData,
|
||||
this.denseNativePredictionData,
|
||||
this.expectedPredictiveBernoulliDistributions,
|
||||
this.expectedIncrementalPredictiveBernoulliDistributions,
|
||||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier for data in native format and
|
||||
/// features in a dense representation.
|
||||
|
@ -820,22 +804,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in native format and
|
||||
/// features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void SparseBinaryNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryNativeMapping),
|
||||
this.sparseNativeTrainingData,
|
||||
this.sparseNativePredictionData,
|
||||
this.expectedPredictiveBernoulliDistributions,
|
||||
this.expectedIncrementalPredictiveBernoulliDistributions,
|
||||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier for data in native format and
|
||||
/// features in a sparse representation.
|
||||
|
@ -1079,22 +1047,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in native format and
|
||||
/// features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void DenseMulticlassNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassNativeMapping),
|
||||
this.denseNativeTrainingData,
|
||||
this.denseNativePredictionData,
|
||||
this.expectedPredictiveDiscreteDistributions,
|
||||
this.expectedIncrementalPredictiveDiscreteDistributions,
|
||||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in native format and
|
||||
/// features in a dense representation.
|
||||
|
@ -1361,22 +1313,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in native format and
|
||||
/// features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void SparseMulticlassNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassNativeMapping),
|
||||
this.sparseNativeTrainingData,
|
||||
this.sparseNativePredictionData,
|
||||
this.expectedPredictiveDiscreteDistributions,
|
||||
this.expectedIncrementalPredictiveDiscreteDistributions,
|
||||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in native format and
|
||||
/// features in a sparse representation.
|
||||
|
@ -1619,22 +1555,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionSimpleTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in standard format and
|
||||
/// features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void DenseBinaryStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryStandardMapping),
|
||||
this.denseStandardTrainingData,
|
||||
this.denseStandardPredictionData,
|
||||
this.expectedPredictiveBernoulliStandardDistributions,
|
||||
this.expectedIncrementalPredictiveBernoulliStandardDistributions,
|
||||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier for data in standard format and
|
||||
/// features in a dense representation.
|
||||
|
@ -1851,22 +1771,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier for data in standard format and
|
||||
/// features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void SparseBinaryStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateBinaryClassifier(this.binaryStandardMapping),
|
||||
this.sparseStandardTrainingData,
|
||||
this.sparseStandardPredictionData,
|
||||
this.expectedPredictiveBernoulliStandardDistributions,
|
||||
this.expectedIncrementalPredictiveBernoulliStandardDistributions,
|
||||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier for data in standard format and
|
||||
/// features in a sparse representation.
|
||||
|
@ -2074,22 +1978,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in standard format and
|
||||
/// features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void DenseMulticlassStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassStandardMapping),
|
||||
this.denseStandardTrainingData,
|
||||
this.denseStandardPredictionData,
|
||||
this.expectedPredictiveDiscreteStandardDistributions,
|
||||
this.expectedIncrementalPredictiveDiscreteStandardDistributions,
|
||||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in standard format and
|
||||
/// features in a dense representation.
|
||||
|
@ -2316,22 +2204,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier for data in standard format and
|
||||
/// features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void SparseMulticlassStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateMulticlassClassifier(this.multiclassStandardMapping),
|
||||
this.sparseStandardTrainingData,
|
||||
this.sparseStandardPredictionData,
|
||||
this.expectedPredictiveDiscreteStandardDistributions,
|
||||
this.expectedIncrementalPredictiveDiscreteStandardDistributions,
|
||||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier for data in standard format and
|
||||
/// features in a sparse representation.
|
||||
|
@ -2569,22 +2441,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianDenseBinaryNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryNativeMapping),
|
||||
this.denseNativeTrainingData,
|
||||
this.denseNativePredictionData,
|
||||
this.gaussianPriorExpectedPredictiveBernoulliDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveBernoulliDistributions,
|
||||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a dense representation.
|
||||
|
@ -2839,22 +2695,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianSparseBinaryNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryNativeMapping),
|
||||
this.sparseNativeTrainingData,
|
||||
this.sparseNativePredictionData,
|
||||
this.gaussianPriorExpectedPredictiveBernoulliDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveBernoulliDistributions,
|
||||
CheckPredictedBernoulliDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a sparse representation.
|
||||
|
@ -3099,22 +2939,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianDenseMulticlassNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassNativeMapping),
|
||||
this.denseNativeTrainingData,
|
||||
this.denseNativePredictionData,
|
||||
this.gaussianPriorExpectedPredictiveDiscreteDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveDiscreteDistributions,
|
||||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a dense representation.
|
||||
|
@ -3382,22 +3206,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianSparseMulticlassNativeSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassNativeMapping),
|
||||
this.sparseNativeTrainingData,
|
||||
this.sparseNativePredictionData,
|
||||
this.gaussianPriorExpectedPredictiveDiscreteDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveDiscreteDistributions,
|
||||
CheckPredictedDiscreteDistributionNativeTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in native format and features in a sparse representation.
|
||||
|
@ -3593,22 +3401,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianDenseBinaryStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryStandardMapping),
|
||||
this.denseStandardTrainingData,
|
||||
this.denseStandardPredictionData,
|
||||
this.gaussianPriorExpectedPredictiveBernoulliStandardDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveBernoulliStandardDistributions,
|
||||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a dense representation.
|
||||
|
@ -3826,22 +3618,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianSparseBinaryStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier(this.binaryStandardMapping),
|
||||
this.sparseStandardTrainingData,
|
||||
this.sparseStandardPredictionData,
|
||||
this.gaussianPriorExpectedPredictiveBernoulliStandardDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveBernoulliStandardDistributions,
|
||||
CheckPredictedBernoulliDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the binary Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a sparse representation.
|
||||
|
@ -4050,22 +3826,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a dense representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianDenseMulticlassStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassStandardMapping),
|
||||
this.denseStandardTrainingData,
|
||||
this.denseStandardPredictionData,
|
||||
this.gaussianPriorExpectedPredictiveDiscreteStandardDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveDiscreteStandardDistributions,
|
||||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a dense representation.
|
||||
|
@ -4293,22 +4053,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization and deserialization of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a sparse representation.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void GaussianSparseMulticlassStandardSerializationRegressionTest()
|
||||
{
|
||||
TestRegressionSerialization(
|
||||
BayesPointMachineClassifier.CreateGaussianPriorMulticlassClassifier(this.multiclassStandardMapping),
|
||||
this.sparseStandardTrainingData,
|
||||
this.sparseStandardPredictionData,
|
||||
this.gaussianPriorExpectedPredictiveDiscreteStandardDistributions,
|
||||
this.gaussianPriorExpectedIncrementalPredictiveDiscreteStandardDistributions,
|
||||
CheckPredictedDiscreteDistributionStandardTestingDataset);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests correctness of training of the multi-class Bayes point machine classifier with <see cref="Gaussian"/> prior distributions
|
||||
/// over weights for data in standard format and features in a sparse representation.
|
||||
|
@ -4988,57 +4732,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests .NET serialization and deserialization of the Bayes point machine classifier.
|
||||
/// </summary>
|
||||
/// <typeparam name="TInstanceSource">The type of a source of instances.</typeparam>
|
||||
/// <typeparam name="TInstance">The type of an instance.</typeparam>
|
||||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <typeparam name="TTrainingSettings">The type of the settings for training.</typeparam>
|
||||
/// <param name="classifier">The Bayes point machine classifier.</param>
|
||||
/// <param name="trainingData">The training data.</param>
|
||||
/// <param name="testData">The prediction data.</param>
|
||||
/// <param name="expectedLabelDistributions">The expected label distributions.</param>
|
||||
/// <param name="expectedIncrementalLabelDistributions">The expected label distributions for incremental training.</param>
|
||||
/// <param name="checkPrediction">A method which asserts the equality of expected and predicted distributions.</param>
|
||||
private static void TestRegressionSerialization<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings>(
|
||||
IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<TLabel>> classifier,
|
||||
TInstanceSource trainingData,
|
||||
TInstanceSource testData,
|
||||
IEnumerable<TLabelDistribution> expectedLabelDistributions,
|
||||
IEnumerable<TLabelDistribution> expectedIncrementalLabelDistributions,
|
||||
Action<IEnumerable<TLabelDistribution>, IEnumerable<TLabelDistribution>, double> checkPrediction)
|
||||
where TTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
const string TrainedFileName = "trainedClassifier.bin";
|
||||
const string UntrainedFileName = "untrainedClassifier.bin";
|
||||
|
||||
// Train and serialize
|
||||
classifier.Settings.Training.IterationCount = IterationCount;
|
||||
|
||||
classifier.Save(UntrainedFileName);
|
||||
classifier.Train(trainingData);
|
||||
classifier.Save(TrainedFileName);
|
||||
|
||||
// Deserialize and test
|
||||
var trainedClassifier = BayesPointMachineClassifier.Load<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<TLabel>>(TrainedFileName);
|
||||
var untrainedClassifier = BayesPointMachineClassifier.Load<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, TTrainingSettings, IBayesPointMachineClassifierPredictionSettings<TLabel>>(UntrainedFileName);
|
||||
|
||||
untrainedClassifier.Train(trainingData);
|
||||
|
||||
checkPrediction(expectedLabelDistributions, trainedClassifier.PredictDistribution(testData), Tolerance);
|
||||
checkPrediction(expectedLabelDistributions, untrainedClassifier.PredictDistribution(testData), Tolerance);
|
||||
|
||||
// Incremental training
|
||||
trainedClassifier.TrainIncremental(trainingData);
|
||||
untrainedClassifier.TrainIncremental(trainingData);
|
||||
|
||||
checkPrediction(expectedIncrementalLabelDistributions, trainedClassifier.PredictDistribution(testData), Tolerance);
|
||||
checkPrediction(expectedIncrementalLabelDistributions, untrainedClassifier.PredictDistribution(testData), Tolerance);
|
||||
}
|
||||
|
||||
#region Custom binary serialization
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -60,7 +60,6 @@
|
|||
<ProjectReference Include="..\..\..\src\Learners\Classifier\Classifier.csproj" />
|
||||
<ProjectReference Include="..\..\..\src\Learners\Core\Core.csproj" />
|
||||
<ProjectReference Include="..\..\..\src\Learners\Recommender\Recommender.csproj" />
|
||||
<ProjectReference Include="..\..\..\src\Learners\Runners\CommandLine\CommandLine.csproj" />
|
||||
<ProjectReference Include="..\..\..\src\Learners\Runners\Common\Common.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
|
|
|
@ -642,75 +642,6 @@ namespace Microsoft.ML.Probabilistic.Learners.Tests
|
|||
this.TestDataConsistency(data, true, false); // Valid item features restored, must not throw
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization/deserialization of the recommender operating on the native Matchbox data format.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void NativeDataFormatSerializationTest()
|
||||
{
|
||||
const string TrainedFileName = "trainedNativeRecommender.bin";
|
||||
const string NotTrainedFileName = "notTrainedNativeRecommender.bin";
|
||||
|
||||
// Train and serialize
|
||||
{
|
||||
var recommender = this.CreateNativeDataFormatMatchboxRecommender();
|
||||
recommender.Save(NotTrainedFileName);
|
||||
recommender.Train(this.nativeTrainingData, this.nativeTrainingData);
|
||||
recommender.Save(TrainedFileName);
|
||||
}
|
||||
|
||||
// Deserialize and test
|
||||
{
|
||||
var trainedRecommender = MatchboxRecommender.Load<NativeDataset, int, int, Discrete, NativeDataset>(TrainedFileName);
|
||||
var notTrainedRecommender = MatchboxRecommender.Load<NativeDataset, int, int, Discrete, NativeDataset>(NotTrainedFileName);
|
||||
|
||||
notTrainedRecommender.Train(this.nativeTrainingData, this.nativeTrainingData);
|
||||
|
||||
this.VerifyNativeRatingDistributionOfUserOneAndItemThree(trainedRecommender.PredictDistribution(1, 3, this.nativeTrainingData));
|
||||
Assert.Equal(MaxStarRating - MinStarRating, trainedRecommender.Predict(1, 3, this.nativeTrainingData));
|
||||
|
||||
this.VerifyNativeRatingDistributionOfUserOneAndItemThree(notTrainedRecommender.PredictDistribution(1, 3, this.nativeTrainingData));
|
||||
Assert.Equal(MaxStarRating - MinStarRating, notTrainedRecommender.Predict(1, 3, this.nativeTrainingData));
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests serialization/deserialization of the recommender operating on the standard data format.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void StandardDataFormatSerializationTest()
|
||||
{
|
||||
const int BatchCount = 2;
|
||||
const string TrainedFileName = "trainedStandardRecommender.bin";
|
||||
const string NotTrainedFileName = "notTrainedStandardRecommender.bin";
|
||||
|
||||
// Add features for cold test set user and item
|
||||
this.standardTrainingDataFeatures.UserFeatures.Add(User.WithId("u2"), Vector.FromArray(4, 2, 1.3, 1.2, 2));
|
||||
this.standardTrainingDataFeatures.ItemFeatures.Add(Item.WithId("i4"), Vector.FromArray(6.3, 0.5));
|
||||
|
||||
// Train and serialize
|
||||
{
|
||||
var recommender = this.CreateStandardDataFormatMatchboxRecommender();
|
||||
recommender.Settings.Training.BatchCount = BatchCount;
|
||||
recommender.Save(NotTrainedFileName);
|
||||
recommender.Train(this.standardTrainingData, this.standardTrainingDataFeatures);
|
||||
recommender.Save(TrainedFileName);
|
||||
|
||||
CheckStandardRatingPrediction(recommender, this.standardTrainingDataFeatures);
|
||||
}
|
||||
|
||||
// Deserialize and test
|
||||
{
|
||||
var trainedRecommender = MatchboxRecommender.Load<StandardDataset, User, Item, RatingDistribution, FeatureProvider>(TrainedFileName);
|
||||
var notTrainedRecommender = MatchboxRecommender.Load<StandardDataset, User, Item, RatingDistribution, FeatureProvider>(NotTrainedFileName);
|
||||
|
||||
notTrainedRecommender.Train(this.standardTrainingData, this.standardTrainingDataFeatures);
|
||||
|
||||
CheckStandardRatingPrediction(trainedRecommender, this.standardTrainingDataFeatures);
|
||||
CheckStandardRatingPrediction(notTrainedRecommender, this.standardTrainingDataFeatures);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests binary custom serialization/deserialization of the recommender operating on the native data format.
|
||||
/// </summary>
|
||||
|
|
|
@ -136,21 +136,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
writer.Dispose();
|
||||
#if NETFRAMEWORK
|
||||
// In the .NET 5.0 BinaryFormatter is obsolete
|
||||
// and would produce errors. This test code should be migrated.
|
||||
// See https://aka.ms/binaryformatter
|
||||
|
||||
if (true)
|
||||
{
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
using (Stream stream = File.Create(Path.Combine(dataFolder, "weights.bin")))
|
||||
{
|
||||
serializer.Serialize(stream, train.wPost);
|
||||
serializer.Serialize(stream, train.biasPost);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// (0.5,0.5):
|
||||
|
@ -170,37 +155,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
#pragma warning restore 162
|
||||
#endif
|
||||
|
||||
#if NETFRAMEWORK
|
||||
public static void Rcv1Test2()
|
||||
{
|
||||
GaussianArray wPost;
|
||||
Gaussian biasPost;
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
using (Stream stream = File.OpenRead(Path.Combine(dataFolder, "weights.bin")))
|
||||
{
|
||||
wPost = (GaussianArray)serializer.Deserialize(stream);
|
||||
biasPost = (Gaussian)serializer.Deserialize(stream);
|
||||
}
|
||||
if (true)
|
||||
{
|
||||
GaussianEstimator est = new GaussianEstimator();
|
||||
foreach (Gaussian item in wPost) est.Add(item.GetMean());
|
||||
Console.WriteLine("weight distribution = {0}", est.GetDistribution(new Gaussian()));
|
||||
}
|
||||
var predict = new BpmPredict2();
|
||||
predict.SetPriors(wPost, biasPost);
|
||||
int count = 0;
|
||||
int errors = 0;
|
||||
foreach (Instance instance in new VwReader(Path.Combine(dataFolder, "rcv1.test.vw.gz")))
|
||||
{
|
||||
bool yPred = predict.Predict(instance);
|
||||
if (yPred != instance.label) errors++;
|
||||
count++;
|
||||
}
|
||||
Console.WriteLine("error rate = {0} = {1}/{2}", (double) errors/count, errors, count);
|
||||
}
|
||||
#endif
|
||||
|
||||
public static void Rcv1Test3()
|
||||
{
|
||||
int nf = 47152;
|
||||
|
|
|
@ -32,6 +32,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
using Microsoft.ML.Probabilistic.Compiler;
|
||||
using Microsoft.ML.Probabilistic.Algorithms;
|
||||
using Microsoft.ML.Probabilistic.Models.Attributes;
|
||||
using System.Runtime.Serialization;
|
||||
|
||||
public class DistributedTests
|
||||
{
|
||||
|
|
|
@ -82,18 +82,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
mc.AssertEqualTo(mc2);
|
||||
}
|
||||
|
||||
#if NETFRAMEWORK
|
||||
[Fact]
|
||||
public void BinaryFormatterTest()
|
||||
{
|
||||
var mc = new MyClass();
|
||||
mc.Initialize(skipStringDistributions: true);
|
||||
|
||||
var mc2 = CloneBinaryFormatter(mc);
|
||||
mc.AssertEqualTo(mc2);
|
||||
}
|
||||
#endif
|
||||
|
||||
[Fact]
|
||||
public void JsonNetSerializerTest()
|
||||
{
|
||||
|
@ -104,46 +92,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
mc.AssertEqualTo(mc2);
|
||||
}
|
||||
|
||||
#if NETFRAMEWORK
|
||||
[Fact]
|
||||
public void VectorSerializeTests()
|
||||
{
|
||||
Sparsity approxSparsity = Sparsity.ApproximateWithTolerance(0.001);
|
||||
double[] fromArray = new double[] {1.2, 2.3, 3.4, 1.2, 1.2, 2.3};
|
||||
Vector vdense = Vector.FromArray(fromArray);
|
||||
Vector vsparse = Vector.FromArray(fromArray, Sparsity.Sparse);
|
||||
Vector vapprox = Vector.FromArray(fromArray, approxSparsity);
|
||||
MemoryStream stream = new MemoryStream();
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
serializer.Serialize(stream, vdense);
|
||||
serializer.Serialize(stream, vsparse);
|
||||
serializer.Serialize(stream, vapprox);
|
||||
|
||||
stream.Position = 0;
|
||||
Vector vdense2 = (Vector) serializer.Deserialize(stream);
|
||||
SparseVector vsparse2 = (SparseVector) serializer.Deserialize(stream);
|
||||
ApproximateSparseVector vapprox2 = (ApproximateSparseVector) serializer.Deserialize(stream);
|
||||
|
||||
Assert.Equal(6, vdense2.Count);
|
||||
for (int i = 0; i < fromArray.Length; i++) Assert.Equal(fromArray[i], vdense2[i]);
|
||||
Assert.Equal(vdense2.Sparsity, Sparsity.Dense);
|
||||
|
||||
Assert.Equal(6, vsparse2.Count);
|
||||
for (int i = 0; i < fromArray.Length; i++) Assert.Equal(vsparse2[i], fromArray[i]);
|
||||
Assert.Equal(vsparse2.Sparsity, Sparsity.Sparse);
|
||||
Assert.Equal(1.2, vsparse2.CommonValue);
|
||||
Assert.Equal(3, vsparse2.SparseValues.Count);
|
||||
Assert.True(vsparse2.HasCommonElements);
|
||||
|
||||
Assert.Equal(6, vapprox2.Count);
|
||||
for (int i = 0; i < fromArray.Length; i++) Assert.Equal(vapprox2[i], fromArray[i]);
|
||||
Assert.Equal(vapprox2.Sparsity, approxSparsity);
|
||||
Assert.Equal(1.2, vapprox2.CommonValue);
|
||||
Assert.Equal(3, vapprox2.SparseValues.Count);
|
||||
Assert.True(vapprox2.HasCommonElements);
|
||||
}
|
||||
#endif
|
||||
|
||||
[DataContract]
|
||||
[Serializable]
|
||||
public class MyClass
|
||||
|
@ -306,19 +254,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
#if NETFRAMEWORK
|
||||
private static T CloneBinaryFormatter<T>(T obj)
|
||||
{
|
||||
var bf = new BinaryFormatter();
|
||||
using (var ms = new MemoryStream())
|
||||
{
|
||||
bf.Serialize(ms, obj);
|
||||
ms.Position = 0;
|
||||
return (T)bf.Deserialize(ms);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
private static T CloneJsonNet<T>(T obj)
|
||||
{
|
||||
var serializerSettings = new JsonSerializerSettings
|
||||
|
|
|
@ -1452,25 +1452,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Console.WriteLine(engine.Infer(mean));
|
||||
}
|
||||
|
||||
#if NETFRAMEWORK
|
||||
internal void BinarySerializationExample()
|
||||
{
|
||||
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
// write to disk
|
||||
using (FileStream stream = new FileStream("temp.bin", FileMode.Create))
|
||||
{
|
||||
serializer.Serialize(stream, d);
|
||||
}
|
||||
// read from disk
|
||||
using (FileStream stream = new FileStream("temp.bin", FileMode.Open))
|
||||
{
|
||||
Dirichlet d2 = (Dirichlet)serializer.Deserialize(stream);
|
||||
Console.WriteLine(d2);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
internal void XmlSerializationExample()
|
||||
{
|
||||
Dirichlet d = new Dirichlet(3.0, 1.0, 2.0);
|
||||
|
|
|
@ -679,21 +679,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
#if NETFRAMEWORK
|
||||
if (false)
|
||||
{
|
||||
BinaryFormatter serializer = new BinaryFormatter();
|
||||
using (var writer = new FileStream("userTraits.bin", FileMode.Create))
|
||||
{
|
||||
serializer.Serialize(writer, engine.Infer(userTraits));
|
||||
}
|
||||
using (var writer = new FileStream("itemTraits.bin", FileMode.Create))
|
||||
{
|
||||
serializer.Serialize(writer, engine.Infer(itemTraits));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// test resetting inference
|
||||
if (engine.Compiler.ReturnCopies)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче