зеркало из https://github.com/dotnet/infer.git
Log probability override in DiscreteChar (#206)
Log probability override for Discrete char
This commit is contained in:
Родитель
000818158c
Коммит
053bac1751
|
@ -2255,7 +2255,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
|
|
||||||
this.Data = builder.GetData();
|
this.Data = builder.GetData();
|
||||||
this.LogValueOverride = automaton.LogValueOverride;
|
this.LogValueOverride = automaton.LogValueOverride;
|
||||||
this.PruneStatesWithLogEndWeightLessThan = automaton.LogValueOverride;
|
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
@ -2276,7 +2276,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
return automaton2.IsZero() ? double.NegativeInfinity : 1;
|
return automaton2.IsZero() ? double.NegativeInfinity : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
TThis theConverger = GetConverger(automaton1, automaton2);
|
TThis theConverger = GetConverger(new TThis[] {automaton1, automaton2});
|
||||||
var automaton1conv = automaton1.Product(theConverger);
|
var automaton1conv = automaton1.Product(theConverger);
|
||||||
var automaton2conv = automaton2.Product(theConverger);
|
var automaton2conv = automaton2.Product(theConverger);
|
||||||
|
|
||||||
|
@ -2310,8 +2310,20 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
|
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="automata">The automata.</param>
|
/// <param name="automata">The automata.</param>
|
||||||
|
/// <param name="decayWeight">The decay weight.</param>
|
||||||
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
|
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
|
||||||
public static TThis GetConverger(params TThis[] automata)
|
public static TThis GetConverger(TThis automata, double decayWeight = 0.99)
|
||||||
|
{
|
||||||
|
return GetConverger(new TThis[] {automata}, decayWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets an automaton such that every given automaton, if multiplied by it, becomes normalizable.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="automata">The automata.</param>
|
||||||
|
/// <param name="decayWeight">The decay weight.</param>
|
||||||
|
/// <returns>An automaton, product with which will make every given automaton normalizable.</returns>
|
||||||
|
public static TThis GetConverger(TThis[] automata, double decayWeight = 0.99)
|
||||||
{
|
{
|
||||||
// TODO: This method might not work in the presense of non-trivial loops.
|
// TODO: This method might not work in the presense of non-trivial loops.
|
||||||
|
|
||||||
|
@ -2347,7 +2359,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
Weight transitionWeight = Weight.Product(
|
Weight transitionWeight = Weight.Product(
|
||||||
Weight.FromLogValue(-uniformDist.GetLogAverageOf(uniformDist)),
|
Weight.FromLogValue(-uniformDist.GetLogAverageOf(uniformDist)),
|
||||||
Weight.FromLogValue(-maxLogTransitionWeightSum),
|
Weight.FromLogValue(-maxLogTransitionWeightSum),
|
||||||
Weight.FromValue(0.99));
|
Weight.FromValue(decayWeight));
|
||||||
theConverger.Start.AddSelfTransition(uniformDist, transitionWeight);
|
theConverger.Start.AddSelfTransition(uniformDist, transitionWeight);
|
||||||
theConverger.Start.SetEndWeight(Weight.One);
|
theConverger.Start.SetEndWeight(Weight.One);
|
||||||
|
|
||||||
|
@ -2679,11 +2691,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
{
|
{
|
||||||
var propertyMask = new BitVector32();
|
var propertyMask = new BitVector32();
|
||||||
var idx = 0;
|
var idx = 0;
|
||||||
propertyMask[1 << idx++] = true; // isEpsilonFree is alway known
|
propertyMask[1 << idx++] = true; // isEpsilonFree is always known
|
||||||
propertyMask[1 << idx++] = this.Data.IsEpsilonFree;
|
propertyMask[1 << idx++] = this.Data.IsEpsilonFree;
|
||||||
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
|
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
|
||||||
propertyMask[1 << idx++] = this.PruneStatesWithLogEndWeightLessThan.HasValue;
|
propertyMask[1 << idx++] = this.PruneStatesWithLogEndWeightLessThan.HasValue;
|
||||||
propertyMask[1 << idx++] = true; // start state is alway serialized
|
propertyMask[1 << idx++] = true; // start state is always serialized
|
||||||
|
|
||||||
writeInt32(propertyMask.Data);
|
writeInt32(propertyMask.Data);
|
||||||
|
|
||||||
|
@ -2711,8 +2723,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
/// Reads an automaton from.
|
/// Reads an automaton from.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <remarks>
|
/// <remarks>
|
||||||
/// Serializtion format is a bit unnatural, but we do it for compatiblity with old serialized data.
|
/// Serialization format is a bit unnatural, but we do it for compatibility with old serialized data.
|
||||||
/// So we don't have to maintain 2 versions of derserialization
|
/// So we don't have to maintain 2 versions of deserialization.
|
||||||
/// </remarks>
|
/// </remarks>
|
||||||
public static TThis Read(Func<double> readDouble, Func<int> readInt32, Func<TElementDistribution> readElementDistribution)
|
public static TThis Read(Func<double> readDouble, Func<int> readInt32, Func<TElementDistribution> readElementDistribution)
|
||||||
{
|
{
|
||||||
|
|
|
@ -10,6 +10,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Diagnostics;
|
using System.Diagnostics;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
|
using System.Linq;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Represents a weighted finite state automaton defined on <see cref="string"/>.
|
/// Represents a weighted finite state automaton defined on <see cref="string"/>.
|
||||||
|
@ -20,6 +21,19 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
public StringAutomaton()
|
public StringAutomaton()
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Whether there are log value overrides at the element level.
|
||||||
|
/// </summary>
|
||||||
|
public bool HasElementLogValueOverrides
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
return this.States.transitions.Any(
|
||||||
|
trans => trans.ElementDistribution.HasValue &&
|
||||||
|
trans.ElementDistribution.Value.HasLogProbabilityOverride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Computes a set of outgoing transitions from a given state of the determinization result.
|
/// Computes a set of outgoing transitions from a given state of the determinization result.
|
||||||
|
|
|
@ -347,6 +347,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
|
|
||||||
// Populate the stack with start destination state
|
// Populate the stack with start destination state
|
||||||
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
|
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
|
||||||
|
var stringAutomaton = srcAutomaton as StringAutomaton;
|
||||||
|
var sourceDistributionHasLogProbabilityOverrides = stringAutomaton?.HasElementLogValueOverrides ?? false;
|
||||||
|
|
||||||
while (stack.Count > 0)
|
while (stack.Count > 0)
|
||||||
{
|
{
|
||||||
|
@ -387,7 +389,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
var destWeight = Weight.Product(mappingTransition.Weight, srcTransition.Weight, Weight.FromLogValue(projectionLogScale));
|
// In the special case of a log probability override in a DiscreteChar element distribution,
|
||||||
|
// we need to compensate for the fact that the distribution is not normalized.
|
||||||
|
if (destElementDistribution.HasValue && sourceDistributionHasLogProbabilityOverrides)
|
||||||
|
{
|
||||||
|
var discreteChar =
|
||||||
|
(DiscreteChar)(IDistribution<char>)srcTransition.ElementDistribution.Value;
|
||||||
|
if (discreteChar.HasLogProbabilityOverride)
|
||||||
|
{
|
||||||
|
var totalMass = discreteChar.Ranges.EnumerableSum(rng =>
|
||||||
|
rng.Probability.Value * (rng.EndExclusive - rng.StartInclusive));
|
||||||
|
projectionLogScale -= System.Math.Log(totalMass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var destWeight =
|
||||||
|
sourceDistributionHasLogProbabilityOverrides && destElementDistribution.HasNoValue
|
||||||
|
? Weight.One
|
||||||
|
: Weight.Product(mappingTransition.Weight, srcTransition.Weight,
|
||||||
|
Weight.FromLogValue(projectionLogScale));
|
||||||
|
|
||||||
|
// We don't want an unnormalizable distribution to become normalizable due to a rounding error.
|
||||||
|
if (Math.Abs(destWeight.LogValue) < 1e-12)
|
||||||
|
{
|
||||||
|
destWeight = Weight.One;
|
||||||
|
}
|
||||||
|
|
||||||
var childDestStateIndex = CreateDestState(childMappingState, srcChildState);
|
var childDestStateIndex = CreateDestState(childMappingState, srcChildState);
|
||||||
destState.AddTransition(destElementDistribution, destWeight, childDestStateIndex, mappingTransition.Group);
|
destState.AddTransition(destElementDistribution, destWeight, childDestStateIndex, mappingTransition.Group);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System.Collections;
|
||||||
|
|
||||||
namespace Microsoft.ML.Probabilistic.Distributions
|
namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
{
|
{
|
||||||
using System;
|
using System;
|
||||||
|
@ -183,6 +185,22 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public bool IsWordChar => this.Data.IsWordChar;
|
public bool IsWordChar => this.Data.IsWordChar;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets a value indicating whether this distribution is broad - i.e. a general class of values.
|
||||||
|
/// </summary>
|
||||||
|
public bool IsBroad => this.Data.IsBroad;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets a value indicating whether this distribution is partial uniform with a log probability override.
|
||||||
|
/// </summary>
|
||||||
|
public bool HasLogProbabilityOverride => this.Data.HasLogProbabilityOverride;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets a value for the log probability override if it exists.
|
||||||
|
/// </summary>
|
||||||
|
public double? LogProbabilityOverride =>
|
||||||
|
this.HasLogProbabilityOverride ? this.Ranges.First().Probability.LogValue : (double?)null;
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
#region Distribution properties
|
#region Distribution properties
|
||||||
|
@ -530,13 +548,34 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
}
|
}
|
||||||
|
|
||||||
var builder = StorageBuilder.Create();
|
var builder = StorageBuilder.Create();
|
||||||
|
|
||||||
foreach (var pair in CharRangePair.IntersectRanges(distribution1, distribution2))
|
foreach (var pair in CharRangePair.IntersectRanges(distribution1, distribution2))
|
||||||
{
|
{
|
||||||
var probProduct = pair.Probability1 * pair.Probability2;
|
var probProduct = pair.Probability1 * pair.Probability2;
|
||||||
builder.AddRange(new CharRange(pair.StartInclusive, pair.EndExclusive, probProduct));
|
builder.AddRange(new CharRange(pair.StartInclusive, pair.EndExclusive, probProduct));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.Data = builder.GetResult();
|
double? logProbabilityOverride = null;
|
||||||
|
var distribution1LogProbabilityOverride = distribution1.LogProbabilityOverride;
|
||||||
|
var distribution2LogProbabilityOverride = distribution2.LogProbabilityOverride;
|
||||||
|
if (distribution1LogProbabilityOverride.HasValue)
|
||||||
|
{
|
||||||
|
if (distribution2LogProbabilityOverride.HasValue)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Only one distribution in a DiscreteChar product may have a log probability override");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (distribution2.IsBroad)
|
||||||
|
{
|
||||||
|
logProbabilityOverride = distribution1LogProbabilityOverride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (distribution2LogProbabilityOverride.HasValue && distribution1.IsBroad)
|
||||||
|
{
|
||||||
|
logProbabilityOverride = distribution2LogProbabilityOverride;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.Data = builder.GetResult(logProbabilityOverride);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -629,8 +668,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Sets the distribution to be uniform over the support of a given distribution.
|
/// Sets the distribution to be uniform over the support of a given distribution.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="distribution">The distribution which support will be used to setup the current distribution.</param>
|
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
|
||||||
public void SetToPartialUniformOf(DiscreteChar distribution)
|
public void SetToPartialUniformOf(DiscreteChar distribution)
|
||||||
|
{
|
||||||
|
SetToPartialUniformOf(distribution, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Sets the distribution to be uniform over the support of a given distribution.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
|
||||||
|
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
|
||||||
|
/// against this distribution. If this is set, then the distribution will not be normalize;
|
||||||
|
/// i.e. the probabilities will not sum to 1 over the support.</param>
|
||||||
|
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
|
||||||
|
/// to create more realistic language model priors. Distributions with this override are always uncached.
|
||||||
|
/// </remarks>
|
||||||
|
public void SetToPartialUniformOf(DiscreteChar distribution, double? logProbabilityOverride)
|
||||||
{
|
{
|
||||||
var builder = StorageBuilder.Create();
|
var builder = StorageBuilder.Create();
|
||||||
foreach (var range in distribution.Ranges)
|
foreach (var range in distribution.Ranges)
|
||||||
|
@ -642,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
range.Probability.IsZero ? Weight.Zero : Weight.One));
|
range.Probability.IsZero ? Weight.Zero : Weight.One));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.Data = builder.GetResult();
|
this.Data = builder.GetResult(logProbabilityOverride);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
@ -1398,11 +1452,38 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
|
|
||||||
public char? Point { get; }
|
public char? Point { get; }
|
||||||
|
|
||||||
// Following 3 members are not immutable and can be recalculated on-demand
|
// Following members are not immutable and can be recalculated on-demand
|
||||||
public CharClasses CharClasses { get; private set; }
|
public CharClasses CharClasses { get; private set; }
|
||||||
|
|
||||||
private string regexRepresentation;
|
private string regexRepresentation;
|
||||||
private string symbolRepresentation;
|
private string symbolRepresentation;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Flags derived from ranges.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns></returns>
|
||||||
|
private readonly Flags flags;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Whether this distribution has broad support. We want to be able to
|
||||||
|
/// distinguish between specific distributions and general distributions.
|
||||||
|
/// Use the number of non-zero digits as a threshold.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns></returns>
|
||||||
|
public bool IsBroad => (this.flags & Flags.IsBroad) != 0;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Whether this distribution is partial uniform with a log probability override.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns></returns>
|
||||||
|
public bool HasLogProbabilityOverride => (this.flags & Flags.HasLogProbabilityOverride) != 0;
|
||||||
|
|
||||||
|
[Flags]
|
||||||
|
private enum Flags
|
||||||
|
{
|
||||||
|
HasLogProbabilityOverride = 0x01,
|
||||||
|
IsBroad = 0x02
|
||||||
|
}
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
#region Constructor and factory methods
|
#region Constructor and factory methods
|
||||||
|
@ -1419,6 +1500,21 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
this.CharClasses = charClasses;
|
this.CharClasses = charClasses;
|
||||||
this.regexRepresentation = regexRepresentation;
|
this.regexRepresentation = regexRepresentation;
|
||||||
this.symbolRepresentation = symbolRepresentation;
|
this.symbolRepresentation = symbolRepresentation;
|
||||||
|
var supportCount = this.Ranges.Where(range => !range.Probability.IsZero).Sum(range => range.EndExclusive - range.StartInclusive);
|
||||||
|
var isBroad = supportCount >= 9; // Number of non-zero digits.
|
||||||
|
var totalMass = this.Ranges.Sum(range =>
|
||||||
|
range.Probability.Value * (range.EndExclusive - range.StartInclusive));
|
||||||
|
var isScaled = Math.Abs(totalMass - 1.0) > 1e-10;
|
||||||
|
this.flags = 0;
|
||||||
|
if (isBroad)
|
||||||
|
{
|
||||||
|
flags |= Flags.IsBroad;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isScaled)
|
||||||
|
{
|
||||||
|
flags |= Flags.HasLogProbabilityOverride;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Storage CreateUncached(
|
public static Storage CreateUncached(
|
||||||
|
@ -1896,7 +1992,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
private readonly List<CharRange> ranges;
|
private readonly List<CharRange> ranges;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Precomuted character class.
|
/// Precomputed character class.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private readonly CharClasses charClasses;
|
private readonly CharClasses charClasses;
|
||||||
|
|
||||||
|
@ -1971,15 +2067,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Normalizes probabilities in ranges and returns build Storage.
|
/// Normalizes probabilities in ranges and returns build Storage.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Storage GetResult()
|
public Storage GetResult(double? maximumProbability = null)
|
||||||
{
|
{
|
||||||
this.MergeNeighboringRanges();
|
this.MergeNeighboringRanges();
|
||||||
this.NormalizeProbabilities();
|
NormalizeProbabilities(this.ranges, maximumProbability);
|
||||||
return Storage.Create(
|
return
|
||||||
this.ranges.ToArray(),
|
maximumProbability.HasValue
|
||||||
this.charClasses,
|
? Storage.CreateUncached(
|
||||||
this.regexRepresentation,
|
this.ranges.ToArray(),
|
||||||
this.symbolRepresentation);
|
null,
|
||||||
|
this.charClasses,
|
||||||
|
this.regexRepresentation,
|
||||||
|
this.symbolRepresentation)
|
||||||
|
: Storage.Create(
|
||||||
|
this.ranges.ToArray(),
|
||||||
|
this.charClasses,
|
||||||
|
this.regexRepresentation,
|
||||||
|
this.symbolRepresentation);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endregion
|
#endregion
|
||||||
|
@ -2015,16 +2119,39 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Normalizes probabilities in ranges
|
/// Normalizes probabilities in ranges.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private void NormalizeProbabilities()
|
/// <param name="ranges">The ranges.</param>
|
||||||
|
/// <param name="logProbabilityOverride">Ignores the probabilities in the ranges and creates a non-normalized partial uniform distribution.</param>
|
||||||
|
/// <exception cref="ArgumentException">Thrown if logProbabilityOverride has value corresponding to a non-probability.</exception>
|
||||||
|
public static void NormalizeProbabilities(IList<CharRange> ranges, double? logProbabilityOverride = null)
|
||||||
{
|
{
|
||||||
var normalizer = this.ComputeInvNormalizer();
|
if (logProbabilityOverride.HasValue)
|
||||||
for (int i = 0; i < this.ranges.Count; ++i)
|
|
||||||
{
|
{
|
||||||
var range = this.ranges[i];
|
var weight = Weight.FromLogValue(logProbabilityOverride.Value);
|
||||||
this.ranges[i] = new CharRange(
|
if (weight.IsZero || weight.Value > 1)
|
||||||
range.StartInclusive, range.EndExclusive, range.Probability * normalizer);
|
{
|
||||||
|
throw new ArgumentException("Invalid log probability override.");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (var i = 0; i < ranges.Count; ++i)
|
||||||
|
{
|
||||||
|
var range = ranges[i];
|
||||||
|
ranges[i] = new CharRange(
|
||||||
|
range.StartInclusive, range.EndExclusive, weight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
var normalizer = ComputeInvNormalizer(ranges);
|
||||||
|
for (var i = 0; i < ranges.Count; ++i)
|
||||||
|
{
|
||||||
|
var range = ranges[i];
|
||||||
|
var probability = range.Probability * normalizer;
|
||||||
|
|
||||||
|
ranges[i] = new CharRange(
|
||||||
|
range.StartInclusive, range.EndExclusive, probability);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2032,11 +2159,11 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
/// Computes the normalizer of this distribution.
|
/// Computes the normalizer of this distribution.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <returns>The computed normalizer.</returns>
|
/// <returns>The computed normalizer.</returns>
|
||||||
private Weight ComputeInvNormalizer()
|
private static Weight ComputeInvNormalizer(IEnumerable<CharRange> ranges)
|
||||||
{
|
{
|
||||||
Weight normalizer = Weight.Zero;
|
var normalizer = Weight.Zero;
|
||||||
|
|
||||||
foreach (var range in this.ranges)
|
foreach (var range in ranges)
|
||||||
{
|
{
|
||||||
normalizer += Weight.FromValue(range.EndExclusive - range.StartInclusive) * range.Probability;
|
normalizer += Weight.FromValue(range.EndExclusive - range.StartInclusive) * range.Probability;
|
||||||
}
|
}
|
||||||
|
|
|
@ -244,19 +244,37 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates a distribution over sequences induced by a given list of distributions over sequence elements.
|
/// Creates a distribution over sequences induced by a given list of distributions over sequence elements
|
||||||
|
/// where the sequence can optionally end at any length, and the last element can optionally repeat without limit.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="sequence">Enumerable of distributions over sequence elements.</param>
|
/// <param name="elementDistributions">Enumerable of distributions over sequence elements and the transition weights.</param>
|
||||||
|
/// <param name="allowEarlyEnd">Allow the sequence to end at any point.</param>
|
||||||
|
/// <param name="repeatLastElement">Repeat the last element.</param>
|
||||||
/// <returns>The created distribution.</returns>
|
/// <returns>The created distribution.</returns>
|
||||||
public static TThis Concatenate(IEnumerable<TElementDistribution> sequence)
|
public static TThis Concatenate(IEnumerable<TElementDistribution> elementDistributions, bool allowEarlyEnd = false, bool repeatLastElement = false)
|
||||||
{
|
{
|
||||||
var result = new Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>.Builder();
|
var result = new Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>.Builder();
|
||||||
var last = result.Start;
|
var last = result.Start;
|
||||||
foreach (var elem in sequence)
|
var elementDistributionArray = elementDistributions.ToArray();
|
||||||
|
for (var i = 0; i < elementDistributionArray.Length - 1; i++)
|
||||||
{
|
{
|
||||||
last = last.AddTransition(elem, Weight.One);
|
last = last.AddTransition(elementDistributionArray[i], Weight.One);
|
||||||
|
if (allowEarlyEnd)
|
||||||
|
{
|
||||||
|
last.SetEndWeight(Weight.One);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lastElement = elementDistributionArray[elementDistributionArray.Length - 1];
|
||||||
|
if (repeatLastElement)
|
||||||
|
{
|
||||||
|
last.AddSelfTransition(lastElement, Weight.One);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
last = last.AddTransition(lastElement, Weight.One);
|
||||||
|
}
|
||||||
|
|
||||||
last.SetEndWeight(Weight.One);
|
last.SetEndWeight(Weight.One);
|
||||||
return FromWorkspace(result.GetAutomaton());
|
return FromWorkspace(result.GetAutomaton());
|
||||||
}
|
}
|
||||||
|
@ -1602,6 +1620,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
return !this.IsPointMass && this.sequenceToWeight.IsZero();
|
return !this.IsPointMass && this.sequenceToWeight.IsZero();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Converges an improper sequence distribution
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="dist">The original distribution.</param>
|
||||||
|
/// <param name="decayWeight">The decay weight.</param>
|
||||||
|
/// <returns>The converged distribution.</returns>
|
||||||
|
public static TThis Converge(TThis dist, double decayWeight = 0.99)
|
||||||
|
{
|
||||||
|
var converger =
|
||||||
|
Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>
|
||||||
|
.GetConverger(new TWeightFunction[]
|
||||||
|
{
|
||||||
|
dist.sequenceToWeight
|
||||||
|
}, decayWeight);
|
||||||
|
return dist.Product(FromWorkspace(converger));
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Checks if <paramref name="obj"/> equals to this distribution (i.e. represents the same distribution over sequences).
|
/// Checks if <paramref name="obj"/> equals to this distribution (i.e. represents the same distribution over sequences).
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -2,7 +2,9 @@
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System.Reflection.Metadata.Ecma335;
|
||||||
using Microsoft.ML.Probabilistic.Math;
|
using Microsoft.ML.Probabilistic.Math;
|
||||||
|
using Microsoft.ML.Probabilistic.Utilities;
|
||||||
|
|
||||||
namespace Microsoft.ML.Probabilistic.Tests
|
namespace Microsoft.ML.Probabilistic.Tests
|
||||||
{
|
{
|
||||||
|
@ -10,13 +12,15 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using Xunit;
|
using Xunit;
|
||||||
using Microsoft.ML.Probabilistic.Distributions;
|
using Microsoft.ML.Probabilistic.Distributions;
|
||||||
using Assert = Xunit.Assert;
|
using Assert = Microsoft.ML.Probabilistic.Tests.AssertHelper;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Tests for <see cref="DiscreteChar"/>.
|
/// Tests for <see cref="DiscreteChar"/>.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public class DiscreteCharTest
|
public class DiscreteCharTest
|
||||||
{
|
{
|
||||||
|
const double Eps = 1e-10;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Runs a set of common distribution tests for <see cref="DiscreteChar"/>.
|
/// Runs a set of common distribution tests for <see cref="DiscreteChar"/>.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
@ -63,7 +67,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Assert.True(false);
|
Xunit.Assert.True(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -88,7 +92,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
var unif = Vector.Constant(numChars, 1.0 / numChars);
|
var unif = Vector.Constant(numChars, 1.0 / numChars);
|
||||||
var maxDiff = hist.MaxDiff(unif);
|
var maxDiff = hist.MaxDiff(unif);
|
||||||
|
|
||||||
Assert.True(maxDiff < 0.01);
|
Xunit.Assert.True(maxDiff < 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
|
@ -106,7 +110,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
ab.SetToSum(1, a, 2, b);
|
ab.SetToSum(1, a, 2, b);
|
||||||
|
|
||||||
// 2 subsequent ranges
|
// 2 subsequent ranges
|
||||||
Assert.Equal(2, ab.Ranges.Count);
|
Xunit.Assert.Equal(2, ab.Ranges.Count);
|
||||||
TestComplement(ab);
|
TestComplement(ab);
|
||||||
|
|
||||||
void TestComplement(DiscreteChar dist)
|
void TestComplement(DiscreteChar dist)
|
||||||
|
@ -117,21 +121,123 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
var complement = dist.Complement();
|
var complement = dist.Complement();
|
||||||
|
|
||||||
// complement should always be partial uniform
|
// complement should always be partial uniform
|
||||||
Assert.True(complement.IsPartialUniform());
|
Xunit.Assert.True(complement.IsPartialUniform());
|
||||||
|
|
||||||
// overlap is zero
|
// overlap is zero
|
||||||
Assert.True(double.IsNegativeInfinity(dist.GetLogAverageOf(complement)));
|
Xunit.Assert.True(double.IsNegativeInfinity(dist.GetLogAverageOf(complement)));
|
||||||
Assert.True(double.IsNegativeInfinity(uniformDist.GetLogAverageOf(complement)));
|
Xunit.Assert.True(double.IsNegativeInfinity(uniformDist.GetLogAverageOf(complement)));
|
||||||
|
|
||||||
// union is covers the whole range
|
// union is covers the whole range
|
||||||
var sum = default(DiscreteChar);
|
var sum = default(DiscreteChar);
|
||||||
sum.SetToSum(1, dist, 1, complement);
|
sum.SetToSum(1, dist, 1, complement);
|
||||||
sum.SetToPartialUniform();
|
sum.SetToPartialUniform();
|
||||||
Assert.True(sum.IsUniform());
|
Xunit.Assert.True(sum.IsUniform());
|
||||||
|
|
||||||
// Doing complement again will cover the same set of characters
|
// Doing complement again will cover the same set of characters
|
||||||
var complement2 = complement.Complement();
|
var complement2 = complement.Complement();
|
||||||
Assert.Equal(uniformDist, complement2);
|
Xunit.Assert.Equal(uniformDist, complement2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void PartialUniformWithLogProbabilityOverride()
|
||||||
|
{
|
||||||
|
var dist = DiscreteChar.LetterOrDigit();
|
||||||
|
var probLetter = Math.Exp(dist.GetLogProb('j'));
|
||||||
|
var probNumber = Math.Exp(dist.GetLogProb('5'));
|
||||||
|
|
||||||
|
var logProbabilityOverride = Math.Log(0.7);
|
||||||
|
var scaledDist = DiscreteChar.Uniform();
|
||||||
|
scaledDist.SetToPartialUniformOf(dist, logProbabilityOverride);
|
||||||
|
var scaledLogProbLetter = scaledDist.GetLogProb('j');
|
||||||
|
var scaledLogProbNumber = scaledDist.GetLogProb('5');
|
||||||
|
|
||||||
|
Assert.Equal(scaledLogProbLetter, logProbabilityOverride, Eps);
|
||||||
|
Assert.Equal(scaledLogProbNumber, logProbabilityOverride, Eps);
|
||||||
|
|
||||||
|
// Check that cache has not been compromised.
|
||||||
|
Assert.Equal(probLetter, Math.Exp(dist.GetLogProb('j')), Eps);
|
||||||
|
Assert.Equal(probNumber, Math.Exp(dist.GetLogProb('5')), Eps);
|
||||||
|
|
||||||
|
// Check that an exception is thrown if a bad maximumProbability is passed down.
|
||||||
|
Xunit.Assert.Throws<ArgumentException>(() =>
|
||||||
|
{
|
||||||
|
var badDist = DiscreteChar.Uniform();
|
||||||
|
badDist.SetToPartialUniformOf(dist, Math.Log(1.2));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void BroadAndNarrow()
|
||||||
|
{
|
||||||
|
var dist1 = DiscreteChar.Digit();
|
||||||
|
Xunit.Assert.True(dist1.IsBroad);
|
||||||
|
|
||||||
|
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
|
||||||
|
Xunit.Assert.False(dist2.IsBroad);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void HasLogOverride()
|
||||||
|
{
|
||||||
|
var dist1 = DiscreteChar.LetterOrDigit();
|
||||||
|
Xunit.Assert.False(dist1.HasLogProbabilityOverride);
|
||||||
|
|
||||||
|
dist1.SetToPartialUniformOf(dist1, Math.Log(0.9));
|
||||||
|
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ProductWithLogOverrideBroad()
|
||||||
|
{
|
||||||
|
for (var i = 0; i < 2; i++)
|
||||||
|
{
|
||||||
|
var dist1 = DiscreteChar.LetterOrDigit();
|
||||||
|
var dist2 = DiscreteChar.Digit();
|
||||||
|
|
||||||
|
var logOverrideProbability = Math.Log(0.9);
|
||||||
|
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
|
||||||
|
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||||
|
Xunit.Assert.True(dist2.IsBroad);
|
||||||
|
|
||||||
|
if (i == 1)
|
||||||
|
{
|
||||||
|
Util.Swap(ref dist1, ref dist2);
|
||||||
|
}
|
||||||
|
|
||||||
|
var dist3 = DiscreteChar.Uniform();
|
||||||
|
dist3.SetToProduct(dist1, dist2);
|
||||||
|
|
||||||
|
Xunit.Assert.True(dist3.HasLogProbabilityOverride);
|
||||||
|
Assert.Equal(logOverrideProbability, dist3.GetLogProb('5'), Eps);
|
||||||
|
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void ProductWithLogOverrideNarrow()
|
||||||
|
{
|
||||||
|
for (var i = 0; i < 2; i++)
|
||||||
|
{
|
||||||
|
var dist1 = DiscreteChar.LetterOrDigit();
|
||||||
|
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
|
||||||
|
|
||||||
|
var logOverrideProbability = Math.Log(0.9);
|
||||||
|
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
|
||||||
|
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
|
||||||
|
Xunit.Assert.False(dist2.IsBroad);
|
||||||
|
|
||||||
|
if (i == 1)
|
||||||
|
{
|
||||||
|
Util.Swap(ref dist1, ref dist2);
|
||||||
|
}
|
||||||
|
|
||||||
|
var dist3 = DiscreteChar.Uniform();
|
||||||
|
dist3.SetToProduct(dist1, dist2);
|
||||||
|
|
||||||
|
Xunit.Assert.False(dist3.HasLogProbabilityOverride);
|
||||||
|
Assert.Equal(Math.Log(0.25), dist3.GetLogProb('5'), Eps);
|
||||||
|
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,12 +258,12 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
|
|
||||||
foreach (var ch in included)
|
foreach (var ch in included)
|
||||||
{
|
{
|
||||||
Assert.True(!double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should contain " + ch);
|
Xunit.Assert.True(!double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should contain " + ch);
|
||||||
}
|
}
|
||||||
|
|
||||||
foreach (var ch in excluded)
|
foreach (var ch in excluded)
|
||||||
{
|
{
|
||||||
Assert.True(double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should not contain " + ch);
|
Xunit.Assert.True(double.IsNegativeInfinity(distribution.GetLogProb(ch)), distribution + " should not contain " + ch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
// The .NET Foundation licenses this file to you under the MIT license.
|
// The .NET Foundation licenses this file to you under the MIT license.
|
||||||
// See the LICENSE file in the project root for more information.
|
// See the LICENSE file in the project root for more information.
|
||||||
|
|
||||||
|
using System.Text.RegularExpressions;
|
||||||
|
|
||||||
namespace Microsoft.ML.Probabilistic.Tests
|
namespace Microsoft.ML.Probabilistic.Tests
|
||||||
{
|
{
|
||||||
using System;
|
using System;
|
||||||
|
@ -543,6 +545,130 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void WordModel()
|
||||||
|
{
|
||||||
|
// We want to build a word model as a reasonably simple StringDistribution. It
|
||||||
|
// should satisfy the following:
|
||||||
|
// (1) The probability of a word of moderate length should not be
|
||||||
|
// significantly less than the probability of a shorter word.
|
||||||
|
// (2) The probability of a specific word conditioned on its length matches that of
|
||||||
|
// words in the target language.
|
||||||
|
// We achieve this by putting non-normalized character distributions on the edges. The
|
||||||
|
// StringDistribution is unaware that these are non-normalized.
|
||||||
|
// The StringDistribution itself is non-normalizable.
|
||||||
|
const double TargetProb1 = 0.05;
|
||||||
|
const double Ratio1 = 0.4;
|
||||||
|
const double TargetProb2 = TargetProb1 * Ratio1;
|
||||||
|
const double Ratio2 = 0.2;
|
||||||
|
const double TargetProb3 = TargetProb2 * Ratio2;
|
||||||
|
const double TargetProb4 = TargetProb3 * Ratio2;
|
||||||
|
const double TargetProb5 = TargetProb4 * Ratio2;
|
||||||
|
const double Ratio3 = 0.999;
|
||||||
|
const double TargetProb6 = TargetProb5 * Ratio3;
|
||||||
|
const double TargetProb7 = TargetProb6 * Ratio3;
|
||||||
|
const double TargetProb8 = TargetProb7 * Ratio3;
|
||||||
|
const double Ratio4 = 0.9;
|
||||||
|
const double TargetProb9 = TargetProb8 * Ratio4;
|
||||||
|
const double TargetProb10 = TargetProb9 * Ratio4;
|
||||||
|
|
||||||
|
var targetProbabilitiesPerLength = new double[]
|
||||||
|
{
|
||||||
|
TargetProb1, TargetProb2, TargetProb3, TargetProb4, TargetProb5, TargetProb6, TargetProb7, TargetProb8, TargetProb9, TargetProb10
|
||||||
|
};
|
||||||
|
|
||||||
|
var charDistUpper = DiscreteChar.Upper();
|
||||||
|
var charDistLower = DiscreteChar.Lower();
|
||||||
|
var charDistUpperNarrow = DiscreteChar.OneOf('A', 'B');
|
||||||
|
var charDistLowerNarrow = DiscreteChar.OneOf('a', 'b');
|
||||||
|
var charDistUpperScaled = DiscreteChar.Uniform();
|
||||||
|
var charDistLowerScaled1 = DiscreteChar.Uniform();
|
||||||
|
var charDistLowerScaled2 = DiscreteChar.Uniform();
|
||||||
|
var charDistLowerScaled3 = DiscreteChar.Uniform();
|
||||||
|
var charDistLowerScaledEnd = DiscreteChar.Uniform();
|
||||||
|
|
||||||
|
charDistUpperScaled.SetToPartialUniformOf(charDistUpper, Math.Log(TargetProb1));
|
||||||
|
charDistLowerScaled1.SetToPartialUniformOf(charDistLower, Math.Log(Ratio1));
|
||||||
|
charDistLowerScaled2.SetToPartialUniformOf(charDistLower, Math.Log(Ratio2));
|
||||||
|
charDistLowerScaled3.SetToPartialUniformOf(charDistLower, Math.Log(Ratio3));
|
||||||
|
charDistLowerScaledEnd.SetToPartialUniformOf(charDistLower, Math.Log(Ratio4));
|
||||||
|
|
||||||
|
var wordModel = StringDistribution.Concatenate(
|
||||||
|
new List<DiscreteChar>
|
||||||
|
{
|
||||||
|
charDistUpperScaled,
|
||||||
|
charDistLowerScaled1,
|
||||||
|
charDistLowerScaled2,
|
||||||
|
charDistLowerScaled2,
|
||||||
|
charDistLowerScaled2,
|
||||||
|
charDistLowerScaled3,
|
||||||
|
charDistLowerScaled3,
|
||||||
|
charDistLowerScaled3,
|
||||||
|
charDistLowerScaledEnd
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
true);
|
||||||
|
|
||||||
|
const string Word = "Abcdefghij";
|
||||||
|
|
||||||
|
const double Eps = 1e-5;
|
||||||
|
var broadDist = StringDistribution.Char(charDistUpper);
|
||||||
|
var narrowDist = StringDistribution.Char(charDistUpperNarrow);
|
||||||
|
var narrowWord = "A";
|
||||||
|
var expectedProbForNarrow = 0.5;
|
||||||
|
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||||
|
{
|
||||||
|
var currentWord = Word.Substring(0, i + 1);
|
||||||
|
var probCurrentWord = Math.Exp(wordModel.GetLogProb(currentWord));
|
||||||
|
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||||
|
|
||||||
|
var logAvg = Math.Exp(wordModel.GetLogAverageOf(broadDist));
|
||||||
|
Assert.Equal(targetProbabilitiesPerLength[i], logAvg, Eps);
|
||||||
|
|
||||||
|
var prod = StringDistribution.Zero();
|
||||||
|
prod.SetToProduct(broadDist, wordModel);
|
||||||
|
Xunit.Assert.True(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides);
|
||||||
|
probCurrentWord = Math.Exp(prod.GetLogProb(currentWord));
|
||||||
|
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||||
|
|
||||||
|
prod.SetToProduct(narrowDist, wordModel);
|
||||||
|
Xunit.Assert.False(prod.GetWorkspaceOrPoint().HasElementLogValueOverrides);
|
||||||
|
var probNarrowWord = Math.Exp(prod.GetLogProb(narrowWord));
|
||||||
|
Assert.Equal(expectedProbForNarrow, probNarrowWord, Eps);
|
||||||
|
|
||||||
|
broadDist = broadDist.Append(charDistLower);
|
||||||
|
narrowDist = narrowDist.Append(charDistLowerNarrow);
|
||||||
|
narrowWord += "a";
|
||||||
|
expectedProbForNarrow *= 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copied model
|
||||||
|
var copiedModel = StringDistribution.FromWorkspace(StringTransducer.Copy().ProjectSource(wordModel.GetWorkspaceOrPoint()));
|
||||||
|
// Under transducer.
|
||||||
|
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||||
|
{
|
||||||
|
var currentWord = Word.Substring(0, i + 1);
|
||||||
|
var probCurrentWord = Math.Exp(copiedModel.GetLogProb(currentWord));
|
||||||
|
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rescaled model
|
||||||
|
var scale = 0.5;
|
||||||
|
var newTargetProb1 = TargetProb1 * scale;
|
||||||
|
var charDistUpperScaled1 = DiscreteChar.Uniform();
|
||||||
|
charDistUpperScaled1.SetToPartialUniformOf(charDistUpper, Math.Log(newTargetProb1));
|
||||||
|
var reWeightingTransducer =
|
||||||
|
StringTransducer.Replace(StringDistribution.Char(charDistUpper).GetWorkspaceOrPoint(), StringDistribution.Char(charDistUpperScaled1).GetWorkspaceOrPoint())
|
||||||
|
.Append(StringTransducer.Copy());
|
||||||
|
var reWeightedWordModel = StringDistribution.FromWorkspace(reWeightingTransducer.ProjectSource(wordModel.GetWorkspaceOrPoint()));
|
||||||
|
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
|
||||||
|
{
|
||||||
|
var currentWord = Word.Substring(0, i + 1);
|
||||||
|
var probCurrentWord = Math.Exp(reWeightedWordModel.GetLogProb(currentWord));
|
||||||
|
Assert.Equal(scale * targetProbabilitiesPerLength[i], probCurrentWord, Eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#region Sampling tests
|
#region Sampling tests
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
Загрузка…
Ссылка в новой задаче