Automaton determinization improvements (#144)

* Determinization doesn't change language of the automaton anymore.
  2 new tests were added which check that it doesn't happen.
  Previously all very low probability transitions (with probability of less than e^-35) were removed.
  That was done because of 2 reasons:
  - some probabilities were represented in linear space. And e^-35 is as low resolution
    as you can get with regular doubles. (That is, if one probability = 1 and another is e^-35,
    then when you add them together, second one is indistinguishable from zero).
    This was fixed when discrete charprobabilities
  - Trying to determinize some non-determinizable automata lead to explosion of low-probability
    states and transitions which led to a very poor performance.
    (E.g. `AutomatonNormalizationPerformance3` test). Now a smarter strategy for detecting
    these non-determinizable automata is used - during traversal all sets of states from root
    are remembered. If automaton comes to the same set of states but with different weights
    than it stops immediately, because otherwise it will be caught in infinite loop
* `Equals()` and `GetHashCode()` for `WeightedStateSet` take into account only high 32 bits of weight.
  This coupled with normalization of weights allows to reuse already added states with
  very close weights. This speeds up the "PropertyInferencePerformanceTest" 2.5x due to
  smaller intermediate automata.
* Weighted sets of size one are handled specially in `TryDeterminize` - they don't need to be
  determinized and can be copied into result almost as is. (Unless they have non-deterministic
  transitions. Simple heuristic of "has different destination states" is used to detect that
  and fallback to slow general path).
* Representation for `WeightedStateSet` is changed from (int -> float) dictionary to
  sorted array of (int, float) pairs. As an optimization, a common case of single-element
  set does not allocate any arrays.
* Determinization code for `ListAutomaton` was removed, because it has never worked
This commit is contained in:
Ivan Korostelev 2019-04-23 15:30:16 +01:00 коммит произвёл GitHub
Родитель 526bfefd3b
Коммит e486a155b9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 670 добавлений и 405 удалений

Просмотреть файл

@ -2,13 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
namespace Microsoft.ML.Probabilistic.Collections
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.Serialization;
using Microsoft.ML.Probabilistic.Serialization;

Просмотреть файл

@ -500,11 +500,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Sets a new end weight for this state.
/// </summary>
public void SetEndWeight(Weight weight)
public StateBuilder SetEndWeight(Weight weight)
{
var state = this.builder.states[this.Index];
state.EndWeight = weight;
this.builder.states[this.Index] = state;
return this;
}
#region AddTransition variants

Просмотреть файл

@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
using System;
@ -45,61 +48,51 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
return false;
}
// Weighted state set is a set of (stateId, weight) pairs, where state ids correspond to states of the original automaton..
// Such pairs correspond to states of the resulting automaton.
var weightedStateSetQueue = new Queue<Determinization.WeightedStateSet>();
var weightedStateSetToNewState = new Dictionary<Determinization.WeightedStateSet, int>();
var builder = new Builder();
var startWeightedStateSet = new Determinization.WeightedStateSet { { this.Start.Index, Weight.One } };
weightedStateSetQueue.Enqueue(startWeightedStateSet);
weightedStateSetToNewState.Add(startWeightedStateSet, builder.StartStateIndex);
builder.Start.SetEndWeight(this.Start.EndWeight);
while (weightedStateSetQueue.Count > 0)
var weightedStateSetStack = new Stack<(bool enter, Determinization.WeightedStateSet set)>();
var enqueuedWeightedStateSetStack = new Stack<(bool enter, Determinization.WeightedStateSet set)>();
var weightedStateSetToNewState = new Dictionary<Determinization.WeightedStateSet, int>();
// This hash set is used to track sets currently in path from root. If we've found a set of states
// that we have already seen during current path from root, but weights are different, that means
// we've found a non-converging loop - infinite number of weighed sets will be generated if
// we continue traversal and determinization will fail. For performance reasons we want to fail
// fast if such loop is found.
var stateSetsInPath = new Dictionary<Determinization.WeightedStateSet, Determinization.WeightedStateSet>(
Determinization.WeightedStateSetOnlyStateComparer.Instance);
var startWeightedStateSet = new Determinization.WeightedStateSet(this.Start.Index);
weightedStateSetStack.Push((true, startWeightedStateSet));
weightedStateSetToNewState.Add(startWeightedStateSet, builder.StartStateIndex);
while (weightedStateSetStack.Count > 0)
{
// Take one unprocessed state of the resulting automaton
Determinization.WeightedStateSet currentWeightedStateSet = weightedStateSetQueue.Dequeue();
var currentStateIndex = weightedStateSetToNewState[currentWeightedStateSet];
var currentState = builder[currentStateIndex];
var (enter, currentWeightedStateSet) = weightedStateSetStack.Pop();
// Find out what transitions we should add for this state
var outgoingTransitionInfos = this.GetOutgoingTransitionsForDeterminization(currentWeightedStateSet);
if (enter)
{
if (currentWeightedStateSet.Count > 1)
{
// Only sets with more than 1 state can lead to infinite loops with different weights.
// Because if there's only 1 state, than it's weight is always Weight.One.
if (!stateSetsInPath.ContainsKey(currentWeightedStateSet))
{
stateSetsInPath.Add(currentWeightedStateSet, currentWeightedStateSet);
}
// For each transition to add
foreach ((TElementDistribution, Weight, Determinization.WeightedStateSet) outgoingTransitionInfo in outgoingTransitionInfos)
{
TElementDistribution elementDistribution = outgoingTransitionInfo.Item1;
Weight weight = outgoingTransitionInfo.Item2;
Determinization.WeightedStateSet destWeightedStateSet = outgoingTransitionInfo.Item3;
weightedStateSetStack.Push((false, currentWeightedStateSet));
}
int destinationStateIndex;
if (!weightedStateSetToNewState.TryGetValue(destWeightedStateSet, out destinationStateIndex))
if (!EnqueueOutgoingTransitions(currentWeightedStateSet))
{
if (builder.StatesCount == maxStatesBeforeStop)
{
// Too many states, determinization attempt failed
return false;
}
// Add new state to the result
var destinationState = builder.AddState();
weightedStateSetToNewState.Add(destWeightedStateSet, destinationState.Index);
weightedStateSetQueue.Enqueue(destWeightedStateSet);
// Compute its ending weight
destinationState.SetEndWeight(Weight.Zero);
foreach (KeyValuePair<int, Weight> stateIdWithWeight in destWeightedStateSet)
}
else
{
var addedWeight = stateIdWithWeight.Value * this.States[stateIdWithWeight.Key].EndWeight;
destinationState.SetEndWeight(destinationState.EndWeight + addedWeight);
}
destinationStateIndex = destinationState.Index;
}
// Add transition to the destination state
currentState.AddTransition(elementDistribution, weight, destinationStateIndex);
stateSetsInPath.Remove(currentWeightedStateSet);
}
}
@ -111,6 +104,137 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.LogValueOverride = this.LogValueOverride;
return true;
bool EnqueueOutgoingTransitions(Determinization.WeightedStateSet currentWeightedStateSet)
{
var currentStateIndex = weightedStateSetToNewState[currentWeightedStateSet];
var currentState = builder[currentStateIndex];
// Common special-case: definitely deterministic transitions from single state.
// In this case no complicated determinization procedure is needed.
if (currentWeightedStateSet.Count == 1 &&
AllDestinationsAreSame(currentWeightedStateSet[0].Index))
{
Debug.Assert(currentWeightedStateSet[0].Weight == Weight.One);
var sourceState = this.States[currentWeightedStateSet[0].Index];
foreach (var transition in sourceState.Transitions)
{
var destinationStates = new Determinization.WeightedStateSet(transition.DestinationStateIndex);
var outgoingTransitionInfo = new Determinization.OutgoingTransition(
transition.ElementDistribution.Value, transition.Weight, destinationStates);
if (!TryAddTransition(enqueuedWeightedStateSetStack, outgoingTransitionInfo, currentState))
{
return false;
}
}
}
else
{
// Find out what transitions we should add for this state
var outgoingTransitions =
this.GetOutgoingTransitionsForDeterminization(currentWeightedStateSet);
foreach (var outgoingTransition in outgoingTransitions)
{
if (!TryAddTransition(enqueuedWeightedStateSetStack, outgoingTransition, currentState))
{
return false;
}
}
}
while (enqueuedWeightedStateSetStack.Count > 0)
{
weightedStateSetStack.Push(enqueuedWeightedStateSetStack.Pop());
}
return true;
}
// Checks that all transitions from state end up in the same destination. This is used
// as a very fast "is determenistic" check, that doesn't care about distributions.
// State can have determenistic transitions with different destinations. This case will be
// handled by slow path.
bool AllDestinationsAreSame(int stateIndex)
{
var transitions = this.States[stateIndex].Transitions;
if (transitions.Count <= 1)
{
return true;
}
var destination = transitions[0].DestinationStateIndex;
for (var i = 1; i < transitions.Count; ++i)
{
if (transitions[i].DestinationStateIndex != destination)
{
return false;
}
}
return true;
}
// Adds transition from currentState into state corresponding to weighted state set from
// outgoingTransitionInfo. If that state does not exist yet it is created and is put into stack
// for further processing. This function returns false if determinization has failed.
// That can happen because of 2 ressons:
// - Too many states were created and its not feasible to continue trying to determinize
// automaton further
// - An infinite loop with not converging weights was found. It leads to infinite number of states.
// So determinization is aborted early.
bool TryAddTransition(
Stack<(bool enter, Determinization.WeightedStateSet set)> destinationStack,
Determinization.OutgoingTransition transition,
Builder.StateBuilder currentState)
{
var destinations = transition.Destinations;
if (!weightedStateSetToNewState.TryGetValue(destinations, out var destinationStateIndex))
{
if (builder.StatesCount == maxStatesBeforeStop)
{
// Too many states, determinization attempt failed
return false;
}
var visitedWeightedStateSet = default(Determinization.WeightedStateSet);
var sameSetVisited =
destinations.Count > 1 &&
stateSetsInPath.TryGetValue(destinations, out visitedWeightedStateSet);
if (sameSetVisited && !destinations.Equals(visitedWeightedStateSet))
{
// We arrived into the same state set as before, but with different weights.
// This is an infinite non-converging loop. Determinization has failed
return false;
}
// Add new state to the result
var destinationState = builder.AddState();
weightedStateSetToNewState.Add(destinations, destinationState.Index);
destinationStack.Push((true, destinations));
if (destinations.Count > 1 && !sameSetVisited)
{
destinationStack.Push((false, destinations));
}
// Compute its ending weight
destinationState.SetEndWeight(Weight.Zero);
for (var i = 0; i < destinations.Count; ++i)
{
var weightedState = destinations[i];
var addedWeight = weightedState.Weight * this.States[weightedState.Index].EndWeight;
destinationState.SetEndWeight(destinationState.EndWeight + addedWeight);
}
destinationStateIndex = destinationState.Index;
}
// Add transition to the destination state
currentState.AddTransition(transition.ElementDistribution, transition.Weight, destinationStateIndex);
return true;
}
}
/// <summary>
@ -124,82 +248,121 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// The first two elements of a tuple define the element distribution and the weight of a transition.
/// The third element defines the outgoing state.
/// </returns>
protected abstract List<(TElementDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
protected abstract IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
Determinization.WeightedStateSet sourceState);
/// <summary>
/// Groups together helper classes used for automata determinization.
/// </summary>
protected static class Determinization
{
public struct OutgoingTransition
{
public TElementDistribution ElementDistribution { get; }
public Weight Weight { get; }
public WeightedStateSet Destinations { get; }
public OutgoingTransition(
TElementDistribution elementDistribution, Weight weight, WeightedStateSet destinations)
{
this.ElementDistribution = elementDistribution;
this.Weight = weight;
this.Destinations = destinations;
}
}
/// <summary>
/// Represents
/// </summary>
public struct WeightedState : IComparable, IComparable<WeightedState>
{
/// <summary>
/// Index of the state.
/// </summary>
public int Index { get; }
/// <summary>
/// High bits of the state weight. Only these bits are used when comparing weighted states
/// for equality or when calculating hash code.
/// </summary>
/// <remarks>
/// Using high bits for hash code allows to have "fuzzy" matching of WeightedStateSets.
/// Sometimes there's more than one way to get to the same weighted state set in automaton,
/// but due to using floating point numbers calculated weights are slightly off. Using
/// only high 32 bits for comparison means that 20 bits of mantissa are used. Which means that
/// difference between weights (in log space) is no more than ~~ 10e-6 which is a sufficiently
/// good precision for all practical purposes.
/// </remarks>
public int WeightHighBits { get; }
public Weight Weight { get; }
public WeightedState(int index, Weight weight)
{
this.Index = index;
this.WeightHighBits = (int)(BitConverter.DoubleToInt64Bits(weight.LogValue) >> 32);
this.Weight = weight;
}
public int CompareTo(object obj)
{
return obj is WeightedState that
? this.CompareTo(that)
: throw new InvalidOperationException(
"WeightedState can be compared only to another WeightedState");
}
public int CompareTo(WeightedState that) => Index.CompareTo(that.Index);
public override int GetHashCode() => (Index ^ WeightHighBits).GetHashCode();
}
/// <summary>
/// Represents a state of the resulting automaton in the power set construction.
/// It is essentially a set of (stateId, weight) pairs of the source automaton, where each state id is unique.
/// Supports a quick lookup of the weight by state id.
/// </summary>
public class WeightedStateSet : IEnumerable<KeyValuePair<int, Weight>>
public struct WeightedStateSet
{
/// <summary>
/// A mapping from state ids to weights.
/// A mapping from state ids to weights. This array is sorted by state Id.
/// </summary>
private readonly Dictionary<int, Weight> stateIdToWeight;
private readonly ReadOnlyArray<WeightedState> weightedStates;
/// <summary>
/// Initializes a new instance of the <see cref="WeightedStateSet"/> class.
/// </summary>
public WeightedStateSet() =>
this.stateIdToWeight = new Dictionary<int, Weight>();
private readonly int singleStateIndex;
/// <summary>
/// Initializes a new instance of the <see cref="WeightedStateSet"/> class.
/// </summary>
/// <param name="stateIdToWeight">A collection of (stateId, weight) pairs.
/// </param>
public WeightedStateSet(IEnumerable<KeyValuePair<int, Weight>> stateIdToWeight) =>
this.stateIdToWeight = stateIdToWeight.ToDictionary(kv => kv.Key, kv => kv.Value);
/// <summary>
/// Gets or sets the weight for a given state id.
/// </summary>
/// <param name="stateId">The state id.</param>
/// <returns>The weight.</returns>
public Weight this[int stateId]
public WeightedStateSet(int stateIndex)
{
get => this.stateIdToWeight[stateId];
set => this.stateIdToWeight[stateId] = value;
this.weightedStates = null;
this.singleStateIndex = stateIndex;
}
/// <summary>
/// Adds a given state id and a weight to the set.
/// </summary>
/// <param name="stateId">The state id.</param>
/// <param name="weight">The weight.</param>
public void Add(int stateId, Weight weight) =>
this.stateIdToWeight.Add(stateId, weight);
public WeightedStateSet(ReadOnlyArray<WeightedState> weightedStates)
{
Debug.Assert(weightedStates.Count > 0);
Debug.Assert(IsSorted(weightedStates));
if (weightedStates.Count == 1)
{
Debug.Assert(weightedStates[0].Weight == Weight.One);
this.weightedStates = null;
this.singleStateIndex = weightedStates[0].Index;
}
else
{
this.weightedStates = weightedStates;
this.singleStateIndex = 0; // <- value doesn't matter, but silences the compiler
}
}
/// <summary>
/// Attempts to retrieve the weight corresponding to a given state id from the set.
/// </summary>
/// <param name="stateId">The state id.</param>
/// <param name="weight">When the method returns, contains the retrieved weight.</param>
/// <returns>
/// <see langword="true"/> if the given state id was present in the set,
/// <see langword="false"/> otherwise.
/// </returns>
public bool TryGetWeight(int stateId, out Weight weight) =>
this.stateIdToWeight.TryGetValue(stateId, out weight);
public int Count =>
this.weightedStates.IsNull
? 1
: this.weightedStates.Count;
/// <summary>
/// Checks whether the state with a given id is present in the set.
/// </summary>
/// <param name="stateId">The state id,</param>
/// <returns>
/// <see langword="true"/> if the given state id was present in the set,
/// <see langword="false"/> otherwise.
/// </returns>
public bool ContainsState(int stateId) =>
this.stateIdToWeight.ContainsKey(stateId);
public WeightedState this[int index] =>
this.weightedStates.IsNull
? new WeightedState(this.singleStateIndex, Weight.One)
: this.weightedStates[index];
/// <summary>
/// Checks whether this object is equal to a given one.
@ -209,25 +372,28 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <see langword="true"/> if the objects are equal,
/// <see langword="false"/> otherwise.
/// </returns>
public override bool Equals(object obj)
public override bool Equals(object obj) => obj is WeightedStateSet that && this.Equals(that);
/// <summary>
/// Checks whether this object is equal to a given one.
/// </summary>
/// <param name="that">The object to compare this object with.</param>
/// <returns>
/// <see langword="true"/> if the objects are equal,
/// <see langword="false"/> otherwise.
/// </returns>
public bool Equals(WeightedStateSet that)
{
if (obj == null || obj.GetType() != typeof(WeightedStateSet))
if (this.Count != that.Count)
{
return false;
}
var other = (WeightedStateSet)obj;
if (this.stateIdToWeight.Count != other.stateIdToWeight.Count)
for (var i = 0; i < this.Count; ++i)
{
return false;
}
foreach (KeyValuePair<int, Weight> pair in this.stateIdToWeight)
{
// TODO: Should we allow for some tolerance? But what about hashing then?
Weight otherWeight;
if (!other.stateIdToWeight.TryGetValue(pair.Key, out otherWeight) || otherWeight != pair.Value)
var state1 = this[i];
var state2 = that[i];
if (state1.Index != state2.Index || state1.WeightHighBits != state2.WeightHighBits)
{
return false;
}
@ -240,17 +406,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Computes the hash code of this instance.
/// </summary>
/// <returns>The computed hash code.</returns>
/// <remarks>Only state ids</remarks>
public override int GetHashCode()
{
int result = 0;
foreach (KeyValuePair<int, Weight> pair in this.stateIdToWeight)
var result = this[0].GetHashCode();
for (var i = 1; i < this.Count; ++i)
{
int pairHash = Hash.Start;
pairHash = Hash.Combine(pairHash, pair.Key.GetHashCode());
pairHash = Hash.Combine(pairHash, pair.Value.GetHashCode());
// Use commutative hashing combination because dictionaries are not ordered
result ^= pairHash;
result = Hash.Combine(result, this[i].GetHashCode());
}
return result;
@ -260,38 +422,128 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Returns a string representation of the instance.
/// </summary>
/// <returns>A string representation of the instance.</returns>
public override string ToString()
{
StringBuilder builder = new StringBuilder();
foreach (var kvp in this.stateIdToWeight)
{
builder.AppendLine(kvp.ToString());
}
return builder.ToString();
}
#region IEnumerable implementation
public override string ToString() => string.Join(", ", weightedStates);
/// <summary>
/// Gets the enumerator.
/// Turns weighted state set into an array. This is convenient for writing LINQ queries
/// in tests.
/// </summary>
/// <returns>
/// The enumerator.
/// </returns>
public IEnumerator<KeyValuePair<int, Weight>> GetEnumerator() =>
this.stateIdToWeight.GetEnumerator();
public WeightedState[] ToArray()
{
var result = new WeightedState[this.Count];
for (var i = 0; i < this.Count; ++i)
{
result[i] = this[i];
}
return result;
}
/// <summary>
/// Gets the enumerator.
/// Checks weather states array is sorted in ascending order by Index.
/// </summary>
/// <returns>
/// The enumerator.
/// </returns>
IEnumerator IEnumerable.GetEnumerator() =>
this.GetEnumerator();
private static bool IsSorted(ReadOnlyArray<WeightedState> array)
{
for (var i = 1; i < array.Count; ++i)
{
if (array[i].Index <= array[i - 1].Index)
{
return false;
}
}
#endregion
return true;
}
}
/// <summary>
/// Builder for weighted sets.
/// </summary>
public struct WeightedStateSetBuilder
{
private List<WeightedState> weightedStates;
public static WeightedStateSetBuilder Create() =>
new WeightedStateSetBuilder()
{
weightedStates = new List<WeightedState>(1),
};
public void Add(int index, Weight weight) =>
this.weightedStates.Add(new WeightedState(index, weight));
public void Reset() => this.weightedStates.Clear();
public (WeightedStateSet, Weight) Get()
{
Debug.Assert(this.weightedStates.Count > 0);
var sortedStates = this.weightedStates.ToArray();
if (sortedStates.Length == 1)
{
var state = sortedStates[0];
sortedStates[0] = new WeightedState(state.Index, Weight.One);
return (new WeightedStateSet(sortedStates), state.Weight);
}
else
{
Array.Sort(sortedStates);
var maxWeight = sortedStates[0].Weight;
for (var i = 1; i < sortedStates.Length; ++i)
{
if (sortedStates[i].Weight > maxWeight)
{
maxWeight = sortedStates[i].Weight;
}
}
var normalizer = Weight.Inverse(maxWeight);
for (var i = 0; i < sortedStates.Length; ++i)
{
var state = sortedStates[i];
sortedStates[i] = new WeightedState(state.Index, state.Weight * normalizer);
}
return (new WeightedStateSet(sortedStates), maxWeight);
}
}
}
public class WeightedStateSetOnlyStateComparer : IEqualityComparer<WeightedStateSet>
{
public static readonly WeightedStateSetOnlyStateComparer Instance =
new WeightedStateSetOnlyStateComparer();
public bool Equals(WeightedStateSet x, WeightedStateSet y)
{
if (x.Count != y.Count)
{
return false;
}
for (var i = 0; i < x.Count; ++i)
{
if (x[i].Index != y[i].Index)
{
return false;
}
}
return true;
}
public int GetHashCode(WeightedStateSet set)
{
var result = set[0].Index.GetHashCode();
for (var i = 1; i < set.Count; ++i)
{
result = Hash.Combine(result, set[i].Index.GetHashCode());
}
return result;
}
}
}
}

Просмотреть файл

@ -53,7 +53,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// The first two elements of a tuple define the element distribution and the weight of a transition.
/// The third element defines the outgoing state.
/// </returns>
protected override List<(TElementDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
Determinization.WeightedStateSet sourceState)
{
throw new NotImplementedException("Determinization is not yet supported for this type of automata.");
@ -84,106 +84,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// The first two elements of a tuple define the element distribution and the weight of a transition.
/// The third element defines the outgoing state.
/// </returns>
protected override List<(TElementDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
Determinization.WeightedStateSet sourceState)
{
throw new NotImplementedException();
//// Build a list of elements, with probabilities
//var elementLists = new Dictionary<TElement, List<TransitionElement>>();
//var uniformList = new List<TransitionElement>();
//foreach (KeyValuePair<int, Weight> stateIdWeight in sourceState)
//{
// var state = this.States[stateIdWeight.Key];
// for (int i = 0; i < state.TransitionCount; ++i)
// {
// AddTransitionElements(state.GetTransition(i), stateIdWeight.Value, elementLists, uniformList);
// }
//}
//// Produce an outgoing transition for each unique subset of overlapping segments
//var results = new List<Tuple<TElementDistribution, Weight, Determinization.WeightedStateSet>>();
//foreach (var kvp in elementLists)
//{
// AddResult(results, kvp.Value);
//}
//AddResult(results, uniformList);
//return results;
}
private static void AddResult(List<Tuple<TElementDistribution, Weight, Determinization.WeightedStateSet>> results,
List<TransitionElement> transitionElements)
{
if (transitionElements.Count == 0) return;
const double LogEps = -30; // Don't add transitions with log-weight less than this as they have been produced by numerical inaccuracies
var elementStatesWeights = new Determinization.WeightedStateSet();
var elementStateWeightSum = Weight.Zero;
foreach (var element in transitionElements)
{
if (!elementStatesWeights.TryGetWeight(element.destIndex, out var prevStateWeight))
{
prevStateWeight = Weight.Zero;
}
elementStatesWeights[element.destIndex] = prevStateWeight + element.weight;
elementStateWeightSum += element.weight;
}
var destinationState = new Determinization.WeightedStateSet();
foreach (KeyValuePair<int, Weight> stateIdWithWeight in elementStatesWeights)
{
if (stateIdWithWeight.Value.LogValue > LogEps)
{
Weight stateWeight = stateIdWithWeight.Value * Weight.Inverse(elementStateWeightSum);
destinationState.Add(stateIdWithWeight.Key, stateWeight);
}
}
Weight transitionWeight = elementStateWeightSum;
results.Add(Tuple.Create(transitionElements[0].distribution, transitionWeight, destinationState));
}
/// <summary>
/// Given a transition and the residual weight of its source state, adds weighted non-zero probability elements
/// associated with the transition to the list.
/// </summary>
/// <param name="transition">The transition.</param>
/// <param name="sourceStateResidualWeight">The logarithm of the residual weight of the source state of the transition.</param>
/// <param name="elements">The list for storing transition elements.</param>
/// <param name="uniformList">The list for storing transition elements for uniform.</param>
private static void AddTransitionElements(
Transition transition, Weight sourceStateResidualWeight,
Dictionary<TElement, List<TransitionElement>> elements, List<TransitionElement> uniformList)
{
var dist = transition.ElementDistribution.Value;
Weight weightBase = transition.Weight * sourceStateResidualWeight;
if (dist.IsPointMass)
{
var pt = dist.Point;
// todo: enumerate distribution
if (!elements.ContainsKey(pt)) elements[pt] = new List<TransitionElement>();
elements[pt].Add(new TransitionElement(transition.DestinationStateIndex, weightBase, dist));
}
else
{
uniformList.Add(new TransitionElement(transition.DestinationStateIndex, weightBase, dist));
}
}
private class TransitionElement
{
internal int destIndex;
internal Weight weight;
internal TElementDistribution distribution;
internal TransitionElement(int destIndex, Weight weight, TElementDistribution distribution)
{
this.destIndex = destIndex;
this.distribution = distribution;
this.weight = weight;
}
}
}
}

Просмотреть файл

@ -27,78 +27,62 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Computes a set of outgoing transitions from a given state of the determinization result.
/// </summary>
/// <param name="sourceState">The source state of the determinized automaton represented as
/// <param name="sourceStateSet">The source state of the determinized automaton represented as
/// a set of (stateId, weight) pairs, where state ids correspond to states of the original automaton.</param>
/// <returns>
/// A collection of (element distribution, weight, weighted state set) triples corresponding to outgoing transitions from <paramref name="sourceState"/>.
/// A collection of (element distribution, weight, weighted state set) triples corresponding to outgoing
/// transitions from <paramref name="sourceStateSet"/>.
/// The first two elements of a tuple define the element distribution and the weight of a transition.
/// The third element defines the outgoing state.
/// </returns>
protected override List<(DiscreteChar, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
Determinization.WeightedStateSet sourceState)
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
Determinization.WeightedStateSet sourceStateSet)
{
const double LogEps = -35; // Don't add transitions with log-weight less than this as they have been produced by numerical inaccuracies
// Build a list of numbered non-zero probability character segment bounds (they are numbered here due to perf. reasons)
var segmentBounds = new List<ValueTuple<int, TransitionCharSegmentBound>>();
int transitionsProcessed = 0;
foreach (KeyValuePair<int, Weight> stateIdWeight in sourceState)
var segmentBounds = new List<TransitionCharSegmentBound>();
for (var i = 0; i < sourceStateSet.Count; ++i)
{
var state = this.States[stateIdWeight.Key];
var sourceState = sourceStateSet[i];
var state = this.States[sourceState.Index];
foreach (var transition in state.Transitions)
{
AddTransitionCharSegmentBounds(transition, stateIdWeight.Value, segmentBounds);
AddTransitionCharSegmentBounds(transition, sourceState.Weight, segmentBounds);
}
}
transitionsProcessed += state.Transitions.Count;
}
// Sort segment bounds left-to-right, start-to-end
var sortedIndexedSegmentBounds = segmentBounds.ToArray();
if (transitionsProcessed > 1)
{
Array.Sort(sortedIndexedSegmentBounds, CompareSegmentBounds);
int CompareSegmentBounds((int, TransitionCharSegmentBound) a, (int, TransitionCharSegmentBound) b) =>
a.Item2.CompareTo(b.Item2);
}
segmentBounds.Sort();
// Produce an outgoing transition for each unique subset of overlapping segments
var result = new List<(DiscreteChar, Weight, Determinization.WeightedStateSet)>();
Weight currentSegmentStateWeightSum = Weight.Zero;
var currentSegmentTotal = WeightSum.Zero();
var currentSegmentStateWeights = new Dictionary<int, Weight>();
foreach (var sb in segmentBounds)
var currentSegmentStateWeights = new Dictionary<int, WeightSum>();
var currentSegmentStart = (int)char.MinValue;
var destinationStateSetBuilder = Determinization.WeightedStateSetBuilder.Create();
foreach (var segmentBound in segmentBounds)
{
currentSegmentStateWeights[sb.Item2.DestinationStateId] = Weight.Zero;
}
var activeSegments = new HashSet<TransitionCharSegmentBound>();
int currentSegmentStart = char.MinValue;
foreach (var tup in sortedIndexedSegmentBounds)
{
TransitionCharSegmentBound segmentBound = tup.Item2;
if (currentSegmentStateWeightSum.LogValue > LogEps && currentSegmentStart < segmentBound.Bound)
if (currentSegmentTotal.Count != 0 && currentSegmentStart < segmentBound.Bound)
{
// Flush previous segment
char segmentEnd = (char)(segmentBound.Bound - 1);
int segmentLength = segmentEnd - currentSegmentStart + 1;
DiscreteChar elementDist = DiscreteChar.InRange((char)currentSegmentStart, segmentEnd);
var segmentEnd = (char)(segmentBound.Bound - 1);
var segmentLength = segmentEnd - currentSegmentStart + 1;
var elementDist = DiscreteChar.InRange((char)currentSegmentStart, segmentEnd);
var invTotalWeight = Weight.Inverse(currentSegmentTotal.Sum);
var destinationState = new Determinization.WeightedStateSet();
foreach (KeyValuePair<int, Weight> stateIdWithWeight in currentSegmentStateWeights)
destinationStateSetBuilder.Reset();
foreach (var stateIdWithWeight in currentSegmentStateWeights)
{
if (stateIdWithWeight.Value.LogValue > LogEps)
{
Weight stateWeight = stateIdWithWeight.Value * Weight.Inverse(currentSegmentStateWeightSum);
destinationState.Add(stateIdWithWeight.Key, stateWeight);
}
var stateWeight = stateIdWithWeight.Value.Sum * invTotalWeight;
destinationStateSetBuilder.Add(stateIdWithWeight.Key, stateWeight);
}
Weight transitionWeight = Weight.FromValue(segmentLength) * currentSegmentStateWeightSum;
result.Add((elementDist, transitionWeight, destinationState));
var (destinationStateSet, destinationStateSetWeight) = destinationStateSetBuilder.Get();
var transitionWeight = Weight.Product(
Weight.FromValue(segmentLength),
currentSegmentTotal.Sum,
destinationStateSetWeight);
yield return new Determinization.OutgoingTransition(
elementDist, transitionWeight, destinationStateSet);
}
// Update current segment
@ -106,39 +90,35 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (segmentBound.IsStart)
{
activeSegments.Add(segmentBound);
currentSegmentStateWeightSum += segmentBound.Weight;
currentSegmentStateWeights[segmentBound.DestinationStateId] += segmentBound.Weight;
currentSegmentTotal += segmentBound.Weight;
if (currentSegmentStateWeights.TryGetValue(segmentBound.DestinationStateId, out var stateWeight))
{
currentSegmentStateWeights[segmentBound.DestinationStateId] =
stateWeight + segmentBound.Weight;
}
else
{
currentSegmentStateWeights[segmentBound.DestinationStateId] = new WeightSum(segmentBound.Weight);
}
}
else
{
Debug.Assert(currentSegmentStateWeights.ContainsKey(segmentBound.DestinationStateId), "We shouldn't exit a state we didn't enter.");
activeSegments.Remove(segmentBounds[tup.Item1 - 1].Item2); // End follows start in original.
if (double.IsInfinity(segmentBound.Weight.Value))
Debug.Assert(!segmentBound.Weight.IsInfinity);
currentSegmentTotal -= segmentBound.Weight;
var prevStateWeight = currentSegmentStateWeights[segmentBound.DestinationStateId];
var newStateWeight = prevStateWeight - segmentBound.Weight;
if (newStateWeight.Count == 0)
{
// Cannot subtract because of the infinities involved.
currentSegmentStateWeightSum =
activeSegments
.Select(sb => sb.Weight)
.Aggregate(Weight.Zero, Weight.Sum);
currentSegmentStateWeights[segmentBound.DestinationStateId] =
activeSegments
.Where(sb => sb.DestinationStateId == segmentBound.DestinationStateId)
.Select(sb => sb.Weight)
.Aggregate(Weight.Zero, Weight.Sum);
currentSegmentStateWeights.Remove(segmentBound.DestinationStateId);
}
else
{
currentSegmentStateWeightSum = activeSegments.Count == 0 ? Weight.Zero : Weight.AbsoluteDifference(currentSegmentStateWeightSum, segmentBound.Weight);
Weight prevStateWeight = currentSegmentStateWeights[segmentBound.DestinationStateId];
currentSegmentStateWeights[segmentBound.DestinationStateId] = Weight.AbsoluteDifference(
prevStateWeight, segmentBound.Weight);
currentSegmentStateWeights[segmentBound.DestinationStateId] = newStateWeight;
}
}
}
return result;
}
/// <summary>
@ -149,41 +129,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="sourceStateResidualWeight">The logarithm of the residual weight of the source state of the transition.</param>
/// <param name="bounds">The list for storing numbered segment bounds.</param>
private static void AddTransitionCharSegmentBounds(
Transition transition, Weight sourceStateResidualWeight, List<ValueTuple<int, TransitionCharSegmentBound>> bounds)
Transition transition, Weight sourceStateResidualWeight, List<TransitionCharSegmentBound> bounds)
{
var distribution = transition.ElementDistribution.Value;
var ranges = distribution.Ranges;
int commonValueStart = char.MinValue;
Weight commonValue = distribution.ProbabilityOutsideRanges;
Weight weightBase = transition.Weight * sourceStateResidualWeight;
TransitionCharSegmentBound newSegmentBound;
var commonValueStart = (int)char.MinValue;
var commonValue = distribution.ProbabilityOutsideRanges;
var weightBase = transition.Weight * sourceStateResidualWeight;
////if (double.IsInfinity(weightBase.Value))
////{
//// Console.WriteLine("Weight base infinity");
////}
void AddEndPoints(int start, int end, int destinationIndex, Weight weight)
{
bounds.Add(new TransitionCharSegmentBound(start, destinationIndex, weight * weightBase, true));
bounds.Add(new TransitionCharSegmentBound(end, destinationIndex, weight * weightBase, false));
}
foreach (var range in ranges)
{
if (range.StartInclusive > commonValueStart && !commonValue.IsZero)
{
// Add endpoints for the common value
Weight segmentWeight = commonValue * weightBase;
newSegmentBound = new TransitionCharSegmentBound(commonValueStart, transition.DestinationStateIndex, segmentWeight, true);
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
newSegmentBound = new TransitionCharSegmentBound(range.StartInclusive, transition.DestinationStateIndex, segmentWeight, false);
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
AddEndPoints(commonValueStart, range.StartInclusive, transition.DestinationStateIndex, commonValue);
}
// Add segment endpoints
Weight pieceValue = range.Probability;
var pieceValue = range.Probability;
if (!pieceValue.IsZero)
{
Weight segmentWeight = pieceValue * weightBase;
newSegmentBound = new TransitionCharSegmentBound(range.StartInclusive, transition.DestinationStateIndex, segmentWeight, true);
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
newSegmentBound = new TransitionCharSegmentBound(range.EndExclusive, transition.DestinationStateIndex, segmentWeight, false);
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count,newSegmentBound));
AddEndPoints(range.StartInclusive, range.EndExclusive, transition.DestinationStateIndex, pieceValue);
}
commonValueStart = range.EndExclusive;
@ -191,12 +162,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (!commonValue.IsZero && (ranges.Count == 0 || ranges[ranges.Count - 1].EndExclusive != DiscreteChar.CharRangeEndExclusive))
{
// Add endpoints for the last common value segment
Weight segmentWeight = commonValue * weightBase;
newSegmentBound = new TransitionCharSegmentBound(commonValueStart, transition.DestinationStateIndex, segmentWeight, true);
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
newSegmentBound = new TransitionCharSegmentBound(char.MaxValue + 1, transition.DestinationStateIndex, segmentWeight, false);
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
AddEndPoints(commonValueStart, char.MaxValue + 1, transition.DestinationStateIndex, commonValue);
}
}

Просмотреть файл

@ -572,7 +572,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// The first two elements of a tuple define the element distribution and the weight of a transition.
/// The third element defines the outgoing state.
/// </returns>
protected override List<(TPairDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
Determinization.WeightedStateSet sourceState)
{
throw new NotImplementedException("Determinization is not yet supported for this type of automata.");

Просмотреть файл

@ -73,7 +73,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Gets value indicating whether weight is infinite.
/// </summary>
public bool IsInfinity => double.IsPositiveInfinity(Value);
public bool IsInfinity => double.IsPositiveInfinity(this.LogValue);
/// <summary>
/// Creates a weight from the logarithm of its value.

Просмотреть файл

@ -0,0 +1,74 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
/// <summary>
/// Helper class for calculating sum of multiple <see cref="Weight"/>. It supports adding
/// and then substracting infinite values. Only values which were added into sum can be substracted from it.
/// </summary>
public struct WeightSum
{
/// <summary>
/// Number of non-infinite weights participating in the sum
/// </summary>
private readonly int count;
/// <summary>
/// Number of infinite weights participating in the sum
/// </summary>
private readonly int infCount;
/// <summary>
/// Sum of all non-infinite weights
/// </summary>
private readonly Weight sum;
public WeightSum(int count, int infCount, Weight sum)
{
this.count = count;
this.infCount = infCount;
this.sum = sum;
}
public WeightSum(Weight init)
{
this.count = 1;
if (init.IsInfinity)
{
this.infCount = 1;
this.sum = Weight.Zero;
}
else
{
this.infCount = 0;
this.sum = init;
}
}
/// <summary>
/// Constructs new empty accumulator instance
/// </summary>
public static WeightSum Zero() => new WeightSum(0, 0, Weight.Zero);
public static WeightSum operator +(WeightSum a, Weight b) =>
b.IsInfinity
? new WeightSum(a.count + 1, a.infCount + 1, a.sum)
: new WeightSum(a.count + 1, a.infCount, a.sum + b);
public static WeightSum operator -(WeightSum a, Weight b) =>
a.count == 1
? WeightSum.Zero()
: (b.IsInfinity
? new WeightSum(a.count - 1, a.infCount - 1, a.sum)
: new WeightSum(a.count - 1, a.infCount, Weight.AbsoluteDifference(a.sum, b)));
public int Count => this.count;
public Weight Sum =>
this.infCount != 0
? Weight.Infinity
: this.sum;
}
}

Просмотреть файл

@ -2244,6 +2244,49 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.True(nonUniform.Equals(uniform));
}
/// <summary>
/// Tests that SetToConstantSupportOf() Doesn't throw low probability transitions out.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
public void SetToConstantSupportOfWithLowProbabilityTransition1()
{
var builder = new StringAutomaton.Builder(1);
builder.Start
.AddTransition('a', Weight.FromValue(5e-40))
.SetEndWeight(Weight.One);
var automaton = builder.GetAutomaton();
automaton.SetToConstantOnSupportOfLog(0.0, automaton);
Assert.Equal(new[] { "a" }, automaton.EnumerateSupport().ToArray());
}
/// <summary>
/// Tests that SetToConstantSupportOf() Doesn't throw low probability transitions out.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
public void SetToConstantSupportOfWithMultipleLowProbabilityTransitions()
{
var builder = new StringAutomaton.Builder(1);
builder.Start
.AddTransition(DiscreteChar.OneOf('a', 'b'), Weight.One)
.SetEndWeight(Weight.One);
builder.Start
.AddTransition(DiscreteChar.OneOf('a', 'b'), Weight.FromLogValue(-1000))
.AddTransition('c', Weight.One)
.SetEndWeight(Weight.One);
builder.Start
.AddTransition('c', Weight.FromLogValue(-10000))
.SetEndWeight(Weight.One);
var automaton = builder.GetAutomaton();
var tmp = automaton.Clone();
tmp.TryDeterminize();
automaton.SetToConstantOnSupportOfLog(0.0, automaton);
var support = automaton.EnumerateSupport().OrderBy(s => s).ToArray();
Assert.Equal(new[] {"a", "ac", "b", "bc", "c"}, support);
}
#endregion
#region Helpers

Просмотреть файл

@ -36,11 +36,13 @@ namespace Microsoft.ML.Probabilistic.Tests
var wrapper = new StringAutomatonWrapper(builder);
var outgoingTransitions =
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(3) } });
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(3));
var expectedOutgoingTransitions = new[]
{
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
DiscreteChar.Uniform(), Weight.FromValue(6), new Dictionary<int, Weight> { { 1, Weight.FromValue(1) } })
ValueTuple.Create(
DiscreteChar.Uniform(),
Weight.FromValue(6),
new[] {(1, Weight.FromValue(1))})
};
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
@ -61,13 +63,17 @@ namespace Microsoft.ML.Probabilistic.Tests
var wrapper = new StringAutomatonWrapper(builder);
var outgoingTransitions =
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(5) } });
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(5));
var expectedOutgoingTransitions = new[]
{
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
DiscreteChar.UniformInRange('A', 'Z'), Weight.FromValue(7.5), new Dictionary<int, Weight> { { 2, Weight.FromValue(1) } }),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(17.5), new Dictionary<int, Weight> { { 1, Weight.FromValue(10 / 17.5) }, { 2, Weight.FromValue(7.5 / 17.5) } }),
ValueTuple.Create(
DiscreteChar.UniformInRange('A', 'Z'),
Weight.FromValue(7.5),
new[] {(2, Weight.FromValue(1))}),
ValueTuple.Create(
DiscreteChar.UniformInRange('a', 'z'),
Weight.FromValue(10),
new[] {(1, Weight.FromValue(1)), (2, Weight.FromValue(0.75))}),
};
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
@ -91,29 +97,33 @@ namespace Microsoft.ML.Probabilistic.Tests
var wrapper = new StringAutomatonWrapper(builder);
var outgoingTransitions =
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(6) } });
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(6));
var expectedOutgoingTransitions = new[]
{
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
DiscreteChar.UniformInRange(char.MinValue, (char)('a' - 1)),
Weight.FromValue(30.0 * 97.0 / 98.0),
new Dictionary<int, Weight> { { 4, Weight.FromValue(1) } }),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
ValueTuple.Create(
DiscreteChar.UniformInRange(char.MinValue, (char) ('a' - 1)),
Weight.FromValue(6 * 5.0 * 97.0 / 98.0),
new[] {(4, Weight.FromValue(1))}),
ValueTuple.Create(
DiscreteChar.PointMass('a'),
Weight.FromValue((30.0 / 98.0) + 6.0),
new Dictionary<int, Weight> { { 1, Weight.FromValue(6.0 / ((30.0 / 98.0) + 6.0)) }, { 4, Weight.FromValue((30.0 / 98.0) / ((30.0 / 98.0) + 6.0)) } }),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
Weight.FromValue(6),
new[]
{
(1, Weight.FromValue(1.0)),
(4, Weight.FromValue(5.0 / 98.0))
}),
ValueTuple.Create(
DiscreteChar.PointMass('b'),
Weight.FromValue(12.0),
new Dictionary<int, Weight> { { 1, Weight.FromValue(0.5) }, { 2, Weight.FromValue(0.5)} }),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
Weight.FromValue(6),
new[] {(1, Weight.FromValue(1)), (2, Weight.FromValue(1))}),
ValueTuple.Create(
DiscreteChar.UniformInRange('c', 'd'),
Weight.FromValue(12.0),
new Dictionary<int, Weight> { { 2, Weight.FromValue(1.0) } }),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
Weight.FromValue(6 * 3 * (2.0 / 3)),
new[] {(2, Weight.FromValue(1))}),
ValueTuple.Create(
DiscreteChar.UniformInRange('e', 'g'),
Weight.FromValue(24.0),
new Dictionary<int, Weight> { { 3, Weight.FromValue(1.0) } }),
Weight.FromValue(6 * 4),
new[] {(3, Weight.FromValue(1.0))}),
};
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
@ -135,36 +145,46 @@ namespace Microsoft.ML.Probabilistic.Tests
var wrapper = new StringAutomatonWrapper(builder);
var outgoingTransitions =
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(5) } });
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(1));
// we have 3 segments:
// 1. [char.MinValue, 'a')
// 2. ['a', 'z')
// 3. ['z', char.MaxValue]
var transition1Segment1Weight = 2.0 * 'a' / (char.MaxValue + 1.0);
var transition1Segment2Weight = 2.0 * ('z' - 'a') / (char.MaxValue + 1.0);
var transition1Segment3Weight = 2.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue + 1.0);
var transition2Segment2Weight = 3.0 * ('z' - 'a') / (char.MaxValue - 'a' + 1.0);
var transition2Segment3Weight = 3.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue - 'a' + 1.0);
var transition3Segment3Weight = 4.0;
var maxSegment2Weight = Math.Max(transition1Segment2Weight, transition2Segment2Weight);
var maxSegment3Weight = Math.Max(
transition1Segment3Weight,
Math.Max(transition2Segment3Weight, transition3Segment3Weight));
double transition1Segment1Weight = 10.0 * 'a' / (char.MaxValue + 1.0);
double transition1Segment2Weight = 10.0 * ('z' - 'a') / (char.MaxValue + 1.0);
double transition1Segment3Weight = 10.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue + 1.0);
double transition2Segment1Weight = 15.0 * ('z' - 'a') / (char.MaxValue - 'a' + 1.0);
double transition2Segment2Weight = 15.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue - 'a' + 1.0);
double transition3Segment1Weight = 20.0;
var expectedOutgoingTransitions = new[]
{
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
ValueTuple.Create(
DiscreteChar.UniformInRange(char.MinValue, (char)('a' - 1)),
Weight.FromValue(transition1Segment1Weight),
new Dictionary<int, Weight> { { 1, Weight.FromValue(1) } }),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
new[] {(1, Weight.FromValue(1))}),
ValueTuple.Create(
DiscreteChar.UniformInRange('a', (char)('z' - 1)),
Weight.FromValue(transition1Segment2Weight + transition2Segment1Weight),
new Dictionary<int, Weight>
Weight.FromValue(maxSegment2Weight),
new[]
{
{ 1, Weight.FromValue(transition1Segment2Weight / (transition1Segment2Weight + transition2Segment1Weight)) },
{ 2, Weight.FromValue(transition2Segment1Weight / (transition1Segment2Weight + transition2Segment1Weight)) }
(1, Weight.FromValue(transition1Segment2Weight / maxSegment2Weight)),
(2, Weight.FromValue(transition2Segment2Weight / maxSegment2Weight)),
}),
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
ValueTuple.Create(
DiscreteChar.UniformInRange('z', char.MaxValue),
Weight.FromValue(transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight),
new Dictionary<int, Weight>
Weight.FromValue(maxSegment3Weight),
new[]
{
{ 1, Weight.FromValue(transition1Segment3Weight / (transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight)) },
{ 2, Weight.FromValue(transition2Segment2Weight / (transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight)) },
{ 3, Weight.FromValue(transition3Segment1Weight / (transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight)) }
(1, Weight.FromValue(transition1Segment3Weight / maxSegment3Weight)),
(2, Weight.FromValue(transition2Segment3Weight / maxSegment3Weight)),
(3, Weight.FromValue(transition3Segment3Weight / maxSegment3Weight)),
}),
};
@ -188,8 +208,8 @@ namespace Microsoft.ML.Probabilistic.Tests
var wrapper = new StringAutomatonWrapper(builder);
var outgoingTransitions = wrapper.GetOutgoingTransitionsForDeterminization(
new Dictionary<int, Weight> { { 0, Weight.FromValue(1) } }).ToArray();
var outgoingTransitions =
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(1)).ToArray();
Assert.Equal(5, outgoingTransitions.Length);
Assert.True(outgoingTransitions.All(ot => ot.Item1.IsPointMass));
@ -277,7 +297,7 @@ namespace Microsoft.ML.Probabilistic.Tests
/// A comparer for weighted states that allows for some tolerance when comparing weights.
/// </summary>
private class WeightedStateEqualityComparer :
EqualityComparerBase<KeyValuePair<int, Weight>, WeightedStateEqualityComparer>
EqualityComparerBase<(int, Weight), WeightedStateEqualityComparer>
{
/// <summary>
/// Checks whether two objects are equal.
@ -288,9 +308,9 @@ namespace Microsoft.ML.Probabilistic.Tests
/// <see langword="true"/> if the objects are equal,
/// <see langword="false"/> otherwise.
/// </returns>
public override bool Equals(KeyValuePair<int, Weight> x, KeyValuePair<int, Weight> y)
public override bool Equals((int, Weight) x, (int, Weight) y)
{
return x.Key == y.Key && WeightEqualityComparer.Instance.Equals(x.Value, y.Value);
return x.Item1 == y.Item1 && WeightEqualityComparer.Instance.Equals(x.Item2, y.Item2);
}
}
@ -298,7 +318,7 @@ namespace Microsoft.ML.Probabilistic.Tests
/// A comparer for transition information that allows for some tolerance when comparing weights.
/// </summary>
private class TransitionInfoEqualityComparer :
EqualityComparerBase<ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>, TransitionInfoEqualityComparer>
EqualityComparerBase<ValueTuple<DiscreteChar, Weight, (int, Weight)[]>, TransitionInfoEqualityComparer>
{
/// <summary>
/// Checks whether two objects are equal.
@ -310,13 +330,13 @@ namespace Microsoft.ML.Probabilistic.Tests
/// <see langword="false"/> otherwise.
/// </returns>
public override bool Equals(
ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>> x,
ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>> y)
ValueTuple<DiscreteChar, Weight, (int, Weight)[]> x,
ValueTuple<DiscreteChar, Weight, (int, Weight)[]> y)
{
return
object.Equals(x.Item1, y.Item1) &&
WeightEqualityComparer.Instance.Equals(x.Item2, y.Item2) &&
Enumerable.SequenceEqual(x.Item3, y.Item3, WeightedStateEqualityComparer.Instance);
x.Item3.SequenceEqual(y.Item3, WeightedStateEqualityComparer.Instance);
}
}
@ -338,14 +358,20 @@ namespace Microsoft.ML.Probabilistic.Tests
/// <param name="sourceState">The source state.</param>
/// <returns>The produced transitions.</returns>
/// <remarks>See the doc of the original method.</remarks>
public IEnumerable<(DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>)>
GetOutgoingTransitionsForDeterminization(IEnumerable<KeyValuePair<int, Weight>> sourceState)
public IEnumerable<(DiscreteChar, Weight, (int, Weight)[])>
GetOutgoingTransitionsForDeterminization(int sourceState, Weight sourceWeight)
{
var result = base.GetOutgoingTransitionsForDeterminization(new Determinization.WeightedStateSet(sourceState));
return result.Select(t => (t.Item1, t.Item2, (IEnumerable<KeyValuePair<int, Weight>>)t.Item3));
}
}
var weightedStateSetBuilder = Determinization.WeightedStateSetBuilder.Create();
weightedStateSetBuilder.Add(sourceState, sourceWeight);
var (weightedStateSet, weight) = weightedStateSetBuilder.Get();
var result = base.GetOutgoingTransitionsForDeterminization(weightedStateSet);
return result.Select(t => (
t.ElementDistribution,
t.Weight * weight,
t.Destinations.ToArray().Select(state => (state.Index, state.Weight)).ToArray()));
}
}
#endregion
}
}