Generate built-in SweepableEstimator classes for all available estimators (#6125)

This commit is contained in:
Xiaoyun Zhang 2022-03-28 13:40:36 -07:00 коммит произвёл GitHub
Родитель bfba5d9836
Коммит a758217121
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
53 изменённых файлов: 3230 добавлений и 24 удалений

Просмотреть файл

@ -3,6 +3,10 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using Microsoft.ML.AutoML.CodeGen;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;
@ -289,5 +293,170 @@ namespace Microsoft.ML.AutoML
{
return new SweepableEstimator((MLContext context, Parameter param) => factory(context, param.AsType<T>()), ss);
}
internal SweepableEstimator[] BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
{
var res = new List<SweepableEstimator>();
if (useFastTree)
{
fastTreeOption = fastTreeOption ?? new FastTreeOption();
fastTreeOption.LabelColumnName = labelColumnName;
fastTreeOption.FeatureColumnName = featureColumnName;
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateFastTreeBinary(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
}
if (useFastForest)
{
fastForestOption = fastForestOption ?? new FastForestOption();
fastForestOption.LabelColumnName = labelColumnName;
fastForestOption.FeatureColumnName = featureColumnName;
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateFastForestBinary(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
}
if (useLgbm)
{
lgbmOption = lgbmOption ?? new LgbmOption();
lgbmOption.LabelColumnName = labelColumnName;
lgbmOption.FeatureColumnName = featureColumnName;
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateLightGbmBinary(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
}
if (useLbfgs)
{
lbfgsOption = lbfgsOption ?? new LbfgsOption();
lbfgsOption.LabelColumnName = labelColumnName;
lbfgsOption.FeatureColumnName = featureColumnName;
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionBinary(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
}
if (useSdca)
{
sdcaOption = sdcaOption ?? new SdcaOption();
sdcaOption.LabelColumnName = labelColumnName;
sdcaOption.FeatureColumnName = featureColumnName;
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionBinary(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
}
return res.ToArray();
}
internal SweepableEstimator[] MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
{
var res = new List<SweepableEstimator>();
if (useFastTree)
{
fastTreeOption = fastTreeOption ?? new FastTreeOption();
fastTreeOption.LabelColumnName = labelColumnName;
fastTreeOption.FeatureColumnName = featureColumnName;
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateFastTreeOva(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
}
if (useFastForest)
{
fastForestOption = fastForestOption ?? new FastForestOption();
fastForestOption.LabelColumnName = labelColumnName;
fastForestOption.FeatureColumnName = featureColumnName;
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateFastForestOva(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
}
if (useLgbm)
{
lgbmOption = lgbmOption ?? new LgbmOption();
lgbmOption.LabelColumnName = labelColumnName;
lgbmOption.FeatureColumnName = featureColumnName;
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateLightGbmMulti(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
}
if (useLbfgs)
{
lbfgsOption = lbfgsOption ?? new LbfgsOption();
lbfgsOption.LabelColumnName = labelColumnName;
lbfgsOption.FeatureColumnName = featureColumnName;
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionOva(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
res.Add(SweepableEstimatorFactory.CreateLbfgsMaximumEntropyMulti(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
}
if (useSdca)
{
sdcaOption = sdcaOption ?? new SdcaOption();
sdcaOption.LabelColumnName = labelColumnName;
sdcaOption.FeatureColumnName = featureColumnName;
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateSdcaMaximumEntropyMulti(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionOva(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
}
return res.ToArray();
}
internal SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
{
var res = new List<SweepableEstimator>();
if (useFastTree)
{
fastTreeOption = fastTreeOption ?? new FastTreeOption();
fastTreeOption.LabelColumnName = labelColumnName;
fastTreeOption.FeatureColumnName = featureColumnName;
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateFastTreeRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
res.Add(SweepableEstimatorFactory.CreateFastTreeTweedieRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
}
if (useFastForest)
{
fastForestOption = fastForestOption ?? new FastForestOption();
fastForestOption.LabelColumnName = labelColumnName;
fastForestOption.FeatureColumnName = featureColumnName;
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateFastForestRegression(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
}
if (useLgbm)
{
lgbmOption = lgbmOption ?? new LgbmOption();
lgbmOption.LabelColumnName = labelColumnName;
lgbmOption.FeatureColumnName = featureColumnName;
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateLightGbmRegression(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
}
if (useLbfgs)
{
lbfgsOption = lbfgsOption ?? new LbfgsOption();
lbfgsOption.LabelColumnName = labelColumnName;
lbfgsOption.FeatureColumnName = featureColumnName;
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateLbfgsPoissonRegressionRegression(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
}
if (useSdca)
{
sdcaOption = sdcaOption ?? new SdcaOption();
sdcaOption.LabelColumnName = labelColumnName;
sdcaOption.FeatureColumnName = featureColumnName;
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
res.Add(SweepableEstimatorFactory.CreateSdcaRegression(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
}
return res.ToArray();
}
}
}

Просмотреть файл

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace Microsoft.ML.AutoML
@ -27,7 +28,34 @@ namespace Microsoft.ML.AutoML
public static SweepableEstimatorPipeline Append(this SweepableEstimator estimator, IEstimator<ITransformer> estimator1)
{
return estimator.Append(estimator1);
return new SweepableEstimatorPipeline().Append(estimator).Append(estimator1);
}
public static MultiModelPipeline Append(this IEstimator<ITransformer> estimator, params SweepableEstimator[] estimators)
{
var sweepableEstimator = new SweepableEstimator((context, parameter) => estimator, new SearchSpace.SearchSpace());
var multiModelPipeline = new MultiModelPipeline().Append(sweepableEstimator).Append(estimators);
return multiModelPipeline;
}
public static MultiModelPipeline Append(this SweepableEstimatorPipeline pipeline, params SweepableEstimator[] estimators)
{
var multiModelPipeline = new MultiModelPipeline();
foreach (var estimator in pipeline.Estimators)
{
multiModelPipeline = multiModelPipeline.Append(estimator);
}
return multiModelPipeline.Append(estimators);
}
public static MultiModelPipeline Append(this SweepableEstimator estimator, params SweepableEstimator[] estimators)
{
var multiModelPipeline = new MultiModelPipeline();
multiModelPipeline = multiModelPipeline.Append(estimator);
return multiModelPipeline.Append(estimators);
}
}
}

Просмотреть файл

@ -3,12 +3,46 @@
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Threading;
namespace Microsoft.ML.AutoML
{
internal static class AutoMlUtils
{
private const string MLNetMaxThread = "MLNET_MAX_THREAD";
public static readonly ThreadLocal<Random> Random = new ThreadLocal<Random>(() => new Random());
/// <summary>
/// Return number of thread if MLNET_MAX_THREAD is set, otherwise return null.
/// </summary>
public static int? GetNumberOfThreadFromEnvrionment()
{
var res = Environment.GetEnvironmentVariable(MLNetMaxThread);
if (int.TryParse(res, out var numberOfThread))
{
return numberOfThread;
}
return null;
}
public static InputOutputColumnPair[] CreateInputOutputColumnPairsFromStrings(string[] inputs, string[] outputs)
{
if (inputs.Length != outputs.Length)
{
throw new Exception("inputs and outputs count must match");
}
var res = new List<InputOutputColumnPair>();
for (int i = 0; i != inputs.Length; ++i)
{
res.Add(new InputOutputColumnPair(outputs[i], inputs[i]));
}
return res.ToArray();
}
}
}

Просмотреть файл

@ -1,7 +1,7 @@
{
"EstimatorFactoryGenerator": false,
"CodeGenCatalogGenerator": false,
"SweepableEstimatorFactory": true,
"EstimatorTypeGenerator": true,
"SearchSpaceGenerator": true,
"SweepableEstimatorGenerator": false
"SweepableEstimatorGenerator": true
}

Просмотреть файл

@ -14,11 +14,13 @@
<PrivateAssets>all</PrivateAssets>
</ProjectReference>
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
<ProjectReference Include="..\Microsoft.ML.OnnxTransformer\Microsoft.ML.OnnxTransformer.csproj" />
<ProjectReference Include="..\Microsoft.ML.SearchSpace\Microsoft.ML.SearchSpace.csproj">
<PrivateAssets>all</PrivateAssets>
<IncludeInNuget>true</IncludeInNuget>
</ProjectReference>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="$(MicrosoftCodeAnalysisCSharpVersion)" />
<ProjectReference Include="..\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj" />
<ProjectReference Include="..\Microsoft.ML.Vision\Microsoft.ML.Vision.csproj" />
<ProjectReference Include="..\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
<ProjectReference Include="..\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj" />
@ -43,13 +45,13 @@
<Target DependsOnTargets="ResolveReferences" Name="CopyProjectReferencesToPackage">
<ItemGroup>
<!--Include DLLs of Project References-->
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true'))"/>
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths-&gt;WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')-&gt;WithMetadataValue('IncludeInNuget','true'))" />
<!--Include PDBs of Project References-->
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true')->Replace('.dll', '.pdb'))"/>
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths-&gt;WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')-&gt;WithMetadataValue('IncludeInNuget','true')-&gt;Replace('.dll', '.pdb'))" />
<!--Include PDBs for Native binaries-->
<!--The path needed to be hardcoded for this to work on our publishing CI-->
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x86\native"/>
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x64\native"/>
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x86\native" />
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x64\native" />
</ItemGroup>
</Target>

Просмотреть файл

@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
namespace Microsoft.ML.AutoML
{
internal class MultiModelPipelineConverter : JsonConverter<MultiModelPipeline>
{
public override MultiModelPipeline Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var jValue = JsonValue.Parse(ref reader);
var schema = jValue["schema"].GetValue<string>();
var estimators = jValue["estimator"].GetValue<Dictionary<string, SweepableEstimator>>();
return new MultiModelPipeline(estimators, Entity.FromExpression(schema));
}
public override void Write(Utf8JsonWriter writer, MultiModelPipeline value, JsonSerializerOptions options)
{
var jsonObject = JsonNode.Parse("{}");
jsonObject["schema"] = value.Schema.ToString();
jsonObject["estimators"] = JsonValue.Create(value.Estimators);
jsonObject.WriteTo(writer, options);
}
}
}

Просмотреть файл

@ -0,0 +1,37 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Microsoft.ML.AutoML.CodeGen;
using Microsoft.ML.SearchSpace;
namespace Microsoft.ML.AutoML
{
internal class SweepableEstimatorConverter : JsonConverter<SweepableEstimator>
{
public override SweepableEstimator Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var jsonObject = JsonValue.Parse(ref reader);
var estimatorType = jsonObject["estimatorType"].GetValue<EstimatorType>();
var parameter = jsonObject["parameter"].GetValue<Parameter>();
var estimator = new SweepableEstimator(estimatorType);
estimator.Parameter = parameter;
return estimator;
}
public override void Write(Utf8JsonWriter writer, SweepableEstimator value, JsonSerializerOptions options)
{
var jObject = JsonObject.Parse("{}");
jObject["estimatorType"] = JsonValue.Create(value.EstimatorType);
jObject["parameter"] = JsonValue.Create(value.Parameter);
jObject.WriteTo(writer, options);
}
}
}

Просмотреть файл

@ -0,0 +1,39 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using Microsoft.ML.SearchSpace;
using Newtonsoft.Json;
namespace Microsoft.ML.AutoML
{
internal class SweepableEstimatorPipelineConverter : JsonConverter<SweepableEstimatorPipeline>
{
public override SweepableEstimatorPipeline Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
var jNode = JsonNode.Parse(ref reader);
var parameter = jNode["parameter"].GetValue<Parameter>();
var estimators = jNode["estimators"].GetValue<SweepableEstimator[]>();
var pipeline = new SweepableEstimatorPipeline(estimators, parameter);
return pipeline;
}
public override void Write(Utf8JsonWriter writer, SweepableEstimatorPipeline value, JsonSerializerOptions options)
{
var parameter = value.Parameter;
var estimators = value.Estimators;
var jNode = JsonNode.Parse("{}");
jNode["parameter"] = JsonValue.Create(parameter);
jNode["estimators"] = JsonValue.Create(estimators);
jNode.WriteTo(writer, options);
}
}
}

Просмотреть файл

@ -12,6 +12,7 @@ namespace Microsoft.ML.AutoML
protected Estimator()
{
this.Parameter = Parameter.CreateNestedParameter();
this.EstimatorType = EstimatorType.Unknown;
}
internal Estimator(EstimatorType estimatorType)

Просмотреть файл

@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class ApplyOnnxModel
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ApplyOnnxModelOption param)
{
return context.Transforms.ApplyOnnxModel(outputColumnName: param.OutputColumnName, inputColumnName: param.InputColumnName, modelFile: param.ModelFile, gpuDeviceId: param.GpuDeviceId, fallbackToCpu: param.FallbackToCpu);
}
}
}

Просмотреть файл

@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class Naive
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, NaiveOption param)
{
return context.BinaryClassification.Calibrators.Naive(param.LabelColumnName, param.ScoreColumnName);
}
}
}

Просмотреть файл

@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class Concatenate
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ConcatOption param)
{
return context.Transforms.Concatenate(param.OutputColumnName, param.InputColumnNames);
}
}
}

Просмотреть файл

@ -0,0 +1,63 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Trainers.FastTree;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class FastForestOva
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastForestOption param)
{
var option = new FastForestBinaryTrainer.Options()
{
NumberOfTrees = param.NumberOfTrees,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
FeatureFraction = param.FeatureFraction,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastForest(option), labelColumnName: param.LabelColumnName);
}
}
internal partial class FastForestRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastForestOption param)
{
var option = new FastForestRegressionTrainer.Options()
{
NumberOfTrees = param.NumberOfTrees,
FeatureFraction = param.FeatureFraction,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.Regression.Trainers.FastForest(option);
}
}
internal partial class FastForestBinary
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastForestOption param)
{
var option = new FastForestBinaryTrainer.Options()
{
NumberOfTrees = param.NumberOfTrees,
NumberOfLeaves = param.NumberOfLeaves,
FeatureFraction = param.FeatureFraction,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.BinaryClassification.Trainers.FastForest(option);
}
}
}

Просмотреть файл

@ -0,0 +1,96 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Trainers.FastTree;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class FastTreeOva
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
{
var option = new FastTreeBinaryTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfTrees = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
FeatureFraction = param.FeatureFraction,
};
return context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastTree(option), labelColumnName: param.LabelColumnName);
}
}
internal partial class FastTreeRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
{
var option = new FastTreeRegressionTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfTrees = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
FeatureFraction = param.FeatureFraction,
};
return context.Regression.Trainers.FastTree(option);
}
}
internal partial class FastTreeTweedieRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
{
var option = new FastTreeTweedieTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfTrees = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
FeatureFraction = param.FeatureFraction,
};
return context.Regression.Trainers.FastTreeTweedie(option);
}
}
internal partial class FastTreeBinary
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
{
var option = new FastTreeBinaryTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfTrees = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
FeatureFraction = param.FeatureFraction,
};
return context.BinaryClassification.Trainers.FastTree(option);
}
}
}

Просмотреть файл

@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class FeaturizeText
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FeaturizeTextOption param)
{
return context.Transforms.Text.FeaturizeText(param.OutputColumnName, param.InputColumnName);
}
}
}

Просмотреть файл

@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class ForecastBySsa
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SsaOption param)
{
if (param.SeriesLength <= param.WindowSize || param.TrainSize <= 2 * param.WindowSize)
{
throw new Exception("ForecastBySsa param check error");
}
return context.Forecasting.ForecastBySsa(param.OutputColumnName, param.InputColumnName, param.WindowSize, param.SeriesLength, param.TrainSize, param.Horizon, confidenceLowerBoundColumn: param.ConfidenceLowerBoundColumn, confidenceUpperBoundColumn: param.ConfidenceUpperBoundColumn);
}
}
}

Просмотреть файл

@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class LoadImages
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LoadImageOption param)
{
return context.Transforms.LoadImages(param.OutputColumnName, param.ImageFolder, param.InputColumnName);
}
}
internal partial class LoadRawImageBytes
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LoadImageOption param)
{
return context.Transforms.LoadRawImageBytes(param.OutputColumnName, param.ImageFolder, param.InputColumnName);
}
}
internal partial class ResizeImages
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ResizeImageOption param)
{
return context.Transforms.ResizeImages(param.OutputColumnName, param.ImageWidth, param.ImageHeight, param.InputColumnName, param.Resizing, param.CropAnchor);
}
}
internal partial class ExtractPixels
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ExtractPixelsOption param)
{
return context.Transforms.ExtractPixels(param.OutputColumnName, param.InputColumnName, param.ColorsToExtract, param.OrderOfExtraction, outputAsFloatArray: param.OutputAsFloatArray);
}
}
internal partial class ImageClassificationMulti
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ImageClassificationOption param)
{
return context.MulticlassClassification.Trainers.ImageClassification(param.LabelColumnName, param.FeatureColumnName, param.ScoreColumnName);
}
}
}

Просмотреть файл

@ -0,0 +1,81 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Trainers;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class LbfgsMaximumEntropyMulti
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
{
var option = new LbfgsMaximumEntropyMulticlassTrainer.Options()
{
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.MulticlassClassification.Trainers.LbfgsMaximumEntropy(option);
}
}
internal partial class LbfgsPoissonRegressionRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
{
var option = new LbfgsPoissonRegressionTrainer.Options()
{
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.Regression.Trainers.LbfgsPoissonRegression(option);
}
}
internal partial class LbfgsLogisticRegressionBinary
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
{
var option = new LbfgsLogisticRegressionBinaryTrainer.Options()
{
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.BinaryClassification.Trainers.LbfgsLogisticRegression(option);
}
}
internal partial class LbfgsLogisticRegressionOva
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
{
var option = new LbfgsLogisticRegressionBinaryTrainer.Options()
{
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
var binaryTrainer = context.BinaryClassification.Trainers.LbfgsLogisticRegression(option);
return context.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, param.LabelColumnName);
}
}
}

Просмотреть файл

@ -0,0 +1,92 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Trainers.LightGbm;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class LightGbmMulti
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LgbmOption param)
{
var option = new LightGbmMulticlassTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfIterations = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
Booster = new GradientBooster.Options()
{
SubsampleFraction = param.SubsampleFraction,
FeatureFraction = param.FeatureFraction,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
},
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
};
return context.MulticlassClassification.Trainers.LightGbm(option);
}
}
internal partial class LightGbmBinary
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LgbmOption param)
{
var option = new LightGbmBinaryTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfIterations = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
Booster = new GradientBooster.Options()
{
SubsampleFraction = param.SubsampleFraction,
FeatureFraction = param.FeatureFraction,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
},
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
};
return context.BinaryClassification.Trainers.LightGbm(option);
}
}
internal partial class LightGbmRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LgbmOption param)
{
var option = new LightGbmRegressionTrainer.Options()
{
NumberOfLeaves = param.NumberOfLeaves,
NumberOfIterations = param.NumberOfTrees,
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
LearningRate = param.LearningRate,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
Booster = new GradientBooster.Options()
{
SubsampleFraction = param.SubsampleFraction,
FeatureFraction = param.FeatureFraction,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
},
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
};
return context.Regression.Trainers.LightGbm(option);
}
}
}

Просмотреть файл

@ -0,0 +1,22 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class MapValueToKey
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, MapValueToKeyOption param)
{
return context.Transforms.Conversion.MapValueToKey(param.OutputColumnName, param.InputColumnName);
}
}
internal partial class MapKeyToValue
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, MapKeyToValueOption param)
{
return context.Transforms.Conversion.MapKeyToValue(param.OutputColumnName, param.InputColumnName);
}
}
}

Просмотреть файл

@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class MatrixFactorization
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, MatrixFactorizationOption param)
{
return context.Recommendation().Trainers.MatrixFactorization(param.LabelColumnName, param.MatrixColumnIndexColumnName, param.MatrixRowIndexColumnName, param.ApproximationRank, param.LearningRate, param.NumberOfIterations);
}
}
}

Просмотреть файл

@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class NormalizeMinMax
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, NormalizeMinMaxOption param)
{
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.OutputColumnNames, param.InputColumnNames);
return context.Transforms.NormalizeMinMax(inputOutputPairs);
}
}
}

Просмотреть файл

@ -0,0 +1,24 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class OneHotEncoding
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, OneHotOption param)
{
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
return context.Transforms.Categorical.OneHotEncoding(inputOutputPairs);
}
}
internal partial class OneHotHashEncoding
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, OneHotOption param)
{
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
return context.Transforms.Categorical.OneHotHashEncoding(inputOutputPairs);
}
}
}

Просмотреть файл

@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class ReplaceMissingValues
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ReplaceMissingValueOption param)
{
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
return context.Transforms.ReplaceMissingValues(inputOutputPairs);
}
}
}

Просмотреть файл

@ -0,0 +1,81 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Trainers;
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class SdcaRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
{
var option = new SdcaRegressionTrainer.Options()
{
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.Regression.Trainers.Sdca(option);
}
}
internal partial class SdcaMaximumEntropyMulti
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
{
var option = new SdcaMaximumEntropyMulticlassTrainer.Options()
{
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.MulticlassClassification.Trainers.SdcaMaximumEntropy(option);
}
}
internal partial class SdcaLogisticRegressionBinary
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
{
var option = new SdcaLogisticRegressionBinaryTrainer.Options()
{
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
return context.BinaryClassification.Trainers.SdcaLogisticRegression(option);
}
}
internal partial class SdcaLogisticRegressionOva
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
{
var option = new SdcaLogisticRegressionBinaryTrainer.Options()
{
LabelColumnName = param.LabelColumnName,
FeatureColumnName = param.FeatureColumnName,
ExampleWeightColumnName = param.ExampleWeightColumnName,
L1Regularization = param.L1Regularization,
L2Regularization = param.L2Regularization,
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
};
var binaryTrainer = context.BinaryClassification.Trainers.SdcaLogisticRegression(option);
return context.MulticlassClassification.Trainers.OneVersusAll(binaryEstimator: binaryTrainer, labelColumnName: param.LabelColumnName);
}
}
}

Просмотреть файл

@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class ConvertType
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ConvertTypeOption param)
{
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
return context.Transforms.Conversion.ConvertType(inputOutputPairs);
}
}
}

Просмотреть файл

@ -6,9 +6,11 @@ using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Text.Json.Serialization;
namespace Microsoft.ML.AutoML
{
[JsonConverter(typeof(MultiModelPipelineConverter))]
internal class MultiModelPipeline
{
private static readonly StringEntity _nilStringEntity = new StringEntity("Nil");

Просмотреть файл

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;
using Microsoft.ML.AutoML.CodeGen;
using Microsoft.ML.SearchSpace;
@ -12,6 +13,7 @@ namespace Microsoft.ML.AutoML
/// <summary>
/// Estimator with search space.
/// </summary>
[JsonConverter(typeof(SweepableEstimatorConverter))]
internal class SweepableEstimator : Estimator
{
private readonly Func<MLContext, Parameter, IEstimator<ITransformer>> _factory;
@ -50,11 +52,6 @@ namespace Microsoft.ML.AutoML
internal virtual IEnumerable<string> NugetDependencies { get; }
internal virtual string FunctionName { get; }
internal virtual string ToDisplayString(Parameter param)
{
throw new NotImplementedException();
}
}
internal abstract class SweepableEstimator<TOption> : SweepableEstimator
@ -75,12 +72,5 @@ namespace Microsoft.ML.AutoML
{
return this.BuildFromOption(context, param.AsType<TOption>());
}
internal abstract string ToDisplayString(TOption param);
internal override string ToDisplayString(Parameter param)
{
return this.ToDisplayString(param.AsType<TOption>());
}
}
}

Просмотреть файл

@ -5,11 +5,13 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;
namespace Microsoft.ML.AutoML
{
[JsonConverter(typeof(SweepableEstimatorPipelineConverter))]
internal class SweepableEstimatorPipeline
{
private readonly List<SweepableEstimator> _estimators;

Просмотреть файл

@ -50,7 +50,7 @@ namespace Microsoft.ML.SearchSpace
}
/// <summary>
/// <see cref="Parameter"/> is used to save sweeping result from <see cref="ITuner.Propose(SearchSpace)"/> and is used to restore mlnet pipeline from sweepable pipline.
/// <see cref="Parameter"/> is used to save sweeping result from tuner and is used to restore mlnet pipeline from sweepable pipline.
/// </summary>
[JsonConverter(typeof(ParameterConverter))]
public sealed class Parameter : IDictionary<string, Parameter>

Просмотреть файл

@ -0,0 +1,61 @@
{
"0": {
"FeatureSpaceDim": 0,
"Default": [],
"Step": []
},
"1": {
"FeatureSpaceDim": 0,
"Default": [],
"Step": []
},
"2": {
"FeatureSpaceDim": 0,
"Default": [],
"Step": []
},
"3": {
"FeatureSpaceDim": 9,
"Default": [
1,
0,
1,
1,
0.71428571428571408,
0,
0,
0,
1
],
"Step": [
null,
null,
null,
null,
null,
null,
null,
null,
null
]
},
"4": {
"FeatureSpaceDim": 6,
"Default": [
1,
0.89689626841273451,
0.71428571428571408,
0.55365468248122718,
0,
0
],
"Step": [
null,
null,
null,
null,
null,
null
]
}
}

Просмотреть файл

@ -0,0 +1,28 @@
{
"0": {},
"1": {},
"2": {},
"3": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"LearningRate": 1,
"NumberOfTrees": 4,
"SubsampleFraction": 1,
"MaximumBinCountPerFeature": 255,
"FeatureFraction": 1,
"L1Regularization": 0.0000000002,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Feature"
},
"4": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"NumberOfTrees": 4,
"MaximumBinCountPerFeature": 255,
"FeatureFraction": 1,
"LearningRate": 0.10,
"LabelColumnName": "Label",
"FeatureColumnName": "Feature"
}
}

Просмотреть файл

@ -0,0 +1,66 @@
{
"schema": "e0 * (e1 + e2 + e3 + e4 + e5)",
"estimators": {
"e0": {
"estimatorType": "Unknown",
"parameter": {}
},
"e1": {
"estimatorType": "FastTreeBinary",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"NumberOfTrees": 4,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"LearningRate": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e2": {
"estimatorType": "FastForestBinary",
"parameter": {
"NumberOfTrees": 4,
"NumberOfLeaves": 4,
"FeatureFraction": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e3": {
"estimatorType": "LightGbmBinary",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"LearningRate": 1,
"NumberOfTrees": 4,
"SubsampleFraction": 1,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"L1Regularization": 0.0000000002,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e4": {
"estimatorType": "LbfgsLogisticRegressionBinary",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e5": {
"estimatorType": "SdcaLogisticRegressionBinary",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
}
}
}

Просмотреть файл

@ -0,0 +1,84 @@
{
"schema": "e0 * (e1 + e2 + e3 + e4 + e5 + e6 + e7)",
"estimators": {
"e0": {
"estimatorType": "Unknown",
"parameter": {}
},
"e1": {
"estimatorType": "FastTreeOva",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"NumberOfTrees": 4,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"LearningRate": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e2": {
"estimatorType": "FastForestOva",
"parameter": {
"NumberOfTrees": 4,
"NumberOfLeaves": 4,
"FeatureFraction": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e3": {
"estimatorType": "LightGbmMulti",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"LearningRate": 1,
"NumberOfTrees": 4,
"SubsampleFraction": 1,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"L1Regularization": 0.0000000002,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e4": {
"estimatorType": "LbfgsLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e5": {
"estimatorType": "LbfgsMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e6": {
"estimatorType": "SdcaMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e7": {
"estimatorType": "SdcaLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
}
}
}

Просмотреть файл

@ -0,0 +1,84 @@
{
"schema": "e0 * (e1 + e2 + e3 + e4 + e5 + e6 + e7)",
"estimators": {
"e0": {
"estimatorType": "Unknown",
"parameter": {}
},
"e1": {
"estimatorType": "FastTreeOva",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"NumberOfTrees": 4,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"LearningRate": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e2": {
"estimatorType": "FastForestOva",
"parameter": {
"NumberOfTrees": 4,
"NumberOfLeaves": 4,
"FeatureFraction": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e3": {
"estimatorType": "LightGbmMulti",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"LearningRate": 1,
"NumberOfTrees": 4,
"SubsampleFraction": 1,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"L1Regularization": 0.0000000002,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e4": {
"estimatorType": "LbfgsLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e5": {
"estimatorType": "LbfgsMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e6": {
"estimatorType": "SdcaMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e7": {
"estimatorType": "SdcaLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
}
}
}

Просмотреть файл

@ -0,0 +1,90 @@
{
"schema": "e0 * (e1 + e2 + e3 + e4 + e5 + e6 + e7)",
"estimators": {
"e0": {
"estimatorType": "FastForestBinary",
"parameter": {
"NumberOfTrees": 4,
"NumberOfLeaves": 4,
"FeatureFraction": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Feature"
}
},
"e1": {
"estimatorType": "FastTreeOva",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"NumberOfTrees": 4,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"LearningRate": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e2": {
"estimatorType": "FastForestOva",
"parameter": {
"NumberOfTrees": 4,
"NumberOfLeaves": 4,
"FeatureFraction": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e3": {
"estimatorType": "LightGbmMulti",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"LearningRate": 1,
"NumberOfTrees": 4,
"SubsampleFraction": 1,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"L1Regularization": 0.0000000002,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e4": {
"estimatorType": "LbfgsLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e5": {
"estimatorType": "LbfgsMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e6": {
"estimatorType": "SdcaMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e7": {
"estimatorType": "SdcaLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
}
}
}

Просмотреть файл

@ -0,0 +1,88 @@
{
"schema": "e0 * e1 * (e2 + e3 + e4 + e5 + e6 + e7 + e8)",
"estimators": {
"e0": {
"estimatorType": "Unknown",
"parameter": {}
},
"e1": {
"estimatorType": "FeaturizeText",
"parameter": {}
},
"e2": {
"estimatorType": "FastTreeOva",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"NumberOfTrees": 4,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"LearningRate": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e3": {
"estimatorType": "FastForestOva",
"parameter": {
"NumberOfTrees": 4,
"NumberOfLeaves": 4,
"FeatureFraction": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e4": {
"estimatorType": "LightGbmMulti",
"parameter": {
"NumberOfLeaves": 4,
"MinimumExampleCountPerLeaf": 20,
"LearningRate": 1,
"NumberOfTrees": 4,
"SubsampleFraction": 1,
"MaximumBinCountPerFeature": 256,
"FeatureFraction": 1,
"L1Regularization": 0.0000000002,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e5": {
"estimatorType": "LbfgsLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e6": {
"estimatorType": "LbfgsMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e7": {
"estimatorType": "SdcaMaximumEntropyMulti",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
},
"e8": {
"estimatorType": "SdcaLogisticRegressionOva",
"parameter": {
"L1Regularization": 1,
"L2Regularization": 0.1,
"LabelColumnName": "Label",
"FeatureColumnName": "Features"
}
}
}
}

Просмотреть файл

@ -11,17 +11,34 @@ using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using FluentAssertions;
using Microsoft.ML.TestFramework;
using Newtonsoft.Json;
using Xunit;
using Xunit.Abstractions;
using Microsoft.ML.AutoML.CodeGen;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace Microsoft.ML.AutoML.Test
{
public class SweepableEstimatorPipelineTest : BaseTestClass
{
private readonly JsonSerializerOptions _jsonSerializerOptions;
public SweepableEstimatorPipelineTest(ITestOutputHelper output)
: base(output)
{
this._jsonSerializerOptions = new JsonSerializerOptions()
{
WriteIndented = true,
Converters =
{
new JsonStringEnumConverter(), new DoubleToDecimalConverter(), new FloatToDecimalConverter(),
},
};
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
{
Approvals.UseAssemblyLocationForApprovedFiles();
}
}
[Fact]
@ -51,5 +68,86 @@ namespace Microsoft.ML.AutoML.Test
pipeline.BuildSweepableEstimatorPipeline("e0 * e2").ToString().Should().Be("Concatenate=>ApplyOnnxModel");
pipeline.BuildSweepableEstimatorPipeline("e1 * Nil").ToString().Should().Be("ConvertType");
}
[Fact]
public void MultiModelPipeline_append_pipeline_test()
{
var e1 = new SweepableEstimator(CodeGen.EstimatorType.Concatenate);
var e2 = new SweepableEstimator(CodeGen.EstimatorType.ConvertType);
var e3 = new SweepableEstimator(CodeGen.EstimatorType.ApplyOnnxModel);
var e4 = new SweepableEstimator(CodeGen.EstimatorType.LightGbmBinary);
var e5 = new SweepableEstimator(CodeGen.EstimatorType.FastTreeBinary);
var pipeline1 = new MultiModelPipeline();
var pipeline2 = new MultiModelPipeline();
pipeline1 = pipeline1.Append(e1 + e2 * e3);
pipeline2 = pipeline2.Append(e1 * (e3 + e4) + e5);
pipeline1 = pipeline1.Append(pipeline2);
pipeline1.Schema.ToString().Should().Be("(e0 + e1 * e2) * (e3 * (e4 + e5) + e6)");
}
[Fact]
public void SweepableEstimatorPipeline_search_space_test()
{
var pipeline = this.CreateSweepbaleEstimatorPipeline();
pipeline.SearchSpace.FeatureSpaceDim.Should().Be(15);
// TODO
// verify other properties in search space.
}
[Fact]
public void SweepableEstimatorPipeline_can_be_created_from_MultiModelPipeline()
{
var multiModelPipeline = this.CreateMultiModelPipeline();
var pipelines = multiModelPipeline.PipelineIds;
pipelines.Should().BeEquivalentTo("e0 * e3 * e4", "e1 * e2 * e3 * e4", "e0 * Nil * e4", "e1 * e2 * Nil * e4", "Nil * e3 * e4", "e0 * e3 * e5", "e1 * e2 * e3 * e5", "e0 * Nil * e5", "e1 * e2 * Nil * e5", "Nil * e3 * e5", "Nil * Nil * e4", "Nil * Nil * e5");
var singleModelPipeline = multiModelPipeline.BuildSweepableEstimatorPipeline(pipelines[0]);
singleModelPipeline.ToString().Should().Be("ReplaceMissingValues=>Concatenate=>LightGbmBinary");
singleModelPipeline = multiModelPipeline.BuildSweepableEstimatorPipeline(pipelines[2]);
singleModelPipeline.ToString().Should().Be("ReplaceMissingValues=>LightGbmBinary");
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
public void SweepableEstimatorPipeline_search_space_init_value_test()
{
var singleModelPipeline = this.CreateSweepbaleEstimatorPipeline();
var defaultParam = singleModelPipeline.SearchSpace.SampleFromFeatureSpace(singleModelPipeline.SearchSpace.Default);
Approvals.Verify(JsonSerializer.Serialize(defaultParam, this._jsonSerializerOptions));
}
private SweepableEstimatorPipeline CreateSweepbaleEstimatorPipeline()
{
var concat = SweepableEstimatorFactory.CreateConcatenate(new ConcatOption());
var replaceMissingValue = SweepableEstimatorFactory.CreateReplaceMissingValues(new ReplaceMissingValueOption());
var oneHot = SweepableEstimatorFactory.CreateOneHotEncoding(new OneHotOption());
var lightGbm = SweepableEstimatorFactory.CreateLightGbmBinary(new LgbmOption());
var fastTree = SweepableEstimatorFactory.CreateFastTreeBinary(new FastTreeOption());
var pipeline = new SweepableEstimatorPipeline(new SweepableEstimator[] { concat, replaceMissingValue, oneHot, lightGbm, fastTree });
return pipeline;
}
private MultiModelPipeline CreateMultiModelPipeline()
{
var concat = SweepableEstimatorFactory.CreateConcatenate(new ConcatOption());
var replaceMissingValue = SweepableEstimatorFactory.CreateReplaceMissingValues(new ReplaceMissingValueOption());
var oneHot = SweepableEstimatorFactory.CreateOneHotEncoding(new OneHotOption());
var lightGbm = SweepableEstimatorFactory.CreateLightGbmBinary(new LgbmOption());
var fastTree = SweepableEstimatorFactory.CreateFastTreeBinary(new FastTreeOption());
var pipeline = new MultiModelPipeline();
pipeline = pipeline.AppendOrSkip(replaceMissingValue + replaceMissingValue * oneHot);
pipeline = pipeline.AppendOrSkip(concat);
pipeline = pipeline.Append(lightGbm + fastTree);
return pipeline;
}
}
}

Просмотреть файл

@ -0,0 +1,152 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Text;
using Xunit;
using Microsoft.ML.AutoML.CodeGen;
using FluentAssertions;
using Microsoft.ML.TestFramework;
using Xunit.Abstractions;
using ApprovalTests.Namers;
using ApprovalTests.Reporters;
using System.Text.Json.Serialization;
using System.Text.Json;
using ApprovalTests;
namespace Microsoft.ML.AutoML.Test
{
public class SweepableExtensionTest : BaseTestClass
{
private readonly JsonSerializerOptions _jsonSerializerOptions;
public SweepableExtensionTest(ITestOutputHelper output)
: base(output)
{
this._jsonSerializerOptions = new JsonSerializerOptions()
{
WriteIndented = true,
Converters =
{
new JsonStringEnumConverter(), new DoubleToDecimalConverter(), new FloatToDecimalConverter(),
},
};
this._jsonSerializerOptions.Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping;
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
{
Approvals.UseAssemblyLocationForApprovedFiles();
}
}
[Fact]
public void CreateSweepableEstimatorPipelineFromIEstimatorTest()
{
var context = new MLContext();
var estimator = context.Transforms.Concatenate("output", "input");
var pipeline = estimator.Append(SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption()));
pipeline.ToString().Should().Be("Unknown=>FastForestBinary");
}
[Fact]
public void AppendIEstimatorToSweepabeEstimatorPipelineTest()
{
var context = new MLContext();
var estimator = context.Transforms.Concatenate("output", "input");
var pipeline = estimator.Append(SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption()));
pipeline = pipeline.Append(context.Transforms.CopyColumns("output", "input"));
pipeline.ToString().Should().Be("Unknown=>FastForestBinary=>Unknown");
}
[Fact]
public void CreateSweepableEstimatorPipelineFromSweepableEstimatorTest()
{
var estimator = SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption());
var pipeline = estimator.Append(estimator);
pipeline.ToString().Should().Be("FastForestBinary=>FastForestBinary");
}
[Fact]
public void CreateSweepableEstimatorPipelineFromSweepableEstimatorAndIEstimatorTest()
{
var context = new MLContext();
var estimator = SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption());
var pipeline = estimator.Append(context.Transforms.Concatenate("output", "input"));
pipeline.ToString().Should().Be("FastForestBinary=>Unknown");
}
[Fact]
[UseReporter(typeof(DiffReporter))]
[UseApprovalSubdirectory("ApprovalTests")]
public void CreateMultiModelPipelineFromIEstimatorAndBinaryClassifiers()
{
var context = new MLContext();
var pipeline = context.Transforms.Concatenate("output", "input")
.Append(context.Auto().BinaryClassification());
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
Approvals.Verify(json);
}
[Fact]
[UseApprovalSubdirectory("ApprovalTests")]
[UseReporter(typeof(DiffReporter))]
public void CreateMultiModelPipelineFromIEstimatorAndMultiClassifiers()
{
var context = new MLContext();
var pipeline = context.Transforms.Concatenate("output", "input")
.Append(context.Auto().MultiClassification());
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
Approvals.Verify(json);
}
[Fact]
[UseApprovalSubdirectory("ApprovalTests")]
[UseReporter(typeof(DiffReporter))]
public void CreateMultiModelPipelineFromIEstimatorAndRegressors()
{
var context = new MLContext();
var pipeline = context.Transforms.Concatenate("output", "input")
.Append(context.Auto().MultiClassification());
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
Approvals.Verify(json);
}
[Fact]
[UseApprovalSubdirectory("ApprovalTests")]
[UseReporter(typeof(DiffReporter))]
public void CreateMultiModelPipelineFromSweepableEstimatorAndMultiClassifiers()
{
var context = new MLContext();
var pipeline = SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption())
.Append(context.Auto().MultiClassification());
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
Approvals.Verify(json);
}
[Fact]
[UseApprovalSubdirectory("ApprovalTests")]
[UseReporter(typeof(DiffReporter))]
public void CreateMultiModelPipelineFromSweepableEstimatorPipelineAndMultiClassifiers()
{
var context = new MLContext();
var pipeline = context.Transforms.Concatenate("output", "input")
.Append(SweepableEstimatorFactory.CreateFeaturizeText(new FeaturizeTextOption()))
.Append(context.Auto().MultiClassification());
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
Approvals.Verify(json);
}
}
}

Просмотреть файл

@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Text.Json.Serialization;
using System.Text.Json;
namespace Microsoft.ML.AutoML.Test
{
internal class DoubleToDecimalConverter : JsonConverter<double>
{
public override double Read(ref Utf8JsonReader reader, Type type, JsonSerializerOptions options)
{
return Convert.ToDouble(reader.GetDecimal());
}
public override void Write(Utf8JsonWriter writer, double value, JsonSerializerOptions options)
{
writer.WriteNumberValue(Convert.ToDecimal(value));
}
}
}

Просмотреть файл

@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Text.Json.Serialization;
using System.Text.Json;
namespace Microsoft.ML.AutoML.Test
{
internal class FloatToDecimalConverter : JsonConverter<float>
{
public override float Read(ref Utf8JsonReader reader, Type type, JsonSerializerOptions options)
{
return Convert.ToSingle(reader.GetDecimal());
}
public override void Write(Utf8JsonWriter writer, float value, JsonSerializerOptions options)
{
writer.WriteNumberValue(Convert.ToDecimal(value));
}
}
}

Просмотреть файл

@ -31,6 +31,11 @@
</ItemGroup>
<ItemGroup>
<Compile Update="Template\SweepableEstimatorFactory.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>SweepableEstimatorFactory.tt</DependentUpon>
</Compile>
<Compile Update="Template\EstimatorType.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
@ -41,9 +46,23 @@
<AutoGen>True</AutoGen>
<DependentUpon>SearchSpace.tt</DependentUpon>
</Compile>
<Compile Update="Template\SweepableEstimator.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>SweepableEstimator.tt</DependentUpon>
</Compile>
<Compile Update="Template\SweepableEstimator_T_.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>SweepableEstimator_T_.tt</DependentUpon>
</Compile>
</ItemGroup>
<ItemGroup>
<None Update="Template\SweepableEstimatorFactory.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>SweepableEstimatorFactory.cs</LastGenOutput>
</None>
<None Update="Template\EstimatorType.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>EstimatorType.cs</LastGenOutput>
@ -52,6 +71,14 @@
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>SearchSpace.cs</LastGenOutput>
</None>
<None Update="Template\SweepableEstimator.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>SweepableEstimator.cs</LastGenOutput>
</None>
<None Update="Template\SweepableEstimator_T_.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>SweepableEstimator_T_.cs</LastGenOutput>
</None>
</ItemGroup>
</Project>

Просмотреть файл

@ -0,0 +1,55 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
using Microsoft.ML.AutoML.SourceGenerator.Template;
namespace Microsoft.ML.AutoML.SourceGenerator
{
[Generator]
public class SweepableEstimatorFactoryGenerator : ISourceGenerator
{
private const string className = "SweepableEstimatorFactory";
public void Execute(GeneratorExecutionContext context)
{
if (context.AdditionalFiles.Where(f => f.Path.Contains("code_gen_flag.json")).First() is AdditionalText text)
{
var json = text.GetText().ToString();
var flags = JsonSerializer.Deserialize<Dictionary<string, bool>>(json);
if (flags.TryGetValue(nameof(SweepableEstimatorFactoryGenerator), out var res) && res == false)
{
return;
}
}
var trainers = context.AdditionalFiles.Where(f => f.Path.Contains("trainer-estimators.json"))
.SelectMany(file => Utils.GetEstimatorsFromJson(file.GetText().ToString()).Estimators, (text, estimator) => (estimator.FunctionName, estimator.EstimatorTypes, estimator.SearchOption))
.SelectMany(union => union.EstimatorTypes.Select(t => (Utils.CreateEstimatorName(union.FunctionName, t), Utils.ToTitleCase(union.SearchOption))))
.ToArray();
var transformers = context.AdditionalFiles.Where(f => f.Path.Contains("transformer-estimators.json"))
.SelectMany(file => Utils.GetEstimatorsFromJson(file.GetText().ToString()).Estimators, (text, estimator) => (estimator.FunctionName, estimator.EstimatorTypes, estimator.SearchOption))
.SelectMany(union => union.EstimatorTypes.Select(t => (Utils.CreateEstimatorName(union.FunctionName, t), Utils.ToTitleCase(union.SearchOption))))
.ToArray();
var code = new SweepableEstimatorFactory()
{
NameSpace = Constant.CodeGeneratorNameSpace,
EstimatorNames = trainers.Concat(transformers),
};
context.AddSource(className + ".cs", SourceText.From(code.TransformText(), Encoding.UTF8));
}
public void Initialize(GeneratorInitializationContext context)
{
}
}
}

Просмотреть файл

@ -0,0 +1,88 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.Json;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using Microsoft.ML.AutoML.SourceGenerator;
using SweepableEstimator = Microsoft.ML.AutoML.SourceGenerator.Template.SweepableEstimator;
using SweepableEstimatorT = Microsoft.ML.AutoML.SourceGenerator.Template.SweepableEstimator_T_;
namespace Microsoft.ML.ModelBuilder.SweepableEstimator.CodeGenerator
{
[Generator]
public class SweepableEstimatorGenerator : ISourceGenerator
{
private const string SweepableEstimatorAttributeDisplayName = Constant.CodeGeneratorNameSpace + "." + "SweepableEstimatorAttribute";
public void Execute(GeneratorExecutionContext context)
{
if (context.AdditionalFiles.Where(f => f.Path.Contains("code_gen_flag.json")).First() is AdditionalText text)
{
var json = text.GetText().ToString();
var flags = JsonSerializer.Deserialize<Dictionary<string, bool>>(json);
if (flags.TryGetValue(nameof(SweepableEstimatorGenerator), out var res) && res == false)
{
return;
}
}
var estimators = context.AdditionalFiles.Where(f => f.Path.Contains("trainer-estimators.json") || f.Path.Contains("transformer-estimators.json"))
.SelectMany(file => Utils.GetEstimatorsFromJson(file.GetText().ToString()).Estimators)
.ToArray();
var code = estimators.SelectMany(e => e.EstimatorTypes.Select(eType => (e, eType, Utils.CreateEstimatorName(e.FunctionName, eType)))
.Select(x =>
{
if (x.e.SearchOption == null)
{
return
(x.Item3,
new AutoML.SourceGenerator.Template.SweepableEstimator()
{
NameSpace = Constant.CodeGeneratorNameSpace,
UsingStatements = x.e.UsingStatements,
ArgumentsList = x.e.ArgumentsList,
ClassName = x.Item3,
FunctionName = x.e.FunctionName,
NugetDependencies = x.e.NugetDependencies,
Type = x.eType,
}.TransformText());
}
else
{
return
(x.Item3,
new SweepableEstimatorT()
{
NameSpace = Constant.CodeGeneratorNameSpace,
UsingStatements = x.e.UsingStatements,
ArgumentsList = x.e.ArgumentsList,
ClassName = x.Item3,
FunctionName = x.e.FunctionName,
NugetDependencies = x.e.NugetDependencies,
Type = x.eType,
TOption = Utils.ToTitleCase(x.e.SearchOption),
}.TransformText());
}
}));
foreach (var c in code)
{
context.AddSource(c.Item1 + ".cs", SourceText.From(c.Item2, Encoding.UTF8));
}
}
public void Initialize(GeneratorInitializationContext context)
{
return;
//context.RegisterForPostInitialization(i => i.AddSource(nameof(SweepableEstimatorAttribute), SweepableEstimatorAttribute));
}
}
}

Просмотреть файл

@ -40,9 +40,9 @@ namespace Microsoft.ML.AutoML.SourceGenerator.Template
this.Write(this.ToStringHelper.ToStringWithCulture(e));
this.Write(",\r\n");
}
this.Write(" }\r\n\r\n public static class EstimatorTypeExtension\r\n {\r\n public st" +
"atic bool IsTrainer(this EstimatorType estimatorType)\r\n {\r\n sw" +
"itch(estimatorType)\r\n {\r\n");
this.Write(" Unknown,\r\n }\r\n\r\n public static class EstimatorTypeExtension\r\n {\r" +
"\n public static bool IsTrainer(this EstimatorType estimatorType)\r\n " +
" {\r\n switch(estimatorType)\r\n {\r\n");
foreach(var estimator in TrainerNames){
this.Write(" case EstimatorType.");
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));

Просмотреть файл

@ -18,6 +18,7 @@ namespace <#=NameSpace#>
<# foreach(var e in TransformerNames){#>
<#=e#>,
<#}#>
Unknown,
}
public static class EstimatorTypeExtension

Просмотреть файл

@ -0,0 +1,353 @@
// ------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by a tool.
// Runtime Version: 17.0.0.0
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
// ------------------------------------------------------------------------------
namespace Microsoft.ML.AutoML.SourceGenerator.Template
{
using System.Linq;
using System.Text;
using System.Collections.Generic;
using System;
/// <summary>
/// Class to produce the template output
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal partial class SweepableEstimator : SweepableEstimatorBase
{
/// <summary>
/// Create the template output
/// </summary>
public virtual string TransformText()
{
this.Write(@"
using System.Collections.Generic;
using Newtonsoft.Json;
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
using Microsoft.ML.AutoML.CodeGen;
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
namespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
this.Write("\r\n{\r\n internal partial class ");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write(" : SweepableEstimator\r\n {\r\n public ");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write("()\r\n {\r\n this.EstimatorType = EstimatorType.");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write(";\r\n }\r\n \r\n");
foreach(var arg in ArgumentsList){
var typeAttributeName = Utils.CapitalFirstLetter(arg.ArgumentType);
var propertyName = Utils.CapitalFirstLetter(arg.ArgumentName);
this.Write(" [");
this.Write(this.ToStringHelper.ToStringWithCulture(typeAttributeName));
this.Write("]\r\n [JsonProperty(NullValueHandling=NullValueHandling.Ignore)]\r\n pu" +
"blic string ");
this.Write(this.ToStringHelper.ToStringWithCulture(propertyName));
this.Write(" { get; set; }\r\n\r\n");
}
this.Write(" internal override IEnumerable<string> CSharpUsingStatements \r\n {\r\n" +
" get => new string[] {");
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))));
this.Write("};\r\n }\r\n\r\n internal override IEnumerable<string> NugetDependencies\r" +
"\n {\r\n get => new string[] {");
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(NugetDependencies)));
this.Write("};\r\n }\r\n\r\n internal override string FunctionName \r\n {\r\n " +
" get => \"");
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.GetPrefix(Type)));
this.Write(".");
this.Write(this.ToStringHelper.ToStringWithCulture(FunctionName));
this.Write("\";\r\n }\r\n }\r\n}\r\n\r\n");
return this.GenerationEnvironment.ToString();
}
public string NameSpace {get;set;}
public string ClassName {get;set;}
public string FunctionName {get;set;}
public string Type {get;set;}
public IEnumerable<Argument> ArgumentsList {get;set;}
public IEnumerable<string> UsingStatements {get; set;}
public IEnumerable<string> NugetDependencies {get; set;}
}
#region Base class
/// <summary>
/// Base class for this transformation
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal class SweepableEstimatorBase
{
#region Fields
private global::System.Text.StringBuilder generationEnvironmentField;
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
private global::System.Collections.Generic.List<int> indentLengthsField;
private string currentIndentField = "";
private bool endsWithNewline;
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
#endregion
#region Properties
/// <summary>
/// The string builder that generation-time code is using to assemble generated output
/// </summary>
protected System.Text.StringBuilder GenerationEnvironment
{
get
{
if ((this.generationEnvironmentField == null))
{
this.generationEnvironmentField = new global::System.Text.StringBuilder();
}
return this.generationEnvironmentField;
}
set
{
this.generationEnvironmentField = value;
}
}
/// <summary>
/// The error collection for the generation process
/// </summary>
public System.CodeDom.Compiler.CompilerErrorCollection Errors
{
get
{
if ((this.errorsField == null))
{
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
}
return this.errorsField;
}
}
/// <summary>
/// A list of the lengths of each indent that was added with PushIndent
/// </summary>
private System.Collections.Generic.List<int> indentLengths
{
get
{
if ((this.indentLengthsField == null))
{
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
}
return this.indentLengthsField;
}
}
/// <summary>
/// Gets the current indent we use when adding lines to the output
/// </summary>
public string CurrentIndent
{
get
{
return this.currentIndentField;
}
}
/// <summary>
/// Current transformation session
/// </summary>
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
{
get
{
return this.sessionField;
}
set
{
this.sessionField = value;
}
}
#endregion
#region Transform-time helpers
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void Write(string textToAppend)
{
if (string.IsNullOrEmpty(textToAppend))
{
return;
}
// If we're starting off, or if the previous text ended with a newline,
// we have to append the current indent first.
if (((this.GenerationEnvironment.Length == 0)
|| this.endsWithNewline))
{
this.GenerationEnvironment.Append(this.currentIndentField);
this.endsWithNewline = false;
}
// Check if the current text ends with a newline
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
{
this.endsWithNewline = true;
}
// This is an optimization. If the current indent is "", then we don't have to do any
// of the more complex stuff further down.
if ((this.currentIndentField.Length == 0))
{
this.GenerationEnvironment.Append(textToAppend);
return;
}
// Everywhere there is a newline in the text, add an indent after it
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
// If the text ends with a newline, then we should strip off the indent added at the very end
// because the appropriate indent will be added when the next time Write() is called
if (this.endsWithNewline)
{
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
}
else
{
this.GenerationEnvironment.Append(textToAppend);
}
}
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void WriteLine(string textToAppend)
{
this.Write(textToAppend);
this.GenerationEnvironment.AppendLine();
this.endsWithNewline = true;
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void Write(string format, params object[] args)
{
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void WriteLine(string format, params object[] args)
{
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Raise an error
/// </summary>
public void Error(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
this.Errors.Add(error);
}
/// <summary>
/// Raise a warning
/// </summary>
public void Warning(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
error.IsWarning = true;
this.Errors.Add(error);
}
/// <summary>
/// Increase the indent
/// </summary>
public void PushIndent(string indent)
{
if ((indent == null))
{
throw new global::System.ArgumentNullException("indent");
}
this.currentIndentField = (this.currentIndentField + indent);
this.indentLengths.Add(indent.Length);
}
/// <summary>
/// Remove the last indent that was added with PushIndent
/// </summary>
public string PopIndent()
{
string returnValue = "";
if ((this.indentLengths.Count > 0))
{
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
if ((indentLength > 0))
{
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
}
}
return returnValue;
}
/// <summary>
/// Remove any indentation
/// </summary>
public void ClearIndent()
{
this.indentLengths.Clear();
this.currentIndentField = "";
}
#endregion
#region ToString Helpers
/// <summary>
/// Utility class to produce culture-oriented representation of an object as a string.
/// </summary>
public class ToStringInstanceHelper
{
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
/// <summary>
/// Gets or sets format provider to be used by ToStringWithCulture method.
/// </summary>
public System.IFormatProvider FormatProvider
{
get
{
return this.formatProviderField ;
}
set
{
if ((value != null))
{
this.formatProviderField = value;
}
}
}
/// <summary>
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
/// </summary>
public string ToStringWithCulture(object objectToConvert)
{
if ((objectToConvert == null))
{
throw new global::System.ArgumentNullException("objectToConvert");
}
System.Type t = objectToConvert.GetType();
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
typeof(System.IFormatProvider)});
if ((method == null))
{
return objectToConvert.ToString();
}
else
{
return ((string)(method.Invoke(objectToConvert, new object[] {
this.formatProviderField })));
}
}
}
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
/// <summary>
/// Helper to produce culture-oriented representation of an object as a string
/// </summary>
public ToStringInstanceHelper ToStringHelper
{
get
{
return this.toStringHelperField;
}
}
#endregion
}
#endregion
}

Просмотреть файл

@ -0,0 +1,58 @@
<#@ template language="C#" linePragmas="false" visibility = "internal"#>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
using System.Collections.Generic;
using Newtonsoft.Json;
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
using Microsoft.ML.AutoML.CodeGen;
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
namespace <#=NameSpace#>
{
internal partial class <#=ClassName#> : SweepableEstimator
{
public <#=ClassName#>()
{
this.EstimatorType = EstimatorType.<#=ClassName#>;
}
<# foreach(var arg in ArgumentsList){
var typeAttributeName = Utils.CapitalFirstLetter(arg.ArgumentType);
var propertyName = Utils.CapitalFirstLetter(arg.ArgumentName);#>
[<#=typeAttributeName#>]
[JsonProperty(NullValueHandling=NullValueHandling.Ignore)]
public string <#=propertyName#> { get; set; }
<#}#>
internal override IEnumerable<string> CSharpUsingStatements
{
get => new string[] {<#=Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))#>};
}
internal override IEnumerable<string> NugetDependencies
{
get => new string[] {<#=Utils.PrettyPrintListOfString(NugetDependencies)#>};
}
internal override string FunctionName
{
get => "<#=Utils.GetPrefix(Type)#>.<#=FunctionName#>";
}
}
}
<#+
public string NameSpace {get;set;}
public string ClassName {get;set;}
public string FunctionName {get;set;}
public string Type {get;set;}
public IEnumerable<Argument> ArgumentsList {get;set;}
public IEnumerable<string> UsingStatements {get; set;}
public IEnumerable<string> NugetDependencies {get; set;}
#>

Просмотреть файл

@ -0,0 +1,329 @@
// ------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by a tool.
// Runtime Version: 17.0.0.0
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
// ------------------------------------------------------------------------------
namespace Microsoft.ML.AutoML.SourceGenerator.Template
{
using System.Linq;
using System.Text;
using System.Collections.Generic;
using System;
/// <summary>
/// Class to produce the template output
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal partial class SweepableEstimatorFactory : SweepableEstimatorFactoryBase
{
/// <summary>
/// Create the template output
/// </summary>
public virtual string TransformText()
{
this.Write("\r\nusing System;\r\nusing System.Collections.Generic;\r\nusing System.Text;\r\nusing New" +
"tonsoft.Json;\r\nusing Newtonsoft.Json.Linq;\r\nusing Microsoft.ML.SearchSpace;\r\nusi" +
"ng Microsoft.ML;\r\n\r\nnamespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
this.Write("\r\n{\r\n internal static class SweepableEstimatorFactory\r\n {\r\n");
foreach((var estimator, var tOption) in EstimatorNames){
this.Write(" public static ");
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
this.Write(" Create");
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
this.Write("(");
this.Write(this.ToStringHelper.ToStringWithCulture(tOption));
this.Write(" defaultOption, SearchSpace<");
this.Write(this.ToStringHelper.ToStringWithCulture(tOption));
this.Write("> searchSpace = null)\r\n {\r\n if(searchSpace == null){\r\n " +
" searchSpace = new SearchSpace<");
this.Write(this.ToStringHelper.ToStringWithCulture(tOption));
this.Write(">(defaultOption);\r\n }\r\n\r\n return new ");
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
this.Write("(defaultOption, searchSpace);\r\n }\r\n\r\n");
}
this.Write(" }\r\n}\r\n\r\n");
return this.GenerationEnvironment.ToString();
}
public string NameSpace {get;set;}
public IEnumerable<(string, string)> EstimatorNames {get;set;}
}
#region Base class
/// <summary>
/// Base class for this transformation
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal class SweepableEstimatorFactoryBase
{
#region Fields
private global::System.Text.StringBuilder generationEnvironmentField;
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
private global::System.Collections.Generic.List<int> indentLengthsField;
private string currentIndentField = "";
private bool endsWithNewline;
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
#endregion
#region Properties
/// <summary>
/// The string builder that generation-time code is using to assemble generated output
/// </summary>
protected System.Text.StringBuilder GenerationEnvironment
{
get
{
if ((this.generationEnvironmentField == null))
{
this.generationEnvironmentField = new global::System.Text.StringBuilder();
}
return this.generationEnvironmentField;
}
set
{
this.generationEnvironmentField = value;
}
}
/// <summary>
/// The error collection for the generation process
/// </summary>
public System.CodeDom.Compiler.CompilerErrorCollection Errors
{
get
{
if ((this.errorsField == null))
{
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
}
return this.errorsField;
}
}
/// <summary>
/// A list of the lengths of each indent that was added with PushIndent
/// </summary>
private System.Collections.Generic.List<int> indentLengths
{
get
{
if ((this.indentLengthsField == null))
{
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
}
return this.indentLengthsField;
}
}
/// <summary>
/// Gets the current indent we use when adding lines to the output
/// </summary>
public string CurrentIndent
{
get
{
return this.currentIndentField;
}
}
/// <summary>
/// Current transformation session
/// </summary>
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
{
get
{
return this.sessionField;
}
set
{
this.sessionField = value;
}
}
#endregion
#region Transform-time helpers
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void Write(string textToAppend)
{
if (string.IsNullOrEmpty(textToAppend))
{
return;
}
// If we're starting off, or if the previous text ended with a newline,
// we have to append the current indent first.
if (((this.GenerationEnvironment.Length == 0)
|| this.endsWithNewline))
{
this.GenerationEnvironment.Append(this.currentIndentField);
this.endsWithNewline = false;
}
// Check if the current text ends with a newline
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
{
this.endsWithNewline = true;
}
// This is an optimization. If the current indent is "", then we don't have to do any
// of the more complex stuff further down.
if ((this.currentIndentField.Length == 0))
{
this.GenerationEnvironment.Append(textToAppend);
return;
}
// Everywhere there is a newline in the text, add an indent after it
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
// If the text ends with a newline, then we should strip off the indent added at the very end
// because the appropriate indent will be added when the next time Write() is called
if (this.endsWithNewline)
{
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
}
else
{
this.GenerationEnvironment.Append(textToAppend);
}
}
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void WriteLine(string textToAppend)
{
this.Write(textToAppend);
this.GenerationEnvironment.AppendLine();
this.endsWithNewline = true;
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void Write(string format, params object[] args)
{
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void WriteLine(string format, params object[] args)
{
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Raise an error
/// </summary>
public void Error(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
this.Errors.Add(error);
}
/// <summary>
/// Raise a warning
/// </summary>
public void Warning(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
error.IsWarning = true;
this.Errors.Add(error);
}
/// <summary>
/// Increase the indent
/// </summary>
public void PushIndent(string indent)
{
if ((indent == null))
{
throw new global::System.ArgumentNullException("indent");
}
this.currentIndentField = (this.currentIndentField + indent);
this.indentLengths.Add(indent.Length);
}
/// <summary>
/// Remove the last indent that was added with PushIndent
/// </summary>
public string PopIndent()
{
string returnValue = "";
if ((this.indentLengths.Count > 0))
{
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
if ((indentLength > 0))
{
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
}
}
return returnValue;
}
/// <summary>
/// Remove any indentation
/// </summary>
public void ClearIndent()
{
this.indentLengths.Clear();
this.currentIndentField = "";
}
#endregion
#region ToString Helpers
/// <summary>
/// Utility class to produce culture-oriented representation of an object as a string.
/// </summary>
public class ToStringInstanceHelper
{
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
/// <summary>
/// Gets or sets format provider to be used by ToStringWithCulture method.
/// </summary>
public System.IFormatProvider FormatProvider
{
get
{
return this.formatProviderField ;
}
set
{
if ((value != null))
{
this.formatProviderField = value;
}
}
}
/// <summary>
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
/// </summary>
public string ToStringWithCulture(object objectToConvert)
{
if ((objectToConvert == null))
{
throw new global::System.ArgumentNullException("objectToConvert");
}
System.Type t = objectToConvert.GetType();
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
typeof(System.IFormatProvider)});
if ((method == null))
{
return objectToConvert.ToString();
}
else
{
return ((string)(method.Invoke(objectToConvert, new object[] {
this.formatProviderField })));
}
}
}
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
/// <summary>
/// Helper to produce culture-oriented representation of an object as a string
/// </summary>
public ToStringInstanceHelper ToStringHelper
{
get
{
return this.toStringHelperField;
}
}
#endregion
}
#endregion
}

Просмотреть файл

@ -0,0 +1,36 @@
<#@ template language="C#" linePragmas="false" visibility = "internal" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
using System;
using System.Collections.Generic;
using System.Text;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Microsoft.ML.SearchSpace;
using Microsoft.ML;
namespace <#=NameSpace#>
{
internal static class SweepableEstimatorFactory
{
<# foreach((var estimator, var tOption) in EstimatorNames){#>
public static <#=estimator#> Create<#=estimator#>(<#=tOption#> defaultOption, SearchSpace<<#=tOption#>> searchSpace = null)
{
if(searchSpace == null){
searchSpace = new SearchSpace<<#=tOption#>>(defaultOption);
}
return new <#=estimator#>(defaultOption, searchSpace);
}
<#}#>
}
}
<#+
public string NameSpace {get;set;}
public IEnumerable<(string, string)> EstimatorNames {get;set;}
#>

Просмотреть файл

@ -0,0 +1,358 @@
// ------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by a tool.
// Runtime Version: 17.0.0.0
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
// ------------------------------------------------------------------------------
namespace Microsoft.ML.AutoML.SourceGenerator.Template
{
using System.Linq;
using System.Text;
using System.Collections.Generic;
using System;
/// <summary>
/// Class to produce the template output
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal partial class SweepableEstimator_T_ : SweepableEstimator_T_Base
{
/// <summary>
/// Create the template output
/// </summary>
public virtual string TransformText()
{
this.Write(@"
using System.Collections.Generic;
using Newtonsoft.Json;
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
using Microsoft.ML.AutoML.CodeGen;
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
using Microsoft.ML.SearchSpace;
namespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
this.Write("\r\n{\r\n internal partial class ");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write(" : SweepableEstimator<");
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
this.Write(">\r\n {\r\n public ");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write("(");
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
this.Write(" defaultOption, SearchSpace<");
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
this.Write("> searchSpace = null)\r\n {\r\n this.TParameter = defaultOption;\r\n " +
" this.SearchSpace = searchSpace;\r\n this.EstimatorType = Est" +
"imatorType.");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write(";\r\n }\r\n\r\n internal ");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write("()\r\n {\r\n this.EstimatorType = EstimatorType.");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write(";\r\n this.TParameter = new ");
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
this.Write("();\r\n }\r\n \r\n internal override IEnumerable<string> CSharpUsingSt" +
"atements \r\n {\r\n get => new string[] {");
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))));
this.Write("};\r\n }\r\n\r\n internal override IEnumerable<string> NugetDependencies\r" +
"\n {\r\n get => new string[] {");
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(NugetDependencies)));
this.Write("};\r\n }\r\n\r\n internal override string FunctionName \r\n {\r\n " +
" get => \"");
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.GetPrefix(Type)));
this.Write(".");
this.Write(this.ToStringHelper.ToStringWithCulture(FunctionName));
this.Write("\";\r\n }\r\n }\r\n}\r\n\r\n");
return this.GenerationEnvironment.ToString();
}
public string NameSpace {get;set;}
public string ClassName {get;set;}
public string FunctionName {get;set;}
public string Type {get;set;}
public IEnumerable<Argument> ArgumentsList {get;set;}
public IEnumerable<string> UsingStatements {get; set;}
public IEnumerable<string> NugetDependencies {get; set;}
public string TOption {get; set;}
}
#region Base class
/// <summary>
/// Base class for this transformation
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal class SweepableEstimator_T_Base
{
#region Fields
private global::System.Text.StringBuilder generationEnvironmentField;
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
private global::System.Collections.Generic.List<int> indentLengthsField;
private string currentIndentField = "";
private bool endsWithNewline;
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
#endregion
#region Properties
/// <summary>
/// The string builder that generation-time code is using to assemble generated output
/// </summary>
protected System.Text.StringBuilder GenerationEnvironment
{
get
{
if ((this.generationEnvironmentField == null))
{
this.generationEnvironmentField = new global::System.Text.StringBuilder();
}
return this.generationEnvironmentField;
}
set
{
this.generationEnvironmentField = value;
}
}
/// <summary>
/// The error collection for the generation process
/// </summary>
public System.CodeDom.Compiler.CompilerErrorCollection Errors
{
get
{
if ((this.errorsField == null))
{
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
}
return this.errorsField;
}
}
/// <summary>
/// A list of the lengths of each indent that was added with PushIndent
/// </summary>
private System.Collections.Generic.List<int> indentLengths
{
get
{
if ((this.indentLengthsField == null))
{
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
}
return this.indentLengthsField;
}
}
/// <summary>
/// Gets the current indent we use when adding lines to the output
/// </summary>
public string CurrentIndent
{
get
{
return this.currentIndentField;
}
}
/// <summary>
/// Current transformation session
/// </summary>
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
{
get
{
return this.sessionField;
}
set
{
this.sessionField = value;
}
}
#endregion
#region Transform-time helpers
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void Write(string textToAppend)
{
if (string.IsNullOrEmpty(textToAppend))
{
return;
}
// If we're starting off, or if the previous text ended with a newline,
// we have to append the current indent first.
if (((this.GenerationEnvironment.Length == 0)
|| this.endsWithNewline))
{
this.GenerationEnvironment.Append(this.currentIndentField);
this.endsWithNewline = false;
}
// Check if the current text ends with a newline
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
{
this.endsWithNewline = true;
}
// This is an optimization. If the current indent is "", then we don't have to do any
// of the more complex stuff further down.
if ((this.currentIndentField.Length == 0))
{
this.GenerationEnvironment.Append(textToAppend);
return;
}
// Everywhere there is a newline in the text, add an indent after it
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
// If the text ends with a newline, then we should strip off the indent added at the very end
// because the appropriate indent will be added when the next time Write() is called
if (this.endsWithNewline)
{
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
}
else
{
this.GenerationEnvironment.Append(textToAppend);
}
}
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void WriteLine(string textToAppend)
{
this.Write(textToAppend);
this.GenerationEnvironment.AppendLine();
this.endsWithNewline = true;
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void Write(string format, params object[] args)
{
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void WriteLine(string format, params object[] args)
{
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Raise an error
/// </summary>
public void Error(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
this.Errors.Add(error);
}
/// <summary>
/// Raise a warning
/// </summary>
public void Warning(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
error.IsWarning = true;
this.Errors.Add(error);
}
/// <summary>
/// Increase the indent
/// </summary>
public void PushIndent(string indent)
{
if ((indent == null))
{
throw new global::System.ArgumentNullException("indent");
}
this.currentIndentField = (this.currentIndentField + indent);
this.indentLengths.Add(indent.Length);
}
/// <summary>
/// Remove the last indent that was added with PushIndent
/// </summary>
public string PopIndent()
{
string returnValue = "";
if ((this.indentLengths.Count > 0))
{
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
if ((indentLength > 0))
{
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
}
}
return returnValue;
}
/// <summary>
/// Remove any indentation
/// </summary>
public void ClearIndent()
{
this.indentLengths.Clear();
this.currentIndentField = "";
}
#endregion
#region ToString Helpers
/// <summary>
/// Utility class to produce culture-oriented representation of an object as a string.
/// </summary>
public class ToStringInstanceHelper
{
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
/// <summary>
/// Gets or sets format provider to be used by ToStringWithCulture method.
/// </summary>
public System.IFormatProvider FormatProvider
{
get
{
return this.formatProviderField ;
}
set
{
if ((value != null))
{
this.formatProviderField = value;
}
}
}
/// <summary>
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
/// </summary>
public string ToStringWithCulture(object objectToConvert)
{
if ((objectToConvert == null))
{
throw new global::System.ArgumentNullException("objectToConvert");
}
System.Type t = objectToConvert.GetType();
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
typeof(System.IFormatProvider)});
if ((method == null))
{
return objectToConvert.ToString();
}
else
{
return ((string)(method.Invoke(objectToConvert, new object[] {
this.formatProviderField })));
}
}
}
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
/// <summary>
/// Helper to produce culture-oriented representation of an object as a string
/// </summary>
public ToStringInstanceHelper ToStringHelper
{
get
{
return this.toStringHelperField;
}
}
#endregion
}
#endregion
}

Просмотреть файл

@ -0,0 +1,60 @@
<#@ template language="C#" linePragmas="false" visibility = "internal"#>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
using System.Collections.Generic;
using Newtonsoft.Json;
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
using Microsoft.ML.AutoML.CodeGen;
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
using Microsoft.ML.SearchSpace;
namespace <#=NameSpace#>
{
internal partial class <#=ClassName#> : SweepableEstimator<<#=TOption#>>
{
public <#=ClassName#>(<#=TOption#> defaultOption, SearchSpace<<#=TOption#>> searchSpace = null)
{
this.TParameter = defaultOption;
this.SearchSpace = searchSpace;
this.EstimatorType = EstimatorType.<#=ClassName#>;
}
internal <#=ClassName#>()
{
this.EstimatorType = EstimatorType.<#=ClassName#>;
this.TParameter = new <#=TOption#>();
}
internal override IEnumerable<string> CSharpUsingStatements
{
get => new string[] {<#=Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))#>};
}
internal override IEnumerable<string> NugetDependencies
{
get => new string[] {<#=Utils.PrettyPrintListOfString(NugetDependencies)#>};
}
internal override string FunctionName
{
get => "<#=Utils.GetPrefix(Type)#>.<#=FunctionName#>";
}
}
}
<#+
public string NameSpace {get;set;}
public string ClassName {get;set;}
public string FunctionName {get;set;}
public string Type {get;set;}
public IEnumerable<Argument> ArgumentsList {get;set;}
public IEnumerable<string> UsingStatements {get; set;}
public IEnumerable<string> NugetDependencies {get; set;}
public string TOption {get; set;}
#>

Просмотреть файл

@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@ -70,5 +71,59 @@ namespace Microsoft.ML.AutoML.SourceGenerator
{
return string.Join(string.Empty, str.Split('_', ' ', '-').Select(x => CapitalFirstLetter(x)));
}
public static string GetPrefix(string estimatorType)
{
if (estimatorType == "BinaryClassification")
{
return "BinaryClassification.Trainers";
}
if (estimatorType == "MultiClassification")
{
return "MulticlassClassification.Trainers";
}
if (estimatorType == "Regression")
{
return "Regression.Trainers";
}
if (estimatorType == "Ranking")
{
return "Ranking.Trainers";
}
if (estimatorType == "OneVersusAll")
{
return "BinaryClassification.Trainers";
}
if (estimatorType == "Recommendation")
{
return "Recommendation().Trainers";
}
if (estimatorType == "Transforms")
{
return "Transforms";
}
if (estimatorType == "Categorical")
{
return "Transforms.Categorical";
}
if (estimatorType == "Conversion")
{
return "Transforms.Conversion";
}
if (estimatorType == "Text")
{
return "Transforms.Text";
}
if (estimatorType == "Calibrators")
{
return "BinaryClassification.Calibrators";
}
if (estimatorType == "Forecasting")
{
return "Forecasting";
}
throw new NotImplementedException();
}
}
}