Log probability override in DiscreteChar (#206)

Log probability override for Discrete char
This commit is contained in:
John Guiver 2020-01-13 14:27:31 +00:00 коммит произвёл GitHub
Родитель 000818158c
Коммит 053bac1751
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 495 добавлений и 48 удалений

Просмотреть файл

@ -2255,7 +2255,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.Data = builder.GetData();
this.LogValueOverride = automaton.LogValueOverride;
this.PruneStatesWithLogEndWeightLessThan = automaton.LogValueOverride;
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
}
#endregion
@ -2276,7 +2276,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
return automaton2.IsZero() ? double.NegativeInfinity : 1;
}
TThis theConverger = GetConverger(automaton1, automaton2);
TThis theConverger = GetConverger(new TThis[] {automaton1, automaton2});
var automaton1conv = automaton1.Product(theConverger);
var automaton2conv = automaton2.Product(theConverger);
@ -2310,8 +2310,20 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
/// </summary>
/// <param name="automata">The automata.</param>
/// <param name="decayWeight">The decay weight.</param>
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
public static TThis GetConverger(params TThis[] automata)
public static TThis GetConverger(TThis automata, double decayWeight = 0.99)
{
return GetConverger(new TThis[] {automata}, decayWeight);
}
/// <summary>
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
/// </summary>
/// <param name="automata">The automata.</param>
/// <param name="decayWeight">The decay weight.</param>
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
public static TThis GetConverger(TThis[] automata, double decayWeight = 0.99)
{
// TODO: This method might not work in the presense of non-trivial loops.
@ -2347,7 +2359,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
Weight transitionWeight = Weight.Product(
Weight.FromLogValue(-uniformDist.GetLogAverageOf(uniformDist)),
Weight.FromLogValue(-maxLogTransitionWeightSum),
Weight.FromValue(0.99));
Weight.FromValue(decayWeight));
theConverger.Start.AddSelfTransition(uniformDist, transitionWeight);
theConverger.Start.SetEndWeight(Weight.One);
@ -2679,11 +2691,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
var propertyMask = new BitVector32();
var idx = 0;
propertyMask[1 << idx++] = true; // isEpsilonFree is alway known
propertyMask[1 << idx++] = true; // isEpsilonFree is always known
propertyMask[1 << idx++] = this.Data.IsEpsilonFree;
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
propertyMask[1 << idx++] = this.PruneStatesWithLogEndWeightLessThan.HasValue;
propertyMask[1 << idx++] = true; // start state is alway serialized
propertyMask[1 << idx++] = true; // start state is always serialized
writeInt32(propertyMask.Data);
@ -2711,8 +2723,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Reads an automaton from.
/// </summary>
/// <remarks>
/// Serializtion format is a bit unnatural, but we do it for compatiblity with old serialized data.
/// So we don't have to maintain 2 versions of derserialization
/// Serialization format is a bit unnatural, but we do it for compatibility with old serialized data.
/// So we don't have to maintain 2 versions of deserialization.
/// </remarks>
public static TThis Read(Func<double> readDouble, Func<int> readInt32, Func<TElementDistribution> readElementDistribution)
{

Просмотреть файл

@ -10,6 +10,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
/// <summary>
/// Represents a weighted finite state automaton defined on <see cref="string"/>.
@ -21,6 +22,19 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
}
/// <summary>
/// Whether there are log value overrides at the element level.
/// </summary>
public bool HasElementLogValueOverrides
{
get
{
return this.States.transitions.Any(
trans => trans.ElementDistribution.HasValue &&
trans.ElementDistribution.Value.HasLogProbabilityOverride);
}
}
/// <summary>
/// Computes a set of outgoing transitions from a given state of the determinization result.
/// </summary>

Просмотреть файл

@ -347,6 +347,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Populate the stack with start destination state
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
var stringAutomaton = srcAutomaton as StringAutomaton;
var sourceDistributionHasLogProbabilityOverrides = stringAutomaton?.HasElementLogValueOverrides ?? false;
while (stack.Count > 0)
{
@ -387,7 +389,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
continue;
}
var destWeight = Weight.Product(mappingTransition.Weight, srcTransition.Weight, Weight.FromLogValue(projectionLogScale));
// In the special case of a log probability override in a DiscreteChar element distribution,
// we need to compensate for the fact that the distribution is not normalized.
if (destElementDistribution.HasValue && sourceDistributionHasLogProbabilityOverrides)
{
var discreteChar =
(DiscreteChar)(IDistribution<char>)srcTransition.ElementDistribution.Value;
if (discreteChar.HasLogProbabilityOverride)
{
var totalMass = discreteChar.Ranges.EnumerableSum(rng =>
rng.Probability.Value * (rng.EndExclusive - rng.StartInclusive));
projectionLogScale -= System.Math.Log(totalMass);
}
}
var destWeight =
sourceDistributionHasLogProbabilityOverrides && destElementDistribution.HasNoValue
? Weight.One
: Weight.Product(mappingTransition.Weight, srcTransition.Weight,
Weight.FromLogValue(projectionLogScale));
// We don't want an unnormalizable distribution to become normalizable due to a rounding error.
if (Math.Abs(destWeight.LogValue) < 1e-12)
{
destWeight = Weight.One;
}
var childDestStateIndex = CreateDestState(childMappingState, srcChildState);
destState.AddTransition(destElementDistribution, destWeight, childDestStateIndex, mappingTransition.Group);
}

Просмотреть файл

@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections;
namespace Microsoft.ML.Probabilistic.Distributions
{
using System;
@ -183,6 +185,22 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// </summary>
public bool IsWordChar => this.Data.IsWordChar;
/// <summary>
/// Gets a value indicating whether this distribution is broad - i.e. a general class of values.
/// </summary>
public bool IsBroad => this.Data.IsBroad;
/// <summary>
/// Gets a value indicating whether this distribution is partial uniform with a log probability override.
/// </summary>
public bool HasLogProbabilityOverride => this.Data.HasLogProbabilityOverride;
/// <summary>
/// Gets a value for the log probability override if it exists.
/// </summary>
public double? LogProbabilityOverride =>
this.HasLogProbabilityOverride ? this.Ranges.First().Probability.LogValue : (double?)null;
#endregion
#region Distribution properties
@ -530,13 +548,34 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
var builder = StorageBuilder.Create();
foreach (var pair in CharRangePair.IntersectRanges(distribution1, distribution2))
{
var probProduct = pair.Probability1 * pair.Probability2;
builder.AddRange(new CharRange(pair.StartInclusive, pair.EndExclusive, probProduct));
}
this.Data = builder.GetResult();
double? logProbabilityOverride = null;
var distribution1LogProbabilityOverride = distribution1.LogProbabilityOverride;
var distribution2LogProbabilityOverride = distribution2.LogProbabilityOverride;
if (distribution1LogProbabilityOverride.HasValue)
{
if (distribution2LogProbabilityOverride.HasValue)
{
throw new ArgumentException("Only one distribution in a DiscreteChar product may have a log probability override");
}
if (distribution2.IsBroad)
{
logProbabilityOverride = distribution1LogProbabilityOverride;
}
}
else if (distribution2LogProbabilityOverride.HasValue && distribution1.IsBroad)
{
logProbabilityOverride = distribution2LogProbabilityOverride;
}
this.Data = builder.GetResult(logProbabilityOverride);
}
/// <summary>
@ -629,8 +668,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <summary>
/// Sets the distribution to be uniform over the support of a given distribution.
/// </summary>
/// <param name="distribution">The distribution which support will be used to setup the current distribution.</param>
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
public void SetToPartialUniformOf(DiscreteChar distribution)
{
SetToPartialUniformOf(distribution, null);
}
/// <summary>
/// Sets the distribution to be uniform over the support of a given distribution.
/// </summary>
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
/// against this distribution. If this is set, then the distribution will not be normalize;
/// i.e. the probabilities will not sum to 1 over the support.</param>
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
/// to create more realistic language model priors. Distributions with this override are always uncached.
/// </remarks>
public void SetToPartialUniformOf(DiscreteChar distribution, double? logProbabilityOverride)
{
var builder = StorageBuilder.Create();
foreach (var range in distribution.Ranges)
@ -642,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
range.Probability.IsZero ? Weight.Zero : Weight.One));
}
this.Data = builder.GetResult();
this.Data = builder.GetResult(logProbabilityOverride);
}
/// <summary>
@ -1398,11 +1452,38 @@ namespace Microsoft.ML.Probabilistic.Distributions
public char? Point { get; }
// Following 3 members are not immutable and can be recalculated on-demand
// Following members are not immutable and can be recalculated on-demand
public CharClasses CharClasses { get; private set; }
private string regexRepresentation;
private string symbolRepresentation;
/// <summary>
/// Flags derived from ranges.
/// </summary>
/// <returns></returns>
private readonly Flags flags;
/// <summary>
/// Whether this distribution has broad support. We want to be able to
/// distinguish between specific distributions and general distributions.
/// Use the number of non-zero digits as a threshold.
/// </summary>
/// <returns></returns>
public bool IsBroad => (this.flags & Flags.IsBroad) != 0;
/// <summary>
/// Whether this distribution is partial uniform with a log probability override.
/// </summary>
/// <returns></returns>
public bool HasLogProbabilityOverride => (this.flags & Flags.HasLogProbabilityOverride) != 0;
[Flags]
private enum Flags
{
HasLogProbabilityOverride = 0x01,
IsBroad = 0x02
}
#endregion
#region Constructor and factory methods
@ -1419,6 +1500,21 @@ namespace Microsoft.ML.Probabilistic.Distributions
this.CharClasses = charClasses;
this.regexRepresentation = regexRepresentation;
this.symbolRepresentation = symbolRepresentation;
var supportCount = this.Ranges.Where(range => !range.Probability.IsZero).Sum(range => range.EndExclusive - range.StartInclusive);
var isBroad = supportCount >= 9; // Number of non-zero digits.
var totalMass = this.Ranges.Sum(range =>
range.Probability.Value * (range.EndExclusive - range.StartInclusive));
var isScaled = Math.Abs(totalMass - 1.0) > 1e-10;
this.flags = 0;
if (isBroad)
{
flags |= Flags.IsBroad;
}
if (isScaled)
{
flags |= Flags.HasLogProbabilityOverride;
}
}
public static Storage CreateUncached(
@ -1896,7 +1992,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
private readonly List<CharRange> ranges;
/// <summary>
/// Precomuted character class.
/// Precomputed character class.
/// </summary>
private readonly CharClasses charClasses;
@ -1971,11 +2067,19 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <summary>
/// Normalizes probabilities in ranges and returns build Storage.
/// </summary>
public Storage GetResult()
public Storage GetResult(double? maximumProbability = null)
{
this.MergeNeighboringRanges();
this.NormalizeProbabilities();
return Storage.Create(
NormalizeProbabilities(this.ranges, maximumProbability);
return
maximumProbability.HasValue
? Storage.CreateUncached(
this.ranges.ToArray(),
null,
this.charClasses,
this.regexRepresentation,
this.symbolRepresentation)
: Storage.Create(
this.ranges.ToArray(),
this.charClasses,
this.regexRepresentation,
@ -2015,16 +2119,39 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
/// <summary>
/// Normalizes probabilities in ranges
/// Normalizes probabilities in ranges.
/// </summary>
private void NormalizeProbabilities()
/// <param name="ranges">The ranges.</param>
/// <param name="logProbabilityOverride">Ignores the probabilities in the ranges and creates a non-normalized partial uniform distribution.</param>
/// <exception cref="ArgumentException">Thrown if logProbabilityOverride has value corresponding to a non-probability.</exception>
public static void NormalizeProbabilities(IList<CharRange> ranges, double? logProbabilityOverride = null)
{
var normalizer = this.ComputeInvNormalizer();
for (int i = 0; i < this.ranges.Count; ++i)
if (logProbabilityOverride.HasValue)
{
var range = this.ranges[i];
this.ranges[i] = new CharRange(
range.StartInclusive, range.EndExclusive, range.Probability * normalizer);
var weight = Weight.FromLogValue(logProbabilityOverride.Value);
if (weight.IsZero || weight.Value > 1)
{
throw new ArgumentException("Invalid log probability override.");
}
for (var i = 0; i < ranges.Count; ++i)
{
var range = ranges[i];
ranges[i] = new CharRange(
range.StartInclusive, range.EndExclusive, weight);
}
}
else
{
var normalizer = ComputeInvNormalizer(ranges);
for (var i = 0; i < ranges.Count; ++i)
{
var range = ranges[i];
var probability = range.Probability * normalizer;
ranges[i] = new CharRange(
range.StartInclusive, range.EndExclusive, probability);
}
}
}
@ -2032,11 +2159,11 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// Computes the normalizer of this distribution.
/// </summary>
/// <returns>The computed normalizer.</returns>
private Weight ComputeInvNormalizer()
private static Weight ComputeInvNormalizer(IEnumerable<CharRange> ranges)
{
Weight normalizer = Weight.Zero;
var normalizer = Weight.Zero;
foreach (var range in this.ranges)
foreach (var range in ranges)
{
normalizer += Weight.FromValue(range.EndExclusive - range.StartInclusive) * range.Probability;
}

Просмотреть файл

@ -244,17 +244,35 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
/// <summary>
/// Creates a distribution over sequences induced by a given list of distributions over sequence elements.
/// Creates a distribution over sequences induced by a given list of distributions over sequence elements
/// where the sequence can optionally end at any length, and the last element can optionally repeat without limit.
/// </summary>
/// <param name="sequence">Enumerable of distributions over sequence elements.</param>
/// <param name="elementDistributions">Enumerable of distributions over sequence elements and the transition weights.</param>
/// <param name="allowEarlyEnd">Allow the sequence to end at any point.</param>
/// <param name="repeatLastElement">Repeat the last element.</param>
/// <returns>The created distribution.</returns>
public static TThis Concatenate(IEnumerable<TElementDistribution> sequence)
public static TThis Concatenate(IEnumerable<TElementDistribution> elementDistributions, bool allowEarlyEnd = false, bool repeatLastElement = false)
{
var result = new Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>.Builder();
var last = result.Start;
foreach (var elem in sequence)
var elementDistributionArray = elementDistributions.ToArray();
for (var i = 0; i < elementDistributionArray.Length - 1; i++)
{
last = last.AddTransition(elem, Weight.One);
last = last.AddTransition(elementDistributionArray[i], Weight.One);
if (allowEarlyEnd)
{
last.SetEndWeight(Weight.One);
}
}
var lastElement = elementDistributionArray[elementDistributionArray.Length - 1];
if (repeatLastElement)
{
last.AddSelfTransition(lastElement, Weight.One);
}
else
{
last = last.AddTransition(lastElement, Weight.One);
}
last.SetEndWeight(Weight.One);
@ -1602,6 +1620,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
return !this.IsPointMass && this.sequenceToWeight.IsZero();
}
/// <summary>
/// Converges an improper sequence distribution
/// </summary>
/// <param name="dist">The original distribution.</param>
/// <param name="decayWeight">The decay weight.</param>
/// <returns>The converged distribution.</returns>
public static TThis Converge(TThis dist, double decayWeight = 0.99)
{
var converger =
Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>
.GetConverger(new TWeightFunction[]
{
dist.sequenceToWeight
}, decayWeight);
return dist.Product(FromWorkspace(converger));
}
/// <summary>
/// Checks if <paramref name="obj"/> equals to this distribution (i.e. represents the same distribution over sequences).
/// </summary>

Просмотреть файл

@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Reflection.Metadata.Ecma335;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Utilities;
namespace Microsoft.ML.Probabilistic.Tests
{
@ -10,13 +12,15 @@ namespace Microsoft.ML.Probabilistic.Tests
using System.Collections.Generic;
using Xunit;
using Microsoft.ML.Probabilistic.Distributions;
using Assert = Xunit.Assert;
using Assert = Microsoft.ML.Probabilistic.Tests.AssertHelper;
/// <summary>
/// Tests for <see cref="DiscreteChar"/>.
/// </summary>
public class DiscreteCharTest
{
const double Eps = 1e-10;
/// <summary>
/// Runs a set of common distribution tests for <see cref="DiscreteChar"/>.
/// </summary>
@ -63,7 +67,7 @@ namespace Microsoft.ML.Probabilistic.Tests
return;
}
Assert.True(false);
Xunit.Assert.True(false);
}
[Fact]
@ -88,7 +92,7 @@ namespace Microsoft.ML.Probabilistic.Tests
var unif = Vector.Constant(numChars, 1.0 / numChars);
var maxDiff = hist.MaxDiff(unif);
Assert.True(maxDiff < 0.01);
Xunit.Assert.True(maxDiff < 0.01);
}
[Fact]
@ -106,7 +110,7 @@ namespace Microsoft.ML.Probabilistic.Tests
ab.SetToSum(1, a, 2, b);
// 2 subsequent ranges
Assert.Equal(2, ab.Ranges.Count);
Xunit.Assert.Equal(2, ab.Ranges.Count);
TestComplement(ab);
void TestComplement(DiscreteChar dist)
@ -117,21 +121,123 @@ namespace Microsoft.ML.Probabilistic.Tests
var complement = dist.Complement();
// complement should always be partial uniform
Assert.True(complement.IsPartialUniform());
Xunit.Assert.True(complement.IsPartialUniform());
// overlap is zero
Assert.True(double.IsNegativeInfinity(dist.GetLogAverageOf(complement)));
Assert.True(double.IsNegativeInfinity(uniformDist.GetLogAverageOf(complement)));
Xunit.Assert.True(double.IsNegativeInfinity(dist.GetLogAverageOf(complement)));
Xunit.Assert.True(double.IsNegativeInfinity(uniformDist.GetLogAverageOf(complement)));
// union is covers the whole range
var sum = default(DiscreteChar);
sum.SetToSum(1, dist, 1, complement);
sum.SetToPartialUniform();
Assert.True(sum.IsUniform());
Xunit.Assert.True(sum.IsUniform());
// Doing complement again will cover the same set of characters
var complement2 = complement.Complement();
Assert.Equal(uniformDist, complement2);
Xunit.Assert.Equal(uniformDist, complement2);
}
}
[Fact]
public void PartialUniformWithLogProbabilityOverride()
{
var dist = DiscreteChar.LetterOrDigit();
var probLetter = Math.Exp(dist.GetLogProb('j'));
var probNumber = Math.Exp(dist.GetLogProb('5'));
var logProbabilityOverride = Math.Log(0.7);
var scaledDist = DiscreteChar.Uniform();
scaledDist.SetToPartialUniformOf(dist, logProbabilityOverride);
var scaledLogProbLetter = scaledDist.GetLogProb('j');
var scaledLogProbNumber = scaledDist.GetLogProb('5');
Assert.Equal(scaledLogProbLetter, logProbabilityOverride, Eps);
Assert.Equal(scaledLogProbNumber, logProbabilityOverride, Eps);
// Check that cache has not been compromised.
Assert.Equal(probLetter, Math.Exp(dist.GetLogProb('j')), Eps);
Assert.Equal(probNumber, Math.Exp(dist.GetLogProb('5')), Eps);
// Check that an exception is thrown if a bad maximumProbability is passed down.
Xunit.Assert.Throws<ArgumentException>(() =>
{
var badDist = DiscreteChar.Uniform();
badDist.SetToPartialUniformOf(dist, Math.Log(1.2));
});
}
[Fact]
public void BroadAndNarrow()
{
var dist1 = DiscreteChar.Digit();
Xunit.Assert.True(dist1.IsBroad);
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
Xunit.Assert.False(dist2.IsBroad);
}
[Fact]
public void HasLogOverride()
{
var dist1 = DiscreteChar.LetterOrDigit();
Xunit.Assert.False(dist1.HasLogProbabilityOverride);
dist1.SetToPartialUniformOf(dist1, Math.Log(0.9));
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
}
[Fact]
public void ProductWithLogOverrideBroad()
{
for (var i = 0; i < 2; i++)
{
var dist1 = DiscreteChar.LetterOrDigit();
var dist2 = DiscreteChar.Digit();
var logOverrideProbability = Math.Log(0.9);
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
Xunit.Assert.True(dist2.IsBroad);
if (i == 1)
{
Util.Swap(ref dist1, ref dist2);
}
var dist3 = DiscreteChar.Uniform();
dist3.SetToProduct(dist1, dist2);
Xunit.Assert.True(dist3.HasLogProbabilityOverride);
Assert.Equal(logOverrideProbability, dist3.GetLogProb('5'), Eps);
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
}
}
[Fact]
public void ProductWithLogOverrideNarrow()
{
for (var i = 0; i < 2; i++)
{
var dist1 = DiscreteChar.LetterOrDigit();
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
var logOverrideProbability = Math.Log(0.9);
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
Xunit.Assert.False(dist2.IsBroad);
if (i == 1)
{
Util.Swap(ref dist1, ref dist2);
}
var dist3 = DiscreteChar.Uniform();
dist3.SetToProduct(dist1, dist2);
Xunit.Assert.False(dist3.HasLogProbabilityOverride);
Assert.Equal(Math.Log(0.25), dist3.GetLogProb('5'), Eps);
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
}
}
@ -152,12 +258,12 @@ namespace Microsoft.ML.Probabilistic.Tests
foreach (var ch in included)
{
Assert.True(!double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should contain " + ch);
Xunit.Assert.True(!double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should contain " + ch);
}
foreach (var ch in excluded)
{
Assert.True(double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should not contain " + ch);
Xunit.Assert.True(double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should not contain " + ch);
}
}
}

Просмотреть файл

@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Text.RegularExpressions;
namespace Microsoft.ML.Probabilistic.Tests
{
using System;
@ -543,6 +545,130 @@ namespace Microsoft.ML.Probabilistic.Tests
});
}
[Fact]
public void WordModel()
{
// We want to build a word model as a reasonably simple StringDistribution. It
// should satisfy the following:
// (1) The probability of a word of moderate length should not be
// significantly less than the probability of a shorter word.
// (2) The probability of a specific word conditioned on its length matches that of
// words in the target language.
// We achieve this by putting non-normalized character distributions on the edges. The
// StringDistribution is unaware that these are non-normalized.
// The StringDistribution itself is non-normalizable.
const double TargetProb1 = 0.05;
const double Ratio1 = 0.4;
const double TargetProb2 = TargetProb1 * Ratio1;
const double Ratio2 = 0.2;
const double TargetProb3 = TargetProb2 * Ratio2;
const double TargetProb4 = TargetProb3 * Ratio2;
const double TargetProb5 = TargetProb4 * Ratio2;
const double Ratio3 = 0.999;
const double TargetProb6 = TargetProb5 * Ratio3;
const double TargetProb7 = TargetProb6 * Ratio3;
const double TargetProb8 = TargetProb7 * Ratio3;
const double Ratio4 = 0.9;
const double TargetProb9 = TargetProb8 * Ratio4;
const double TargetProb10 = TargetProb9 * Ratio4;
var targetProbabilitiesPerLength = new double[]
{
TargetProb1, TargetProb2, TargetProb3, TargetProb4, TargetProb5, TargetProb6, TargetProb7, TargetProb8, TargetProb9, TargetProb10
};
var charDistUpper = DiscreteChar.Upper();
var charDistLower = DiscreteChar.Lower();
var charDistUpperNarrow = DiscreteChar.OneOf('A', 'B');
var charDistLowerNarrow = DiscreteChar.OneOf('a', 'b');
var charDistUpperScaled = DiscreteChar.Uniform();
var charDistLowerScaled1 = DiscreteChar.Uniform();
var charDistLowerScaled2 = DiscreteChar.Uniform();
var charDistLowerScaled3 = DiscreteChar.Uniform();
var charDistLowerScaledEnd = DiscreteChar.Uniform();
charDistUpperScaled.SetToPartialUniformOf(charDistUpper, Math.Log(TargetProb1));
charDistLowerScaled1.SetToPartialUniformOf(charDistLower, Math.Log(Ratio1));
charDistLowerScaled2.SetToPartialUniformOf(charDistLower, Math.Log(Ratio2));
charDistLowerScaled3.SetToPartialUniformOf(charDistLower, Math.Log(Ratio3));
charDistLowerScaledEnd.SetToPartialUniformOf(charDistLower, Math.Log(Ratio4));
var wordModel = StringDistribution.Concatenate(
new List<DiscreteChar>
{
charDistUpperScaled,
charDistLowerScaled1,
charDistLowerScaled2,
charDistLowerScaled2,
charDistLowerScaled2,
charDistLowerScaled3,
charDistLowerScaled3,
charDistLowerScaled3,
charDistLowerScaledEnd
},
true,
true);
const string Word = "Abcdefghij";
const double Eps = 1e-5;
var broadDist = StringDistribution.Char(charDistUpper);
var narrowDist = StringDistribution.Char(charDistUpperNarrow);
var narrowWord = "A";
var expectedProbForNarrow = 0.5;
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
{
var currentWord = Word.Substring(0, i + 1);
var probCurrentWord = Math.Exp(wordModel.GetLogProb(currentWord));
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
var logAvg = Math.Exp(wordModel.GetLogAverageOf(broadDist));
Assert.Equal(targetProbabilitiesPerLength[i], logAvg, Eps);
var prod = StringDistribution.Zero();
prod.SetToProduct(broadDist, wordModel);
Xunit.Assert.True(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides);
probCurrentWord = Math.Exp(prod.GetLogProb(currentWord));
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
prod.SetToProduct(narrowDist, wordModel);
Xunit.Assert.False(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides);
var probNarrowWord = Math.Exp(prod.GetLogProb(narrowWord));
Assert.Equal(expectedProbForNarrow, probNarrowWord, Eps);
broadDist = broadDist.Append(charDistLower);
narrowDist = narrowDist.Append(charDistLowerNarrow);
narrowWord += "a";
expectedProbForNarrow *= 0.5;
}
// Copied model
var copiedModel = StringDistribution.FromWorkspace(StringTransducer.Copy().ProjectSource(wordModel.GetWorkspaceOrPoint()));
// Under transducer.
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
{
var currentWord = Word.Substring(0, i + 1);
var probCurrentWord = Math.Exp(copiedModel.GetLogProb(currentWord));
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
}
// Rescaled model
var scale = 0.5;
var newTargetProb1 = TargetProb1 * scale;
var charDistUpperScaled1 = DiscreteChar.Uniform();
charDistUpperScaled1.SetToPartialUniformOf(charDistUpper, Math.Log(newTargetProb1));
var reWeightingTransducer =
StringTransducer.Replace(StringDistribution.Char(charDistUpper).GetWorkspaceOrPoint(), StringDistribution.Char(charDistUpperScaled1).GetWorkspaceOrPoint())
.Append(StringTransducer.Copy());
var reWeightedWordModel = StringDistribution.FromWorkspace(reWeightingTransducer.ProjectSource(wordModel.GetWorkspaceOrPoint()));
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
{
var currentWord = Word.Substring(0, i + 1);
var probCurrentWord = Math.Exp(reWeightedWordModel.GetLogProb(currentWord));
Assert.Equal(scale * targetProbabilitiesPerLength[i], probCurrentWord, Eps);
}
}
#region Sampling tests
/// <summary>