зеркало из https://github.com/dotnet/infer.git
Add helper for faster creation of objects of generic type
This commit is contained in:
Родитель
68212a02f3
Коммит
cac3556389
|
@ -6,16 +6,13 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.IO;
|
||||
using System.Globalization;
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using System.Xml;
|
||||
using System.Xml.Schema;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Utilities
|
||||
{
|
||||
using System.Linq;
|
||||
using System.Linq.Expressions;
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning disable 1591
|
||||
|
@ -464,6 +461,23 @@ namespace Microsoft.ML.Probabilistic.Utilities
|
|||
|
||||
yield break;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A faster versio of `new T()` when T is a generic type parameter.
|
||||
/// See: https://stackoverflow.com/a/1280832
|
||||
/// </summary>
|
||||
public static T New<T>()
|
||||
where T : new()
|
||||
{
|
||||
return NewFuncCache<T>.NewFunc();
|
||||
}
|
||||
|
||||
private static class NewFuncCache<T>
|
||||
where T : new()
|
||||
{
|
||||
public static Expression<Func<T>> NewExpression = () => new T();
|
||||
public static Func<T> NewFunc = NewExpression.Compile();
|
||||
}
|
||||
}
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
using System.Diagnostics;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
|
||||
public abstract partial class Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>
|
||||
{
|
||||
|
@ -428,7 +429,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Builds new automaton object. Builder must not be used after this method is called
|
||||
/// </summary>
|
||||
public TThis GetAutomaton() => new TThis() { Data = this.GetData() };
|
||||
public TThis GetAutomaton()
|
||||
{
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = this.GetData();
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Stores built automaton in pre-allocated <see cref="Automaton{TSequence,TElement,TElementDistribution,TSequenceManipulator,TThis}"/> object.
|
||||
|
|
|
@ -321,7 +321,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
builder.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements)), builder.StartStateIndex);
|
||||
}
|
||||
|
||||
return new TThis() { Data = builder.GetData(true) };
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = builder.GetData(true);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -348,15 +350,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
Argument.CheckIfNotNull(allowedElements, nameof(allowedElements));
|
||||
|
||||
var result = new Builder();
|
||||
var builder = new Builder();
|
||||
if (!double.IsNegativeInfinity(logValue))
|
||||
{
|
||||
allowedElements = allowedElements.CreatePartialUniform();
|
||||
var finish = result.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements)));
|
||||
var finish = builder.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements)));
|
||||
finish.SetEndWeight(Weight.FromLogValue(logValue));
|
||||
}
|
||||
|
||||
return new TThis() { Data = result.GetData(true) };
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = builder.GetData(true);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -400,19 +404,21 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
Argument.CheckIfNotNull(sequences, "sequences");
|
||||
|
||||
var result = new Builder();
|
||||
var builder = new Builder();
|
||||
int sequenceCount = 0;
|
||||
if (!double.IsNegativeInfinity(logValue))
|
||||
{
|
||||
foreach (var sequence in sequences)
|
||||
{
|
||||
var sequenceEndState = result.Start.AddTransitionsForSequence(sequence);
|
||||
var sequenceEndState = builder.Start.AddTransitionsForSequence(sequence);
|
||||
sequenceEndState.SetEndWeight(Weight.FromLogValue(logValue));
|
||||
++sequenceCount;
|
||||
}
|
||||
}
|
||||
|
||||
return new TThis() { Data = result.GetData(sequenceCount <= 1 ? (bool?)true : null) };
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = builder.GetData(sequenceCount <= 1 ? (bool?)true : null);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1262,7 +1268,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
}
|
||||
|
||||
automaton = new TThis() { Data = result.GetData(true) };
|
||||
automaton = Util.New<TThis>();
|
||||
automaton.Data = result.GetData(true);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1477,7 +1484,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var bothInputsDeterminized = automaton1.Data.IsDeterminized == true && automaton2.Data.IsDeterminized == true;
|
||||
var determinizationState = bothInputsDeterminized ? (bool?)true : null;
|
||||
|
||||
var result = new TThis() { Data = builder.GetData(determinizationState) };
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = builder.GetData(determinizationState);
|
||||
if (determinizationState != true && result is StringAutomaton && tryDeterminize)
|
||||
{
|
||||
result = result.TryDeterminize();
|
||||
|
@ -2141,36 +2149,42 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
#region Helpers
|
||||
|
||||
protected TThis WithData(DataContainer data) => new TThis()
|
||||
protected TThis WithData(DataContainer data)
|
||||
{
|
||||
Data = data,
|
||||
LogValueOverride = LogValueOverride,
|
||||
PruneStatesWithLogEndWeightLessThan = PruneStatesWithLogEndWeightLessThan
|
||||
};
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = data;
|
||||
result.LogValueOverride = LogValueOverride;
|
||||
result.PruneStatesWithLogEndWeightLessThan = PruneStatesWithLogEndWeightLessThan;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a copy of the current automaton with a different value of <see cref="LogValueOverride"/>.
|
||||
/// </summary>
|
||||
/// <param name="logValueOverride">New <see cref="LogValueOverride"/>.</param>
|
||||
/// <returns>The created automaton.</returns>
|
||||
public TThis WithLogValueOverride(double? logValueOverride) => new TThis()
|
||||
public TThis WithLogValueOverride(double? logValueOverride)
|
||||
{
|
||||
Data = Data,
|
||||
LogValueOverride = logValueOverride,
|
||||
PruneStatesWithLogEndWeightLessThan = PruneStatesWithLogEndWeightLessThan
|
||||
};
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = Data;
|
||||
result.LogValueOverride = logValueOverride;
|
||||
result.PruneStatesWithLogEndWeightLessThan = PruneStatesWithLogEndWeightLessThan;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a copy of the current automaton with a different value of <see cref="PruneStatesWithLogEndWeightLessThan"/>.
|
||||
/// </summary>
|
||||
/// <param name="pruneStatesWithLogEndWeightLessThan">New <see cref="PruneStatesWithLogEndWeightLessThan"/>.</param>
|
||||
/// <returns>The created automaton.</returns>
|
||||
public TThis WithPruneStatesWithLogEndWeightLessThan(double? pruneStatesWithLogEndWeightLessThan) => new TThis()
|
||||
public TThis WithPruneStatesWithLogEndWeightLessThan(double? pruneStatesWithLogEndWeightLessThan)
|
||||
{
|
||||
Data = Data,
|
||||
LogValueOverride = LogValueOverride,
|
||||
PruneStatesWithLogEndWeightLessThan = pruneStatesWithLogEndWeightLessThan
|
||||
};
|
||||
var result = Util.New<TThis>();
|
||||
result.Data = Data;
|
||||
result.LogValueOverride = LogValueOverride;
|
||||
result.PruneStatesWithLogEndWeightLessThan = pruneStatesWithLogEndWeightLessThan;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating how close two given automata are
|
||||
|
|
|
@ -2,15 +2,12 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||
{
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Represents a weighted finite state automaton defined on <see cref="string"/>.
|
||||
|
|
|
@ -59,7 +59,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <param name="sequenceWeightPairs">The collection of pairs of a sequence and the weight on that sequence.</param>
|
||||
public static TThis FromWeights(IEnumerable<KeyValuePair<TSequence, Weight>> sequenceWeightPairs)
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetWeights(sequenceWeightPairs);
|
||||
return result;
|
||||
|
||||
|
@ -73,7 +73,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
[Construction(nameof(Dictionary))]
|
||||
public static TThis FromDistinctWeights(IEnumerable<KeyValuePair<TSequence, Weight>> sequenceWeightPairs)
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetDistinctWeights(sequenceWeightPairs);
|
||||
return result;
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <param name="point">The only sequence contained in the dictionary.</param>
|
||||
public static TThis FromPoint(TSequence point)
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetDistinctWeights(new[] { new KeyValuePair<TSequence, Weight>(point, Weight.One) });
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// A function mapping sequences to weights (non-normalized probabilities).
|
||||
/// </summary>
|
||||
[DataMember]
|
||||
private TWeightFunction sequenceToWeight = new TWeightFunction();
|
||||
private TWeightFunction sequenceToWeight = default;
|
||||
|
||||
/// <summary>
|
||||
/// Specifies whether the <see cref="sequenceToWeight"/> is normalized.
|
||||
|
@ -142,7 +142,9 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
{
|
||||
Argument.CheckIfNotNull(point, "point", "Point mass must not be null.");
|
||||
|
||||
return new TThis { Point = point };
|
||||
var result = Util.New<TThis>();
|
||||
result.Point = point;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -153,7 +155,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
[Skip]
|
||||
public static TThis Uniform()
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetToUniform();
|
||||
return result;
|
||||
}
|
||||
|
@ -165,7 +167,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
[Construction(UseWhen = "IsZero")]
|
||||
public static TThis Zero()
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetToZero();
|
||||
return result;
|
||||
}
|
||||
|
@ -180,7 +182,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
{
|
||||
Argument.CheckIfNotNull(weightFunction, nameof(weightFunction));
|
||||
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetWeightFunction(weightFunction);
|
||||
return result;
|
||||
}
|
||||
|
@ -293,7 +295,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
}
|
||||
|
||||
var probFunctions = enumerable.Select(d => d.sequenceToWeight);
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.sequenceToWeight = WeightFunctionFactory.Sum(probFunctions).NormalizeStructure();
|
||||
return result;
|
||||
}
|
||||
|
@ -318,7 +320,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <returns>The created mixture distribution.</returns>
|
||||
public static TThis OneOf(double weight1, TThis dist1, double weight2, TThis dist2)
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetToSum(weight1, dist1, weight2, dist2);
|
||||
return result;
|
||||
}
|
||||
|
@ -331,7 +333,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <returns>The created distribution.</returns>
|
||||
public static TThis OneOf(IEnumerable<KeyValuePair<TSequence, double>> sequenceProbPairs)
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.sequenceToWeight = WeightFunctionFactory.FromValues(sequenceProbPairs).NormalizeStructure();
|
||||
return result;
|
||||
}
|
||||
|
@ -1111,9 +1113,9 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
{
|
||||
Argument.CheckIfNotNull(that, "that");
|
||||
|
||||
var auto = new TThis();
|
||||
auto.SetToProduct((TThis)this, that);
|
||||
return auto;
|
||||
var result = Util.New<TThis>();
|
||||
result.SetToProduct((TThis)this, that);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1157,8 +1159,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
{
|
||||
Argument.CheckIfNotNull(that, nameof(that));
|
||||
|
||||
var temp = new TThis();
|
||||
return temp.SetToProductAndReturnLogNormalizer((TThis)this, that, false);
|
||||
var result = Util.New<TThis>();
|
||||
return result.SetToProductAndReturnLogNormalizer((TThis)this, that, false);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1319,7 +1321,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <returns>The created copy.</returns>
|
||||
public TThis Clone()
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetTo((TThis)this);
|
||||
return result;
|
||||
}
|
||||
|
@ -1619,7 +1621,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <returns>The created distribution.</returns>
|
||||
protected static TThis UniformOf(TElementDistribution allowedElements, double uniformLogProb)
|
||||
{
|
||||
var result = new TThis();
|
||||
var result = Util.New<TThis>();
|
||||
result.SetToUniformOf(allowedElements, uniformLogProb);
|
||||
return result;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче