зеркало из https://github.com/dotnet/infer.git
Simplify discrete char implementation (#361)
There are two changes: 1. (major) `LogProbabilityOverride` was removed from `DiscreteChar` and `ImmutableDiscreteChar` This was an unsound functionality that is not used and was complicating the implementation of char distributions 2. (minor) `ImmutableDiscreteChar.Multiply` may reuse immutable discrete char in more cases out of the box. This reduces the GC pressure a little bit.
This commit is contained in:
Родитель
abae518039
Коммит
8370c9e8e7
|
@ -21,19 +21,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
public StringAutomaton()
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Whether there are log value overrides at the element level.
|
||||
/// </summary>
|
||||
public bool HasElementLogValueOverrides
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.States.transitions.Any(
|
||||
trans => trans.ElementDistribution.HasValue &&
|
||||
trans.ElementDistribution.Value.HasLogProbabilityOverride);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes a set of outgoing transitions from a given state of the determinization result.
|
||||
|
|
|
@ -379,7 +379,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
// Populate the stack with start destination state
|
||||
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
|
||||
var stringAutomaton = srcAutomaton as StringAutomaton;
|
||||
var sourceDistributionHasLogProbabilityOverrides = stringAutomaton?.HasElementLogValueOverrides ?? false;
|
||||
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
|
@ -420,25 +419,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
continue;
|
||||
}
|
||||
|
||||
// In the special case of a log probability override in an ImmutableDiscreteChar element distribution,
|
||||
// we need to compensate for the fact that the distribution is not normalized.
|
||||
if (destElementDistribution.HasValue && sourceDistributionHasLogProbabilityOverrides)
|
||||
{
|
||||
var discreteChar =
|
||||
(ImmutableDiscreteChar)(IImmutableDistribution<char, ImmutableDiscreteChar>)srcTransition.ElementDistribution.Value;
|
||||
if (discreteChar.HasLogProbabilityOverride)
|
||||
{
|
||||
var totalMass = discreteChar.Ranges.EnumerableSum(rng =>
|
||||
rng.Probability.Value * (rng.EndExclusive - rng.StartInclusive));
|
||||
projectionLogScale -= System.Math.Log(totalMass);
|
||||
}
|
||||
}
|
||||
|
||||
var destWeight =
|
||||
sourceDistributionHasLogProbabilityOverrides && destElementDistribution.HasNoValue
|
||||
? Weight.One
|
||||
: Weight.Product(mappingTransition.Weight, srcTransition.Weight,
|
||||
Weight.FromLogValue(projectionLogScale));
|
||||
var destWeight = Weight.Product(
|
||||
mappingTransition.Weight,
|
||||
srcTransition.Weight,
|
||||
Weight.FromLogValue(projectionLogScale));
|
||||
|
||||
// We don't want an unnormalizable distribution to become normalizable due to a rounding error.
|
||||
if (Math.Abs(destWeight.LogValue) < 1e-12)
|
||||
|
|
|
@ -147,15 +147,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// </summary>
|
||||
public bool IsWordChar => WrappedDistribution.IsWordChar;
|
||||
|
||||
/// <inheritdoc cref="ImmutableDiscreteChar.IsBroad"/>
|
||||
public bool IsBroad => WrappedDistribution.IsBroad;
|
||||
|
||||
/// <inheritdoc cref="ImmutableDiscreteChar.HasLogProbabilityOverride"/>
|
||||
public bool HasLogProbabilityOverride => WrappedDistribution.HasLogProbabilityOverride;
|
||||
|
||||
/// <inheritdoc cref="ImmutableDiscreteChar.LogProbabilityOverride"/>
|
||||
public double? LogProbabilityOverride => WrappedDistribution.LogProbabilityOverride;
|
||||
|
||||
#endregion
|
||||
|
||||
/// <summary>
|
||||
|
@ -541,21 +532,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <inheritdoc cref="ImmutableDiscreteChar.GetProbs"/>
|
||||
public PiecewiseVector GetProbs() => WrappedDistribution.GetProbs();
|
||||
|
||||
/// <summary>
|
||||
/// Sets the distribution to be uniform over the support of a given distribution.
|
||||
/// </summary>
|
||||
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
|
||||
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
|
||||
/// against this distribution. If this is set, then the distribution will not be normalize;
|
||||
/// i.e. the probabilities will not sum to 1 over the support.</param>
|
||||
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
|
||||
/// to create more realistic language model priors. Distributions with this override are always uncached.
|
||||
/// </remarks>
|
||||
public void SetToPartialUniformOf(DiscreteChar distribution, double? logProbabilityOverride)
|
||||
{
|
||||
WrappedDistribution = distribution.WrappedDistribution.CreatePartialUniform(logProbabilityOverride);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Serialization
|
||||
|
@ -705,22 +681,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// </summary>
|
||||
public bool IsWordChar => this.Data.IsWordChar;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether this distribution is broad - i.e. a general class of values.
|
||||
/// </summary>
|
||||
public bool IsBroad => this.Data.IsBroad;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether this distribution is partial uniform with a log probability override.
|
||||
/// </summary>
|
||||
public bool HasLogProbabilityOverride => this.Data.HasLogProbabilityOverride;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value for the log probability override if it exists.
|
||||
/// </summary>
|
||||
public double? LogProbabilityOverride =>
|
||||
this.HasLogProbabilityOverride ? this.Ranges.First().Probability.LogValue : (double?)null;
|
||||
|
||||
#endregion
|
||||
|
||||
#region Distribution properties
|
||||
|
@ -1039,9 +999,9 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <inheritdoc/>
|
||||
public ImmutableDiscreteChar Multiply(ImmutableDiscreteChar other)
|
||||
{
|
||||
if (IsPointMass && other.IsPointMass)
|
||||
if (IsPointMass)
|
||||
{
|
||||
if (Point != other.Point)
|
||||
if (other.FindProb(Point) == Weight.Zero)
|
||||
{
|
||||
throw new AllZeroException("A character distribution that is zero everywhere has been produced.");
|
||||
}
|
||||
|
@ -1049,6 +1009,16 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
return this;
|
||||
}
|
||||
|
||||
if (other.IsPointMass)
|
||||
{
|
||||
if (FindProb(other.Point) == Weight.Zero)
|
||||
{
|
||||
throw new AllZeroException("A character distribution that is zero everywhere has been produced.");
|
||||
}
|
||||
|
||||
return other;
|
||||
}
|
||||
|
||||
var builder = StorageBuilder.Create();
|
||||
|
||||
foreach (var pair in CharRangePair.IntersectRanges(this, other))
|
||||
|
@ -1057,27 +1027,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
builder.AddRange(new CharRange(pair.StartInclusive, pair.EndExclusive, probProduct));
|
||||
}
|
||||
|
||||
double? logProbabilityOverride = null;
|
||||
var thisLogProbabilityOverride = LogProbabilityOverride;
|
||||
var otherLogProbabilityOverride = other.LogProbabilityOverride;
|
||||
if (thisLogProbabilityOverride.HasValue)
|
||||
{
|
||||
if (otherLogProbabilityOverride.HasValue)
|
||||
{
|
||||
throw new ArgumentException("Only one distribution in a ImmutableDiscreteChar product may have a log probability override");
|
||||
}
|
||||
|
||||
if (other.IsBroad)
|
||||
{
|
||||
logProbabilityOverride = thisLogProbabilityOverride;
|
||||
}
|
||||
}
|
||||
else if (otherLogProbabilityOverride.HasValue && IsBroad)
|
||||
{
|
||||
logProbabilityOverride = otherLogProbabilityOverride;
|
||||
}
|
||||
|
||||
return new ImmutableDiscreteChar(builder.GetResult(logProbabilityOverride));
|
||||
return new ImmutableDiscreteChar(builder.GetResult());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1157,16 +1107,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
public ImmutableDiscreteChar CreatePartialUniform() => CreatePartialUniform(null);
|
||||
|
||||
/// <inheritdoc cref="CreatePartialUniform()"/>
|
||||
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
|
||||
/// against this distribution. If this is set, then the distribution will not be normalize;
|
||||
/// i.e. the probabilities will not sum to 1 over the support.</param>
|
||||
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
|
||||
/// to create more realistic language model priors. Distributions with this override are always uncached.
|
||||
/// </remarks>
|
||||
public ImmutableDiscreteChar CreatePartialUniform(double? logProbabilityOverride)
|
||||
public ImmutableDiscreteChar CreatePartialUniform()
|
||||
{
|
||||
var builder = StorageBuilder.Create();
|
||||
foreach (var range in Ranges)
|
||||
|
@ -1178,7 +1119,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
range.Probability.IsZero ? Weight.Zero : Weight.One));
|
||||
}
|
||||
|
||||
return new ImmutableDiscreteChar(builder.GetResult(logProbabilityOverride));
|
||||
return new ImmutableDiscreteChar(builder.GetResult());
|
||||
}
|
||||
|
||||
/// <inheritdoc/>
|
||||
|
@ -1899,32 +1840,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
private string regexRepresentation;
|
||||
private string symbolRepresentation;
|
||||
|
||||
/// <summary>
|
||||
/// Flags derived from ranges.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
private readonly Flags flags;
|
||||
|
||||
/// <summary>
|
||||
/// Whether this distribution has broad support. We want to be able to
|
||||
/// distinguish between specific distributions and general distributions.
|
||||
/// Use the number of non-zero digits as a threshold.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public bool IsBroad => (this.flags & Flags.IsBroad) != 0;
|
||||
|
||||
/// <summary>
|
||||
/// Whether this distribution is partial uniform with a log probability override.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public bool HasLogProbabilityOverride => (this.flags & Flags.HasLogProbabilityOverride) != 0;
|
||||
|
||||
[Flags]
|
||||
private enum Flags
|
||||
{
|
||||
HasLogProbabilityOverride = 0x01,
|
||||
IsBroad = 0x02
|
||||
}
|
||||
#endregion
|
||||
|
||||
#region Constructor and factory methods
|
||||
|
@ -1942,20 +1857,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
this.regexRepresentation = regexRepresentation;
|
||||
this.symbolRepresentation = symbolRepresentation;
|
||||
var supportCount = this.Ranges.Where(range => !range.Probability.IsZero).Sum(range => range.EndExclusive - range.StartInclusive);
|
||||
var isBroad = supportCount >= 9; // Number of non-zero digits.
|
||||
var totalMass = this.Ranges.Sum(range =>
|
||||
range.Probability.Value * (range.EndExclusive - range.StartInclusive));
|
||||
var isScaled = Math.Abs(totalMass - 1.0) > 1e-10;
|
||||
this.flags = 0;
|
||||
if (isBroad)
|
||||
{
|
||||
flags |= Flags.IsBroad;
|
||||
}
|
||||
|
||||
if (isScaled)
|
||||
{
|
||||
flags |= Flags.HasLogProbabilityOverride;
|
||||
}
|
||||
}
|
||||
|
||||
public static Storage CreateUncached(
|
||||
|
@ -2508,23 +2409,15 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <summary>
|
||||
/// Normalizes probabilities in ranges and returns build Storage.
|
||||
/// </summary>
|
||||
public Storage GetResult(double? maximumProbability = null)
|
||||
public Storage GetResult()
|
||||
{
|
||||
this.MergeNeighboringRanges();
|
||||
NormalizeProbabilities(this.ranges, maximumProbability);
|
||||
return
|
||||
maximumProbability.HasValue
|
||||
? Storage.CreateUncached(
|
||||
this.ranges.ToReadOnlyArray(),
|
||||
null,
|
||||
this.charClasses,
|
||||
this.regexRepresentation,
|
||||
this.symbolRepresentation)
|
||||
: Storage.Create(
|
||||
this.ranges.ToReadOnlyArray(),
|
||||
this.charClasses,
|
||||
this.regexRepresentation,
|
||||
this.symbolRepresentation);
|
||||
NormalizeProbabilities(this.ranges);
|
||||
return Storage.Create(
|
||||
this.ranges.ToReadOnlyArray(),
|
||||
this.charClasses,
|
||||
this.regexRepresentation,
|
||||
this.symbolRepresentation);
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
@ -2563,36 +2456,17 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// Normalizes probabilities in ranges.
|
||||
/// </summary>
|
||||
/// <param name="ranges">The ranges.</param>
|
||||
/// <param name="logProbabilityOverride">Ignores the probabilities in the ranges and creates a non-normalized partial uniform distribution.</param>
|
||||
/// <exception cref="ArgumentException">Thrown if logProbabilityOverride has value corresponding to a non-probability.</exception>
|
||||
public static void NormalizeProbabilities(IList<CharRange> ranges, double? logProbabilityOverride = null)
|
||||
public static void NormalizeProbabilities(IList<CharRange> ranges)
|
||||
{
|
||||
if (logProbabilityOverride.HasValue)
|
||||
var normalizer = ComputeInvNormalizer(ranges);
|
||||
for (var i = 0; i < ranges.Count; ++i)
|
||||
{
|
||||
var weight = Weight.FromLogValue(logProbabilityOverride.Value);
|
||||
if (weight.IsZero || weight.Value > 1)
|
||||
{
|
||||
throw new ArgumentException("Invalid log probability override.");
|
||||
}
|
||||
var range = ranges[i];
|
||||
var probability = range.Probability * normalizer;
|
||||
|
||||
for (var i = 0; i < ranges.Count; ++i)
|
||||
{
|
||||
var range = ranges[i];
|
||||
ranges[i] = new CharRange(
|
||||
range.StartInclusive, range.EndExclusive, weight);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var normalizer = ComputeInvNormalizer(ranges);
|
||||
for (var i = 0; i < ranges.Count; ++i)
|
||||
{
|
||||
var range = ranges[i];
|
||||
var probability = range.Probability * normalizer;
|
||||
|
||||
ranges[i] = new CharRange(
|
||||
range.StartInclusive, range.EndExclusive, probability);
|
||||
}
|
||||
ranges[i] = new CharRange(
|
||||
range.StartInclusive, range.EndExclusive, probability);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -138,108 +138,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PartialUniformWithLogProbabilityOverride()
|
||||
{
|
||||
var dist = DiscreteChar.LetterOrDigit();
|
||||
var probLetter = Math.Exp(dist.GetLogProb('j'));
|
||||
var probNumber = Math.Exp(dist.GetLogProb('5'));
|
||||
|
||||
var logProbabilityOverride = Math.Log(0.7);
|
||||
var scaledDist = DiscreteChar.Uniform();
|
||||
scaledDist.SetToPartialUniformOf(dist, logProbabilityOverride);
|
||||
var scaledLogProbLetter = scaledDist.GetLogProb('j');
|
||||
var scaledLogProbNumber = scaledDist.GetLogProb('5');
|
||||
|
||||
Assert.Equal(scaledLogProbLetter, logProbabilityOverride, Eps);
|
||||
Assert.Equal(scaledLogProbNumber, logProbabilityOverride, Eps);
|
||||
|
||||
// Check that cache has not been compromised.
|
||||
Assert.Equal(probLetter, Math.Exp(dist.GetLogProb('j')), Eps);
|
||||
Assert.Equal(probNumber, Math.Exp(dist.GetLogProb('5')), Eps);
|
||||
|
||||
// Check that an exception is thrown if a bad maximumProbability is passed down.
|
||||
Xunit.Assert.Throws<ArgumentException>(() =>
|
||||
{
|
||||
var badDist = DiscreteChar.Uniform();
|
||||
badDist.SetToPartialUniformOf(dist, Math.Log(1.2));
|
||||
});
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BroadAndNarrow()
|
||||
{
|
||||
var dist1 = DiscreteChar.Digit();
|
||||
Xunit.Assert.True(dist1.IsBroad);
|
||||
|
||||
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
|
||||
Xunit.Assert.False(dist2.IsBroad);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void HasLogOverride()
|
||||
{
|
||||
var dist1 = DiscreteChar.LetterOrDigit();
|
||||
Xunit.Assert.False(dist1.HasLogProbabilityOverride);
|
||||
|
||||
dist1.SetToPartialUniformOf(dist1, Math.Log(0.9));
|
||||
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ProductWithLogOverrideBroad()
|
||||
{
|
||||
for (var i = 0; i < 2; i++)
|
||||
{
|
||||
var dist1 = DiscreteChar.LetterOrDigit();
|
||||
var dist2 = DiscreteChar.Digit();
|
||||
|
||||
var logOverrideProbability = Math.Log(0.9);
|
||||
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
|
||||
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||
Xunit.Assert.True(dist2.IsBroad);
|
||||
|
||||
if (i == 1)
|
||||
{
|
||||
Util.Swap(ref dist1, ref dist2);
|
||||
}
|
||||
|
||||
var dist3 = DiscreteChar.Uniform();
|
||||
dist3.SetToProduct(dist1, dist2);
|
||||
|
||||
Xunit.Assert.True(dist3.HasLogProbabilityOverride);
|
||||
Assert.Equal(logOverrideProbability, dist3.GetLogProb('5'), Eps);
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ProductWithLogOverrideNarrow()
|
||||
{
|
||||
for (var i = 0; i < 2; i++)
|
||||
{
|
||||
var dist1 = DiscreteChar.LetterOrDigit();
|
||||
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
|
||||
|
||||
var logOverrideProbability = Math.Log(0.9);
|
||||
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
|
||||
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||
Xunit.Assert.False(dist2.IsBroad);
|
||||
|
||||
if (i == 1)
|
||||
{
|
||||
Util.Swap(ref dist1, ref dist2);
|
||||
}
|
||||
|
||||
var dist3 = DiscreteChar.Uniform();
|
||||
dist3.SetToProduct(dist1, dist2);
|
||||
|
||||
Xunit.Assert.False(dist3.HasLogProbabilityOverride);
|
||||
Assert.Equal(Math.Log(0.25), dist3.GetLogProb('5'), Eps);
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests the support of a character distribution.
|
||||
/// </summary>
|
||||
|
|
|
@ -567,8 +567,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
// significantly less than the probability of a shorter word.
|
||||
// (2) The probability of a specific word conditioned on its length matches that of
|
||||
// words in the target language.
|
||||
// We achieve this by putting non-normalized character distributions on the edges. The
|
||||
// StringDistribution is unaware that these are non-normalized.
|
||||
// We achieve this by putting weights which do not sum to 1 on transitions.
|
||||
// The StringDistribution itself is non-normalizable.
|
||||
const double TargetProb1 = 0.05;
|
||||
const double Ratio1 = 0.4;
|
||||
|
@ -592,62 +591,42 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
var charDistUpper = ImmutableDiscreteChar.Upper();
|
||||
var charDistLower = ImmutableDiscreteChar.Lower();
|
||||
var charDistUpperNarrow = ImmutableDiscreteChar.OneOf('A', 'B');
|
||||
var charDistLowerNarrow = ImmutableDiscreteChar.OneOf('a', 'b');
|
||||
|
||||
var charDistUpperScaled = charDistUpper.CreatePartialUniform(Math.Log(TargetProb1));
|
||||
var charDistLowerScaled1 = charDistLower.CreatePartialUniform(Math.Log(Ratio1));
|
||||
var charDistLowerScaled2 = charDistLower.CreatePartialUniform(Math.Log(Ratio2));
|
||||
var charDistLowerScaled3 = charDistLower.CreatePartialUniform(Math.Log(Ratio3));
|
||||
var charDistLowerScaledEnd = charDistLower.CreatePartialUniform(Math.Log(Ratio4));
|
||||
var workspace = new StringAutomaton.Builder();
|
||||
var state = workspace.Start;
|
||||
|
||||
var wordModel = StringDistribution.Concatenate(
|
||||
new List<ImmutableDiscreteChar>
|
||||
{
|
||||
charDistUpperScaled,
|
||||
charDistLowerScaled1,
|
||||
charDistLowerScaled2,
|
||||
charDistLowerScaled2,
|
||||
charDistLowerScaled2,
|
||||
charDistLowerScaled3,
|
||||
charDistLowerScaled3,
|
||||
charDistLowerScaled3,
|
||||
charDistLowerScaledEnd
|
||||
},
|
||||
true,
|
||||
true);
|
||||
void AddCharToModel(ImmutableDiscreteChar c, double ratio)
|
||||
{
|
||||
var realCharProb = c.Ranges.First().Probability;
|
||||
var weight = Weight.FromValue(ratio) * Weight.Inverse(realCharProb);
|
||||
state = state.AddTransition(c, weight);
|
||||
state.SetEndWeight(Weight.One);
|
||||
}
|
||||
|
||||
AddCharToModel(charDistUpper, TargetProb1);
|
||||
AddCharToModel(charDistLower, Ratio1);
|
||||
AddCharToModel(charDistLower, Ratio2);
|
||||
AddCharToModel(charDistLower, Ratio2);
|
||||
AddCharToModel(charDistLower, Ratio2);
|
||||
AddCharToModel(charDistLower, Ratio3);
|
||||
AddCharToModel(charDistLower, Ratio3);
|
||||
AddCharToModel(charDistLower, Ratio3);
|
||||
AddCharToModel(charDistLower, Ratio4);
|
||||
AddCharToModel(charDistLower, Ratio4);
|
||||
|
||||
state.AddTransition(charDistLower, Weight.One, state.Index);
|
||||
|
||||
var wordModel = new StringDistribution();
|
||||
wordModel.SetWeightFunction(workspace.GetAutomaton());
|
||||
|
||||
const string Word = "Abcdefghij";
|
||||
|
||||
const double Eps = 1e-5;
|
||||
var broadDist = StringDistribution.Char(charDistUpper);
|
||||
var narrowDist = StringDistribution.Char(charDistUpperNarrow);
|
||||
var narrowWord = "A";
|
||||
var expectedProbForNarrow = 0.5;
|
||||
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||
{
|
||||
var currentWord = Word.Substring(0, i + 1);
|
||||
var probCurrentWord = Math.Exp(wordModel.GetLogProb(currentWord));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
|
||||
var logAvg = Math.Exp(wordModel.GetLogAverageOf(broadDist));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], logAvg, Eps);
|
||||
|
||||
var prod = StringDistribution.Zero();
|
||||
prod.SetToProduct(broadDist, wordModel);
|
||||
Xunit.Assert.True(prod.ToAutomaton().HasElementLogValueOverrides);
|
||||
probCurrentWord = Math.Exp(prod.GetLogProb(currentWord));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
|
||||
prod.SetToProduct(narrowDist, wordModel);
|
||||
Xunit.Assert.False(prod.ToAutomaton().HasElementLogValueOverrides);
|
||||
var probNarrowWord = Math.Exp(prod.GetLogProb(narrowWord));
|
||||
Assert.Equal(expectedProbForNarrow, probNarrowWord, Eps);
|
||||
|
||||
broadDist = broadDist.Append(charDistLower);
|
||||
narrowDist = narrowDist.Append(charDistLowerNarrow);
|
||||
narrowWord += "a";
|
||||
expectedProbForNarrow *= 0.5;
|
||||
}
|
||||
|
||||
// Copied model
|
||||
|
@ -659,21 +638,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var probCurrentWord = Math.Exp(copiedModel.GetLogProb(currentWord));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
}
|
||||
|
||||
// Rescaled model
|
||||
var scale = 0.5;
|
||||
var newTargetProb1 = TargetProb1 * scale;
|
||||
var charDistUpperScaled1 = charDistUpper.CreatePartialUniform(Math.Log(newTargetProb1));
|
||||
var reWeightingTransducer =
|
||||
StringTransducer.Replace(StringDistribution.Char(charDistUpper).ToAutomaton(), StringDistribution.Char(charDistUpperScaled1).ToAutomaton())
|
||||
.Append(StringTransducer.Copy());
|
||||
var reWeightedWordModel = StringDistribution.FromWeightFunction(reWeightingTransducer.ProjectSource(wordModel.ToAutomaton()));
|
||||
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||
{
|
||||
var currentWord = Word.Substring(0, i + 1);
|
||||
var probCurrentWord = Math.Exp(reWeightedWordModel.GetLogProb(currentWord));
|
||||
Assert.Equal(scale * targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
Загрузка…
Ссылка в новой задаче