зеркало из https://github.com/dotnet/infer.git
Automaton determinization improvements (#144)
* Determinization doesn't change language of the automaton anymore. 2 new tests were added which check that it doesn't happen. Previously all very low probability transitions (with probability of less than e^-35) were removed. That was done because of 2 reasons: - some probabilities were represented in linear space. And e^-35 is as low resolution as you can get with regular doubles. (That is, if one probability = 1 and another is e^-35, then when you add them together, second one is indistinguishable from zero). This was fixed when discrete charprobabilities - Trying to determinize some non-determinizable automata lead to explosion of low-probability states and transitions which led to a very poor performance. (E.g. `AutomatonNormalizationPerformance3` test). Now a smarter strategy for detecting these non-determinizable automata is used - during traversal all sets of states from root are remembered. If automaton comes to the same set of states but with different weights than it stops immediately, because otherwise it will be caught in infinite loop * `Equals()` and `GetHashCode()` for `WeightedStateSet` take into account only high 32 bits of weight. This coupled with normalization of weights allows to reuse already added states with very close weights. This speeds up the "PropertyInferencePerformanceTest" 2.5x due to smaller intermediate automata. * Weighted sets of size one are handled specially in `TryDeterminize` - they don't need to be determinized and can be copied into result almost as is. (Unless they have non-deterministic transitions. Simple heuristic of "has different destination states" is used to detect that and fallback to slow general path). * Representation for `WeightedStateSet` is changed from (int -> float) dictionary to sorted array of (int, float) pairs. As an optimization, a common case of single-element set does not allocate any arrays. * Determinization code for `ListAutomaton` was removed, because it has never worked
This commit is contained in:
Родитель
526bfefd3b
Коммит
e486a155b9
|
@ -2,13 +2,12 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Collections
|
||||
{
|
||||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.Serialization;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Serialization;
|
||||
|
|
|
@ -500,11 +500,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Sets a new end weight for this state.
|
||||
/// </summary>
|
||||
public void SetEndWeight(Weight weight)
|
||||
public StateBuilder SetEndWeight(Weight weight)
|
||||
{
|
||||
var state = this.builder.states[this.Index];
|
||||
state.EndWeight = weight;
|
||||
this.builder.states[this.Index] = state;
|
||||
return this;
|
||||
}
|
||||
|
||||
#region AddTransition variants
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System.Diagnostics;
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||
{
|
||||
using System;
|
||||
|
@ -45,61 +48,51 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
return false;
|
||||
}
|
||||
|
||||
// Weighted state set is a set of (stateId, weight) pairs, where state ids correspond to states of the original automaton..
|
||||
// Such pairs correspond to states of the resulting automaton.
|
||||
var weightedStateSetQueue = new Queue<Determinization.WeightedStateSet>();
|
||||
var weightedStateSetToNewState = new Dictionary<Determinization.WeightedStateSet, int>();
|
||||
var builder = new Builder();
|
||||
|
||||
var startWeightedStateSet = new Determinization.WeightedStateSet { { this.Start.Index, Weight.One } };
|
||||
weightedStateSetQueue.Enqueue(startWeightedStateSet);
|
||||
weightedStateSetToNewState.Add(startWeightedStateSet, builder.StartStateIndex);
|
||||
builder.Start.SetEndWeight(this.Start.EndWeight);
|
||||
|
||||
while (weightedStateSetQueue.Count > 0)
|
||||
var weightedStateSetStack = new Stack<(bool enter, Determinization.WeightedStateSet set)>();
|
||||
var enqueuedWeightedStateSetStack = new Stack<(bool enter, Determinization.WeightedStateSet set)>();
|
||||
var weightedStateSetToNewState = new Dictionary<Determinization.WeightedStateSet, int>();
|
||||
// This hash set is used to track sets currently in path from root. If we've found a set of states
|
||||
// that we have already seen during current path from root, but weights are different, that means
|
||||
// we've found a non-converging loop - infinite number of weighed sets will be generated if
|
||||
// we continue traversal and determinization will fail. For performance reasons we want to fail
|
||||
// fast if such loop is found.
|
||||
var stateSetsInPath = new Dictionary<Determinization.WeightedStateSet, Determinization.WeightedStateSet>(
|
||||
Determinization.WeightedStateSetOnlyStateComparer.Instance);
|
||||
|
||||
var startWeightedStateSet = new Determinization.WeightedStateSet(this.Start.Index);
|
||||
weightedStateSetStack.Push((true, startWeightedStateSet));
|
||||
weightedStateSetToNewState.Add(startWeightedStateSet, builder.StartStateIndex);
|
||||
|
||||
while (weightedStateSetStack.Count > 0)
|
||||
{
|
||||
// Take one unprocessed state of the resulting automaton
|
||||
Determinization.WeightedStateSet currentWeightedStateSet = weightedStateSetQueue.Dequeue();
|
||||
var currentStateIndex = weightedStateSetToNewState[currentWeightedStateSet];
|
||||
var currentState = builder[currentStateIndex];
|
||||
var (enter, currentWeightedStateSet) = weightedStateSetStack.Pop();
|
||||
|
||||
// Find out what transitions we should add for this state
|
||||
var outgoingTransitionInfos = this.GetOutgoingTransitionsForDeterminization(currentWeightedStateSet);
|
||||
|
||||
// For each transition to add
|
||||
foreach ((TElementDistribution, Weight, Determinization.WeightedStateSet) outgoingTransitionInfo in outgoingTransitionInfos)
|
||||
if (enter)
|
||||
{
|
||||
TElementDistribution elementDistribution = outgoingTransitionInfo.Item1;
|
||||
Weight weight = outgoingTransitionInfo.Item2;
|
||||
Determinization.WeightedStateSet destWeightedStateSet = outgoingTransitionInfo.Item3;
|
||||
|
||||
int destinationStateIndex;
|
||||
if (!weightedStateSetToNewState.TryGetValue(destWeightedStateSet, out destinationStateIndex))
|
||||
if (currentWeightedStateSet.Count > 1)
|
||||
{
|
||||
if (builder.StatesCount == maxStatesBeforeStop)
|
||||
// Only sets with more than 1 state can lead to infinite loops with different weights.
|
||||
// Because if there's only 1 state, than it's weight is always Weight.One.
|
||||
if (!stateSetsInPath.ContainsKey(currentWeightedStateSet))
|
||||
{
|
||||
// Too many states, determinization attempt failed
|
||||
return false;
|
||||
stateSetsInPath.Add(currentWeightedStateSet, currentWeightedStateSet);
|
||||
}
|
||||
|
||||
// Add new state to the result
|
||||
var destinationState = builder.AddState();
|
||||
weightedStateSetToNewState.Add(destWeightedStateSet, destinationState.Index);
|
||||
weightedStateSetQueue.Enqueue(destWeightedStateSet);
|
||||
|
||||
// Compute its ending weight
|
||||
destinationState.SetEndWeight(Weight.Zero);
|
||||
foreach (KeyValuePair<int, Weight> stateIdWithWeight in destWeightedStateSet)
|
||||
{
|
||||
var addedWeight = stateIdWithWeight.Value * this.States[stateIdWithWeight.Key].EndWeight;
|
||||
destinationState.SetEndWeight(destinationState.EndWeight + addedWeight);
|
||||
}
|
||||
|
||||
destinationStateIndex = destinationState.Index;
|
||||
weightedStateSetStack.Push((false, currentWeightedStateSet));
|
||||
}
|
||||
|
||||
// Add transition to the destination state
|
||||
currentState.AddTransition(elementDistribution, weight, destinationStateIndex);
|
||||
if (!EnqueueOutgoingTransitions(currentWeightedStateSet))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
stateSetsInPath.Remove(currentWeightedStateSet);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,6 +104,137 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
this.LogValueOverride = this.LogValueOverride;
|
||||
|
||||
return true;
|
||||
|
||||
bool EnqueueOutgoingTransitions(Determinization.WeightedStateSet currentWeightedStateSet)
|
||||
{
|
||||
var currentStateIndex = weightedStateSetToNewState[currentWeightedStateSet];
|
||||
var currentState = builder[currentStateIndex];
|
||||
|
||||
// Common special-case: definitely deterministic transitions from single state.
|
||||
// In this case no complicated determinization procedure is needed.
|
||||
if (currentWeightedStateSet.Count == 1 &&
|
||||
AllDestinationsAreSame(currentWeightedStateSet[0].Index))
|
||||
{
|
||||
Debug.Assert(currentWeightedStateSet[0].Weight == Weight.One);
|
||||
|
||||
var sourceState = this.States[currentWeightedStateSet[0].Index];
|
||||
foreach (var transition in sourceState.Transitions)
|
||||
{
|
||||
var destinationStates = new Determinization.WeightedStateSet(transition.DestinationStateIndex);
|
||||
var outgoingTransitionInfo = new Determinization.OutgoingTransition(
|
||||
transition.ElementDistribution.Value, transition.Weight, destinationStates);
|
||||
if (!TryAddTransition(enqueuedWeightedStateSetStack, outgoingTransitionInfo, currentState))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Find out what transitions we should add for this state
|
||||
var outgoingTransitions =
|
||||
this.GetOutgoingTransitionsForDeterminization(currentWeightedStateSet);
|
||||
foreach (var outgoingTransition in outgoingTransitions)
|
||||
{
|
||||
if (!TryAddTransition(enqueuedWeightedStateSetStack, outgoingTransition, currentState))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (enqueuedWeightedStateSetStack.Count > 0)
|
||||
{
|
||||
weightedStateSetStack.Push(enqueuedWeightedStateSetStack.Pop());
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks that all transitions from state end up in the same destination. This is used
|
||||
// as a very fast "is determenistic" check, that doesn't care about distributions.
|
||||
// State can have determenistic transitions with different destinations. This case will be
|
||||
// handled by slow path.
|
||||
bool AllDestinationsAreSame(int stateIndex)
|
||||
{
|
||||
var transitions = this.States[stateIndex].Transitions;
|
||||
if (transitions.Count <= 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
var destination = transitions[0].DestinationStateIndex;
|
||||
for (var i = 1; i < transitions.Count; ++i)
|
||||
{
|
||||
if (transitions[i].DestinationStateIndex != destination)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Adds transition from currentState into state corresponding to weighted state set from
|
||||
// outgoingTransitionInfo. If that state does not exist yet it is created and is put into stack
|
||||
// for further processing. This function returns false if determinization has failed.
|
||||
// That can happen because of 2 ressons:
|
||||
// - Too many states were created and its not feasible to continue trying to determinize
|
||||
// automaton further
|
||||
// - An infinite loop with not converging weights was found. It leads to infinite number of states.
|
||||
// So determinization is aborted early.
|
||||
bool TryAddTransition(
|
||||
Stack<(bool enter, Determinization.WeightedStateSet set)> destinationStack,
|
||||
Determinization.OutgoingTransition transition,
|
||||
Builder.StateBuilder currentState)
|
||||
{
|
||||
var destinations = transition.Destinations;
|
||||
if (!weightedStateSetToNewState.TryGetValue(destinations, out var destinationStateIndex))
|
||||
{
|
||||
if (builder.StatesCount == maxStatesBeforeStop)
|
||||
{
|
||||
// Too many states, determinization attempt failed
|
||||
return false;
|
||||
}
|
||||
|
||||
var visitedWeightedStateSet = default(Determinization.WeightedStateSet);
|
||||
var sameSetVisited =
|
||||
destinations.Count > 1 &&
|
||||
stateSetsInPath.TryGetValue(destinations, out visitedWeightedStateSet);
|
||||
|
||||
if (sameSetVisited && !destinations.Equals(visitedWeightedStateSet))
|
||||
{
|
||||
// We arrived into the same state set as before, but with different weights.
|
||||
// This is an infinite non-converging loop. Determinization has failed
|
||||
return false;
|
||||
}
|
||||
|
||||
// Add new state to the result
|
||||
var destinationState = builder.AddState();
|
||||
weightedStateSetToNewState.Add(destinations, destinationState.Index);
|
||||
destinationStack.Push((true, destinations));
|
||||
|
||||
if (destinations.Count > 1 && !sameSetVisited)
|
||||
{
|
||||
destinationStack.Push((false, destinations));
|
||||
}
|
||||
|
||||
// Compute its ending weight
|
||||
destinationState.SetEndWeight(Weight.Zero);
|
||||
for (var i = 0; i < destinations.Count; ++i)
|
||||
{
|
||||
var weightedState = destinations[i];
|
||||
var addedWeight = weightedState.Weight * this.States[weightedState.Index].EndWeight;
|
||||
destinationState.SetEndWeight(destinationState.EndWeight + addedWeight);
|
||||
}
|
||||
|
||||
destinationStateIndex = destinationState.Index;
|
||||
}
|
||||
|
||||
// Add transition to the destination state
|
||||
currentState.AddTransition(transition.ElementDistribution, transition.Weight, destinationStateIndex);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -124,82 +248,121 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// The first two elements of a tuple define the element distribution and the weight of a transition.
|
||||
/// The third element defines the outgoing state.
|
||||
/// </returns>
|
||||
protected abstract List<(TElementDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
|
||||
protected abstract IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
|
||||
Determinization.WeightedStateSet sourceState);
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Groups together helper classes used for automata determinization.
|
||||
/// </summary>
|
||||
protected static class Determinization
|
||||
{
|
||||
public struct OutgoingTransition
|
||||
{
|
||||
public TElementDistribution ElementDistribution { get; }
|
||||
public Weight Weight { get; }
|
||||
public WeightedStateSet Destinations { get; }
|
||||
|
||||
public OutgoingTransition(
|
||||
TElementDistribution elementDistribution, Weight weight, WeightedStateSet destinations)
|
||||
{
|
||||
this.ElementDistribution = elementDistribution;
|
||||
this.Weight = weight;
|
||||
this.Destinations = destinations;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents
|
||||
/// </summary>
|
||||
public struct WeightedState : IComparable, IComparable<WeightedState>
|
||||
{
|
||||
/// <summary>
|
||||
/// Index of the state.
|
||||
/// </summary>
|
||||
public int Index { get; }
|
||||
|
||||
/// <summary>
|
||||
/// High bits of the state weight. Only these bits are used when comparing weighted states
|
||||
/// for equality or when calculating hash code.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Using high bits for hash code allows to have "fuzzy" matching of WeightedStateSets.
|
||||
/// Sometimes there's more than one way to get to the same weighted state set in automaton,
|
||||
/// but due to using floating point numbers calculated weights are slightly off. Using
|
||||
/// only high 32 bits for comparison means that 20 bits of mantissa are used. Which means that
|
||||
/// difference between weights (in log space) is no more than ~~ 10e-6 which is a sufficiently
|
||||
/// good precision for all practical purposes.
|
||||
/// </remarks>
|
||||
public int WeightHighBits { get; }
|
||||
|
||||
public Weight Weight { get; }
|
||||
|
||||
public WeightedState(int index, Weight weight)
|
||||
{
|
||||
this.Index = index;
|
||||
this.WeightHighBits = (int)(BitConverter.DoubleToInt64Bits(weight.LogValue) >> 32);
|
||||
this.Weight = weight;
|
||||
}
|
||||
|
||||
public int CompareTo(object obj)
|
||||
{
|
||||
return obj is WeightedState that
|
||||
? this.CompareTo(that)
|
||||
: throw new InvalidOperationException(
|
||||
"WeightedState can be compared only to another WeightedState");
|
||||
}
|
||||
|
||||
public int CompareTo(WeightedState that) => Index.CompareTo(that.Index);
|
||||
|
||||
public override int GetHashCode() => (Index ^ WeightHighBits).GetHashCode();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents a state of the resulting automaton in the power set construction.
|
||||
/// It is essentially a set of (stateId, weight) pairs of the source automaton, where each state id is unique.
|
||||
/// Supports a quick lookup of the weight by state id.
|
||||
/// </summary>
|
||||
public class WeightedStateSet : IEnumerable<KeyValuePair<int, Weight>>
|
||||
public struct WeightedStateSet
|
||||
{
|
||||
/// <summary>
|
||||
/// A mapping from state ids to weights.
|
||||
/// A mapping from state ids to weights. This array is sorted by state Id.
|
||||
/// </summary>
|
||||
private readonly Dictionary<int, Weight> stateIdToWeight;
|
||||
private readonly ReadOnlyArray<WeightedState> weightedStates;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="WeightedStateSet"/> class.
|
||||
/// </summary>
|
||||
public WeightedStateSet() =>
|
||||
this.stateIdToWeight = new Dictionary<int, Weight>();
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="WeightedStateSet"/> class.
|
||||
/// </summary>
|
||||
/// <param name="stateIdToWeight">A collection of (stateId, weight) pairs.
|
||||
/// </param>
|
||||
public WeightedStateSet(IEnumerable<KeyValuePair<int, Weight>> stateIdToWeight) =>
|
||||
this.stateIdToWeight = stateIdToWeight.ToDictionary(kv => kv.Key, kv => kv.Value);
|
||||
private readonly int singleStateIndex;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the weight for a given state id.
|
||||
/// </summary>
|
||||
/// <param name="stateId">The state id.</param>
|
||||
/// <returns>The weight.</returns>
|
||||
public Weight this[int stateId]
|
||||
public WeightedStateSet(int stateIndex)
|
||||
{
|
||||
get => this.stateIdToWeight[stateId];
|
||||
set => this.stateIdToWeight[stateId] = value;
|
||||
this.weightedStates = null;
|
||||
this.singleStateIndex = stateIndex;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds a given state id and a weight to the set.
|
||||
/// </summary>
|
||||
/// <param name="stateId">The state id.</param>
|
||||
/// <param name="weight">The weight.</param>
|
||||
public void Add(int stateId, Weight weight) =>
|
||||
this.stateIdToWeight.Add(stateId, weight);
|
||||
public WeightedStateSet(ReadOnlyArray<WeightedState> weightedStates)
|
||||
{
|
||||
Debug.Assert(weightedStates.Count > 0);
|
||||
Debug.Assert(IsSorted(weightedStates));
|
||||
if (weightedStates.Count == 1)
|
||||
{
|
||||
Debug.Assert(weightedStates[0].Weight == Weight.One);
|
||||
this.weightedStates = null;
|
||||
this.singleStateIndex = weightedStates[0].Index;
|
||||
}
|
||||
else
|
||||
{
|
||||
this.weightedStates = weightedStates;
|
||||
this.singleStateIndex = 0; // <- value doesn't matter, but silences the compiler
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Attempts to retrieve the weight corresponding to a given state id from the set.
|
||||
/// </summary>
|
||||
/// <param name="stateId">The state id.</param>
|
||||
/// <param name="weight">When the method returns, contains the retrieved weight.</param>
|
||||
/// <returns>
|
||||
/// <see langword="true"/> if the given state id was present in the set,
|
||||
/// <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
public bool TryGetWeight(int stateId, out Weight weight) =>
|
||||
this.stateIdToWeight.TryGetValue(stateId, out weight);
|
||||
public int Count =>
|
||||
this.weightedStates.IsNull
|
||||
? 1
|
||||
: this.weightedStates.Count;
|
||||
|
||||
/// <summary>
|
||||
/// Checks whether the state with a given id is present in the set.
|
||||
/// </summary>
|
||||
/// <param name="stateId">The state id,</param>
|
||||
/// <returns>
|
||||
/// <see langword="true"/> if the given state id was present in the set,
|
||||
/// <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
public bool ContainsState(int stateId) =>
|
||||
this.stateIdToWeight.ContainsKey(stateId);
|
||||
public WeightedState this[int index] =>
|
||||
this.weightedStates.IsNull
|
||||
? new WeightedState(this.singleStateIndex, Weight.One)
|
||||
: this.weightedStates[index];
|
||||
|
||||
/// <summary>
|
||||
/// Checks whether this object is equal to a given one.
|
||||
|
@ -209,25 +372,28 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <see langword="true"/> if the objects are equal,
|
||||
/// <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
public override bool Equals(object obj)
|
||||
public override bool Equals(object obj) => obj is WeightedStateSet that && this.Equals(that);
|
||||
|
||||
/// <summary>
|
||||
/// Checks whether this object is equal to a given one.
|
||||
/// </summary>
|
||||
/// <param name="that">The object to compare this object with.</param>
|
||||
/// <returns>
|
||||
/// <see langword="true"/> if the objects are equal,
|
||||
/// <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
public bool Equals(WeightedStateSet that)
|
||||
{
|
||||
if (obj == null || obj.GetType() != typeof(WeightedStateSet))
|
||||
if (this.Count != that.Count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var other = (WeightedStateSet)obj;
|
||||
|
||||
if (this.stateIdToWeight.Count != other.stateIdToWeight.Count)
|
||||
for (var i = 0; i < this.Count; ++i)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
foreach (KeyValuePair<int, Weight> pair in this.stateIdToWeight)
|
||||
{
|
||||
// TODO: Should we allow for some tolerance? But what about hashing then?
|
||||
Weight otherWeight;
|
||||
if (!other.stateIdToWeight.TryGetValue(pair.Key, out otherWeight) || otherWeight != pair.Value)
|
||||
var state1 = this[i];
|
||||
var state2 = that[i];
|
||||
if (state1.Index != state2.Index || state1.WeightHighBits != state2.WeightHighBits)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
@ -240,17 +406,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// Computes the hash code of this instance.
|
||||
/// </summary>
|
||||
/// <returns>The computed hash code.</returns>
|
||||
/// <remarks>Only state ids</remarks>
|
||||
public override int GetHashCode()
|
||||
{
|
||||
int result = 0;
|
||||
foreach (KeyValuePair<int, Weight> pair in this.stateIdToWeight)
|
||||
var result = this[0].GetHashCode();
|
||||
for (var i = 1; i < this.Count; ++i)
|
||||
{
|
||||
int pairHash = Hash.Start;
|
||||
pairHash = Hash.Combine(pairHash, pair.Key.GetHashCode());
|
||||
pairHash = Hash.Combine(pairHash, pair.Value.GetHashCode());
|
||||
|
||||
// Use commutative hashing combination because dictionaries are not ordered
|
||||
result ^= pairHash;
|
||||
result = Hash.Combine(result, this[i].GetHashCode());
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -260,38 +422,128 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// Returns a string representation of the instance.
|
||||
/// </summary>
|
||||
/// <returns>A string representation of the instance.</returns>
|
||||
public override string ToString()
|
||||
public override string ToString() => string.Join(", ", weightedStates);
|
||||
|
||||
/// <summary>
|
||||
/// Turns weighted state set into an array. This is convenient for writing LINQ queries
|
||||
/// in tests.
|
||||
/// </summary>
|
||||
public WeightedState[] ToArray()
|
||||
{
|
||||
StringBuilder builder = new StringBuilder();
|
||||
foreach (var kvp in this.stateIdToWeight)
|
||||
var result = new WeightedState[this.Count];
|
||||
for (var i = 0; i < this.Count; ++i)
|
||||
{
|
||||
builder.AppendLine(kvp.ToString());
|
||||
result[i] = this[i];
|
||||
}
|
||||
|
||||
return builder.ToString();
|
||||
return result;
|
||||
}
|
||||
|
||||
#region IEnumerable implementation
|
||||
|
||||
/// <summary>
|
||||
/// Gets the enumerator.
|
||||
/// Checks weather states array is sorted in ascending order by Index.
|
||||
/// </summary>
|
||||
/// <returns>
|
||||
/// The enumerator.
|
||||
/// </returns>
|
||||
public IEnumerator<KeyValuePair<int, Weight>> GetEnumerator() =>
|
||||
this.stateIdToWeight.GetEnumerator();
|
||||
private static bool IsSorted(ReadOnlyArray<WeightedState> array)
|
||||
{
|
||||
for (var i = 1; i < array.Count; ++i)
|
||||
{
|
||||
if (array[i].Index <= array[i - 1].Index)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the enumerator.
|
||||
/// </summary>
|
||||
/// <returns>
|
||||
/// The enumerator.
|
||||
/// </returns>
|
||||
IEnumerator IEnumerable.GetEnumerator() =>
|
||||
this.GetEnumerator();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
/// <summary>
|
||||
/// Builder for weighted sets.
|
||||
/// </summary>
|
||||
public struct WeightedStateSetBuilder
|
||||
{
|
||||
private List<WeightedState> weightedStates;
|
||||
|
||||
public static WeightedStateSetBuilder Create() =>
|
||||
new WeightedStateSetBuilder()
|
||||
{
|
||||
weightedStates = new List<WeightedState>(1),
|
||||
};
|
||||
|
||||
public void Add(int index, Weight weight) =>
|
||||
this.weightedStates.Add(new WeightedState(index, weight));
|
||||
|
||||
public void Reset() => this.weightedStates.Clear();
|
||||
|
||||
public (WeightedStateSet, Weight) Get()
|
||||
{
|
||||
Debug.Assert(this.weightedStates.Count > 0);
|
||||
|
||||
var sortedStates = this.weightedStates.ToArray();
|
||||
if (sortedStates.Length == 1)
|
||||
{
|
||||
var state = sortedStates[0];
|
||||
sortedStates[0] = new WeightedState(state.Index, Weight.One);
|
||||
return (new WeightedStateSet(sortedStates), state.Weight);
|
||||
}
|
||||
else
|
||||
{
|
||||
Array.Sort(sortedStates);
|
||||
|
||||
var maxWeight = sortedStates[0].Weight;
|
||||
for (var i = 1; i < sortedStates.Length; ++i)
|
||||
{
|
||||
if (sortedStates[i].Weight > maxWeight)
|
||||
{
|
||||
maxWeight = sortedStates[i].Weight;
|
||||
}
|
||||
}
|
||||
|
||||
var normalizer = Weight.Inverse(maxWeight);
|
||||
|
||||
for (var i = 0; i < sortedStates.Length; ++i)
|
||||
{
|
||||
var state = sortedStates[i];
|
||||
sortedStates[i] = new WeightedState(state.Index, state.Weight * normalizer);
|
||||
}
|
||||
|
||||
return (new WeightedStateSet(sortedStates), maxWeight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class WeightedStateSetOnlyStateComparer : IEqualityComparer<WeightedStateSet>
|
||||
{
|
||||
public static readonly WeightedStateSetOnlyStateComparer Instance =
|
||||
new WeightedStateSetOnlyStateComparer();
|
||||
|
||||
public bool Equals(WeightedStateSet x, WeightedStateSet y)
|
||||
{
|
||||
if (x.Count != y.Count)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for (var i = 0; i < x.Count; ++i)
|
||||
{
|
||||
if (x[i].Index != y[i].Index)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public int GetHashCode(WeightedStateSet set)
|
||||
{
|
||||
var result = set[0].Index.GetHashCode();
|
||||
for (var i = 1; i < set.Count; ++i)
|
||||
{
|
||||
result = Hash.Combine(result, set[i].Index.GetHashCode());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// The first two elements of a tuple define the element distribution and the weight of a transition.
|
||||
/// The third element defines the outgoing state.
|
||||
/// </returns>
|
||||
protected override List<(TElementDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
|
||||
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
|
||||
Determinization.WeightedStateSet sourceState)
|
||||
{
|
||||
throw new NotImplementedException("Determinization is not yet supported for this type of automata.");
|
||||
|
@ -84,106 +84,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// The first two elements of a tuple define the element distribution and the weight of a transition.
|
||||
/// The third element defines the outgoing state.
|
||||
/// </returns>
|
||||
protected override List<(TElementDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
|
||||
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
|
||||
Determinization.WeightedStateSet sourceState)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
|
||||
//// Build a list of elements, with probabilities
|
||||
//var elementLists = new Dictionary<TElement, List<TransitionElement>>();
|
||||
//var uniformList = new List<TransitionElement>();
|
||||
//foreach (KeyValuePair<int, Weight> stateIdWeight in sourceState)
|
||||
//{
|
||||
// var state = this.States[stateIdWeight.Key];
|
||||
// for (int i = 0; i < state.TransitionCount; ++i)
|
||||
// {
|
||||
// AddTransitionElements(state.GetTransition(i), stateIdWeight.Value, elementLists, uniformList);
|
||||
// }
|
||||
//}
|
||||
|
||||
//// Produce an outgoing transition for each unique subset of overlapping segments
|
||||
//var results = new List<Tuple<TElementDistribution, Weight, Determinization.WeightedStateSet>>();
|
||||
|
||||
//foreach (var kvp in elementLists)
|
||||
//{
|
||||
// AddResult(results, kvp.Value);
|
||||
//}
|
||||
//AddResult(results, uniformList);
|
||||
//return results;
|
||||
}
|
||||
|
||||
private static void AddResult(List<Tuple<TElementDistribution, Weight, Determinization.WeightedStateSet>> results,
|
||||
List<TransitionElement> transitionElements)
|
||||
{
|
||||
if (transitionElements.Count == 0) return;
|
||||
const double LogEps = -30; // Don't add transitions with log-weight less than this as they have been produced by numerical inaccuracies
|
||||
|
||||
var elementStatesWeights = new Determinization.WeightedStateSet();
|
||||
var elementStateWeightSum = Weight.Zero;
|
||||
foreach (var element in transitionElements)
|
||||
{
|
||||
if (!elementStatesWeights.TryGetWeight(element.destIndex, out var prevStateWeight))
|
||||
{
|
||||
prevStateWeight = Weight.Zero;
|
||||
}
|
||||
|
||||
elementStatesWeights[element.destIndex] = prevStateWeight + element.weight;
|
||||
elementStateWeightSum += element.weight;
|
||||
}
|
||||
|
||||
var destinationState = new Determinization.WeightedStateSet();
|
||||
foreach (KeyValuePair<int, Weight> stateIdWithWeight in elementStatesWeights)
|
||||
{
|
||||
if (stateIdWithWeight.Value.LogValue > LogEps)
|
||||
{
|
||||
Weight stateWeight = stateIdWithWeight.Value * Weight.Inverse(elementStateWeightSum);
|
||||
destinationState.Add(stateIdWithWeight.Key, stateWeight);
|
||||
}
|
||||
}
|
||||
|
||||
Weight transitionWeight = elementStateWeightSum;
|
||||
results.Add(Tuple.Create(transitionElements[0].distribution, transitionWeight, destinationState));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Given a transition and the residual weight of its source state, adds weighted non-zero probability elements
|
||||
/// associated with the transition to the list.
|
||||
/// </summary>
|
||||
/// <param name="transition">The transition.</param>
|
||||
/// <param name="sourceStateResidualWeight">The logarithm of the residual weight of the source state of the transition.</param>
|
||||
/// <param name="elements">The list for storing transition elements.</param>
|
||||
/// <param name="uniformList">The list for storing transition elements for uniform.</param>
|
||||
private static void AddTransitionElements(
|
||||
Transition transition, Weight sourceStateResidualWeight,
|
||||
Dictionary<TElement, List<TransitionElement>> elements, List<TransitionElement> uniformList)
|
||||
{
|
||||
var dist = transition.ElementDistribution.Value;
|
||||
Weight weightBase = transition.Weight * sourceStateResidualWeight;
|
||||
if (dist.IsPointMass)
|
||||
{
|
||||
var pt = dist.Point;
|
||||
// todo: enumerate distribution
|
||||
if (!elements.ContainsKey(pt)) elements[pt] = new List<TransitionElement>();
|
||||
elements[pt].Add(new TransitionElement(transition.DestinationStateIndex, weightBase, dist));
|
||||
}
|
||||
else
|
||||
{
|
||||
uniformList.Add(new TransitionElement(transition.DestinationStateIndex, weightBase, dist));
|
||||
}
|
||||
}
|
||||
|
||||
private class TransitionElement
|
||||
{
|
||||
internal int destIndex;
|
||||
internal Weight weight;
|
||||
internal TElementDistribution distribution;
|
||||
|
||||
internal TransitionElement(int destIndex, Weight weight, TElementDistribution distribution)
|
||||
{
|
||||
this.destIndex = destIndex;
|
||||
this.distribution = distribution;
|
||||
this.weight = weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,78 +27,62 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Computes a set of outgoing transitions from a given state of the determinization result.
|
||||
/// </summary>
|
||||
/// <param name="sourceState">The source state of the determinized automaton represented as
|
||||
/// <param name="sourceStateSet">The source state of the determinized automaton represented as
|
||||
/// a set of (stateId, weight) pairs, where state ids correspond to states of the original automaton.</param>
|
||||
/// <returns>
|
||||
/// A collection of (element distribution, weight, weighted state set) triples corresponding to outgoing transitions from <paramref name="sourceState"/>.
|
||||
/// A collection of (element distribution, weight, weighted state set) triples corresponding to outgoing
|
||||
/// transitions from <paramref name="sourceStateSet"/>.
|
||||
/// The first two elements of a tuple define the element distribution and the weight of a transition.
|
||||
/// The third element defines the outgoing state.
|
||||
/// </returns>
|
||||
protected override List<(DiscreteChar, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
|
||||
Determinization.WeightedStateSet sourceState)
|
||||
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
|
||||
Determinization.WeightedStateSet sourceStateSet)
|
||||
{
|
||||
const double LogEps = -35; // Don't add transitions with log-weight less than this as they have been produced by numerical inaccuracies
|
||||
|
||||
// Build a list of numbered non-zero probability character segment bounds (they are numbered here due to perf. reasons)
|
||||
var segmentBounds = new List<ValueTuple<int, TransitionCharSegmentBound>>();
|
||||
int transitionsProcessed = 0;
|
||||
foreach (KeyValuePair<int, Weight> stateIdWeight in sourceState)
|
||||
var segmentBounds = new List<TransitionCharSegmentBound>();
|
||||
for (var i = 0; i < sourceStateSet.Count; ++i)
|
||||
{
|
||||
|
||||
var state = this.States[stateIdWeight.Key];
|
||||
var sourceState = sourceStateSet[i];
|
||||
var state = this.States[sourceState.Index];
|
||||
foreach (var transition in state.Transitions)
|
||||
{
|
||||
AddTransitionCharSegmentBounds(transition, stateIdWeight.Value, segmentBounds);
|
||||
AddTransitionCharSegmentBounds(transition, sourceState.Weight, segmentBounds);
|
||||
}
|
||||
|
||||
transitionsProcessed += state.Transitions.Count;
|
||||
}
|
||||
|
||||
// Sort segment bounds left-to-right, start-to-end
|
||||
var sortedIndexedSegmentBounds = segmentBounds.ToArray();
|
||||
if (transitionsProcessed > 1)
|
||||
{
|
||||
Array.Sort(sortedIndexedSegmentBounds, CompareSegmentBounds);
|
||||
|
||||
int CompareSegmentBounds((int, TransitionCharSegmentBound) a, (int, TransitionCharSegmentBound) b) =>
|
||||
a.Item2.CompareTo(b.Item2);
|
||||
}
|
||||
segmentBounds.Sort();
|
||||
|
||||
// Produce an outgoing transition for each unique subset of overlapping segments
|
||||
var result = new List<(DiscreteChar, Weight, Determinization.WeightedStateSet)>();
|
||||
Weight currentSegmentStateWeightSum = Weight.Zero;
|
||||
var currentSegmentTotal = WeightSum.Zero();
|
||||
|
||||
var currentSegmentStateWeights = new Dictionary<int, Weight>();
|
||||
foreach (var sb in segmentBounds)
|
||||
var currentSegmentStateWeights = new Dictionary<int, WeightSum>();
|
||||
var currentSegmentStart = (int)char.MinValue;
|
||||
var destinationStateSetBuilder = Determinization.WeightedStateSetBuilder.Create();
|
||||
foreach (var segmentBound in segmentBounds)
|
||||
{
|
||||
currentSegmentStateWeights[sb.Item2.DestinationStateId] = Weight.Zero;
|
||||
}
|
||||
|
||||
var activeSegments = new HashSet<TransitionCharSegmentBound>();
|
||||
int currentSegmentStart = char.MinValue;
|
||||
foreach (var tup in sortedIndexedSegmentBounds)
|
||||
{
|
||||
TransitionCharSegmentBound segmentBound = tup.Item2;
|
||||
|
||||
if (currentSegmentStateWeightSum.LogValue > LogEps && currentSegmentStart < segmentBound.Bound)
|
||||
if (currentSegmentTotal.Count != 0 && currentSegmentStart < segmentBound.Bound)
|
||||
{
|
||||
// Flush previous segment
|
||||
char segmentEnd = (char)(segmentBound.Bound - 1);
|
||||
int segmentLength = segmentEnd - currentSegmentStart + 1;
|
||||
DiscreteChar elementDist = DiscreteChar.InRange((char)currentSegmentStart, segmentEnd);
|
||||
var segmentEnd = (char)(segmentBound.Bound - 1);
|
||||
var segmentLength = segmentEnd - currentSegmentStart + 1;
|
||||
var elementDist = DiscreteChar.InRange((char)currentSegmentStart, segmentEnd);
|
||||
var invTotalWeight = Weight.Inverse(currentSegmentTotal.Sum);
|
||||
|
||||
var destinationState = new Determinization.WeightedStateSet();
|
||||
foreach (KeyValuePair<int, Weight> stateIdWithWeight in currentSegmentStateWeights)
|
||||
destinationStateSetBuilder.Reset();
|
||||
foreach (var stateIdWithWeight in currentSegmentStateWeights)
|
||||
{
|
||||
if (stateIdWithWeight.Value.LogValue > LogEps)
|
||||
{
|
||||
Weight stateWeight = stateIdWithWeight.Value * Weight.Inverse(currentSegmentStateWeightSum);
|
||||
destinationState.Add(stateIdWithWeight.Key, stateWeight);
|
||||
}
|
||||
var stateWeight = stateIdWithWeight.Value.Sum * invTotalWeight;
|
||||
destinationStateSetBuilder.Add(stateIdWithWeight.Key, stateWeight);
|
||||
}
|
||||
|
||||
Weight transitionWeight = Weight.FromValue(segmentLength) * currentSegmentStateWeightSum;
|
||||
result.Add((elementDist, transitionWeight, destinationState));
|
||||
var (destinationStateSet, destinationStateSetWeight) = destinationStateSetBuilder.Get();
|
||||
|
||||
var transitionWeight = Weight.Product(
|
||||
Weight.FromValue(segmentLength),
|
||||
currentSegmentTotal.Sum,
|
||||
destinationStateSetWeight);
|
||||
yield return new Determinization.OutgoingTransition(
|
||||
elementDist, transitionWeight, destinationStateSet);
|
||||
}
|
||||
|
||||
// Update current segment
|
||||
|
@ -106,39 +90,35 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
if (segmentBound.IsStart)
|
||||
{
|
||||
activeSegments.Add(segmentBound);
|
||||
currentSegmentStateWeightSum += segmentBound.Weight;
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] += segmentBound.Weight;
|
||||
currentSegmentTotal += segmentBound.Weight;
|
||||
if (currentSegmentStateWeights.TryGetValue(segmentBound.DestinationStateId, out var stateWeight))
|
||||
{
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] =
|
||||
stateWeight + segmentBound.Weight;
|
||||
}
|
||||
else
|
||||
{
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] = new WeightSum(segmentBound.Weight);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Debug.Assert(currentSegmentStateWeights.ContainsKey(segmentBound.DestinationStateId), "We shouldn't exit a state we didn't enter.");
|
||||
activeSegments.Remove(segmentBounds[tup.Item1 - 1].Item2); // End follows start in original.
|
||||
if (double.IsInfinity(segmentBound.Weight.Value))
|
||||
Debug.Assert(!segmentBound.Weight.IsInfinity);
|
||||
currentSegmentTotal -= segmentBound.Weight;
|
||||
|
||||
var prevStateWeight = currentSegmentStateWeights[segmentBound.DestinationStateId];
|
||||
var newStateWeight = prevStateWeight - segmentBound.Weight;
|
||||
if (newStateWeight.Count == 0)
|
||||
{
|
||||
// Cannot subtract because of the infinities involved.
|
||||
currentSegmentStateWeightSum =
|
||||
activeSegments
|
||||
.Select(sb => sb.Weight)
|
||||
.Aggregate(Weight.Zero, Weight.Sum);
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] =
|
||||
activeSegments
|
||||
.Where(sb => sb.DestinationStateId == segmentBound.DestinationStateId)
|
||||
.Select(sb => sb.Weight)
|
||||
.Aggregate(Weight.Zero, Weight.Sum);
|
||||
currentSegmentStateWeights.Remove(segmentBound.DestinationStateId);
|
||||
}
|
||||
else
|
||||
{
|
||||
currentSegmentStateWeightSum = activeSegments.Count == 0 ? Weight.Zero : Weight.AbsoluteDifference(currentSegmentStateWeightSum, segmentBound.Weight);
|
||||
|
||||
Weight prevStateWeight = currentSegmentStateWeights[segmentBound.DestinationStateId];
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] = Weight.AbsoluteDifference(
|
||||
prevStateWeight, segmentBound.Weight);
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] = newStateWeight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -149,41 +129,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <param name="sourceStateResidualWeight">The logarithm of the residual weight of the source state of the transition.</param>
|
||||
/// <param name="bounds">The list for storing numbered segment bounds.</param>
|
||||
private static void AddTransitionCharSegmentBounds(
|
||||
Transition transition, Weight sourceStateResidualWeight, List<ValueTuple<int, TransitionCharSegmentBound>> bounds)
|
||||
Transition transition, Weight sourceStateResidualWeight, List<TransitionCharSegmentBound> bounds)
|
||||
{
|
||||
var distribution = transition.ElementDistribution.Value;
|
||||
var ranges = distribution.Ranges;
|
||||
int commonValueStart = char.MinValue;
|
||||
Weight commonValue = distribution.ProbabilityOutsideRanges;
|
||||
Weight weightBase = transition.Weight * sourceStateResidualWeight;
|
||||
TransitionCharSegmentBound newSegmentBound;
|
||||
var commonValueStart = (int)char.MinValue;
|
||||
var commonValue = distribution.ProbabilityOutsideRanges;
|
||||
var weightBase = transition.Weight * sourceStateResidualWeight;
|
||||
|
||||
////if (double.IsInfinity(weightBase.Value))
|
||||
////{
|
||||
//// Console.WriteLine("Weight base infinity");
|
||||
////}
|
||||
void AddEndPoints(int start, int end, int destinationIndex, Weight weight)
|
||||
{
|
||||
bounds.Add(new TransitionCharSegmentBound(start, destinationIndex, weight * weightBase, true));
|
||||
bounds.Add(new TransitionCharSegmentBound(end, destinationIndex, weight * weightBase, false));
|
||||
}
|
||||
|
||||
foreach (var range in ranges)
|
||||
{
|
||||
if (range.StartInclusive > commonValueStart && !commonValue.IsZero)
|
||||
{
|
||||
// Add endpoints for the common value
|
||||
Weight segmentWeight = commonValue * weightBase;
|
||||
newSegmentBound = new TransitionCharSegmentBound(commonValueStart, transition.DestinationStateIndex, segmentWeight, true);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
newSegmentBound = new TransitionCharSegmentBound(range.StartInclusive, transition.DestinationStateIndex, segmentWeight, false);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
AddEndPoints(commonValueStart, range.StartInclusive, transition.DestinationStateIndex, commonValue);
|
||||
}
|
||||
|
||||
// Add segment endpoints
|
||||
Weight pieceValue = range.Probability;
|
||||
var pieceValue = range.Probability;
|
||||
if (!pieceValue.IsZero)
|
||||
{
|
||||
Weight segmentWeight = pieceValue * weightBase;
|
||||
newSegmentBound = new TransitionCharSegmentBound(range.StartInclusive, transition.DestinationStateIndex, segmentWeight, true);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
newSegmentBound = new TransitionCharSegmentBound(range.EndExclusive, transition.DestinationStateIndex, segmentWeight, false);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count,newSegmentBound));
|
||||
AddEndPoints(range.StartInclusive, range.EndExclusive, transition.DestinationStateIndex, pieceValue);
|
||||
}
|
||||
|
||||
commonValueStart = range.EndExclusive;
|
||||
|
@ -191,12 +162,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
if (!commonValue.IsZero && (ranges.Count == 0 || ranges[ranges.Count - 1].EndExclusive != DiscreteChar.CharRangeEndExclusive))
|
||||
{
|
||||
// Add endpoints for the last common value segment
|
||||
Weight segmentWeight = commonValue * weightBase;
|
||||
newSegmentBound = new TransitionCharSegmentBound(commonValueStart, transition.DestinationStateIndex, segmentWeight, true);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
newSegmentBound = new TransitionCharSegmentBound(char.MaxValue + 1, transition.DestinationStateIndex, segmentWeight, false);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
AddEndPoints(commonValueStart, char.MaxValue + 1, transition.DestinationStateIndex, commonValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -572,7 +572,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// The first two elements of a tuple define the element distribution and the weight of a transition.
|
||||
/// The third element defines the outgoing state.
|
||||
/// </returns>
|
||||
protected override List<(TPairDistribution, Weight, Determinization.WeightedStateSet)> GetOutgoingTransitionsForDeterminization(
|
||||
protected override IEnumerable<Determinization.OutgoingTransition> GetOutgoingTransitionsForDeterminization(
|
||||
Determinization.WeightedStateSet sourceState)
|
||||
{
|
||||
throw new NotImplementedException("Determinization is not yet supported for this type of automata.");
|
||||
|
|
|
@ -73,7 +73,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Gets value indicating whether weight is infinite.
|
||||
/// </summary>
|
||||
public bool IsInfinity => double.IsPositiveInfinity(Value);
|
||||
public bool IsInfinity => double.IsPositiveInfinity(this.LogValue);
|
||||
|
||||
/// <summary>
|
||||
/// Creates a weight from the logarithm of its value.
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
||||
{
|
||||
/// <summary>
|
||||
/// Helper class for calculating sum of multiple <see cref="Weight"/>. It supports adding
|
||||
/// and then substracting infinite values. Only values which were added into sum can be substracted from it.
|
||||
/// </summary>
|
||||
public struct WeightSum
|
||||
{
|
||||
/// <summary>
|
||||
/// Number of non-infinite weights participating in the sum
|
||||
/// </summary>
|
||||
private readonly int count;
|
||||
|
||||
/// <summary>
|
||||
/// Number of infinite weights participating in the sum
|
||||
/// </summary>
|
||||
private readonly int infCount;
|
||||
|
||||
/// <summary>
|
||||
/// Sum of all non-infinite weights
|
||||
/// </summary>
|
||||
private readonly Weight sum;
|
||||
|
||||
public WeightSum(int count, int infCount, Weight sum)
|
||||
{
|
||||
this.count = count;
|
||||
this.infCount = infCount;
|
||||
this.sum = sum;
|
||||
}
|
||||
|
||||
public WeightSum(Weight init)
|
||||
{
|
||||
this.count = 1;
|
||||
if (init.IsInfinity)
|
||||
{
|
||||
this.infCount = 1;
|
||||
this.sum = Weight.Zero;
|
||||
}
|
||||
else
|
||||
{
|
||||
this.infCount = 0;
|
||||
this.sum = init;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Constructs new empty accumulator instance
|
||||
/// </summary>
|
||||
public static WeightSum Zero() => new WeightSum(0, 0, Weight.Zero);
|
||||
|
||||
public static WeightSum operator +(WeightSum a, Weight b) =>
|
||||
b.IsInfinity
|
||||
? new WeightSum(a.count + 1, a.infCount + 1, a.sum)
|
||||
: new WeightSum(a.count + 1, a.infCount, a.sum + b);
|
||||
|
||||
public static WeightSum operator -(WeightSum a, Weight b) =>
|
||||
a.count == 1
|
||||
? WeightSum.Zero()
|
||||
: (b.IsInfinity
|
||||
? new WeightSum(a.count - 1, a.infCount - 1, a.sum)
|
||||
: new WeightSum(a.count - 1, a.infCount, Weight.AbsoluteDifference(a.sum, b)));
|
||||
|
||||
public int Count => this.count;
|
||||
|
||||
public Weight Sum =>
|
||||
this.infCount != 0
|
||||
? Weight.Infinity
|
||||
: this.sum;
|
||||
}
|
||||
}
|
|
@ -2244,6 +2244,49 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Assert.True(nonUniform.Equals(uniform));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests that SetToConstantSupportOf() Doesn't throw low probability transitions out.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
[Trait("Category", "StringInference")]
|
||||
public void SetToConstantSupportOfWithLowProbabilityTransition1()
|
||||
{
|
||||
var builder = new StringAutomaton.Builder(1);
|
||||
builder.Start
|
||||
.AddTransition('a', Weight.FromValue(5e-40))
|
||||
.SetEndWeight(Weight.One);
|
||||
var automaton = builder.GetAutomaton();
|
||||
automaton.SetToConstantOnSupportOfLog(0.0, automaton);
|
||||
Assert.Equal(new[] { "a" }, automaton.EnumerateSupport().ToArray());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests that SetToConstantSupportOf() Doesn't throw low probability transitions out.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
[Trait("Category", "StringInference")]
|
||||
public void SetToConstantSupportOfWithMultipleLowProbabilityTransitions()
|
||||
{
|
||||
var builder = new StringAutomaton.Builder(1);
|
||||
builder.Start
|
||||
.AddTransition(DiscreteChar.OneOf('a', 'b'), Weight.One)
|
||||
.SetEndWeight(Weight.One);
|
||||
builder.Start
|
||||
.AddTransition(DiscreteChar.OneOf('a', 'b'), Weight.FromLogValue(-1000))
|
||||
.AddTransition('c', Weight.One)
|
||||
.SetEndWeight(Weight.One);
|
||||
builder.Start
|
||||
.AddTransition('c', Weight.FromLogValue(-10000))
|
||||
.SetEndWeight(Weight.One);
|
||||
var automaton = builder.GetAutomaton();
|
||||
var tmp = automaton.Clone();
|
||||
tmp.TryDeterminize();
|
||||
automaton.SetToConstantOnSupportOfLog(0.0, automaton);
|
||||
var support = automaton.EnumerateSupport().OrderBy(s => s).ToArray();
|
||||
Assert.Equal(new[] {"a", "ac", "b", "bc", "c"}, support);
|
||||
}
|
||||
|
||||
|
||||
#endregion
|
||||
|
||||
#region Helpers
|
||||
|
|
|
@ -36,11 +36,13 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var wrapper = new StringAutomatonWrapper(builder);
|
||||
|
||||
var outgoingTransitions =
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(3) } });
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(3));
|
||||
var expectedOutgoingTransitions = new[]
|
||||
{
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
DiscreteChar.Uniform(), Weight.FromValue(6), new Dictionary<int, Weight> { { 1, Weight.FromValue(1) } })
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.Uniform(),
|
||||
Weight.FromValue(6),
|
||||
new[] {(1, Weight.FromValue(1))})
|
||||
};
|
||||
|
||||
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
|
||||
|
@ -61,13 +63,17 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var wrapper = new StringAutomatonWrapper(builder);
|
||||
|
||||
var outgoingTransitions =
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(5) } });
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(5));
|
||||
var expectedOutgoingTransitions = new[]
|
||||
{
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
DiscreteChar.UniformInRange('A', 'Z'), Weight.FromValue(7.5), new Dictionary<int, Weight> { { 2, Weight.FromValue(1) } }),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(17.5), new Dictionary<int, Weight> { { 1, Weight.FromValue(10 / 17.5) }, { 2, Weight.FromValue(7.5 / 17.5) } }),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange('A', 'Z'),
|
||||
Weight.FromValue(7.5),
|
||||
new[] {(2, Weight.FromValue(1))}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange('a', 'z'),
|
||||
Weight.FromValue(10),
|
||||
new[] {(1, Weight.FromValue(1)), (2, Weight.FromValue(0.75))}),
|
||||
};
|
||||
|
||||
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
|
||||
|
@ -91,29 +97,33 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
var wrapper = new StringAutomatonWrapper(builder);
|
||||
|
||||
var outgoingTransitions =
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(6) } });
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(6));
|
||||
var expectedOutgoingTransitions = new[]
|
||||
{
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
DiscreteChar.UniformInRange(char.MinValue, (char)('a' - 1)),
|
||||
Weight.FromValue(30.0 * 97.0 / 98.0),
|
||||
new Dictionary<int, Weight> { { 4, Weight.FromValue(1) } }),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange(char.MinValue, (char) ('a' - 1)),
|
||||
Weight.FromValue(6 * 5.0 * 97.0 / 98.0),
|
||||
new[] {(4, Weight.FromValue(1))}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.PointMass('a'),
|
||||
Weight.FromValue((30.0 / 98.0) + 6.0),
|
||||
new Dictionary<int, Weight> { { 1, Weight.FromValue(6.0 / ((30.0 / 98.0) + 6.0)) }, { 4, Weight.FromValue((30.0 / 98.0) / ((30.0 / 98.0) + 6.0)) } }),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
Weight.FromValue(6),
|
||||
new[]
|
||||
{
|
||||
(1, Weight.FromValue(1.0)),
|
||||
(4, Weight.FromValue(5.0 / 98.0))
|
||||
}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.PointMass('b'),
|
||||
Weight.FromValue(12.0),
|
||||
new Dictionary<int, Weight> { { 1, Weight.FromValue(0.5) }, { 2, Weight.FromValue(0.5)} }),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
Weight.FromValue(6),
|
||||
new[] {(1, Weight.FromValue(1)), (2, Weight.FromValue(1))}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange('c', 'd'),
|
||||
Weight.FromValue(12.0),
|
||||
new Dictionary<int, Weight> { { 2, Weight.FromValue(1.0) } }),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
Weight.FromValue(6 * 3 * (2.0 / 3)),
|
||||
new[] {(2, Weight.FromValue(1))}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange('e', 'g'),
|
||||
Weight.FromValue(24.0),
|
||||
new Dictionary<int, Weight> { { 3, Weight.FromValue(1.0) } }),
|
||||
Weight.FromValue(6 * 4),
|
||||
new[] {(3, Weight.FromValue(1.0))}),
|
||||
};
|
||||
|
||||
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
|
||||
|
@ -133,39 +143,49 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
builder.Start.AddTransition(DiscreteChar.UniformInRanges('z', char.MaxValue), Weight.FromValue(4));
|
||||
|
||||
var wrapper = new StringAutomatonWrapper(builder);
|
||||
|
||||
var outgoingTransitions =
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(new Dictionary<int, Weight> { { 0, Weight.FromValue(5) } });
|
||||
|
||||
double transition1Segment1Weight = 10.0 * 'a' / (char.MaxValue + 1.0);
|
||||
double transition1Segment2Weight = 10.0 * ('z' - 'a') / (char.MaxValue + 1.0);
|
||||
double transition1Segment3Weight = 10.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue + 1.0);
|
||||
double transition2Segment1Weight = 15.0 * ('z' - 'a') / (char.MaxValue - 'a' + 1.0);
|
||||
double transition2Segment2Weight = 15.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue - 'a' + 1.0);
|
||||
double transition3Segment1Weight = 20.0;
|
||||
var outgoingTransitions =
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(1));
|
||||
|
||||
// we have 3 segments:
|
||||
// 1. [char.MinValue, 'a')
|
||||
// 2. ['a', 'z')
|
||||
// 3. ['z', char.MaxValue]
|
||||
var transition1Segment1Weight = 2.0 * 'a' / (char.MaxValue + 1.0);
|
||||
var transition1Segment2Weight = 2.0 * ('z' - 'a') / (char.MaxValue + 1.0);
|
||||
var transition1Segment3Weight = 2.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue + 1.0);
|
||||
var transition2Segment2Weight = 3.0 * ('z' - 'a') / (char.MaxValue - 'a' + 1.0);
|
||||
var transition2Segment3Weight = 3.0 * (char.MaxValue - 'z' + 1.0) / (char.MaxValue - 'a' + 1.0);
|
||||
var transition3Segment3Weight = 4.0;
|
||||
|
||||
var maxSegment2Weight = Math.Max(transition1Segment2Weight, transition2Segment2Weight);
|
||||
var maxSegment3Weight = Math.Max(
|
||||
transition1Segment3Weight,
|
||||
Math.Max(transition2Segment3Weight, transition3Segment3Weight));
|
||||
|
||||
var expectedOutgoingTransitions = new[]
|
||||
{
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange(char.MinValue, (char)('a' - 1)),
|
||||
Weight.FromValue(transition1Segment1Weight),
|
||||
new Dictionary<int, Weight> { { 1, Weight.FromValue(1) } }),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
new[] {(1, Weight.FromValue(1))}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange('a', (char)('z' - 1)),
|
||||
Weight.FromValue(transition1Segment2Weight + transition2Segment1Weight),
|
||||
new Dictionary<int, Weight>
|
||||
{
|
||||
{ 1, Weight.FromValue(transition1Segment2Weight / (transition1Segment2Weight + transition2Segment1Weight)) },
|
||||
{ 2, Weight.FromValue(transition2Segment1Weight / (transition1Segment2Weight + transition2Segment1Weight)) }
|
||||
}),
|
||||
new ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>(
|
||||
Weight.FromValue(maxSegment2Weight),
|
||||
new[]
|
||||
{
|
||||
(1, Weight.FromValue(transition1Segment2Weight / maxSegment2Weight)),
|
||||
(2, Weight.FromValue(transition2Segment2Weight / maxSegment2Weight)),
|
||||
}),
|
||||
ValueTuple.Create(
|
||||
DiscreteChar.UniformInRange('z', char.MaxValue),
|
||||
Weight.FromValue(transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight),
|
||||
new Dictionary<int, Weight>
|
||||
{
|
||||
{ 1, Weight.FromValue(transition1Segment3Weight / (transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight)) },
|
||||
{ 2, Weight.FromValue(transition2Segment2Weight / (transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight)) },
|
||||
{ 3, Weight.FromValue(transition3Segment1Weight / (transition1Segment3Weight + transition2Segment2Weight + transition3Segment1Weight)) }
|
||||
}),
|
||||
Weight.FromValue(maxSegment3Weight),
|
||||
new[]
|
||||
{
|
||||
(1, Weight.FromValue(transition1Segment3Weight / maxSegment3Weight)),
|
||||
(2, Weight.FromValue(transition2Segment3Weight / maxSegment3Weight)),
|
||||
(3, Weight.FromValue(transition3Segment3Weight / maxSegment3Weight)),
|
||||
}),
|
||||
};
|
||||
|
||||
AssertCollectionsEqual(expectedOutgoingTransitions, outgoingTransitions, TransitionInfoEqualityComparer.Instance);
|
||||
|
@ -188,8 +208,8 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
var wrapper = new StringAutomatonWrapper(builder);
|
||||
|
||||
var outgoingTransitions = wrapper.GetOutgoingTransitionsForDeterminization(
|
||||
new Dictionary<int, Weight> { { 0, Weight.FromValue(1) } }).ToArray();
|
||||
var outgoingTransitions =
|
||||
wrapper.GetOutgoingTransitionsForDeterminization(0, Weight.FromValue(1)).ToArray();
|
||||
|
||||
Assert.Equal(5, outgoingTransitions.Length);
|
||||
Assert.True(outgoingTransitions.All(ot => ot.Item1.IsPointMass));
|
||||
|
@ -277,7 +297,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// A comparer for weighted states that allows for some tolerance when comparing weights.
|
||||
/// </summary>
|
||||
private class WeightedStateEqualityComparer :
|
||||
EqualityComparerBase<KeyValuePair<int, Weight>, WeightedStateEqualityComparer>
|
||||
EqualityComparerBase<(int, Weight), WeightedStateEqualityComparer>
|
||||
{
|
||||
/// <summary>
|
||||
/// Checks whether two objects are equal.
|
||||
|
@ -288,9 +308,9 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// <see langword="true"/> if the objects are equal,
|
||||
/// <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
public override bool Equals(KeyValuePair<int, Weight> x, KeyValuePair<int, Weight> y)
|
||||
public override bool Equals((int, Weight) x, (int, Weight) y)
|
||||
{
|
||||
return x.Key == y.Key && WeightEqualityComparer.Instance.Equals(x.Value, y.Value);
|
||||
return x.Item1 == y.Item1 && WeightEqualityComparer.Instance.Equals(x.Item2, y.Item2);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,7 +318,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// A comparer for transition information that allows for some tolerance when comparing weights.
|
||||
/// </summary>
|
||||
private class TransitionInfoEqualityComparer :
|
||||
EqualityComparerBase<ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>>, TransitionInfoEqualityComparer>
|
||||
EqualityComparerBase<ValueTuple<DiscreteChar, Weight, (int, Weight)[]>, TransitionInfoEqualityComparer>
|
||||
{
|
||||
/// <summary>
|
||||
/// Checks whether two objects are equal.
|
||||
|
@ -310,13 +330,13 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// <see langword="false"/> otherwise.
|
||||
/// </returns>
|
||||
public override bool Equals(
|
||||
ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>> x,
|
||||
ValueTuple<DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>> y)
|
||||
ValueTuple<DiscreteChar, Weight, (int, Weight)[]> x,
|
||||
ValueTuple<DiscreteChar, Weight, (int, Weight)[]> y)
|
||||
{
|
||||
return
|
||||
object.Equals(x.Item1, y.Item1) &&
|
||||
WeightEqualityComparer.Instance.Equals(x.Item2, y.Item2) &&
|
||||
Enumerable.SequenceEqual(x.Item3, y.Item3, WeightedStateEqualityComparer.Instance);
|
||||
x.Item3.SequenceEqual(y.Item3, WeightedStateEqualityComparer.Instance);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -338,14 +358,20 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// <param name="sourceState">The source state.</param>
|
||||
/// <returns>The produced transitions.</returns>
|
||||
/// <remarks>See the doc of the original method.</remarks>
|
||||
public IEnumerable<(DiscreteChar, Weight, IEnumerable<KeyValuePair<int, Weight>>)>
|
||||
GetOutgoingTransitionsForDeterminization(IEnumerable<KeyValuePair<int, Weight>> sourceState)
|
||||
public IEnumerable<(DiscreteChar, Weight, (int, Weight)[])>
|
||||
GetOutgoingTransitionsForDeterminization(int sourceState, Weight sourceWeight)
|
||||
{
|
||||
var result = base.GetOutgoingTransitionsForDeterminization(new Determinization.WeightedStateSet(sourceState));
|
||||
return result.Select(t => (t.Item1, t.Item2, (IEnumerable<KeyValuePair<int, Weight>>)t.Item3));
|
||||
var weightedStateSetBuilder = Determinization.WeightedStateSetBuilder.Create();
|
||||
weightedStateSetBuilder.Add(sourceState, sourceWeight);
|
||||
|
||||
var (weightedStateSet, weight) = weightedStateSetBuilder.Get();
|
||||
var result = base.GetOutgoingTransitionsForDeterminization(weightedStateSet);
|
||||
return result.Select(t => (
|
||||
t.ElementDistribution,
|
||||
t.Weight * weight,
|
||||
t.Destinations.ToArray().Select(state => (state.Index, state.Weight)).ToArray()));
|
||||
}
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче