зеркало из https://github.com/dotnet/infer.git
Get rid of recursive implementations (#125)
We finally reached a point where automatons are big enough that recursive implementations of algorithms fail with `StackOverflowException`. There are 4 tests which test operations with big automatons (100k states): * `TryComputePointLargeAutomaton` * `SetToProductLargeAutomaton` * `GetLogNormalizerLargeAutomaton` * `ProjectSourceLargeAutomaton` All four used to fail before code was rewritten to use explicit stack instead of recursive calls. All changes except one didn't change algorithms used, code was almost mechanically changed to use stack. The only exception - automaton simplification. Old code used recursion in very non-trivial way. It was rewritten from scratch, using different algorithm: instead of extraction of generalized sequences and then reinserting them back new code merges states directly in automaton. (There's a comment at the beginning of Simplify() method explaining all operations)
This commit is contained in:
Родитель
702874f9aa
Коммит
6d2ce9a993
|
@ -214,15 +214,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
}
|
||||
|
||||
this.StartStateIndex = oldToNewStateIdMapping[this.StartStateIndex];
|
||||
if (this.StartStateIndex == -1)
|
||||
if (deadStateCount == 0)
|
||||
{
|
||||
// Cannot reach any end state from the start state => the automaton is zero everywhere
|
||||
this.Clear();
|
||||
this.AddState();
|
||||
return deadStateCount;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// may invalidate automaton
|
||||
this.StartStateIndex = oldToNewStateIdMapping[this.StartStateIndex];
|
||||
|
||||
for (var i = 0; i < this.states.Count; ++i)
|
||||
{
|
||||
var newId = oldToNewStateIdMapping[i];
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||
{
|
||||
using System;
|
||||
|
@ -108,12 +110,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
this.useApproximateClosure = useApproximateClosure;
|
||||
|
||||
this.components = new List<StronglyConnectedComponent>();
|
||||
var stateIdStack = new Stack<State>();
|
||||
var stateIdToTarjanInfo = new Dictionary<int, TarjanStateInfo>();
|
||||
int traversalIndex = 0;
|
||||
this.FindStronglyConnectedComponents(this.Root, ref traversalIndex, stateIdToTarjanInfo, stateIdStack);
|
||||
this.FindStronglyConnectedComponents();
|
||||
|
||||
this.stateIdToInfo = new Dictionary<int, CondensationStateInfo>(stateIdToTarjanInfo.Count);
|
||||
this.stateIdToInfo = new Dictionary<int, CondensationStateInfo>();
|
||||
for (int i = 0; i < this.components.Count; ++i)
|
||||
{
|
||||
StronglyConnectedComponent component = this.components[i];
|
||||
|
@ -194,57 +193,99 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// Implements <a href="http://en.wikipedia.org/wiki/Tarjan's_strongly_connected_components_algorithm">Tarjan's algorithm</a>
|
||||
/// for finding the strongly connected components of the automaton graph.
|
||||
/// </summary>
|
||||
/// <param name="currentState">The state currently being traversed.</param>
|
||||
/// <param name="traversalIndex">The traversal index (as defined by the Tarjan's algorithm).</param>
|
||||
/// <param name="stateIdToStateInfo">A dictionary mapping state indices to info records maintained by the Tarjan's algorithm.</param>
|
||||
/// <param name="stateIdStack">The traversal stack (as defined by the Tarjan's algorithm).</param>
|
||||
private void FindStronglyConnectedComponents(
|
||||
State currentState,
|
||||
ref int traversalIndex,
|
||||
Dictionary<int, TarjanStateInfo> stateIdToStateInfo,
|
||||
Stack<State> stateIdStack)
|
||||
/// <remarks>
|
||||
/// This implementation closely follows algorithm from wikipedia. But is rewritten to use explicit
|
||||
/// traversal stack instead of recursion.
|
||||
///
|
||||
/// Basic idea is each time we need to do a recursive call we save all needed state into traversal
|
||||
/// stack, and push this "continuation" onto the stack. All state that need to be preserved are:
|
||||
/// - currentStateIndex - state being processed
|
||||
/// - currentTransitionIndex - index of the transition which was recursed last.
|
||||
/// If it is -1, that means that state was not processed yet at all.
|
||||
/// </remarks>
|
||||
private void FindStronglyConnectedComponents()
|
||||
{
|
||||
Debug.Assert(!stateIdToStateInfo.ContainsKey(currentState.Index), "Visited states must not be revisited.");
|
||||
var states = this.Root.Owner.States;
|
||||
var traversalIndex = 0;
|
||||
var stateIdStack = new Stack<int>(); // Stack that maintains Tarjan algorithm invariant
|
||||
var info = new TarjanStateInfo[this.Root.Owner.States.Count];
|
||||
var traversalStack = new Stack<(int currentStateIndex, int lastDestination)>();
|
||||
|
||||
var stateInfo = new TarjanStateInfo(traversalIndex);
|
||||
stateIdToStateInfo.Add(currentState.Index, stateInfo);
|
||||
++traversalIndex;
|
||||
traversalStack.Push((this.Root.Index, -1));
|
||||
|
||||
stateIdStack.Push(currentState);
|
||||
stateInfo.InStack = true;
|
||||
|
||||
foreach (var transition in currentState.Transitions)
|
||||
while (traversalStack.Count > 0)
|
||||
{
|
||||
if (!this.transitionFilter(transition))
|
||||
var (current, currentTransitionIndex) = traversalStack.Pop();
|
||||
var currentState = states[current];
|
||||
|
||||
if (currentTransitionIndex < 0)
|
||||
{
|
||||
// Just entered
|
||||
Debug.Assert(!info[current].Visited);
|
||||
info[current].Visited = true;
|
||||
info[current].TraversalIndex = traversalIndex;
|
||||
info[current].Lowlink = traversalIndex;
|
||||
info[current].InStack = true;
|
||||
stateIdStack.Push(current);
|
||||
++traversalIndex;
|
||||
currentTransitionIndex = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// we already traversed some state, and its index is in enumerator
|
||||
var lastTraversed = currentState.Transitions[currentTransitionIndex].DestinationStateIndex;
|
||||
info[current].Lowlink = Math.Min(info[current].Lowlink, info[lastTraversed].Lowlink);
|
||||
++currentTransitionIndex;
|
||||
}
|
||||
|
||||
// Continue processing
|
||||
for (; currentTransitionIndex < currentState.Transitions.Count; ++currentTransitionIndex)
|
||||
{
|
||||
var transition = currentState.Transitions[currentTransitionIndex];
|
||||
if (transitionFilter(transition))
|
||||
{
|
||||
var destination = transition.DestinationStateIndex;
|
||||
if (!info[destination].Visited)
|
||||
{
|
||||
// return to this state after destination is traversed
|
||||
traversalStack.Push((current, currentTransitionIndex));
|
||||
// traverse destination
|
||||
traversalStack.Push((destination, -1));
|
||||
// Processing of this state will effectively be resumed after destination is processed
|
||||
break;
|
||||
}
|
||||
|
||||
if (info[destination].InStack)
|
||||
{
|
||||
info[current].Lowlink =
|
||||
Math.Min(info[current].Lowlink, info[destination].TraversalIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We can break from for-loop above before end condition is met only if we pushed some
|
||||
// work to do onto traversal stack. One of the things pushed to stack will effectively
|
||||
// resume the loop from the currentTransitionIndex.
|
||||
if (currentTransitionIndex < currentState.Transitions.Count)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!stateIdToStateInfo.TryGetValue(transition.DestinationStateIndex, out TarjanStateInfo destinationStateInfo))
|
||||
if (info[current].Lowlink == info[current].TraversalIndex)
|
||||
{
|
||||
this.FindStronglyConnectedComponents(
|
||||
this.Root.Owner.States[transition.DestinationStateIndex], ref traversalIndex, stateIdToStateInfo, stateIdStack);
|
||||
stateInfo.Lowlink = Math.Min(stateInfo.Lowlink, stateIdToStateInfo[transition.DestinationStateIndex].Lowlink);
|
||||
}
|
||||
else if (destinationStateInfo.InStack)
|
||||
{
|
||||
stateInfo.Lowlink = Math.Min(stateInfo.Lowlink, destinationStateInfo.TraversalIndex);
|
||||
}
|
||||
}
|
||||
var statesInComponent = new List<State>();
|
||||
State state;
|
||||
do
|
||||
{
|
||||
state = states[stateIdStack.Pop()];
|
||||
info[state.Index].InStack = false;
|
||||
statesInComponent.Add(state);
|
||||
} while (state.Index != current);
|
||||
|
||||
if (stateInfo.Lowlink == stateInfo.TraversalIndex)
|
||||
{
|
||||
var statesInComponent = new List<State>();
|
||||
State state;
|
||||
do
|
||||
{
|
||||
state = stateIdStack.Pop();
|
||||
stateIdToStateInfo[state.Index].InStack = false;
|
||||
statesInComponent.Add(state);
|
||||
this.components.Add(
|
||||
new StronglyConnectedComponent(
|
||||
this.transitionFilter, statesInComponent, this.useApproximateClosure));
|
||||
}
|
||||
while (state.Index != currentState.Index);
|
||||
|
||||
this.components.Add(new StronglyConnectedComponent(this.transitionFilter, statesInComponent, this.useApproximateClosure));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -407,17 +448,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Stores the information maintained by the Tarjan's algorithm.
|
||||
/// </summary>
|
||||
private class TarjanStateInfo
|
||||
private struct TarjanStateInfo
|
||||
{
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="TarjanStateInfo"/> class.
|
||||
/// Gets or sets value indicating whether this node has already been visited
|
||||
/// </summary>
|
||||
/// <param name="traversalIndex">The current traversal index.</param>
|
||||
public TarjanStateInfo(int traversalIndex)
|
||||
{
|
||||
this.TraversalIndex = traversalIndex;
|
||||
this.Lowlink = traversalIndex;
|
||||
}
|
||||
public bool Visited { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether the state is currently in stack.
|
||||
|
@ -427,7 +463,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Gets the traversal index of the state.
|
||||
/// </summary>
|
||||
public int TraversalIndex { get; private set; }
|
||||
public int TraversalIndex { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the lowlink of the state.
|
||||
|
|
|
@ -123,11 +123,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
}
|
||||
|
||||
var simplification = new Simplification(builder, this.PruneTransitionsWithLogWeightLessThan);
|
||||
var simplification = new Simplification(builder, this.PruneStatesWithLogEndWeightLessThan);
|
||||
simplification.MergeParallelTransitions(); // Determinization produces a separate transition for each segment
|
||||
|
||||
var result = builder.GetAutomaton();
|
||||
result.PruneTransitionsWithLogWeightLessThan = this.PruneTransitionsWithLogWeightLessThan;
|
||||
result.PruneStatesWithLogEndWeightLessThan = this.PruneStatesWithLogEndWeightLessThan;
|
||||
result.LogValueOverride = this.LogValueOverride;
|
||||
this.SwapWith(result);
|
||||
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -2,11 +2,10 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Collections;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||
{
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.Collections.Specialized;
|
||||
using System.Diagnostics;
|
||||
|
@ -143,7 +142,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <remarks>
|
||||
/// TODO: We need to develop more elegant automaton approximation methods, this is a simple placeholder for those.
|
||||
/// </remarks>
|
||||
public double? PruneTransitionsWithLogWeightLessThan { get; set; }
|
||||
public double? PruneStatesWithLogEndWeightLessThan { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum number of states an automaton can have.
|
||||
|
@ -763,12 +762,107 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
return this.ToString();
|
||||
}
|
||||
else
|
||||
|
||||
var builder = new StringBuilder();
|
||||
var visitedStates = new HashSet<int>();
|
||||
var stack = new Stack<(string prefix, Option<TElementDistribution> prefixDistribution, int state)>();
|
||||
stack.Push((string.Empty, Option.None, Start.Index));
|
||||
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
StringBuilder builder = new StringBuilder();
|
||||
this.AppendString(builder, new HashSet<int>(), this.Start.Index, appendElement);
|
||||
return builder.ToString();
|
||||
var (prefix, prefixDistribution, stateIndex) = stack.Pop();
|
||||
|
||||
builder.Append(prefix);
|
||||
if (prefixDistribution.HasValue)
|
||||
{
|
||||
if (appendElement != null)
|
||||
{
|
||||
appendElement(prefixDistribution.Value, builder);
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.Append(prefixDistribution);
|
||||
}
|
||||
}
|
||||
|
||||
if (stateIndex == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (visitedStates.Contains(stateIndex))
|
||||
{
|
||||
builder.Append('➥');
|
||||
continue;
|
||||
}
|
||||
|
||||
visitedStates.Add(stateIndex);
|
||||
|
||||
var state = this.States[stateIndex];
|
||||
var transitions = state.Transitions.Where(t => !t.Weight.IsZero);
|
||||
var selfTransitions = transitions.Where(t => t.DestinationStateIndex == stateIndex);
|
||||
var selfTransitionCount = selfTransitions.Count();
|
||||
var nonSelfTransitions = transitions.Where(t => t.DestinationStateIndex != stateIndex).ToArray();
|
||||
var nonSelfTransitionCount = nonSelfTransitions.Count();
|
||||
|
||||
if (state.CanEnd && nonSelfTransitionCount > 0)
|
||||
{
|
||||
builder.Append('▪');
|
||||
}
|
||||
|
||||
if (selfTransitionCount > 1)
|
||||
{
|
||||
builder.Append('[');
|
||||
}
|
||||
|
||||
var transIdx = 0;
|
||||
foreach (var transition in selfTransitions)
|
||||
{
|
||||
if (!transition.IsEpsilon)
|
||||
{
|
||||
if (appendElement != null)
|
||||
{
|
||||
appendElement(transition.ElementDistribution.Value, builder);
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.Append(transition.ElementDistribution);
|
||||
}
|
||||
}
|
||||
|
||||
if (++transIdx < selfTransitionCount)
|
||||
{
|
||||
builder.Append('|');
|
||||
}
|
||||
}
|
||||
|
||||
if (selfTransitionCount > 1)
|
||||
{
|
||||
builder.Append(']');
|
||||
}
|
||||
|
||||
if (selfTransitionCount > 0)
|
||||
{
|
||||
builder.Append('*');
|
||||
}
|
||||
|
||||
if (nonSelfTransitionCount > 1)
|
||||
{
|
||||
builder.Append('[');
|
||||
stack.Push(("]", Option.None, -1));
|
||||
}
|
||||
|
||||
// Iterate transitions in reverse order, because after destinations are pushed to stack,
|
||||
// order will be reversed again.
|
||||
for (transIdx = nonSelfTransitions.Length - 1; transIdx >= 0; --transIdx)
|
||||
{
|
||||
var transition = nonSelfTransitions[transIdx];
|
||||
var transitionPrefix = transIdx == 0 ? string.Empty : "|";
|
||||
stack.Push((transitionPrefix, transition.ElementDistribution, transition.DestinationStateIndex));
|
||||
}
|
||||
}
|
||||
|
||||
return builder.ToString();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -853,10 +947,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// </summary>
|
||||
/// <returns>The logarithm of the normalizer.</returns>
|
||||
/// <remarks>Returns <see cref="double.PositiveInfinity"/> if the sum diverges.</remarks>
|
||||
public double GetLogNormalizer()
|
||||
{
|
||||
return this.DoGetLogNormalizer(false);
|
||||
}
|
||||
public double GetLogNormalizer() => this.DoGetLogNormalizer(false);
|
||||
|
||||
/// <summary>
|
||||
/// Normalizes the automaton so that the sum of its values over all possible sequences equals to one
|
||||
|
@ -1036,7 +1127,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var state = result[stateId];
|
||||
if (state.CanEnd)
|
||||
{
|
||||
// Make all accepting states contibute the desired value to the result
|
||||
// Make all accepting states contribute the desired value to the result
|
||||
state.SetEndWeight(value);
|
||||
}
|
||||
|
||||
|
@ -1242,37 +1333,35 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
var builder = new Builder(0);
|
||||
var productStateCache = new Dictionary<(int, int), int>(automaton1.States.Count + automaton2.States.Count);
|
||||
builder.StartStateIndex = BuildProduct(automaton1.Start, automaton2.Start);
|
||||
var stack = new Stack<(int state1, int state2, int productStateIndex)>();
|
||||
|
||||
var simplification = new Simplification(builder, this.PruneTransitionsWithLogWeightLessThan);
|
||||
simplification.RemoveDeadStates(); // Product can potentially create dead states
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
this.Data = builder.GetData();
|
||||
if (this is StringAutomaton && tryDeterminize)
|
||||
// Creates product state and schedules product computation for it.
|
||||
// If computation is already scheduled or done the state index is simply taken from cache
|
||||
int CreateProductState(State state1, State state2)
|
||||
{
|
||||
this.TryDeterminize();
|
||||
}
|
||||
|
||||
// Recursively builds an automaton representing the product of two given automata.
|
||||
// Returns start state of product automaton.
|
||||
int BuildProduct(State state1, State state2)
|
||||
{
|
||||
Debug.Assert(state1 != null && state2 != null, "Valid states must be provided.");
|
||||
Debug.Assert(
|
||||
state2.Owner.IsEpsilonFree,
|
||||
"The second argument of the product operation must be epsilon-free.");
|
||||
|
||||
// State already exists, return its index
|
||||
var statePair = (state1.Index, state2.Index);
|
||||
if (productStateCache.TryGetValue(statePair, out var productStateIndex))
|
||||
var destPair = (state1.Index, state2.Index);
|
||||
if (!productStateCache.TryGetValue(destPair, out var productStateIndex))
|
||||
{
|
||||
return productStateIndex;
|
||||
var productState = builder.AddState();
|
||||
productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight));
|
||||
stack.Push((state1.Index, state2.Index, productState.Index));
|
||||
productStateCache[destPair] = productState.Index;
|
||||
productStateIndex = productState.Index;
|
||||
}
|
||||
|
||||
// Create a new state
|
||||
var productState = builder.AddState();
|
||||
productStateCache.Add(statePair, productState.Index);
|
||||
return productStateIndex;
|
||||
}
|
||||
|
||||
// Populate the stack with start product state
|
||||
builder.StartStateIndex = CreateProductState(automaton1.Start, automaton2.Start);
|
||||
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
var (state1Index, state2Index, productStateIndex) = stack.Pop();
|
||||
|
||||
var state1 = automaton1.States[state1Index];
|
||||
var state2 = automaton2.States[state2Index];
|
||||
var productState = builder[productStateIndex];
|
||||
|
||||
// Iterate over transitions in state1
|
||||
foreach (var transition1 in state1.Transitions)
|
||||
|
@ -1282,7 +1371,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (transition1.IsEpsilon)
|
||||
{
|
||||
// Epsilon transition case
|
||||
var destProductStateIndex = BuildProduct(destState1, state2);
|
||||
var destProductStateIndex = CreateProductState(destState1, state2);
|
||||
productState.AddEpsilonTransition(transition1.Weight, destProductStateIndex, transition1.Group);
|
||||
continue;
|
||||
}
|
||||
|
@ -1295,23 +1384,29 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
"The second argument of the product operation must be epsilon-free.");
|
||||
var destState2 = state2.Owner.States[transition2.DestinationStateIndex];
|
||||
var productLogNormalizer = Distribution<TElement>.GetLogAverageOf(
|
||||
transition1.ElementDistribution.Value, transition2.ElementDistribution.Value, out var product);
|
||||
transition1.ElementDistribution.Value, transition2.ElementDistribution.Value,
|
||||
out var product);
|
||||
if (double.IsNegativeInfinity(productLogNormalizer))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var productWeight = Weight.Product(
|
||||
transition1.Weight,
|
||||
transition2.Weight,
|
||||
Weight.FromLogValue(productLogNormalizer));
|
||||
var destProductStateIndex = BuildProduct(destState1, destState2);
|
||||
transition1.Weight, transition2.Weight, Weight.FromLogValue(productLogNormalizer));
|
||||
var destProductStateIndex = CreateProductState(destState1, destState2);
|
||||
productState.AddTransition(product, productWeight, destProductStateIndex, transition1.Group);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight));
|
||||
return productState.Index;
|
||||
var simplification = new Simplification(builder, this.PruneStatesWithLogEndWeightLessThan);
|
||||
simplification.RemoveDeadStates(); // Product can potentially create dead states
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
this.Data = builder.GetData();
|
||||
if (this is StringAutomaton && tryDeterminize)
|
||||
{
|
||||
this.TryDeterminize();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1323,17 +1418,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <returns>The merged pruning threshold.</returns>
|
||||
private double? MergePruningWeights(TThis automaton1, TThis automaton2)
|
||||
{
|
||||
if (automaton1.PruneTransitionsWithLogWeightLessThan == null)
|
||||
if (automaton1.PruneStatesWithLogEndWeightLessThan == null)
|
||||
{
|
||||
return automaton2.PruneTransitionsWithLogWeightLessThan;
|
||||
return automaton2.PruneStatesWithLogEndWeightLessThan;
|
||||
}
|
||||
|
||||
if (automaton2.PruneTransitionsWithLogWeightLessThan == null)
|
||||
if (automaton2.PruneStatesWithLogEndWeightLessThan == null)
|
||||
{
|
||||
return automaton1.PruneTransitionsWithLogWeightLessThan;
|
||||
return automaton1.PruneStatesWithLogEndWeightLessThan;
|
||||
}
|
||||
|
||||
return Math.Min(automaton1.PruneTransitionsWithLogWeightLessThan.Value, automaton2.PruneTransitionsWithLogWeightLessThan.Value);
|
||||
return Math.Min(automaton1.PruneStatesWithLogEndWeightLessThan.Value, automaton2.PruneStatesWithLogEndWeightLessThan.Value);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1437,7 +1532,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
}
|
||||
|
||||
var simplification = new Simplification(result, this.PruneTransitionsWithLogWeightLessThan);
|
||||
var simplification = new Simplification(result, this.PruneStatesWithLogEndWeightLessThan);
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
this.Data = result.GetData();
|
||||
|
@ -1571,7 +1666,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
this.Data = automaton.Data;
|
||||
this.LogValueOverride = automaton.LogValueOverride;
|
||||
this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan;
|
||||
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1632,56 +1727,85 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// </summary>
|
||||
/// <param name="sequence">The sequence to compute the value on.</param>
|
||||
/// <returns>The logarithm of the value.</returns>
|
||||
/// <remarks></remarks>
|
||||
/// <remarks>Recursive implementation would be simpler but prone to stack overflows with large automata</remarks>
|
||||
public double GetLogValue(TSequence sequence)
|
||||
{
|
||||
Argument.CheckIfNotNull(sequence, "sequence");
|
||||
var valueCache = new Dictionary<(int, int), Weight>();
|
||||
var logValue = DoGetValue(this.Start.Index, 0).LogValue;
|
||||
var sequenceLength = SequenceManipulator.GetLength(sequence);
|
||||
|
||||
return
|
||||
!double.IsNegativeInfinity(logValue) && this.LogValueOverride.HasValue
|
||||
? this.LogValueOverride.Value
|
||||
: logValue;
|
||||
// This algorithm is unwinding of trivial recursion with cache. It encodes recursion through explicit
|
||||
// stack to avoid stack-overflow. It makes calculation not straightforward. Recursive algorithm that
|
||||
// we want to encode looks like this:
|
||||
// GetLogValue(state, sequencePos) =
|
||||
// closure = EpsilonClosure(state)
|
||||
// if sequencePos == sequenceLength:
|
||||
// return closure.EndWeight
|
||||
// else:
|
||||
// return sum([GetLogValue(transition.Destination, sequencePos+1) * transitionWeight'
|
||||
// for transition in closure.Transitions])
|
||||
// where transitionWeight' = transition.Weight * closureStateWeight * transition.ElementDistribution.GetLogWeight(sequence[sequencePos])
|
||||
//
|
||||
// To encode that with explicit stack we split each GetLogValue() call into 2 parts which are put
|
||||
// on operations stack:
|
||||
// a) (stateIndex, sequencePos, mul, -1) -
|
||||
// schedules computation for each transition from this state and (b) operation after them
|
||||
// b) (stateIndex, sequencePos, mul, sumUntil) -
|
||||
// sums values of transitions from this state. They will be on top of the stack
|
||||
// As an optimization, value calculated by (b) is cached, so if (a) notices that result for
|
||||
// (stateIndex, sequencePos) already calculated it doesn't schedule any extra work
|
||||
var operationsStack = new Stack<(int stateIndex, int sequencePos, Weight multiplier, int sumUntil)>();
|
||||
var valuesStack = new Stack<Weight>();
|
||||
var valueCache = new Dictionary<(int stateIndex, int sequencePos), Weight>();
|
||||
operationsStack.Push((this.Start.Index, 0, Weight.One, -1));
|
||||
|
||||
Weight DoGetValue(int stateIndex, int sequencePosition)
|
||||
while (operationsStack.Count > 0)
|
||||
{
|
||||
var state = this.States[stateIndex];
|
||||
var stateIndexPair = (stateIndex, sequencePosition);
|
||||
if (valueCache.TryGetValue(stateIndexPair, out var cachedValue))
|
||||
{
|
||||
return cachedValue;
|
||||
}
|
||||
var (stateIndex, sequencePos, multiplier, sumUntil) = operationsStack.Pop();
|
||||
var statePosPair = (stateIndex, sequencePos);
|
||||
|
||||
var closure = state.GetEpsilonClosure();
|
||||
|
||||
var value = Weight.Zero;
|
||||
var count = SequenceManipulator.GetLength(sequence);
|
||||
var isCurrent = sequencePosition < count;
|
||||
if (isCurrent)
|
||||
if (sumUntil < 0)
|
||||
{
|
||||
var element = SequenceManipulator.GetElement(sequence, sequencePosition);
|
||||
for (var closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex)
|
||||
if (valueCache.TryGetValue(statePosPair, out var cachedValue))
|
||||
{
|
||||
var closureState = closure.GetStateByIndex(closureStateIndex);
|
||||
var closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex);
|
||||
valuesStack.Push(Weight.Product(cachedValue, multiplier));
|
||||
}
|
||||
else
|
||||
{
|
||||
var closure = this.States[stateIndex].GetEpsilonClosure();
|
||||
|
||||
foreach (var transition in closureState.Transitions)
|
||||
if (sequencePos == sequenceLength)
|
||||
{
|
||||
if (transition.IsEpsilon)
|
||||
{
|
||||
continue; // The destination is a part of the closure anyway
|
||||
}
|
||||
// We are at the end of sequence. So put an answer on stack
|
||||
valuesStack.Push(Weight.Product(closure.EndWeight, multiplier));
|
||||
valueCache[statePosPair] = closure.EndWeight;
|
||||
}
|
||||
else
|
||||
{
|
||||
// schedule second part of computation - sum values for all transitions
|
||||
// Note: it is put on stack before operations for any transitions, that
|
||||
// means that it will be executed after them
|
||||
operationsStack.Push((stateIndex, sequencePos, multiplier, valuesStack.Count));
|
||||
|
||||
var destState = this.States[transition.DestinationStateIndex];
|
||||
var distWeight = Weight.FromLogValue(transition.ElementDistribution.Value.GetLogProb(element));
|
||||
if (!distWeight.IsZero && !transition.Weight.IsZero)
|
||||
var element = SequenceManipulator.GetElement(sequence, sequencePos);
|
||||
for (var closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex)
|
||||
{
|
||||
var destValue = DoGetValue(destState.Index, sequencePosition + 1);
|
||||
if (!destValue.IsZero)
|
||||
var closureState = closure.GetStateByIndex(closureStateIndex);
|
||||
var closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex);
|
||||
foreach (var transition in closureState.Transitions)
|
||||
{
|
||||
value = Weight.Sum(
|
||||
value,
|
||||
Weight.Product(closureStateWeight, transition.Weight, distWeight, destValue));
|
||||
if (transition.IsEpsilon)
|
||||
{
|
||||
continue; // The destination is a part of the closure anyway
|
||||
}
|
||||
|
||||
var destStateIndex = transition.DestinationStateIndex;
|
||||
var distWeight = Weight.FromLogValue(transition.ElementDistribution.Value.GetLogProb(element));
|
||||
if (!distWeight.IsZero && !transition.Weight.IsZero)
|
||||
{
|
||||
var weightMul = Weight.Product(closureStateWeight, transition.Weight, distWeight);
|
||||
operationsStack.Push((destStateIndex, sequencePos + 1, weightMul, -1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1689,12 +1813,25 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
else
|
||||
{
|
||||
value = closure.EndWeight;
|
||||
}
|
||||
// All transitions value from this state are already calculated, sum them and store on stack and in cache
|
||||
var sum = Weight.Zero;
|
||||
while (valuesStack.Count > sumUntil)
|
||||
{
|
||||
var transitionValue = valuesStack.Pop();
|
||||
sum = Weight.Sum(sum, transitionValue);
|
||||
}
|
||||
|
||||
valueCache.Add(stateIndexPair, value);
|
||||
return value;
|
||||
valuesStack.Push(Weight.Product(multiplier, sum));
|
||||
valueCache[statePosPair] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
Debug.Assert(valuesStack.Count == 1);
|
||||
var result = valuesStack.Pop();
|
||||
return
|
||||
!result.IsZero && this.LogValueOverride.HasValue
|
||||
? this.LogValueOverride.Value
|
||||
: result.LogValue;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1714,10 +1851,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// The only sequence having non-zero value, if found.
|
||||
/// <see langword="null"/>, if the automaton is zero everywhere or is non-zero on more than one sequence.
|
||||
/// </returns>
|
||||
/// <remarks>Recursive implementation would be simpler but prone to stack overflows with large automata</remarks>
|
||||
public TSequence TryComputePoint()
|
||||
{
|
||||
bool[] endNodeReachability = this.ComputeEndStateReachability();
|
||||
if (!endNodeReachability[this.Start.Index])
|
||||
var isEndNodeReachable = this.ComputeEndStateReachability();
|
||||
if (!isEndNodeReachable[this.Start.Index])
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
@ -1725,8 +1863,92 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var point = new List<TElement>();
|
||||
int? pointLength = null;
|
||||
var stateDepth = new ArrayDictionary<int>(this.States.Count);
|
||||
bool isPoint = this.TryComputePointDfs(this.Start, 0, stateDepth, endNodeReachability, point, ref pointLength);
|
||||
return isPoint && pointLength.HasValue ? SequenceManipulator.ToSequence(point) : null;
|
||||
var stack = new Stack<(int stateIndex, int sequencePos)>();
|
||||
stack.Push((this.Start.Index, 0));
|
||||
|
||||
// Note: this algorithm looks simpler if implemented recursively. But recursive implementation
|
||||
// causes StackOverflowException.
|
||||
// Algorithm is simple: traverse automaton id depth-first fashion and check that element transitions
|
||||
// along all paths are equal. If any inconsistency is found
|
||||
|
||||
while (stack.Count != 0)
|
||||
{
|
||||
var (stateIndex, sequencePos) = stack.Pop();
|
||||
Debug.Assert(isEndNodeReachable[stateIndex], "Dead branches must not be visited.");
|
||||
|
||||
if (stateDepth.TryGetValue(stateIndex, out var cachedStateDepth))
|
||||
{
|
||||
// If we've already been in this state, we must be at the same sequence pos
|
||||
if (sequencePos != cachedStateDepth)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
// This state was already processed, goto next one
|
||||
continue;
|
||||
}
|
||||
|
||||
stateDepth.Add(stateIndex, sequencePos);
|
||||
|
||||
var state = this.States[stateIndex];
|
||||
|
||||
// Can we stop in this state?
|
||||
if (state.CanEnd)
|
||||
{
|
||||
// Is this a suffix or a prefix of the point already found?
|
||||
if (pointLength.HasValue)
|
||||
{
|
||||
if (sequencePos != pointLength.Value)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// Now we know the length of the sequence
|
||||
pointLength = sequencePos;
|
||||
}
|
||||
|
||||
foreach (var transition in state.Transitions)
|
||||
{
|
||||
var destStateIndex = transition.DestinationStateIndex;
|
||||
if (!isEndNodeReachable[destStateIndex])
|
||||
{
|
||||
// Only walk through the accepting part of the automaton
|
||||
continue;
|
||||
}
|
||||
|
||||
if (transition.IsEpsilon)
|
||||
{
|
||||
// Move to the next state, keep the sequence position
|
||||
stack.Push((destStateIndex, sequencePos));
|
||||
}
|
||||
else if (!transition.ElementDistribution.Value.IsPointMass)
|
||||
{
|
||||
// If there's non-point distribution on transition, than automaton doesn't have point either
|
||||
return null;
|
||||
}
|
||||
else
|
||||
{
|
||||
var element = transition.ElementDistribution.Value.Point;
|
||||
if (sequencePos == point.Count)
|
||||
{
|
||||
// It is the first time at this sequence position
|
||||
point.Add(element);
|
||||
}
|
||||
else if (!point[sequencePos].Equals(element))
|
||||
{
|
||||
// This is not the first time at this sequence position, and the elements are different
|
||||
return null;
|
||||
}
|
||||
|
||||
stack.Push((destStateIndex, sequencePos + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pointLength.HasValue ? SequenceManipulator.ToSequence(point) : null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1936,7 +2158,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
this.Data = builder.GetData();
|
||||
this.LogValueOverride = automaton.LogValueOverride;
|
||||
this.PruneTransitionsWithLogWeightLessThan = automaton.LogValueOverride;
|
||||
this.PruneStatesWithLogEndWeightLessThan = automaton.LogValueOverride;
|
||||
|
||||
// Recursively builds an automaton representing the epsilon closure of a given automaton.
|
||||
// Returns the state index of state representing the closure
|
||||
|
@ -2077,12 +2299,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// For each state computes whether any state with non-zero ending weight can be reached from it.
|
||||
/// </summary>
|
||||
/// <returns>An array mapping state indices to end state reachability.</returns>
|
||||
/// <remarks>Recursive implementation would be simpler but prone to stack overflows with large automatons</remarks>
|
||||
private bool[] ComputeEndStateReachability()
|
||||
{
|
||||
//// First, build a reversed graph
|
||||
|
||||
int[] edgePlacementIndices = new int[this.States.Count + 1];
|
||||
for (int i = 0; i < this.States.Count; ++i)
|
||||
var edgePlacementIndices = new int[this.States.Count + 1];
|
||||
for (var i = 0; i < this.States.Count; ++i)
|
||||
{
|
||||
var state = this.States[i];
|
||||
foreach (var transition in state.Transitions)
|
||||
|
@ -2097,15 +2320,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
// The element of edgePlacementIndices at index i+1 contains a count of the number of edges
|
||||
// going into the i'th state (the indegree of the state).
|
||||
// Convert this into a cumulative count (which will be used to give a unique index to each edge).
|
||||
for (int i = 1; i < edgePlacementIndices.Length; ++i)
|
||||
for (var i = 1; i < edgePlacementIndices.Length; ++i)
|
||||
{
|
||||
edgePlacementIndices[i] += edgePlacementIndices[i - 1];
|
||||
}
|
||||
|
||||
int[] edgeArrayStarts = (int[])edgePlacementIndices.Clone();
|
||||
int totalEdgeCount = edgePlacementIndices[this.States.Count];
|
||||
int[] edgeDestinationIndices = new int[totalEdgeCount];
|
||||
for (int i = 0; i < this.States.Count; ++i)
|
||||
var edgeArrayStarts = (int[])edgePlacementIndices.Clone();
|
||||
var totalEdgeCount = edgePlacementIndices[this.States.Count];
|
||||
var edgeDestinationIndices = new int[totalEdgeCount];
|
||||
for (var i = 0; i < this.States.Count; ++i)
|
||||
{
|
||||
var state = this.States[i];
|
||||
foreach (var transition in state.Transitions)
|
||||
|
@ -2113,7 +2336,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (!transition.Weight.IsZero)
|
||||
{
|
||||
// The unique index for this edge
|
||||
int edgePlacementIndex = edgePlacementIndices[transition.DestinationStateIndex]++;
|
||||
var edgePlacementIndex = edgePlacementIndices[transition.DestinationStateIndex]++;
|
||||
|
||||
// The source index for the edge (which is the destination edge in the reversed graph)
|
||||
edgeDestinationIndices[edgePlacementIndex] = i;
|
||||
|
@ -2122,38 +2345,38 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
|
||||
//// Now run a depth-first search to label all reachable nodes
|
||||
bool[] visitedNodes = new bool[this.States.Count];
|
||||
for (int i = 0; i < this.States.Count; ++i)
|
||||
var stack = new Stack<int>();
|
||||
var visitedNodes = new bool[this.States.Count];
|
||||
for (var i = 0; i < this.States.Count; ++i)
|
||||
{
|
||||
if (!visitedNodes[i] && this.States[i].CanEnd)
|
||||
{
|
||||
LabelReachableNodesDfs(i);
|
||||
visitedNodes[i] = true;
|
||||
stack.Push(i);
|
||||
while (stack.Count != 0)
|
||||
{
|
||||
var stateIndex = stack.Pop();
|
||||
for (var edgeIndex = edgeArrayStarts[stateIndex]; edgeIndex < edgeArrayStarts[stateIndex + 1]; ++edgeIndex)
|
||||
{
|
||||
var destinationIndex = edgeDestinationIndices[edgeIndex];
|
||||
if (!visitedNodes[destinationIndex])
|
||||
{
|
||||
visitedNodes[destinationIndex] = true;
|
||||
stack.Push(destinationIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return visitedNodes;
|
||||
|
||||
void LabelReachableNodesDfs(int currentVertex)
|
||||
{
|
||||
Debug.Assert(!visitedNodes[currentVertex], "Visited vertices must not be revisited.");
|
||||
visitedNodes[currentVertex] = true;
|
||||
|
||||
for (int edgeIndex = edgeArrayStarts[currentVertex]; edgeIndex < edgeArrayStarts[currentVertex + 1]; ++edgeIndex)
|
||||
{
|
||||
int destVertexIndex = edgeDestinationIndices[edgeIndex];
|
||||
if (!visitedNodes[destVertexIndex])
|
||||
{
|
||||
LabelReachableNodesDfs(destVertexIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the logarithm of the normalizer of the automaton, normalizing it afterwards if requested.
|
||||
/// </summary>
|
||||
/// <param name="normalize">Specifies whether the automaton must be normalized after computing the normalizer.</param>
|
||||
/// <returns>The logarithm of the normalizer.</returns>
|
||||
/// <returns>The logarithm of the normalizer and condensation of automaton.</returns>
|
||||
/// <remarks>The automaton is normalized only if the normalizer has a finite non-zero value.</remarks>
|
||||
private double DoGetLogNormalizer(bool normalize)
|
||||
{
|
||||
|
@ -2161,7 +2384,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
TThis noEpsilonTransitions = Zero();
|
||||
noEpsilonTransitions.SetToEpsilonClosureOf((TThis)this); // To get rid of infinite weight closures
|
||||
|
||||
Condensation condensation = noEpsilonTransitions.ComputeCondensation(noEpsilonTransitions.Start, tr => true, false);
|
||||
var condensation = noEpsilonTransitions.ComputeCondensation(noEpsilonTransitions.Start, tr => true, false);
|
||||
double logNormalizer = condensation.GetWeightToEnd(noEpsilonTransitions.Start.Index).LogValue;
|
||||
if (normalize)
|
||||
{
|
||||
|
@ -2205,101 +2428,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Recursively looks for the only sequence which has non-zero value under the automaton.
|
||||
/// </summary>
|
||||
/// <param name="currentState">The state currently being traversed.</param>
|
||||
/// <param name="currentSequencePos">The current position in the sequence.</param>
|
||||
/// <param name="stateSequencePosCache">A lookup table for memoization.</param>
|
||||
/// <param name="isEndNodeReachable">End node reachability table to avoid dead branches.</param>
|
||||
/// <param name="point">The candidate sequence.</param>
|
||||
/// <param name="pointLength">The length of the candidate sequence
|
||||
/// or <see langword="null"/> if the length isn't known yet.</param>
|
||||
/// <returns>
|
||||
/// <see langword="true"/> if the sequence was found, <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
private bool TryComputePointDfs(
|
||||
State currentState,
|
||||
int currentSequencePos,
|
||||
ArrayDictionary<int> stateSequencePosCache,
|
||||
bool[] isEndNodeReachable,
|
||||
List<TElement> point,
|
||||
ref int? pointLength)
|
||||
{
|
||||
Debug.Assert(isEndNodeReachable[currentState.Index], "Dead branches must not be visited.");
|
||||
|
||||
int cachedStateDepth;
|
||||
if (stateSequencePosCache.TryGetValue(currentState.Index, out cachedStateDepth))
|
||||
{
|
||||
// If we've already been in this state, we must be at the same sequence pos
|
||||
return currentSequencePos == cachedStateDepth;
|
||||
}
|
||||
|
||||
stateSequencePosCache.Add(currentState.Index, currentSequencePos);
|
||||
|
||||
// Can we stop in this state?
|
||||
if (currentState.CanEnd)
|
||||
{
|
||||
// Is this a suffix or a prefix of the point already found?
|
||||
if (pointLength.HasValue)
|
||||
{
|
||||
if (currentSequencePos != pointLength.Value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Now we know the length of the sequence
|
||||
pointLength = currentSequencePos;
|
||||
}
|
||||
}
|
||||
|
||||
foreach (var transition in currentState.Transitions)
|
||||
{
|
||||
State destState = this.States[transition.DestinationStateIndex];
|
||||
if (!isEndNodeReachable[destState.Index])
|
||||
{
|
||||
continue; // Only walk through the accepting part of the automaton
|
||||
}
|
||||
|
||||
if (transition.IsEpsilon)
|
||||
{
|
||||
// Move to the next state, keep the sequence position
|
||||
if (!this.TryComputePointDfs(destState, currentSequencePos, stateSequencePosCache, isEndNodeReachable, point, ref pointLength))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (!transition.ElementDistribution.Value.IsPointMass)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
TElement element = transition.ElementDistribution.Value.Point;
|
||||
if (currentSequencePos == point.Count)
|
||||
{
|
||||
// It is the first time at this sequence position
|
||||
point.Add(element);
|
||||
}
|
||||
else if (!point[currentSequencePos].Equals(element))
|
||||
{
|
||||
// This is not the first time at this sequence position, and the elements are different
|
||||
return false;
|
||||
}
|
||||
|
||||
// Look at the next sequence position
|
||||
if (!this.TryComputePointDfs(destState, currentSequencePos + 1, stateSequencePosCache, isEndNodeReachable, point, ref pointLength))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Swaps the current automaton with a given one.
|
||||
/// </summary>
|
||||
|
@ -2317,107 +2445,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
this.LogValueOverride = automaton.LogValueOverride;
|
||||
automaton.LogValueOverride = dummy;
|
||||
|
||||
dummy = this.PruneTransitionsWithLogWeightLessThan;
|
||||
this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan;
|
||||
automaton.PruneTransitionsWithLogWeightLessThan = dummy;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Appends a string representing the automaton from the current state.
|
||||
/// </summary>
|
||||
/// <param name="builder">The string builder.</param>
|
||||
/// <param name="visitedStates">The set of states that has already been visited.</param>
|
||||
/// <param name="stateIndex">The index of the current state.</param>
|
||||
/// <param name="appendRegex">Optional method for appending at the element distribution level.</param>
|
||||
private void AppendString(StringBuilder builder, HashSet<int> visitedStates, int stateIndex, Action<TElementDistribution, StringBuilder> appendRegex = null)
|
||||
{
|
||||
if (visitedStates.Contains(stateIndex))
|
||||
{
|
||||
builder.Append('➥');
|
||||
return;
|
||||
}
|
||||
|
||||
visitedStates.Add(stateIndex);
|
||||
|
||||
var currentState = this.States[stateIndex];
|
||||
var transitions = currentState.Transitions.Where(t => !t.Weight.IsZero);
|
||||
var selfTransitions = transitions.Where(t => t.DestinationStateIndex == stateIndex);
|
||||
int selfTransitionCount = selfTransitions.Count();
|
||||
var nonSelfTransitions = transitions.Where(t => t.DestinationStateIndex != stateIndex);
|
||||
int nonSelfTransitionCount = nonSelfTransitions.Count();
|
||||
|
||||
if (currentState.CanEnd && nonSelfTransitionCount > 0)
|
||||
{
|
||||
builder.Append('▪');
|
||||
}
|
||||
|
||||
if (selfTransitionCount > 1)
|
||||
{
|
||||
builder.Append('[');
|
||||
}
|
||||
|
||||
int transIdx = 0;
|
||||
foreach (var transition in selfTransitions)
|
||||
{
|
||||
if (!transition.IsEpsilon)
|
||||
{
|
||||
if (appendRegex != null)
|
||||
{
|
||||
appendRegex(transition.ElementDistribution.Value, builder);
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.Append(transition.ElementDistribution);
|
||||
}
|
||||
}
|
||||
|
||||
if (++transIdx < selfTransitionCount)
|
||||
{
|
||||
builder.Append('|');
|
||||
}
|
||||
}
|
||||
|
||||
if (selfTransitionCount > 1)
|
||||
{
|
||||
builder.Append(']');
|
||||
}
|
||||
|
||||
if (selfTransitionCount > 0)
|
||||
{
|
||||
builder.Append('*');
|
||||
}
|
||||
|
||||
if (nonSelfTransitionCount > 1)
|
||||
{
|
||||
builder.Append('[');
|
||||
}
|
||||
|
||||
transIdx = 0;
|
||||
foreach (var transition in nonSelfTransitions)
|
||||
{
|
||||
if (!transition.IsEpsilon)
|
||||
{
|
||||
if (appendRegex != null)
|
||||
{
|
||||
appendRegex(transition.ElementDistribution.Value, builder);
|
||||
}
|
||||
else
|
||||
{
|
||||
builder.Append(transition.ElementDistribution);
|
||||
}
|
||||
}
|
||||
|
||||
this.AppendString(builder, visitedStates, transition.DestinationStateIndex, appendRegex);
|
||||
if (++transIdx < nonSelfTransitionCount)
|
||||
{
|
||||
builder.Append('|');
|
||||
}
|
||||
}
|
||||
|
||||
if (nonSelfTransitionCount > 1)
|
||||
{
|
||||
builder.Append(']');
|
||||
}
|
||||
dummy = this.PruneStatesWithLogEndWeightLessThan;
|
||||
this.PruneStatesWithLogEndWeightLessThan = automaton.PruneStatesWithLogEndWeightLessThan;
|
||||
automaton.PruneStatesWithLogEndWeightLessThan = dummy;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -2587,7 +2617,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
propertyMask[1 << idx++] = true; // isEpsilonFree is alway known
|
||||
propertyMask[1 << idx++] = this.Data.IsEpsilonFree;
|
||||
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
|
||||
propertyMask[1 << idx++] = this.PruneTransitionsWithLogWeightLessThan.HasValue;
|
||||
propertyMask[1 << idx++] = this.PruneStatesWithLogEndWeightLessThan.HasValue;
|
||||
propertyMask[1 << idx++] = true; // start state is alway serialized
|
||||
|
||||
writeInt32(propertyMask.Data);
|
||||
|
@ -2597,9 +2627,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
writeDouble(this.LogValueOverride.Value);
|
||||
}
|
||||
|
||||
if (this.PruneTransitionsWithLogWeightLessThan.HasValue)
|
||||
if (this.PruneStatesWithLogEndWeightLessThan.HasValue)
|
||||
{
|
||||
writeDouble(this.PruneTransitionsWithLogWeightLessThan.Value);
|
||||
writeDouble(this.PruneStatesWithLogEndWeightLessThan.Value);
|
||||
}
|
||||
|
||||
// This state is serialized only for its index.
|
||||
|
@ -2624,11 +2654,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var propertyMask = new BitVector32(readInt32());
|
||||
var res = new TThis();
|
||||
var idx = 0;
|
||||
// we do not trust serialized "isEpsilonFree". Will take it from builder anyway
|
||||
// Serialized "isEpsilonFree" is not used. Will be taken from builder anyway
|
||||
var hasEpsilonFreeIgnored = propertyMask[1 << idx++];
|
||||
var isEpsilonFreeIgnored = propertyMask[1 << idx++];
|
||||
var hasLogValueOverride = propertyMask[1 << idx++];
|
||||
var hasPruneTransitions = propertyMask[1 << idx++];
|
||||
var hasPruneWeights = propertyMask[1 << idx++];
|
||||
var hasStartState = propertyMask[1 << idx++];
|
||||
|
||||
if (hasLogValueOverride)
|
||||
|
@ -2636,9 +2666,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
res.LogValueOverride = readDouble();
|
||||
}
|
||||
|
||||
if (hasPruneTransitions)
|
||||
if (hasPruneWeights)
|
||||
{
|
||||
res.PruneTransitionsWithLogWeightLessThan = readDouble();
|
||||
res.PruneStatesWithLogEndWeightLessThan = readDouble();
|
||||
}
|
||||
|
||||
var builder = new Builder(0);
|
||||
|
|
|
@ -318,47 +318,56 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// </summary>
|
||||
/// <param name="srcAutomaton">The automaton to project.</param>
|
||||
/// <returns>The projection.</returns>
|
||||
/// <remarks>
|
||||
/// The code of this method has a lot in common with the code of Automaton.SetToProduct.
|
||||
/// Unfortunately, it's not clear how to avoid the duplication in the current design.
|
||||
/// </remarks>
|
||||
public TDestAutomaton ProjectSource(TSrcAutomaton srcAutomaton)
|
||||
{
|
||||
Argument.CheckIfNotNull(srcAutomaton, "srcAutomaton");
|
||||
|
||||
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
|
||||
|
||||
if (srcAutomaton.IsCanonicZero() || this.sequencePairToWeight.IsCanonicZero())
|
||||
var mappingAutomaton = this.sequencePairToWeight;
|
||||
if (srcAutomaton.IsCanonicZero() || mappingAutomaton.IsCanonicZero())
|
||||
{
|
||||
return result.GetAutomaton();
|
||||
return Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Zero();
|
||||
}
|
||||
|
||||
// The projected automaton must be epsilon-free
|
||||
srcAutomaton.MakeEpsilonFree();
|
||||
|
||||
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
|
||||
var destStateCache = new Dictionary<(int, int), int>();
|
||||
result.StartStateIndex = BuildProjectionOfAutomaton(this.sequencePairToWeight.Start, srcAutomaton.Start);
|
||||
var stack = new Stack<(int state1, int state2, int destStateIndex)>();
|
||||
|
||||
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
|
||||
simplification.RemoveDeadStates();
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
return result.GetAutomaton();
|
||||
|
||||
// Recursively builds the projection of a given automaton onto this transducer.
|
||||
// The projected automaton must be epsilon-free.
|
||||
int BuildProjectionOfAutomaton(
|
||||
// Creates destination state and schedules projection computation for it.
|
||||
// If computation is already scheduled or done the state index is simply taken from cache
|
||||
int CreateDestState(
|
||||
PairListAutomaton.State mappingState,
|
||||
Automaton<TSrcSequence, TSrcElement, TSrcElementDistribution, TSrcSequenceManipulator, TSrcAutomaton>.State srcState)
|
||||
{
|
||||
//// The code of this method has a lot in common with the code of Automaton<>.BuildProduct.
|
||||
//// Unfortunately, it's not clear how to avoid the duplication in the current design.
|
||||
|
||||
// State already exists, return its index
|
||||
var statePair = (mappingState.Index, srcState.Index);
|
||||
if (destStateCache.TryGetValue(statePair, out var destStateIndex))
|
||||
var destPair = (mappingState.Index, srcState.Index);
|
||||
if (!destStateCache.TryGetValue(destPair, out var destStateIndex))
|
||||
{
|
||||
return destStateIndex;
|
||||
var destState = result.AddState();
|
||||
destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight));
|
||||
stack.Push((mappingState.Index, srcState.Index, destState.Index));
|
||||
destStateCache[destPair] = destState.Index;
|
||||
destStateIndex = destState.Index;
|
||||
}
|
||||
|
||||
var destState = result.AddState();
|
||||
destStateCache.Add(statePair, destState.Index);
|
||||
return destStateIndex;
|
||||
}
|
||||
|
||||
// Populate the stack with start destination state
|
||||
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, srcAutomaton.Start);
|
||||
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
var (mappingStateIndex, srcStateIndex, destStateIndex) = stack.Pop();
|
||||
|
||||
var mappingState = mappingAutomaton.States[mappingStateIndex];
|
||||
var srcState = srcAutomaton.States[srcStateIndex];
|
||||
var destState = result[destStateIndex];
|
||||
|
||||
// Iterate over transitions from mappingState
|
||||
foreach (var mappingTransition in mappingState.Transitions)
|
||||
|
@ -372,7 +381,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
mappingTransition.ElementDistribution.HasValue
|
||||
? mappingTransition.ElementDistribution.Value.Second
|
||||
: Option.None;
|
||||
var childDestStateIndex = BuildProjectionOfAutomaton(childMappingState, srcState);
|
||||
var childDestStateIndex = CreateDestState(childMappingState, srcState);
|
||||
destState.AddTransition(destElementDistribution, mappingTransition.Weight, childDestStateIndex, mappingTransition.Group);
|
||||
continue;
|
||||
}
|
||||
|
@ -392,14 +401,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
|
||||
var destWeight = Weight.Product(mappingTransition.Weight, srcTransition.Weight, Weight.FromLogValue(projectionLogScale));
|
||||
var childDestStateIndex = BuildProjectionOfAutomaton(childMappingState, srcChildState);
|
||||
var childDestStateIndex = CreateDestState(childMappingState, srcChildState);
|
||||
destState.AddTransition(destElementDistribution, destWeight, childDestStateIndex, mappingTransition.Group);
|
||||
}
|
||||
}
|
||||
|
||||
destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight));
|
||||
return destState.Index;
|
||||
}
|
||||
|
||||
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
|
||||
simplification.RemoveDeadStates();
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
return result.GetAutomaton();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -411,45 +423,57 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// Using this method is more efficient than applying <see cref="ProjectSource(TSrcAutomaton)"/>
|
||||
/// to the automaton representation of a projected sequence.
|
||||
/// </remarks>
|
||||
/// <remarks>
|
||||
/// The code of this method has a lot in common with the code of Automaton.SetToProduct.
|
||||
/// Unfortunately, it's not clear how to avoid the duplication in the current design.
|
||||
/// </remarks>
|
||||
public TDestAutomaton ProjectSource(TSrcSequence srcSequence)
|
||||
{
|
||||
Argument.CheckIfNotNull(srcSequence, "srcSequence");
|
||||
|
||||
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
|
||||
|
||||
if (this.sequencePairToWeight.IsCanonicZero())
|
||||
var mappingAutomaton = this.sequencePairToWeight;
|
||||
if (mappingAutomaton.IsCanonicZero())
|
||||
{
|
||||
return result.GetAutomaton();
|
||||
return Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Zero();
|
||||
}
|
||||
|
||||
var sourceSequenceManipulator =
|
||||
Automaton<TSrcSequence, TSrcElement, TSrcElementDistribution, TSrcSequenceManipulator, TSrcAutomaton>.SequenceManipulator;
|
||||
var srcSequenceLength = sourceSequenceManipulator.GetLength(srcSequence);
|
||||
|
||||
var result = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Builder();
|
||||
var destStateCache = new Dictionary<(int, int), int>();
|
||||
result.StartStateIndex = BuildProjectionOfSequence(this.sequencePairToWeight.Start, 0);
|
||||
var stack = new Stack<(int state1, int state2, int destStateIndex)>();
|
||||
|
||||
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
|
||||
simplification.RemoveDeadStates();
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
return result.GetAutomaton();
|
||||
|
||||
// Recursively builds the projection of a given sequence onto this transducer.
|
||||
int BuildProjectionOfSequence(PairListAutomaton.State mappingState, int srcSequenceIndex)
|
||||
// Creates destination state and schedules projection computation for it.
|
||||
// If computation is already scheduled or done the state index is simply taken from cache
|
||||
int CreateDestState(PairListAutomaton.State mappingState, int srcSequenceIndex)
|
||||
{
|
||||
//// The code of this method has a lot in common with the code of Automaton<>.BuildProduct.
|
||||
//// Unfortunately, it's not clear how to avoid the duplication in the current design.
|
||||
|
||||
var sourceSequenceManipulator =
|
||||
Automaton<TSrcSequence, TSrcElement, TSrcElementDistribution, TSrcSequenceManipulator, TSrcAutomaton>.SequenceManipulator;
|
||||
|
||||
var statePair = (mappingState.Index, srcSequenceIndex);
|
||||
if (destStateCache.TryGetValue(statePair, out var destStateIndex))
|
||||
var destPair = (mappingState.Index, srcSequenceIndex);
|
||||
if (!destStateCache.TryGetValue(destPair, out var destStateIndex))
|
||||
{
|
||||
return destStateIndex;
|
||||
var destState = result.AddState();
|
||||
destState.SetEndWeight(
|
||||
srcSequenceIndex == srcSequenceLength
|
||||
? mappingState.EndWeight
|
||||
: Weight.Zero);
|
||||
stack.Push((mappingState.Index, srcSequenceIndex, destState.Index));
|
||||
destStateCache[destPair] = destState.Index;
|
||||
destStateIndex = destState.Index;
|
||||
}
|
||||
|
||||
var destState = result.AddState();
|
||||
destStateCache.Add(statePair, destState.Index);
|
||||
return destStateIndex;
|
||||
}
|
||||
|
||||
var srcSequenceLength = sourceSequenceManipulator.GetLength(srcSequence);
|
||||
// Populate the stack with start destination state
|
||||
result.StartStateIndex = CreateDestState(mappingAutomaton.Start, 0);
|
||||
|
||||
while (stack.Count > 0)
|
||||
{
|
||||
var (mappingStateIndex, srcSequenceIndex, destStateIndex) = stack.Pop();
|
||||
|
||||
var mappingState = mappingAutomaton.States[mappingStateIndex];
|
||||
var destState = result[destStateIndex];
|
||||
|
||||
// Enumerate transitions from the current mapping state
|
||||
foreach (var mappingTransition in mappingState.Transitions)
|
||||
|
@ -463,7 +487,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
mappingTransition.ElementDistribution.HasValue
|
||||
? mappingTransition.ElementDistribution.Value.Second
|
||||
: Option.None;
|
||||
var childDestStateIndex = BuildProjectionOfSequence(destMappingState, srcSequenceIndex);
|
||||
var childDestStateIndex = CreateDestState(destMappingState, srcSequenceIndex);
|
||||
destState.AddTransition(destElementWeights, mappingTransition.Weight, childDestStateIndex, mappingTransition.Group);
|
||||
continue;
|
||||
}
|
||||
|
@ -481,14 +505,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
|
||||
var weight = Weight.Product(mappingTransition.Weight, Weight.FromLogValue(projectionLogScale));
|
||||
var childDestState = BuildProjectionOfSequence(destMappingState,srcSequenceIndex + 1);
|
||||
var childDestState = CreateDestState(destMappingState,srcSequenceIndex + 1);
|
||||
destState.AddTransition(destElementDistribution, weight, childDestState, mappingTransition.Group);
|
||||
}
|
||||
}
|
||||
|
||||
destState.SetEndWeight(srcSequenceIndex == srcSequenceLength ? mappingState.EndWeight : Weight.Zero);
|
||||
return destState.Index;
|
||||
}
|
||||
|
||||
var simplification = new Automaton<TDestSequence, TDestElement, TDestElementDistribution, TDestSequenceManipulator, TDestAutomaton>.Simplification(result, null);
|
||||
simplification.RemoveDeadStates();
|
||||
simplification.SimplifyIfNeeded();
|
||||
|
||||
return result.GetAutomaton();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -88,6 +88,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
get { return double.IsNegativeInfinity(this.LogValue); }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets value indicating whether weight is infinite.
|
||||
/// </summary>
|
||||
public bool IsInfinity => double.IsInfinity(Value);
|
||||
|
||||
/// <summary>
|
||||
/// Creates a weight from the logarithm of its value.
|
||||
/// </summary>
|
||||
|
|
|
@ -1792,15 +1792,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
this.isNormalized = true;
|
||||
}
|
||||
|
||||
////// todo: consider moving the following logic into the Automaton class
|
||||
////if (product.PruneTransitionsWithLogWeightLessThan.HasValue)
|
||||
////{
|
||||
//// if (this.isNormalized || product.TryNormalizeValues())
|
||||
//// {
|
||||
//// product.RemoveTransitionsWithSmallWeights(product.PruneTransitionsWithLogWeightLessThan.Value);
|
||||
//// }
|
||||
////}
|
||||
|
||||
return logNormalizer;
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Tests
|
||||
{
|
||||
using System;
|
||||
|
@ -15,10 +13,10 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
using Xunit;
|
||||
using Assert = Microsoft.ML.Probabilistic.Tests.AssertHelper;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Distributions.Automata;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
|
||||
/// <summary>
|
||||
/// Tests for weighted finite state automata.
|
||||
|
@ -64,10 +62,12 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
StringInferenceTestUtilities.TestValue(zero3, 0.0, "abc", "ab", "a", string.Empty);
|
||||
|
||||
StringAutomaton zero4 =
|
||||
StringAutomaton.Constant(2.0, DiscreteChar.Lower())
|
||||
.Product(
|
||||
StringAutomaton.Constant(3.0, DiscreteChar.Upper())
|
||||
.Append(StringAutomaton.ConstantOnElement(1.5, DiscreteChar.Digit())));
|
||||
StringAutomaton
|
||||
.Constant(2.0, DiscreteChar.Lower())
|
||||
.Product(
|
||||
StringAutomaton
|
||||
.Constant(3.0, DiscreteChar.Upper())
|
||||
.Append(StringAutomaton.ConstantOnElement(1.5, DiscreteChar.Digit())));
|
||||
Assert.True(zero4.IsZero());
|
||||
Assert.True(zero4.IsCanonicZero());
|
||||
StringInferenceTestUtilities.TestValue(zero4, 0.0, "abc", "ab", "a", string.Empty);
|
||||
|
@ -760,34 +760,91 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Assert.Equal(expectedLogValue, logValue, 1e-8);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests whether the point mass computation operations fails due to a stack overflow when
|
||||
/// an automaton becomes sufficiently large.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
[Trait("Category", "StringInference")]
|
||||
public void TryComputePointLargeAutomaton()
|
||||
{
|
||||
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
|
||||
{
|
||||
const int StateCount = 100_000;
|
||||
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var state = builder.Start;
|
||||
|
||||
for (var i = 1; i < StateCount; ++i)
|
||||
{
|
||||
state = state.AddTransition('a', Weight.One);
|
||||
}
|
||||
|
||||
state.SetEndWeight(Weight.One);
|
||||
|
||||
var automaton = builder.GetAutomaton();
|
||||
var point = new string('a', StateCount - 1);
|
||||
|
||||
Assert.True(automaton.TryComputePoint() == point);
|
||||
StringInferenceTestUtilities.TestValue(automaton, 1.0, point);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
[Trait("Category", "StringInference")]
|
||||
[Trait("Category", "OpenBug")]
|
||||
public void TryComputePointLargeAutomaton()
|
||||
public void SetToProductLargeAutomaton()
|
||||
{
|
||||
//// Fails with ~2500 states due to stack overflow
|
||||
//// Fails on MacOS 64-bit with 750 states due to stack overflow
|
||||
int stateCount = Environment.Is64BitProcess ? 600 : 1500; // Stack frames are larger on 64bit
|
||||
Debug.Assert(stateCount <= StringAutomaton.MaxStateCount, "MaxStateCount must be adjusted first.");
|
||||
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var state = builder.Start;
|
||||
|
||||
for (int i = 1; i < stateCount; ++i)
|
||||
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
|
||||
{
|
||||
state = state.AddTransition('a', Weight.One);
|
||||
const int StateCount = 100_000;
|
||||
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var state = builder.Start;
|
||||
|
||||
for (var i = 1; i < StateCount; ++i)
|
||||
{
|
||||
state = state.AddTransition('a', Weight.One);
|
||||
}
|
||||
|
||||
state.SetEndWeight(Weight.One);
|
||||
|
||||
var automaton1 = builder.GetAutomaton();
|
||||
var automaton2 = builder.GetAutomaton();
|
||||
var point = new string('a', StateCount - 1);
|
||||
|
||||
var productAutomaton = StringAutomaton.Product(automaton1, automaton2);
|
||||
StringInferenceTestUtilities.TestValue(productAutomaton, 1.0, point);
|
||||
}
|
||||
}
|
||||
|
||||
state.SetEndWeight(Weight.One);
|
||||
/// <summary>
|
||||
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
[Trait("Category", "StringInference")]
|
||||
public void GetLogNormalizerLargeAutomaton()
|
||||
{
|
||||
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
|
||||
{
|
||||
const int StateCount = 100_000;
|
||||
|
||||
var automaton = builder.GetAutomaton();
|
||||
string point = new string('a', stateCount - 1);
|
||||
|
||||
Assert.True(automaton.TryComputePoint() == point);
|
||||
StringInferenceTestUtilities.TestValue(automaton, 1.0, point);
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var state = builder.Start;
|
||||
|
||||
for (var i = 1; i < StateCount; ++i)
|
||||
{
|
||||
state = state.AddTransition('a', Weight.One);
|
||||
}
|
||||
|
||||
state.SetEndWeight(Weight.One);
|
||||
|
||||
var logNormalizer = builder.GetAutomaton().GetLogNormalizer();
|
||||
|
||||
Assert.Equal(0.0, logNormalizer);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1282,7 +1339,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var automaton = builder.GetAutomaton();
|
||||
|
||||
for (int i = 0; i < 3; ++i)
|
||||
{
|
||||
{
|
||||
StringInferenceTestUtilities.TestValue(automaton, 2.0, "ab");
|
||||
StringInferenceTestUtilities.TestValue(automaton, 3.0, "adc", "adddc", "ac");
|
||||
StringInferenceTestUtilities.TestValue(automaton, 0.0, "adb", "ad", string.Empty);
|
||||
|
|
|
@ -536,5 +536,38 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Assert.Equal(referenceValue2, transpose1.GetValue(valuePair[0], valuePair[1]));
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests whether the point mass computation operations fails due to a stack overflow when an automaton becomes sufficiently large.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
[Trait("Category", "StringInference")]
|
||||
public void ProjectSourceLargeAutomaton()
|
||||
{
|
||||
using (var unlimited = new StringAutomaton.UnlimitedStatesComputation())
|
||||
{
|
||||
const int StateCount = 100_000;
|
||||
|
||||
var builder = new StringAutomaton.Builder();
|
||||
var state = builder.Start;
|
||||
|
||||
for (var i = 1; i < StateCount; ++i)
|
||||
{
|
||||
state = state.AddTransition('a', Weight.One);
|
||||
}
|
||||
|
||||
state.SetEndWeight(Weight.One);
|
||||
|
||||
var automaton = builder.GetAutomaton();
|
||||
var point = new string('a', StateCount - 1);
|
||||
var copyTransducer = StringTransducer.Copy();
|
||||
|
||||
var projectedAutomaton = copyTransducer.ProjectSource(automaton);
|
||||
var projectedPoint = copyTransducer.ProjectSource(point);
|
||||
|
||||
StringInferenceTestUtilities.TestValue(projectedAutomaton, 1.0, point);
|
||||
StringInferenceTestUtilities.TestValue(projectedPoint, 1.0, point);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -910,13 +910,13 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// <param name="eps">Precision.</param>
|
||||
public static void Equal(double expected, double observed, double eps)
|
||||
{
|
||||
// Infinty check
|
||||
// Infinity check
|
||||
if (expected == observed)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
Assert.True(Math.Abs(expected - observed) < eps, $"Equality failure\n. Expected: {expected}\nActual: {observed}");
|
||||
Assert.True(Math.Abs(expected - observed) < eps, $"Equality failure.\nExpected: {expected}\nActual: {observed}");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
Загрузка…
Ссылка в новой задаче