Get rid of recursive implementations (#125)

We finally reached a point where automatons are big enough that recursive implementations
of algorithms fail with `StackOverflowException`.

There are 4 tests which test operations with big automatons (100k states):
* `TryComputePointLargeAutomaton`
* `SetToProductLargeAutomaton`
* `GetLogNormalizerLargeAutomaton`
* `ProjectSourceLargeAutomaton`

All four used to fail before code was rewritten to use explicit stack instead of recursive calls.
All changes except one didn't change algorithms used, code was almost mechanically changed
to use stack.

The only exception - automaton simplification. Old code used recursion in very non-trivial way.
It was rewritten from scratch, using different algorithm: instead of extraction of generalized
sequences and then reinserting them back new code merges states directly in automaton.
(There's a comment at the beginning of Simplify() method explaining all operations)
This commit is contained in:
Ivan Korostelev 2019-03-11 11:33:29 +00:00 коммит произвёл GitHub
Родитель 702874f9aa
Коммит 6d2ce9a993
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 986 добавлений и 1303 удалений

Просмотреть файл

@ -214,15 +214,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
this.StartStateIndex = oldToNewStateIdMapping[this.StartStateIndex];
if (this.StartStateIndex == -1)
if (deadStateCount == 0)
{
// Cannot reach any end state from the start state => the automaton is zero everywhere
this.Clear();
this.AddState();
return deadStateCount;
return 0;
}
// may invalidate automaton
this.StartStateIndex = oldToNewStateIdMapping[this.StartStateIndex];
for (var i = 0; i < this.states.Count; ++i)
{
var newId = oldToNewStateIdMapping[i];

Просмотреть файл

@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
using System;
@ -108,12 +110,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.useApproximateClosure = useApproximateClosure;
this.components = new List<StronglyConnectedComponent>();
var stateIdStack = new Stack<State>();
var stateIdToTarjanInfo = new Dictionary<int, TarjanStateInfo>();
int traversalIndex = 0;
this.FindStronglyConnectedComponents(this.Root, ref traversalIndex, stateIdToTarjanInfo, stateIdStack);
this.FindStronglyConnectedComponents();
this.stateIdToInfo = new Dictionary<int, CondensationStateInfo>(stateIdToTarjanInfo.Count);
this.stateIdToInfo = new Dictionary<int, CondensationStateInfo>();
for (int i = 0; i < this.components.Count; ++i)
{
StronglyConnectedComponent component = this.components[i];
@ -194,57 +193,99 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Implements <a href="http://en.wikipedia.org/wiki/Tarjan's_strongly_connected_components_algorithm">Tarjan's algorithm</a>
/// for finding the strongly connected components of the automaton graph.
/// </summary>
/// <param name="currentState">The state currently being traversed.</param>
/// <param name="traversalIndex">The traversal index (as defined by the Tarjan's algorithm).</param>
/// <param name="stateIdToStateInfo">A dictionary mapping state indices to info records maintained by the Tarjan's algorithm.</param>
/// <param name="stateIdStack">The traversal stack (as defined by the Tarjan's algorithm).</param>
private void FindStronglyConnectedComponents(
State currentState,
ref int traversalIndex,
Dictionary<int, TarjanStateInfo> stateIdToStateInfo,
Stack<State> stateIdStack)
/// <remarks>
/// This implementation closely follows algorithm from wikipedia. But is rewritten to use explicit
/// traversal stack instead of recursion.
///
/// Basic idea is each time we need to do a recursive call we save all needed state into traversal
/// stack, and push this "continuation" onto the stack. All state that need to be preserved are:
/// - currentStateIndex - state being processed
/// - currentTransitionIndex - index of the transition which was recursed last.
/// If it is -1, that means that state was not processed yet at all.
/// </remarks>
private void FindStronglyConnectedComponents()
{
Debug.Assert(!stateIdToStateInfo.ContainsKey(currentState.Index), "Visited states must not be revisited.");
var states = this.Root.Owner.States;
var traversalIndex = 0;
var stateIdStack = new Stack<int>(); // Stack that maintains Tarjan algorithm invariant
var info = new TarjanStateInfo[this.Root.Owner.States.Count];
var traversalStack = new Stack<(int currentStateIndex, int lastDestination)>();
var stateInfo = new TarjanStateInfo(traversalIndex);
stateIdToStateInfo.Add(currentState.Index, stateInfo);
++traversalIndex;
traversalStack.Push((this.Root.Index, -1));
stateIdStack.Push(currentState);
stateInfo.InStack = true;
foreach (var transition in currentState.Transitions)
while (traversalStack.Count > 0)
{
if (!this.transitionFilter(transition))
var (current, currentTransitionIndex) = traversalStack.Pop();
var currentState = states[current];
if (currentTransitionIndex < 0)
{
// Just entered
Debug.Assert(!info[current].Visited);
info[current].Visited = true;
info[current].TraversalIndex = traversalIndex;
info[current].Lowlink = traversalIndex;
info[current].InStack = true;
stateIdStack.Push(current);
++traversalIndex;
currentTransitionIndex = 0;
}
else
{
// we already traversed some state, and its index is in enumerator
var lastTraversed = currentState.Transitions[currentTransitionIndex].DestinationStateIndex;
info[current].Lowlink = Math.Min(info[current].Lowlink, info[lastTraversed].Lowlink);
++currentTransitionIndex;
}
// Continue processing
for (; currentTransitionIndex < currentState.Transitions.Count; ++currentTransitionIndex)
{
var transition = currentState.Transitions[currentTransitionIndex];
if (transitionFilter(transition))
{
var destination = transition.DestinationStateIndex;
if (!info[destination].Visited)
{
// return to this state after destination is traversed
traversalStack.Push((current, currentTransitionIndex));
// traverse destination
traversalStack.Push((destination, -1));
// Processing of this state will effectively be resumed after destination is processed
break;
}
if (info[destination].InStack)
{
info[current].Lowlink =
Math.Min(info[current].Lowlink, info[destination].TraversalIndex);
}
}
}
// We can break from for-loop above before end condition is met only if we pushed some
// work to do onto traversal stack. One of the things pushed to stack will effectively
// resume the loop from the currentTransitionIndex.
if (currentTransitionIndex < currentState.Transitions.Count)
{
continue;
}
if (!stateIdToStateInfo.TryGetValue(transition.DestinationStateIndex, out TarjanStateInfo destinationStateInfo))
if (info[current].Lowlink == info[current].TraversalIndex)
{
this.FindStronglyConnectedComponents(
this.Root.Owner.States[transition.DestinationStateIndex], ref traversalIndex, stateIdToStateInfo, stateIdStack);
stateInfo.Lowlink = Math.Min(stateInfo.Lowlink, stateIdToStateInfo[transition.DestinationStateIndex].Lowlink);
}
else if (destinationStateInfo.InStack)
{
stateInfo.Lowlink = Math.Min(stateInfo.Lowlink, destinationStateInfo.TraversalIndex);
}
}
var statesInComponent = new List<State>();
State state;
do
{
state = states[stateIdStack.Pop()];
info[state.Index].InStack = false;
statesInComponent.Add(state);
} while (state.Index != current);
if (stateInfo.Lowlink == stateInfo.TraversalIndex)
{
var statesInComponent = new List<State>();
State state;
do
{
state = stateIdStack.Pop();
stateIdToStateInfo[state.Index].InStack = false;
statesInComponent.Add(state);
this.components.Add(
new StronglyConnectedComponent(
this.transitionFilter, statesInComponent, this.useApproximateClosure));
}
while (state.Index != currentState.Index);
this.components.Add(new StronglyConnectedComponent(this.transitionFilter, statesInComponent, this.useApproximateClosure));
}
}
@ -407,17 +448,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Stores the information maintained by the Tarjan's algorithm.
/// </summary>
private class TarjanStateInfo
private struct TarjanStateInfo
{
/// <summary>
/// Initializes a new instance of the <see cref="TarjanStateInfo"/> class.
/// Gets or sets value indicating whether this node has already been visited
/// </summary>
/// <param name="traversalIndex">The current traversal index.</param>
public TarjanStateInfo(int traversalIndex)
{
this.TraversalIndex = traversalIndex;
this.Lowlink = traversalIndex;
}
public bool Visited { get; set; }
/// <summary>
/// Gets or sets a value indicating whether the state is currently in stack.
@ -427,7 +463,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Gets the traversal index of the state.
/// </summary>
public int TraversalIndex { get; private set; }
public int TraversalIndex { get; set; }
/// <summary>
/// Gets or sets the lowlink of the state.

Просмотреть файл

@ -123,11 +123,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
var simplification = new Simplification(builder, this.PruneTransitionsWithLogWeightLessThan);
var simplification = new Simplification(builder, this.PruneStatesWithLogEndWeightLessThan);
simplification.MergeParallelTransitions(); // Determinization produces a separate transition for each segment
var result = builder.GetAutomaton();
result.PruneTransitionsWithLogWeightLessThan = this.PruneTransitionsWithLogWeightLessThan;
result.PruneStatesWithLogEndWeightLessThan = this.PruneStatesWithLogEndWeightLessThan;
result.LogValueOverride = this.LogValueOverride;
this.SwapWith(result);

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -2,11 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections;
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Diagnostics;
@ -143,7 +142,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <remarks>
/// TODO: We need to develop more elegant automaton approximation methods, this is a simple placeholder for those.
/// </remarks>
public double? PruneTransitionsWithLogWeightLessThan { get; set; }
public double? PruneStatesWithLogEndWeightLessThan { get; set; }
/// <summary>
/// Gets or sets the maximum number of states an automaton can have.
@ -763,12 +762,107 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
return this.ToString();
}
else
var builder = new StringBuilder();
var visitedStates = new HashSet<int>();
var stack = new Stack<(string prefix, Option<TElementDistribution> prefixDistribution, int state)>();
stack.Push((string.Empty, Option.None, Start.Index));
while (stack.Count > 0)
{
StringBuilder builder = new StringBuilder();
this.AppendString(builder, new HashSet<int>(), this.Start.Index, appendElement);
return builder.ToString();
var (prefix, prefixDistribution, stateIndex) = stack.Pop();
builder.Append(prefix);
if (prefixDistribution.HasValue)
{
if (appendElement != null)
{
appendElement(prefixDistribution.Value, builder);
}
else
{
builder.Append(prefixDistribution);
}
}
if (stateIndex == -1)
{
continue;
}
if (visitedStates.Contains(stateIndex))
{
builder.Append('➥');
continue;
}
visitedStates.Add(stateIndex);
var state = this.States[stateIndex];
var transitions = state.Transitions.Where(t => !t.Weight.IsZero);
var selfTransitions = transitions.Where(t => t.DestinationStateIndex == stateIndex);
var selfTransitionCount = selfTransitions.Count();
var nonSelfTransitions = transitions.Where(t => t.DestinationStateIndex != stateIndex).ToArray();
var nonSelfTransitionCount = nonSelfTransitions.Count();
if (state.CanEnd && nonSelfTransitionCount > 0)
{
builder.Append('▪');
}
if (selfTransitionCount > 1)
{
builder.Append('[');
}
var transIdx = 0;
foreach (var transition in selfTransitions)
{
if (!transition.IsEpsilon)
{
if (appendElement != null)
{
appendElement(transition.ElementDistribution.Value, builder);
}
else
{
builder.Append(transition.ElementDistribution);
}
}
if (++transIdx < selfTransitionCount)
{
builder.Append('|');
}
}
if (selfTransitionCount > 1)
{
builder.Append(']');
}
if (selfTransitionCount > 0)
{
builder.Append('*');
}
if (nonSelfTransitionCount > 1)
{
builder.Append('[');
stack.Push(("]", Option.None, -1));
}
// Iterate transitions in reverse order, because after destinations are pushed to stack,
// order will be reversed again.
for (transIdx = nonSelfTransitions.Length - 1; transIdx >= 0; --transIdx)
{
var transition = nonSelfTransitions[transIdx];
var transitionPrefix = transIdx == 0 ? string.Empty : "|";
stack.Push((transitionPrefix, transition.ElementDistribution, transition.DestinationStateIndex));
}
}
return builder.ToString();
}
/// <summary>
@ -853,10 +947,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
/// <returns>The logarithm of the normalizer.</returns>
/// <remarks>Returns <see cref="double.PositiveInfinity"/> if the sum diverges.</remarks>
public double GetLogNormalizer()
{
return this.DoGetLogNormalizer(false);
}
public double GetLogNormalizer() => this.DoGetLogNormalizer(false);
/// <summary>
/// Normalizes the automaton so that the sum of its values over all possible sequences equals to one
@ -1036,7 +1127,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var state = result[stateId];
if (state.CanEnd)
{
// Make all accepting states contibute the desired value to the result
// Make all accepting states contribute the desired value to the result
state.SetEndWeight(value);
}
@ -1242,37 +1333,35 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var builder = new Builder(0);
var productStateCache = new Dictionary<(int, int), int>(automaton1.States.Count + automaton2.States.Count);
builder.StartStateIndex = BuildProduct(automaton1.Start, automaton2.Start);
var stack = new Stack<(int state1, int state2, int productStateIndex)>();
var simplification = new Simplification(builder, this.PruneTransitionsWithLogWeightLessThan);
simplification.RemoveDeadStates(); // Product can potentially create dead states
simplification.SimplifyIfNeeded();
this.Data = builder.GetData();
if (this is StringAutomaton && tryDeterminize)
// Creates product state and schedules product computation for it.
// If computation is already scheduled or done the state index is simply taken from cache
int CreateProductState(State state1, State state2)
{
this.TryDeterminize();
}
// Recursively builds an automaton representing the product of two given automata.
// Returns start state of product automaton.
int BuildProduct(State state1, State state2)
{
Debug.Assert(state1 != null && state2 != null, "Valid states must be provided.");
Debug.Assert(
state2.Owner.IsEpsilonFree,
"The second argument of the product operation must be epsilon-free.");
// State already exists, return its index
var statePair = (state1.Index, state2.Index);
if (productStateCache.TryGetValue(statePair, out var productStateIndex))
var destPair = (state1.Index, state2.Index);
if (!productStateCache.TryGetValue(destPair, out var productStateIndex))
{
return productStateIndex;
var productState = builder.AddState();
productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight));
stack.Push((state1.Index, state2.Index, productState.Index));
productStateCache[destPair] = productState.Index;
productStateIndex = productState.Index;
}
// Create a new state
var productState = builder.AddState();
productStateCache.Add(statePair, productState.Index);
return productStateIndex;
}
// Populate the stack with start product state
builder.StartStateIndex = CreateProductState(automaton1.Start, automaton2.Start);
while (stack.Count > 0)
{
var (state1Index, state2Index, productStateIndex) = stack.Pop();
var state1 = automaton1.States[state1Index];
var state2 = automaton2.States[state2Index];
var productState = builder[productStateIndex];
// Iterate over transitions in state1
foreach (var transition1 in state1.Transitions)
@ -1282,7 +1371,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (transition1.IsEpsilon)
{
// Epsilon transition case
var destProductStateIndex = BuildProduct(destState1, state2);
var destProductStateIndex = CreateProductState(destState1, state2);
productState.AddEpsilonTransition(transition1.Weight, destProductStateIndex, transition1.Group);
continue;
}
@ -1295,23 +1384,29 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
"The second argument of the product operation must be epsilon-free.");
var destState2 = state2.Owner.States[transition2.DestinationStateIndex];
var productLogNormalizer = Distribution<TElement>.GetLogAverageOf(
transition1.ElementDistribution.Value, transition2.ElementDistribution.Value, out var product);
transition1.ElementDistribution.Value, transition2.ElementDistribution.Value,
out var product);
if (double.IsNegativeInfinity(productLogNormalizer))
{
continue;
}
var productWeight = Weight.Product(
transition1.Weight,
transition2.Weight,
Weight.FromLogValue(productLogNormalizer));
var destProductStateIndex = BuildProduct(destState1, destState2);
transition1.Weight, transition2.Weight, Weight.FromLogValue(productLogNormalizer));
var destProductStateIndex = CreateProductState(destState1, destState2);
productState.AddTransition(product, productWeight, destProductStateIndex, transition1.Group);
}
}
}
productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight));
return productState.Index;
var simplification = new Simplification(builder, this.PruneStatesWithLogEndWeightLessThan);
simplification.RemoveDeadStates(); // Product can potentially create dead states
simplification.SimplifyIfNeeded();
this.Data = builder.GetData();
if (this is StringAutomaton && tryDeterminize)
{
this.TryDeterminize();
}
}
@ -1323,17 +1418,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>The merged pruning threshold.</returns>
private double? MergePruningWeights(TThis automaton1, TThis automaton2)
{
if (automaton1.PruneTransitionsWithLogWeightLessThan == null)
if (automaton1.PruneStatesWithLogEndWeightLessThan == null)
{
return automaton2.PruneTransitionsWithLogWeightLessThan;
return automaton2.PruneStatesWithLogEndWeightLessThan;
}
if (automaton2.PruneTransitionsWithLogWeightLessThan == null)
if (automaton2.PruneStatesWithLogEndWeightLessThan == null)
{
return automaton1.PruneTransitionsWithLogWeightLessThan;
return automaton1.PruneStatesWithLogEndWeightLessThan;
}
return Math.Min(automaton1.PruneTransitionsWithLogWeightLessThan.Value, automaton2.PruneTransitionsWithLogWeightLessThan.Value);
return Math.Min(automaton1.PruneStatesWithLogEndWeightLessThan.Value, automaton2.PruneStatesWithLogEndWeightLessThan.Value);
}
/// <summary>
@ -1437,7 +1532,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
var simplification = new Simplification(result, this.PruneTransitionsWithLogWeightLessThan);
var simplification = new Simplification(result, this.PruneStatesWithLogEndWeightLessThan);
simplification.SimplifyIfNeeded();
this.Data = result.GetData();
@ -1571,7 +1666,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.Data = automaton.Data;
this.LogValueOverride = automaton.LogValueOverride;
this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan;
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
}
/// <summary>
@ -1632,56 +1727,85 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
/// <param name="sequence">The sequence to compute the value on.</param>
/// <returns>The logarithm of the value.</returns>
/// <remarks></remarks>
/// <remarks>Recursive implementation would be simpler but prone to stack overflows with large automata</remarks>
public double GetLogValue(TSequence sequence)
{
Argument.CheckIfNotNull(sequence, "sequence");
var valueCache = new Dictionary<(int, int), Weight>();
var logValue = DoGetValue(this.Start.Index, 0).LogValue;
var sequenceLength = SequenceManipulator.GetLength(sequence);
return
!double.IsNegativeInfinity(logValue) && this.LogValueOverride.HasValue
? this.LogValueOverride.Value
: logValue;
// This algorithm is unwinding of trivial recursion with cache. It encodes recursion through explicit
// stack to avoid stack-overflow. It makes calculation not straightforward. Recursive algorithm that
// we want to encode looks like this:
// GetLogValue(state, sequencePos) =
// closure = EpsilonClosure(state)
// if sequencePos == sequenceLength:
// return closure.EndWeight
// else:
// return sum([GetLogValue(transition.Destination, sequencePos+1) * transitionWeight'
// for transition in closure.Transitions])
// where transitionWeight' = transition.Weight * closureStateWeight * transition.ElementDistribution.GetLogWeight(sequence[sequencePos])
//
// To encode that with explicit stack we split each GetLogValue() call into 2 parts which are put
// on operations stack:
// a) (stateIndex, sequencePos, mul, -1) -
// schedules computation for each transition from this state and (b) operation after them
// b) (stateIndex, sequencePos, mul, sumUntil) -
// sums values of transitions from this state. They will be on top of the stack
// As an optimization, value calculated by (b) is cached, so if (a) notices that result for
// (stateIndex, sequencePos) already calculated it doesn't schedule any extra work
var operationsStack = new Stack<(int stateIndex, int sequencePos, Weight multiplier, int sumUntil)>();
var valuesStack = new Stack<Weight>();
var valueCache = new Dictionary<(int stateIndex, int sequencePos), Weight>();
operationsStack.Push((this.Start.Index, 0, Weight.One, -1));
Weight DoGetValue(int stateIndex, int sequencePosition)
while (operationsStack.Count > 0)
{
var state = this.States[stateIndex];
var stateIndexPair = (stateIndex, sequencePosition);
if (valueCache.TryGetValue(stateIndexPair, out var cachedValue))
{
return cachedValue;
}
var (stateIndex, sequencePos, multiplier, sumUntil) = operationsStack.Pop();
var statePosPair = (stateIndex, sequencePos);
var closure = state.GetEpsilonClosure();
var value = Weight.Zero;
var count = SequenceManipulator.GetLength(sequence);
var isCurrent = sequencePosition < count;
if (isCurrent)
if (sumUntil < 0)
{
var element = SequenceManipulator.GetElement(sequence, sequencePosition);
for (var closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex)
if (valueCache.TryGetValue(statePosPair, out var cachedValue))
{
var closureState = closure.GetStateByIndex(closureStateIndex);
var closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex);
valuesStack.Push(Weight.Product(cachedValue, multiplier));
}
else
{
var closure = this.States[stateIndex].GetEpsilonClosure();
foreach (var transition in closureState.Transitions)
if (sequencePos == sequenceLength)
{
if (transition.IsEpsilon)
{
continue; // The destination is a part of the closure anyway
}
// We are at the end of sequence. So put an answer on stack
valuesStack.Push(Weight.Product(closure.EndWeight, multiplier));
valueCache[statePosPair] = closure.EndWeight;
}
else
{
// schedule second part of computation - sum values for all transitions
// Note: it is put on stack before operations for any transitions, that
// means that it will be executed after them
operationsStack.Push((stateIndex, sequencePos, multiplier, valuesStack.Count));
var destState = this.States[transition.DestinationStateIndex];
var distWeight = Weight.FromLogValue(transition.ElementDistribution.Value.GetLogProb(element));
if (!distWeight.IsZero && !transition.Weight.IsZero)
var element = SequenceManipulator.GetElement(sequence, sequencePos);
for (var closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex)
{
var destValue = DoGetValue(destState.Index, sequencePosition + 1);
if (!destValue.IsZero)
var closureState = closure.GetStateByIndex(closureStateIndex);
var closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex);
foreach (var transition in closureState.Transitions)
{
value = Weight.Sum(
value,
Weight.Product(closureStateWeight, transition.Weight, distWeight, destValue));
if (transition.IsEpsilon)
{
continue; // The destination is a part of the closure anyway
}
var destStateIndex = transition.DestinationStateIndex;
var distWeight = Weight.FromLogValue(transition.ElementDistribution.Value.GetLogProb(element));
if (!distWeight.IsZero && !transition.Weight.IsZero)
{
var weightMul = Weight.Product(closureStateWeight, transition.Weight, distWeight);
operationsStack.Push((destStateIndex, sequencePos + 1, weightMul, -1));
}
}
}
}
@ -1689,12 +1813,25 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
else
{
value = closure.EndWeight;
}
// All transitions value from this state are already calculated, sum them and store on stack and in cache
var sum = Weight.Zero;
while (valuesStack.Count > sumUntil)
{
var transitionValue = valuesStack.Pop();
sum = Weight.Sum(sum, transitionValue);
}
valueCache.Add(stateIndexPair, value);
return value;
valuesStack.Push(Weight.Product(multiplier, sum));
valueCache[statePosPair] = sum;
}
}
Debug.Assert(valuesStack.Count == 1);
var result = valuesStack.Pop();
return
!result.IsZero && this.LogValueOverride.HasValue
? this.LogValueOverride.Value
: result.LogValue;
}
/// <summary>
@ -1714,10 +1851,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// The only sequence having non-zero value, if found.
/// <see langword="null"/>, if the automaton is zero everywhere or is non-zero on more than one sequence.
/// </returns>
/// <remarks>Recursive implementation would be simpler but prone to stack overflows with large automata</remarks>
public TSequence TryComputePoint()
{
bool[] endNodeReachability = this.ComputeEndStateReachability();
if (!endNodeReachability[this.Start.Index])
var isEndNodeReachable = this.ComputeEndStateReachability();
if (!isEndNodeReachable[this.Start.Index])
{
return null;
}
@ -1725,8 +1863,92 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var point = new List<TElement>();
int? pointLength = null;
var stateDepth = new ArrayDictionary<int>(this.States.Count);
bool isPoint = this.TryComputePointDfs(this.Start, 0, stateDepth, endNodeReachability, point, ref pointLength);
return isPoint && pointLength.HasValue ? SequenceManipulator.ToSequence(point) : null;
var stack = new Stack<(int stateIndex, int sequencePos)>();
stack.Push((this.Start.Index, 0));
// Note: this algorithm looks simpler if implemented recursively. But recursive implementation
// causes StackOverflowException.
// Algorithm is simple: traverse automaton id depth-first fashion and check that element transitions
// along all paths are equal. If any inconsistency is found
while (stack.Count != 0)
{
var (stateIndex, sequencePos) = stack.Pop();
Debug.Assert(isEndNodeReachable[stateIndex], "Dead branches must not be visited.");
if (stateDepth.TryGetValue(stateIndex, out var cachedStateDepth))
{
// If we've already been in this state, we must be at the same sequence pos
if (sequencePos != cachedStateDepth)
{
return null;
}
// This state was already processed, goto next one
continue;
}
stateDepth.Add(stateIndex, sequencePos);
var state = this.States[stateIndex];
// Can we stop in this state?
if (state.CanEnd)
{
// Is this a suffix or a prefix of the point already found?
if (pointLength.HasValue)
{
if (sequencePos != pointLength.Value)
{
return null;
}
continue;
}
// Now we know the length of the sequence
pointLength = sequencePos;
}
foreach (var transition in state.Transitions)
{
var destStateIndex = transition.DestinationStateIndex;
if (!isEndNodeReachable[destStateIndex])
{
// Only walk through the accepting part of the automaton
continue;
}
if (transition.IsEpsilon)
{
// Move to the next state, keep the sequence position
stack.Push((destStateIndex, sequencePos));
}
else if (!transition.ElementDistribution.Value.IsPointMass)
{
// If there's non-point distribution on transition, than automaton doesn't have point either
return null;
}
else
{
var element = transition.ElementDistribution.Value.Point;
if (sequencePos == point.Count)
{
// It is the first time at this sequence position
point.Add(element);
}
else if (!point[sequencePos].Equals(element))
{
// This is not the first time at this sequence position, and the elements are different
return null;
}
stack.Push((destStateIndex, sequencePos + 1));
}
}
}
return pointLength.HasValue ? SequenceManipulator.ToSequence(point) : null;
}
/// <summary>
@ -1936,7 +2158,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.Data = builder.GetData();
this.LogValueOverride = automaton.LogValueOverride;
this.PruneTransitionsWithLogWeightLessThan = automaton.LogValueOverride;
this.PruneStatesWithLogEndWeightLessThan = automaton.LogValueOverride;
// Recursively builds an automaton representing the epsilon closure of a given automaton.
// Returns the state index of state representing the closure
@ -2077,12 +2299,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// For each state computes whether any state with non-zero ending weight can be reached from it.
/// </summary>
/// <returns>An array mapping state indices to end state reachability.</returns>
/// <remarks>Recursive implementation would be simpler but prone to stack overflows with large automatons</remarks>
private bool[] ComputeEndStateReachability()
{
//// First, build a reversed graph
int[] edgePlacementIndices = new int[this.States.Count + 1];
for (int i = 0; i < this.States.Count; ++i)
var edgePlacementIndices = new int[this.States.Count + 1];
for (var i = 0; i < this.States.Count; ++i)
{
var state = this.States[i];
foreach (var transition in state.Transitions)
@ -2097,15 +2320,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// The element of edgePlacementIndices at index i+1 contains a count of the number of edges
// going into the i'th state (the indegree of the state).
// Convert this into a cumulative count (which will be used to give a unique index to each edge).
for (int i = 1; i < edgePlacementIndices.Length; ++i)
for (var i = 1; i < edgePlacementIndices.Length; ++i)
{
edgePlacementIndices[i] += edgePlacementIndices[i - 1];
}
int[] edgeArrayStarts = (int[])edgePlacementIndices.Clone();
int totalEdgeCount = edgePlacementIndices[this.States.Count];
int[] edgeDestinationIndices = new int[totalEdgeCount];
for (int i = 0; i < this.States.Count; ++i)
var edgeArrayStarts = (int[])edgePlacementIndices.Clone();
var totalEdgeCount = edgePlacementIndices[this.States.Count];
var edgeDestinationIndices = new int[totalEdgeCount];
for (var i = 0; i < this.States.Count; ++i)
{
var state = this.States[i];
foreach (var transition in state.Transitions)
@ -2113,7 +2336,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (!transition.Weight.IsZero)
{
// The unique index for this edge
int edgePlacementIndex = edgePlacementIndices[transition.DestinationStateIndex]++;
var edgePlacementIndex = edgePlacementIndices[transition.DestinationStateIndex]++;
// The source index for the edge (which is the destination edge in the reversed graph)
edgeDestinationIndices[edgePlacementIndex] = i;
@ -2122,38 +2345,38 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
//// Now run a depth-first search to label all reachable nodes
bool[] visitedNodes = new bool[this.States.Count];
for (int i = 0; i < this.States.Count; ++i)
var stack = new Stack<int>();
var visitedNodes = new bool[this.States.Count];
for (var i = 0; i < this.States.Count; ++i)
{
if (!visitedNodes[i] && this.States[i].CanEnd)
{
LabelReachableNodesDfs(i);
visitedNodes[i] = true;
stack.Push(i);
while (stack.Count != 0)
{
var stateIndex = stack.Pop();
for (var edgeIndex = edgeArrayStarts[stateIndex]; edgeIndex < edgeArrayStarts[stateIndex + 1]; ++edgeIndex)
{
var destinationIndex = edgeDestinationIndices[edgeIndex];
if (!visitedNodes[destinationIndex])
{
visitedNodes[destinationIndex] = true;
stack.Push(destinationIndex);
}
}
}
}
}
return visitedNodes;
void LabelReachableNodesDfs(int currentVertex)
{
Debug.Assert(!visitedNodes[currentVertex], "Visited vertices must not be revisited.");
visitedNodes[currentVertex] = true;
for (int edgeIndex = edgeArrayStarts[currentVertex]; edgeIndex < edgeArrayStarts[currentVertex + 1]; ++edgeIndex)
{
int destVertexIndex = edgeDestinationIndices[edgeIndex];
if (!visitedNodes[destVertexIndex])
{
LabelReachableNodesDfs(destVertexIndex);
}
}
}
}
/// <summary>
/// Computes the logarithm of the normalizer of the automaton, normalizing it afterwards if requested.
/// </summary>
/// <param name="normalize">Specifies whether the automaton must be normalized after computing the normalizer.</param>
/// <returns>The logarithm of the normalizer.</returns>
/// <returns>The logarithm of the normalizer and condensation of automaton.</returns>
/// <remarks>The automaton is normalized only if the normalizer has a finite non-zero value.</remarks>
private double DoGetLogNormalizer(bool normalize)
{
@ -2161,7 +2384,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
TThis noEpsilonTransitions = Zero();
noEpsilonTransitions.SetToEpsilonClosureOf((TThis)this); // To get rid of infinite weight closures
Condensation condensation = noEpsilonTransitions.ComputeCondensation(noEpsilonTransitions.Start, tr => true, false);
var condensation = noEpsilonTransitions.ComputeCondensation(noEpsilonTransitions.Start, tr => true, false);
double logNormalizer = condensation.GetWeightToEnd(noEpsilonTransitions.Start.Index).LogValue;
if (normalize)
{
@ -2205,101 +2428,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
/// <summary>
/// Recursively looks for the only sequence which has non-zero value under the automaton.
/// </summary>
/// <param name="currentState">The state currently being traversed.</param>
/// <param name="currentSequencePos">The current position in the sequence.</param>
/// <param name="stateSequencePosCache">A lookup table for memoization.</param>
/// <param name="isEndNodeReachable">End node reachability table to avoid dead branches.</param>
/// <param name="point">The candidate sequence.</param>
/// <param name="pointLength">The length of the candidate sequence
/// or <see langword="null"/> if the length isn't known yet.</param>
/// <returns>
/// <see langword="true"/> if the sequence was found, <see langword="false"/> otherwise.
/// </returns>
private bool TryComputePointDfs(
State currentState,
int currentSequencePos,
ArrayDictionary<int> stateSequencePosCache,
bool[] isEndNodeReachable,
List<TElement> point,
ref int? pointLength)
{
Debug.Assert(isEndNodeReachable[currentState.Index], "Dead branches must not be visited.");
int cachedStateDepth;
if (stateSequencePosCache.TryGetValue(currentState.Index, out cachedStateDepth))
{
// If we've already been in this state, we must be at the same sequence pos
return currentSequencePos == cachedStateDepth;
}
stateSequencePosCache.Add(currentState.Index, currentSequencePos);
// Can we stop in this state?
if (currentState.CanEnd)
{
// Is this a suffix or a prefix of the point already found?
if (pointLength.HasValue)
{
if (currentSequencePos != pointLength.Value)
{
return false;
}
}
else
{
// Now we know the length of the sequence
pointLength = currentSequencePos;
}
}
foreach (var transition in currentState.Transitions)
{
State destState = this.States[transition.DestinationStateIndex];
if (!isEndNodeReachable[destState.Index])
{
continue; // Only walk through the accepting part of the automaton
}
if (transition.IsEpsilon)
{
// Move to the next state, keep the sequence position
if (!this.TryComputePointDfs(destState, currentSequencePos, stateSequencePosCache, isEndNodeReachable, point, ref pointLength))
{
return false;
}
}
else if (!transition.ElementDistribution.Value.IsPointMass)
{
return false;
}
else
{
TElement element = transition.ElementDistribution.Value.Point;
if (currentSequencePos == point.Count)
{
// It is the first time at this sequence position
point.Add(element);
}
else if (!point[currentSequencePos].Equals(element))
{
// This is not the first time at this sequence position, and the elements are different
return false;
}
// Look at the next sequence position
if (!this.TryComputePointDfs(destState, currentSequencePos + 1, stateSequencePosCache, isEndNodeReachable, point, ref pointLength))
{
return false;
}
}
}
return true;
}
/// <summary>
/// Swaps the current automaton with a given one.
/// </summary>
@ -2317,107 +2445,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.LogValueOverride = automaton.LogValueOverride;
automaton.LogValueOverride = dummy;
dummy = this.PruneTransitionsWithLogWeightLessThan;
this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan;
automaton.PruneTransitionsWithLogWeightLessThan = dummy;
}
/// <summary>
/// Appends a string representing the automaton from the current state.
/// </summary>
/// <param name="builder">The string builder.</param>
/// <param name="visitedStates">The set of states that has already been visited.</param>
/// <param name="stateIndex">The index of the current state.</param>
/// <param name="appendRegex">Optional method for appending at the element distribution level.</param>
private void AppendString(StringBuilder builder, HashSet<int> visitedStates, int stateIndex, Action<TElementDistribution, StringBuilder> appendRegex = null)
{
if (visitedStates.Contains(stateIndex))
{
builder.Append('➥');
return;
}
visitedStates.Add(stateIndex);
var currentState = this.States[stateIndex];
var transitions = currentState.Transitions.Where(t => !t.Weight.IsZero);
var selfTransitions = transitions.Where(t => t.DestinationStateIndex == stateIndex);
int selfTransitionCount = selfTransitions.Count();
var nonSelfTransitions = transitions.Where(t => t.DestinationStateIndex != stateIndex);
int nonSelfTransitionCount = nonSelfTransitions.Count();
if (currentState.CanEnd && nonSelfTransitionCount > 0)
{
builder.Append('▪');
}
if (selfTransitionCount > 1)
{
builder.Append('[');
}
int transIdx = 0;
foreach (var transition in selfTransitions)
{
if (!transition.IsEpsilon)
{
if (appendRegex != null)
{
appendRegex(transition.ElementDistribution.Value, builder);
}
else
{
builder.Append(transition.ElementDistribution);
}
}
if (++transIdx < selfTransitionCount)
{
builder.Append('|');
}
}
if (selfTransitionCount > 1)
{
builder.Append(']');
}
if (selfTransitionCount > 0)
{
builder.Append('*');
}
if (nonSelfTransitionCount > 1)
{
builder.Append('[');
}
transIdx = 0;
foreach (var transition in nonSelfTransitions)
{
if (!transition.IsEpsilon)
{
if (appendRegex != null)
{
appendRegex(transition.ElementDistribution.Value, builder);
}
else
{
builder.Append(transition.ElementDistribution);
}
}
this.AppendString(builder, visitedStates, transition.DestinationStateIndex, appendRegex);
if (++transIdx < nonSelfTransitionCount)
{
builder.Append('|');
}
}
if (nonSelfTransitionCount > 1)
{
builder.Append(']');
}
dummy = this.PruneStatesWithLogEndWeightLessThan;
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
automaton.PruneStatesWithLogEndWeightLessThan = dummy;
}
/// <summary>
@ -2587,7 +2617,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
propertyMask[1 << idx++] = true; // isEpsilonFree is alway known
propertyMask[1 << idx++] = this.Data.IsEpsilonFree;
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
propertyMask[1 << idx++] = this.PruneTransitionsWithLogWeightLessThan.HasValue;
propertyMask[1 << idx++] = this.PruneStatesWithLogEndWeightLessThan.HasValue;
propertyMask[1 << idx++] = true; // start state is alway serialized
writeInt32(propertyMask.Data);
@ -2597,9 +2627,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
writeDouble(this.LogValueOverride.Value);
}
if (this.PruneTransitionsWithLogWeightLessThan.HasValue)
if (this.PruneStatesWithLogEndWeightLessThan.HasValue)
{
writeDouble(this.PruneTransitionsWithLogWeightLessThan.Value);
writeDouble(this.PruneStatesWithLogEndWeightLessThan.Value);
}
// This state is serialized only for its index.
@ -2624,11 +2654,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var propertyMask = new BitVector32(readInt32());
var res = new TThis();
var idx = 0;
// we do not trust serialized "isEpsilonFree". Will take it from builder anyway
// Serialized "isEpsilonFree" is not used. Will be taken from builder anyway
var hasEpsilonFreeIgnored = propertyMask[1 << idx++];
var isEpsilonFreeIgnored = propertyMask[1 << idx++];
var hasLogValueOverride = propertyMask[1 << idx++];
var hasPruneTransitions = propertyMask[1 << idx++];
var hasPruneWeights = propertyMask[1 << idx++];
var hasStartState = propertyMask[1 << idx++];
if (hasLogValueOverride)
@ -2636,9 +2666,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
res.LogValueOverride = readDouble();
}
if (hasPruneTransitions)
if (hasPruneWeights)
{
res.PruneTransitionsWithLogWeightLessThan = readDouble();
res.PruneStatesWithLogEndWeightLessThan = readDouble();
}
var builder = new Builder(0);

Просмотреть файл

@ -318,47 +318,56 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
/// <param name="srcAutomaton">The automaton to project.</param>
/// <returns>The projection.</returns>
/// <remarks>
/// The code of this method has a lot in common with the code of Automaton.SetToProduct.
/// Unfortunately, it's not clear how to avoid the duplication in the current design.
/// </remarks>
public TDestAutomaton ProjectSource(TSrcAutomaton srcAutomaton)
{
Argument.CheckIfNotNull(srcAutomaton, "srcAutomaton");
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
if (srcAutomaton.IsCanonicZero() || this.sequencePairToWeight.IsCanonicZero())
var mappingAutomaton = this.sequencePairToWeight;
if (srcAutomaton.IsCanonicZero() || mappingAutomaton.IsCanonicZero())
{
return result.GetAutomaton();
return Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Zero();
}
// The projected automaton must be epsilon-free
srcAutomaton.MakeEpsilonFree();
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
var destStateCache = new Dictionary<(int, int), int>();
result.StartStateIndex = BuildProjectionOfAutomaton(this.sequencePairToWeight.Start, srcAutomaton.Start);
var stack = new Stack<(int state1, int state2, int destStateIndex)>();
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
simplification.RemoveDeadStates();
simplification.SimplifyIfNeeded();
return result.GetAutomaton();
// Recursively builds the projection of a given automaton onto this transducer.
// The projected automaton must be epsilon-free.
int BuildProjectionOfAutomaton(
// Creates destination state and schedules projection computation for it.
// If computation is already scheduled or done the state index is simply taken from cache
int CreateDestState(
PairListAutomaton.State mappingState,
Automaton<TSrcSequence, TSrcElement, TSrcElementDistribution, TSrcSequenceManipulator, TSrcAutomaton>.State srcState)
{
//// The code of this method has a lot in common with the code of Automaton<>.BuildProduct.
//// Unfortunately, it's not clear how to avoid the duplication in the current design.
// State already exists, return its index
var statePair = (mappingState.Index, srcState.Index);
if (destStateCache.TryGetValue(statePair, out var destStateIndex))
var destPair = (mappingState.Index, srcState.Index);
if (!destStateCache.TryGetValue(destPair, out var destStateIndex))
{
return destStateIndex;
var destState = result.AddState();
destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight));
stack.Push((mappingState.Index, srcState.Index, destState.Index));
destStateCache[destPair] = destState.Index;
destStateIndex = destState.Index;
}
var destState = result.AddState();
destStateCache.Add(statePair, destState.Index);
return destStateIndex;
}
// Populate the stack with start destination state
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
while (stack.Count > 0)
{
var (mappingStateIndex, srcStateIndex, destStateIndex) = stack.Pop();
var mappingState = mappingAutomaton.States[mappingStateIndex];
var srcState = srcAutomaton.States[srcStateIndex];
var destState = result[destStateIndex];
// Iterate over transitions from mappingState
foreach (var mappingTransition in mappingState.Transitions)
@ -372,7 +381,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
mappingTransition.ElementDistribution.HasValue
? mappingTransition.ElementDistribution.Value.Second
: Option.None;
var childDestStateIndex = BuildProjectionOfAutomaton(childMappingState, srcState);
var childDestStateIndex = CreateDestState(childMappingState, srcState);
destState.AddTransition(destElementDistribution, mappingTransition.Weight, childDestStateIndex, mappingTransition.Group);
continue;
}
@ -392,14 +401,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
var destWeight = Weight.Product(mappingTransition.Weight, srcTransition.Weight, Weight.FromLogValue(projectionLogScale));
var childDestStateIndex = BuildProjectionOfAutomaton(childMappingState, srcChildState);
var childDestStateIndex = CreateDestState(childMappingState, srcChildState);
destState.AddTransition(destElementDistribution, destWeight, childDestStateIndex, mappingTransition.Group);
}
}
destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight));
return destState.Index;
}
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
simplification.RemoveDeadStates();
simplification.SimplifyIfNeeded();
return result.GetAutomaton();
}
/// <summary>
@ -411,45 +423,57 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Using this method is more efficient than applying <see cref="ProjectSource(TSrcAutomaton)"/>
/// to the automaton representation of a projected sequence.
/// </remarks>
/// <remarks>
/// The code of this method has a lot in common with the code of Automaton.SetToProduct.
/// Unfortunately, it's not clear how to avoid the duplication in the current design.
/// </remarks>
public TDestAutomaton ProjectSource(TSrcSequence srcSequence)
{
Argument.CheckIfNotNull(srcSequence, "srcSequence");
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
if (this.sequencePairToWeight.IsCanonicZero())
var mappingAutomaton = this.sequencePairToWeight;
if (mappingAutomaton.IsCanonicZero())
{
return result.GetAutomaton();
return Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Zero();
}
var sourceSequenceManipulator =
Automaton<TSrcSequence, TSrcElement, TSrcElementDistribution, TSrcSequenceManipulator, TSrcAutomaton>.SequenceManipulator;
var srcSequenceLength = sourceSequenceManipulator.GetLength(srcSequence);
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
var destStateCache = new Dictionary<(int, int), int>();
result.StartStateIndex = BuildProjectionOfSequence(this.sequencePairToWeight.Start, 0);
var stack = new Stack<(int state1, int state2, int destStateIndex)>();
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
simplification.RemoveDeadStates();
simplification.SimplifyIfNeeded();
return result.GetAutomaton();
// Recursively builds the projection of a given sequence onto this transducer.
int BuildProjectionOfSequence(PairListAutomaton.State mappingState, int srcSequenceIndex)
// Creates destination state and schedules projection computation for it.
// If computation is already scheduled or done the state index is simply taken from cache
int CreateDestState(PairListAutomaton.State mappingState, int srcSequenceIndex)
{
//// The code of this method has a lot in common with the code of Automaton<>.BuildProduct.
//// Unfortunately, it's not clear how to avoid the duplication in the current design.
var sourceSequenceManipulator =
Automaton<TSrcSequence, TSrcElement, TSrcElementDistribution, TSrcSequenceManipulator, TSrcAutomaton>.SequenceManipulator;
var statePair = (mappingState.Index, srcSequenceIndex);
if (destStateCache.TryGetValue(statePair, out var destStateIndex))
var destPair = (mappingState.Index, srcSequenceIndex);
if (!destStateCache.TryGetValue(destPair, out var destStateIndex))
{
return destStateIndex;
var destState = result.AddState();
destState.SetEndWeight(
srcSequenceIndex == srcSequenceLength
? mappingState.EndWeight
: Weight.Zero);
stack.Push((mappingState.Index, srcSequenceIndex, destState.Index));
destStateCache[destPair] = destState.Index;
destStateIndex = destState.Index;
}
var destState = result.AddState();
destStateCache.Add(statePair, destState.Index);
return destStateIndex;
}
var srcSequenceLength = sourceSequenceManipulator.GetLength(srcSequence);
// Populate the stack with start destination state
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, 0);
while (stack.Count > 0)
{
var (mappingStateIndex, srcSequenceIndex, destStateIndex) = stack.Pop();
var mappingState = mappingAutomaton.States[mappingStateIndex];
var destState = result[destStateIndex];
// Enumerate transitions from the current mapping state
foreach (var mappingTransition in mappingState.Transitions)
@ -463,7 +487,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
mappingTransition.ElementDistribution.HasValue
? mappingTransition.ElementDistribution.Value.Second
: Option.None;
var childDestStateIndex = BuildProjectionOfSequence(destMappingState, srcSequenceIndex);
var childDestStateIndex = CreateDestState(destMappingState, srcSequenceIndex);
destState.AddTransition(destElementWeights, mappingTransition.Weight, childDestStateIndex, mappingTransition.Group);
continue;
}
@ -481,14 +505,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
var weight = Weight.Product(mappingTransition.Weight, Weight.FromLogValue(projectionLogScale));
var childDestState = BuildProjectionOfSequence(destMappingState,srcSequenceIndex + 1);
var childDestState = CreateDestState(destMappingState,srcSequenceIndex + 1);
destState.AddTransition(destElementDistribution, weight, childDestState, mappingTransition.Group);
}
}
destState.SetEndWeight(srcSequenceIndex == srcSequenceLength ? mappingState.EndWeight : Weight.Zero);
return destState.Index;
}
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
simplification.RemoveDeadStates();
simplification.SimplifyIfNeeded();
return result.GetAutomaton();
}
/// <summary>

Просмотреть файл

@ -88,6 +88,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
get { return double.IsNegativeInfinity(this.LogValue); }
}
/// <summary>
/// Gets value indicating whether weight is infinite.
/// </summary>
public bool IsInfinity => double.IsInfinity(Value);
/// <summary>
/// Creates a weight from the logarithm of its value.
/// </summary>

Просмотреть файл

@ -1792,15 +1792,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
this.isNormalized = true;
}
////// todo: consider moving the following logic into the Automaton class
////if (product.PruneTransitionsWithLogWeightLessThan.HasValue)
////{
//// if (this.isNormalized || product.TryNormalizeValues())
//// {
//// product.RemoveTransitionsWithSmallWeights(product.PruneTransitionsWithLogWeightLessThan.Value);
//// }
////}
return logNormalizer;
}

Просмотреть файл

@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Tests
{
using System;
@ -15,10 +13,10 @@ namespace Microsoft.ML.Probabilistic.Tests
using Xunit;
using Assert = Microsoft.ML.Probabilistic.Tests.AssertHelper;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Distributions.Automata;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Utilities;
/// <summary>
/// Tests for weighted finite state automata.
@ -64,10 +62,12 @@ namespace Microsoft.ML.Probabilistic.Tests
StringInferenceTestUtilities.TestValue(zero3, 0.0, "abc", "ab", "a", string.Empty);
StringAutomaton zero4 =
StringAutomaton.Constant(2.0, DiscreteChar.Lower())
.Product(
StringAutomaton.Constant(3.0, DiscreteChar.Upper())
.Append(StringAutomaton.ConstantOnElement(1.5, DiscreteChar.Digit())));
StringAutomaton
.Constant(2.0, DiscreteChar.Lower())
.Product(
StringAutomaton
.Constant(3.0, DiscreteChar.Upper())
.Append(StringAutomaton.ConstantOnElement(1.5, DiscreteChar.Digit())));
Assert.True(zero4.IsZero());
Assert.True(zero4.IsCanonicZero());
StringInferenceTestUtilities.TestValue(zero4, 0.0, "abc", "ab", "a", string.Empty);
@ -760,34 +760,91 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(expectedLogValue, logValue, 1e-8);
}
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when
/// an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
public void TryComputePointLargeAutomaton()
{
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
{
const int StateCount = 100_000;
var builder = new StringAutomaton.Builder();
var state = builder.Start;
for (var i = 1; i < StateCount; ++i)
{
state = state.AddTransition('a', Weight.One);
}
state.SetEndWeight(Weight.One);
var automaton = builder.GetAutomaton();
var point = new string('a', StateCount - 1);
Assert.True(automaton.TryComputePoint() == point);
StringInferenceTestUtilities.TestValue(automaton, 1.0, point);
}
}
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
[Trait("Category", "OpenBug")]
public void TryComputePointLargeAutomaton()
public void SetToProductLargeAutomaton()
{
//// Fails with ~2500 states due to stack overflow
//// Fails on MacOS 64-bit with 750 states due to stack overflow
int stateCount = Environment.Is64BitProcess ? 600 : 1500; // Stack frames are larger on 64bit
Debug.Assert(stateCount <= StringAutomaton.MaxStateCount, "MaxStateCount must be adjusted first.");
var builder = new StringAutomaton.Builder();
var state = builder.Start;
for (int i = 1; i < stateCount; ++i)
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
{
state = state.AddTransition('a', Weight.One);
const int StateCount = 100_000;
var builder = new StringAutomaton.Builder();
var state = builder.Start;
for (var i = 1; i < StateCount; ++i)
{
state = state.AddTransition('a', Weight.One);
}
state.SetEndWeight(Weight.One);
var automaton1 = builder.GetAutomaton();
var automaton2 = builder.GetAutomaton();
var point = new string('a', StateCount - 1);
var productAutomaton = StringAutomaton.Product(automaton1, automaton2);
StringInferenceTestUtilities.TestValue(productAutomaton, 1.0, point);
}
}
state.SetEndWeight(Weight.One);
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
public void GetLogNormalizerLargeAutomaton()
{
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
{
const int StateCount = 100_000;
var automaton = builder.GetAutomaton();
string point = new string('a', stateCount - 1);
Assert.True(automaton.TryComputePoint() == point);
StringInferenceTestUtilities.TestValue(automaton, 1.0, point);
var builder = new StringAutomaton.Builder();
var state = builder.Start;
for (var i = 1; i < StateCount; ++i)
{
state = state.AddTransition('a', Weight.One);
}
state.SetEndWeight(Weight.One);
var logNormalizer = builder.GetAutomaton().GetLogNormalizer();
Assert.Equal(0.0, logNormalizer);
}
}
/// <summary>
@ -1282,7 +1339,7 @@ namespace Microsoft.ML.Probabilistic.Tests
var automaton = builder.GetAutomaton();
for (int i = 0; i < 3; ++i)
{
{
StringInferenceTestUtilities.TestValue(automaton, 2.0, "ab");
StringInferenceTestUtilities.TestValue(automaton, 3.0, "adc", "adddc", "ac");
StringInferenceTestUtilities.TestValue(automaton, 0.0, "adb", "ad", string.Empty);

Просмотреть файл

@ -536,5 +536,38 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(referenceValue2, transpose1.GetValue(valuePair[0], valuePair[1]));
}
}
/// <summary>
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
/// </summary>
[Fact]
[Trait("Category", "StringInference")]
public void ProjectSourceLargeAutomaton()
{
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
{
const int StateCount = 100_000;
var builder = new StringAutomaton.Builder();
var state = builder.Start;
for (var i = 1; i < StateCount; ++i)
{
state = state.AddTransition('a', Weight.One);
}
state.SetEndWeight(Weight.One);
var automaton = builder.GetAutomaton();
var point = new string('a', StateCount - 1);
var copyTransducer = StringTransducer.Copy();
var projectedAutomaton = copyTransducer.ProjectSource(automaton);
var projectedPoint = copyTransducer.ProjectSource(point);
StringInferenceTestUtilities.TestValue(projectedAutomaton, 1.0, point);
StringInferenceTestUtilities.TestValue(projectedPoint, 1.0, point);
}
}
}
}

Просмотреть файл

@ -910,13 +910,13 @@ namespace Microsoft.ML.Probabilistic.Tests
/// <param name="eps">Precision.</param>
public static void Equal(double expected, double observed, double eps)
{
// Infinty check
// Infinity check
if (expected == observed)
{
return;
}
Assert.True(Math.Abs(expected - observed) < eps, $"Equality failure\n. Expected: {expected}\nActual: {observed}");
Assert.True(Math.Abs(expected - observed) < eps, $"Equality failure.\nExpected: {expected}\nActual: {observed}");
}
/// <summary>