Simplify discrete char implementation (#361)

There are two changes:
1. (major) `LogProbabilityOverride` was removed from `DiscreteChar` and `ImmutableDiscreteChar`
    This was an unsound functionality that is not used and was complicating the implementation of char distributions
2. (minor) `ImmutableDiscreteChar.Multiply` may reuse immutable discrete char in more cases out of the box. 
   This reduces the GC pressure a little bit.
This commit is contained in:
Ivan Korostelev 2021-09-14 21:53:05 +01:00 коммит произвёл GitHub
Родитель abae518039
Коммит 8370c9e8e7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 59 добавлений и 352 удалений

Просмотреть файл

@ -21,19 +21,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
public StringAutomaton()
{
}
/// <summary>
/// Whether there are log value overrides at the element level.
/// </summary>
public bool HasElementLogValueOverrides
{
get
{
return this.States.transitions.Any(
trans => trans.ElementDistribution.HasValue &&
trans.ElementDistribution.Value.HasLogProbabilityOverride);
}
}
/// <summary>
/// Computes a set of outgoing transitions from a given state of the determinization result.

Просмотреть файл

@ -379,7 +379,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Populate the stack with start destination state
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
var stringAutomaton = srcAutomaton as StringAutomaton;
var sourceDistributionHasLogProbabilityOverrides = stringAutomaton?.HasElementLogValueOverrides ?? false;
while (stack.Count > 0)
{
@ -420,25 +419,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
continue;
}
// In the special case of a log probability override in an ImmutableDiscreteChar element distribution,
// we need to compensate for the fact that the distribution is not normalized.
if (destElementDistribution.HasValue && sourceDistributionHasLogProbabilityOverrides)
{
var discreteChar =
(ImmutableDiscreteChar)(IImmutableDistribution<char, ImmutableDiscreteChar>)srcTransition.ElementDistribution.Value;
if (discreteChar.HasLogProbabilityOverride)
{
var totalMass = discreteChar.Ranges.EnumerableSum(rng =>
rng.Probability.Value * (rng.EndExclusive - rng.StartInclusive));
projectionLogScale -= System.Math.Log(totalMass);
}
}
var destWeight =
sourceDistributionHasLogProbabilityOverrides && destElementDistribution.HasNoValue
? Weight.One
: Weight.Product(mappingTransition.Weight, srcTransition.Weight,
Weight.FromLogValue(projectionLogScale));
var destWeight = Weight.Product(
mappingTransition.Weight,
srcTransition.Weight,
Weight.FromLogValue(projectionLogScale));
// We don't want an unnormalizable distribution to become normalizable due to a rounding error.
if (Math.Abs(destWeight.LogValue) < 1e-12)

Просмотреть файл

@ -147,15 +147,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// </summary>
public bool IsWordChar => WrappedDistribution.IsWordChar;
/// <inheritdoc cref="ImmutableDiscreteChar.IsBroad"/>
public bool IsBroad => WrappedDistribution.IsBroad;
/// <inheritdoc cref="ImmutableDiscreteChar.HasLogProbabilityOverride"/>
public bool HasLogProbabilityOverride => WrappedDistribution.HasLogProbabilityOverride;
/// <inheritdoc cref="ImmutableDiscreteChar.LogProbabilityOverride"/>
public double? LogProbabilityOverride => WrappedDistribution.LogProbabilityOverride;
#endregion
/// <summary>
@ -541,21 +532,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <inheritdoc cref="ImmutableDiscreteChar.GetProbs"/>
public PiecewiseVector GetProbs() => WrappedDistribution.GetProbs();
/// <summary>
/// Sets the distribution to be uniform over the support of a given distribution.
/// </summary>
/// <param name="distribution">The distribution whose support will be used to setup the current distribution.</param>
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
/// against this distribution. If this is set, then the distribution will not be normalize;
/// i.e. the probabilities will not sum to 1 over the support.</param>
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
/// to create more realistic language model priors. Distributions with this override are always uncached.
/// </remarks>
public void SetToPartialUniformOf(DiscreteChar distribution, double? logProbabilityOverride)
{
WrappedDistribution = distribution.WrappedDistribution.CreatePartialUniform(logProbabilityOverride);
}
#endregion
#region Serialization
@ -705,22 +681,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// </summary>
public bool IsWordChar => this.Data.IsWordChar;
/// <summary>
/// Gets a value indicating whether this distribution is broad - i.e. a general class of values.
/// </summary>
public bool IsBroad => this.Data.IsBroad;
/// <summary>
/// Gets a value indicating whether this distribution is partial uniform with a log probability override.
/// </summary>
public bool HasLogProbabilityOverride => this.Data.HasLogProbabilityOverride;
/// <summary>
/// Gets a value for the log probability override if it exists.
/// </summary>
public double? LogProbabilityOverride =>
this.HasLogProbabilityOverride ? this.Ranges.First().Probability.LogValue : (double?)null;
#endregion
#region Distribution properties
@ -1039,9 +999,9 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <inheritdoc/>
public ImmutableDiscreteChar Multiply(ImmutableDiscreteChar other)
{
if (IsPointMass && other.IsPointMass)
if (IsPointMass)
{
if (Point != other.Point)
if (other.FindProb(Point) == Weight.Zero)
{
throw new AllZeroException("A character distribution that is zero everywhere has been produced.");
}
@ -1049,6 +1009,16 @@ namespace Microsoft.ML.Probabilistic.Distributions
return this;
}
if (other.IsPointMass)
{
if (FindProb(other.Point) == Weight.Zero)
{
throw new AllZeroException("A character distribution that is zero everywhere has been produced.");
}
return other;
}
var builder = StorageBuilder.Create();
foreach (var pair in CharRangePair.IntersectRanges(this, other))
@ -1057,27 +1027,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
builder.AddRange(new CharRange(pair.StartInclusive, pair.EndExclusive, probProduct));
}
double? logProbabilityOverride = null;
var thisLogProbabilityOverride = LogProbabilityOverride;
var otherLogProbabilityOverride = other.LogProbabilityOverride;
if (thisLogProbabilityOverride.HasValue)
{
if (otherLogProbabilityOverride.HasValue)
{
throw new ArgumentException("Only one distribution in a ImmutableDiscreteChar product may have a log probability override");
}
if (other.IsBroad)
{
logProbabilityOverride = thisLogProbabilityOverride;
}
}
else if (otherLogProbabilityOverride.HasValue && IsBroad)
{
logProbabilityOverride = otherLogProbabilityOverride;
}
return new ImmutableDiscreteChar(builder.GetResult(logProbabilityOverride));
return new ImmutableDiscreteChar(builder.GetResult());
}
/// <summary>
@ -1157,16 +1107,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
/// <inheritdoc/>
public ImmutableDiscreteChar CreatePartialUniform() => CreatePartialUniform(null);
/// <inheritdoc cref="CreatePartialUniform()"/>
/// <param name="logProbabilityOverride">An optional value to override for the log probability calculation
/// against this distribution. If this is set, then the distribution will not be normalize;
/// i.e. the probabilities will not sum to 1 over the support.</param>
/// <remarks>Overriding the log probability calculation in this way is useful within the context of using <see cref="StringDistribution"/>
/// to create more realistic language model priors. Distributions with this override are always uncached.
/// </remarks>
public ImmutableDiscreteChar CreatePartialUniform(double? logProbabilityOverride)
public ImmutableDiscreteChar CreatePartialUniform()
{
var builder = StorageBuilder.Create();
foreach (var range in Ranges)
@ -1178,7 +1119,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
range.Probability.IsZero ? Weight.Zero : Weight.One));
}
return new ImmutableDiscreteChar(builder.GetResult(logProbabilityOverride));
return new ImmutableDiscreteChar(builder.GetResult());
}
/// <inheritdoc/>
@ -1899,32 +1840,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
private string regexRepresentation;
private string symbolRepresentation;
/// <summary>
/// Flags derived from ranges.
/// </summary>
/// <returns></returns>
private readonly Flags flags;
/// <summary>
/// Whether this distribution has broad support. We want to be able to
/// distinguish between specific distributions and general distributions.
/// Use the number of non-zero digits as a threshold.
/// </summary>
/// <returns></returns>
public bool IsBroad => (this.flags & Flags.IsBroad) != 0;
/// <summary>
/// Whether this distribution is partial uniform with a log probability override.
/// </summary>
/// <returns></returns>
public bool HasLogProbabilityOverride => (this.flags & Flags.HasLogProbabilityOverride) != 0;
[Flags]
private enum Flags
{
HasLogProbabilityOverride = 0x01,
IsBroad = 0x02
}
#endregion
#region Constructor and factory methods
@ -1942,20 +1857,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
this.regexRepresentation = regexRepresentation;
this.symbolRepresentation = symbolRepresentation;
var supportCount = this.Ranges.Where(range => !range.Probability.IsZero).Sum(range => range.EndExclusive - range.StartInclusive);
var isBroad = supportCount >= 9; // Number of non-zero digits.
var totalMass = this.Ranges.Sum(range =>
range.Probability.Value * (range.EndExclusive - range.StartInclusive));
var isScaled = Math.Abs(totalMass - 1.0) > 1e-10;
this.flags = 0;
if (isBroad)
{
flags |= Flags.IsBroad;
}
if (isScaled)
{
flags |= Flags.HasLogProbabilityOverride;
}
}
public static Storage CreateUncached(
@ -2508,23 +2409,15 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <summary>
/// Normalizes probabilities in ranges and returns build Storage.
/// </summary>
public Storage GetResult(double? maximumProbability = null)
public Storage GetResult()
{
this.MergeNeighboringRanges();
NormalizeProbabilities(this.ranges, maximumProbability);
return
maximumProbability.HasValue
? Storage.CreateUncached(
this.ranges.ToReadOnlyArray(),
null,
this.charClasses,
this.regexRepresentation,
this.symbolRepresentation)
: Storage.Create(
this.ranges.ToReadOnlyArray(),
this.charClasses,
this.regexRepresentation,
this.symbolRepresentation);
NormalizeProbabilities(this.ranges);
return Storage.Create(
this.ranges.ToReadOnlyArray(),
this.charClasses,
this.regexRepresentation,
this.symbolRepresentation);
}
#endregion
@ -2563,36 +2456,17 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// Normalizes probabilities in ranges.
/// </summary>
/// <param name="ranges">The ranges.</param>
/// <param name="logProbabilityOverride">Ignores the probabilities in the ranges and creates a non-normalized partial uniform distribution.</param>
/// <exception cref="ArgumentException">Thrown if logProbabilityOverride has value corresponding to a non-probability.</exception>
public static void NormalizeProbabilities(IList<CharRange> ranges, double? logProbabilityOverride = null)
public static void NormalizeProbabilities(IList<CharRange> ranges)
{
if (logProbabilityOverride.HasValue)
var normalizer = ComputeInvNormalizer(ranges);
for (var i = 0; i < ranges.Count; ++i)
{
var weight = Weight.FromLogValue(logProbabilityOverride.Value);
if (weight.IsZero || weight.Value > 1)
{
throw new ArgumentException("Invalid log probability override.");
}
var range = ranges[i];
var probability = range.Probability * normalizer;
for (var i = 0; i < ranges.Count; ++i)
{
var range = ranges[i];
ranges[i] = new CharRange(
range.StartInclusive, range.EndExclusive, weight);
}
}
else
{
var normalizer = ComputeInvNormalizer(ranges);
for (var i = 0; i < ranges.Count; ++i)
{
var range = ranges[i];
var probability = range.Probability * normalizer;
ranges[i] = new CharRange(
range.StartInclusive, range.EndExclusive, probability);
}
ranges[i] = new CharRange(
range.StartInclusive, range.EndExclusive, probability);
}
}

Просмотреть файл

@ -138,108 +138,6 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
[Fact]
public void PartialUniformWithLogProbabilityOverride()
{
var dist = DiscreteChar.LetterOrDigit();
var probLetter = Math.Exp(dist.GetLogProb('j'));
var probNumber = Math.Exp(dist.GetLogProb('5'));
var logProbabilityOverride = Math.Log(0.7);
var scaledDist = DiscreteChar.Uniform();
scaledDist.SetToPartialUniformOf(dist, logProbabilityOverride);
var scaledLogProbLetter = scaledDist.GetLogProb('j');
var scaledLogProbNumber = scaledDist.GetLogProb('5');
Assert.Equal(scaledLogProbLetter, logProbabilityOverride, Eps);
Assert.Equal(scaledLogProbNumber, logProbabilityOverride, Eps);
// Check that cache has not been compromised.
Assert.Equal(probLetter, Math.Exp(dist.GetLogProb('j')), Eps);
Assert.Equal(probNumber, Math.Exp(dist.GetLogProb('5')), Eps);
// Check that an exception is thrown if a bad maximumProbability is passed down.
Xunit.Assert.Throws<ArgumentException>(() =>
{
var badDist = DiscreteChar.Uniform();
badDist.SetToPartialUniformOf(dist, Math.Log(1.2));
});
}
[Fact]
public void BroadAndNarrow()
{
var dist1 = DiscreteChar.Digit();
Xunit.Assert.True(dist1.IsBroad);
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
Xunit.Assert.False(dist2.IsBroad);
}
[Fact]
public void HasLogOverride()
{
var dist1 = DiscreteChar.LetterOrDigit();
Xunit.Assert.False(dist1.HasLogProbabilityOverride);
dist1.SetToPartialUniformOf(dist1, Math.Log(0.9));
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
}
[Fact]
public void ProductWithLogOverrideBroad()
{
for (var i = 0; i < 2; i++)
{
var dist1 = DiscreteChar.LetterOrDigit();
var dist2 = DiscreteChar.Digit();
var logOverrideProbability = Math.Log(0.9);
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
Xunit.Assert.True(dist2.IsBroad);
if (i == 1)
{
Util.Swap(ref dist1, ref dist2);
}
var dist3 = DiscreteChar.Uniform();
dist3.SetToProduct(dist1, dist2);
Xunit.Assert.True(dist3.HasLogProbabilityOverride);
Assert.Equal(logOverrideProbability, dist3.GetLogProb('5'), Eps);
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
}
}
[Fact]
public void ProductWithLogOverrideNarrow()
{
for (var i = 0; i < 2; i++)
{
var dist1 = DiscreteChar.LetterOrDigit();
var dist2 = DiscreteChar.OneOf('1', '3', '5', '6');
var logOverrideProbability = Math.Log(0.9);
dist1.SetToPartialUniformOf(dist1, logOverrideProbability);
Xunit.Assert.True(dist1.HasLogProbabilityOverride);
Xunit.Assert.False(dist2.IsBroad);
if (i == 1)
{
Util.Swap(ref dist1, ref dist2);
}
var dist3 = DiscreteChar.Uniform();
dist3.SetToProduct(dist1, dist2);
Xunit.Assert.False(dist3.HasLogProbabilityOverride);
Assert.Equal(Math.Log(0.25), dist3.GetLogProb('5'), Eps);
Xunit.Assert.True(double.IsNegativeInfinity(dist3.GetLogProb('a')));
}
}
/// <summary>
/// Tests the support of a character distribution.
/// </summary>

Просмотреть файл

@ -567,8 +567,7 @@ namespace Microsoft.ML.Probabilistic.Tests
// significantly less than the probability of a shorter word.
// (2) The probability of a specific word conditioned on its length matches that of
// words in the target language.
// We achieve this by putting non-normalized character distributions on the edges. The
// StringDistribution is unaware that these are non-normalized.
// We achieve this by putting weights which do not sum to 1 on transitions.
// The StringDistribution itself is non-normalizable.
const double TargetProb1 = 0.05;
const double Ratio1 = 0.4;
@ -592,62 +591,42 @@ namespace Microsoft.ML.Probabilistic.Tests
var charDistUpper = ImmutableDiscreteChar.Upper();
var charDistLower = ImmutableDiscreteChar.Lower();
var charDistUpperNarrow = ImmutableDiscreteChar.OneOf('A', 'B');
var charDistLowerNarrow = ImmutableDiscreteChar.OneOf('a', 'b');
var charDistUpperScaled = charDistUpper.CreatePartialUniform(Math.Log(TargetProb1));
var charDistLowerScaled1 = charDistLower.CreatePartialUniform(Math.Log(Ratio1));
var charDistLowerScaled2 = charDistLower.CreatePartialUniform(Math.Log(Ratio2));
var charDistLowerScaled3 = charDistLower.CreatePartialUniform(Math.Log(Ratio3));
var charDistLowerScaledEnd = charDistLower.CreatePartialUniform(Math.Log(Ratio4));
var workspace = new StringAutomaton.Builder();
var state = workspace.Start;
var wordModel = StringDistribution.Concatenate(
new List<ImmutableDiscreteChar>
{
charDistUpperScaled,
charDistLowerScaled1,
charDistLowerScaled2,
charDistLowerScaled2,
charDistLowerScaled2,
charDistLowerScaled3,
charDistLowerScaled3,
charDistLowerScaled3,
charDistLowerScaledEnd
},
true,
true);
void AddCharToModel(ImmutableDiscreteChar c, double ratio)
{
var realCharProb = c.Ranges.First().Probability;
var weight = Weight.FromValue(ratio) * Weight.Inverse(realCharProb);
state = state.AddTransition(c, weight);
state.SetEndWeight(Weight.One);
}
AddCharToModel(charDistUpper, TargetProb1);
AddCharToModel(charDistLower, Ratio1);
AddCharToModel(charDistLower, Ratio2);
AddCharToModel(charDistLower, Ratio2);
AddCharToModel(charDistLower, Ratio2);
AddCharToModel(charDistLower, Ratio3);
AddCharToModel(charDistLower, Ratio3);
AddCharToModel(charDistLower, Ratio3);
AddCharToModel(charDistLower, Ratio4);
AddCharToModel(charDistLower, Ratio4);
state.AddTransition(charDistLower, Weight.One, state.Index);
var wordModel = new StringDistribution();
wordModel.SetWeightFunction(workspace.GetAutomaton());
const string Word = "Abcdefghij";
const double Eps = 1e-5;
var broadDist = StringDistribution.Char(charDistUpper);
var narrowDist = StringDistribution.Char(charDistUpperNarrow);
var narrowWord = "A";
var expectedProbForNarrow = 0.5;
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
{
var currentWord = Word.Substring(0, i + 1);
var probCurrentWord = Math.Exp(wordModel.GetLogProb(currentWord));
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
var logAvg = Math.Exp(wordModel.GetLogAverageOf(broadDist));
Assert.Equal(targetProbabilitiesPerLength[i], logAvg, Eps);
var prod = StringDistribution.Zero();
prod.SetToProduct(broadDist, wordModel);
Xunit.Assert.True(prod.ToAutomaton().HasElementLogValueOverrides);
probCurrentWord = Math.Exp(prod.GetLogProb(currentWord));
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
prod.SetToProduct(narrowDist, wordModel);
Xunit.Assert.False(prod.ToAutomaton().HasElementLogValueOverrides);
var probNarrowWord = Math.Exp(prod.GetLogProb(narrowWord));
Assert.Equal(expectedProbForNarrow, probNarrowWord, Eps);
broadDist = broadDist.Append(charDistLower);
narrowDist = narrowDist.Append(charDistLowerNarrow);
narrowWord += "a";
expectedProbForNarrow *= 0.5;
}
// Copied model
@ -659,21 +638,6 @@ namespace Microsoft.ML.Probabilistic.Tests
var probCurrentWord = Math.Exp(copiedModel.GetLogProb(currentWord));
Assert.Equal(targetProbabilitiesPerLength[i], probCurrentWord, Eps);
}
// Rescaled model
var scale = 0.5;
var newTargetProb1 = TargetProb1 * scale;
var charDistUpperScaled1 = charDistUpper.CreatePartialUniform(Math.Log(newTargetProb1));
var reWeightingTransducer =
StringTransducer.Replace(StringDistribution.Char(charDistUpper).ToAutomaton(), StringDistribution.Char(charDistUpperScaled1).ToAutomaton())
.Append(StringTransducer.Copy());
var reWeightedWordModel = StringDistribution.FromWeightFunction(reWeightingTransducer.ProjectSource(wordModel.ToAutomaton()));
for (var i = 0; i < targetProbabilitiesPerLength.Length; i++)
{
var currentWord = Word.Substring(0, i + 1);
var probCurrentWord = Math.Exp(reWeightedWordModel.GetLogProb(currentWord));
Assert.Equal(scale * targetProbabilitiesPerLength[i], probCurrentWord, Eps);
}
}
[Fact]