Generate built-in SweepableEstimator classes for all available estimators (#6125)
This commit is contained in:
Родитель
bfba5d9836
Коммит
a758217121
|
@ -3,6 +3,10 @@
|
|||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.Contracts;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
|
||||
|
@ -289,5 +293,170 @@ namespace Microsoft.ML.AutoML
|
|||
{
|
||||
return new SweepableEstimator((MLContext context, Parameter param) => factory(context, param.AsType<T>()), ss);
|
||||
}
|
||||
|
||||
internal SweepableEstimator[] BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
|
||||
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
|
||||
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
|
||||
{
|
||||
var res = new List<SweepableEstimator>();
|
||||
|
||||
if (useFastTree)
|
||||
{
|
||||
fastTreeOption = fastTreeOption ?? new FastTreeOption();
|
||||
fastTreeOption.LabelColumnName = labelColumnName;
|
||||
fastTreeOption.FeatureColumnName = featureColumnName;
|
||||
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateFastTreeBinary(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
|
||||
}
|
||||
|
||||
if (useFastForest)
|
||||
{
|
||||
fastForestOption = fastForestOption ?? new FastForestOption();
|
||||
fastForestOption.LabelColumnName = labelColumnName;
|
||||
fastForestOption.FeatureColumnName = featureColumnName;
|
||||
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateFastForestBinary(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
|
||||
}
|
||||
|
||||
if (useLgbm)
|
||||
{
|
||||
lgbmOption = lgbmOption ?? new LgbmOption();
|
||||
lgbmOption.LabelColumnName = labelColumnName;
|
||||
lgbmOption.FeatureColumnName = featureColumnName;
|
||||
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateLightGbmBinary(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
|
||||
}
|
||||
|
||||
if (useLbfgs)
|
||||
{
|
||||
lbfgsOption = lbfgsOption ?? new LbfgsOption();
|
||||
lbfgsOption.LabelColumnName = labelColumnName;
|
||||
lbfgsOption.FeatureColumnName = featureColumnName;
|
||||
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionBinary(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
|
||||
}
|
||||
|
||||
if (useSdca)
|
||||
{
|
||||
sdcaOption = sdcaOption ?? new SdcaOption();
|
||||
sdcaOption.LabelColumnName = labelColumnName;
|
||||
sdcaOption.FeatureColumnName = featureColumnName;
|
||||
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionBinary(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
|
||||
}
|
||||
|
||||
return res.ToArray();
|
||||
}
|
||||
|
||||
internal SweepableEstimator[] MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
|
||||
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
|
||||
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
|
||||
{
|
||||
var res = new List<SweepableEstimator>();
|
||||
|
||||
if (useFastTree)
|
||||
{
|
||||
fastTreeOption = fastTreeOption ?? new FastTreeOption();
|
||||
fastTreeOption.LabelColumnName = labelColumnName;
|
||||
fastTreeOption.FeatureColumnName = featureColumnName;
|
||||
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateFastTreeOva(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
|
||||
}
|
||||
|
||||
if (useFastForest)
|
||||
{
|
||||
fastForestOption = fastForestOption ?? new FastForestOption();
|
||||
fastForestOption.LabelColumnName = labelColumnName;
|
||||
fastForestOption.FeatureColumnName = featureColumnName;
|
||||
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateFastForestOva(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
|
||||
}
|
||||
|
||||
if (useLgbm)
|
||||
{
|
||||
lgbmOption = lgbmOption ?? new LgbmOption();
|
||||
lgbmOption.LabelColumnName = labelColumnName;
|
||||
lgbmOption.FeatureColumnName = featureColumnName;
|
||||
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateLightGbmMulti(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
|
||||
}
|
||||
|
||||
if (useLbfgs)
|
||||
{
|
||||
lbfgsOption = lbfgsOption ?? new LbfgsOption();
|
||||
lbfgsOption.LabelColumnName = labelColumnName;
|
||||
lbfgsOption.FeatureColumnName = featureColumnName;
|
||||
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionOva(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
|
||||
res.Add(SweepableEstimatorFactory.CreateLbfgsMaximumEntropyMulti(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
|
||||
}
|
||||
|
||||
if (useSdca)
|
||||
{
|
||||
sdcaOption = sdcaOption ?? new SdcaOption();
|
||||
sdcaOption.LabelColumnName = labelColumnName;
|
||||
sdcaOption.FeatureColumnName = featureColumnName;
|
||||
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateSdcaMaximumEntropyMulti(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
|
||||
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionOva(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
|
||||
}
|
||||
|
||||
return res.ToArray();
|
||||
}
|
||||
|
||||
internal SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
|
||||
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
|
||||
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
|
||||
{
|
||||
var res = new List<SweepableEstimator>();
|
||||
|
||||
if (useFastTree)
|
||||
{
|
||||
fastTreeOption = fastTreeOption ?? new FastTreeOption();
|
||||
fastTreeOption.LabelColumnName = labelColumnName;
|
||||
fastTreeOption.FeatureColumnName = featureColumnName;
|
||||
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateFastTreeRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
|
||||
res.Add(SweepableEstimatorFactory.CreateFastTreeTweedieRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
|
||||
}
|
||||
|
||||
if (useFastForest)
|
||||
{
|
||||
fastForestOption = fastForestOption ?? new FastForestOption();
|
||||
fastForestOption.LabelColumnName = labelColumnName;
|
||||
fastForestOption.FeatureColumnName = featureColumnName;
|
||||
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateFastForestRegression(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
|
||||
}
|
||||
|
||||
if (useLgbm)
|
||||
{
|
||||
lgbmOption = lgbmOption ?? new LgbmOption();
|
||||
lgbmOption.LabelColumnName = labelColumnName;
|
||||
lgbmOption.FeatureColumnName = featureColumnName;
|
||||
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateLightGbmRegression(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
|
||||
}
|
||||
|
||||
if (useLbfgs)
|
||||
{
|
||||
lbfgsOption = lbfgsOption ?? new LbfgsOption();
|
||||
lbfgsOption.LabelColumnName = labelColumnName;
|
||||
lbfgsOption.FeatureColumnName = featureColumnName;
|
||||
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateLbfgsPoissonRegressionRegression(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
|
||||
}
|
||||
|
||||
if (useSdca)
|
||||
{
|
||||
sdcaOption = sdcaOption ?? new SdcaOption();
|
||||
sdcaOption.LabelColumnName = labelColumnName;
|
||||
sdcaOption.FeatureColumnName = featureColumnName;
|
||||
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
|
||||
res.Add(SweepableEstimatorFactory.CreateSdcaRegression(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
|
||||
}
|
||||
|
||||
return res.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
|
@ -27,7 +28,34 @@ namespace Microsoft.ML.AutoML
|
|||
|
||||
public static SweepableEstimatorPipeline Append(this SweepableEstimator estimator, IEstimator<ITransformer> estimator1)
|
||||
{
|
||||
return estimator.Append(estimator1);
|
||||
return new SweepableEstimatorPipeline().Append(estimator).Append(estimator1);
|
||||
}
|
||||
|
||||
public static MultiModelPipeline Append(this IEstimator<ITransformer> estimator, params SweepableEstimator[] estimators)
|
||||
{
|
||||
var sweepableEstimator = new SweepableEstimator((context, parameter) => estimator, new SearchSpace.SearchSpace());
|
||||
var multiModelPipeline = new MultiModelPipeline().Append(sweepableEstimator).Append(estimators);
|
||||
|
||||
return multiModelPipeline;
|
||||
}
|
||||
|
||||
public static MultiModelPipeline Append(this SweepableEstimatorPipeline pipeline, params SweepableEstimator[] estimators)
|
||||
{
|
||||
var multiModelPipeline = new MultiModelPipeline();
|
||||
foreach (var estimator in pipeline.Estimators)
|
||||
{
|
||||
multiModelPipeline = multiModelPipeline.Append(estimator);
|
||||
}
|
||||
|
||||
return multiModelPipeline.Append(estimators);
|
||||
}
|
||||
|
||||
public static MultiModelPipeline Append(this SweepableEstimator estimator, params SweepableEstimator[] estimators)
|
||||
{
|
||||
var multiModelPipeline = new MultiModelPipeline();
|
||||
multiModelPipeline = multiModelPipeline.Append(estimator);
|
||||
|
||||
return multiModelPipeline.Append(estimators);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,12 +3,46 @@
|
|||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
{
|
||||
internal static class AutoMlUtils
|
||||
{
|
||||
private const string MLNetMaxThread = "MLNET_MAX_THREAD";
|
||||
|
||||
public static readonly ThreadLocal<Random> Random = new ThreadLocal<Random>(() => new Random());
|
||||
|
||||
/// <summary>
|
||||
/// Return number of thread if MLNET_MAX_THREAD is set, otherwise return null.
|
||||
/// </summary>
|
||||
public static int? GetNumberOfThreadFromEnvrionment()
|
||||
{
|
||||
var res = Environment.GetEnvironmentVariable(MLNetMaxThread);
|
||||
|
||||
if (int.TryParse(res, out var numberOfThread))
|
||||
{
|
||||
return numberOfThread;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public static InputOutputColumnPair[] CreateInputOutputColumnPairsFromStrings(string[] inputs, string[] outputs)
|
||||
{
|
||||
if (inputs.Length != outputs.Length)
|
||||
{
|
||||
throw new Exception("inputs and outputs count must match");
|
||||
}
|
||||
|
||||
var res = new List<InputOutputColumnPair>();
|
||||
for (int i = 0; i != inputs.Length; ++i)
|
||||
{
|
||||
res.Add(new InputOutputColumnPair(outputs[i], inputs[i]));
|
||||
}
|
||||
|
||||
return res.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"EstimatorFactoryGenerator": false,
|
||||
"CodeGenCatalogGenerator": false,
|
||||
"SweepableEstimatorFactory": true,
|
||||
"EstimatorTypeGenerator": true,
|
||||
"SearchSpaceGenerator": true,
|
||||
"SweepableEstimatorGenerator": false
|
||||
"SweepableEstimatorGenerator": true
|
||||
}
|
||||
|
|
|
@ -14,11 +14,13 @@
|
|||
<PrivateAssets>all</PrivateAssets>
|
||||
</ProjectReference>
|
||||
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.ML.OnnxTransformer\Microsoft.ML.OnnxTransformer.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.ML.SearchSpace\Microsoft.ML.SearchSpace.csproj">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeInNuget>true</IncludeInNuget>
|
||||
</ProjectReference>
|
||||
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="$(MicrosoftCodeAnalysisCSharpVersion)" />
|
||||
<ProjectReference Include="..\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.ML.Vision\Microsoft.ML.Vision.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj" />
|
||||
|
@ -43,13 +45,13 @@
|
|||
<Target DependsOnTargets="ResolveReferences" Name="CopyProjectReferencesToPackage">
|
||||
<ItemGroup>
|
||||
<!--Include DLLs of Project References-->
|
||||
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true'))"/>
|
||||
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true'))" />
|
||||
<!--Include PDBs of Project References-->
|
||||
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true')->Replace('.dll', '.pdb'))"/>
|
||||
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true')->Replace('.dll', '.pdb'))" />
|
||||
<!--Include PDBs for Native binaries-->
|
||||
<!--The path needed to be hardcoded for this to work on our publishing CI-->
|
||||
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x86\native"/>
|
||||
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x64\native"/>
|
||||
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x86\native" />
|
||||
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x64\native" />
|
||||
</ItemGroup>
|
||||
</Target>
|
||||
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Nodes;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
{
|
||||
internal class MultiModelPipelineConverter : JsonConverter<MultiModelPipeline>
|
||||
{
|
||||
public override MultiModelPipeline Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
{
|
||||
var jValue = JsonValue.Parse(ref reader);
|
||||
var schema = jValue["schema"].GetValue<string>();
|
||||
var estimators = jValue["estimator"].GetValue<Dictionary<string, SweepableEstimator>>();
|
||||
|
||||
return new MultiModelPipeline(estimators, Entity.FromExpression(schema));
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, MultiModelPipeline value, JsonSerializerOptions options)
|
||||
{
|
||||
var jsonObject = JsonNode.Parse("{}");
|
||||
jsonObject["schema"] = value.Schema.ToString();
|
||||
jsonObject["estimators"] = JsonValue.Create(value.Estimators);
|
||||
|
||||
jsonObject.WriteTo(writer, options);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Nodes;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
{
|
||||
internal class SweepableEstimatorConverter : JsonConverter<SweepableEstimator>
|
||||
{
|
||||
public override SweepableEstimator Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
{
|
||||
var jsonObject = JsonValue.Parse(ref reader);
|
||||
var estimatorType = jsonObject["estimatorType"].GetValue<EstimatorType>();
|
||||
var parameter = jsonObject["parameter"].GetValue<Parameter>();
|
||||
var estimator = new SweepableEstimator(estimatorType);
|
||||
estimator.Parameter = parameter;
|
||||
|
||||
return estimator;
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, SweepableEstimator value, JsonSerializerOptions options)
|
||||
{
|
||||
var jObject = JsonObject.Parse("{}");
|
||||
jObject["estimatorType"] = JsonValue.Create(value.EstimatorType);
|
||||
jObject["parameter"] = JsonValue.Create(value.Parameter);
|
||||
jObject.WriteTo(writer, options);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Nodes;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
{
|
||||
internal class SweepableEstimatorPipelineConverter : JsonConverter<SweepableEstimatorPipeline>
|
||||
{
|
||||
public override SweepableEstimatorPipeline Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
|
||||
{
|
||||
var jNode = JsonNode.Parse(ref reader);
|
||||
var parameter = jNode["parameter"].GetValue<Parameter>();
|
||||
var estimators = jNode["estimators"].GetValue<SweepableEstimator[]>();
|
||||
var pipeline = new SweepableEstimatorPipeline(estimators, parameter);
|
||||
|
||||
return pipeline;
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, SweepableEstimatorPipeline value, JsonSerializerOptions options)
|
||||
{
|
||||
var parameter = value.Parameter;
|
||||
var estimators = value.Estimators;
|
||||
var jNode = JsonNode.Parse("{}");
|
||||
jNode["parameter"] = JsonValue.Create(parameter);
|
||||
jNode["estimators"] = JsonValue.Create(estimators);
|
||||
|
||||
jNode.WriteTo(writer, options);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ namespace Microsoft.ML.AutoML
|
|||
protected Estimator()
|
||||
{
|
||||
this.Parameter = Parameter.CreateNestedParameter();
|
||||
this.EstimatorType = EstimatorType.Unknown;
|
||||
}
|
||||
|
||||
internal Estimator(EstimatorType estimatorType)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class ApplyOnnxModel
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ApplyOnnxModelOption param)
|
||||
{
|
||||
return context.Transforms.ApplyOnnxModel(outputColumnName: param.OutputColumnName, inputColumnName: param.InputColumnName, modelFile: param.ModelFile, gpuDeviceId: param.GpuDeviceId, fallbackToCpu: param.FallbackToCpu);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class Naive
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, NaiveOption param)
|
||||
{
|
||||
return context.BinaryClassification.Calibrators.Naive(param.LabelColumnName, param.ScoreColumnName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class Concatenate
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ConcatOption param)
|
||||
{
|
||||
return context.Transforms.Concatenate(param.OutputColumnName, param.InputColumnNames);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Trainers.FastTree;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class FastForestOva
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastForestOption param)
|
||||
{
|
||||
var option = new FastForestBinaryTrainer.Options()
|
||||
{
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastForest(option), labelColumnName: param.LabelColumnName);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class FastForestRegression
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastForestOption param)
|
||||
{
|
||||
var option = new FastForestRegressionTrainer.Options()
|
||||
{
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.Regression.Trainers.FastForest(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class FastForestBinary
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastForestOption param)
|
||||
{
|
||||
var option = new FastForestBinaryTrainer.Options()
|
||||
{
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.BinaryClassification.Trainers.FastForest(option);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Trainers.FastTree;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class FastTreeOva
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
|
||||
{
|
||||
var option = new FastTreeBinaryTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
};
|
||||
|
||||
return context.MulticlassClassification.Trainers.OneVersusAll(context.BinaryClassification.Trainers.FastTree(option), labelColumnName: param.LabelColumnName);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class FastTreeRegression
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
|
||||
{
|
||||
var option = new FastTreeRegressionTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
};
|
||||
|
||||
return context.Regression.Trainers.FastTree(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class FastTreeTweedieRegression
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
|
||||
{
|
||||
var option = new FastTreeTweedieTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
};
|
||||
|
||||
return context.Regression.Trainers.FastTreeTweedie(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class FastTreeBinary
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FastTreeOption param)
|
||||
{
|
||||
var option = new FastTreeBinaryTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfTrees = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
};
|
||||
|
||||
return context.BinaryClassification.Trainers.FastTree(option);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class FeaturizeText
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, FeaturizeTextOption param)
|
||||
{
|
||||
return context.Transforms.Text.FeaturizeText(param.OutputColumnName, param.InputColumnName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class ForecastBySsa
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SsaOption param)
|
||||
{
|
||||
if (param.SeriesLength <= param.WindowSize || param.TrainSize <= 2 * param.WindowSize)
|
||||
{
|
||||
throw new Exception("ForecastBySsa param check error");
|
||||
}
|
||||
|
||||
return context.Forecasting.ForecastBySsa(param.OutputColumnName, param.InputColumnName, param.WindowSize, param.SeriesLength, param.TrainSize, param.Horizon, confidenceLowerBoundColumn: param.ConfidenceLowerBoundColumn, confidenceUpperBoundColumn: param.ConfidenceUpperBoundColumn);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class LoadImages
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LoadImageOption param)
|
||||
{
|
||||
return context.Transforms.LoadImages(param.OutputColumnName, param.ImageFolder, param.InputColumnName);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class LoadRawImageBytes
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LoadImageOption param)
|
||||
{
|
||||
return context.Transforms.LoadRawImageBytes(param.OutputColumnName, param.ImageFolder, param.InputColumnName);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class ResizeImages
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ResizeImageOption param)
|
||||
{
|
||||
return context.Transforms.ResizeImages(param.OutputColumnName, param.ImageWidth, param.ImageHeight, param.InputColumnName, param.Resizing, param.CropAnchor);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class ExtractPixels
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ExtractPixelsOption param)
|
||||
{
|
||||
return context.Transforms.ExtractPixels(param.OutputColumnName, param.InputColumnName, param.ColorsToExtract, param.OrderOfExtraction, outputAsFloatArray: param.OutputAsFloatArray);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class ImageClassificationMulti
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ImageClassificationOption param)
|
||||
{
|
||||
|
||||
return context.MulticlassClassification.Trainers.ImageClassification(param.LabelColumnName, param.FeatureColumnName, param.ScoreColumnName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Trainers;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class LbfgsMaximumEntropyMulti
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
|
||||
{
|
||||
var option = new LbfgsMaximumEntropyMulticlassTrainer.Options()
|
||||
{
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.MulticlassClassification.Trainers.LbfgsMaximumEntropy(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class LbfgsPoissonRegressionRegression
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
|
||||
{
|
||||
var option = new LbfgsPoissonRegressionTrainer.Options()
|
||||
{
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.Regression.Trainers.LbfgsPoissonRegression(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class LbfgsLogisticRegressionBinary
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
|
||||
{
|
||||
var option = new LbfgsLogisticRegressionBinaryTrainer.Options()
|
||||
{
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.BinaryClassification.Trainers.LbfgsLogisticRegression(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class LbfgsLogisticRegressionOva
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LbfgsOption param)
|
||||
{
|
||||
var option = new LbfgsLogisticRegressionBinaryTrainer.Options()
|
||||
{
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
var binaryTrainer = context.BinaryClassification.Trainers.LbfgsLogisticRegression(option);
|
||||
return context.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, param.LabelColumnName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Trainers.LightGbm;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class LightGbmMulti
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LgbmOption param)
|
||||
{
|
||||
var option = new LightGbmMulticlassTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfIterations = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
Booster = new GradientBooster.Options()
|
||||
{
|
||||
SubsampleFraction = param.SubsampleFraction,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
},
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
};
|
||||
|
||||
return context.MulticlassClassification.Trainers.LightGbm(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class LightGbmBinary
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LgbmOption param)
|
||||
{
|
||||
var option = new LightGbmBinaryTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfIterations = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
Booster = new GradientBooster.Options()
|
||||
{
|
||||
SubsampleFraction = param.SubsampleFraction,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
},
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
};
|
||||
|
||||
return context.BinaryClassification.Trainers.LightGbm(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class LightGbmRegression
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, LgbmOption param)
|
||||
{
|
||||
var option = new LightGbmRegressionTrainer.Options()
|
||||
{
|
||||
NumberOfLeaves = param.NumberOfLeaves,
|
||||
NumberOfIterations = param.NumberOfTrees,
|
||||
MinimumExampleCountPerLeaf = param.MinimumExampleCountPerLeaf,
|
||||
LearningRate = param.LearningRate,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
Booster = new GradientBooster.Options()
|
||||
{
|
||||
SubsampleFraction = param.SubsampleFraction,
|
||||
FeatureFraction = param.FeatureFraction,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
},
|
||||
MaximumBinCountPerFeature = param.MaximumBinCountPerFeature,
|
||||
};
|
||||
|
||||
return context.Regression.Trainers.LightGbm(option);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class MapValueToKey
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, MapValueToKeyOption param)
|
||||
{
|
||||
return context.Transforms.Conversion.MapValueToKey(param.OutputColumnName, param.InputColumnName);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class MapKeyToValue
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, MapKeyToValueOption param)
|
||||
{
|
||||
return context.Transforms.Conversion.MapKeyToValue(param.OutputColumnName, param.InputColumnName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class MatrixFactorization
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, MatrixFactorizationOption param)
|
||||
{
|
||||
return context.Recommendation().Trainers.MatrixFactorization(param.LabelColumnName, param.MatrixColumnIndexColumnName, param.MatrixRowIndexColumnName, param.ApproximationRank, param.LearningRate, param.NumberOfIterations);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class NormalizeMinMax
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, NormalizeMinMaxOption param)
|
||||
{
|
||||
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.OutputColumnNames, param.InputColumnNames);
|
||||
|
||||
return context.Transforms.NormalizeMinMax(inputOutputPairs);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class OneHotEncoding
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, OneHotOption param)
|
||||
{
|
||||
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
|
||||
return context.Transforms.Categorical.OneHotEncoding(inputOutputPairs);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class OneHotHashEncoding
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, OneHotOption param)
|
||||
{
|
||||
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
|
||||
return context.Transforms.Categorical.OneHotHashEncoding(inputOutputPairs);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class ReplaceMissingValues
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ReplaceMissingValueOption param)
|
||||
{
|
||||
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
|
||||
return context.Transforms.ReplaceMissingValues(inputOutputPairs);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Trainers;
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class SdcaRegression
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
|
||||
{
|
||||
var option = new SdcaRegressionTrainer.Options()
|
||||
{
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.Regression.Trainers.Sdca(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class SdcaMaximumEntropyMulti
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
|
||||
{
|
||||
var option = new SdcaMaximumEntropyMulticlassTrainer.Options()
|
||||
{
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.MulticlassClassification.Trainers.SdcaMaximumEntropy(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class SdcaLogisticRegressionBinary
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
|
||||
{
|
||||
var option = new SdcaLogisticRegressionBinaryTrainer.Options()
|
||||
{
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
return context.BinaryClassification.Trainers.SdcaLogisticRegression(option);
|
||||
}
|
||||
}
|
||||
|
||||
internal partial class SdcaLogisticRegressionOva
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SdcaOption param)
|
||||
{
|
||||
var option = new SdcaLogisticRegressionBinaryTrainer.Options()
|
||||
{
|
||||
LabelColumnName = param.LabelColumnName,
|
||||
FeatureColumnName = param.FeatureColumnName,
|
||||
ExampleWeightColumnName = param.ExampleWeightColumnName,
|
||||
L1Regularization = param.L1Regularization,
|
||||
L2Regularization = param.L2Regularization,
|
||||
NumberOfThreads = AutoMlUtils.GetNumberOfThreadFromEnvrionment(),
|
||||
};
|
||||
|
||||
var binaryTrainer = context.BinaryClassification.Trainers.SdcaLogisticRegression(option);
|
||||
return context.MulticlassClassification.Trainers.OneVersusAll(binaryEstimator: binaryTrainer, labelColumnName: param.LabelColumnName);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.AutoML.CodeGen
|
||||
{
|
||||
internal partial class ConvertType
|
||||
{
|
||||
public override IEstimator<ITransformer> BuildFromOption(MLContext context, ConvertTypeOption param)
|
||||
{
|
||||
var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames);
|
||||
return context.Transforms.Conversion.ConvertType(inputOutputPairs);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,9 +6,11 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Collections.Immutable;
|
||||
using System.Linq;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
{
|
||||
[JsonConverter(typeof(MultiModelPipelineConverter))]
|
||||
internal class MultiModelPipeline
|
||||
{
|
||||
private static readonly StringEntity _nilStringEntity = new StringEntity("Nil");
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
|
||||
|
@ -12,6 +13,7 @@ namespace Microsoft.ML.AutoML
|
|||
/// <summary>
|
||||
/// Estimator with search space.
|
||||
/// </summary>
|
||||
[JsonConverter(typeof(SweepableEstimatorConverter))]
|
||||
internal class SweepableEstimator : Estimator
|
||||
{
|
||||
private readonly Func<MLContext, Parameter, IEstimator<ITransformer>> _factory;
|
||||
|
@ -50,11 +52,6 @@ namespace Microsoft.ML.AutoML
|
|||
internal virtual IEnumerable<string> NugetDependencies { get; }
|
||||
|
||||
internal virtual string FunctionName { get; }
|
||||
|
||||
internal virtual string ToDisplayString(Parameter param)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
||||
internal abstract class SweepableEstimator<TOption> : SweepableEstimator
|
||||
|
@ -75,12 +72,5 @@ namespace Microsoft.ML.AutoML
|
|||
{
|
||||
return this.BuildFromOption(context, param.AsType<TOption>());
|
||||
}
|
||||
|
||||
internal abstract string ToDisplayString(TOption param);
|
||||
|
||||
internal override string ToDisplayString(Parameter param)
|
||||
{
|
||||
return this.ToDisplayString(param.AsType<TOption>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,11 +5,13 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text.Json.Serialization;
|
||||
using Microsoft.ML.Data;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
|
||||
namespace Microsoft.ML.AutoML
|
||||
{
|
||||
[JsonConverter(typeof(SweepableEstimatorPipelineConverter))]
|
||||
internal class SweepableEstimatorPipeline
|
||||
{
|
||||
private readonly List<SweepableEstimator> _estimators;
|
||||
|
|
|
@ -50,7 +50,7 @@ namespace Microsoft.ML.SearchSpace
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// <see cref="Parameter"/> is used to save sweeping result from <see cref="ITuner.Propose(SearchSpace)"/> and is used to restore mlnet pipeline from sweepable pipline.
|
||||
/// <see cref="Parameter"/> is used to save sweeping result from tuner and is used to restore mlnet pipeline from sweepable pipline.
|
||||
/// </summary>
|
||||
[JsonConverter(typeof(ParameterConverter))]
|
||||
public sealed class Parameter : IDictionary<string, Parameter>
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
{
|
||||
"0": {
|
||||
"FeatureSpaceDim": 0,
|
||||
"Default": [],
|
||||
"Step": []
|
||||
},
|
||||
"1": {
|
||||
"FeatureSpaceDim": 0,
|
||||
"Default": [],
|
||||
"Step": []
|
||||
},
|
||||
"2": {
|
||||
"FeatureSpaceDim": 0,
|
||||
"Default": [],
|
||||
"Step": []
|
||||
},
|
||||
"3": {
|
||||
"FeatureSpaceDim": 9,
|
||||
"Default": [
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0.71428571428571408,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1
|
||||
],
|
||||
"Step": [
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
]
|
||||
},
|
||||
"4": {
|
||||
"FeatureSpaceDim": 6,
|
||||
"Default": [
|
||||
1,
|
||||
0.89689626841273451,
|
||||
0.71428571428571408,
|
||||
0.55365468248122718,
|
||||
0,
|
||||
0
|
||||
],
|
||||
"Step": [
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"0": {},
|
||||
"1": {},
|
||||
"2": {},
|
||||
"3": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"LearningRate": 1,
|
||||
"NumberOfTrees": 4,
|
||||
"SubsampleFraction": 1,
|
||||
"MaximumBinCountPerFeature": 255,
|
||||
"FeatureFraction": 1,
|
||||
"L1Regularization": 0.0000000002,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Feature"
|
||||
},
|
||||
"4": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"NumberOfTrees": 4,
|
||||
"MaximumBinCountPerFeature": 255,
|
||||
"FeatureFraction": 1,
|
||||
"LearningRate": 0.10,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Feature"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
{
|
||||
"schema": "e0 * (e1 + e2 + e3 + e4 + e5)",
|
||||
"estimators": {
|
||||
"e0": {
|
||||
"estimatorType": "Unknown",
|
||||
"parameter": {}
|
||||
},
|
||||
"e1": {
|
||||
"estimatorType": "FastTreeBinary",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"NumberOfTrees": 4,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"LearningRate": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e2": {
|
||||
"estimatorType": "FastForestBinary",
|
||||
"parameter": {
|
||||
"NumberOfTrees": 4,
|
||||
"NumberOfLeaves": 4,
|
||||
"FeatureFraction": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e3": {
|
||||
"estimatorType": "LightGbmBinary",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"LearningRate": 1,
|
||||
"NumberOfTrees": 4,
|
||||
"SubsampleFraction": 1,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"L1Regularization": 0.0000000002,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e4": {
|
||||
"estimatorType": "LbfgsLogisticRegressionBinary",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e5": {
|
||||
"estimatorType": "SdcaLogisticRegressionBinary",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"schema": "e0 * (e1 + e2 + e3 + e4 + e5 + e6 + e7)",
|
||||
"estimators": {
|
||||
"e0": {
|
||||
"estimatorType": "Unknown",
|
||||
"parameter": {}
|
||||
},
|
||||
"e1": {
|
||||
"estimatorType": "FastTreeOva",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"NumberOfTrees": 4,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"LearningRate": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e2": {
|
||||
"estimatorType": "FastForestOva",
|
||||
"parameter": {
|
||||
"NumberOfTrees": 4,
|
||||
"NumberOfLeaves": 4,
|
||||
"FeatureFraction": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e3": {
|
||||
"estimatorType": "LightGbmMulti",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"LearningRate": 1,
|
||||
"NumberOfTrees": 4,
|
||||
"SubsampleFraction": 1,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"L1Regularization": 0.0000000002,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e4": {
|
||||
"estimatorType": "LbfgsLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e5": {
|
||||
"estimatorType": "LbfgsMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e6": {
|
||||
"estimatorType": "SdcaMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e7": {
|
||||
"estimatorType": "SdcaLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"schema": "e0 * (e1 + e2 + e3 + e4 + e5 + e6 + e7)",
|
||||
"estimators": {
|
||||
"e0": {
|
||||
"estimatorType": "Unknown",
|
||||
"parameter": {}
|
||||
},
|
||||
"e1": {
|
||||
"estimatorType": "FastTreeOva",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"NumberOfTrees": 4,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"LearningRate": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e2": {
|
||||
"estimatorType": "FastForestOva",
|
||||
"parameter": {
|
||||
"NumberOfTrees": 4,
|
||||
"NumberOfLeaves": 4,
|
||||
"FeatureFraction": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e3": {
|
||||
"estimatorType": "LightGbmMulti",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"LearningRate": 1,
|
||||
"NumberOfTrees": 4,
|
||||
"SubsampleFraction": 1,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"L1Regularization": 0.0000000002,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e4": {
|
||||
"estimatorType": "LbfgsLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e5": {
|
||||
"estimatorType": "LbfgsMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e6": {
|
||||
"estimatorType": "SdcaMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e7": {
|
||||
"estimatorType": "SdcaLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
{
|
||||
"schema": "e0 * (e1 + e2 + e3 + e4 + e5 + e6 + e7)",
|
||||
"estimators": {
|
||||
"e0": {
|
||||
"estimatorType": "FastForestBinary",
|
||||
"parameter": {
|
||||
"NumberOfTrees": 4,
|
||||
"NumberOfLeaves": 4,
|
||||
"FeatureFraction": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Feature"
|
||||
}
|
||||
},
|
||||
"e1": {
|
||||
"estimatorType": "FastTreeOva",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"NumberOfTrees": 4,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"LearningRate": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e2": {
|
||||
"estimatorType": "FastForestOva",
|
||||
"parameter": {
|
||||
"NumberOfTrees": 4,
|
||||
"NumberOfLeaves": 4,
|
||||
"FeatureFraction": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e3": {
|
||||
"estimatorType": "LightGbmMulti",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"LearningRate": 1,
|
||||
"NumberOfTrees": 4,
|
||||
"SubsampleFraction": 1,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"L1Regularization": 0.0000000002,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e4": {
|
||||
"estimatorType": "LbfgsLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e5": {
|
||||
"estimatorType": "LbfgsMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e6": {
|
||||
"estimatorType": "SdcaMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e7": {
|
||||
"estimatorType": "SdcaLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
{
|
||||
"schema": "e0 * e1 * (e2 + e3 + e4 + e5 + e6 + e7 + e8)",
|
||||
"estimators": {
|
||||
"e0": {
|
||||
"estimatorType": "Unknown",
|
||||
"parameter": {}
|
||||
},
|
||||
"e1": {
|
||||
"estimatorType": "FeaturizeText",
|
||||
"parameter": {}
|
||||
},
|
||||
"e2": {
|
||||
"estimatorType": "FastTreeOva",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"NumberOfTrees": 4,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"LearningRate": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e3": {
|
||||
"estimatorType": "FastForestOva",
|
||||
"parameter": {
|
||||
"NumberOfTrees": 4,
|
||||
"NumberOfLeaves": 4,
|
||||
"FeatureFraction": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e4": {
|
||||
"estimatorType": "LightGbmMulti",
|
||||
"parameter": {
|
||||
"NumberOfLeaves": 4,
|
||||
"MinimumExampleCountPerLeaf": 20,
|
||||
"LearningRate": 1,
|
||||
"NumberOfTrees": 4,
|
||||
"SubsampleFraction": 1,
|
||||
"MaximumBinCountPerFeature": 256,
|
||||
"FeatureFraction": 1,
|
||||
"L1Regularization": 0.0000000002,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e5": {
|
||||
"estimatorType": "LbfgsLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e6": {
|
||||
"estimatorType": "LbfgsMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e7": {
|
||||
"estimatorType": "SdcaMaximumEntropyMulti",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
},
|
||||
"e8": {
|
||||
"estimatorType": "SdcaLogisticRegressionOva",
|
||||
"parameter": {
|
||||
"L1Regularization": 1,
|
||||
"L2Regularization": 0.1,
|
||||
"LabelColumnName": "Label",
|
||||
"FeatureColumnName": "Features"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,17 +11,34 @@ using ApprovalTests.Namers;
|
|||
using ApprovalTests.Reporters;
|
||||
using FluentAssertions;
|
||||
using Microsoft.ML.TestFramework;
|
||||
using Newtonsoft.Json;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
|
||||
namespace Microsoft.ML.AutoML.Test
|
||||
{
|
||||
public class SweepableEstimatorPipelineTest : BaseTestClass
|
||||
{
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
|
||||
public SweepableEstimatorPipelineTest(ITestOutputHelper output)
|
||||
: base(output)
|
||||
{
|
||||
this._jsonSerializerOptions = new JsonSerializerOptions()
|
||||
{
|
||||
WriteIndented = true,
|
||||
Converters =
|
||||
{
|
||||
new JsonStringEnumConverter(), new DoubleToDecimalConverter(), new FloatToDecimalConverter(),
|
||||
},
|
||||
};
|
||||
|
||||
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
|
||||
{
|
||||
Approvals.UseAssemblyLocationForApprovedFiles();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
@ -51,5 +68,86 @@ namespace Microsoft.ML.AutoML.Test
|
|||
pipeline.BuildSweepableEstimatorPipeline("e0 * e2").ToString().Should().Be("Concatenate=>ApplyOnnxModel");
|
||||
pipeline.BuildSweepableEstimatorPipeline("e1 * Nil").ToString().Should().Be("ConvertType");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MultiModelPipeline_append_pipeline_test()
|
||||
{
|
||||
var e1 = new SweepableEstimator(CodeGen.EstimatorType.Concatenate);
|
||||
var e2 = new SweepableEstimator(CodeGen.EstimatorType.ConvertType);
|
||||
var e3 = new SweepableEstimator(CodeGen.EstimatorType.ApplyOnnxModel);
|
||||
var e4 = new SweepableEstimator(CodeGen.EstimatorType.LightGbmBinary);
|
||||
var e5 = new SweepableEstimator(CodeGen.EstimatorType.FastTreeBinary);
|
||||
|
||||
var pipeline1 = new MultiModelPipeline();
|
||||
var pipeline2 = new MultiModelPipeline();
|
||||
|
||||
pipeline1 = pipeline1.Append(e1 + e2 * e3);
|
||||
pipeline2 = pipeline2.Append(e1 * (e3 + e4) + e5);
|
||||
|
||||
pipeline1 = pipeline1.Append(pipeline2);
|
||||
|
||||
pipeline1.Schema.ToString().Should().Be("(e0 + e1 * e2) * (e3 * (e4 + e5) + e6)");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SweepableEstimatorPipeline_search_space_test()
|
||||
{
|
||||
var pipeline = this.CreateSweepbaleEstimatorPipeline();
|
||||
pipeline.SearchSpace.FeatureSpaceDim.Should().Be(15);
|
||||
|
||||
// TODO
|
||||
// verify other properties in search space.
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SweepableEstimatorPipeline_can_be_created_from_MultiModelPipeline()
|
||||
{
|
||||
var multiModelPipeline = this.CreateMultiModelPipeline();
|
||||
var pipelines = multiModelPipeline.PipelineIds;
|
||||
|
||||
pipelines.Should().BeEquivalentTo("e0 * e3 * e4", "e1 * e2 * e3 * e4", "e0 * Nil * e4", "e1 * e2 * Nil * e4", "Nil * e3 * e4", "e0 * e3 * e5", "e1 * e2 * e3 * e5", "e0 * Nil * e5", "e1 * e2 * Nil * e5", "Nil * e3 * e5", "Nil * Nil * e4", "Nil * Nil * e5");
|
||||
var singleModelPipeline = multiModelPipeline.BuildSweepableEstimatorPipeline(pipelines[0]);
|
||||
singleModelPipeline.ToString().Should().Be("ReplaceMissingValues=>Concatenate=>LightGbmBinary");
|
||||
singleModelPipeline = multiModelPipeline.BuildSweepableEstimatorPipeline(pipelines[2]);
|
||||
singleModelPipeline.ToString().Should().Be("ReplaceMissingValues=>LightGbmBinary");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
public void SweepableEstimatorPipeline_search_space_init_value_test()
|
||||
{
|
||||
var singleModelPipeline = this.CreateSweepbaleEstimatorPipeline();
|
||||
var defaultParam = singleModelPipeline.SearchSpace.SampleFromFeatureSpace(singleModelPipeline.SearchSpace.Default);
|
||||
Approvals.Verify(JsonSerializer.Serialize(defaultParam, this._jsonSerializerOptions));
|
||||
}
|
||||
|
||||
private SweepableEstimatorPipeline CreateSweepbaleEstimatorPipeline()
|
||||
{
|
||||
var concat = SweepableEstimatorFactory.CreateConcatenate(new ConcatOption());
|
||||
var replaceMissingValue = SweepableEstimatorFactory.CreateReplaceMissingValues(new ReplaceMissingValueOption());
|
||||
var oneHot = SweepableEstimatorFactory.CreateOneHotEncoding(new OneHotOption());
|
||||
var lightGbm = SweepableEstimatorFactory.CreateLightGbmBinary(new LgbmOption());
|
||||
var fastTree = SweepableEstimatorFactory.CreateFastTreeBinary(new FastTreeOption());
|
||||
|
||||
var pipeline = new SweepableEstimatorPipeline(new SweepableEstimator[] { concat, replaceMissingValue, oneHot, lightGbm, fastTree });
|
||||
return pipeline;
|
||||
}
|
||||
|
||||
private MultiModelPipeline CreateMultiModelPipeline()
|
||||
{
|
||||
var concat = SweepableEstimatorFactory.CreateConcatenate(new ConcatOption());
|
||||
var replaceMissingValue = SweepableEstimatorFactory.CreateReplaceMissingValues(new ReplaceMissingValueOption());
|
||||
var oneHot = SweepableEstimatorFactory.CreateOneHotEncoding(new OneHotOption());
|
||||
var lightGbm = SweepableEstimatorFactory.CreateLightGbmBinary(new LgbmOption());
|
||||
var fastTree = SweepableEstimatorFactory.CreateFastTreeBinary(new FastTreeOption());
|
||||
|
||||
var pipeline = new MultiModelPipeline();
|
||||
pipeline = pipeline.AppendOrSkip(replaceMissingValue + replaceMissingValue * oneHot);
|
||||
pipeline = pipeline.AppendOrSkip(concat);
|
||||
pipeline = pipeline.Append(lightGbm + fastTree);
|
||||
|
||||
return pipeline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using Xunit;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using FluentAssertions;
|
||||
using Microsoft.ML.TestFramework;
|
||||
using Xunit.Abstractions;
|
||||
using ApprovalTests.Namers;
|
||||
using ApprovalTests.Reporters;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Text.Json;
|
||||
using ApprovalTests;
|
||||
|
||||
namespace Microsoft.ML.AutoML.Test
|
||||
{
|
||||
public class SweepableExtensionTest : BaseTestClass
|
||||
{
|
||||
private readonly JsonSerializerOptions _jsonSerializerOptions;
|
||||
|
||||
public SweepableExtensionTest(ITestOutputHelper output)
|
||||
: base(output)
|
||||
{
|
||||
this._jsonSerializerOptions = new JsonSerializerOptions()
|
||||
{
|
||||
WriteIndented = true,
|
||||
Converters =
|
||||
{
|
||||
new JsonStringEnumConverter(), new DoubleToDecimalConverter(), new FloatToDecimalConverter(),
|
||||
},
|
||||
};
|
||||
|
||||
this._jsonSerializerOptions.Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping;
|
||||
|
||||
if (Environment.GetEnvironmentVariable("HELIX_CORRELATION_ID") != null)
|
||||
{
|
||||
Approvals.UseAssemblyLocationForApprovedFiles();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateSweepableEstimatorPipelineFromIEstimatorTest()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var estimator = context.Transforms.Concatenate("output", "input");
|
||||
var pipeline = estimator.Append(SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption()));
|
||||
|
||||
pipeline.ToString().Should().Be("Unknown=>FastForestBinary");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AppendIEstimatorToSweepabeEstimatorPipelineTest()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var estimator = context.Transforms.Concatenate("output", "input");
|
||||
var pipeline = estimator.Append(SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption()));
|
||||
pipeline = pipeline.Append(context.Transforms.CopyColumns("output", "input"));
|
||||
|
||||
pipeline.ToString().Should().Be("Unknown=>FastForestBinary=>Unknown");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateSweepableEstimatorPipelineFromSweepableEstimatorTest()
|
||||
{
|
||||
var estimator = SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption());
|
||||
var pipeline = estimator.Append(estimator);
|
||||
|
||||
pipeline.ToString().Should().Be("FastForestBinary=>FastForestBinary");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateSweepableEstimatorPipelineFromSweepableEstimatorAndIEstimatorTest()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var estimator = SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption());
|
||||
var pipeline = estimator.Append(context.Transforms.Concatenate("output", "input"));
|
||||
|
||||
pipeline.ToString().Should().Be("FastForestBinary=>Unknown");
|
||||
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
public void CreateMultiModelPipelineFromIEstimatorAndBinaryClassifiers()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var pipeline = context.Transforms.Concatenate("output", "input")
|
||||
.Append(context.Auto().BinaryClassification());
|
||||
|
||||
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
|
||||
Approvals.Verify(json);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
public void CreateMultiModelPipelineFromIEstimatorAndMultiClassifiers()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var pipeline = context.Transforms.Concatenate("output", "input")
|
||||
.Append(context.Auto().MultiClassification());
|
||||
|
||||
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
|
||||
Approvals.Verify(json);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
public void CreateMultiModelPipelineFromIEstimatorAndRegressors()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var pipeline = context.Transforms.Concatenate("output", "input")
|
||||
.Append(context.Auto().MultiClassification());
|
||||
|
||||
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
|
||||
Approvals.Verify(json);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
public void CreateMultiModelPipelineFromSweepableEstimatorAndMultiClassifiers()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var pipeline = SweepableEstimatorFactory.CreateFastForestBinary(new FastForestOption())
|
||||
.Append(context.Auto().MultiClassification());
|
||||
|
||||
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
|
||||
Approvals.Verify(json);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
[UseApprovalSubdirectory("ApprovalTests")]
|
||||
[UseReporter(typeof(DiffReporter))]
|
||||
public void CreateMultiModelPipelineFromSweepableEstimatorPipelineAndMultiClassifiers()
|
||||
{
|
||||
var context = new MLContext();
|
||||
var pipeline = context.Transforms.Concatenate("output", "input")
|
||||
.Append(SweepableEstimatorFactory.CreateFeaturizeText(new FeaturizeTextOption()))
|
||||
.Append(context.Auto().MultiClassification());
|
||||
|
||||
var json = JsonSerializer.Serialize(pipeline, this._jsonSerializerOptions);
|
||||
Approvals.Verify(json);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace Microsoft.ML.AutoML.Test
|
||||
{
|
||||
internal class DoubleToDecimalConverter : JsonConverter<double>
|
||||
{
|
||||
public override double Read(ref Utf8JsonReader reader, Type type, JsonSerializerOptions options)
|
||||
{
|
||||
return Convert.ToDouble(reader.GetDecimal());
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, double value, JsonSerializerOptions options)
|
||||
{
|
||||
writer.WriteNumberValue(Convert.ToDecimal(value));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace Microsoft.ML.AutoML.Test
|
||||
{
|
||||
internal class FloatToDecimalConverter : JsonConverter<float>
|
||||
{
|
||||
public override float Read(ref Utf8JsonReader reader, Type type, JsonSerializerOptions options)
|
||||
{
|
||||
return Convert.ToSingle(reader.GetDecimal());
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, float value, JsonSerializerOptions options)
|
||||
{
|
||||
writer.WriteNumberValue(Convert.ToDecimal(value));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -31,6 +31,11 @@
|
|||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Update="Template\SweepableEstimatorFactory.cs">
|
||||
<DesignTime>True</DesignTime>
|
||||
<AutoGen>True</AutoGen>
|
||||
<DependentUpon>SweepableEstimatorFactory.tt</DependentUpon>
|
||||
</Compile>
|
||||
<Compile Update="Template\EstimatorType.cs">
|
||||
<DesignTime>True</DesignTime>
|
||||
<AutoGen>True</AutoGen>
|
||||
|
@ -41,9 +46,23 @@
|
|||
<AutoGen>True</AutoGen>
|
||||
<DependentUpon>SearchSpace.tt</DependentUpon>
|
||||
</Compile>
|
||||
<Compile Update="Template\SweepableEstimator.cs">
|
||||
<DesignTime>True</DesignTime>
|
||||
<AutoGen>True</AutoGen>
|
||||
<DependentUpon>SweepableEstimator.tt</DependentUpon>
|
||||
</Compile>
|
||||
<Compile Update="Template\SweepableEstimator_T_.cs">
|
||||
<DesignTime>True</DesignTime>
|
||||
<AutoGen>True</AutoGen>
|
||||
<DependentUpon>SweepableEstimator_T_.tt</DependentUpon>
|
||||
</Compile>
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<None Update="Template\SweepableEstimatorFactory.tt">
|
||||
<Generator>TextTemplatingFilePreprocessor</Generator>
|
||||
<LastGenOutput>SweepableEstimatorFactory.cs</LastGenOutput>
|
||||
</None>
|
||||
<None Update="Template\EstimatorType.tt">
|
||||
<Generator>TextTemplatingFilePreprocessor</Generator>
|
||||
<LastGenOutput>EstimatorType.cs</LastGenOutput>
|
||||
|
@ -52,6 +71,14 @@
|
|||
<Generator>TextTemplatingFilePreprocessor</Generator>
|
||||
<LastGenOutput>SearchSpace.cs</LastGenOutput>
|
||||
</None>
|
||||
<None Update="Template\SweepableEstimator.tt">
|
||||
<Generator>TextTemplatingFilePreprocessor</Generator>
|
||||
<LastGenOutput>SweepableEstimator.cs</LastGenOutput>
|
||||
</None>
|
||||
<None Update="Template\SweepableEstimator_T_.tt">
|
||||
<Generator>TextTemplatingFilePreprocessor</Generator>
|
||||
<LastGenOutput>SweepableEstimator_T_.cs</LastGenOutput>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
using Microsoft.ML.AutoML.SourceGenerator.Template;
|
||||
|
||||
namespace Microsoft.ML.AutoML.SourceGenerator
|
||||
{
|
||||
[Generator]
|
||||
public class SweepableEstimatorFactoryGenerator : ISourceGenerator
|
||||
{
|
||||
private const string className = "SweepableEstimatorFactory";
|
||||
|
||||
public void Execute(GeneratorExecutionContext context)
|
||||
{
|
||||
if (context.AdditionalFiles.Where(f => f.Path.Contains("code_gen_flag.json")).First() is AdditionalText text)
|
||||
{
|
||||
var json = text.GetText().ToString();
|
||||
var flags = JsonSerializer.Deserialize<Dictionary<string, bool>>(json);
|
||||
if (flags.TryGetValue(nameof(SweepableEstimatorFactoryGenerator), out var res) && res == false)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
var trainers = context.AdditionalFiles.Where(f => f.Path.Contains("trainer-estimators.json"))
|
||||
.SelectMany(file => Utils.GetEstimatorsFromJson(file.GetText().ToString()).Estimators, (text, estimator) => (estimator.FunctionName, estimator.EstimatorTypes, estimator.SearchOption))
|
||||
.SelectMany(union => union.EstimatorTypes.Select(t => (Utils.CreateEstimatorName(union.FunctionName, t), Utils.ToTitleCase(union.SearchOption))))
|
||||
.ToArray();
|
||||
|
||||
var transformers = context.AdditionalFiles.Where(f => f.Path.Contains("transformer-estimators.json"))
|
||||
.SelectMany(file => Utils.GetEstimatorsFromJson(file.GetText().ToString()).Estimators, (text, estimator) => (estimator.FunctionName, estimator.EstimatorTypes, estimator.SearchOption))
|
||||
.SelectMany(union => union.EstimatorTypes.Select(t => (Utils.CreateEstimatorName(union.FunctionName, t), Utils.ToTitleCase(union.SearchOption))))
|
||||
.ToArray();
|
||||
|
||||
var code = new SweepableEstimatorFactory()
|
||||
{
|
||||
NameSpace = Constant.CodeGeneratorNameSpace,
|
||||
EstimatorNames = trainers.Concat(transformers),
|
||||
};
|
||||
|
||||
context.AddSource(className + ".cs", SourceText.From(code.TransformText(), Encoding.UTF8));
|
||||
}
|
||||
|
||||
public void Initialize(GeneratorInitializationContext context)
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.CodeAnalysis;
|
||||
using Microsoft.CodeAnalysis.CSharp.Syntax;
|
||||
using Microsoft.CodeAnalysis.Text;
|
||||
using Microsoft.ML.AutoML.SourceGenerator;
|
||||
using SweepableEstimator = Microsoft.ML.AutoML.SourceGenerator.Template.SweepableEstimator;
|
||||
using SweepableEstimatorT = Microsoft.ML.AutoML.SourceGenerator.Template.SweepableEstimator_T_;
|
||||
namespace Microsoft.ML.ModelBuilder.SweepableEstimator.CodeGenerator
|
||||
{
|
||||
[Generator]
|
||||
public class SweepableEstimatorGenerator : ISourceGenerator
|
||||
{
|
||||
private const string SweepableEstimatorAttributeDisplayName = Constant.CodeGeneratorNameSpace + "." + "SweepableEstimatorAttribute";
|
||||
|
||||
public void Execute(GeneratorExecutionContext context)
|
||||
{
|
||||
if (context.AdditionalFiles.Where(f => f.Path.Contains("code_gen_flag.json")).First() is AdditionalText text)
|
||||
{
|
||||
var json = text.GetText().ToString();
|
||||
var flags = JsonSerializer.Deserialize<Dictionary<string, bool>>(json);
|
||||
if (flags.TryGetValue(nameof(SweepableEstimatorGenerator), out var res) && res == false)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
var estimators = context.AdditionalFiles.Where(f => f.Path.Contains("trainer-estimators.json") || f.Path.Contains("transformer-estimators.json"))
|
||||
.SelectMany(file => Utils.GetEstimatorsFromJson(file.GetText().ToString()).Estimators)
|
||||
.ToArray();
|
||||
|
||||
var code = estimators.SelectMany(e => e.EstimatorTypes.Select(eType => (e, eType, Utils.CreateEstimatorName(e.FunctionName, eType)))
|
||||
.Select(x =>
|
||||
{
|
||||
if (x.e.SearchOption == null)
|
||||
{
|
||||
return
|
||||
(x.Item3,
|
||||
new AutoML.SourceGenerator.Template.SweepableEstimator()
|
||||
{
|
||||
NameSpace = Constant.CodeGeneratorNameSpace,
|
||||
UsingStatements = x.e.UsingStatements,
|
||||
ArgumentsList = x.e.ArgumentsList,
|
||||
ClassName = x.Item3,
|
||||
FunctionName = x.e.FunctionName,
|
||||
NugetDependencies = x.e.NugetDependencies,
|
||||
Type = x.eType,
|
||||
}.TransformText());
|
||||
}
|
||||
else
|
||||
{
|
||||
return
|
||||
(x.Item3,
|
||||
new SweepableEstimatorT()
|
||||
{
|
||||
NameSpace = Constant.CodeGeneratorNameSpace,
|
||||
UsingStatements = x.e.UsingStatements,
|
||||
ArgumentsList = x.e.ArgumentsList,
|
||||
ClassName = x.Item3,
|
||||
FunctionName = x.e.FunctionName,
|
||||
NugetDependencies = x.e.NugetDependencies,
|
||||
Type = x.eType,
|
||||
TOption = Utils.ToTitleCase(x.e.SearchOption),
|
||||
}.TransformText());
|
||||
}
|
||||
}));
|
||||
|
||||
foreach (var c in code)
|
||||
{
|
||||
context.AddSource(c.Item1 + ".cs", SourceText.From(c.Item2, Encoding.UTF8));
|
||||
}
|
||||
}
|
||||
|
||||
public void Initialize(GeneratorInitializationContext context)
|
||||
{
|
||||
return;
|
||||
//context.RegisterForPostInitialization(i => i.AddSource(nameof(SweepableEstimatorAttribute), SweepableEstimatorAttribute));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,9 +40,9 @@ namespace Microsoft.ML.AutoML.SourceGenerator.Template
|
|||
this.Write(this.ToStringHelper.ToStringWithCulture(e));
|
||||
this.Write(",\r\n");
|
||||
}
|
||||
this.Write(" }\r\n\r\n public static class EstimatorTypeExtension\r\n {\r\n public st" +
|
||||
"atic bool IsTrainer(this EstimatorType estimatorType)\r\n {\r\n sw" +
|
||||
"itch(estimatorType)\r\n {\r\n");
|
||||
this.Write(" Unknown,\r\n }\r\n\r\n public static class EstimatorTypeExtension\r\n {\r" +
|
||||
"\n public static bool IsTrainer(this EstimatorType estimatorType)\r\n " +
|
||||
" {\r\n switch(estimatorType)\r\n {\r\n");
|
||||
foreach(var estimator in TrainerNames){
|
||||
this.Write(" case EstimatorType.");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
|
||||
|
|
|
@ -18,6 +18,7 @@ namespace <#=NameSpace#>
|
|||
<# foreach(var e in TransformerNames){#>
|
||||
<#=e#>,
|
||||
<#}#>
|
||||
Unknown,
|
||||
}
|
||||
|
||||
public static class EstimatorTypeExtension
|
||||
|
|
|
@ -0,0 +1,353 @@
|
|||
// ------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by a tool.
|
||||
// Runtime Version: 17.0.0.0
|
||||
//
|
||||
// Changes to this file may cause incorrect behavior and will be lost if
|
||||
// the code is regenerated.
|
||||
// </auto-generated>
|
||||
// ------------------------------------------------------------------------------
|
||||
namespace Microsoft.ML.AutoML.SourceGenerator.Template
|
||||
{
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Collections.Generic;
|
||||
using System;
|
||||
|
||||
/// <summary>
|
||||
/// Class to produce the template output
|
||||
/// </summary>
|
||||
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
|
||||
internal partial class SweepableEstimator : SweepableEstimatorBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Create the template output
|
||||
/// </summary>
|
||||
public virtual string TransformText()
|
||||
{
|
||||
this.Write(@"
|
||||
using System.Collections.Generic;
|
||||
using Newtonsoft.Json;
|
||||
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
|
||||
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
|
||||
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
|
||||
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
|
||||
|
||||
namespace ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
|
||||
this.Write("\r\n{\r\n internal partial class ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write(" : SweepableEstimator\r\n {\r\n public ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write("()\r\n {\r\n this.EstimatorType = EstimatorType.");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write(";\r\n }\r\n \r\n");
|
||||
foreach(var arg in ArgumentsList){
|
||||
var typeAttributeName = Utils.CapitalFirstLetter(arg.ArgumentType);
|
||||
var propertyName = Utils.CapitalFirstLetter(arg.ArgumentName);
|
||||
this.Write(" [");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(typeAttributeName));
|
||||
this.Write("]\r\n [JsonProperty(NullValueHandling=NullValueHandling.Ignore)]\r\n pu" +
|
||||
"blic string ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(propertyName));
|
||||
this.Write(" { get; set; }\r\n\r\n");
|
||||
}
|
||||
this.Write(" internal override IEnumerable<string> CSharpUsingStatements \r\n {\r\n" +
|
||||
" get => new string[] {");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))));
|
||||
this.Write("};\r\n }\r\n\r\n internal override IEnumerable<string> NugetDependencies\r" +
|
||||
"\n {\r\n get => new string[] {");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(NugetDependencies)));
|
||||
this.Write("};\r\n }\r\n\r\n internal override string FunctionName \r\n {\r\n " +
|
||||
" get => \"");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.GetPrefix(Type)));
|
||||
this.Write(".");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(FunctionName));
|
||||
this.Write("\";\r\n }\r\n }\r\n}\r\n\r\n");
|
||||
return this.GenerationEnvironment.ToString();
|
||||
}
|
||||
|
||||
public string NameSpace {get;set;}
|
||||
public string ClassName {get;set;}
|
||||
public string FunctionName {get;set;}
|
||||
public string Type {get;set;}
|
||||
public IEnumerable<Argument> ArgumentsList {get;set;}
|
||||
public IEnumerable<string> UsingStatements {get; set;}
|
||||
public IEnumerable<string> NugetDependencies {get; set;}
|
||||
|
||||
}
|
||||
#region Base class
|
||||
/// <summary>
|
||||
/// Base class for this transformation
|
||||
/// </summary>
|
||||
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
|
||||
internal class SweepableEstimatorBase
|
||||
{
|
||||
#region Fields
|
||||
private global::System.Text.StringBuilder generationEnvironmentField;
|
||||
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
|
||||
private global::System.Collections.Generic.List<int> indentLengthsField;
|
||||
private string currentIndentField = "";
|
||||
private bool endsWithNewline;
|
||||
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
|
||||
#endregion
|
||||
#region Properties
|
||||
/// <summary>
|
||||
/// The string builder that generation-time code is using to assemble generated output
|
||||
/// </summary>
|
||||
protected System.Text.StringBuilder GenerationEnvironment
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.generationEnvironmentField == null))
|
||||
{
|
||||
this.generationEnvironmentField = new global::System.Text.StringBuilder();
|
||||
}
|
||||
return this.generationEnvironmentField;
|
||||
}
|
||||
set
|
||||
{
|
||||
this.generationEnvironmentField = value;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// The error collection for the generation process
|
||||
/// </summary>
|
||||
public System.CodeDom.Compiler.CompilerErrorCollection Errors
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.errorsField == null))
|
||||
{
|
||||
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
|
||||
}
|
||||
return this.errorsField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// A list of the lengths of each indent that was added with PushIndent
|
||||
/// </summary>
|
||||
private System.Collections.Generic.List<int> indentLengths
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.indentLengthsField == null))
|
||||
{
|
||||
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
|
||||
}
|
||||
return this.indentLengthsField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Gets the current indent we use when adding lines to the output
|
||||
/// </summary>
|
||||
public string CurrentIndent
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.currentIndentField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Current transformation session
|
||||
/// </summary>
|
||||
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.sessionField;
|
||||
}
|
||||
set
|
||||
{
|
||||
this.sessionField = value;
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
#region Transform-time helpers
|
||||
/// <summary>
|
||||
/// Write text directly into the generated output
|
||||
/// </summary>
|
||||
public void Write(string textToAppend)
|
||||
{
|
||||
if (string.IsNullOrEmpty(textToAppend))
|
||||
{
|
||||
return;
|
||||
}
|
||||
// If we're starting off, or if the previous text ended with a newline,
|
||||
// we have to append the current indent first.
|
||||
if (((this.GenerationEnvironment.Length == 0)
|
||||
|| this.endsWithNewline))
|
||||
{
|
||||
this.GenerationEnvironment.Append(this.currentIndentField);
|
||||
this.endsWithNewline = false;
|
||||
}
|
||||
// Check if the current text ends with a newline
|
||||
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
|
||||
{
|
||||
this.endsWithNewline = true;
|
||||
}
|
||||
// This is an optimization. If the current indent is "", then we don't have to do any
|
||||
// of the more complex stuff further down.
|
||||
if ((this.currentIndentField.Length == 0))
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend);
|
||||
return;
|
||||
}
|
||||
// Everywhere there is a newline in the text, add an indent after it
|
||||
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
|
||||
// If the text ends with a newline, then we should strip off the indent added at the very end
|
||||
// because the appropriate indent will be added when the next time Write() is called
|
||||
if (this.endsWithNewline)
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
|
||||
}
|
||||
else
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend);
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Write text directly into the generated output
|
||||
/// </summary>
|
||||
public void WriteLine(string textToAppend)
|
||||
{
|
||||
this.Write(textToAppend);
|
||||
this.GenerationEnvironment.AppendLine();
|
||||
this.endsWithNewline = true;
|
||||
}
|
||||
/// <summary>
|
||||
/// Write formatted text directly into the generated output
|
||||
/// </summary>
|
||||
public void Write(string format, params object[] args)
|
||||
{
|
||||
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
|
||||
}
|
||||
/// <summary>
|
||||
/// Write formatted text directly into the generated output
|
||||
/// </summary>
|
||||
public void WriteLine(string format, params object[] args)
|
||||
{
|
||||
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
|
||||
}
|
||||
/// <summary>
|
||||
/// Raise an error
|
||||
/// </summary>
|
||||
public void Error(string message)
|
||||
{
|
||||
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
|
||||
error.ErrorText = message;
|
||||
this.Errors.Add(error);
|
||||
}
|
||||
/// <summary>
|
||||
/// Raise a warning
|
||||
/// </summary>
|
||||
public void Warning(string message)
|
||||
{
|
||||
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
|
||||
error.ErrorText = message;
|
||||
error.IsWarning = true;
|
||||
this.Errors.Add(error);
|
||||
}
|
||||
/// <summary>
|
||||
/// Increase the indent
|
||||
/// </summary>
|
||||
public void PushIndent(string indent)
|
||||
{
|
||||
if ((indent == null))
|
||||
{
|
||||
throw new global::System.ArgumentNullException("indent");
|
||||
}
|
||||
this.currentIndentField = (this.currentIndentField + indent);
|
||||
this.indentLengths.Add(indent.Length);
|
||||
}
|
||||
/// <summary>
|
||||
/// Remove the last indent that was added with PushIndent
|
||||
/// </summary>
|
||||
public string PopIndent()
|
||||
{
|
||||
string returnValue = "";
|
||||
if ((this.indentLengths.Count > 0))
|
||||
{
|
||||
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
|
||||
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
|
||||
if ((indentLength > 0))
|
||||
{
|
||||
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
|
||||
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
|
||||
}
|
||||
}
|
||||
return returnValue;
|
||||
}
|
||||
/// <summary>
|
||||
/// Remove any indentation
|
||||
/// </summary>
|
||||
public void ClearIndent()
|
||||
{
|
||||
this.indentLengths.Clear();
|
||||
this.currentIndentField = "";
|
||||
}
|
||||
#endregion
|
||||
#region ToString Helpers
|
||||
/// <summary>
|
||||
/// Utility class to produce culture-oriented representation of an object as a string.
|
||||
/// </summary>
|
||||
public class ToStringInstanceHelper
|
||||
{
|
||||
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
|
||||
/// <summary>
|
||||
/// Gets or sets format provider to be used by ToStringWithCulture method.
|
||||
/// </summary>
|
||||
public System.IFormatProvider FormatProvider
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.formatProviderField ;
|
||||
}
|
||||
set
|
||||
{
|
||||
if ((value != null))
|
||||
{
|
||||
this.formatProviderField = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
|
||||
/// </summary>
|
||||
public string ToStringWithCulture(object objectToConvert)
|
||||
{
|
||||
if ((objectToConvert == null))
|
||||
{
|
||||
throw new global::System.ArgumentNullException("objectToConvert");
|
||||
}
|
||||
System.Type t = objectToConvert.GetType();
|
||||
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
|
||||
typeof(System.IFormatProvider)});
|
||||
if ((method == null))
|
||||
{
|
||||
return objectToConvert.ToString();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ((string)(method.Invoke(objectToConvert, new object[] {
|
||||
this.formatProviderField })));
|
||||
}
|
||||
}
|
||||
}
|
||||
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
|
||||
/// <summary>
|
||||
/// Helper to produce culture-oriented representation of an object as a string
|
||||
/// </summary>
|
||||
public ToStringInstanceHelper ToStringHelper
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.toStringHelperField;
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
#endregion
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
<#@ template language="C#" linePragmas="false" visibility = "internal"#>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
|
||||
using System.Collections.Generic;
|
||||
using Newtonsoft.Json;
|
||||
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
|
||||
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
|
||||
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
|
||||
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
|
||||
|
||||
namespace <#=NameSpace#>
|
||||
{
|
||||
internal partial class <#=ClassName#> : SweepableEstimator
|
||||
{
|
||||
public <#=ClassName#>()
|
||||
{
|
||||
this.EstimatorType = EstimatorType.<#=ClassName#>;
|
||||
}
|
||||
|
||||
<# foreach(var arg in ArgumentsList){
|
||||
var typeAttributeName = Utils.CapitalFirstLetter(arg.ArgumentType);
|
||||
var propertyName = Utils.CapitalFirstLetter(arg.ArgumentName);#>
|
||||
[<#=typeAttributeName#>]
|
||||
[JsonProperty(NullValueHandling=NullValueHandling.Ignore)]
|
||||
public string <#=propertyName#> { get; set; }
|
||||
|
||||
<#}#>
|
||||
internal override IEnumerable<string> CSharpUsingStatements
|
||||
{
|
||||
get => new string[] {<#=Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))#>};
|
||||
}
|
||||
|
||||
internal override IEnumerable<string> NugetDependencies
|
||||
{
|
||||
get => new string[] {<#=Utils.PrettyPrintListOfString(NugetDependencies)#>};
|
||||
}
|
||||
|
||||
internal override string FunctionName
|
||||
{
|
||||
get => "<#=Utils.GetPrefix(Type)#>.<#=FunctionName#>";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
<#+
|
||||
public string NameSpace {get;set;}
|
||||
public string ClassName {get;set;}
|
||||
public string FunctionName {get;set;}
|
||||
public string Type {get;set;}
|
||||
public IEnumerable<Argument> ArgumentsList {get;set;}
|
||||
public IEnumerable<string> UsingStatements {get; set;}
|
||||
public IEnumerable<string> NugetDependencies {get; set;}
|
||||
#>
|
|
@ -0,0 +1,329 @@
|
|||
// ------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by a tool.
|
||||
// Runtime Version: 17.0.0.0
|
||||
//
|
||||
// Changes to this file may cause incorrect behavior and will be lost if
|
||||
// the code is regenerated.
|
||||
// </auto-generated>
|
||||
// ------------------------------------------------------------------------------
|
||||
namespace Microsoft.ML.AutoML.SourceGenerator.Template
|
||||
{
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Collections.Generic;
|
||||
using System;
|
||||
|
||||
/// <summary>
|
||||
/// Class to produce the template output
|
||||
/// </summary>
|
||||
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
|
||||
internal partial class SweepableEstimatorFactory : SweepableEstimatorFactoryBase
|
||||
{
|
||||
/// <summary>
|
||||
/// Create the template output
|
||||
/// </summary>
|
||||
public virtual string TransformText()
|
||||
{
|
||||
this.Write("\r\nusing System;\r\nusing System.Collections.Generic;\r\nusing System.Text;\r\nusing New" +
|
||||
"tonsoft.Json;\r\nusing Newtonsoft.Json.Linq;\r\nusing Microsoft.ML.SearchSpace;\r\nusi" +
|
||||
"ng Microsoft.ML;\r\n\r\nnamespace ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
|
||||
this.Write("\r\n{\r\n internal static class SweepableEstimatorFactory\r\n {\r\n");
|
||||
foreach((var estimator, var tOption) in EstimatorNames){
|
||||
this.Write(" public static ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
|
||||
this.Write(" Create");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
|
||||
this.Write("(");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(tOption));
|
||||
this.Write(" defaultOption, SearchSpace<");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(tOption));
|
||||
this.Write("> searchSpace = null)\r\n {\r\n if(searchSpace == null){\r\n " +
|
||||
" searchSpace = new SearchSpace<");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(tOption));
|
||||
this.Write(">(defaultOption);\r\n }\r\n\r\n return new ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(estimator));
|
||||
this.Write("(defaultOption, searchSpace);\r\n }\r\n\r\n");
|
||||
}
|
||||
this.Write(" }\r\n}\r\n\r\n");
|
||||
return this.GenerationEnvironment.ToString();
|
||||
}
|
||||
|
||||
public string NameSpace {get;set;}
|
||||
public IEnumerable<(string, string)> EstimatorNames {get;set;}
|
||||
|
||||
}
|
||||
#region Base class
|
||||
/// <summary>
|
||||
/// Base class for this transformation
|
||||
/// </summary>
|
||||
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
|
||||
internal class SweepableEstimatorFactoryBase
|
||||
{
|
||||
#region Fields
|
||||
private global::System.Text.StringBuilder generationEnvironmentField;
|
||||
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
|
||||
private global::System.Collections.Generic.List<int> indentLengthsField;
|
||||
private string currentIndentField = "";
|
||||
private bool endsWithNewline;
|
||||
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
|
||||
#endregion
|
||||
#region Properties
|
||||
/// <summary>
|
||||
/// The string builder that generation-time code is using to assemble generated output
|
||||
/// </summary>
|
||||
protected System.Text.StringBuilder GenerationEnvironment
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.generationEnvironmentField == null))
|
||||
{
|
||||
this.generationEnvironmentField = new global::System.Text.StringBuilder();
|
||||
}
|
||||
return this.generationEnvironmentField;
|
||||
}
|
||||
set
|
||||
{
|
||||
this.generationEnvironmentField = value;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// The error collection for the generation process
|
||||
/// </summary>
|
||||
public System.CodeDom.Compiler.CompilerErrorCollection Errors
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.errorsField == null))
|
||||
{
|
||||
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
|
||||
}
|
||||
return this.errorsField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// A list of the lengths of each indent that was added with PushIndent
|
||||
/// </summary>
|
||||
private System.Collections.Generic.List<int> indentLengths
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.indentLengthsField == null))
|
||||
{
|
||||
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
|
||||
}
|
||||
return this.indentLengthsField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Gets the current indent we use when adding lines to the output
|
||||
/// </summary>
|
||||
public string CurrentIndent
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.currentIndentField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Current transformation session
|
||||
/// </summary>
|
||||
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.sessionField;
|
||||
}
|
||||
set
|
||||
{
|
||||
this.sessionField = value;
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
#region Transform-time helpers
|
||||
/// <summary>
|
||||
/// Write text directly into the generated output
|
||||
/// </summary>
|
||||
public void Write(string textToAppend)
|
||||
{
|
||||
if (string.IsNullOrEmpty(textToAppend))
|
||||
{
|
||||
return;
|
||||
}
|
||||
// If we're starting off, or if the previous text ended with a newline,
|
||||
// we have to append the current indent first.
|
||||
if (((this.GenerationEnvironment.Length == 0)
|
||||
|| this.endsWithNewline))
|
||||
{
|
||||
this.GenerationEnvironment.Append(this.currentIndentField);
|
||||
this.endsWithNewline = false;
|
||||
}
|
||||
// Check if the current text ends with a newline
|
||||
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
|
||||
{
|
||||
this.endsWithNewline = true;
|
||||
}
|
||||
// This is an optimization. If the current indent is "", then we don't have to do any
|
||||
// of the more complex stuff further down.
|
||||
if ((this.currentIndentField.Length == 0))
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend);
|
||||
return;
|
||||
}
|
||||
// Everywhere there is a newline in the text, add an indent after it
|
||||
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
|
||||
// If the text ends with a newline, then we should strip off the indent added at the very end
|
||||
// because the appropriate indent will be added when the next time Write() is called
|
||||
if (this.endsWithNewline)
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
|
||||
}
|
||||
else
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend);
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Write text directly into the generated output
|
||||
/// </summary>
|
||||
public void WriteLine(string textToAppend)
|
||||
{
|
||||
this.Write(textToAppend);
|
||||
this.GenerationEnvironment.AppendLine();
|
||||
this.endsWithNewline = true;
|
||||
}
|
||||
/// <summary>
|
||||
/// Write formatted text directly into the generated output
|
||||
/// </summary>
|
||||
public void Write(string format, params object[] args)
|
||||
{
|
||||
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
|
||||
}
|
||||
/// <summary>
|
||||
/// Write formatted text directly into the generated output
|
||||
/// </summary>
|
||||
public void WriteLine(string format, params object[] args)
|
||||
{
|
||||
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
|
||||
}
|
||||
/// <summary>
|
||||
/// Raise an error
|
||||
/// </summary>
|
||||
public void Error(string message)
|
||||
{
|
||||
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
|
||||
error.ErrorText = message;
|
||||
this.Errors.Add(error);
|
||||
}
|
||||
/// <summary>
|
||||
/// Raise a warning
|
||||
/// </summary>
|
||||
public void Warning(string message)
|
||||
{
|
||||
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
|
||||
error.ErrorText = message;
|
||||
error.IsWarning = true;
|
||||
this.Errors.Add(error);
|
||||
}
|
||||
/// <summary>
|
||||
/// Increase the indent
|
||||
/// </summary>
|
||||
public void PushIndent(string indent)
|
||||
{
|
||||
if ((indent == null))
|
||||
{
|
||||
throw new global::System.ArgumentNullException("indent");
|
||||
}
|
||||
this.currentIndentField = (this.currentIndentField + indent);
|
||||
this.indentLengths.Add(indent.Length);
|
||||
}
|
||||
/// <summary>
|
||||
/// Remove the last indent that was added with PushIndent
|
||||
/// </summary>
|
||||
public string PopIndent()
|
||||
{
|
||||
string returnValue = "";
|
||||
if ((this.indentLengths.Count > 0))
|
||||
{
|
||||
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
|
||||
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
|
||||
if ((indentLength > 0))
|
||||
{
|
||||
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
|
||||
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
|
||||
}
|
||||
}
|
||||
return returnValue;
|
||||
}
|
||||
/// <summary>
|
||||
/// Remove any indentation
|
||||
/// </summary>
|
||||
public void ClearIndent()
|
||||
{
|
||||
this.indentLengths.Clear();
|
||||
this.currentIndentField = "";
|
||||
}
|
||||
#endregion
|
||||
#region ToString Helpers
|
||||
/// <summary>
|
||||
/// Utility class to produce culture-oriented representation of an object as a string.
|
||||
/// </summary>
|
||||
public class ToStringInstanceHelper
|
||||
{
|
||||
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
|
||||
/// <summary>
|
||||
/// Gets or sets format provider to be used by ToStringWithCulture method.
|
||||
/// </summary>
|
||||
public System.IFormatProvider FormatProvider
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.formatProviderField ;
|
||||
}
|
||||
set
|
||||
{
|
||||
if ((value != null))
|
||||
{
|
||||
this.formatProviderField = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
|
||||
/// </summary>
|
||||
public string ToStringWithCulture(object objectToConvert)
|
||||
{
|
||||
if ((objectToConvert == null))
|
||||
{
|
||||
throw new global::System.ArgumentNullException("objectToConvert");
|
||||
}
|
||||
System.Type t = objectToConvert.GetType();
|
||||
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
|
||||
typeof(System.IFormatProvider)});
|
||||
if ((method == null))
|
||||
{
|
||||
return objectToConvert.ToString();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ((string)(method.Invoke(objectToConvert, new object[] {
|
||||
this.formatProviderField })));
|
||||
}
|
||||
}
|
||||
}
|
||||
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
|
||||
/// <summary>
|
||||
/// Helper to produce culture-oriented representation of an object as a string
|
||||
/// </summary>
|
||||
public ToStringInstanceHelper ToStringHelper
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.toStringHelperField;
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
#endregion
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
<#@ template language="C#" linePragmas="false" visibility = "internal" #>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
using Microsoft.ML;
|
||||
|
||||
namespace <#=NameSpace#>
|
||||
{
|
||||
internal static class SweepableEstimatorFactory
|
||||
{
|
||||
<# foreach((var estimator, var tOption) in EstimatorNames){#>
|
||||
public static <#=estimator#> Create<#=estimator#>(<#=tOption#> defaultOption, SearchSpace<<#=tOption#>> searchSpace = null)
|
||||
{
|
||||
if(searchSpace == null){
|
||||
searchSpace = new SearchSpace<<#=tOption#>>(defaultOption);
|
||||
}
|
||||
|
||||
return new <#=estimator#>(defaultOption, searchSpace);
|
||||
}
|
||||
|
||||
<#}#>
|
||||
}
|
||||
}
|
||||
|
||||
<#+
|
||||
public string NameSpace {get;set;}
|
||||
public IEnumerable<(string, string)> EstimatorNames {get;set;}
|
||||
#>
|
|
@ -0,0 +1,358 @@
|
|||
// ------------------------------------------------------------------------------
|
||||
// <auto-generated>
|
||||
// This code was generated by a tool.
|
||||
// Runtime Version: 17.0.0.0
|
||||
//
|
||||
// Changes to this file may cause incorrect behavior and will be lost if
|
||||
// the code is regenerated.
|
||||
// </auto-generated>
|
||||
// ------------------------------------------------------------------------------
|
||||
namespace Microsoft.ML.AutoML.SourceGenerator.Template
|
||||
{
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Collections.Generic;
|
||||
using System;
|
||||
|
||||
/// <summary>
|
||||
/// Class to produce the template output
|
||||
/// </summary>
|
||||
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
|
||||
internal partial class SweepableEstimator_T_ : SweepableEstimator_T_Base
|
||||
{
|
||||
/// <summary>
|
||||
/// Create the template output
|
||||
/// </summary>
|
||||
public virtual string TransformText()
|
||||
{
|
||||
this.Write(@"
|
||||
using System.Collections.Generic;
|
||||
using Newtonsoft.Json;
|
||||
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
|
||||
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
|
||||
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
|
||||
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
|
||||
namespace ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
|
||||
this.Write("\r\n{\r\n internal partial class ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write(" : SweepableEstimator<");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
|
||||
this.Write(">\r\n {\r\n public ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write("(");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
|
||||
this.Write(" defaultOption, SearchSpace<");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
|
||||
this.Write("> searchSpace = null)\r\n {\r\n this.TParameter = defaultOption;\r\n " +
|
||||
" this.SearchSpace = searchSpace;\r\n this.EstimatorType = Est" +
|
||||
"imatorType.");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write(";\r\n }\r\n\r\n internal ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write("()\r\n {\r\n this.EstimatorType = EstimatorType.");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
|
||||
this.Write(";\r\n this.TParameter = new ");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(TOption));
|
||||
this.Write("();\r\n }\r\n \r\n internal override IEnumerable<string> CSharpUsingSt" +
|
||||
"atements \r\n {\r\n get => new string[] {");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))));
|
||||
this.Write("};\r\n }\r\n\r\n internal override IEnumerable<string> NugetDependencies\r" +
|
||||
"\n {\r\n get => new string[] {");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.PrettyPrintListOfString(NugetDependencies)));
|
||||
this.Write("};\r\n }\r\n\r\n internal override string FunctionName \r\n {\r\n " +
|
||||
" get => \"");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.GetPrefix(Type)));
|
||||
this.Write(".");
|
||||
this.Write(this.ToStringHelper.ToStringWithCulture(FunctionName));
|
||||
this.Write("\";\r\n }\r\n }\r\n}\r\n\r\n");
|
||||
return this.GenerationEnvironment.ToString();
|
||||
}
|
||||
|
||||
public string NameSpace {get;set;}
|
||||
public string ClassName {get;set;}
|
||||
public string FunctionName {get;set;}
|
||||
public string Type {get;set;}
|
||||
public IEnumerable<Argument> ArgumentsList {get;set;}
|
||||
public IEnumerable<string> UsingStatements {get; set;}
|
||||
public IEnumerable<string> NugetDependencies {get; set;}
|
||||
public string TOption {get; set;}
|
||||
|
||||
}
|
||||
#region Base class
|
||||
/// <summary>
|
||||
/// Base class for this transformation
|
||||
/// </summary>
|
||||
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
|
||||
internal class SweepableEstimator_T_Base
|
||||
{
|
||||
#region Fields
|
||||
private global::System.Text.StringBuilder generationEnvironmentField;
|
||||
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
|
||||
private global::System.Collections.Generic.List<int> indentLengthsField;
|
||||
private string currentIndentField = "";
|
||||
private bool endsWithNewline;
|
||||
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
|
||||
#endregion
|
||||
#region Properties
|
||||
/// <summary>
|
||||
/// The string builder that generation-time code is using to assemble generated output
|
||||
/// </summary>
|
||||
protected System.Text.StringBuilder GenerationEnvironment
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.generationEnvironmentField == null))
|
||||
{
|
||||
this.generationEnvironmentField = new global::System.Text.StringBuilder();
|
||||
}
|
||||
return this.generationEnvironmentField;
|
||||
}
|
||||
set
|
||||
{
|
||||
this.generationEnvironmentField = value;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// The error collection for the generation process
|
||||
/// </summary>
|
||||
public System.CodeDom.Compiler.CompilerErrorCollection Errors
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.errorsField == null))
|
||||
{
|
||||
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
|
||||
}
|
||||
return this.errorsField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// A list of the lengths of each indent that was added with PushIndent
|
||||
/// </summary>
|
||||
private System.Collections.Generic.List<int> indentLengths
|
||||
{
|
||||
get
|
||||
{
|
||||
if ((this.indentLengthsField == null))
|
||||
{
|
||||
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
|
||||
}
|
||||
return this.indentLengthsField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Gets the current indent we use when adding lines to the output
|
||||
/// </summary>
|
||||
public string CurrentIndent
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.currentIndentField;
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Current transformation session
|
||||
/// </summary>
|
||||
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.sessionField;
|
||||
}
|
||||
set
|
||||
{
|
||||
this.sessionField = value;
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
#region Transform-time helpers
|
||||
/// <summary>
|
||||
/// Write text directly into the generated output
|
||||
/// </summary>
|
||||
public void Write(string textToAppend)
|
||||
{
|
||||
if (string.IsNullOrEmpty(textToAppend))
|
||||
{
|
||||
return;
|
||||
}
|
||||
// If we're starting off, or if the previous text ended with a newline,
|
||||
// we have to append the current indent first.
|
||||
if (((this.GenerationEnvironment.Length == 0)
|
||||
|| this.endsWithNewline))
|
||||
{
|
||||
this.GenerationEnvironment.Append(this.currentIndentField);
|
||||
this.endsWithNewline = false;
|
||||
}
|
||||
// Check if the current text ends with a newline
|
||||
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
|
||||
{
|
||||
this.endsWithNewline = true;
|
||||
}
|
||||
// This is an optimization. If the current indent is "", then we don't have to do any
|
||||
// of the more complex stuff further down.
|
||||
if ((this.currentIndentField.Length == 0))
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend);
|
||||
return;
|
||||
}
|
||||
// Everywhere there is a newline in the text, add an indent after it
|
||||
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
|
||||
// If the text ends with a newline, then we should strip off the indent added at the very end
|
||||
// because the appropriate indent will be added when the next time Write() is called
|
||||
if (this.endsWithNewline)
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
|
||||
}
|
||||
else
|
||||
{
|
||||
this.GenerationEnvironment.Append(textToAppend);
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// Write text directly into the generated output
|
||||
/// </summary>
|
||||
public void WriteLine(string textToAppend)
|
||||
{
|
||||
this.Write(textToAppend);
|
||||
this.GenerationEnvironment.AppendLine();
|
||||
this.endsWithNewline = true;
|
||||
}
|
||||
/// <summary>
|
||||
/// Write formatted text directly into the generated output
|
||||
/// </summary>
|
||||
public void Write(string format, params object[] args)
|
||||
{
|
||||
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
|
||||
}
|
||||
/// <summary>
|
||||
/// Write formatted text directly into the generated output
|
||||
/// </summary>
|
||||
public void WriteLine(string format, params object[] args)
|
||||
{
|
||||
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
|
||||
}
|
||||
/// <summary>
|
||||
/// Raise an error
|
||||
/// </summary>
|
||||
public void Error(string message)
|
||||
{
|
||||
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
|
||||
error.ErrorText = message;
|
||||
this.Errors.Add(error);
|
||||
}
|
||||
/// <summary>
|
||||
/// Raise a warning
|
||||
/// </summary>
|
||||
public void Warning(string message)
|
||||
{
|
||||
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
|
||||
error.ErrorText = message;
|
||||
error.IsWarning = true;
|
||||
this.Errors.Add(error);
|
||||
}
|
||||
/// <summary>
|
||||
/// Increase the indent
|
||||
/// </summary>
|
||||
public void PushIndent(string indent)
|
||||
{
|
||||
if ((indent == null))
|
||||
{
|
||||
throw new global::System.ArgumentNullException("indent");
|
||||
}
|
||||
this.currentIndentField = (this.currentIndentField + indent);
|
||||
this.indentLengths.Add(indent.Length);
|
||||
}
|
||||
/// <summary>
|
||||
/// Remove the last indent that was added with PushIndent
|
||||
/// </summary>
|
||||
public string PopIndent()
|
||||
{
|
||||
string returnValue = "";
|
||||
if ((this.indentLengths.Count > 0))
|
||||
{
|
||||
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
|
||||
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
|
||||
if ((indentLength > 0))
|
||||
{
|
||||
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
|
||||
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
|
||||
}
|
||||
}
|
||||
return returnValue;
|
||||
}
|
||||
/// <summary>
|
||||
/// Remove any indentation
|
||||
/// </summary>
|
||||
public void ClearIndent()
|
||||
{
|
||||
this.indentLengths.Clear();
|
||||
this.currentIndentField = "";
|
||||
}
|
||||
#endregion
|
||||
#region ToString Helpers
|
||||
/// <summary>
|
||||
/// Utility class to produce culture-oriented representation of an object as a string.
|
||||
/// </summary>
|
||||
public class ToStringInstanceHelper
|
||||
{
|
||||
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
|
||||
/// <summary>
|
||||
/// Gets or sets format provider to be used by ToStringWithCulture method.
|
||||
/// </summary>
|
||||
public System.IFormatProvider FormatProvider
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.formatProviderField ;
|
||||
}
|
||||
set
|
||||
{
|
||||
if ((value != null))
|
||||
{
|
||||
this.formatProviderField = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
/// <summary>
|
||||
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
|
||||
/// </summary>
|
||||
public string ToStringWithCulture(object objectToConvert)
|
||||
{
|
||||
if ((objectToConvert == null))
|
||||
{
|
||||
throw new global::System.ArgumentNullException("objectToConvert");
|
||||
}
|
||||
System.Type t = objectToConvert.GetType();
|
||||
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
|
||||
typeof(System.IFormatProvider)});
|
||||
if ((method == null))
|
||||
{
|
||||
return objectToConvert.ToString();
|
||||
}
|
||||
else
|
||||
{
|
||||
return ((string)(method.Invoke(objectToConvert, new object[] {
|
||||
this.formatProviderField })));
|
||||
}
|
||||
}
|
||||
}
|
||||
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
|
||||
/// <summary>
|
||||
/// Helper to produce culture-oriented representation of an object as a string
|
||||
/// </summary>
|
||||
public ToStringInstanceHelper ToStringHelper
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.toStringHelperField;
|
||||
}
|
||||
}
|
||||
#endregion
|
||||
}
|
||||
#endregion
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
<#@ template language="C#" linePragmas="false" visibility = "internal"#>
|
||||
<#@ assembly name="System.Core" #>
|
||||
<#@ import namespace="System.Linq" #>
|
||||
<#@ import namespace="System.Text" #>
|
||||
<#@ import namespace="System.Collections.Generic" #>
|
||||
|
||||
using System.Collections.Generic;
|
||||
using Newtonsoft.Json;
|
||||
using SweepableEstimator = Microsoft.ML.AutoML.SweepableEstimator;
|
||||
using Microsoft.ML.AutoML.CodeGen;
|
||||
using ColorsOrder = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorsOrder;
|
||||
using ColorBits = Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator.ColorBits;
|
||||
using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.ResizingKind;
|
||||
using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor;
|
||||
using Microsoft.ML.SearchSpace;
|
||||
|
||||
namespace <#=NameSpace#>
|
||||
{
|
||||
internal partial class <#=ClassName#> : SweepableEstimator<<#=TOption#>>
|
||||
{
|
||||
public <#=ClassName#>(<#=TOption#> defaultOption, SearchSpace<<#=TOption#>> searchSpace = null)
|
||||
{
|
||||
this.TParameter = defaultOption;
|
||||
this.SearchSpace = searchSpace;
|
||||
this.EstimatorType = EstimatorType.<#=ClassName#>;
|
||||
}
|
||||
|
||||
internal <#=ClassName#>()
|
||||
{
|
||||
this.EstimatorType = EstimatorType.<#=ClassName#>;
|
||||
this.TParameter = new <#=TOption#>();
|
||||
}
|
||||
|
||||
internal override IEnumerable<string> CSharpUsingStatements
|
||||
{
|
||||
get => new string[] {<#=Utils.PrettyPrintListOfString(UsingStatements.Select(x => $"using {x};"))#>};
|
||||
}
|
||||
|
||||
internal override IEnumerable<string> NugetDependencies
|
||||
{
|
||||
get => new string[] {<#=Utils.PrettyPrintListOfString(NugetDependencies)#>};
|
||||
}
|
||||
|
||||
internal override string FunctionName
|
||||
{
|
||||
get => "<#=Utils.GetPrefix(Type)#>.<#=FunctionName#>";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
<#+
|
||||
public string NameSpace {get;set;}
|
||||
public string ClassName {get;set;}
|
||||
public string FunctionName {get;set;}
|
||||
public string Type {get;set;}
|
||||
public IEnumerable<Argument> ArgumentsList {get;set;}
|
||||
public IEnumerable<string> UsingStatements {get; set;}
|
||||
public IEnumerable<string> NugetDependencies {get; set;}
|
||||
public string TOption {get; set;}
|
||||
#>
|
|
@ -2,6 +2,7 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
|
@ -70,5 +71,59 @@ namespace Microsoft.ML.AutoML.SourceGenerator
|
|||
{
|
||||
return string.Join(string.Empty, str.Split('_', ' ', '-').Select(x => CapitalFirstLetter(x)));
|
||||
}
|
||||
|
||||
public static string GetPrefix(string estimatorType)
|
||||
{
|
||||
if (estimatorType == "BinaryClassification")
|
||||
{
|
||||
return "BinaryClassification.Trainers";
|
||||
}
|
||||
if (estimatorType == "MultiClassification")
|
||||
{
|
||||
return "MulticlassClassification.Trainers";
|
||||
}
|
||||
if (estimatorType == "Regression")
|
||||
{
|
||||
return "Regression.Trainers";
|
||||
}
|
||||
if (estimatorType == "Ranking")
|
||||
{
|
||||
return "Ranking.Trainers";
|
||||
}
|
||||
if (estimatorType == "OneVersusAll")
|
||||
{
|
||||
return "BinaryClassification.Trainers";
|
||||
}
|
||||
if (estimatorType == "Recommendation")
|
||||
{
|
||||
return "Recommendation().Trainers";
|
||||
}
|
||||
if (estimatorType == "Transforms")
|
||||
{
|
||||
return "Transforms";
|
||||
}
|
||||
if (estimatorType == "Categorical")
|
||||
{
|
||||
return "Transforms.Categorical";
|
||||
}
|
||||
if (estimatorType == "Conversion")
|
||||
{
|
||||
return "Transforms.Conversion";
|
||||
}
|
||||
if (estimatorType == "Text")
|
||||
{
|
||||
return "Transforms.Text";
|
||||
}
|
||||
if (estimatorType == "Calibrators")
|
||||
{
|
||||
return "BinaryClassification.Calibrators";
|
||||
}
|
||||
if (estimatorType == "Forecasting")
|
||||
{
|
||||
return "Forecasting";
|
||||
}
|
||||
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче