Fixing ModelParameter discrepancies (#2968)
* fixing model parameter discrepencies * multiclass LR singe that refactoring is happening in a parallel PR * review comments. Added Multiclass to NaiveBayes * Drop Classification from trainer names - v1 (more trainers to follow) * multiclass LR will be handled separately * Drop Classification from trainer names - v2 (all trainers taken care of) * fix entrypoint file
This commit is contained in:
Родитель
71693b3ac8
Коммит
08318656ed
|
@ -47,7 +47,7 @@ namespace Microsoft.ML.Samples.Dynamic.PermutationFeatureImportance
|
|||
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
|
||||
=> output.AboveAverage = input.MedianHomeValue > 22.6;
|
||||
|
||||
public static float[] GetLinearModelWeights(OrdinaryLeastSquaresRegressionModelParameters linearModel)
|
||||
public static float[] GetLinearModelWeights(OlsModelParameters linearModel)
|
||||
{
|
||||
return linearModel.Weights.ToArray();
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
|
|||
// we could do so by tweaking the 'advancedSetting'.
|
||||
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
|
||||
.Append(mlContext.BinaryClassification.Trainers.SdcaCalibrated(
|
||||
new SdcaCalibratedBinaryClassificationTrainer.Options {
|
||||
new SdcaCalibratedBinaryTrainer.Options {
|
||||
LabelColumnName = "Sentiment",
|
||||
FeatureColumnName = "Features",
|
||||
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace Microsoft.ML.Samples.Dynamic.Trainers.BinaryClassification
|
|||
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);
|
||||
|
||||
// Define the trainer options.
|
||||
var options = new SdcaCalibratedBinaryClassificationTrainer.Options()
|
||||
var options = new SdcaCalibratedBinaryTrainer.Options()
|
||||
{
|
||||
// Make the convergence tolerance tighter.
|
||||
ConvergenceTolerance = 0.05f,
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace Microsoft.ML.Samples.Dynamic.Trainers.MulticlassClassification
|
|||
// CC 1.216908,1.248052,1.391902,0.4326252,1.099942,0.9262842,1.334019,1.08762,0.9468155,0.4811099
|
||||
// DD 0.7871246,1.053327,0.8971719,1.588544,1.242697,1.362964,0.6303943,0.9810045,0.9431419,1.557455
|
||||
|
||||
var options = new SdcaMulticlassClassificationTrainer.Options
|
||||
var options = new SdcaMulticlassTrainer.Options
|
||||
{
|
||||
// Add custom loss
|
||||
LossFunction = new HingeLoss(),
|
||||
|
|
|
@ -9,7 +9,7 @@ using Microsoft.ML.Internal.Internallearn;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers.FastTree;
|
||||
|
||||
[assembly: EntryPointModule(typeof(FastTreeBinaryClassificationTrainer.Options))]
|
||||
[assembly: EntryPointModule(typeof(FastTreeBinaryTrainer.Options))]
|
||||
[assembly: EntryPointModule(typeof(FastTreeRegressionTrainer.Options))]
|
||||
[assembly: EntryPointModule(typeof(FastTreeTweedieTrainer.Options))]
|
||||
[assembly: EntryPointModule(typeof(FastTreeRankingTrainer.Options))]
|
||||
|
@ -52,10 +52,10 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
}
|
||||
|
||||
// XML docs are provided in the other part of this partial class. No need to duplicate the content here.
|
||||
public sealed partial class FastTreeBinaryClassificationTrainer
|
||||
public sealed partial class FastTreeBinaryTrainer
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for the <see cref="FastTreeBinaryClassificationTrainer"/>.
|
||||
/// Options for the <see cref="FastTreeBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
[TlcModule.Component(Name = LoadNameValue, FriendlyName = UserNameValue, Desc = Summary)]
|
||||
public sealed class Options : BoostedTreeOptions, IFastTreeTrainerFactory
|
||||
|
@ -102,7 +102,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
EarlyStoppingMetric = EarlyStoppingMetric.L1Norm;
|
||||
}
|
||||
|
||||
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryClassificationTrainer(env, this);
|
||||
ITrainer IComponentFactory<ITrainer>.CreateComponent(IHostEnvironment env) => new FastTreeBinaryTrainer(env, this);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,14 +13,14 @@ using Microsoft.ML.Model;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers.FastTree;
|
||||
|
||||
[assembly: LoadableClass(FastTreeBinaryClassificationTrainer.Summary, typeof(FastTreeBinaryClassificationTrainer), typeof(FastTreeBinaryClassificationTrainer.Options),
|
||||
[assembly: LoadableClass(FastTreeBinaryTrainer.Summary, typeof(FastTreeBinaryTrainer), typeof(FastTreeBinaryTrainer.Options),
|
||||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
|
||||
FastTreeBinaryClassificationTrainer.UserNameValue,
|
||||
FastTreeBinaryClassificationTrainer.LoadNameValue,
|
||||
FastTreeBinaryTrainer.UserNameValue,
|
||||
FastTreeBinaryTrainer.LoadNameValue,
|
||||
"FastTreeClassification",
|
||||
"FastTree",
|
||||
"ft",
|
||||
FastTreeBinaryClassificationTrainer.ShortName,
|
||||
FastTreeBinaryTrainer.ShortName,
|
||||
|
||||
// FastRank names
|
||||
"FastRankBinaryClassification",
|
||||
|
@ -101,8 +101,8 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using FastTree.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="FastTree_remarks"]/*' />
|
||||
public sealed partial class FastTreeBinaryClassificationTrainer :
|
||||
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options,
|
||||
public sealed partial class FastTreeBinaryTrainer :
|
||||
BoostingFastTreeTrainerBase<FastTreeBinaryTrainer.Options,
|
||||
BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,
|
||||
CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>
|
||||
{
|
||||
|
@ -118,7 +118,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
private double _sigmoidParameter;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -128,7 +128,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
|
||||
/// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
|
||||
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
|
||||
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal FastTreeBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -143,11 +143,11 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the <see cref="Options"/> class.
|
||||
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/> by using the <see cref="Options"/> class.
|
||||
/// </summary>
|
||||
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="options">Algorithm advanced settings.</param>
|
||||
internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal FastTreeBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
|
||||
{
|
||||
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
|
||||
|
@ -278,7 +278,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Trains a <see cref="FastTreeBinaryClassificationTrainer"/> using both training and validation data, returns
|
||||
/// Trains a <see cref="FastTreeBinaryTrainer"/> using both training and validation data, returns
|
||||
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
|
||||
/// </summary>
|
||||
public BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
|
||||
|
@ -403,18 +403,18 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
internal static partial class FastTree
|
||||
{
|
||||
[TlcModule.EntryPoint(Name = "Trainers.FastTreeBinaryClassifier",
|
||||
Desc = FastTreeBinaryClassificationTrainer.Summary,
|
||||
UserName = FastTreeBinaryClassificationTrainer.UserNameValue,
|
||||
ShortName = FastTreeBinaryClassificationTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryClassificationTrainer.Options input)
|
||||
Desc = FastTreeBinaryTrainer.Summary,
|
||||
UserName = FastTreeBinaryTrainer.UserNameValue,
|
||||
ShortName = FastTreeBinaryTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastTreeBinaryTrainer.Options input)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
var host = env.Register("TrainFastTree");
|
||||
host.CheckValue(input, nameof(input));
|
||||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<FastTreeBinaryClassificationTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new FastTreeBinaryClassificationTrainer(host, input),
|
||||
return TrainerEntryPointsUtils.Train<FastTreeBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new FastTreeBinaryTrainer(host, input),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName));
|
||||
|
|
|
@ -13,16 +13,16 @@ using Microsoft.ML.Model;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers.FastTree;
|
||||
|
||||
[assembly: LoadableClass(GamBinaryClassificationTrainer.Summary,
|
||||
typeof(GamBinaryClassificationTrainer), typeof(GamBinaryClassificationTrainer.Options),
|
||||
[assembly: LoadableClass(GamBinaryTrainer.Summary,
|
||||
typeof(GamBinaryTrainer), typeof(GamBinaryTrainer.Options),
|
||||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
|
||||
GamBinaryClassificationTrainer.UserNameValue,
|
||||
GamBinaryClassificationTrainer.LoadNameValue,
|
||||
GamBinaryClassificationTrainer.ShortName, DocName = "trainer/GAM.md")]
|
||||
GamBinaryTrainer.UserNameValue,
|
||||
GamBinaryTrainer.LoadNameValue,
|
||||
GamBinaryTrainer.ShortName, DocName = "trainer/GAM.md")]
|
||||
|
||||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(BinaryClassificationGamModelParameters), null, typeof(SignatureLoadModel),
|
||||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(GamBinaryModelParameters), null, typeof(SignatureLoadModel),
|
||||
"GAM Binary Class Predictor",
|
||||
BinaryClassificationGamModelParameters.LoaderSignature)]
|
||||
GamBinaryModelParameters.LoaderSignature)]
|
||||
|
||||
namespace Microsoft.ML.Trainers.FastTree
|
||||
{
|
||||
|
@ -30,13 +30,13 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a binary classification model with generalized additive models (GAM).
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="GAM_remarks"]/*' />
|
||||
public sealed class GamBinaryClassificationTrainer :
|
||||
GamTrainerBase<GamBinaryClassificationTrainer.Options,
|
||||
BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>,
|
||||
CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
|
||||
public sealed class GamBinaryTrainer :
|
||||
GamTrainerBase<GamBinaryTrainer.Options,
|
||||
BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>,
|
||||
CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for the <see cref="GamBinaryClassificationTrainer"/>.
|
||||
/// Options for the <see cref="GamBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
public sealed class Options : OptionsBase
|
||||
{
|
||||
|
@ -57,16 +57,16 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
private protected override bool NeedCalibration => true;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="GamBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="GamBinaryTrainer"/>
|
||||
/// </summary>
|
||||
internal GamBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal GamBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, options, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
|
||||
{
|
||||
_sigmoidParameter = 1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="GamBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="GamBinaryTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -75,7 +75,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <param name="numberOfIterations">The number of iterations to use in learning the features.</param>
|
||||
/// <param name="learningRate">The learning rate. GAMs work best with a small learning rate.</param>
|
||||
/// <param name="maximumBinCountPerFeature">The maximum number of bins to use to approximate features</param>
|
||||
internal GamBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal GamBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string rowGroupColumnName = null,
|
||||
|
@ -111,18 +111,18 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
|
||||
return boolArray;
|
||||
}
|
||||
private protected override CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
|
||||
private protected override CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
|
||||
{
|
||||
TrainBase(context);
|
||||
var predictor = new BinaryClassificationGamModelParameters(Host,
|
||||
var predictor = new GamBinaryModelParameters(Host,
|
||||
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
|
||||
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
|
||||
return new ValueMapperCalibratedModelParameters<BinaryClassificationGamModelParameters, PlattCalibrator>(Host, predictor, calibrator);
|
||||
return new ValueMapperCalibratedModelParameters<GamBinaryModelParameters, PlattCalibrator>(Host, predictor, calibrator);
|
||||
}
|
||||
|
||||
private protected override ObjectiveFunctionBase CreateObjectiveFunction()
|
||||
{
|
||||
return new FastTreeBinaryClassificationTrainer.ObjectiveImpl(
|
||||
return new FastTreeBinaryTrainer.ObjectiveImpl(
|
||||
TrainSet,
|
||||
ConvertTargetsToBool(TrainSet.Targets),
|
||||
GamTrainerOptions.LearningRate,
|
||||
|
@ -146,15 +146,15 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
PruningTest = new TestHistory(validTest, PruningLossIndex);
|
||||
}
|
||||
|
||||
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
|
||||
MakeTransformer(CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
|
||||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
private protected override BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>
|
||||
MakeTransformer(CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator> model, DataViewSchema trainSchema)
|
||||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Trains a <see cref="GamBinaryClassificationTrainer"/> using both training and validation data, returns
|
||||
/// Trains a <see cref="GamBinaryTrainer"/> using both training and validation data, returns
|
||||
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
|
||||
/// </summary>
|
||||
public BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
|
||||
public BinaryPredictionTransformer<CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
|
||||
=> TrainTransformer(trainData, validationData);
|
||||
|
||||
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
|
||||
|
@ -171,7 +171,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <summary>
|
||||
/// The model parameters class for Binary Classification GAMs
|
||||
/// </summary>
|
||||
public sealed class BinaryClassificationGamModelParameters : GamModelParametersBase, IPredictorProducing<float>
|
||||
public sealed class GamBinaryModelParameters : GamModelParametersBase, IPredictorProducing<float>
|
||||
{
|
||||
internal const string LoaderSignature = "BinaryClassGamPredictor";
|
||||
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
|
||||
|
@ -188,11 +188,11 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <param name="featureToInputMap">A map from the feature shape functions, as described by <paramref name="binUpperBounds"/> and <paramref name="binEffects"/>.
|
||||
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
|
||||
/// a shape function.</param>
|
||||
internal BinaryClassificationGamModelParameters(IHostEnvironment env,
|
||||
internal GamBinaryModelParameters(IHostEnvironment env,
|
||||
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength, int[] featureToInputMap)
|
||||
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }
|
||||
|
||||
private BinaryClassificationGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private GamBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
: base(env, LoaderSignature, ctx) { }
|
||||
|
||||
private static VersionInfo GetVersionInfo()
|
||||
|
@ -205,7 +205,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
verReadableCur: 0x00010002,
|
||||
verWeCanReadBack: 0x00010001,
|
||||
loaderSignature: LoaderSignature,
|
||||
loaderAssemblyName: typeof(BinaryClassificationGamModelParameters).Assembly.FullName);
|
||||
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName);
|
||||
}
|
||||
|
||||
private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
|
@ -214,12 +214,12 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
env.CheckValue(ctx, nameof(ctx));
|
||||
ctx.CheckAtModel(GetVersionInfo());
|
||||
|
||||
var predictor = new BinaryClassificationGamModelParameters(env, ctx);
|
||||
var predictor = new GamBinaryModelParameters(env, ctx);
|
||||
ICalibrator calibrator;
|
||||
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
|
||||
if (calibrator == null)
|
||||
return predictor;
|
||||
return new SchemaBindableCalibratedModelParameters<BinaryClassificationGamModelParameters, ICalibrator>(env, predictor, calibrator);
|
||||
return new SchemaBindableCalibratedModelParameters<GamBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
|
||||
}
|
||||
|
||||
private protected override void SaveCore(ModelSaveContext ctx)
|
||||
|
|
|
@ -879,12 +879,12 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
// 2. RegressionGamModelParameters
|
||||
// For (1), the trained model, GamModelParametersBase, is a field we need to extract. For (2),
|
||||
// we don't need to do anything because RegressionGamModelParameters is derived from GamModelParametersBase.
|
||||
var calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
|
||||
var calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
|
||||
while (calibrated != null)
|
||||
{
|
||||
hadCalibrator = true;
|
||||
rawPred = calibrated.SubModel;
|
||||
calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
|
||||
calibrated = rawPred as CalibratedModelParametersBase<GamBinaryModelParameters, PlattCalibrator>;
|
||||
}
|
||||
var pred = rawPred as GamModelParametersBase;
|
||||
ch.CheckUserArg(pred != null, nameof(ImplOptions.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase));
|
||||
|
|
|
@ -17,9 +17,9 @@ using Microsoft.ML.Trainers.FastTree;
|
|||
GamRegressionTrainer.LoadNameValue,
|
||||
GamRegressionTrainer.ShortName, DocName = "trainer/GAM.md")]
|
||||
|
||||
[assembly: LoadableClass(typeof(RegressionGamModelParameters), null, typeof(SignatureLoadModel),
|
||||
[assembly: LoadableClass(typeof(GamRegressionModelParameters), null, typeof(SignatureLoadModel),
|
||||
"GAM Regression Predictor",
|
||||
RegressionGamModelParameters.LoaderSignature)]
|
||||
GamRegressionModelParameters.LoaderSignature)]
|
||||
|
||||
namespace Microsoft.ML.Trainers.FastTree
|
||||
{
|
||||
|
@ -27,7 +27,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a regression model with generalized additive models (GAM).
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="GAM_remarks"]/*' />
|
||||
public sealed class GamRegressionTrainer : GamTrainerBase<GamRegressionTrainer.Options, RegressionPredictionTransformer<RegressionGamModelParameters>, RegressionGamModelParameters>
|
||||
public sealed class GamRegressionTrainer : GamTrainerBase<GamRegressionTrainer.Options, RegressionPredictionTransformer<GamRegressionModelParameters>, GamRegressionModelParameters>
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for the <see cref="GamRegressionTrainer"/>.
|
||||
|
@ -55,7 +55,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
: base(env, options, LoadNameValue, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName)) { }
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="FastTreeBinaryTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -80,10 +80,10 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
data.CheckRegressionLabel();
|
||||
}
|
||||
|
||||
private protected override RegressionGamModelParameters TrainModelCore(TrainContext context)
|
||||
private protected override GamRegressionModelParameters TrainModelCore(TrainContext context)
|
||||
{
|
||||
TrainBase(context);
|
||||
return new RegressionGamModelParameters(Host, BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
|
||||
return new GamRegressionModelParameters(Host, BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
|
||||
}
|
||||
|
||||
private protected override ObjectiveFunctionBase CreateObjectiveFunction()
|
||||
|
@ -99,14 +99,14 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
PruningTest = new TestHistory(validTest, PruningLossIndex);
|
||||
}
|
||||
|
||||
private protected override RegressionPredictionTransformer<RegressionGamModelParameters> MakeTransformer(RegressionGamModelParameters model, DataViewSchema trainSchema)
|
||||
=> new RegressionPredictionTransformer<RegressionGamModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
private protected override RegressionPredictionTransformer<GamRegressionModelParameters> MakeTransformer(GamRegressionModelParameters model, DataViewSchema trainSchema)
|
||||
=> new RegressionPredictionTransformer<GamRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Trains a <see cref="GamRegressionTrainer"/> using both training and validation data, returns
|
||||
/// a <see cref="RegressionPredictionTransformer{RegressionGamModelParameters}"/>.
|
||||
/// </summary>
|
||||
public RegressionPredictionTransformer<RegressionGamModelParameters> Fit(IDataView trainData, IDataView validationData)
|
||||
public RegressionPredictionTransformer<GamRegressionModelParameters> Fit(IDataView trainData, IDataView validationData)
|
||||
=> TrainTransformer(trainData, validationData);
|
||||
|
||||
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
|
||||
|
@ -121,7 +121,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <summary>
|
||||
/// The model parameters class for Binary Classification GAMs
|
||||
/// </summary>
|
||||
public sealed class RegressionGamModelParameters : GamModelParametersBase
|
||||
public sealed class GamRegressionModelParameters : GamModelParametersBase
|
||||
{
|
||||
internal const string LoaderSignature = "RegressionGamPredictor";
|
||||
private protected override PredictionKind PredictionKind => PredictionKind.Regression;
|
||||
|
@ -138,11 +138,11 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <param name="featureToInputMap">A map from the feature shape functions (as described by the binUpperBounds and BinEffects)
|
||||
/// to the input feature. Used when the number of input features is different than the number of shape functions. Use default if all features have
|
||||
/// a shape function.</param>
|
||||
internal RegressionGamModelParameters(IHostEnvironment env,
|
||||
internal GamRegressionModelParameters(IHostEnvironment env,
|
||||
double[][] binUpperBounds, double[][] binEffects, double intercept, int inputLength = -1, int[] featureToInputMap = null)
|
||||
: base(env, LoaderSignature, binUpperBounds, binEffects, intercept, inputLength, featureToInputMap) { }
|
||||
|
||||
private RegressionGamModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private GamRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
: base(env, LoaderSignature, ctx) { }
|
||||
|
||||
private static VersionInfo GetVersionInfo()
|
||||
|
@ -155,16 +155,16 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
verReadableCur: 0x00010002,
|
||||
verWeCanReadBack: 0x00010001,
|
||||
loaderSignature: LoaderSignature,
|
||||
loaderAssemblyName: typeof(RegressionGamModelParameters).Assembly.FullName);
|
||||
loaderAssemblyName: typeof(GamRegressionModelParameters).Assembly.FullName);
|
||||
}
|
||||
|
||||
private static RegressionGamModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(ctx, nameof(ctx));
|
||||
ctx.CheckAtModel(GetVersionInfo());
|
||||
|
||||
return new RegressionGamModelParameters(env, ctx);
|
||||
return new GamRegressionModelParameters(env, ctx);
|
||||
}
|
||||
|
||||
private protected override void SaveCore(ModelSaveContext ctx)
|
||||
|
|
|
@ -696,16 +696,16 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
|
||||
}
|
||||
|
||||
[TlcModule.EntryPoint(Name = "Trainers.GeneralizedAdditiveModelBinaryClassifier", Desc = GamBinaryClassificationTrainer.Summary, UserName = GamBinaryClassificationTrainer.UserNameValue, ShortName = GamBinaryClassificationTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, GamBinaryClassificationTrainer.Options input)
|
||||
[TlcModule.EntryPoint(Name = "Trainers.GeneralizedAdditiveModelBinaryClassifier", Desc = GamBinaryTrainer.Summary, UserName = GamBinaryTrainer.UserNameValue, ShortName = GamBinaryTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, GamBinaryTrainer.Options input)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
var host = env.Register("TrainGAM");
|
||||
host.CheckValue(input, nameof(input));
|
||||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<GamBinaryClassificationTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new GamBinaryClassificationTrainer(host, input),
|
||||
return TrainerEntryPointsUtils.Train<GamBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new GamBinaryTrainer(host, input),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
|
||||
}
|
||||
|
|
|
@ -13,17 +13,17 @@ using Microsoft.ML.Model;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers.FastTree;
|
||||
|
||||
[assembly: LoadableClass(FastForestBinaryClassificationTrainer.Summary, typeof(FastForestBinaryClassificationTrainer), typeof(FastForestBinaryClassificationTrainer.Options),
|
||||
[assembly: LoadableClass(FastForestBinaryTrainer.Summary, typeof(FastForestBinaryTrainer), typeof(FastForestBinaryTrainer.Options),
|
||||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer), typeof(SignatureFeatureScorerTrainer) },
|
||||
FastForestBinaryClassificationTrainer.UserNameValue,
|
||||
FastForestBinaryClassificationTrainer.LoadNameValue,
|
||||
FastForestBinaryTrainer.UserNameValue,
|
||||
FastForestBinaryTrainer.LoadNameValue,
|
||||
"FastForest",
|
||||
FastForestBinaryClassificationTrainer.ShortName,
|
||||
FastForestBinaryTrainer.ShortName,
|
||||
"ffc")]
|
||||
|
||||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(FastForestClassificationModelParameters), null, typeof(SignatureLoadModel),
|
||||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(FastForestBinaryModelParameters), null, typeof(SignatureLoadModel),
|
||||
"FastForest Binary Executor",
|
||||
FastForestClassificationModelParameters.LoaderSignature)]
|
||||
FastForestBinaryModelParameters.LoaderSignature)]
|
||||
|
||||
[assembly: LoadableClass(typeof(void), typeof(FastForest), null, typeof(SignatureEntryPointModule), "FastForest")]
|
||||
|
||||
|
@ -48,7 +48,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
}
|
||||
}
|
||||
|
||||
public sealed class FastForestClassificationModelParameters :
|
||||
public sealed class FastForestBinaryModelParameters :
|
||||
TreeEnsembleModelParametersBasedOnQuantileRegressionTree
|
||||
{
|
||||
internal const string LoaderSignature = "FastForestBinaryExec";
|
||||
|
@ -67,7 +67,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
verReadableCur: 0x00010005,
|
||||
verWeCanReadBack: 0x00010001,
|
||||
loaderSignature: LoaderSignature,
|
||||
loaderAssemblyName: typeof(FastForestClassificationModelParameters).Assembly.FullName);
|
||||
loaderAssemblyName: typeof(FastForestBinaryModelParameters).Assembly.FullName);
|
||||
}
|
||||
|
||||
private protected override uint VerNumFeaturesSerialized => 0x00010003;
|
||||
|
@ -81,11 +81,11 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// </summary>
|
||||
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
|
||||
|
||||
internal FastForestClassificationModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
|
||||
internal FastForestBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
|
||||
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
|
||||
{ }
|
||||
|
||||
private FastForestClassificationModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private FastForestBinaryModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
: base(env, RegistrationName, ctx, GetVersionInfo())
|
||||
{
|
||||
}
|
||||
|
@ -101,12 +101,12 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(ctx, nameof(ctx));
|
||||
ctx.CheckAtModel(GetVersionInfo());
|
||||
var predictor = new FastForestClassificationModelParameters(env, ctx);
|
||||
var predictor = new FastForestBinaryModelParameters(env, ctx);
|
||||
ICalibrator calibrator;
|
||||
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
|
||||
if (calibrator == null)
|
||||
return predictor;
|
||||
return new SchemaBindableCalibratedModelParameters<FastForestClassificationModelParameters, ICalibrator>(env, predictor, calibrator);
|
||||
return new SchemaBindableCalibratedModelParameters<FastForestBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -114,11 +114,11 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a decision tree binary classification model using Fast Forest.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="FastForest_remarks"]/*' />
|
||||
public sealed partial class FastForestBinaryClassificationTrainer :
|
||||
RandomForestTrainerBase<FastForestBinaryClassificationTrainer.Options, BinaryPredictionTransformer<FastForestClassificationModelParameters>, FastForestClassificationModelParameters>
|
||||
public sealed partial class FastForestBinaryTrainer :
|
||||
RandomForestTrainerBase<FastForestBinaryTrainer.Options, BinaryPredictionTransformer<FastForestBinaryModelParameters>, FastForestBinaryModelParameters>
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for the <see cref="FastForestBinaryClassificationTrainer"/>.
|
||||
/// Options for the <see cref="FastForestBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
public sealed class Options : FastForestOptionsBase
|
||||
{
|
||||
|
@ -146,7 +146,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
private protected override bool NeedCalibration => true;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="FastForestBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="FastForestBinaryTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -155,7 +155,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
/// <param name="numberOfLeaves">The max number of leaves in each regression tree.</param>
|
||||
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
|
||||
/// <param name="minimumExampleCountPerLeaf">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
|
||||
internal FastForestBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal FastForestBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -169,16 +169,16 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="FastForestBinaryClassificationTrainer"/> by using the <see cref="Options"/> class.
|
||||
/// Initializes a new instance of <see cref="FastForestBinaryTrainer"/> by using the <see cref="Options"/> class.
|
||||
/// </summary>
|
||||
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="options">Algorithm advanced settings.</param>
|
||||
internal FastForestBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal FastForestBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
|
||||
{
|
||||
}
|
||||
|
||||
private protected override FastForestClassificationModelParameters TrainModelCore(TrainContext context)
|
||||
private protected override FastForestBinaryModelParameters TrainModelCore(TrainContext context)
|
||||
{
|
||||
Host.CheckValue(context, nameof(context));
|
||||
var trainData = context.TrainingSet;
|
||||
|
@ -201,7 +201,7 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
// calibrator, transform the scores using that.
|
||||
|
||||
// REVIEW: Need a way to signal the outside world that we prefer simple sigmoid?
|
||||
return new FastForestClassificationModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions);
|
||||
return new FastForestBinaryModelParameters(Host, TrainedEnsemble, FeatureCount, InnerOptions);
|
||||
}
|
||||
|
||||
private protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
|
||||
|
@ -221,14 +221,14 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, 1);
|
||||
}
|
||||
|
||||
private protected override BinaryPredictionTransformer<FastForestClassificationModelParameters> MakeTransformer(FastForestClassificationModelParameters model, DataViewSchema trainSchema)
|
||||
=> new BinaryPredictionTransformer<FastForestClassificationModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
private protected override BinaryPredictionTransformer<FastForestBinaryModelParameters> MakeTransformer(FastForestBinaryModelParameters model, DataViewSchema trainSchema)
|
||||
=> new BinaryPredictionTransformer<FastForestBinaryModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Trains a <see cref="FastForestBinaryClassificationTrainer"/> using both training and validation data, returns
|
||||
/// Trains a <see cref="FastForestBinaryTrainer"/> using both training and validation data, returns
|
||||
/// a <see cref="BinaryPredictionTransformer{FastForestClassificationModelParameters}"/>.
|
||||
/// </summary>
|
||||
public BinaryPredictionTransformer<FastForestClassificationModelParameters> Fit(IDataView trainData, IDataView validationData)
|
||||
public BinaryPredictionTransformer<FastForestBinaryModelParameters> Fit(IDataView trainData, IDataView validationData)
|
||||
=> TrainTransformer(trainData, validationData);
|
||||
|
||||
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
|
||||
|
@ -263,18 +263,18 @@ namespace Microsoft.ML.Trainers.FastTree
|
|||
internal static partial class FastForest
|
||||
{
|
||||
[TlcModule.EntryPoint(Name = "Trainers.FastForestBinaryClassifier",
|
||||
Desc = FastForestBinaryClassificationTrainer.Summary,
|
||||
UserName = FastForestBinaryClassificationTrainer.UserNameValue,
|
||||
ShortName = FastForestBinaryClassificationTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastForestBinaryClassificationTrainer.Options input)
|
||||
Desc = FastForestBinaryTrainer.Summary,
|
||||
UserName = FastForestBinaryTrainer.UserNameValue,
|
||||
ShortName = FastForestBinaryTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, FastForestBinaryTrainer.Options input)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
var host = env.Register("TrainFastForest");
|
||||
host.CheckValue(input, nameof(input));
|
||||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<FastForestBinaryClassificationTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new FastForestBinaryClassificationTrainer(host, input),
|
||||
return TrainerEntryPointsUtils.Train<FastForestBinaryTrainer.Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new FastForestBinaryTrainer(host, input),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.RowGroupColumnName),
|
||||
|
|
|
@ -61,7 +61,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -71,7 +71,7 @@ namespace Microsoft.ML
|
|||
/// <param name="numberOfLeaves">The maximum number of leaves per decision tree.</param>
|
||||
/// <param name="minimumExampleCountPerLeaf">The minimal number of data points required to form a new tree leaf.</param>
|
||||
/// <param name="learningRate">The learning rate.</param>
|
||||
public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
public static FastTreeBinaryTrainer FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -82,22 +82,22 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new FastTreeBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf, learningRate);
|
||||
return new FastTreeBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf, learningRate);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
FastTreeBinaryClassificationTrainer.Options options)
|
||||
public static FastTreeBinaryTrainer FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
FastTreeBinaryTrainer.Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
Contracts.CheckValue(options, nameof(options));
|
||||
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new FastTreeBinaryClassificationTrainer(env, options);
|
||||
return new FastTreeBinaryTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -143,7 +143,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using generalized additive models (GAM) trained with the <see cref="GamBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using generalized additive models (GAM) trained with the <see cref="GamBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -152,7 +152,7 @@ namespace Microsoft.ML
|
|||
/// <param name="numberOfIterations">The number of iterations to use in learning the features.</param>
|
||||
/// <param name="maximumBinCountPerFeature">The maximum number of bins to use to approximate features.</param>
|
||||
/// <param name="learningRate">The learning rate. GAMs work best with a small learning rate.</param>
|
||||
public static GamBinaryClassificationTrainer Gam(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
public static GamBinaryTrainer Gam(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -162,20 +162,20 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new GamBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfIterations, learningRate, maximumBinCountPerFeature);
|
||||
return new GamBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfIterations, learningRate, maximumBinCountPerFeature);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using generalized additive models (GAM) trained with the <see cref="GamBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using generalized additive models (GAM) trained with the <see cref="GamBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
public static GamBinaryClassificationTrainer Gam(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
GamBinaryClassificationTrainer.Options options)
|
||||
public static GamBinaryTrainer Gam(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
GamBinaryTrainer.Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new GamBinaryClassificationTrainer(env, options);
|
||||
return new GamBinaryTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -293,7 +293,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -302,7 +302,7 @@ namespace Microsoft.ML
|
|||
/// <param name="numberOfTrees">Total number of decision trees to create in the ensemble.</param>
|
||||
/// <param name="numberOfLeaves">The maximum number of leaves per decision tree.</param>
|
||||
/// <param name="minDatapointsInLeaves">The minimal number of data points required to form a new tree leaf.</param>
|
||||
public static FastForestBinaryClassificationTrainer FastForest(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
public static FastForestBinaryTrainer FastForest(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -312,22 +312,22 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new FastForestBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, numberOfTrees, minDatapointsInLeaves);
|
||||
return new FastForestBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, numberOfTrees, minDatapointsInLeaves);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestBinaryClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestBinaryTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
public static FastForestBinaryClassificationTrainer FastForest(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
FastForestBinaryClassificationTrainer.Options options)
|
||||
public static FastForestBinaryTrainer FastForest(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
FastForestBinaryTrainer.Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
Contracts.CheckValue(options, nameof(options));
|
||||
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new FastForestBinaryClassificationTrainer(env, options);
|
||||
return new FastForestBinaryTrainer(env, options);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,7 +98,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a tree binary classification model trained with the <see cref="LightGbmBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a tree binary classification model trained with the <see cref="LightGbmBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="label">The label column.</param>
|
||||
|
@ -136,7 +136,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new LightGbmBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
|
||||
var trainer = new LightGbmBinaryTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
|
||||
minimumExampleCountPerLeaf, learningRate, numberOfIterations);
|
||||
|
||||
if (onFit != null)
|
||||
|
@ -149,7 +149,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a tree binary classification model trained with the <see cref="LightGbmBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a tree binary classification model trained with the <see cref="LightGbmBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="label">The label column.</param>
|
||||
|
@ -177,7 +177,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
options.FeatureColumnName = featuresName;
|
||||
options.ExampleWeightColumnName = weightsName;
|
||||
|
||||
var trainer = new LightGbmBinaryClassificationTrainer(env, options);
|
||||
var trainer = new LightGbmBinaryTrainer(env, options);
|
||||
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
|
@ -278,7 +278,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a tree multiclass classification model trained with the <see cref="LightGbmMulticlassClassificationTrainer"/>.
|
||||
/// Predict a target using a tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The multiclass classification catalog trainer object.</param>
|
||||
/// <param name="label">The label, or dependent variable.</param>
|
||||
|
@ -317,7 +317,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler<TVal>(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new LightGbmMulticlassClassificationTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
|
||||
var trainer = new LightGbmMulticlassTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
|
||||
minimumExampleCountPerLeaf, learningRate, numberOfIterations);
|
||||
|
||||
if (onFit != null)
|
||||
|
@ -329,7 +329,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a tree multiclass classification model trained with the <see cref="LightGbmMulticlassClassificationTrainer"/>.
|
||||
/// Predict a target using a tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The multiclass classification catalog trainer object.</param>
|
||||
/// <param name="label">The label, or dependent variable.</param>
|
||||
|
@ -359,7 +359,7 @@ namespace Microsoft.ML.Trainers.LightGbm.StaticPipe
|
|||
options.FeatureColumnName = featuresName;
|
||||
options.ExampleWeightColumnName = weightsName;
|
||||
|
||||
var trainer = new LightGbmMulticlassClassificationTrainer(env, options);
|
||||
var trainer = new LightGbmMulticlassTrainer(env, options);
|
||||
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
|
|
|
@ -150,7 +150,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
public class Options : ISupportBoosterParameterFactory
|
||||
{
|
||||
/// <summary>
|
||||
/// Whether training data is unbalanced. Used by <see cref="LightGbmBinaryClassificationTrainer"/>.
|
||||
/// Whether training data is unbalanced. Used by <see cref="LightGbmBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Use for binary classification when training data is not balanced.", ShortName = "us")]
|
||||
public bool UnbalancedSets = false;
|
||||
|
@ -263,7 +263,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
public double L1Regularization = 0;
|
||||
|
||||
/// <summary>
|
||||
/// Controls the balance of positive and negative weights in <see cref="LightGbmBinaryClassificationTrainer"/>.
|
||||
/// Controls the balance of positive and negative weights in <see cref="LightGbmBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <value>
|
||||
/// This is useful for training on unbalanced data. A typical value to consider is sum(negative cases) / sum(positive cases).
|
||||
|
@ -518,7 +518,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
public EvalMetricType EvaluationMetric = EvalMetricType.DefaultMetric;
|
||||
|
||||
/// <summary>
|
||||
/// Whether to use softmax loss. Used only by <see cref="LightGbmMulticlassClassificationTrainer"/>.
|
||||
/// Whether to use softmax loss. Used only by <see cref="LightGbmMulticlassTrainer"/>.
|
||||
/// </summary>
|
||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Use softmax loss for the multi classification.")]
|
||||
[TlcModule.SweepableDiscreteParam("UseSoftmax", new object[] { true, false })]
|
||||
|
@ -542,9 +542,9 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
public string CustomGains = "0,3,7,15,31,63,127,255,511,1023,2047,4095";
|
||||
|
||||
/// <summary>
|
||||
/// Parameter for the sigmoid function. Used only by <see cref="LightGbmBinaryClassificationTrainer"/>, <see cref="LightGbmMulticlassClassificationTrainer"/>, and <see cref="LightGbmRankingTrainer"/>.
|
||||
/// Parameter for the sigmoid function. Used only by <see cref="LightGbmBinaryTrainer"/>, <see cref="LightGbmMulticlassTrainer"/>, and <see cref="LightGbmRankingTrainer"/>.
|
||||
/// </summary>
|
||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function. Used only in " + nameof(LightGbmBinaryClassificationTrainer) + ", " + nameof(LightGbmMulticlassClassificationTrainer) +
|
||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Parameter for the sigmoid function. Used only in " + nameof(LightGbmBinaryTrainer) + ", " + nameof(LightGbmMulticlassTrainer) +
|
||||
" and in " + nameof(LightGbmRankingTrainer) + ".", ShortName = "sigmoid")]
|
||||
[TGUI(Label = "Sigmoid", SuggestedSweeps = "0.5,1")]
|
||||
public double Sigmoid = 0.5;
|
||||
|
|
|
@ -10,9 +10,9 @@ using Microsoft.ML.Runtime;
|
|||
using Microsoft.ML.Trainers.FastTree;
|
||||
using Microsoft.ML.Trainers.LightGbm;
|
||||
|
||||
[assembly: LoadableClass(LightGbmBinaryClassificationTrainer.Summary, typeof(LightGbmBinaryClassificationTrainer), typeof(Options),
|
||||
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options),
|
||||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) },
|
||||
LightGbmBinaryClassificationTrainer.UserName, LightGbmBinaryClassificationTrainer.LoadNameValue, LightGbmBinaryClassificationTrainer.ShortName, DocName = "trainer/LightGBM.md")]
|
||||
LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")]
|
||||
|
||||
[assembly: LoadableClass(typeof(IPredictorProducing<float>), typeof(LightGbmBinaryModelParameters), null, typeof(SignatureLoadModel),
|
||||
"LightGBM Binary Executor",
|
||||
|
@ -82,7 +82,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a boosted decision tree binary classification model using LightGBM.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="LightGBM_remarks"]/*' />
|
||||
public sealed class LightGbmBinaryClassificationTrainer : LightGbmTrainerBase<float,
|
||||
public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase<float,
|
||||
BinaryPredictionTransformer<CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>>,
|
||||
CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>>
|
||||
{
|
||||
|
@ -93,13 +93,13 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
|
||||
private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
|
||||
|
||||
internal LightGbmBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal LightGbmBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="LightGbmBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="LightGbmBinaryTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="labelColumnName">The name of The label column.</param>
|
||||
|
@ -109,7 +109,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
/// <param name="minimumExampleCountPerLeaf">The minimal number of data points allowed in a leaf of the tree, out of the subsampled data.</param>
|
||||
/// <param name="learningRate">The learning rate.</param>
|
||||
/// <param name="numberOfIterations">Number of iterations.</param>
|
||||
internal LightGbmBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal LightGbmBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -165,7 +165,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Trains a <see cref="LightGbmBinaryClassificationTrainer"/> using both training and validation data, returns
|
||||
/// Trains a <see cref="LightGbmBinaryTrainer"/> using both training and validation data, returns
|
||||
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
|
||||
/// </summary>
|
||||
public BinaryPredictionTransformer<CalibratedModelParametersBase<LightGbmBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, IDataView validationData)
|
||||
|
@ -179,9 +179,9 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
{
|
||||
[TlcModule.EntryPoint(
|
||||
Name = "Trainers.LightGbmBinaryClassifier",
|
||||
Desc = LightGbmBinaryClassificationTrainer.Summary,
|
||||
UserName = LightGbmBinaryClassificationTrainer.UserName,
|
||||
ShortName = LightGbmBinaryClassificationTrainer.ShortName)]
|
||||
Desc = LightGbmBinaryTrainer.Summary,
|
||||
UserName = LightGbmBinaryTrainer.UserName,
|
||||
ShortName = LightGbmBinaryTrainer.ShortName)]
|
||||
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
|
@ -190,7 +190,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new LightGbmBinaryClassificationTrainer(host, input),
|
||||
() => new LightGbmBinaryTrainer(host, input),
|
||||
getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
|
||||
getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
|
||||
}
|
||||
|
|
|
@ -66,7 +66,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a gradient boosting decision tree binary classification model trained with the <see cref="LightGbmBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a gradient boosting decision tree binary classification model trained with the <see cref="LightGbmBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -83,7 +83,7 @@ namespace Microsoft.ML
|
|||
/// ]]>
|
||||
/// </format>
|
||||
/// </example>
|
||||
public static LightGbmBinaryClassificationTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -94,11 +94,11 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new LightGbmBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations);
|
||||
return new LightGbmBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a gradient boosting decision tree binary classification model trained with the <see cref="LightGbmBinaryClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a gradient boosting decision tree binary classification model trained with the <see cref="LightGbmBinaryTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
|
@ -109,12 +109,12 @@ namespace Microsoft.ML
|
|||
/// ]]>
|
||||
/// </format>
|
||||
/// </example>
|
||||
public static LightGbmBinaryClassificationTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
public static LightGbmBinaryTrainer LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new LightGbmBinaryClassificationTrainer(env, options);
|
||||
return new LightGbmBinaryTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -172,7 +172,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a gradient boosting decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassClassificationTrainer"/>.
|
||||
/// Predict a target using a gradient boosting decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -189,7 +189,7 @@ namespace Microsoft.ML
|
|||
/// ]]>
|
||||
/// </format>
|
||||
/// </example>
|
||||
public static LightGbmMulticlassClassificationTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -200,11 +200,11 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new LightGbmMulticlassClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations);
|
||||
return new LightGbmMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, numberOfLeaves, minimumExampleCountPerLeaf, learningRate, numberOfIterations);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a gradient boosting decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a gradient boosting decision tree multiclass classification model trained with the <see cref="LightGbmMulticlassTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog"/>.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
|
@ -215,12 +215,12 @@ namespace Microsoft.ML
|
|||
/// ]]>
|
||||
/// </format>
|
||||
/// </example>
|
||||
public static LightGbmMulticlassClassificationTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
public static LightGbmMulticlassTrainer LightGbm(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new LightGbmMulticlassClassificationTrainer(env, options);
|
||||
return new LightGbmMulticlassTrainer(env, options);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,9 +12,9 @@ using Microsoft.ML.Runtime;
|
|||
using Microsoft.ML.Trainers.FastTree;
|
||||
using Microsoft.ML.Trainers.LightGbm;
|
||||
|
||||
[assembly: LoadableClass(LightGbmMulticlassClassificationTrainer.Summary, typeof(LightGbmMulticlassClassificationTrainer), typeof(Options),
|
||||
[assembly: LoadableClass(LightGbmMulticlassTrainer.Summary, typeof(LightGbmMulticlassTrainer), typeof(Options),
|
||||
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
|
||||
"LightGBM Multi-class Classifier", LightGbmMulticlassClassificationTrainer.LoadNameValue, LightGbmMulticlassClassificationTrainer.ShortName, DocName = "trainer/LightGBM.md")]
|
||||
"LightGBM Multi-class Classifier", LightGbmMulticlassTrainer.LoadNameValue, LightGbmMulticlassTrainer.ShortName, DocName = "trainer/LightGBM.md")]
|
||||
|
||||
namespace Microsoft.ML.Trainers.LightGbm
|
||||
{
|
||||
|
@ -22,7 +22,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a boosted decision tree multi-class classification model using LightGBM.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="LightGBM_remarks"]/*' />
|
||||
public sealed class LightGbmMulticlassClassificationTrainer : LightGbmTrainerBase<VBuffer<float>, MulticlassPredictionTransformer<OneVersusAllModelParameters>, OneVersusAllModelParameters>
|
||||
public sealed class LightGbmMulticlassTrainer : LightGbmTrainerBase<VBuffer<float>, MulticlassPredictionTransformer<OneVersusAllModelParameters>, OneVersusAllModelParameters>
|
||||
{
|
||||
internal const string Summary = "LightGBM Multi Class Classifier";
|
||||
internal const string LoadNameValue = "LightGBMMulticlass";
|
||||
|
@ -34,14 +34,14 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
private int _tlcNumClass;
|
||||
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
|
||||
|
||||
internal LightGbmMulticlassClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal LightGbmMulticlassTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, LoadNameValue, options, TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
|
||||
{
|
||||
_numClass = -1;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="LightGbmMulticlassClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="LightGbmMulticlassTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
|
||||
/// <param name="labelColumnName">The name of The label column.</param>
|
||||
|
@ -51,7 +51,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
/// <param name="minimumExampleCountPerLeaf">The minimal number of data points allowed in a leaf of the tree, out of the subsampled data.</param>
|
||||
/// <param name="learningRate">The learning rate.</param>
|
||||
/// <param name="numberOfIterations">The number of iterations to use.</param>
|
||||
internal LightGbmMulticlassClassificationTrainer(IHostEnvironment env,
|
||||
internal LightGbmMulticlassTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -223,7 +223,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
=> new MulticlassPredictionTransformer<OneVersusAllModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Trains a <see cref="LightGbmMulticlassClassificationTrainer"/> using both training and validation data, returns
|
||||
/// Trains a <see cref="LightGbmMulticlassTrainer"/> using both training and validation data, returns
|
||||
/// a <see cref="MulticlassPredictionTransformer{OneVsAllModelParameters}"/>.
|
||||
/// </summary>
|
||||
public MulticlassPredictionTransformer<OneVersusAllModelParameters> Fit(IDataView trainData, IDataView validationData)
|
||||
|
@ -238,8 +238,8 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
[TlcModule.EntryPoint(
|
||||
Name = "Trainers.LightGbmClassifier",
|
||||
Desc = "Train a LightGBM multi class model.",
|
||||
UserName = LightGbmMulticlassClassificationTrainer.Summary,
|
||||
ShortName = LightGbmMulticlassClassificationTrainer.ShortName)]
|
||||
UserName = LightGbmMulticlassTrainer.Summary,
|
||||
ShortName = LightGbmMulticlassTrainer.ShortName)]
|
||||
public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, Options input)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
|
@ -248,7 +248,7 @@ namespace Microsoft.ML.Trainers.LightGbm
|
|||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
|
||||
() => new LightGbmMulticlassClassificationTrainer(host, input),
|
||||
() => new LightGbmMulticlassTrainer(host, input),
|
||||
getLabel: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
|
||||
getWeight: () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
|
||||
}
|
||||
|
|
|
@ -23,16 +23,16 @@ using Microsoft.ML.Trainers;
|
|||
OlsTrainer.LoadNameValue,
|
||||
OlsTrainer.ShortName)]
|
||||
|
||||
[assembly: LoadableClass(typeof(OrdinaryLeastSquaresRegressionModelParameters), null, typeof(SignatureLoadModel),
|
||||
[assembly: LoadableClass(typeof(OlsModelParameters), null, typeof(SignatureLoadModel),
|
||||
"OLS Linear Regression Executor",
|
||||
OrdinaryLeastSquaresRegressionModelParameters.LoaderSignature)]
|
||||
OlsModelParameters.LoaderSignature)]
|
||||
|
||||
[assembly: LoadableClass(typeof(void), typeof(OlsTrainer), null, typeof(SignatureEntryPointModule), OlsTrainer.LoadNameValue)]
|
||||
|
||||
namespace Microsoft.ML.Trainers
|
||||
{
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="OLS"]/*' />
|
||||
public sealed class OlsTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OrdinaryLeastSquaresRegressionModelParameters>, OrdinaryLeastSquaresRegressionModelParameters>
|
||||
public sealed class OlsTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OlsModelParameters>, OlsModelParameters>
|
||||
{
|
||||
///<summary> Advanced options for trainer.</summary>
|
||||
public sealed class Options : TrainerInputBaseWithWeight
|
||||
|
@ -85,8 +85,8 @@ namespace Microsoft.ML.Trainers
|
|||
_perParameterSignificance = options.CalculateStatistics;
|
||||
}
|
||||
|
||||
private protected override RegressionPredictionTransformer<OrdinaryLeastSquaresRegressionModelParameters> MakeTransformer(OrdinaryLeastSquaresRegressionModelParameters model, DataViewSchema trainSchema)
|
||||
=> new RegressionPredictionTransformer<OrdinaryLeastSquaresRegressionModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
private protected override RegressionPredictionTransformer<OlsModelParameters> MakeTransformer(OlsModelParameters model, DataViewSchema trainSchema)
|
||||
=> new RegressionPredictionTransformer<OlsModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
private protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
|
||||
{
|
||||
|
@ -105,7 +105,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <returns>Either p, or 0 or 1 if it was outside the range 0 to 1</returns>
|
||||
private static Double ProbClamp(Double p) => Math.Max(0, Math.Min(p, 1));
|
||||
|
||||
private protected override OrdinaryLeastSquaresRegressionModelParameters TrainModelCore(TrainContext context)
|
||||
private protected override OlsModelParameters TrainModelCore(TrainContext context)
|
||||
{
|
||||
using (var ch = Host.Start("Training"))
|
||||
{
|
||||
|
@ -136,7 +136,7 @@ namespace Microsoft.ML.Trainers
|
|||
}
|
||||
}
|
||||
|
||||
private OrdinaryLeastSquaresRegressionModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
|
||||
private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
|
||||
{
|
||||
Host.AssertValue(ch);
|
||||
ch.AssertValue(cursorFactory);
|
||||
|
@ -267,7 +267,7 @@ namespace Microsoft.ML.Trainers
|
|||
{
|
||||
// We would expect the solution to the problem to be exact in this case.
|
||||
ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived");
|
||||
return new OrdinaryLeastSquaresRegressionModelParameters(Host, in weights, bias);
|
||||
return new OlsModelParameters(Host, in weights, bias);
|
||||
}
|
||||
|
||||
Double rss = 0; // residual sum of squares
|
||||
|
@ -303,7 +303,7 @@ namespace Microsoft.ML.Trainers
|
|||
// Also we can't estimate it, unless we can estimate the variance, which requires more examples than
|
||||
// parameters.
|
||||
if (!_perParameterSignificance || m >= n)
|
||||
return new OrdinaryLeastSquaresRegressionModelParameters(Host, in weights, bias, rSquared: rSquared, rSquaredAdjusted: rSquaredAdjusted);
|
||||
return new OlsModelParameters(Host, in weights, bias, rSquared: rSquared, rSquaredAdjusted: rSquaredAdjusted);
|
||||
|
||||
ch.Assert(!Double.IsNaN(rSquaredAdjusted));
|
||||
var standardErrors = new Double[m];
|
||||
|
@ -350,7 +350,7 @@ namespace Microsoft.ML.Trainers
|
|||
ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range");
|
||||
}
|
||||
|
||||
return new OrdinaryLeastSquaresRegressionModelParameters(Host, in weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted);
|
||||
return new OlsModelParameters(Host, in weights, bias, standardErrors, tValues, pValues, rSquared, rSquaredAdjusted);
|
||||
}
|
||||
|
||||
internal static class Mkl
|
||||
|
@ -509,7 +509,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <summary>
|
||||
/// A linear predictor for which per parameter significance statistics are available.
|
||||
/// </summary>
|
||||
public sealed class OrdinaryLeastSquaresRegressionModelParameters : RegressionModelParameters
|
||||
public sealed class OlsModelParameters : RegressionModelParameters
|
||||
{
|
||||
internal const string LoaderSignature = "OlsLinearRegressionExec";
|
||||
internal const string RegistrationName = "OlsLinearRegressionPredictor";
|
||||
|
@ -525,7 +525,7 @@ namespace Microsoft.ML.Trainers
|
|||
verReadableCur: 0x00010001,
|
||||
verWeCanReadBack: 0x00010001,
|
||||
loaderSignature: LoaderSignature,
|
||||
loaderAssemblyName: typeof(OrdinaryLeastSquaresRegressionModelParameters).Assembly.FullName);
|
||||
loaderAssemblyName: typeof(OlsModelParameters).Assembly.FullName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -587,7 +587,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="pValues">Optional: The p-values of the weights and bias.</param>
|
||||
/// <param name="rSquared">The coefficient of determination.</param>
|
||||
/// <param name="rSquaredAdjusted">The adjusted coefficient of determination.</param>
|
||||
internal OrdinaryLeastSquaresRegressionModelParameters(IHostEnvironment env, in VBuffer<float> weights, float bias,
|
||||
internal OlsModelParameters(IHostEnvironment env, in VBuffer<float> weights, float bias,
|
||||
Double[] standardErrors = null, Double[] tValues = null, Double[] pValues = null, Double rSquared = 1, Double rSquaredAdjusted = float.NaN)
|
||||
: base(env, RegistrationName, in weights, bias)
|
||||
{
|
||||
|
@ -624,7 +624,7 @@ namespace Microsoft.ML.Trainers
|
|||
RSquaredAdjusted = rSquaredAdjusted;
|
||||
}
|
||||
|
||||
private OrdinaryLeastSquaresRegressionModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private OlsModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
: base(env, RegistrationName, ctx)
|
||||
{
|
||||
// *** Binary format ***
|
||||
|
@ -708,12 +708,12 @@ namespace Microsoft.ML.Trainers
|
|||
Contracts.CheckDecode(0 <= p && p <= 1);
|
||||
}
|
||||
|
||||
private static OrdinaryLeastSquaresRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private static OlsModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(ctx, nameof(ctx));
|
||||
ctx.CheckAtModel(GetVersionInfo());
|
||||
return new OrdinaryLeastSquaresRegressionModelParameters(env, ctx);
|
||||
return new OlsModelParameters(env, ctx);
|
||||
}
|
||||
|
||||
private protected override void SaveSummary(TextWriter writer, RoleMappedSchema schema)
|
||||
|
|
|
@ -23,8 +23,8 @@ using Microsoft.ML.Trainers;
|
|||
RandomizedPcaTrainer.LoadNameValue,
|
||||
RandomizedPcaTrainer.ShortName)]
|
||||
|
||||
[assembly: LoadableClass(typeof(PrincipleComponentModelParameters), null, typeof(SignatureLoadModel),
|
||||
"PCA Anomaly Executor", PrincipleComponentModelParameters.LoaderSignature)]
|
||||
[assembly: LoadableClass(typeof(PcaModelParameters), null, typeof(SignatureLoadModel),
|
||||
"PCA Anomaly Executor", PcaModelParameters.LoaderSignature)]
|
||||
|
||||
[assembly: LoadableClass(typeof(void), typeof(RandomizedPcaTrainer), null, typeof(SignatureEntryPointModule), RandomizedPcaTrainer.LoadNameValue)]
|
||||
|
||||
|
@ -39,7 +39,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <remarks>
|
||||
/// This PCA can be made into Kernel PCA by using Random Fourier Features transform
|
||||
/// </remarks>
|
||||
public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<AnomalyPredictionTransformer<PrincipleComponentModelParameters>, PrincipleComponentModelParameters>
|
||||
public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<AnomalyPredictionTransformer<PcaModelParameters>, PcaModelParameters>
|
||||
{
|
||||
internal const string LoadNameValue = "pcaAnomaly";
|
||||
internal const string UserNameValue = "PCA Anomaly Detector";
|
||||
|
@ -139,7 +139,7 @@ namespace Microsoft.ML.Trainers
|
|||
|
||||
}
|
||||
|
||||
private protected override PrincipleComponentModelParameters TrainModelCore(TrainContext context)
|
||||
private protected override PcaModelParameters TrainModelCore(TrainContext context)
|
||||
{
|
||||
Host.CheckValue(context, nameof(context));
|
||||
|
||||
|
@ -164,7 +164,7 @@ namespace Microsoft.ML.Trainers
|
|||
}
|
||||
|
||||
//Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9)
|
||||
private PrincipleComponentModelParameters TrainCore(IChannel ch, RoleMappedData data, int dimension)
|
||||
private PcaModelParameters TrainCore(IChannel ch, RoleMappedData data, int dimension)
|
||||
{
|
||||
Host.AssertValue(ch);
|
||||
ch.AssertValue(data);
|
||||
|
@ -222,7 +222,7 @@ namespace Microsoft.ML.Trainers
|
|||
EigenUtils.EigenDecomposition(b2, out smallEigenvalues, out smallEigenvectors);
|
||||
PostProcess(b, smallEigenvalues, smallEigenvectors, dimension, oversampledRank);
|
||||
|
||||
return new PrincipleComponentModelParameters(Host, _rank, b, in mean);
|
||||
return new PcaModelParameters(Host, _rank, b, in mean);
|
||||
}
|
||||
|
||||
private static float[][] Zeros(int k, int d)
|
||||
|
@ -343,8 +343,8 @@ namespace Microsoft.ML.Trainers
|
|||
};
|
||||
}
|
||||
|
||||
private protected override AnomalyPredictionTransformer<PrincipleComponentModelParameters> MakeTransformer(PrincipleComponentModelParameters model, DataViewSchema trainSchema)
|
||||
=> new AnomalyPredictionTransformer<PrincipleComponentModelParameters>(Host, model, trainSchema, _featureColumn);
|
||||
private protected override AnomalyPredictionTransformer<PcaModelParameters> MakeTransformer(PcaModelParameters model, DataViewSchema trainSchema)
|
||||
=> new AnomalyPredictionTransformer<PcaModelParameters>(Host, model, trainSchema, _featureColumn);
|
||||
|
||||
[TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector",
|
||||
Desc = "Train an PCA Anomaly model.",
|
||||
|
@ -370,7 +370,7 @@ namespace Microsoft.ML.Trainers
|
|||
// REVIEW: move the predictor to a different file and fold EigenUtils.cs to this file.
|
||||
// REVIEW: Include the above detail in the XML documentation file.
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
|
||||
public sealed class PrincipleComponentModelParameters : ModelParametersBase<float>,
|
||||
public sealed class PcaModelParameters : ModelParametersBase<float>,
|
||||
IValueMapper,
|
||||
ICanGetSummaryAsIDataView,
|
||||
ICanSaveInTextFormat,
|
||||
|
@ -387,7 +387,7 @@ namespace Microsoft.ML.Trainers
|
|||
verReadableCur: 0x00010001,
|
||||
verWeCanReadBack: 0x00010001,
|
||||
loaderSignature: LoaderSignature,
|
||||
loaderAssemblyName: typeof(PrincipleComponentModelParameters).Assembly.FullName);
|
||||
loaderAssemblyName: typeof(PcaModelParameters).Assembly.FullName);
|
||||
}
|
||||
|
||||
private readonly int _dimension;
|
||||
|
@ -408,7 +408,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="rank">The rank of the PCA approximation of the covariance matrix. This is the number of eigenvectors in the model.</param>
|
||||
/// <param name="eigenVectors">Array of eigenvectors.</param>
|
||||
/// <param name="mean">The mean vector of the training data.</param>
|
||||
internal PrincipleComponentModelParameters(IHostEnvironment env, int rank, float[][] eigenVectors, in VBuffer<float> mean)
|
||||
internal PcaModelParameters(IHostEnvironment env, int rank, float[][] eigenVectors, in VBuffer<float> mean)
|
||||
: base(env, RegistrationName)
|
||||
{
|
||||
_dimension = eigenVectors[0].Length;
|
||||
|
@ -428,7 +428,7 @@ namespace Microsoft.ML.Trainers
|
|||
_inputType = new VectorType(NumberDataViewType.Single, _dimension);
|
||||
}
|
||||
|
||||
private PrincipleComponentModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private PcaModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
: base(env, RegistrationName, ctx)
|
||||
{
|
||||
// *** Binary format ***
|
||||
|
@ -500,12 +500,12 @@ namespace Microsoft.ML.Trainers
|
|||
writer.WriteSinglesNoCount(_eigenVectors[i].GetValues().Slice(0, _dimension));
|
||||
}
|
||||
|
||||
private static PrincipleComponentModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(ctx, nameof(ctx));
|
||||
ctx.CheckAtModel(GetVersionInfo());
|
||||
return new PrincipleComponentModelParameters(env, ctx);
|
||||
return new PcaModelParameters(env, ctx);
|
||||
}
|
||||
|
||||
void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
|
||||
|
|
|
@ -15,21 +15,21 @@ using Microsoft.ML.Numeric;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers;
|
||||
|
||||
[assembly: LoadableClass(LogisticRegressionBinaryClassificationTrainer.Summary, typeof(LogisticRegressionBinaryClassificationTrainer), typeof(LogisticRegressionBinaryClassificationTrainer.Options),
|
||||
[assembly: LoadableClass(LogisticRegressionBinaryTrainer.Summary, typeof(LogisticRegressionBinaryTrainer), typeof(LogisticRegressionBinaryTrainer.Options),
|
||||
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
|
||||
LogisticRegressionBinaryClassificationTrainer.UserNameValue,
|
||||
LogisticRegressionBinaryClassificationTrainer.LoadNameValue,
|
||||
LogisticRegressionBinaryClassificationTrainer.ShortName,
|
||||
LogisticRegressionBinaryTrainer.UserNameValue,
|
||||
LogisticRegressionBinaryTrainer.LoadNameValue,
|
||||
LogisticRegressionBinaryTrainer.ShortName,
|
||||
"logisticregressionwrapper")]
|
||||
|
||||
[assembly: LoadableClass(typeof(void), typeof(LogisticRegressionBinaryClassificationTrainer), null, typeof(SignatureEntryPointModule), LogisticRegressionBinaryClassificationTrainer.LoadNameValue)]
|
||||
[assembly: LoadableClass(typeof(void), typeof(LogisticRegressionBinaryTrainer), null, typeof(SignatureEntryPointModule), LogisticRegressionBinaryTrainer.LoadNameValue)]
|
||||
|
||||
namespace Microsoft.ML.Trainers
|
||||
{
|
||||
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="LBFGS"]/*' />
|
||||
/// <include file='doc.xml' path='docs/members/example[@name="LogisticRegressionBinaryClassifier"]/*' />
|
||||
public sealed partial class LogisticRegressionBinaryClassificationTrainer : LbfgsTrainerBase<LogisticRegressionBinaryClassificationTrainer.Options,
|
||||
public sealed partial class LogisticRegressionBinaryTrainer : LbfgsTrainerBase<LogisticRegressionBinaryTrainer.Options,
|
||||
BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>,
|
||||
CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>
|
||||
{
|
||||
|
@ -54,7 +54,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <summary>
|
||||
/// The instance of <see cref="ComputeLogisticRegressionStandardDeviation"/> that computes the std of the training statistics, at the end of training.
|
||||
/// The calculations are not part of Microsoft.ML package, due to the size of MKL.
|
||||
/// If you need these calculations, add the Microsoft.ML.Mkl.Components package, and initialize <see cref="LogisticRegressionBinaryClassificationTrainer.Options.ComputeStandardDeviation"/>.
|
||||
/// If you need these calculations, add the Microsoft.ML.Mkl.Components package, and initialize <see cref="LogisticRegressionBinaryTrainer.Options.ComputeStandardDeviation"/>.
|
||||
/// to the <see cref="ComputeLogisticRegressionStandardDeviation"/> implementation in the Microsoft.ML.Mkl.Components package.
|
||||
/// </summary>
|
||||
public ComputeLogisticRegressionStandardDeviation ComputeStandardDeviation;
|
||||
|
@ -64,7 +64,7 @@ namespace Microsoft.ML.Trainers
|
|||
private LinearModelStatistics _stats;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="LogisticRegressionBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="LogisticRegressionBinaryTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The environment to use.</param>
|
||||
/// <param name="labelColumn">The name of the label column.</param>
|
||||
|
@ -73,9 +73,9 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="enforceNoNegativity">Enforce non-negative weights.</param>
|
||||
/// <param name="l1Weight">Weight of L1 regularizer term.</param>
|
||||
/// <param name="l2Weight">Weight of L2 regularizer term.</param>
|
||||
/// <param name="memorySize">Memory size for <see cref="LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="memorySize">Memory size for <see cref="LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
internal LogisticRegressionBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal LogisticRegressionBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumn = DefaultColumnNames.Label,
|
||||
string featureColumn = DefaultColumnNames.Features,
|
||||
string weights = null,
|
||||
|
@ -95,9 +95,9 @@ namespace Microsoft.ML.Trainers
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="LogisticRegressionBinaryClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="LogisticRegressionBinaryTrainer"/>
|
||||
/// </summary>
|
||||
internal LogisticRegressionBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal LogisticRegressionBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
|
||||
{
|
||||
_posWeight = 0;
|
||||
|
@ -127,7 +127,7 @@ namespace Microsoft.ML.Trainers
|
|||
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Continues the training of a <see cref="LogisticRegressionBinaryClassificationTrainer"/> using an already trained <paramref name="modelParameters"/> and returns
|
||||
/// Continues the training of a <see cref="LogisticRegressionBinaryTrainer"/> using an already trained <paramref name="modelParameters"/> and returns
|
||||
/// a <see cref="BinaryPredictionTransformer{CalibratedModelParametersBase}"/>.
|
||||
/// </summary>
|
||||
public BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>> Fit(IDataView trainData, LinearModelParameters modelParameters)
|
||||
|
@ -420,7 +420,7 @@ namespace Microsoft.ML.Trainers
|
|||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
|
||||
() => new LogisticRegressionBinaryClassificationTrainer(host, input),
|
||||
() => new LogisticRegressionBinaryTrainer(host, input),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.ExampleWeightColumnName));
|
||||
}
|
||||
|
@ -439,7 +439,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// Computes the standard deviation matrix of each of the non-zero training weights, needed to calculate further the standard deviation,
|
||||
/// p-value and z-Score.
|
||||
/// The calculations are not part of Microsoft.ML package, due to the size of MKL.
|
||||
/// If you need these calculations, add the Microsoft.ML.Mkl.Components package, and initialize <see cref="LogisticRegressionBinaryClassificationTrainer.Options.ComputeStandardDeviation"/>
|
||||
/// If you need these calculations, add the Microsoft.ML.Mkl.Components package, and initialize <see cref="LogisticRegressionBinaryTrainer.Options.ComputeStandardDeviation"/>
|
||||
/// to the <see cref="ComputeLogisticRegressionStandardDeviation"/> implementation in the Microsoft.ML.Mkl.Components package.
|
||||
/// Due to the existence of regularization, an approximation is used to compute the variances of the trained linear coefficients.
|
||||
/// </summary>
|
||||
|
|
|
@ -80,7 +80,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="enforceNoNegativity">Enforce non-negative weights.</param>
|
||||
/// <param name="l1Weight">Weight of L1 regularizer term.</param>
|
||||
/// <param name="l2Weight">Weight of L2 regularizer term.</param>
|
||||
/// <param name="memorySize">Memory size for <see cref="LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="memorySize">Memory size for <see cref="LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
internal LogisticRegressionMulticlassClassificationTrainer(IHostEnvironment env,
|
||||
string labelColumn = DefaultColumnNames.Label,
|
||||
|
@ -429,7 +429,7 @@ namespace Microsoft.ML.Trainers
|
|||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="MulticlassLogisticRegressionModelParameters"/> class.
|
||||
/// This constructor is called by <see cref="SdcaMulticlassClassificationTrainer"/> to create the predictor.
|
||||
/// This constructor is called by <see cref="SdcaMulticlassTrainer"/> to create the predictor.
|
||||
/// </summary>
|
||||
/// <param name="env">The host environment.</param>
|
||||
/// <param name="weights">The array of weights vectors. It should contain <paramref name="numClasses"/> weights.</param>
|
||||
|
@ -1004,7 +1004,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <summary>
|
||||
/// A component to train a logistic regression model.
|
||||
/// </summary>
|
||||
public partial class LogisticRegressionBinaryClassificationTrainer
|
||||
public partial class LogisticRegressionBinaryTrainer
|
||||
{
|
||||
[TlcModule.EntryPoint(Name = "Trainers.LogisticRegressionClassifier",
|
||||
Desc = Summary,
|
||||
|
|
|
@ -13,20 +13,20 @@ using Microsoft.ML.Model;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers;
|
||||
|
||||
[assembly: LoadableClass(NaiveBayesTrainer.Summary, typeof(NaiveBayesTrainer), typeof(NaiveBayesTrainer.Options),
|
||||
[assembly: LoadableClass(NaiveBayesMulticlassTrainer.Summary, typeof(NaiveBayesMulticlassTrainer), typeof(NaiveBayesMulticlassTrainer.Options),
|
||||
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer) },
|
||||
NaiveBayesTrainer.UserName,
|
||||
NaiveBayesTrainer.LoadName,
|
||||
NaiveBayesTrainer.ShortName, DocName = "trainer/NaiveBayes.md")]
|
||||
NaiveBayesMulticlassTrainer.UserName,
|
||||
NaiveBayesMulticlassTrainer.LoadName,
|
||||
NaiveBayesMulticlassTrainer.ShortName, DocName = "trainer/NaiveBayes.md")]
|
||||
|
||||
[assembly: LoadableClass(typeof(MulticlassNaiveBayesModelParameters), null, typeof(SignatureLoadModel),
|
||||
"Multi Class Naive Bayes predictor", MulticlassNaiveBayesModelParameters.LoaderSignature)]
|
||||
[assembly: LoadableClass(typeof(NaiveBayesMulticlassModelParameters), null, typeof(SignatureLoadModel),
|
||||
"Multi Class Naive Bayes predictor", NaiveBayesMulticlassModelParameters.LoaderSignature)]
|
||||
|
||||
[assembly: LoadableClass(typeof(void), typeof(NaiveBayesTrainer), null, typeof(SignatureEntryPointModule), NaiveBayesTrainer.LoadName)]
|
||||
[assembly: LoadableClass(typeof(void), typeof(NaiveBayesMulticlassTrainer), null, typeof(SignatureEntryPointModule), NaiveBayesMulticlassTrainer.LoadName)]
|
||||
|
||||
namespace Microsoft.ML.Trainers
|
||||
{
|
||||
public sealed class NaiveBayesTrainer : TrainerEstimatorBase<MulticlassPredictionTransformer<MulticlassNaiveBayesModelParameters>, MulticlassNaiveBayesModelParameters>
|
||||
public sealed class NaiveBayesMulticlassTrainer : TrainerEstimatorBase<MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters>, NaiveBayesMulticlassModelParameters>
|
||||
{
|
||||
internal const string LoadName = "MultiClassNaiveBayes";
|
||||
internal const string UserName = "Multiclass Naive Bayes";
|
||||
|
@ -49,12 +49,12 @@ namespace Microsoft.ML.Trainers
|
|||
public override TrainerInfo Info => _info;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="NaiveBayesTrainer"/>
|
||||
/// Initializes a new instance of <see cref="NaiveBayesMulticlassTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The environment to use.</param>
|
||||
/// <param name="labelColumn">The name of the label column.</param>
|
||||
/// <param name="featureColumn">The name of the feature column.</param>
|
||||
internal NaiveBayesTrainer(IHostEnvironment env,
|
||||
internal NaiveBayesMulticlassTrainer(IHostEnvironment env,
|
||||
string labelColumn = DefaultColumnNames.Label,
|
||||
string featureColumn = DefaultColumnNames.Features)
|
||||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn),
|
||||
|
@ -65,9 +65,9 @@ namespace Microsoft.ML.Trainers
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="NaiveBayesTrainer"/>
|
||||
/// Initializes a new instance of <see cref="NaiveBayesMulticlassTrainer"/>
|
||||
/// </summary>
|
||||
internal NaiveBayesTrainer(IHostEnvironment env, Options options)
|
||||
internal NaiveBayesMulticlassTrainer(IHostEnvironment env, Options options)
|
||||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName),
|
||||
TrainerUtils.MakeU4ScalarColumn(options.LabelColumnName))
|
||||
{
|
||||
|
@ -89,10 +89,10 @@ namespace Microsoft.ML.Trainers
|
|||
};
|
||||
}
|
||||
|
||||
private protected override MulticlassPredictionTransformer<MulticlassNaiveBayesModelParameters> MakeTransformer(MulticlassNaiveBayesModelParameters model, DataViewSchema trainSchema)
|
||||
=> new MulticlassPredictionTransformer<MulticlassNaiveBayesModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
|
||||
private protected override MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters> MakeTransformer(NaiveBayesMulticlassModelParameters model, DataViewSchema trainSchema)
|
||||
=> new MulticlassPredictionTransformer<NaiveBayesMulticlassModelParameters>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
|
||||
|
||||
private protected override MulticlassNaiveBayesModelParameters TrainModelCore(TrainContext context)
|
||||
private protected override NaiveBayesMulticlassModelParameters TrainModelCore(TrainContext context)
|
||||
{
|
||||
Host.CheckValue(context, nameof(context));
|
||||
var data = context.TrainingSet;
|
||||
|
@ -160,7 +160,7 @@ namespace Microsoft.ML.Trainers
|
|||
|
||||
Array.Resize(ref labelHistogram, labelCount);
|
||||
Array.Resize(ref featureHistogram, labelCount);
|
||||
return new MulticlassNaiveBayesModelParameters(Host, labelHistogram, featureHistogram, featureCount);
|
||||
return new NaiveBayesMulticlassModelParameters(Host, labelHistogram, featureHistogram, featureCount);
|
||||
}
|
||||
|
||||
[TlcModule.EntryPoint(Name = "Trainers.NaiveBayesClassifier",
|
||||
|
@ -175,12 +175,12 @@ namespace Microsoft.ML.Trainers
|
|||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
|
||||
() => new NaiveBayesTrainer(host, input),
|
||||
() => new NaiveBayesMulticlassTrainer(host, input),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
|
||||
}
|
||||
}
|
||||
|
||||
public sealed class MulticlassNaiveBayesModelParameters :
|
||||
public sealed class NaiveBayesMulticlassModelParameters :
|
||||
ModelParametersBase<VBuffer<float>>,
|
||||
IValueMapper
|
||||
{
|
||||
|
@ -193,7 +193,7 @@ namespace Microsoft.ML.Trainers
|
|||
verReadableCur: 0x00010001,
|
||||
verWeCanReadBack: 0x00010001,
|
||||
loaderSignature: LoaderSignature,
|
||||
loaderAssemblyName: typeof(MulticlassNaiveBayesModelParameters).Assembly.FullName);
|
||||
loaderAssemblyName: typeof(NaiveBayesMulticlassModelParameters).Assembly.FullName);
|
||||
}
|
||||
|
||||
private readonly int[] _labelHistogram;
|
||||
|
@ -229,7 +229,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="labelHistogram">The histogram of labels.</param>
|
||||
/// <param name="featureHistogram">The feature histogram.</param>
|
||||
/// <param name="featureCount">The number of features.</param>
|
||||
internal MulticlassNaiveBayesModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
|
||||
internal NaiveBayesMulticlassModelParameters(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
|
||||
: base(env, LoaderSignature)
|
||||
{
|
||||
Host.AssertValue(labelHistogram);
|
||||
|
@ -246,7 +246,7 @@ namespace Microsoft.ML.Trainers
|
|||
_outputType = new VectorType(NumberDataViewType.Single, _labelCount);
|
||||
}
|
||||
|
||||
private MulticlassNaiveBayesModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private NaiveBayesMulticlassModelParameters(IHostEnvironment env, ModelLoadContext ctx)
|
||||
: base(env, LoaderSignature, ctx)
|
||||
{
|
||||
// *** Binary format ***
|
||||
|
@ -280,12 +280,12 @@ namespace Microsoft.ML.Trainers
|
|||
_outputType = new VectorType(NumberDataViewType.Single, _labelCount);
|
||||
}
|
||||
|
||||
private static MulticlassNaiveBayesModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
private static NaiveBayesMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
env.CheckValue(ctx, nameof(ctx));
|
||||
ctx.CheckAtModel(GetVersionInfo());
|
||||
return new MulticlassNaiveBayesModelParameters(env, ctx);
|
||||
return new NaiveBayesMulticlassModelParameters(env, ctx);
|
||||
}
|
||||
|
||||
private protected override void SaveCore(ModelSaveContext ctx)
|
||||
|
|
|
@ -48,7 +48,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="l1Weight">Weight of L1 regularizer term.</param>
|
||||
/// <param name="l2Weight">Weight of L2 regularizer term.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
/// <param name="memorySize">Memory size for <see cref="LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="memorySize">Memory size for <see cref="LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="enforceNoNegativity">Enforce non-negative weights.</param>
|
||||
internal PoissonRegressionTrainer(IHostEnvironment env,
|
||||
string labelColumn = DefaultColumnNames.Label,
|
||||
|
|
|
@ -1425,8 +1425,8 @@ namespace Microsoft.ML.Trainers
|
|||
/// <summary>
|
||||
/// SDCA is a general training algorithm for (generalized) linear models such as support vector machine, linear regression, logistic regression,
|
||||
/// and so on. SDCA binary classification trainer family includes several sealed members:
|
||||
/// (1) <see cref="SdcaNonCalibratedBinaryClassificationTrainer"/> supports general loss functions and returns <see cref="LinearBinaryModelParameters"/>.
|
||||
/// (2) <see cref="SdcaCalibratedBinaryClassificationTrainer"/> essentially trains a regularized logistic regression model. Because logistic regression
|
||||
/// (1) <see cref="SdcaNonCalibratedBinaryTrainer"/> supports general loss functions and returns <see cref="LinearBinaryModelParameters"/>.
|
||||
/// (2) <see cref="SdcaCalibratedBinaryTrainer"/> essentially trains a regularized logistic regression model. Because logistic regression
|
||||
/// naturally provide probability output, this generated model's type is <see cref="CalibratedModelParametersBase{TSubModel, TCalibrator}"/>.
|
||||
/// where <see langword="TSubModel"/> is <see cref="LinearBinaryModelParameters"/> and <see langword="TCalibrator "/> is <see cref="PlattCalibrator"/>.
|
||||
/// </summary>
|
||||
|
@ -1546,17 +1546,17 @@ namespace Microsoft.ML.Trainers
|
|||
/// linear function to a <see cref="PlattCalibrator"/>.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="SDCA_remarks"]/*' />
|
||||
public sealed class SdcaCalibratedBinaryClassificationTrainer :
|
||||
public sealed class SdcaCalibratedBinaryTrainer :
|
||||
SdcaBinaryTrainerBase<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for the <see cref="SdcaCalibratedBinaryClassificationTrainer"/>.
|
||||
/// Options for the <see cref="SdcaCalibratedBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
public sealed class Options : BinaryOptionsBase
|
||||
{
|
||||
}
|
||||
|
||||
internal SdcaCalibratedBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal SdcaCalibratedBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string weightColumnName = null,
|
||||
|
@ -1567,7 +1567,7 @@ namespace Microsoft.ML.Trainers
|
|||
{
|
||||
}
|
||||
|
||||
internal SdcaCalibratedBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal SdcaCalibratedBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, options, new LogLoss())
|
||||
{
|
||||
}
|
||||
|
@ -1610,10 +1610,10 @@ namespace Microsoft.ML.Trainers
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a binary logistic regression classification model using the stochastic dual coordinate ascent method.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="SDCA_remarks"]/*' />
|
||||
public sealed class SdcaNonCalibratedBinaryClassificationTrainer : SdcaBinaryTrainerBase<LinearBinaryModelParameters>
|
||||
public sealed class SdcaNonCalibratedBinaryTrainer : SdcaBinaryTrainerBase<LinearBinaryModelParameters>
|
||||
{
|
||||
/// <summary>
|
||||
/// Options for the <see cref="SdcaNonCalibratedBinaryClassificationTrainer"/>.
|
||||
/// Options for the <see cref="SdcaNonCalibratedBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
public sealed class Options : BinaryOptionsBase
|
||||
{
|
||||
|
@ -1635,7 +1635,7 @@ namespace Microsoft.ML.Trainers
|
|||
public ISupportSdcaClassificationLoss LossFunction { get; set; }
|
||||
}
|
||||
|
||||
internal SdcaNonCalibratedBinaryClassificationTrainer(IHostEnvironment env,
|
||||
internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string weightColumnName = null,
|
||||
|
@ -1647,7 +1647,7 @@ namespace Microsoft.ML.Trainers
|
|||
{
|
||||
}
|
||||
|
||||
internal SdcaNonCalibratedBinaryClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal SdcaNonCalibratedBinaryTrainer(IHostEnvironment env, Options options)
|
||||
: base(env, options, options.LossFunction ?? options.LossFunctionFactory.CreateComponent(env))
|
||||
{
|
||||
}
|
||||
|
@ -1673,7 +1673,7 @@ namespace Microsoft.ML.Trainers
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Comparing with <see cref="SdcaCalibratedBinaryClassificationTrainer.CreatePredictor(VBuffer{float}[], float[])"/>,
|
||||
/// Comparing with <see cref="SdcaCalibratedBinaryTrainer.CreatePredictor(VBuffer{float}[], float[])"/>,
|
||||
/// <see cref="CreatePredictor"/> directly outputs a <see cref="LinearBinaryModelParameters"/> built from
|
||||
/// the learned weights and bias without calibration.
|
||||
/// </summary>
|
||||
|
@ -1940,7 +1940,7 @@ namespace Microsoft.ML.Trainers
|
|||
=> new BinaryPredictionTransformer<TModel>(Host, model, trainSchema, FeatureColumn.Name);
|
||||
|
||||
/// <summary>
|
||||
/// Continues the training of a <see cref="SdcaCalibratedBinaryClassificationTrainer"/> using an already trained <paramref name="modelParameters"/> and returns a <see cref="BinaryPredictionTransformer"/>.
|
||||
/// Continues the training of a <see cref="SdcaCalibratedBinaryTrainer"/> using an already trained <paramref name="modelParameters"/> and returns a <see cref="BinaryPredictionTransformer"/>.
|
||||
/// </summary>
|
||||
public BinaryPredictionTransformer<TModel> Fit(IDataView trainData, LinearModelParameters modelParameters)
|
||||
=> TrainTransformer(trainData, initPredictor: modelParameters);
|
||||
|
|
|
@ -16,11 +16,11 @@ using Microsoft.ML.Numeric;
|
|||
using Microsoft.ML.Runtime;
|
||||
using Microsoft.ML.Trainers;
|
||||
|
||||
[assembly: LoadableClass(SdcaMulticlassClassificationTrainer.Summary, typeof(SdcaMulticlassClassificationTrainer), typeof(SdcaMulticlassClassificationTrainer.Options),
|
||||
[assembly: LoadableClass(SdcaMulticlassTrainer.Summary, typeof(SdcaMulticlassTrainer), typeof(SdcaMulticlassTrainer.Options),
|
||||
new[] { typeof(SignatureMulticlassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
|
||||
SdcaMulticlassClassificationTrainer.UserNameValue,
|
||||
SdcaMulticlassClassificationTrainer.LoadNameValue,
|
||||
SdcaMulticlassClassificationTrainer.ShortName)]
|
||||
SdcaMulticlassTrainer.UserNameValue,
|
||||
SdcaMulticlassTrainer.LoadNameValue,
|
||||
SdcaMulticlassTrainer.ShortName)]
|
||||
|
||||
namespace Microsoft.ML.Trainers
|
||||
{
|
||||
|
@ -28,7 +28,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// The <see cref="IEstimator{TTransformer}"/> for training a multiclass logistic regression classification model using the stochastic dual coordinate ascent method.
|
||||
/// </summary>
|
||||
/// <include file='doc.xml' path='doc/members/member[@name="SDCA_remarks"]/*' />
|
||||
public sealed class SdcaMulticlassClassificationTrainer : SdcaTrainerBase<SdcaMulticlassClassificationTrainer.Options, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
|
||||
public sealed class SdcaMulticlassTrainer : SdcaTrainerBase<SdcaMulticlassTrainer.Options, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
|
||||
{
|
||||
internal const string LoadNameValue = "SDCAMC";
|
||||
internal const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)";
|
||||
|
@ -36,7 +36,7 @@ namespace Microsoft.ML.Trainers
|
|||
internal const string Summary = "The SDCA linear multi-class classification trainer.";
|
||||
|
||||
/// <summary>
|
||||
/// Options for the <see cref="SdcaMulticlassClassificationTrainer"/>.
|
||||
/// Options for the <see cref="SdcaMulticlassTrainer"/>.
|
||||
/// </summary>
|
||||
public sealed class Options : OptionsBase
|
||||
{
|
||||
|
@ -63,7 +63,7 @@ namespace Microsoft.ML.Trainers
|
|||
private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of <see cref="SdcaMulticlassClassificationTrainer"/>
|
||||
/// Initializes a new instance of <see cref="SdcaMulticlassTrainer"/>
|
||||
/// </summary>
|
||||
/// <param name="env">The environment to use.</param>
|
||||
/// <param name="labelColumn">The label, or dependent variable.</param>
|
||||
|
@ -73,7 +73,7 @@ namespace Microsoft.ML.Trainers
|
|||
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
|
||||
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
|
||||
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
|
||||
internal SdcaMulticlassClassificationTrainer(IHostEnvironment env,
|
||||
internal SdcaMulticlassTrainer(IHostEnvironment env,
|
||||
string labelColumn = DefaultColumnNames.Label,
|
||||
string featureColumn = DefaultColumnNames.Features,
|
||||
string weights = null,
|
||||
|
@ -90,7 +90,7 @@ namespace Microsoft.ML.Trainers
|
|||
Loss = _loss;
|
||||
}
|
||||
|
||||
internal SdcaMulticlassClassificationTrainer(IHostEnvironment env, Options options,
|
||||
internal SdcaMulticlassTrainer(IHostEnvironment env, Options options,
|
||||
string featureColumn, string labelColumn, string weightColumn = null)
|
||||
: base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
|
||||
{
|
||||
|
@ -101,7 +101,7 @@ namespace Microsoft.ML.Trainers
|
|||
Loss = _loss;
|
||||
}
|
||||
|
||||
internal SdcaMulticlassClassificationTrainer(IHostEnvironment env, Options options)
|
||||
internal SdcaMulticlassTrainer(IHostEnvironment env, Options options)
|
||||
: this(env, options, options.FeatureColumnName, options.LabelColumnName)
|
||||
{
|
||||
}
|
||||
|
@ -448,18 +448,18 @@ namespace Microsoft.ML.Trainers
|
|||
internal static partial class Sdca
|
||||
{
|
||||
[TlcModule.EntryPoint(Name = "Trainers.StochasticDualCoordinateAscentClassifier",
|
||||
Desc = SdcaMulticlassClassificationTrainer.Summary,
|
||||
UserName = SdcaMulticlassClassificationTrainer.UserNameValue,
|
||||
ShortName = SdcaMulticlassClassificationTrainer.ShortName)]
|
||||
public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, SdcaMulticlassClassificationTrainer.Options input)
|
||||
Desc = SdcaMulticlassTrainer.Summary,
|
||||
UserName = SdcaMulticlassTrainer.UserNameValue,
|
||||
ShortName = SdcaMulticlassTrainer.ShortName)]
|
||||
public static CommonOutputs.MulticlassClassificationOutput TrainMulticlass(IHostEnvironment env, SdcaMulticlassTrainer.Options input)
|
||||
{
|
||||
Contracts.CheckValue(env, nameof(env));
|
||||
var host = env.Register("TrainSDCA");
|
||||
host.CheckValue(input, nameof(input));
|
||||
EntryPointUtils.CheckInputArgs(host, input);
|
||||
|
||||
return TrainerEntryPointsUtils.Train<SdcaMulticlassClassificationTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
|
||||
() => new SdcaMulticlassClassificationTrainer(host, input),
|
||||
return TrainerEntryPointsUtils.Train<SdcaMulticlassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
|
||||
() => new SdcaMulticlassTrainer(host, input),
|
||||
() => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ using Microsoft.ML.Trainers;
|
|||
|
||||
namespace Microsoft.ML
|
||||
{
|
||||
using LROptions = LogisticRegressionBinaryClassificationTrainer.Options;
|
||||
using LROptions = LogisticRegressionBinaryTrainer.Options;
|
||||
|
||||
/// <summary>
|
||||
/// TrainerEstimator extension methods.
|
||||
|
@ -181,7 +181,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaCalibratedBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaCalibratedBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -196,7 +196,7 @@ namespace Microsoft.ML
|
|||
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscent.cs)]
|
||||
/// ]]></format>
|
||||
/// </example>
|
||||
public static SdcaCalibratedBinaryClassificationTrainer SdcaCalibrated(
|
||||
public static SdcaCalibratedBinaryTrainer SdcaCalibrated(
|
||||
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
|
@ -207,11 +207,11 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new SdcaCalibratedBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Threshold, maximumNumberOfIterations);
|
||||
return new SdcaCalibratedBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l2Regularization, l1Threshold, maximumNumberOfIterations);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaCalibratedBinaryClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaCalibratedBinaryTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
|
@ -221,19 +221,19 @@ namespace Microsoft.ML
|
|||
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentWithOptions.cs)]
|
||||
/// ]]></format>
|
||||
/// </example>
|
||||
public static SdcaCalibratedBinaryClassificationTrainer SdcaCalibrated(
|
||||
public static SdcaCalibratedBinaryTrainer SdcaCalibrated(
|
||||
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
SdcaCalibratedBinaryClassificationTrainer.Options options)
|
||||
SdcaCalibratedBinaryTrainer.Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
Contracts.CheckValue(options, nameof(options));
|
||||
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new SdcaCalibratedBinaryClassificationTrainer(env, options);
|
||||
return new SdcaCalibratedBinaryTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaNonCalibratedBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaNonCalibratedBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -249,7 +249,7 @@ namespace Microsoft.ML
|
|||
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/StochasticDualCoordinateAscentNonCalibrated.cs)]
|
||||
/// ]]></format>
|
||||
/// </example>
|
||||
public static SdcaNonCalibratedBinaryClassificationTrainer SdcaNonCalibrated(
|
||||
public static SdcaNonCalibratedBinaryTrainer SdcaNonCalibrated(
|
||||
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
|
@ -261,27 +261,27 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new SdcaNonCalibratedBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations);
|
||||
return new SdcaNonCalibratedBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaNonCalibratedBinaryClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a linear classification model trained with <see cref="SdcaNonCalibratedBinaryTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
public static SdcaNonCalibratedBinaryClassificationTrainer SdcaNonCalibrated(
|
||||
public static SdcaNonCalibratedBinaryTrainer SdcaNonCalibrated(
|
||||
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
SdcaNonCalibratedBinaryClassificationTrainer.Options options)
|
||||
SdcaNonCalibratedBinaryTrainer.Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
Contracts.CheckValue(options, nameof(options));
|
||||
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new SdcaNonCalibratedBinaryClassificationTrainer(env, options);
|
||||
return new SdcaNonCalibratedBinaryTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear multiclass classification model trained with <see cref="SdcaMulticlassClassificationTrainer"/>.
|
||||
/// Predict a target using a linear multiclass classification model trained with <see cref="SdcaMulticlassTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The multiclass classification catalog trainer object.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -297,7 +297,7 @@ namespace Microsoft.ML
|
|||
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscent.cs)]
|
||||
/// ]]></format>
|
||||
/// </example>
|
||||
public static SdcaMulticlassClassificationTrainer Sdca(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
public static SdcaMulticlassTrainer Sdca(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -308,11 +308,11 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new SdcaMulticlassClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations);
|
||||
return new SdcaMulticlassTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, loss, l2Regularization, l1Threshold, maximumNumberOfIterations);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear multiclass classification model trained with <see cref="SdcaMulticlassClassificationTrainer"/> and advanced options.
|
||||
/// Predict a target using a linear multiclass classification model trained with <see cref="SdcaMulticlassTrainer"/> and advanced options.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The multiclass classification catalog trainer object.</param>
|
||||
/// <param name="options">Trainer options.</param>
|
||||
|
@ -322,14 +322,14 @@ namespace Microsoft.ML
|
|||
/// [!code-csharp[SDCA](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/StochasticDualCoordinateAscentWithOptions.cs)]
|
||||
/// ]]></format>
|
||||
/// </example>
|
||||
public static SdcaMulticlassClassificationTrainer Sdca(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
SdcaMulticlassClassificationTrainer.Options options)
|
||||
public static SdcaMulticlassTrainer Sdca(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
SdcaMulticlassTrainer.Options options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
Contracts.CheckValue(options, nameof(options));
|
||||
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new SdcaMulticlassClassificationTrainer(env, options);
|
||||
return new SdcaMulticlassTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -448,7 +448,7 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Trainers.LogisticRegressionBinaryClassificationTrainer"/> trainer.
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Trainers.LogisticRegressionBinaryTrainer"/> trainer.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
|
@ -457,7 +457,7 @@ namespace Microsoft.ML
|
|||
/// <param name="enforceNonNegativity">Enforce non-negative weights.</param>
|
||||
/// <param name="l1Regularization">Weight of L1 regularization term.</param>
|
||||
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Trainers.LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Trainers.LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
/// <example>
|
||||
/// <format type="text/markdown">
|
||||
|
@ -466,7 +466,7 @@ namespace Microsoft.ML
|
|||
/// ]]>
|
||||
/// </format>
|
||||
/// </example>
|
||||
public static LogisticRegressionBinaryClassificationTrainer LogisticRegression(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
public static LogisticRegressionBinaryTrainer LogisticRegression(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features,
|
||||
string exampleWeightColumnName = null,
|
||||
|
@ -478,21 +478,21 @@ namespace Microsoft.ML
|
|||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new LogisticRegressionBinaryClassificationTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity);
|
||||
return new LogisticRegressionBinaryTrainer(env, labelColumnName, featureColumnName, exampleWeightColumnName, l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Trainers.LogisticRegressionBinaryClassificationTrainer"/> trainer.
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Trainers.LogisticRegressionBinaryTrainer"/> trainer.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="options">Advanced arguments to the algorithm.</param>
|
||||
public static LogisticRegressionBinaryClassificationTrainer LogisticRegression(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, LROptions options)
|
||||
public static LogisticRegressionBinaryTrainer LogisticRegression(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog, LROptions options)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
Contracts.CheckValue(options, nameof(options));
|
||||
|
||||
var env = CatalogUtils.GetEnvironment(catalog);
|
||||
return new LogisticRegressionBinaryClassificationTrainer(env, options);
|
||||
return new LogisticRegressionBinaryTrainer(env, options);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -579,18 +579,18 @@ namespace Microsoft.ML
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predicts a target using a linear multiclass classification model trained with the <see cref="NaiveBayesTrainer"/>.
|
||||
/// The <see cref="NaiveBayesTrainer"/> trains a multiclass Naive Bayes predictor that supports binary feature values.
|
||||
/// Predicts a target using a linear multiclass classification model trained with the <see cref="NaiveBayesMulticlassTrainer"/>.
|
||||
/// The <see cref="NaiveBayesMulticlassTrainer"/> trains a multiclass Naive Bayes predictor that supports binary feature values.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="MulticlassClassificationCatalog.MulticlassClassificationTrainers"/>.</param>
|
||||
/// <param name="labelColumnName">The name of the label column.</param>
|
||||
/// <param name="featureColumnName">The name of the feature column.</param>
|
||||
public static NaiveBayesTrainer NaiveBayes(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
public static NaiveBayesMulticlassTrainer NaiveBayes(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
string labelColumnName = DefaultColumnNames.Label,
|
||||
string featureColumnName = DefaultColumnNames.Features)
|
||||
{
|
||||
Contracts.CheckValue(catalog, nameof(catalog));
|
||||
return new NaiveBayesTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, featureColumnName);
|
||||
return new NaiveBayesMulticlassTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, featureColumnName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -9,7 +9,7 @@ using Microsoft.ML.Trainers;
|
|||
|
||||
namespace Microsoft.ML.StaticPipe
|
||||
{
|
||||
using Options = LogisticRegressionBinaryClassificationTrainer.Options;
|
||||
using Options = LogisticRegressionBinaryTrainer.Options;
|
||||
|
||||
/// <summary>
|
||||
/// Binary Classification trainer estimators.
|
||||
|
@ -17,7 +17,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
public static class LbfgsBinaryClassificationStaticExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/> trainer.
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/> trainer.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="label">The label, or dependent variable.</param>
|
||||
|
@ -26,7 +26,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
/// <param name="enforceNonNegativity">Enforce non-negative weights.</param>
|
||||
/// <param name="l1Regularization">Weight of L1 regularization term.</param>
|
||||
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
/// <param name="onFit">A delegate that is called every time the
|
||||
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
|
||||
|
@ -50,7 +50,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new LogisticRegressionBinaryClassificationTrainer(env, labelName, featuresName, weightsName,
|
||||
var trainer = new LogisticRegressionBinaryTrainer(env, labelName, featuresName, weightsName,
|
||||
l1Regularization, l2Regularization, optimizationTolerance, historySize, enforceNonNegativity);
|
||||
|
||||
if (onFit != null)
|
||||
|
@ -63,7 +63,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/> trainer.
|
||||
/// Predict a target using a linear binary classification model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/> trainer.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The binary classification catalog trainer object.</param>
|
||||
/// <param name="label">The label, or dependent variable.</param>
|
||||
|
@ -95,7 +95,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
options.FeatureColumnName = featuresName;
|
||||
options.ExampleWeightColumnName = weightsName;
|
||||
|
||||
var trainer = new LogisticRegressionBinaryClassificationTrainer(env, options);
|
||||
var trainer = new LogisticRegressionBinaryTrainer(env, options);
|
||||
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
|
@ -113,7 +113,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
public static class LbfgsRegressionExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Predict a target using a linear regression model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/> trainer.
|
||||
/// Predict a target using a linear regression model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/> trainer.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The regression catalog trainer object.</param>
|
||||
/// <param name="label">The label, or dependent variable.</param>
|
||||
|
@ -122,7 +122,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
/// <param name="enforceNonNegativity">Enforce non-negative weights.</param>
|
||||
/// <param name="l1Regularization">Weight of L1 regularization term.</param>
|
||||
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
/// <param name="onFit">A delegate that is called every time the
|
||||
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
|
||||
|
@ -159,7 +159,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Predict a target using a linear regression model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/> trainer.
|
||||
/// Predict a target using a linear regression model trained with the <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/> trainer.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The regression catalog trainer object.</param>
|
||||
/// <param name="label">The label, or dependent variable.</param>
|
||||
|
@ -218,7 +218,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
/// <param name="enforceNonNegativity">Enforce non-negative weights.</param>
|
||||
/// <param name="l1Regularization">Weight of L1 regularization term.</param>
|
||||
/// <param name="l2Regularization">Weight of L2 regularization term.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="historySize">Memory size for <see cref="Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer"/>. Low=faster, less accurate.</param>
|
||||
/// <param name="optimizationTolerance">Threshold for optimizer convergence.</param>
|
||||
/// <param name="onFit">A delegate that is called every time the
|
||||
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
MulticlassNaiveBayesTrainer<TVal>(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
|
||||
Key<uint, TVal> label,
|
||||
Vector<float> features,
|
||||
Action<MulticlassNaiveBayesModelParameters> onFit = null)
|
||||
Action<NaiveBayesMulticlassModelParameters> onFit = null)
|
||||
{
|
||||
Contracts.CheckValue(features, nameof(features));
|
||||
Contracts.CheckValue(label, nameof(label));
|
||||
|
@ -38,7 +38,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler<TVal>(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new NaiveBayesTrainer(env, labelName, featuresName);
|
||||
var trainer = new NaiveBayesMulticlassTrainer(env, labelName, featuresName);
|
||||
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
|
|
|
@ -154,7 +154,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new SdcaCalibratedBinaryClassificationTrainer(env, labelName, featuresName, weightsName, l2Regularization, l1Threshold, numberOfIterations);
|
||||
var trainer = new SdcaCalibratedBinaryTrainer(env, labelName, featuresName, weightsName, l2Regularization, l1Threshold, numberOfIterations);
|
||||
if (onFit != null)
|
||||
{
|
||||
return trainer.WithOnFitDelegate(trans =>
|
||||
|
@ -192,7 +192,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) Sdca(
|
||||
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
Scalar<bool> label, Vector<float> features, Scalar<float> weights,
|
||||
SdcaCalibratedBinaryClassificationTrainer.Options options,
|
||||
SdcaCalibratedBinaryTrainer.Options options,
|
||||
Action<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>> onFit = null)
|
||||
{
|
||||
Contracts.CheckValue(label, nameof(label));
|
||||
|
@ -207,7 +207,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
options.LabelColumnName = labelName;
|
||||
options.FeatureColumnName = featuresName;
|
||||
|
||||
var trainer = new SdcaCalibratedBinaryClassificationTrainer(env, options);
|
||||
var trainer = new SdcaCalibratedBinaryTrainer(env, options);
|
||||
if (onFit != null)
|
||||
{
|
||||
return trainer.WithOnFitDelegate(trans =>
|
||||
|
@ -263,7 +263,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.BinaryClassifierNoCalibration(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new SdcaNonCalibratedBinaryClassificationTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations);
|
||||
var trainer = new SdcaNonCalibratedBinaryTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations);
|
||||
if (onFit != null)
|
||||
{
|
||||
return trainer.WithOnFitDelegate(trans =>
|
||||
|
@ -299,7 +299,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
Scalar<bool> label, Vector<float> features, Scalar<float> weights,
|
||||
ISupportSdcaClassificationLoss loss,
|
||||
SdcaNonCalibratedBinaryClassificationTrainer.Options options,
|
||||
SdcaNonCalibratedBinaryTrainer.Options options,
|
||||
Action<LinearBinaryModelParameters> onFit = null)
|
||||
{
|
||||
Contracts.CheckValue(label, nameof(label));
|
||||
|
@ -314,7 +314,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
options.FeatureColumnName = featuresName;
|
||||
options.LabelColumnName = labelName;
|
||||
|
||||
var trainer = new SdcaNonCalibratedBinaryClassificationTrainer(env, options);
|
||||
var trainer = new SdcaNonCalibratedBinaryTrainer(env, options);
|
||||
if (onFit != null)
|
||||
{
|
||||
return trainer.WithOnFitDelegate(trans =>
|
||||
|
@ -368,7 +368,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.MulticlassClassificationReconciler<TVal>(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new SdcaMulticlassClassificationTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations);
|
||||
var trainer = new SdcaMulticlassTrainer(env, labelName, featuresName, weightsName, loss, l2Regularization, l1Threshold, numberOfIterations);
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
return trainer;
|
||||
|
@ -396,7 +396,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
Key<uint, TVal> label,
|
||||
Vector<float> features,
|
||||
Scalar<float> weights,
|
||||
SdcaMulticlassClassificationTrainer.Options options,
|
||||
SdcaMulticlassTrainer.Options options,
|
||||
Action<MulticlassLogisticRegressionModelParameters> onFit = null)
|
||||
{
|
||||
Contracts.CheckValue(label, nameof(label));
|
||||
|
@ -411,7 +411,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
options.LabelColumnName = labelName;
|
||||
options.FeatureColumnName = featuresName;
|
||||
|
||||
var trainer = new SdcaMulticlassClassificationTrainer(env, options);
|
||||
var trainer = new SdcaMulticlassTrainer(env, options);
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
return trainer;
|
||||
|
|
|
@ -108,7 +108,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
|
||||
/// <summary>
|
||||
/// FastTree <see cref="BinaryClassificationCatalog"/> extension method.
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="label">The label column.</param>
|
||||
|
@ -144,7 +144,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
|
||||
(env, labelName, featuresName, weightsName) =>
|
||||
{
|
||||
var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
|
||||
var trainer = new FastTreeBinaryTrainer(env, labelName, featuresName, weightsName, numberOfLeaves,
|
||||
numberOfTrees, minimumExampleCountPerLeaf, learningRate);
|
||||
|
||||
if (onFit != null)
|
||||
|
@ -158,7 +158,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
|
||||
/// <summary>
|
||||
/// FastTree <see cref="BinaryClassificationCatalog"/> extension method.
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
|
||||
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryTrainer"/>.
|
||||
/// </summary>
|
||||
/// <param name="catalog">The <see cref="BinaryClassificationCatalog"/>.</param>
|
||||
/// <param name="label">The label column.</param>
|
||||
|
@ -180,7 +180,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
/// </example>
|
||||
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) FastTree(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
|
||||
Scalar<bool> label, Vector<float> features, Scalar<float> weights,
|
||||
FastTreeBinaryClassificationTrainer.Options options,
|
||||
FastTreeBinaryTrainer.Options options,
|
||||
Action<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> onFit = null)
|
||||
{
|
||||
Contracts.CheckValueOrNull(options);
|
||||
|
@ -193,7 +193,7 @@ namespace Microsoft.ML.StaticPipe
|
|||
options.FeatureColumnName = featuresName;
|
||||
options.ExampleWeightColumnName = weightsName;
|
||||
|
||||
var trainer = new FastTreeBinaryClassificationTrainer(env, options);
|
||||
var trainer = new FastTreeBinaryTrainer(env, options);
|
||||
|
||||
if (onFit != null)
|
||||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
|
||||
|
|
|
@ -43,14 +43,14 @@ Trainers.AveragedPerceptronBinaryClassifier Averaged Perceptron Binary Classifie
|
|||
Trainers.EnsembleBinaryClassifier Train binary ensemble. Microsoft.ML.Trainers.Ensemble.Ensemble CreateBinaryEnsemble Microsoft.ML.Trainers.Ensemble.EnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.EnsembleClassification Train multiclass ensemble. Microsoft.ML.Trainers.Ensemble.Ensemble CreateMulticlassEnsemble Microsoft.ML.Trainers.Ensemble.MulticlassDataPartitionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.EnsembleRegression Train regression ensemble. Microsoft.ML.Trainers.Ensemble.Ensemble CreateRegressionEnsemble Microsoft.ML.Trainers.Ensemble.RegressionEnsembleTrainer+Arguments Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.FastForestBinaryClassifier Uses a random forest learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastForest TrainBinary Microsoft.ML.Trainers.FastTree.FastForestBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.FastForestBinaryClassifier Uses a random forest learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastForest TrainBinary Microsoft.ML.Trainers.FastTree.FastForestBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.FastForestRegressor Trains a random forest to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastForest TrainRegression Microsoft.ML.Trainers.FastTree.FastForestRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastTree TrainBinary Microsoft.ML.Trainers.FastTree.FastTreeBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.FastTreeBinaryClassifier Uses a logit-boost boosted tree learner to perform binary classification. Microsoft.ML.Trainers.FastTree.FastTree TrainBinary Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.FastTreeRanker Trains gradient boosted decision trees to the LambdaRank quasi-gradient. Microsoft.ML.Trainers.FastTree.FastTree TrainRanking Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput
|
||||
Trainers.FastTreeRegressor Trains gradient boosted decision trees to fit target values using least-squares. Microsoft.ML.Trainers.FastTree.FastTree TrainRegression Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.FastTreeTweedieRegressor Trains gradient boosted decision trees to fit target values using a Tweedie loss function. This learner is a generalization of Poisson, compound Poisson, and gamma regression. Microsoft.ML.Trainers.FastTree.FastTree TrainTweedieRegression Microsoft.ML.Trainers.FastTree.FastTreeTweedieTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.FieldAwareFactorizationMachineBinaryClassifier Train a field-aware factorization machine for binary classification Microsoft.ML.Trainers.FieldAwareFactorizationMachineTrainer TrainBinary Microsoft.ML.Trainers.FieldAwareFactorizationMachineTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.GamBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.GeneralizedAdditiveModelBinaryClassifier Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainBinary Microsoft.ML.Trainers.FastTree.GamBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.GeneralizedAdditiveModelRegressor Trains a gradient boosted stump per feature, on all features simultaneously, to fit target values using least-squares. It mantains no interactions between features. Microsoft.ML.Trainers.FastTree.Gam TrainRegression Microsoft.ML.Trainers.FastTree.GamRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.KMeansPlusPlusClusterer K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. Microsoft.ML.Trainers.KMeansTrainer TrainKMeans Microsoft.ML.Trainers.KMeansTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+ClusteringOutput
|
||||
Trainers.LightGbmBinaryClassifier Train a LightGBM binary classification model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainBinary Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
|
@ -58,15 +58,15 @@ Trainers.LightGbmClassifier Train a LightGBM multi class model. Microsoft.ML.Tra
|
|||
Trainers.LightGbmRanker Train a LightGBM ranking model. Microsoft.ML.Trainers.LightGbm.LightGbm TrainRanking Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+RankingOutput
|
||||
Trainers.LightGbmRegressor LightGBM Regression Microsoft.ML.Trainers.LightGbm.LightGbm TrainRegression Microsoft.ML.Trainers.LightGbm.Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.LinearSvmBinaryClassifier Train a linear SVM. Microsoft.ML.Trainers.LinearSvmTrainer TrainLinearSvm Microsoft.ML.Trainers.LinearSvmTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer TrainBinary Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryClassificationTrainer TrainMulticlass Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.NaiveBayesClassifier Train a MulticlassNaiveBayesTrainer. Microsoft.ML.Trainers.NaiveBayesTrainer TrainMulticlassNaiveBayesTrainer Microsoft.ML.Trainers.NaiveBayesTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.LogisticRegressionBinaryClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer TrainBinary Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.LogisticRegressionClassifier Logistic Regression is a method in statistics used to predict the probability of occurrence of an event and can be used as a classification algorithm. The algorithm predicts the probability of occurrence of an event by fitting data to a logistical function. Microsoft.ML.Trainers.LogisticRegressionBinaryTrainer TrainMulticlass Microsoft.ML.Trainers.LogisticRegressionMulticlassClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.NaiveBayesClassifier Train a MulticlassNaiveBayesTrainer. Microsoft.ML.Trainers.NaiveBayesMulticlassTrainer TrainMulticlassNaiveBayesTrainer Microsoft.ML.Trainers.NaiveBayesMulticlassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.OnlineGradientDescentRegressor Train a Online gradient descent perceptron. Microsoft.ML.Trainers.OnlineGradientDescentTrainer TrainRegression Microsoft.ML.Trainers.OnlineGradientDescentTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.OrdinaryLeastSquaresRegressor Train an OLS regression model. Microsoft.ML.Trainers.OlsTrainer TrainRegression Microsoft.ML.Trainers.OlsTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.PcaAnomalyDetector Train an PCA Anomaly model. Microsoft.ML.Trainers.RandomizedPcaTrainer TrainPcaAnomaly Microsoft.ML.Trainers.RandomizedPcaTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+AnomalyDetectionOutput
|
||||
Trainers.PoissonRegressor Train an Poisson regression model. Microsoft.ML.Trainers.PoissonRegressionTrainer TrainRegression Microsoft.ML.Trainers.PoissonRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.StochasticDualCoordinateAscentBinaryClassifier Train an SDCA binary model. Microsoft.ML.Trainers.Sdca TrainBinary Microsoft.ML.Trainers.LegacySdcaBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMulticlass Microsoft.ML.Trainers.SdcaMulticlassClassificationTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.StochasticDualCoordinateAscentClassifier The SDCA linear multi-class classification trainer. Microsoft.ML.Trainers.Sdca TrainMulticlass Microsoft.ML.Trainers.SdcaMulticlassTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+MulticlassClassificationOutput
|
||||
Trainers.StochasticDualCoordinateAscentRegressor The SDCA linear regression trainer. Microsoft.ML.Trainers.Sdca TrainRegression Microsoft.ML.Trainers.SdcaRegressionTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+RegressionOutput
|
||||
Trainers.StochasticGradientDescentBinaryClassifier Train an Hogwild SGD binary model. Microsoft.ML.Trainers.LegacySgdBinaryTrainer TrainBinary Microsoft.ML.Trainers.LegacySgdBinaryTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
Trainers.SymSgdBinaryClassifier Train a symbolic SGD. Microsoft.ML.Trainers.SymbolicSgdTrainer TrainSymSgd Microsoft.ML.Trainers.SymbolicSgdTrainer+Options Microsoft.ML.EntryPoints.CommonOutputs+BinaryClassificationOutput
|
||||
|
|
|
|
@ -11356,7 +11356,7 @@
|
|||
{
|
||||
"Name": "Sigmoid",
|
||||
"Type": "Float",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryClassificationTrainer, LightGbmMulticlassClassificationTrainer and in LightGbmRankingTrainer.",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.",
|
||||
"Aliases": [
|
||||
"sigmoid"
|
||||
],
|
||||
|
@ -11859,7 +11859,7 @@
|
|||
{
|
||||
"Name": "Sigmoid",
|
||||
"Type": "Float",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryClassificationTrainer, LightGbmMulticlassClassificationTrainer and in LightGbmRankingTrainer.",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.",
|
||||
"Aliases": [
|
||||
"sigmoid"
|
||||
],
|
||||
|
@ -12362,7 +12362,7 @@
|
|||
{
|
||||
"Name": "Sigmoid",
|
||||
"Type": "Float",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryClassificationTrainer, LightGbmMulticlassClassificationTrainer and in LightGbmRankingTrainer.",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.",
|
||||
"Aliases": [
|
||||
"sigmoid"
|
||||
],
|
||||
|
@ -12865,7 +12865,7 @@
|
|||
{
|
||||
"Name": "Sigmoid",
|
||||
"Type": "Float",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryClassificationTrainer, LightGbmMulticlassClassificationTrainer and in LightGbmRankingTrainer.",
|
||||
"Desc": "Parameter for the sigmoid function. Used only in LightGbmBinaryTrainer, LightGbmMulticlassTrainer and in LightGbmRankingTrainer.",
|
||||
"Aliases": [
|
||||
"sigmoid"
|
||||
],
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
.Append(ml.Clustering.Trainers.KMeans("Features"))
|
||||
.Append(ml.Transforms.Concatenate("Features", "Features", "Score"))
|
||||
.Append(ml.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { EnforceNonNegativity = true, OptmizationTolerance = 1e-3f, }));
|
||||
new LogisticRegressionBinaryTrainer.Options { EnforceNonNegativity = true, OptmizationTolerance = 1e-3f, }));
|
||||
|
||||
var model = estimatorPipeline.Fit(input);
|
||||
// Return the last model in the chain.
|
||||
|
|
|
@ -58,7 +58,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
" xf=NAHandleTransform{col=Features}" +
|
||||
" tr=LightGBMRanking{}";
|
||||
|
||||
var environment = EnvironmentFactory.CreateRankingEnvironment<RankingEvaluator, TextLoader, HashingTransformer, LightGbmMulticlassClassificationTrainer, OneVersusAllModelParameters>();
|
||||
var environment = EnvironmentFactory.CreateRankingEnvironment<RankingEvaluator, TextLoader, HashingTransformer, LightGbmMulticlassTrainer, OneVersusAllModelParameters>();
|
||||
cmd.ExecuteMamlCommand(environment);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
var pipeline = new ColumnConcatenatingEstimator(env, "Features", new[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" })
|
||||
.Append(env.Transforms.Conversion.MapValueToKey("Label"))
|
||||
.Append(env.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, }));
|
||||
new SdcaMulticlassTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, }));
|
||||
|
||||
var model = pipeline.Fit(data);
|
||||
|
||||
|
@ -93,7 +93,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
|
||||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.Append(mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, }));
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, }));
|
||||
|
||||
var model = pipeline.Fit(data);
|
||||
|
||||
|
@ -127,7 +127,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
IDataView data = loader.Load(_breastCancerDataPath);
|
||||
|
||||
var pipeline = env.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, });
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1, ConvergenceTolerance = 1e-2f, });
|
||||
|
||||
var model = pipeline.Fit(data);
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
" xf=Concat{col=Features:FeaturesText,logged_in,ns}" +
|
||||
" tr=LightGBMMulticlass{iter=10}";
|
||||
|
||||
var environment = EnvironmentFactory.CreateClassificationEnvironment<TextLoader, OneHotEncodingTransformer, LightGbmMulticlassClassificationTrainer, OneVersusAllModelParameters>();
|
||||
var environment = EnvironmentFactory.CreateClassificationEnvironment<TextLoader, OneHotEncodingTransformer, LightGbmMulticlassTrainer, OneVersusAllModelParameters>();
|
||||
cmd.ExecuteMamlCommand(environment);
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,7 @@ namespace Microsoft.ML.Benchmarks
|
|||
" xf=WordEmbeddingsTransform{col=FeaturesWordEmbedding:FeaturesText_TransformedText model=FastTextWikipedia300D}" +
|
||||
" xf=Concat{col=Features:FeaturesWordEmbedding,logged_in,ns}";
|
||||
|
||||
var environment = EnvironmentFactory.CreateClassificationEnvironment<TextLoader, OneHotEncodingTransformer, SdcaMulticlassClassificationTrainer, MulticlassLogisticRegressionModelParameters>();
|
||||
var environment = EnvironmentFactory.CreateClassificationEnvironment<TextLoader, OneHotEncodingTransformer, SdcaMulticlassTrainer, MulticlassLogisticRegressionModelParameters>();
|
||||
cmd.ExecuteMamlCommand(environment);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -132,7 +132,7 @@ namespace Microsoft.ML.RunTests
|
|||
var dataView = GetBreastCancerDataviewWithTextColumns();
|
||||
dataView = Env.CreateTransform("Term{col=F1}", dataView);
|
||||
var trainData = FeatureCombiner.PrepareFeatures(Env, new FeatureCombiner.FeatureCombinerInput() { Data = dataView, Features = new[] { "F1", "F2", "Rest" } });
|
||||
var lrModel = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, new LogisticRegressionBinaryClassificationTrainer.Options { TrainingData = trainData.OutputData }).PredictorModel;
|
||||
var lrModel = LogisticRegressionBinaryTrainer.TrainBinary(Env, new LogisticRegressionBinaryTrainer.Options { TrainingData = trainData.OutputData }).PredictorModel;
|
||||
var model = ModelOperations.CombineTwoModels(Env, new ModelOperations.SimplePredictorModelInput() { TransformModel = trainData.Model, PredictorModel = lrModel }).PredictorModel;
|
||||
|
||||
var scored1 = ScoreModel.Score(Env, new ScoreModel.Input() { Data = dataView, PredictorModel = model }).ScoredData;
|
||||
|
@ -362,12 +362,12 @@ namespace Microsoft.ML.RunTests
|
|||
{
|
||||
var catalog = Env.ComponentCatalog;
|
||||
|
||||
InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegressionBinaryClassificationTrainer.Options), catalog);
|
||||
InputBuilder ib1 = new InputBuilder(Env, typeof(LogisticRegressionBinaryTrainer.Options), catalog);
|
||||
// Ensure that InputBuilder unwraps the Optional<string> correctly.
|
||||
var weightType = ib1.GetFieldTypeOrNull("ExampleWeightColumnName");
|
||||
Assert.True(weightType.Equals(typeof(string)));
|
||||
|
||||
var instance = ib1.GetInstance() as LogisticRegressionBinaryClassificationTrainer.Options;
|
||||
var instance = ib1.GetInstance() as LogisticRegressionBinaryTrainer.Options;
|
||||
Assert.True(instance.ExampleWeightColumnName == null);
|
||||
|
||||
ib1.TrySetValue("ExampleWeightColumnName", "OtherWeight");
|
||||
|
@ -420,14 +420,14 @@ namespace Microsoft.ML.RunTests
|
|||
for (int i = 0; i < nModels; i++)
|
||||
{
|
||||
var data = splitOutput.TrainData[i];
|
||||
var lrInput = new LogisticRegressionBinaryClassificationTrainer.Options
|
||||
var lrInput = new LogisticRegressionBinaryTrainer.Options
|
||||
{
|
||||
TrainingData = data,
|
||||
L1Regularization = (Single)0.1 * i,
|
||||
L2Regularization = (Single)0.01 * (1 + i),
|
||||
NormalizeFeatures = NormalizeOption.No
|
||||
};
|
||||
predictorModels[i] = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
predictorModels[i] = LogisticRegressionBinaryTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
individualScores[i] =
|
||||
ScoreModel.Score(Env,
|
||||
new ScoreModel.Input { Data = splitOutput.TestData[nModels], PredictorModel = predictorModels[i] })
|
||||
|
@ -676,7 +676,7 @@ namespace Microsoft.ML.RunTests
|
|||
|
||||
var splitOutput = CVSplit.Split(Env, new CVSplit.Input { Data = dataView, NumFolds = 3 });
|
||||
|
||||
var lrModel = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, new LogisticRegressionBinaryClassificationTrainer.Options { TrainingData = splitOutput.TestData[0] }).PredictorModel;
|
||||
var lrModel = LogisticRegressionBinaryTrainer.TrainBinary(Env, new LogisticRegressionBinaryTrainer.Options { TrainingData = splitOutput.TestData[0] }).PredictorModel;
|
||||
var calibratedLrModel = Calibrate.FixedPlatt(Env,
|
||||
new Calibrate.FixedPlattInput { Data = splitOutput.TestData[1], UncalibratedPredictorModel = lrModel }).PredictorModel;
|
||||
|
||||
|
@ -695,7 +695,7 @@ namespace Microsoft.ML.RunTests
|
|||
calibratedLrModel = Calibrate.Pav(Env, input).PredictorModel;
|
||||
|
||||
// This tests that the SchemaBindableCalibratedPredictor doesn't get confused if its sub-predictor is already calibrated.
|
||||
var fastForest = new FastForestBinaryClassificationTrainer(Env, "Label", "Features");
|
||||
var fastForest = new FastForestBinaryTrainer(Env, "Label", "Features");
|
||||
var rmd = new RoleMappedData(splitOutput.TrainData[0], "Label", "Features");
|
||||
var ffModel = new PredictorModelImpl(Env, rmd, splitOutput.TrainData[0], fastForest.Train(rmd));
|
||||
var calibratedFfModel = Calibrate.Platt(Env,
|
||||
|
@ -724,14 +724,14 @@ namespace Microsoft.ML.RunTests
|
|||
data = new ColumnConcatenatingTransformer(Env, "Features", new[] { "Features1", "Features2" }).Transform(data);
|
||||
data = new ValueToKeyMappingEstimator(Env, "Label", "Label", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue).Fit(data).Transform(data);
|
||||
|
||||
var lrInput = new LogisticRegressionBinaryClassificationTrainer.Options
|
||||
var lrInput = new LogisticRegressionBinaryTrainer.Options
|
||||
{
|
||||
TrainingData = data,
|
||||
L1Regularization = (Single)0.1 * i,
|
||||
L2Regularization = (Single)0.01 * (1 + i),
|
||||
NormalizeFeatures = NormalizeOption.Yes
|
||||
};
|
||||
predictorModels[i] = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
predictorModels[i] = LogisticRegressionBinaryTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
var transformModel = new TransformModelImpl(Env, data, splitOutput.TrainData[i]);
|
||||
|
||||
predictorModels[i] = ModelOperations.CombineTwoModels(Env,
|
||||
|
@ -985,14 +985,14 @@ namespace Microsoft.ML.RunTests
|
|||
},
|
||||
data);
|
||||
}
|
||||
var lrInput = new LogisticRegressionBinaryClassificationTrainer.Options
|
||||
var lrInput = new LogisticRegressionBinaryTrainer.Options
|
||||
{
|
||||
TrainingData = data,
|
||||
L1Regularization = (Single)0.1 * i,
|
||||
L2Regularization = (Single)0.01 * (1 + i),
|
||||
NormalizeFeatures = NormalizeOption.Yes
|
||||
};
|
||||
predictorModels[i] = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
predictorModels[i] = LogisticRegressionBinaryTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
var transformModel = new TransformModelImpl(Env, data, splitOutput.TrainData[i]);
|
||||
|
||||
predictorModels[i] = ModelOperations.CombineTwoModels(Env,
|
||||
|
@ -1318,7 +1318,7 @@ namespace Microsoft.ML.RunTests
|
|||
data = new ColumnConcatenatingTransformer(Env, new ColumnConcatenatingTransformer.ColumnOptions("Features", i % 2 == 0 ? new[] { "Features", "Cat" } : new[] { "Cat", "Features" })).Transform(data);
|
||||
if (i % 2 == 0)
|
||||
{
|
||||
var lrInput = new LogisticRegressionBinaryClassificationTrainer.Options
|
||||
var lrInput = new LogisticRegressionBinaryTrainer.Options
|
||||
{
|
||||
TrainingData = data,
|
||||
NormalizeFeatures = NormalizeOption.Yes,
|
||||
|
@ -1326,7 +1326,7 @@ namespace Microsoft.ML.RunTests
|
|||
ShowTrainingStatistics = true,
|
||||
ComputeStandardDeviation = new ComputeLRTrainingStdThroughMkl()
|
||||
};
|
||||
predictorModels[i] = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
predictorModels[i] = LogisticRegressionBinaryTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
var transformModel = new TransformModelImpl(Env, data, splitOutput.TrainData[i]);
|
||||
|
||||
predictorModels[i] = ModelOperations.CombineTwoModels(Env,
|
||||
|
@ -1335,7 +1335,7 @@ namespace Microsoft.ML.RunTests
|
|||
}
|
||||
else if (i % 2 == 1)
|
||||
{
|
||||
var trainer = new FastTreeBinaryClassificationTrainer(Env, "Label", "Features");
|
||||
var trainer = new FastTreeBinaryTrainer(Env, "Label", "Features");
|
||||
var rmd = new RoleMappedData(data, false,
|
||||
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Feature, "Features"),
|
||||
RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, "Label"));
|
||||
|
@ -3347,7 +3347,7 @@ namespace Microsoft.ML.RunTests
|
|||
InputFile = inputFile,
|
||||
}).Data;
|
||||
|
||||
var lrInput = new LogisticRegressionBinaryClassificationTrainer.Options
|
||||
var lrInput = new LogisticRegressionBinaryTrainer.Options
|
||||
{
|
||||
TrainingData = dataView,
|
||||
NormalizeFeatures = NormalizeOption.Yes,
|
||||
|
@ -3355,7 +3355,7 @@ namespace Microsoft.ML.RunTests
|
|||
ShowTrainingStatistics = true,
|
||||
ComputeStandardDeviation = new ComputeLRTrainingStdThroughMkl()
|
||||
};
|
||||
var model = LogisticRegressionBinaryClassificationTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
var model = LogisticRegressionBinaryTrainer.TrainBinary(Env, lrInput).PredictorModel;
|
||||
|
||||
var mcLrInput = new LogisticRegressionMulticlassClassificationTrainer.Options
|
||||
{
|
||||
|
@ -3364,7 +3364,7 @@ namespace Microsoft.ML.RunTests
|
|||
NumberOfThreads = 1,
|
||||
ShowTrainingStatistics = true
|
||||
};
|
||||
var mcModel = LogisticRegressionBinaryClassificationTrainer.TrainMulticlass(Env, mcLrInput).PredictorModel;
|
||||
var mcModel = LogisticRegressionBinaryTrainer.TrainMulticlass(Env, mcLrInput).PredictorModel;
|
||||
|
||||
var output = SummarizePredictor.Summarize(Env,
|
||||
new SummarizePredictor.Input() { PredictorModel = model });
|
||||
|
@ -3556,7 +3556,7 @@ namespace Microsoft.ML.RunTests
|
|||
Columns = new[] { new ColumnConcatenatingTransformer.Column { Name = "Features", Source = new[] { "Categories", "NumericFeatures" } } }
|
||||
});
|
||||
|
||||
var fastTree = Trainers.FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryClassificationTrainer.Options
|
||||
var fastTree = Trainers.FastTree.FastTree.TrainBinary(Env, new FastTreeBinaryTrainer.Options
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
NumberOfTrees = 5,
|
||||
|
|
|
@ -144,7 +144,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
}, "SentimentText")
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.BinaryClassification.Trainers.SdcaCalibrated(
|
||||
new SdcaCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train the model.
|
||||
var model = pipeline.Fit(data);
|
||||
|
|
|
@ -65,7 +65,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train the model.
|
||||
var model = pipeline.Fit(data);
|
||||
|
@ -94,7 +94,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train the model.
|
||||
var model = pipeline.Fit(data);
|
||||
|
@ -152,7 +152,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { NumberOfThreads = 1}));
|
||||
new SdcaMulticlassTrainer.Options { NumberOfThreads = 1}));
|
||||
|
||||
// Train the model.
|
||||
var model = pipeline.Fit(data);
|
||||
|
@ -274,7 +274,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train the model.
|
||||
var model = pipeline.Fit(data);
|
||||
|
|
|
@ -82,7 +82,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.BinaryClassification.Trainers.FastTree(
|
||||
new FastTreeBinaryClassificationTrainer.Options{ NumberOfLeaves = 5, NumberOfTrees= 3, NumberOfThreads = 1 }));
|
||||
new FastTreeBinaryTrainer.Options{ NumberOfLeaves = 5, NumberOfTrees= 3, NumberOfThreads = 1 }));
|
||||
|
||||
// Fit the pipeline.
|
||||
var model = pipeline.Fit(data);
|
||||
|
@ -217,7 +217,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Fit the pipeline.
|
||||
var model = pipeline.Fit(data);
|
||||
|
@ -423,7 +423,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
{
|
||||
return mlContext.Transforms.Conversion.MapValueToKey("Label")
|
||||
.Append(mlContext.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options {
|
||||
new SdcaMulticlassTrainer.Options {
|
||||
MaximumNumberOfIterations = 10,
|
||||
NumberOfThreads = 1 }));
|
||||
}
|
||||
|
|
|
@ -89,19 +89,19 @@ namespace Microsoft.ML.Functional.Tests
|
|||
|
||||
var gam = ((loadedTransformerModel as ISingleFeaturePredictionTransformer<object>).Model
|
||||
as CalibratedModelParametersBase).SubModel
|
||||
as BinaryClassificationGamModelParameters;
|
||||
as GamBinaryModelParameters;
|
||||
Assert.NotNull(gam);
|
||||
|
||||
gam = (((loadedCompositeLoader as CompositeDataLoader<IMultiStreamSource, ITransformer>).Transformer.LastTransformer
|
||||
as ISingleFeaturePredictionTransformer<object>).Model
|
||||
as CalibratedModelParametersBase).SubModel
|
||||
as BinaryClassificationGamModelParameters;
|
||||
as GamBinaryModelParameters;
|
||||
Assert.NotNull(gam);
|
||||
|
||||
gam = (((loadedTransformerModel1 as TransformerChain<ITransformer>).LastTransformer
|
||||
as ISingleFeaturePredictionTransformer<object>).Model
|
||||
as CalibratedModelParametersBase).SubModel
|
||||
as BinaryClassificationGamModelParameters;
|
||||
as GamBinaryModelParameters;
|
||||
Assert.NotNull(gam);
|
||||
}
|
||||
|
||||
|
@ -146,7 +146,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
Assert.NotNull(singleFeaturePredictionTransformer);
|
||||
var calibratedModelParameters = singleFeaturePredictionTransformer.Model as CalibratedModelParametersBase;
|
||||
Assert.NotNull(calibratedModelParameters);
|
||||
var gamModel = calibratedModelParameters.SubModel as BinaryClassificationGamModelParameters;
|
||||
var gamModel = calibratedModelParameters.SubModel as GamBinaryModelParameters;
|
||||
Assert.NotNull(gamModel);
|
||||
var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex);
|
||||
var ageBinEffects = gamModel.GetBinEffects(ageIndex);
|
||||
|
|
|
@ -44,10 +44,10 @@ namespace Microsoft.ML.Functional.Tests
|
|||
|
||||
// Create a selection of learners.
|
||||
var sdcaTrainer = mlContext.BinaryClassification.Trainers.SdcaCalibrated(
|
||||
new SdcaCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 });
|
||||
new SdcaCalibratedBinaryTrainer.Options { NumberOfThreads = 1 });
|
||||
|
||||
var fastTreeTrainer = mlContext.BinaryClassification.Trainers.FastTree(
|
||||
new FastTreeBinaryClassificationTrainer.Options { NumberOfThreads = 1 });
|
||||
new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 });
|
||||
|
||||
var ffmTrainer = mlContext.BinaryClassification.Trainers.FieldAwareFactorizationMachine();
|
||||
|
||||
|
@ -226,7 +226,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
.AppendCacheCheckpoint(mlContext);
|
||||
|
||||
var trainer = mlContext.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 10 });
|
||||
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1, MaximumNumberOfIterations = 10 });
|
||||
|
||||
// Fit the data transformation pipeline.
|
||||
var featurization = featurizationPipeline.Fit(data);
|
||||
|
@ -452,7 +452,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
|
||||
// Create a model training an OVA trainer with a binary classifier.
|
||||
var binaryClassificationTrainer = mlContext.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { MaximumNumberOfIterations = 10, NumberOfThreads = 1, });
|
||||
new LogisticRegressionBinaryTrainer.Options { MaximumNumberOfIterations = 10, NumberOfThreads = 1, });
|
||||
var binaryClassificationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
|
||||
|
|
|
@ -43,7 +43,7 @@ namespace Microsoft.ML.Functional.Tests
|
|||
|
||||
// Check that the results are valid
|
||||
Assert.IsType<RegressionMetrics>(cvResult[0].Metrics);
|
||||
Assert.IsType<TransformerChain<RegressionPredictionTransformer<OrdinaryLeastSquaresRegressionModelParameters>>>(cvResult[0].Model);
|
||||
Assert.IsType<TransformerChain<RegressionPredictionTransformer<OlsModelParameters>>>(cvResult[0].Model);
|
||||
Assert.True(cvResult[0].ScoredHoldOutSet is IDataView);
|
||||
Assert.Equal(5, cvResult.Length);
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace Microsoft.ML.RunTests
|
|||
new double[] { 2, 1, 0 }
|
||||
};
|
||||
|
||||
var gam = new RegressionGamModelParameters(mlContext, binUpperBounds, binEffects, intercept);
|
||||
var gam = new GamRegressionModelParameters(mlContext, binUpperBounds, binEffects, intercept);
|
||||
|
||||
// Check that the model has the right number of shape functions
|
||||
Assert.Equal(binUpperBounds.Length, gam.NumberOfShapeFunctions);
|
||||
|
@ -50,15 +50,15 @@ namespace Microsoft.ML.RunTests
|
|||
Utils.AreEqual(binEffects[i], gam.GetBinEffects(i).ToArray());
|
||||
|
||||
// Check that the constructor handles null inputs properly
|
||||
Assert.Throws<System.ArgumentNullException>(() => new RegressionGamModelParameters(mlContext, binUpperBounds, null, intercept));
|
||||
Assert.Throws<System.ArgumentNullException>(() => new RegressionGamModelParameters(mlContext, null, binEffects, intercept));
|
||||
Assert.Throws<System.ArgumentNullException>(() => new RegressionGamModelParameters(mlContext, null, null, intercept));
|
||||
Assert.Throws<System.ArgumentNullException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, null, intercept));
|
||||
Assert.Throws<System.ArgumentNullException>(() => new GamRegressionModelParameters(mlContext, null, binEffects, intercept));
|
||||
Assert.Throws<System.ArgumentNullException>(() => new GamRegressionModelParameters(mlContext, null, null, intercept));
|
||||
|
||||
// Check that the constructor handles mismatches in length between bin upper bounds and bin effects
|
||||
var misMatchArray = new double[1][];
|
||||
misMatchArray[0] = new double[] { 0 };
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new RegressionGamModelParameters(mlContext, binUpperBounds, misMatchArray, intercept));
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new RegressionGamModelParameters(mlContext, misMatchArray, binEffects, intercept));
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, misMatchArray, intercept));
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, misMatchArray, binEffects, intercept));
|
||||
|
||||
// Check that the constructor handles a mismatch in bin upper bounds and bin effects for a feature
|
||||
var fewerBinEffects = new double[2][]
|
||||
|
@ -66,13 +66,13 @@ namespace Microsoft.ML.RunTests
|
|||
new double[] { 0, 1 },
|
||||
new double[] { 2, 1, 0 }
|
||||
};
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new RegressionGamModelParameters(mlContext, binUpperBounds, fewerBinEffects, intercept));
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, fewerBinEffects, intercept));
|
||||
var moreBinEffects = new double[2][]
|
||||
{
|
||||
new double[] { 0, 1, 2, 3 },
|
||||
new double[] { 2, 1, 0 }
|
||||
};
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new RegressionGamModelParameters(mlContext, binUpperBounds, moreBinEffects, intercept));
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, binUpperBounds, moreBinEffects, intercept));
|
||||
|
||||
// Check that the constructor handles bin upper bounds that are not sorted
|
||||
var unsortedUpperBounds = new double[2][]
|
||||
|
@ -80,7 +80,7 @@ namespace Microsoft.ML.RunTests
|
|||
new double[] { 1, 3, 2 },
|
||||
new double[] { 4, 5, 6 }
|
||||
};
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new RegressionGamModelParameters(mlContext, unsortedUpperBounds, binEffects, intercept));
|
||||
Assert.Throws<System.ArgumentOutOfRangeException>(() => new GamRegressionModelParameters(mlContext, unsortedUpperBounds, binEffects, intercept));
|
||||
}
|
||||
|
||||
private void CheckArrayOfArrayEquality(double[][] array1, double[][] array2)
|
||||
|
|
|
@ -594,7 +594,7 @@ namespace Microsoft.ML.RunTests
|
|||
var fastTrees = new PredictorModel[3];
|
||||
for (int i = 0; i < 3; i++)
|
||||
{
|
||||
fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Options
|
||||
fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryTrainer.Options
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
NumberOfTrees = 5,
|
||||
|
@ -616,7 +616,7 @@ namespace Microsoft.ML.RunTests
|
|||
var fastTrees = new PredictorModel[3];
|
||||
for (int i = 0; i < 3; i++)
|
||||
{
|
||||
fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Options
|
||||
fastTrees[i] = FastTree.TrainBinary(ML, new FastTreeBinaryTrainer.Options
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
NumberOfTrees = 5,
|
||||
|
@ -723,7 +723,7 @@ namespace Microsoft.ML.RunTests
|
|||
|
||||
var predictors = new PredictorModel[]
|
||||
{
|
||||
FastTree.TrainBinary(ML, new FastTreeBinaryClassificationTrainer.Options
|
||||
FastTree.TrainBinary(ML, new FastTreeBinaryTrainer.Options
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
NumberOfTrees = 5,
|
||||
|
@ -739,7 +739,7 @@ namespace Microsoft.ML.RunTests
|
|||
TrainingData = dataView,
|
||||
NormalizeFeatures = NormalizeOption.No
|
||||
}).PredictorModel,
|
||||
LogisticRegressionBinaryClassificationTrainer.TrainBinary(ML, new LogisticRegressionBinaryClassificationTrainer.Options()
|
||||
LogisticRegressionBinaryTrainer.TrainBinary(ML, new LogisticRegressionBinaryTrainer.Options()
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
LabelColumnName = DefaultColumnNames.Label,
|
||||
|
@ -747,7 +747,7 @@ namespace Microsoft.ML.RunTests
|
|||
TrainingData = dataView,
|
||||
NormalizeFeatures = NormalizeOption.No
|
||||
}).PredictorModel,
|
||||
LogisticRegressionBinaryClassificationTrainer.TrainBinary(ML, new LogisticRegressionBinaryClassificationTrainer.Options()
|
||||
LogisticRegressionBinaryTrainer.TrainBinary(ML, new LogisticRegressionBinaryTrainer.Options()
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
LabelColumnName = DefaultColumnNames.Label,
|
||||
|
@ -776,7 +776,7 @@ namespace Microsoft.ML.RunTests
|
|||
LabelColumnName = DefaultColumnNames.Label,
|
||||
TrainingData = dataView
|
||||
}).PredictorModel,
|
||||
LogisticRegressionBinaryClassificationTrainer.TrainMulticlass(Env, new LogisticRegressionMulticlassClassificationTrainer.Options()
|
||||
LogisticRegressionBinaryTrainer.TrainMulticlass(Env, new LogisticRegressionMulticlassClassificationTrainer.Options()
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
LabelColumnName = DefaultColumnNames.Label,
|
||||
|
@ -784,7 +784,7 @@ namespace Microsoft.ML.RunTests
|
|||
TrainingData = dataView,
|
||||
NormalizeFeatures = NormalizeOption.No
|
||||
}).PredictorModel,
|
||||
LogisticRegressionBinaryClassificationTrainer.TrainMulticlass(Env, new LogisticRegressionMulticlassClassificationTrainer.Options()
|
||||
LogisticRegressionBinaryTrainer.TrainMulticlass(Env, new LogisticRegressionMulticlassClassificationTrainer.Options()
|
||||
{
|
||||
FeatureColumnName = "Features",
|
||||
LabelColumnName = DefaultColumnNames.Label,
|
||||
|
|
|
@ -117,7 +117,7 @@ namespace Microsoft.ML.StaticPipelineTesting
|
|||
|
||||
var est = reader.MakeNewEstimator()
|
||||
.Append(r => (r.label, preds: catalog.Trainers.Sdca(r.label, r.features, null,
|
||||
new SdcaCalibratedBinaryClassificationTrainer.Options { MaximumNumberOfIterations = 2, NumberOfThreads = 1 },
|
||||
new SdcaCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 2, NumberOfThreads = 1 },
|
||||
onFit: (p) => { pred = p; })));
|
||||
|
||||
var pipe = reader.Append(est);
|
||||
|
@ -197,7 +197,7 @@ namespace Microsoft.ML.StaticPipelineTesting
|
|||
// With a custom loss function we no longer get calibrated predictions.
|
||||
var est = reader.MakeNewEstimator()
|
||||
.Append(r => (r.label, preds: catalog.Trainers.SdcaNonCalibrated(r.label, r.features, null, loss,
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { MaximumNumberOfIterations = 2, NumberOfThreads = 1 },
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 2, NumberOfThreads = 1 },
|
||||
onFit: p => pred = p)));
|
||||
|
||||
var pipe = reader.Append(est);
|
||||
|
@ -654,7 +654,7 @@ namespace Microsoft.ML.StaticPipelineTesting
|
|||
|
||||
var est = reader.MakeNewEstimator()
|
||||
.Append(r => (r.label, preds: catalog.Trainers.LogisticRegressionBinaryClassifier(r.label, r.features, null,
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { L1Regularization = 10, NumberOfThreads = 1 }, onFit: (p) => { pred = p; })));
|
||||
new LogisticRegressionBinaryTrainer.Options { L1Regularization = 10, NumberOfThreads = 1 }, onFit: (p) => { pred = p; })));
|
||||
|
||||
var pipe = reader.Append(est);
|
||||
|
||||
|
@ -960,7 +960,7 @@ namespace Microsoft.ML.StaticPipelineTesting
|
|||
var reader = TextLoaderStatic.CreateLoader(env,
|
||||
c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));
|
||||
|
||||
MulticlassNaiveBayesModelParameters pred = null;
|
||||
NaiveBayesMulticlassModelParameters pred = null;
|
||||
|
||||
// With a custom loss function we no longer get calibrated predictions.
|
||||
var est = reader.MakeNewEstimator()
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace Microsoft.ML.TestFramework
|
|||
env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryModelParameters).Assembly); // ML.FastTree
|
||||
env.ComponentCatalog.RegisterAssembly(typeof(EnsembleModelParameters).Assembly); // ML.Ensemble
|
||||
env.ComponentCatalog.RegisterAssembly(typeof(KMeansModelParameters).Assembly); // ML.KMeansClustering
|
||||
env.ComponentCatalog.RegisterAssembly(typeof(PrincipleComponentModelParameters).Assembly); // ML.PCA
|
||||
env.ComponentCatalog.RegisterAssembly(typeof(PcaModelParameters).Assembly); // ML.PCA
|
||||
env.ComponentCatalog.RegisterAssembly(typeof(CVSplit).Assembly); // ML.EntryPoints
|
||||
return env;
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ namespace Microsoft.ML.RunTests
|
|||
static TestLearnersBase()
|
||||
{
|
||||
bool ok = true;
|
||||
ok &= typeof(FastTreeBinaryClassificationTrainer) != null;
|
||||
ok &= typeof(FastTreeBinaryTrainer) != null;
|
||||
Contracts.Check(ok, "Missing assemblies!");
|
||||
}
|
||||
|
||||
|
|
|
@ -150,7 +150,7 @@ namespace Microsoft.ML.Tests
|
|||
public void TestSDCABinary()
|
||||
{
|
||||
TestFeatureContribution(ML.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1, }), GetSparseDataset(TaskType.BinaryClassification, 100), "SDCABinary", precision: 5);
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1, }), GetSparseDataset(TaskType.BinaryClassification, 100), "SDCABinary", precision: 5);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
|
@ -152,7 +152,7 @@ namespace Microsoft.ML.Tests
|
|||
{
|
||||
var data = GetDenseDataset(TaskType.BinaryClassification);
|
||||
var model = ML.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }).Fit(data);
|
||||
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }).Fit(data);
|
||||
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);
|
||||
|
||||
// Pfi Indices:
|
||||
|
@ -190,7 +190,7 @@ namespace Microsoft.ML.Tests
|
|||
{
|
||||
var data = GetSparseDataset(TaskType.BinaryClassification);
|
||||
var model = ML.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options { NumberOfThreads = 1 }).Fit(data);
|
||||
new LogisticRegressionBinaryTrainer.Options { NumberOfThreads = 1 }).Fit(data);
|
||||
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);
|
||||
|
||||
// Pfi Indices:
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
var pipeline = new ColumnConcatenatingEstimator (ml, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
|
||||
.Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest)
|
||||
.Append(ml.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, }))
|
||||
new SdcaMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, }))
|
||||
.Append(new KeyToValueMappingEstimator(ml, "PredictedLabel"));
|
||||
|
||||
var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring);
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
.Append(new CustomMappingEstimator<IrisData, IrisData>(ml, action, null), TransformerScope.TrainTest)
|
||||
.Append(new ValueToKeyMappingEstimator(ml, "Label"), TransformerScope.TrainTest)
|
||||
.Append(ml.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 }))
|
||||
new SdcaMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 }))
|
||||
.Append(new KeyToValueMappingEstimator(ml, "PredictedLabel"));
|
||||
|
||||
var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring);
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(ml)
|
||||
.Append(ml.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train.
|
||||
var model = pipeline.Fit(data);
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
var pipeline = ml.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
|
||||
.Append(ml.Transforms.Conversion.MapValueToKey("Label"), TransformerScope.TrainTest)
|
||||
.Append(ml.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, }));
|
||||
new SdcaMulticlassTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1, }));
|
||||
|
||||
var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring);
|
||||
var engine = ml.Model.CreatePredictionEngine<IrisDataNoLabel, IrisPredictionNotCasted>(model);
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(ml)
|
||||
.Append(ml.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train.
|
||||
var model = pipeline.Fit(data);
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
|
||||
.AppendCacheCheckpoint(ml)
|
||||
.Append(ml.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Train.
|
||||
var model = pipeline.Fit(data);
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace Microsoft.ML.Tests.Scenarios.Api
|
|||
|
||||
// Train the first predictor.
|
||||
var trainer = ml.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { NumberOfThreads = 1 });
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { NumberOfThreads = 1 });
|
||||
|
||||
var firstModel = trainer.Fit(trainData);
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace Microsoft.ML.Scenarios
|
|||
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaMulticlassTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Read training and test data sets
|
||||
string dataPath = GetDataPath(TestDatasets.iris.trainFilename);
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace Microsoft.ML.Scenarios
|
|||
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "IrisPlantType"), TransformerScope.TrainTest)
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { NumberOfThreads = 1 }))
|
||||
new SdcaMulticlassTrainer.Options { NumberOfThreads = 1 }))
|
||||
.Append(mlContext.Transforms.Conversion.MapKeyToValue(("Plant", "PredictedLabel")));
|
||||
|
||||
// Train the pipeline
|
||||
|
|
|
@ -104,7 +104,7 @@ namespace Microsoft.ML.Scenarios
|
|||
|
||||
// Pipeline
|
||||
var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(
|
||||
mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryClassificationTrainer.Options { NumberOfThreads = 1 }),
|
||||
mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 }),
|
||||
useProbabilities: false);
|
||||
|
||||
var model = pipeline.Fit(data);
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace Microsoft.ML.Scenarios
|
|||
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
|
||||
.AppendCacheCheckpoint(mlContext)
|
||||
.Append(mlContext.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { NumberOfThreads = 1 }));
|
||||
new SdcaMulticlassTrainer.Options { NumberOfThreads = 1 }));
|
||||
|
||||
// Read training and test data sets
|
||||
string dataPath = GetDataPath(TestDatasets.iris.trainFilename);
|
||||
|
|
|
@ -56,7 +56,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
{
|
||||
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
|
||||
|
||||
pipe = pipe.Append(ML.BinaryClassification.Trainers.LogisticRegression(new LogisticRegressionBinaryClassificationTrainer.Options { ShowTrainingStatistics = true }));
|
||||
pipe = pipe.Append(ML.BinaryClassification.Trainers.LogisticRegression(new LogisticRegressionBinaryTrainer.Options { ShowTrainingStatistics = true }));
|
||||
var transformerChain = pipe.Fit(dataView) as TransformerChain<BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>>;
|
||||
|
||||
var linearModel = transformerChain.LastTransformer.Model.SubModel as LinearBinaryModelParameters;
|
||||
|
@ -73,7 +73,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
(IEstimator<ITransformer> pipe, IDataView dataView) = GetBinaryClassificationPipeline();
|
||||
|
||||
pipe = pipe.Append(ML.BinaryClassification.Trainers.LogisticRegression(
|
||||
new LogisticRegressionBinaryClassificationTrainer.Options
|
||||
new LogisticRegressionBinaryTrainer.Options
|
||||
{
|
||||
ShowTrainingStatistics = true,
|
||||
ComputeStandardDeviation = new ComputeLRTrainingStdThroughMkl(),
|
||||
|
|
|
@ -42,7 +42,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
{
|
||||
var (pipeline, data) = GetMulticlassPipeline();
|
||||
var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 });
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 });
|
||||
|
||||
pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer, useProbabilities: false))
|
||||
.Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));
|
||||
|
@ -60,7 +60,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
var (pipeline, data) = GetMulticlassPipeline();
|
||||
|
||||
var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 });
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 });
|
||||
|
||||
pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.PairwiseCoupling(sdcaTrainer))
|
||||
.Append(ML.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
|
||||
|
@ -83,7 +83,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
var data = loader.Load(GetDataPath(TestDatasets.irisData.trainFilename));
|
||||
|
||||
var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options {
|
||||
new SdcaNonCalibratedBinaryTrainer.Options {
|
||||
LabelColumnName = "Label",
|
||||
FeatureColumnName = "Vars",
|
||||
MaximumNumberOfIterations = 100,
|
||||
|
|
|
@ -24,11 +24,11 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
.Fit(data.AsDynamic).Transform(data.AsDynamic);
|
||||
|
||||
var binaryTrainer = ML.BinaryClassification.Trainers.SdcaCalibrated(
|
||||
new SdcaCalibratedBinaryClassificationTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 });
|
||||
new SdcaCalibratedBinaryTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 });
|
||||
TestEstimatorCore(binaryTrainer, binaryData);
|
||||
|
||||
var nonCalibratedBinaryTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated(
|
||||
new SdcaNonCalibratedBinaryClassificationTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 });
|
||||
new SdcaNonCalibratedBinaryTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 });
|
||||
TestEstimatorCore(nonCalibratedBinaryTrainer, binaryData);
|
||||
|
||||
var regressionTrainer = ML.Regression.Trainers.Sdca(
|
||||
|
@ -38,7 +38,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
var mcData = ML.Transforms.Conversion.MapValueToKey("Label").Fit(data.AsDynamic).Transform(data.AsDynamic);
|
||||
|
||||
var mcTrainer = ML.MulticlassClassification.Trainers.Sdca(
|
||||
new SdcaMulticlassClassificationTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 });
|
||||
new SdcaMulticlassTrainer.Options { ConvergenceTolerance = 1e-2f, MaximumNumberOfIterations = 10 });
|
||||
TestEstimatorCore(mcTrainer, mcData);
|
||||
|
||||
Done();
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
|
||||
// Define a tree model whose trees will be extracted to construct a tree featurizer.
|
||||
var trainer = ML.BinaryClassification.Trainers.FastTree(
|
||||
new FastTreeBinaryClassificationTrainer.Options
|
||||
new FastTreeBinaryTrainer.Options
|
||||
{
|
||||
NumberOfThreads = 1,
|
||||
NumberOfTrees = 10,
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
var (pipe, dataView) = GetBinaryClassificationPipeline();
|
||||
|
||||
var trainer = ML.BinaryClassification.Trainers.FastTree(
|
||||
new FastTreeBinaryClassificationTrainer.Options
|
||||
new FastTreeBinaryTrainer.Options
|
||||
{
|
||||
NumberOfThreads = 1,
|
||||
NumberOfTrees = 10,
|
||||
|
@ -70,7 +70,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
{
|
||||
var (pipe, dataView) = GetBinaryClassificationPipeline();
|
||||
|
||||
var trainer = new GamBinaryClassificationTrainer(Env, new GamBinaryClassificationTrainer.Options
|
||||
var trainer = new GamBinaryTrainer(Env, new GamBinaryTrainer.Options
|
||||
{
|
||||
GainConfidenceLevel = 0,
|
||||
NumberOfIterations = 15,
|
||||
|
@ -90,7 +90,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
var (pipe, dataView) = GetBinaryClassificationPipeline();
|
||||
|
||||
var trainer = ML.BinaryClassification.Trainers.FastForest(
|
||||
new FastForestBinaryClassificationTrainer.Options
|
||||
new FastForestBinaryTrainer.Options
|
||||
{
|
||||
NumberOfLeaves = 10,
|
||||
NumberOfTrees = 20,
|
||||
|
@ -294,7 +294,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
|
|||
var mlContext = new MLContext(seed: 0);
|
||||
var dataView = mlContext.Data.LoadFromEnumerable(dataList);
|
||||
int numberOfTrainingIterations = 3;
|
||||
var gbmTrainer = new LightGbmMulticlassClassificationTrainer(mlContext, new Options
|
||||
var gbmTrainer = new LightGbmMulticlassTrainer(mlContext, new Options
|
||||
{
|
||||
NumberOfIterations = numberOfTrainingIterations,
|
||||
MinimumExampleCountPerGroup = 1,
|
||||
|
|
Загрузка…
Ссылка в новой задаче