зеркало из https://github.com/dotnet/infer.git
Log probability override in DiscreteChar (#206)
Log probability override for Discrete char
This commit is contained in:
Родитель
000818158c
Коммит
053bac1751
|
@ -2255,7 +2255,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
this.Data = builder.GetData();
|
||||
this.LogValueOverride = automaton.LogValueOverride;
|
||||
this.PruneStatesWithLogEndWeightLessThan = automaton.LogValueOverride;
|
||||
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
@ -2276,7 +2276,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
return automaton2.IsZero() ? double.NegativeInfinity : 1;
|
||||
}
|
||||
|
||||
TThis theConverger = GetConverger(automaton1, automaton2);
|
||||
TThis theConverger = GetConverger(new TThis[] {automaton1, automaton2});
|
||||
var automaton1conv = automaton1.Product(theConverger);
|
||||
var automaton2conv = automaton2.Product(theConverger);
|
||||
|
||||
|
@ -2310,8 +2310,20 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
|
||||
/// </summary>
|
||||
/// <param name="automata">The automata.</param>
|
||||
/// <param name="decayWeight">The decay weight.</param>
|
||||
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
|
||||
public static TThis GetConverger(params TThis[] automata)
|
||||
public static TThis GetConverger(TThis automata, double decayWeight = 0.99)
|
||||
{
|
||||
return GetConverger(new TThis[] {automata}, decayWeight);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
|
||||
/// </summary>
|
||||
/// <param name="automata">The automata.</param>
|
||||
/// <param name="decayWeight">The decay weight.</param>
|
||||
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
|
||||
public static TThis GetConverger(TThis[] automata, double decayWeight = 0.99)
|
||||
{
|
||||
// TODO: This method might not work in the presense of non-trivial loops.
|
||||
|
||||
|
@ -2347,7 +2359,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
Weight transitionWeight = Weight.Product(
|
||||
Weight.FromLogValue(-uniformDist.GetLogAverageOf(uniformDist)),
|
||||
Weight.FromLogValue(-maxLogTransitionWeightSum),
|
||||
Weight.FromValue(0.99));
|
||||
Weight.FromValue(decayWeight));
|
||||
theConverger.Start.AddSelfTransition(uniformDist, transitionWeight);
|
||||
theConverger.Start.SetEndWeight(Weight.One);
|
||||
|
||||
|
@ -2679,11 +2691,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
var propertyMask = new BitVector32();
|
||||
var idx = 0;
|
||||
propertyMask[1 << idx++] = true; // isEpsilonFree is alway known
|
||||
propertyMask[1 << idx++] = true; // isEpsilonFree is always known
|
||||
propertyMask[1 << idx++] = this.Data.IsEpsilonFree;
|
||||
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
|
||||
propertyMask[1 << idx++] = this.PruneStatesWithLogEndWeightLessThan.HasValue;
|
||||
propertyMask[1 << idx++] = true; // start state is alway serialized
|
||||
propertyMask[1 << idx++] = true; // start state is always serialized
|
||||
|
||||
writeInt32(propertyMask.Data);
|
||||
|
||||
|
@ -2711,8 +2723,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// Reads an automaton from.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Serializtion format is a bit unnatural, but we do it for compatiblity with old serialized data.
|
||||
/// So we don't have to maintain 2 versions of derserialization
|
||||
/// Serialization format is a bit unnatural, but we do it for compatibility with old serialized data.
|
||||
/// So we don't have to maintain 2 versions of deserialization.
|
||||
/// </remarks>
|
||||
public static TThis Read(Func<double> readDouble, Func<int> readInt32, Func<TElementDistribution> readElementDistribution)
|
||||
{
|
||||
|
|
|
@ -10,6 +10,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
|
||||
/// <summary>
|
||||
/// Represents a weighted finite state automaton defined on <see cref="string"/>.
|
||||
|
@ -21,6 +22,19 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Whether there are log value overrides at the element level.
|
||||
/// </summary>
|
||||
public bool HasElementLogValueOverrides
|
||||
{
|
||||
get
|
||||
{
|
||||
return this.States.transitions.Any(
|
||||
trans => trans.ElementDistribution.HasValue &&
|
||||
trans.ElementDistribution.Value.HasLogProbabilityOverride);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes a set of outgoing transitions from a given state of the determinization result.
|
||||
/// </summary>
|
||||
|
|
|
@ -347,6 +347,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
// Populate the stack with start destination state
|
||||
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
|
||||
var stringAutomaton = srcAutomaton as StringAutomaton;
|
||||
var sourceDistributionHasLogProbabilityOverrides = stringAutomaton?.HasElementLogValueOverrides ?? false;
|
||||
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
|
@ -387,7 +389,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
continue;
|
||||
}
|
||||
|
||||
var destWeight = Weight.Product(mappingTransition.Weight, srcTransition.Weight, Weight.FromLogValue(projectionLogScale));
|
||||
// In the special case of a log probability override in a DiscreteChar element distribution,
|
||||
// we need to compensate for the fact that the distribution is not normalized.
|
||||
if (destElementDistribution.HasValue && sourceDistributionHasLogProbabilityOverrides)
|
||||
{
|
||||
var discreteChar =
|
||||
(DiscreteChar)(IDistribution<char>)srcTransition.ElementDistribution.Value;
|
||||
if (discreteChar.HasLogProbabilityOverride)
|
||||
{
|
||||
var totalMass = discreteChar.Ranges.EnumerableSum(rng =>
|
||||
rng.Probability.Value * (rng.EndExclusive - rng.StartInclusive));
|
||||
projectionLogScale -= System.Math.Log(totalMass);
|
||||
}
|
||||
}
|
||||
|
||||
var destWeight =
|
||||
sourceDistributionHasLogProbabilityOverrides && destElementDistribution.HasNoValue
|
||||
? Weight.One
|
||||
: Weight.Product(mappingTransition.Weight, srcTransition.Weight,
|
||||
Weight.FromLogValue(projectionLogScale));
|
||||
|
||||
// We don't want an unnormalizable distribution to become normalizable due to a rounding error.
|
||||
if (Math.Abs(destWeight.LogValue) < 1e-12)
|
||||
{
|
||||
destWeight = Weight.One;
|
||||
}
|
||||
|
||||
var childDestStateIndex = CreateDestState(childMappingState, srcChildState);
|
||||
destState.AddTransition(destElementDistribution, destWeight, childDestStateIndex, mappingTransition.Group);
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Collections;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions
|
||||
{
|
||||
using System;
|
||||
|
@ -183,6 +185,22 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// </summary>
|
||||
public bool IsWordChar => this.Data.IsWordChar;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether this distribution is broad - i.e. a general class of values.
|
||||
/// </summary>
|
||||
public bool IsBroad => this.Data.IsBroad;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether this distribution is partial uniform with a log probability override.
|
||||
/// </summary>
|
||||
public bool HasLogProbabilityOverride => this.Data.HasLogProbabilityOverride;
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value for the log probability override if it exists.
|
||||
/// </summary>
|
||||
public double? LogProbabilityOverride =>
|
||||
this.HasLogProbabilityOverride ? this.Ranges.First().Probability.LogValue : (double?)null;
|
||||
|
||||
#endregion
|
||||
|
||||
#region Distribution properties
|
||||
|
@ -530,13 +548,34 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
}
|
||||
|
||||
var builder = StorageBuilder.Create();
|
||||
|
||||
foreach (var pair in CharRangePair.IntersectRanges(distribution1, distribution2))
|
||||
{
|
||||
var probProduct = pair.Probability1 * pair.Probability2;
|
||||
builder.AddRange(new CharRange(pair.StartInclusive, pair.EndExclusive, probProduct));
|
||||
}
|
||||
|
||||
this.Data = builder.GetResult();
|
||||
double? logProbabilityOverride = null;
|
||||
var distribution1LogProbabilityOverride = distribution1.LogProbabilityOverride;
|
||||
var distribution2LogProbabilityOverride = distribution2.LogProbabilityOverride;
|
||||
if (distribution1LogProbabilityOverride.HasValue)
|
||||
{
|
||||
if (distribution2LogProbabilityOverride.HasValue)
|
||||
{
|
||||
throw new ArgumentException("Only one distribution in a DiscreteChar product may have a log probability override");
|
||||
}
|
||||
|
||||
if (distribution2.IsBroad)
|
||||
{
|
||||
logProbabilityOverride = distribution1LogProbabilityOverride;
|
||||
}
|
||||
}
|
||||
else if (distribution2LogProbabilityOverride.HasValue && distribution1.IsBroad)
|
||||
{
|
||||
logProbabilityOverride = distribution2LogProbabilityOverride;
|
||||
}
|
||||
|
||||
this.Data = builder.GetResult(logProbabilityOverride);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -629,8 +668,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <summary>
|
||||
/// Sets the distribution to be uniform over the support of a given distribution.
|
||||
/// </summary>
|
||||
/// <param name="distribution">The distribution which support will be used to setup the current distribution.</param>
|
||||
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
|
||||
public void SetToPartialUniformOf(DiscreteChar distribution)
|
||||
{
|
||||
SetToPartialUniformOf(distribution, null);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets the distribution to be uniform over the support of a given distribution.
|
||||
/// </summary>
|
||||
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
|
||||
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
|
||||
/// against this distribution. If this is set, then the distribution will not be normalize;
|
||||
/// i.e. the probabilities will not sum to 1 over the support.</param>
|
||||
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
|
||||
/// to create more realistic language model priors. Distributions with this override are always uncached.
|
||||
/// </remarks>
|
||||
public void SetToPartialUniformOf(DiscreteChar distribution, double? logProbabilityOverride)
|
||||
{
|
||||
var builder = StorageBuilder.Create();
|
||||
foreach (var range in distribution.Ranges)
|
||||
|
@ -642,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
range.Probability.IsZero ? Weight.Zero : Weight.One));
|
||||
}
|
||||
|
||||
this.Data = builder.GetResult();
|
||||
this.Data = builder.GetResult(logProbabilityOverride);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1398,11 +1452,38 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
|
||||
public char? Point { get; }
|
||||
|
||||
// Following 3 members are not immutable and can be recalculated on-demand
|
||||
// Following members are not immutable and can be recalculated on-demand
|
||||
public CharClasses CharClasses { get; private set; }
|
||||
|
||||
private string regexRepresentation;
|
||||
private string symbolRepresentation;
|
||||
|
||||
/// <summary>
|
||||
/// Flags derived from ranges.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
private readonly Flags flags;
|
||||
|
||||
/// <summary>
|
||||
/// Whether this distribution has broad support. We want to be able to
|
||||
/// distinguish between specific distributions and general distributions.
|
||||
/// Use the number of non-zero digits as a threshold.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public bool IsBroad => (this.flags & Flags.IsBroad) != 0;
|
||||
|
||||
/// <summary>
|
||||
/// Whether this distribution is partial uniform with a log probability override.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public bool HasLogProbabilityOverride => (this.flags & Flags.HasLogProbabilityOverride) != 0;
|
||||
|
||||
[Flags]
|
||||
private enum Flags
|
||||
{
|
||||
HasLogProbabilityOverride = 0x01,
|
||||
IsBroad = 0x02
|
||||
}
|
||||
#endregion
|
||||
|
||||
#region Constructor and factory methods
|
||||
|
@ -1419,6 +1500,21 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
this.CharClasses = charClasses;
|
||||
this.regexRepresentation = regexRepresentation;
|
||||
this.symbolRepresentation = symbolRepresentation;
|
||||
var supportCount = this.Ranges.Where(range => !range.Probability.IsZero).Sum(range => range.EndExclusive - range.StartInclusive);
|
||||
var isBroad = supportCount >= 9; // Number of non-zero digits.
|
||||
var totalMass = this.Ranges.Sum(range =>
|
||||
range.Probability.Value * (range.EndExclusive - range.StartInclusive));
|
||||
var isScaled = Math.Abs(totalMass - 1.0) > 1e-10;
|
||||
this.flags = 0;
|
||||
if (isBroad)
|
||||
{
|
||||
flags |= Flags.IsBroad;
|
||||
}
|
||||
|
||||
if (isScaled)
|
||||
{
|
||||
flags |= Flags.HasLogProbabilityOverride;
|
||||
}
|
||||
}
|
||||
|
||||
public static Storage CreateUncached(
|
||||
|
@ -1896,7 +1992,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
private readonly List<CharRange> ranges;
|
||||
|
||||
/// <summary>
|
||||
/// Precomuted character class.
|
||||
/// Precomputed character class.
|
||||
/// </summary>
|
||||
private readonly CharClasses charClasses;
|
||||
|
||||
|
@ -1971,11 +2067,19 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <summary>
|
||||
/// Normalizes probabilities in ranges and returns build Storage.
|
||||
/// </summary>
|
||||
public Storage GetResult()
|
||||
public Storage GetResult(double? maximumProbability = null)
|
||||
{
|
||||
this.MergeNeighboringRanges();
|
||||
this.NormalizeProbabilities();
|
||||
return Storage.Create(
|
||||
NormalizeProbabilities(this.ranges, maximumProbability);
|
||||
return
|
||||
maximumProbability.HasValue
|
||||
? Storage.CreateUncached(
|
||||
this.ranges.ToArray(),
|
||||
null,
|
||||
this.charClasses,
|
||||
this.regexRepresentation,
|
||||
this.symbolRepresentation)
|
||||
: Storage.Create(
|
||||
this.ranges.ToArray(),
|
||||
this.charClasses,
|
||||
this.regexRepresentation,
|
||||
|
@ -2015,16 +2119,39 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Normalizes probabilities in ranges
|
||||
/// Normalizes probabilities in ranges.
|
||||
/// </summary>
|
||||
private void NormalizeProbabilities()
|
||||
/// <param name="ranges">The ranges.</param>
|
||||
/// <param name="logProbabilityOverride">Ignores the probabilities in the ranges and creates a non-normalized partial uniform distribution.</param>
|
||||
/// <exception cref="ArgumentException">Thrown if logProbabilityOverride has value corresponding to a non-probability.</exception>
|
||||
public static void NormalizeProbabilities(IList<CharRange> ranges, double? logProbabilityOverride = null)
|
||||
{
|
||||
var normalizer = this.ComputeInvNormalizer();
|
||||
for (int i = 0; i < this.ranges.Count; ++i)
|
||||
if (logProbabilityOverride.HasValue)
|
||||
{
|
||||
var range = this.ranges[i];
|
||||
this.ranges[i] = new CharRange(
|
||||
range.StartInclusive, range.EndExclusive, range.Probability * normalizer);
|
||||
var weight = Weight.FromLogValue(logProbabilityOverride.Value);
|
||||
if (weight.IsZero || weight.Value > 1)
|
||||
{
|
||||
throw new ArgumentException("Invalid log probability override.");
|
||||
}
|
||||
|
||||
for (var i = 0; i < ranges.Count; ++i)
|
||||
{
|
||||
var range = ranges[i];
|
||||
ranges[i] = new CharRange(
|
||||
range.StartInclusive, range.EndExclusive, weight);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var normalizer = ComputeInvNormalizer(ranges);
|
||||
for (var i = 0; i < ranges.Count; ++i)
|
||||
{
|
||||
var range = ranges[i];
|
||||
var probability = range.Probability * normalizer;
|
||||
|
||||
ranges[i] = new CharRange(
|
||||
range.StartInclusive, range.EndExclusive, probability);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2032,11 +2159,11 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// Computes the normalizer of this distribution.
|
||||
/// </summary>
|
||||
/// <returns>The computed normalizer.</returns>
|
||||
private Weight ComputeInvNormalizer()
|
||||
private static Weight ComputeInvNormalizer(IEnumerable<CharRange> ranges)
|
||||
{
|
||||
Weight normalizer = Weight.Zero;
|
||||
var normalizer = Weight.Zero;
|
||||
|
||||
foreach (var range in this.ranges)
|
||||
foreach (var range in ranges)
|
||||
{
|
||||
normalizer += Weight.FromValue(range.EndExclusive - range.StartInclusive) * range.Probability;
|
||||
}
|
||||
|
|
|
@ -244,17 +244,35 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a distribution over sequences induced by a given list of distributions over sequence elements.
|
||||
/// Creates a distribution over sequences induced by a given list of distributions over sequence elements
|
||||
/// where the sequence can optionally end at any length, and the last element can optionally repeat without limit.
|
||||
/// </summary>
|
||||
/// <param name="sequence">Enumerable of distributions over sequence elements.</param>
|
||||
/// <param name="elementDistributions">Enumerable of distributions over sequence elements and the transition weights.</param>
|
||||
/// <param name="allowEarlyEnd">Allow the sequence to end at any point.</param>
|
||||
/// <param name="repeatLastElement">Repeat the last element.</param>
|
||||
/// <returns>The created distribution.</returns>
|
||||
public static TThis Concatenate(IEnumerable<TElementDistribution> sequence)
|
||||
public static TThis Concatenate(IEnumerable<TElementDistribution> elementDistributions, bool allowEarlyEnd = false, bool repeatLastElement = false)
|
||||
{
|
||||
var result = new Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>.Builder();
|
||||
var last = result.Start;
|
||||
foreach (var elem in sequence)
|
||||
var elementDistributionArray = elementDistributions.ToArray();
|
||||
for (var i = 0; i < elementDistributionArray.Length - 1; i++)
|
||||
{
|
||||
last = last.AddTransition(elem, Weight.One);
|
||||
last = last.AddTransition(elementDistributionArray[i], Weight.One);
|
||||
if (allowEarlyEnd)
|
||||
{
|
||||
last.SetEndWeight(Weight.One);
|
||||
}
|
||||
}
|
||||
|
||||
var lastElement = elementDistributionArray[elementDistributionArray.Length - 1];
|
||||
if (repeatLastElement)
|
||||
{
|
||||
last.AddSelfTransition(lastElement, Weight.One);
|
||||
}
|
||||
else
|
||||
{
|
||||
last = last.AddTransition(lastElement, Weight.One);
|
||||
}
|
||||
|
||||
last.SetEndWeight(Weight.One);
|
||||
|
@ -1602,6 +1620,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
return !this.IsPointMass && this.sequenceToWeight.IsZero();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converges an improper sequence distribution
|
||||
/// </summary>
|
||||
/// <param name="dist">The original distribution.</param>
|
||||
/// <param name="decayWeight">The decay weight.</param>
|
||||
/// <returns>The converged distribution.</returns>
|
||||
public static TThis Converge(TThis dist, double decayWeight = 0.99)
|
||||
{
|
||||
var converger =
|
||||
Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>
|
||||
.GetConverger(new TWeightFunction[]
|
||||
{
|
||||
dist.sequenceToWeight
|
||||
}, decayWeight);
|
||||
return dist.Product(FromWorkspace(converger));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks if <paramref name="obj"/> equals to this distribution (i.e. represents the same distribution over sequences).
|
||||
/// </summary>
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Reflection.Metadata.Ecma335;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Tests
|
||||
{
|
||||
|
@ -10,13 +12,15 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
using System.Collections.Generic;
|
||||
using Xunit;
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Assert = Xunit.Assert;
|
||||
using Assert = Microsoft.ML.Probabilistic.Tests.AssertHelper;
|
||||
|
||||
/// <summary>
|
||||
/// Tests for <see cref="DiscreteChar"/>.
|
||||
/// </summary>
|
||||
public class DiscreteCharTest
|
||||
{
|
||||
const double Eps = 1e-10;
|
||||
|
||||
/// <summary>
|
||||
/// Runs a set of common distribution tests for <see cref="DiscreteChar"/>.
|
||||
/// </summary>
|
||||
|
@ -63,7 +67,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
return;
|
||||
}
|
||||
|
||||
Assert.True(false);
|
||||
Xunit.Assert.True(false);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
@ -88,7 +92,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var unif = Vector.Constant(numChars, 1.0 / numChars);
|
||||
var maxDiff = hist.MaxDiff(unif);
|
||||
|
||||
Assert.True(maxDiff < 0.01);
|
||||
Xunit.Assert.True(maxDiff < 0.01);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
@ -106,7 +110,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
ab.SetToSum(1, a, 2, b);
|
||||
|
||||
// 2 subsequent ranges
|
||||
Assert.Equal(2, ab.Ranges.Count);
|
||||
Xunit.Assert.Equal(2, ab.Ranges.Count);
|
||||
TestComplement(ab);
|
||||
|
||||
void TestComplement(DiscreteChar dist)
|
||||
|
@ -117,21 +121,123 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var complement = dist.Complement();
|
||||
|
||||
// complement should always be partial uniform
|
||||
Assert.True(complement.IsPartialUniform());
|
||||
Xunit.Assert.True(complement.IsPartialUniform());
|
||||
|
||||
// overlap is zero
|
||||
Assert.True(double.IsNegativeInfinity(dist.GetLogAverageOf(complement)));
|
||||
Assert.True(double.IsNegativeInfinity(uniformDist.GetLogAverageOf(complement)));
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(dist.GetLogAverageOf(complement)));
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(uniformDist.GetLogAverageOf(complement)));
|
||||
|
||||
// union is covers the whole range
|
||||
var sum = default(DiscreteChar);
|
||||
sum.SetToSum(1, dist, 1, complement);
|
||||
sum.SetToPartialUniform();
|
||||
Assert.True(sum.IsUniform());
|
||||
Xunit.Assert.True(sum.IsUniform());
|
||||
|
||||
// Doing complement again will cover the same set of characters
|
||||
var complement2 = complement.Complement();
|
||||
Assert.Equal(uniformDist, complement2);
|
||||
Xunit.Assert.Equal(uniformDist, complement2);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PartialUniformWithLogProbabilityOverride()
|
||||
{
|
||||
var dist = DiscreteChar.LetterOrDigit();
|
||||
var probLetter = Math.Exp(dist.GetLogProb('j'));
|
||||
var probNumber = Math.Exp(dist.GetLogProb('5'));
|
||||
|
||||
var logProbabilityOverride = Math.Log(0.7);
|
||||
var scaledDist = DiscreteChar.Uniform();
|
||||
scaledDist.SetToPartialUniformOf(dist, logProbabilityOverride);
|
||||
var scaledLogProbLetter = scaledDist.GetLogProb('j');
|
||||
var scaledLogProbNumber = scaledDist.GetLogProb('5');
|
||||
|
||||
Assert.Equal(scaledLogProbLetter, logProbabilityOverride, Eps);
|
||||
Assert.Equal(scaledLogProbNumber, logProbabilityOverride, Eps);
|
||||
|
||||
// Check that cache has not been compromised.
|
||||
Assert.Equal(probLetter, Math.Exp(dist.GetLogProb('j')), Eps);
|
||||
Assert.Equal(probNumber, Math.Exp(dist.GetLogProb('5')), Eps);
|
||||
|
||||
// Check that an exception is thrown if a bad maximumProbability is passed down.
|
||||
Xunit.Assert.Throws<ArgumentException>(() =>
|
||||
{
|
||||
var badDist = DiscreteChar.Uniform();
|
||||
badDist.SetToPartialUniformOf(dist, Math.Log(1.2));
|
||||
});
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BroadAndNarrow()
|
||||
{
|
||||
var dist1 = DiscreteChar.Digit();
|
||||
Xunit.Assert.True(dist1.IsBroad);
|
||||
|
||||
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
|
||||
Xunit.Assert.False(dist2.IsBroad);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void HasLogOverride()
|
||||
{
|
||||
var dist1 = DiscreteChar.LetterOrDigit();
|
||||
Xunit.Assert.False(dist1.HasLogProbabilityOverride);
|
||||
|
||||
dist1.SetToPartialUniformOf(dist1, Math.Log(0.9));
|
||||
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ProductWithLogOverrideBroad()
|
||||
{
|
||||
for (var i = 0; i < 2; i++)
|
||||
{
|
||||
var dist1 = DiscreteChar.LetterOrDigit();
|
||||
var dist2 = DiscreteChar.Digit();
|
||||
|
||||
var logOverrideProbability = Math.Log(0.9);
|
||||
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
|
||||
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||
Xunit.Assert.True(dist2.IsBroad);
|
||||
|
||||
if (i == 1)
|
||||
{
|
||||
Util.Swap(ref dist1, ref dist2);
|
||||
}
|
||||
|
||||
var dist3 = DiscreteChar.Uniform();
|
||||
dist3.SetToProduct(dist1, dist2);
|
||||
|
||||
Xunit.Assert.True(dist3.HasLogProbabilityOverride);
|
||||
Assert.Equal(logOverrideProbability, dist3.GetLogProb('5'), Eps);
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ProductWithLogOverrideNarrow()
|
||||
{
|
||||
for (var i = 0; i < 2; i++)
|
||||
{
|
||||
var dist1 = DiscreteChar.LetterOrDigit();
|
||||
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
|
||||
|
||||
var logOverrideProbability = Math.Log(0.9);
|
||||
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
|
||||
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||
Xunit.Assert.False(dist2.IsBroad);
|
||||
|
||||
if (i == 1)
|
||||
{
|
||||
Util.Swap(ref dist1, ref dist2);
|
||||
}
|
||||
|
||||
var dist3 = DiscreteChar.Uniform();
|
||||
dist3.SetToProduct(dist1, dist2);
|
||||
|
||||
Xunit.Assert.False(dist3.HasLogProbabilityOverride);
|
||||
Assert.Equal(Math.Log(0.25), dist3.GetLogProb('5'), Eps);
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,12 +258,12 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
foreach (var ch in included)
|
||||
{
|
||||
Assert.True(!double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should contain " + ch);
|
||||
Xunit.Assert.True(!double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should contain " + ch);
|
||||
}
|
||||
|
||||
foreach (var ch in excluded)
|
||||
{
|
||||
Assert.True(double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should not contain " + ch);
|
||||
Xunit.Assert.True(double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should not contain " + ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Text.RegularExpressions;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Tests
|
||||
{
|
||||
using System;
|
||||
|
@ -543,6 +545,130 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
});
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WordModel()
|
||||
{
|
||||
// We want to build a word model as a reasonably simple StringDistribution. It
|
||||
// should satisfy the following:
|
||||
// (1) The probability of a word of moderate length should not be
|
||||
// significantly less than the probability of a shorter word.
|
||||
// (2) The probability of a specific word conditioned on its length matches that of
|
||||
// words in the target language.
|
||||
// We achieve this by putting non-normalized character distributions on the edges. The
|
||||
// StringDistribution is unaware that these are non-normalized.
|
||||
// The StringDistribution itself is non-normalizable.
|
||||
const double TargetProb1 = 0.05;
|
||||
const double Ratio1 = 0.4;
|
||||
const double TargetProb2 = TargetProb1 * Ratio1;
|
||||
const double Ratio2 = 0.2;
|
||||
const double TargetProb3 = TargetProb2 * Ratio2;
|
||||
const double TargetProb4 = TargetProb3 * Ratio2;
|
||||
const double TargetProb5 = TargetProb4 * Ratio2;
|
||||
const double Ratio3 = 0.999;
|
||||
const double TargetProb6 = TargetProb5 * Ratio3;
|
||||
const double TargetProb7 = TargetProb6 * Ratio3;
|
||||
const double TargetProb8 = TargetProb7 * Ratio3;
|
||||
const double Ratio4 = 0.9;
|
||||
const double TargetProb9 = TargetProb8 * Ratio4;
|
||||
const double TargetProb10 = TargetProb9 * Ratio4;
|
||||
|
||||
var targetProbabilitiesPerLength = new double[]
|
||||
{
|
||||
TargetProb1, TargetProb2, TargetProb3, TargetProb4, TargetProb5, TargetProb6, TargetProb7, TargetProb8, TargetProb9, TargetProb10
|
||||
};
|
||||
|
||||
var charDistUpper = DiscreteChar.Upper();
|
||||
var charDistLower = DiscreteChar.Lower();
|
||||
var charDistUpperNarrow = DiscreteChar.OneOf('A', 'B');
|
||||
var charDistLowerNarrow = DiscreteChar.OneOf('a', 'b');
|
||||
var charDistUpperScaled = DiscreteChar.Uniform();
|
||||
var charDistLowerScaled1 = DiscreteChar.Uniform();
|
||||
var charDistLowerScaled2 = DiscreteChar.Uniform();
|
||||
var charDistLowerScaled3 = DiscreteChar.Uniform();
|
||||
var charDistLowerScaledEnd = DiscreteChar.Uniform();
|
||||
|
||||
charDistUpperScaled.SetToPartialUniformOf(charDistUpper, Math.Log(TargetProb1));
|
||||
charDistLowerScaled1.SetToPartialUniformOf(charDistLower, Math.Log(Ratio1));
|
||||
charDistLowerScaled2.SetToPartialUniformOf(charDistLower, Math.Log(Ratio2));
|
||||
charDistLowerScaled3.SetToPartialUniformOf(charDistLower, Math.Log(Ratio3));
|
||||
charDistLowerScaledEnd.SetToPartialUniformOf(charDistLower, Math.Log(Ratio4));
|
||||
|
||||
var wordModel = StringDistribution.Concatenate(
|
||||
new List<DiscreteChar>
|
||||
{
|
||||
charDistUpperScaled,
|
||||
charDistLowerScaled1,
|
||||
charDistLowerScaled2,
|
||||
charDistLowerScaled2,
|
||||
charDistLowerScaled2,
|
||||
charDistLowerScaled3,
|
||||
charDistLowerScaled3,
|
||||
charDistLowerScaled3,
|
||||
charDistLowerScaledEnd
|
||||
},
|
||||
true,
|
||||
true);
|
||||
|
||||
const string Word = "Abcdefghij";
|
||||
|
||||
const double Eps = 1e-5;
|
||||
var broadDist = StringDistribution.Char(charDistUpper);
|
||||
var narrowDist = StringDistribution.Char(charDistUpperNarrow);
|
||||
var narrowWord = "A";
|
||||
var expectedProbForNarrow = 0.5;
|
||||
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||
{
|
||||
var currentWord = Word.Substring(0, i + 1);
|
||||
var probCurrentWord = Math.Exp(wordModel.GetLogProb(currentWord));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
|
||||
var logAvg = Math.Exp(wordModel.GetLogAverageOf(broadDist));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], logAvg, Eps);
|
||||
|
||||
var prod = StringDistribution.Zero();
|
||||
prod.SetToProduct(broadDist, wordModel);
|
||||
Xunit.Assert.True(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides);
|
||||
probCurrentWord = Math.Exp(prod.GetLogProb(currentWord));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
|
||||
prod.SetToProduct(narrowDist, wordModel);
|
||||
Xunit.Assert.False(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides);
|
||||
var probNarrowWord = Math.Exp(prod.GetLogProb(narrowWord));
|
||||
Assert.Equal(expectedProbForNarrow, probNarrowWord, Eps);
|
||||
|
||||
broadDist = broadDist.Append(charDistLower);
|
||||
narrowDist = narrowDist.Append(charDistLowerNarrow);
|
||||
narrowWord += "a";
|
||||
expectedProbForNarrow *= 0.5;
|
||||
}
|
||||
|
||||
// Copied model
|
||||
var copiedModel = StringDistribution.FromWorkspace(StringTransducer.Copy().ProjectSource(wordModel.GetWorkspaceOrPoint()));
|
||||
// Under transducer.
|
||||
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||
{
|
||||
var currentWord = Word.Substring(0, i + 1);
|
||||
var probCurrentWord = Math.Exp(copiedModel.GetLogProb(currentWord));
|
||||
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
}
|
||||
|
||||
// Rescaled model
|
||||
var scale = 0.5;
|
||||
var newTargetProb1 = TargetProb1 * scale;
|
||||
var charDistUpperScaled1 = DiscreteChar.Uniform();
|
||||
charDistUpperScaled1.SetToPartialUniformOf(charDistUpper, Math.Log(newTargetProb1));
|
||||
var reWeightingTransducer =
|
||||
StringTransducer.Replace(StringDistribution.Char(charDistUpper).GetWorkspaceOrPoint(), StringDistribution.Char(charDistUpperScaled1).GetWorkspaceOrPoint())
|
||||
.Append(StringTransducer.Copy());
|
||||
var reWeightedWordModel = StringDistribution.FromWorkspace(reWeightingTransducer.ProjectSource(wordModel.GetWorkspaceOrPoint()));
|
||||
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||
{
|
||||
var currentWord = Word.Substring(0, i + 1);
|
||||
var probCurrentWord = Math.Exp(reWeightedWordModel.GetLogProb(currentWord));
|
||||
Assert.Equal(scale * targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||
}
|
||||
}
|
||||
|
||||
#region Sampling tests
|
||||
|
||||
/// <summary>
|
||||
|
|
Загрузка…
Ссылка в новой задаче