Do not store Owner and Index properties in States (#69)

Owner and Index properties bring 2 issues:
1. They were taking 16 bytes (after aligning everything)
2. They complicate serialization due to cyclic references

Changes made:
1. State class is split in 2 parts: StateData and State. StateData is stored inside Automaton, State is created on demand and acts as a fat-reference to which allows to read/modify StateData that it references and carries around additional properties - Owner and Index.
2. Introduced StateCollection class which wraps List<StateData> and acts like ReadOnlyList<State> but does StateData to State conversion on demand

C# compiler and .NET runtime are good at optimizing away all StateData->State wrapping. No performance degradation was measured in my tests.

Binary serialization remains backward-compatible, but JSON/BinaryFormatter/DataContract ones are changed.
This commit is contained in:
Ivan Korostelev 2018-11-05 14:38:15 +00:00 коммит произвёл GitHub
Родитель 456bbfa809
Коммит e900bfa7e6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 747 добавлений и 596 удалений

Просмотреть файл

@ -46,9 +46,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>The computed condensation.</returns>
public Condensation ComputeCondensation(State root, Func<Transition, bool> transitionFilter, bool useApproximateClosure)
{
Argument.CheckIfNotNull(root, "root");
Argument.CheckIfNotNull(transitionFilter, "transitionFilter");
Argument.CheckIfValid(ReferenceEquals(root.Owner, this), "root", "The given node belongs to a different automaton.");
Argument.CheckIfValid(!root.IsNull, nameof(root));
Argument.CheckIfNotNull(transitionFilter, nameof(transitionFilter));
Argument.CheckIfValid(ReferenceEquals(root.Owner, this), nameof(root), "The given node belongs to a different automaton.");
return new Condensation(root, transitionFilter, useApproximateClosure);
}
@ -108,7 +108,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </param>
internal Condensation(State root, Func<Transition, bool> transitionFilter, bool useApproximateClosure)
{
Debug.Assert(root != null, "A valid root node must be provided.");
Debug.Assert(!root.IsNull, "A valid root node must be provided.");
Debug.Assert(transitionFilter != null, "A valid transition filter must be provided.");
this.Root = root;
@ -164,7 +164,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>The computed total weight.</returns>
public Weight GetWeightToEnd(State state)
{
Argument.CheckIfNotNull(state, "state");
Argument.CheckIfValid(!state.IsNull, nameof(state));
Argument.CheckIfValid(ReferenceEquals(state.Owner, this.Root.Owner), "state", "The given state belongs to a different automaton.");
if (!this.weightsToEndComputed)
@ -189,7 +189,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>The computed total weight.</returns>
public Weight GetWeightFromRoot(State state)
{
Argument.CheckIfNotNull(state, "state");
Argument.CheckIfValid(!state.IsNull, nameof(state));
Argument.CheckIfValid(ReferenceEquals(state.Owner, this.Root.Owner), "state", "The given state belongs to a different automaton.");
if (!this.weightsFromRootComputed)
@ -241,7 +241,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (!stateIdToStateInfo.TryGetValue(transition.DestinationStateIndex, out destinationStateInfo))
{
this.FindStronglyConnectedComponents(
this.Root.Owner.states[transition.DestinationStateIndex], ref traversalIndex, stateIdToStateInfo, stateIdStack);
this.Root.Owner.States[transition.DestinationStateIndex], ref traversalIndex, stateIdToStateInfo, stateIdStack);
stateInfo.Lowlink = Math.Min(stateInfo.Lowlink, stateIdToStateInfo[transition.DestinationStateIndex].Lowlink);
}
else if (destinationStateInfo.InStack)
@ -288,7 +288,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
for (int transitionIndex = 0; transitionIndex < state.TransitionCount; ++transitionIndex)
{
Transition transition = state.GetTransition(transitionIndex);
State destState = state.Owner.states[transition.DestinationStateIndex];
State destState = state.Owner.States[transition.DestinationStateIndex];
if (this.transitionFilter(transition) && !currentComponent.HasState(destState))
{
weightToAdd = Weight.Sum(
@ -367,7 +367,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
for (int transitionIndex = 0; transitionIndex < srcState.TransitionCount; ++transitionIndex)
{
Transition transition = srcState.GetTransition(transitionIndex);
State destState = srcState.Owner.states[transition.DestinationStateIndex];
State destState = srcState.Owner.States[transition.DestinationStateIndex];
if (this.transitionFilter(transition) && !currentComponent.HasState(destState))
{
CondensationStateInfo destStateInfo = this.stateIdToInfo[destState.Index];

Просмотреть файл

@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
using System.Collections.Generic;
@ -40,7 +42,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="state">The state, which epsilon closure this instance will represent.</param>
internal EpsilonClosure(State state)
{
Argument.CheckIfNotNull(state, "state");
Argument.CheckIfValid(!state.IsNull, nameof(state));
// Optimize for a very common case: a single-node closure
bool singleNodeClosure = true;

Просмотреть файл

@ -23,17 +23,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
internal static Dictionary<int, TThis> ExtractGroups(Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis> automaton)
{
Dictionary<int, HashSet<int>> subGraphs;
var order = ComputeTopologicalOrderAndGroupSubgraphs(automaton, out subGraphs);
return BuildSubautomata(automaton.states, order, subGraphs);
var order = ComputeTopologicalOrderAndGroupSubgraphs(automaton, out var subGraphs);
return BuildSubautomata(automaton.States, order, subGraphs);
}
private static Dictionary<int, TThis> BuildSubautomata(
List<State> states,
List<State> topologicalOrder,
IReadOnlyList<State> states,
IReadOnlyList<State> topologicalOrder,
Dictionary<int, HashSet<int>> groupSubGraphs) => groupSubGraphs.ToDictionary(g => g.Key, g => BuildSubautomaton(states, topologicalOrder, g.Key, g.Value));
private static TThis BuildSubautomaton(List<State> states, List<State> topologicalOrder, int group, HashSet<int> subgraph)
private static TThis BuildSubautomaton(IReadOnlyList<State> states, IReadOnlyList<State> topologicalOrder, int group, HashSet<int> subgraph)
{
var weightsFromRoot = ComputeWeightsFromRoot(states.Count, topologicalOrder, group);
var weightsToEnd = ComputeWeightsToEnd(states.Count, topologicalOrder, group);
@ -69,14 +68,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var weightFromRoot = newSourceState.TransitionCount > 0 ? weightsFromRoot[stateIndex] : Weight.Zero;
if (!weightFromRoot.IsZero)
{
subautomaton.startState.AddEpsilonTransition(weightFromRoot, newSourceState);
subautomaton.Start.AddEpsilonTransition(weightFromRoot, newSourceState);
}
// consider end states
var weightToEnd = !hasNoIncomingTransitions.Contains(stateIndex) ? weightsToEnd[stateIndex] : Weight.Zero;
if (!weightToEnd.IsZero)
{
newSourceState.EndWeight = weightToEnd;
newSourceState.SetEndWeight(weightToEnd);
}
correctionFactor = Weight.Sum(correctionFactor, Weight.Product(weightFromRoot, weightToEnd));
@ -84,7 +83,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (!correctionFactor.IsZero) throw new Exception("Write a unit test for this case. Code should be fine.");
var epsilonWeight = Weight.AbsoluteDifference(weightsToEnd[topologicalOrder[0].Index], correctionFactor);
subautomaton.startState.EndWeight = epsilonWeight;
subautomaton.Start.SetEndWeight(epsilonWeight);
return subautomaton;
}
@ -106,16 +105,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
private static List<State> ComputeTopologicalOrderAndGroupSubgraphs(Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis> automaton, out Dictionary<int, HashSet<int>> groupSubGraphs)
{
var topologicalOrder = new Stack<int>();
var states = automaton.states;
var temporary = new BitArray(states.Count);
var permanent = new BitArray(states.Count);
var temporary = new BitArray(automaton.States.Count);
var permanent = new BitArray(automaton.States.Count);
groupSubGraphs = new Dictionary<int, HashSet<int>>();
VisitNode(states, automaton.startState.Index, temporary, permanent, groupSubGraphs, topologicalOrder);
return topologicalOrder.Select(idx => states[idx]).ToList();
VisitNode(automaton.States, automaton.Start.Index, temporary, permanent, groupSubGraphs, topologicalOrder);
return topologicalOrder.Select(idx => automaton.States[idx]).ToList();
}
private static void VisitNode(List<State> states, int stateIdx, BitArray temporary, BitArray permanent, Dictionary<int, HashSet<int>> groupSubGraphs, Stack<int> topologicalOrder)
private static void VisitNode(IReadOnlyList<State> states, int stateIdx, BitArray temporary, BitArray permanent, Dictionary<int, HashSet<int>> groupSubGraphs, Stack<int> topologicalOrder)
{
if (temporary[stateIdx])
{
@ -158,7 +156,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// Ending weights are taken into account.
/// </summary>
/// <remarks>The weights are computed using dynamic programming, going up from leafs to the root.</remarks>
private static Weight[] ComputeWeightsToEnd(int nStates, List<State> topologicalOrder, int group)
private static Weight[] ComputeWeightsToEnd(int nStates, IReadOnlyList<State> topologicalOrder, int group)
{
var weights = CreateZeroWeights(nStates);
// Iterate in the reverse topological order
@ -190,7 +188,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// and ending at that state. Ending weights are not taken into account.
/// </summary>
/// <remarks>The weights are computed using dynamic programming, going down from the root to leafs.</remarks>
private static Weight[] ComputeWeightsFromRoot(int nStates, List<State> topologicalOrder, int group)
private static Weight[] ComputeWeightsFromRoot(int nStates, IReadOnlyList<State> topologicalOrder, int group)
{
var weights = CreateZeroWeights(nStates);
weights[topologicalOrder[0].Index] = Weight.One;

Просмотреть файл

@ -43,9 +43,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
return false;
}
for (int stateId = 0; stateId < this.states.Count; ++stateId)
for (int stateId = 0; stateId < this.States.Count; ++stateId)
{
var state = this.states[stateId];
var state = this.States[stateId];
// There should be no epsilon transitions
for (int transitionIndex = 0; transitionIndex < state.TransitionCount; ++transitionIndex)
@ -135,7 +135,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var startWeightedStateSet = new Determinization.WeightedStateSet { { this.Start.Index, Weight.One } };
weightedStateSetQueue.Enqueue(startWeightedStateSet);
weightedStateSetToNewState.Add(startWeightedStateSet, result.Start);
result.Start.EndWeight = this.Start.EndWeight;
result.Start.SetEndWeight(this.Start.EndWeight);
while (weightedStateSetQueue.Count > 0)
{
@ -169,12 +169,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
weightedStateSetQueue.Enqueue(destWeightedStateSet);
// Compute its ending weight
destinationState.EndWeight = Weight.Zero;
destinationState.SetEndWeight(Weight.Zero);
foreach (KeyValuePair<int, Weight> stateIdWithWeight in destWeightedStateSet)
{
destinationState.EndWeight = Weight.Sum(
destinationState.SetEndWeight(Weight.Sum(
destinationState.EndWeight,
Weight.Product(stateIdWithWeight.Value, this.States[stateIdWithWeight.Key].EndWeight));
Weight.Product(stateIdWithWeight.Value, this.States[stateIdWithWeight.Key].EndWeight)));
}
}
@ -202,7 +202,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
public void SimplifyIfNeeded()
{
if (this.states.Count > MaxStateCountBeforeSimplification || this.PruneTransitionsWithLogWeightLessThan != null)
if (this.States.Count > MaxStateCountBeforeSimplification || this.PruneTransitionsWithLogWeightLessThan != null)
{
////Console.WriteLine(this.ToString(AutomatonFormats.GraphViz));
this.Simplify();
@ -306,10 +306,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="minStatesToActuallyRemove">If the number of stats to remove is less than this value, the removal will not be done.</param>
private void RemoveStates(bool[] statesToKeep, int minStatesToActuallyRemove)
{
int[] oldToNewStateIdMapping = new int[this.states.Count];
int[] oldToNewStateIdMapping = new int[this.States.Count];
int newStateId = 0;
int deadStateCount = 0;
for (int stateId = 0; stateId < this.states.Count; ++stateId)
for (int stateId = 0; stateId < this.States.Count; ++stateId)
{
if (statesToKeep[stateId])
{
@ -340,16 +340,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
funcWithoutStates.Start = funcWithoutStates.States[oldToNewStateIdMapping[this.Start.Index]];
funcWithoutStates.LogValueOverride = this.LogValueOverride;
funcWithoutStates.PruneTransitionsWithLogWeightLessThan = this.PruneTransitionsWithLogWeightLessThan;
for (int i = 0; i < this.states.Count; ++i)
for (int i = 0; i < this.States.Count; ++i)
{
if (oldToNewStateIdMapping[i] == -1)
{
continue;
}
State oldState = this.states[i];
State oldState = this.States[i];
State newState = funcWithoutStates.States[oldToNewStateIdMapping[i]];
newState.EndWeight = oldState.EndWeight;
newState.SetEndWeight(oldState.EndWeight);
for (int transitionIndex = 0; transitionIndex < oldState.TransitionCount; ++transitionIndex)
{
Transition transition = oldState.GetTransition(transitionIndex);
@ -374,7 +374,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="logWeightThreshold">The smallest log weight that a transition can have and not be removed.</param>
public void RemoveTransitionsWithSmallWeights(double logWeightThreshold)
{
foreach (var state in this.states)
foreach (var state in this.States)
{
for (int i = state.TransitionCount-1; i >=0; i--)
{
@ -441,7 +441,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Self-loops are allowed
if (transition.DestinationStateIndex != currentState.Index)
{
isGeneralizedTree &= this.DoLabelStatesForSimplification(this.states[transition.DestinationStateIndex], stateLabels);
isGeneralizedTree &= this.DoLabelStatesForSimplification(this.States[transition.DestinationStateIndex], stateLabels);
}
}
@ -455,9 +455,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
private void MergeParallelTransitions()
{
for (int stateIndex = 0; stateIndex < this.states.Count; ++stateIndex)
for (int stateIndex = 0; stateIndex < this.States.Count; ++stateIndex)
{
State state = this.states[stateIndex];
State state = this.States[stateIndex];
for (int transitionIndex1 = 0; transitionIndex1 < state.TransitionCount; ++transitionIndex1)
{
Transition transition1 = state.GetTransition(transitionIndex1);
@ -541,13 +541,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
copiedState = this.AddState();
copiedState.EndWeight = stateToCopy.EndWeight;
copiedState.SetEndWeight(stateToCopy.EndWeight);
copiedStateCache.Add(stateToCopy.Index, copiedState);
for (int i = 0; i < stateToCopy.TransitionCount; ++i)
{
Transition transitionToCopy = stateToCopy.GetTransition(i);
State destStateToCopy = stateToCopy.Owner.states[transitionToCopy.DestinationStateIndex];
State destStateToCopy = stateToCopy.Owner.States[transitionToCopy.DestinationStateIndex];
if (!lookAtLabels || !stateLabels[destStateToCopy.Index])
{
State copiedDestState = this.DoCopyNonSimplifiable(destStateToCopy, stateLabels, false, copiedStateCache);
@ -696,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
stack.Push(
new StateWeight(
states[transition.DestinationStateIndex],
States[transition.DestinationStateIndex],
Weight.Product(currentWeight, transition.Weight)));
if (!transition.IsEpsilon)
@ -784,7 +784,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
state.EndWeight = Weight.Sum(state.EndWeight, weight);
state.SetEndWeight(Weight.Sum(state.EndWeight, weight));
return true;
}
@ -802,7 +802,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (transition.DestinationStateIndex != state.Index && transition.IsEpsilon && transition.DestinationStateIndex >= firstAllowedStateIndex)
{
if (this.DoAddGeneralizedSequence(
this.states[transition.DestinationStateIndex],
this.States[transition.DestinationStateIndex],
false,
false,
firstAllowedStateIndex,
@ -831,7 +831,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
// Try this epsilon transition
if (this.DoAddGeneralizedSequence(
this.states[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight))
this.States[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight))
{
return true;
}
@ -878,7 +878,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
// Try this epsilon transition
if (this.DoAddGeneralizedSequence(
this.states[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight))
this.States[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight))
{
return true;
}
@ -908,7 +908,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Weight of the existing transition must be taken into account
// This case can fail if the next element is a self-loop and the destination state already has a different one
if (this.DoAddGeneralizedSequence(
this.states[transition.DestinationStateIndex],
this.States[transition.DestinationStateIndex],
false,
false,
firstAllowedStateIndex,
@ -921,7 +921,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
// Add a new transition
State newChild = state.AddTransition(element.ElementDistribution, Weight.One, null, element.Group);
State newChild = state.AddTransition(element.ElementDistribution, Weight.One, default(State), element.Group);
success = this.DoAddGeneralizedSequence(newChild, true, false, firstAllowedStateIndex, currentSequencePos + 1, sequence, weight);
Debug.Assert(success, "This call must always succeed.");
return true;

Просмотреть файл

@ -7,6 +7,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Serialization;
using System.Text;
@ -16,9 +17,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
using Microsoft.ML.Probabilistic.Serialization;
using Microsoft.ML.Probabilistic.Utilities;
/// <content>
/// Contains the class used to represent a state of an automaton.
/// </content>
public abstract partial class Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>
where TSequence : class, IEnumerable<TElement>
where TElementDistribution : class, IDistribution<TElement>, SettableToProduct<TElementDistribution>, SettableToWeightedSumExact<TElementDistribution>, CanGetLogAverageOf<TElementDistribution>, SettableToPartialUniform<TElementDistribution>, new()
@ -26,44 +24,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
where TThis : Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>, new()
{
/// <summary>
/// Represents a state of an automaton.
/// Represents a reference to a state of automaton for exposure in public API.
/// </summary>
[Serializable]
[DataContract(IsReference = true)]
public class State
/// <remarks>
/// Acts as a "fat reference" to state in automaton. In addition to reference to actual StateData it carries
/// 2 additional properties for convinience: <see cref="Owner"/> automaton and <see cref="Index"/> of the state.
/// We don't store them in <see cref="StateData"/> to save some memoty. C# compiler and .NET jitter are good
/// at optimizing wrapping where it is not needed.
/// </remarks>
public struct State : IEquatable<State>
{
//// This class has been made inner so that the user doesn't have to deal with a lot of generic parameters on it.
internal readonly StateData Data;
/// <summary>
/// The default capacity of the <see cref="transitions"/>.
/// Initializes a new instance of <see cref="State"/> class. Used internally by automaton implementation
/// to wrap StateData for use in public Automaton APIs.
/// </summary>
private const int DefaultTransitionArrayCapacity = 1;
/// <summary>
/// The array of outgoing transitions.
/// </summary>
/// <remarks>
/// We don't use <see cref="List{T}"/> here for performance reasons.
/// </remarks>
[DataMember]
private Transition[] transitions = new Transition[DefaultTransitionArrayCapacity];
/// <summary>
/// The number of outgoing transitions from the state.
/// </summary>
[DataMember]
private int transitionCount;
/// <summary>
/// Initializes a new instance of the <see cref="State"/> class.
/// </summary>
public State()
internal State(Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis> owner, int index, StateData data)
{
this.EndWeight = Weight.Zero;
this.Owner = owner;
this.Index = index;
this.Data = data;
}
/// <summary>
/// Initializes a new instance of the <see cref="State"/> class.
/// Initializes a new instance of the <see cref="State"/> class. Created state does not belong
/// to any automaton and has to be added to some automaton explicitly via Automaton.AddStates.
/// </summary>
/// <param name="index">The index of the state.</param>
/// <param name="transitions">The outgoing transitions.</param>
@ -73,73 +59,90 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
: this()
{
Argument.CheckIfInRange(index >= 0, "index", "State index must be non-negative.");
Argument.CheckIfNotNull(transitions, "transitions");
this.Index = index;
this.EndWeight = endWeight;
foreach (var transition in transitions)
{
this.DoAddTransition(transition);
}
this.Data = new StateData(transitions, endWeight);
}
/// <summary>
/// Gets the automaton which owns the state.
/// Returns where this State represents some valid state in automaton.
/// </summary>
/// <remarks>
/// Owner is not serialized to avoid circular references. It has to be restored manually upon deserialization.
/// BinaryFormatter and Newtonsoft.Json handle circular references differently by default.
/// At the same time [DataMember] does the right thing, because IsReference=true property on State DataContract
/// makes DataContractSerializer handle circular references just fine.
/// </remarks>
[DataMember]
[NonSerializedProperty]
public TThis Owner { get; internal set; }
public bool IsNull => this.Data == null;
/// <summary>
/// Helper method for Newtonsoft.Json to skip serialization of <see cref="Owner"/> property.
/// Automaton to which this state belongs.
/// </summary>
public bool ShouldSerializeOwner() => false;
public Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis> Owner { get; }
/// <summary>
/// Gets the index of the state.
/// </summary>
[DataMember]
public int Index { get; internal set; } // TODO: setter of this property is needed only for the state removal procedure
public int Index { get; }
/// <summary>
/// Gets or sets the ending weight of the state.
/// </summary>
[DataMember]
public Weight EndWeight { get; set; }
/// <remarks>
/// C# compiler disallows to use property setter if it sees that <see cref="States"/> instance is a temporary.
/// It is not smart enough to understand that property setter actually changes something behind a reference.
/// To overcome this issue special <see cref="SetEndWeight"/> method is added calling which is equivalent
/// to calling property setter but is not rejected by compiler.
/// </remarks>
public Weight EndWeight => this.Data.EndWeight;
/// <summary>
/// Sets the <see cref="EndWeight"/> property of State.
///
/// Because <see cref="State"/> is a struct, trying to set <see cref="EndWeight"/> on it
/// (if property setter was provided) would result in compilation error. Compiler isn't
/// smart enough to see that setting property just updates the value in referenced <see cref="Data"/>.
/// Having a method call doesn't create this problem.
/// </summary>
/// <param name="weight">New end weight.</param>
public void SetEndWeight(Weight weight)
{
this.Data.EndWeight = weight;
}
/// <summary>
/// Gets a value indicating whether the ending weight of this state is greater than zero.
/// </summary>
public bool CanEnd
{
get { return !this.EndWeight.IsZero; }
}
public bool CanEnd => this.Data.CanEnd;
/// <summary>
/// Gets the number of outgoing transitions.
/// </summary>
public int TransitionCount
{
get { return this.transitionCount; }
}
public int TransitionCount => this.Data.TransitionCount;
/// <summary>
/// Creates the copy of the array of outgoing transitions. Used by quoting.
/// </summary>
/// <returns>The copy of the array of outgoing transitions.</returns>
public Transition[] GetTransitions()
{
var result = new Transition[this.transitionCount];
Array.Copy(this.transitions, result, this.transitionCount);
return result;
}
public Transition[] GetTransitions() => this.Data.GetTransitions();
/// <summary>
/// Compares 2 states for equality.
/// </summary>
public static bool operator ==(State a, State b) => a.Data == b.Data;
/// <summary>
/// Compares 2 states for inequality.
/// </summary>
public static bool operator !=(State a, State b) => !(a == b);
/// <summary>
/// Compares 2 states for equality.
/// </summary>
public bool Equals(State that) => this == that;
/// <summary>
/// Compares 2 states for equality.
/// </summary>
public override bool Equals(object obj) => obj is State that && this.Equals(that);
/// <summary>
/// Returns HashCode of this state.
/// </summary>
public override int GetHashCode() => this.Data?.GetHashCode() ?? 0;
/// <summary>
/// Adds a series of transitions labeled with the elements of a given sequence to the current state,
@ -152,16 +155,19 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </param>
/// <param name="group">The group of the added transitions.</param>
/// <returns>The last state in the added transition series.</returns>
public State AddTransitionsForSequence(TSequence sequence, State destinationState = null, int group = 0)
public State AddTransitionsForSequence(TSequence sequence, State destinationState = default(State), int group = 0)
{
State currentState = this;
IEnumerator<TElement> enumerator = sequence.GetEnumerator();
bool moveNext = enumerator.MoveNext();
while (moveNext)
var currentState = this;
using (var enumerator = sequence.GetEnumerator())
{
TElement element = enumerator.Current;
moveNext = enumerator.MoveNext();
currentState = currentState.AddTransition(element, Weight.One, moveNext ? null : destinationState, group);
var moveNext = enumerator.MoveNext();
while (moveNext)
{
var element = enumerator.Current;
moveNext = enumerator.MoveNext();
currentState = currentState.AddTransition(
element, Weight.One, moveNext ? default(State) : destinationState, group);
}
}
return currentState;
@ -177,7 +183,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// If the value of this parameter is <see langword="null"/>, a new state will be created.</param>
/// <param name="group">The group of the added transition.</param>
/// <returns>The destination state of the added transition.</returns>
public State AddTransition(TElement element, Weight weight, State destinationState = null, int group = 0)
public State AddTransition(TElement element, Weight weight, State destinationState = default(State), int group = 0)
{
return this.AddTransition(new TElementDistribution { Point = element }, weight, destinationState, group);
}
@ -191,7 +197,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// If the value of this parameter is <see langword="null"/>, a new state will be created.</param>
/// <param name="group">The group of the added transition.</param>
/// <returns>The destination state of the added transition.</returns>
public State AddEpsilonTransition(Weight weight, State destinationState = null, int group = 0)
public State AddEpsilonTransition(Weight weight, State destinationState = default(State), int group = 0)
{
return this.AddTransition(null, weight, destinationState, group);
}
@ -209,9 +215,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// If the value of this parameter is <see langword="null"/>, a new state will be created.</param>
/// <param name="group">The group of the added transition.</param>
/// <returns>The destination state of the added transition.</returns>
public State AddTransition(TElementDistribution elementDistribution, Weight weight, State destinationState = null, int group = 0)
public State AddTransition(TElementDistribution elementDistribution, Weight weight, State destinationState = default(State), int group = 0)
{
if (destinationState == null)
if (destinationState.IsNull)
{
destinationState = this.Owner.AddState();
}
@ -231,11 +237,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>The destination state of the added transition.</returns>
public State AddTransition(Transition transition)
{
Argument.CheckIfValid(this.Owner == null || transition.DestinationStateIndex < this.Owner.states.Count, "transition", "The destination state index is not valid.");
Argument.CheckIfValid(this.Owner == null || transition.DestinationStateIndex < this.Owner.statesData.Count, "transition", "The destination state index is not valid.");
this.Data.AddTransition(transition);
if (this.Owner.isEpsilonFree == true)
{
this.Owner.isEpsilonFree = !transition.IsEpsilon;
}
this.DoAddTransition(transition);
if (this.Owner.isEpsilonFree==true) this.Owner.isEpsilonFree = !transition.IsEpsilon;
return this.Owner.states[transition.DestinationStateIndex];
return this.Owner.States[transition.DestinationStateIndex];
}
/// <summary>
@ -270,11 +280,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
/// <param name="index">The index of the transition.</param>
/// <returns>The transition.</returns>
public Transition GetTransition(int index)
{
//Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given.");
return this.transitions[index];
}
public Transition GetTransition(int index) => this.Data.GetTransition(index);
/// <summary>
/// Replaces the transition at a given index with a given transition.
@ -283,28 +289,26 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="updatedTransition">The transition to replace with.</param>
public void SetTransition(int index, Transition updatedTransition)
{
Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given.");
Argument.CheckIfValid(updatedTransition.DestinationStateIndex < this.Owner.states.Count, "updatedTransition", "The destination state index is not valid.");
Argument.CheckIfInRange(index >= 0 && index < this.TransitionCount, "index", "An invalid transition index given.");
Argument.CheckIfValid(updatedTransition.DestinationStateIndex < this.Owner.statesData.Count, "updatedTransition", "The destination state index is not valid.");
if (updatedTransition.IsEpsilon) {
if (updatedTransition.IsEpsilon)
{
this.Owner.isEpsilonFree = false;
}
else
{
this.Owner.isEpsilonFree = null;
}
this.transitions[index] = updatedTransition;
this.Data.SetTransition(index, updatedTransition);
}
/// <summary>
/// Removes the transition with a given index.
/// </summary>
/// <param name="index">The index of the transition to remove.</param>
public void RemoveTransition(int index)
{
Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given.");
this.transitions[index] = this.transitions[--this.transitionCount];
}
public void RemoveTransition(int index) => this.Data.RemoveTransition(index);
/// <summary>
/// Returns a string that represents the state.
@ -314,7 +318,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
const string StartStateMarker = "START ->";
const string TransitionSeparator = ",";
var sb = new StringBuilder();
bool isStartState = this.Owner != null && this.Owner.Start == this;
@ -341,7 +345,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (CanEnd)
{
if (!firstTransition) sb.Append(TransitionSeparator);
sb.Append(this.EndWeight.Value+" -> END");
sb.Append(this.EndWeight.Value + " -> END");
}
return sb.ToString();
@ -389,22 +393,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
#region Helpers
/// <summary>
/// Adds a given transition to the transition array, increasing the size of the array if necessary.
/// </summary>
/// <param name="transition">The transition to add.</param>
private void DoAddTransition(Transition transition)
{
if (this.transitionCount == this.transitions.Length)
{
var newTransitions = new Transition[this.transitionCount * 2];
Array.Copy(this.transitions, newTransitions, this.transitionCount);
this.transitions = newTransitions;
}
this.transitions[this.transitionCount++] = transition;
}
/// <summary>
/// Recursively checks if the automaton has non-trivial loops
/// (i.e. loops consisting of more than one transition).
@ -426,11 +414,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
stateInStack.Add(this.Index, true);
for (int i = 0; i < this.transitionCount; ++i)
for (int i = 0; i < this.TransitionCount; ++i)
{
if (this.transitions[i].DestinationStateIndex != this.Index)
var transition = this.GetTransition(i);
if (transition.DestinationStateIndex != this.Index)
{
State destState = this.Owner.States[this.transitions[i].DestinationStateIndex];
var destState = this.Owner.States[transition.DestinationStateIndex];
if (destState.DoHasNonTrivialLoops(stateInStack))
{
return true;
@ -459,13 +448,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
visitedStates[this.Index] = true;
bool isZero = !this.CanEnd;
int transitionIndex = 0;
while (isZero && transitionIndex < this.transitionCount)
var isZero = !this.CanEnd;
var transitionIndex = 0;
while (isZero && transitionIndex < this.TransitionCount)
{
if (!this.transitions[transitionIndex].Weight.IsZero)
var transition = this.GetTransition(transitionIndex);
if (!transition.Weight.IsZero)
{
State destState = this.Owner.States[this.transitions[transitionIndex].DestinationStateIndex];
var destState = this.Owner.States[transition.DestinationStateIndex];
isZero = destState.DoIsZero(visitedStates);
}
@ -486,43 +476,42 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
TSequence sequence, int sequencePosition, Dictionary<(int, int), Weight> valueCache)
{
var stateIndexPair = (this.Index, sequencePosition);
Weight cachedValue;
if (valueCache.TryGetValue(stateIndexPair, out cachedValue))
if (valueCache.TryGetValue(stateIndexPair, out var cachedValue))
{
return cachedValue;
}
EpsilonClosure closure = this.GetEpsilonClosure();
var closure = this.GetEpsilonClosure();
Weight value = Weight.Zero;
int count = Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>.SequenceManipulator.GetLength(sequence);
bool isCurrent = sequencePosition < count;
var value = Weight.Zero;
var count = SequenceManipulator.GetLength(sequence);
var isCurrent = sequencePosition < count;
if (isCurrent)
{
TElement element = Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>.SequenceManipulator.GetElement(sequence, sequencePosition);
for (int closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex)
var element = SequenceManipulator.GetElement(sequence, sequencePosition);
for (var closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex)
{
State closureState = closure.GetStateByIndex(closureStateIndex);
Weight closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex);
var closureState = closure.GetStateByIndex(closureStateIndex);
var closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex);
for (int transitionIndex = 0; transitionIndex < closureState.transitionCount; transitionIndex++)
for (int transitionIndex = 0; transitionIndex < closureState.TransitionCount; transitionIndex++)
{
Transition transition = closureState.transitions[transitionIndex];
var transition = closureState.GetTransition(transitionIndex);
if (transition.IsEpsilon)
{
continue; // The destination is a part of the closure anyway
}
State destState = this.Owner.states[transition.DestinationStateIndex];
Weight distWeight = Weight.FromLogValue(transition.ElementDistribution.GetLogProb(element));
var destState = this.Owner.States[transition.DestinationStateIndex];
var distWeight = Weight.FromLogValue(transition.ElementDistribution.GetLogProb(element));
if (!distWeight.IsZero && !transition.Weight.IsZero)
{
Weight destValue = destState.DoGetValue(sequence, sequencePosition + 1, valueCache);
var destValue = destState.DoGetValue(sequence, sequencePosition + 1, valueCache);
if (!destValue.IsZero)
{
value = Weight.Sum(
value,
Weight.Product(closureStateWeight, transition.Weight, distWeight, destValue));
Weight.Product(closureStateWeight, transition.Weight, distWeight, destValue));
}
}
}
@ -546,15 +535,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
get
{
foreach (var state in Owner.States)
{
var trans = state.GetTransitions();
for (int i = 0; i < trans.Length; i++)
{
if (trans[i].DestinationStateIndex == Index) return true;
}
}
return false;
var this_ = this;
return this.Owner.States.Any(
state => state.GetTransitions().Any(
transition => transition.DestinationStateIndex == this_.Index));
}
}
@ -566,10 +550,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
this.EndWeight.Write(writeDouble);
writeInt32(this.Index);
writeInt32(this.transitionCount);
for (var i = 0; i < transitionCount; i++)
writeInt32(this.TransitionCount);
for (var i = 0; i < TransitionCount; i++)
{
transitions[i].Write(writeInt32, writeDouble, writeElementDistribution);
GetTransition(i).Write(writeInt32, writeDouble, writeElementDistribution);
}
}
@ -578,27 +562,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
public static State Read(Func<int> readInt32, Func<double> readDouble, Func<TElementDistribution> readElementDistribution)
{
var res = new State();
res.EndWeight = Weight.Read(readDouble);
res.Index = readInt32();
var endWeight = Weight.Read(readDouble);
// Note: index is serialized for compatibility with old binary serializations
var index = readInt32();
var transitionCount = readInt32();
var transitionLength = res.transitions.Length;
while (transitionLength < transitionCount)
{
transitionLength <<= 1;
}
var transitions = transitionLength == res.transitions.Length ? res.transitions : new Transition[transitionLength];
var transitions = new Transition[transitionCount];
for (var i = 0; i < transitionCount; i++)
{
transitions[i] = Transition.Read(readInt32, readDouble, readElementDistribution);
}
res.transitionCount = transitionCount;
res.transitions = transitions;
return res;
return new State(index, transitions, endWeight);
}
}
}

Просмотреть файл

@ -0,0 +1,75 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
public abstract partial class Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>
where TSequence : class, IEnumerable<TElement>
where TElementDistribution : class, IDistribution<TElement>, SettableToProduct<TElementDistribution>, SettableToWeightedSumExact<TElementDistribution>, CanGetLogAverageOf<TElementDistribution>, SettableToPartialUniform<TElementDistribution>, new()
where TSequenceManipulator : ISequenceManipulator<TSequence, TElement>, new()
where TThis : Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>, new()
{
/// <summary>
/// Represents a collection of automaton states for use in public APIs
/// </summary>
/// <remarks>
/// Is a thin wrapper around Automaton.stateData. Wraps each <see cref="StateData"/> into <see cref="State"/> on demand.
/// </remarks>
public struct StateCollection : IReadOnlyList<State>
{
/// <summary>
/// Owner automaton of all states in collection.
/// </summary>
private readonly Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis> owner;
/// <summary>
/// Cached value of owner.statesData. Cached for performance reasons.
/// </summary>
private readonly List<StateData> statesData;
/// <summary>
/// Initializes instance of <see cref="StateCollection"/>.
/// </summary>
internal StateCollection(Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis> owner, List<StateData> states)
{
this.owner = owner;
this.statesData = owner.statesData;
}
/// <summary>
/// Gets state by its index.
/// </summary>
public State this[int index] => new State(this.owner, index, this.statesData[index]);
/// <summary>
/// Gets number of states in collection.
/// </summary>
public int Count => this.statesData.Count;
/// <summary>
/// Returns enumerator over all states in collection.
/// </summary>
public IEnumerator<State> GetEnumerator()
{
var owner = this.owner;
return this.statesData.Select((data, index) => new State(owner, index, data)).GetEnumerator();
}
/// <summary>
/// Returns enumerator over all states in collection.
/// </summary>
IEnumerator IEnumerable.GetEnumerator()
{
return this.GetEnumerator();
}
}
}
}

Просмотреть файл

@ -0,0 +1,149 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.Serialization;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Serialization;
using Microsoft.ML.Probabilistic.Utilities;
public abstract partial class Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>
where TSequence : class, IEnumerable<TElement>
where TElementDistribution : class, IDistribution<TElement>, SettableToProduct<TElementDistribution>, SettableToWeightedSumExact<TElementDistribution>, CanGetLogAverageOf<TElementDistribution>, SettableToPartialUniform<TElementDistribution>, new()
where TSequenceManipulator : ISequenceManipulator<TSequence, TElement>, new()
where TThis : Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TThis>, new()
{
/// <summary>
/// Represents a state of an automaton that is stored in the Automaton.statesData. This is an internal representation
/// of the state. <see cref="State"/> struct should be used in public APIs.
/// </summary>
[Serializable]
[DataContract]
internal class StateData
{
/// <summary>
/// The default capacity of the <see cref="transitions"/>.
/// </summary>
private const int DefaultTransitionArrayCapacity = 1;
/// <summary>
/// The array of outgoing transitions.
/// </summary>
/// <remarks>
/// We don't use <see cref="List{T}"/> here for performance reasons.
/// </remarks>
[DataMember]
private Transition[] transitions = new Transition[DefaultTransitionArrayCapacity];
/// <summary>
/// The number of outgoing transitions from the state.
/// </summary>
[DataMember]
private int transitionCount;
/// <summary>
/// Initializes a new instance of the <see cref="StateData"/> class.
/// </summary>
public StateData() => this.EndWeight = Weight.Zero;
/// <summary>
/// Initializes a new instance of the <see cref="StateData"/> class.
/// </summary>
/// <param name="transitions">The outgoing transitions.</param>
/// <param name="endWeight">The ending weight of the state.</param>
[Construction("GetTransitions", "EndWeight")]
public StateData(IEnumerable<Transition> transitions, Weight endWeight)
: this()
{
Argument.CheckIfNotNull(transitions, "transitions");
this.EndWeight = endWeight;
foreach (var transition in transitions)
{
this.AddTransition(transition);
}
}
/// <summary>
/// Gets or sets the ending weight of the state.
/// </summary>
[DataMember]
public Weight EndWeight { get; set; }
/// <summary>
/// Gets a value indicating whether the ending weight of this state is greater than zero.
/// </summary>
public bool CanEnd => !this.EndWeight.IsZero;
/// <summary>
/// Gets the number of outgoing transitions.
/// </summary>
public int TransitionCount => this.transitionCount;
/// <summary>
/// Creates the copy of the array of outgoing transitions. Used by quoting.
/// </summary>
/// <returns>The copy of the array of outgoing transitions.</returns>
public Transition[] GetTransitions()
{
var result = new Transition[this.transitionCount];
Array.Copy(this.transitions, result, this.transitionCount);
return result;
}
/// <summary>
/// Adds a transition to the current state.
/// </summary>
/// <param name="transition">The transition to add.</param>
/// <returns>The destination state of the added transition.</returns>
public void AddTransition(Transition transition)
{
if (this.transitionCount == this.transitions.Length)
{
var newTransitions = new Transition[this.transitionCount * 2];
Array.Copy(this.transitions, newTransitions, this.transitionCount);
this.transitions = newTransitions;
}
this.transitions[this.transitionCount++] = transition;
}
/// <summary>
/// Gets the transition at a specified index.
/// </summary>
/// <param name="index">The index of the transition.</param>
/// <returns>The transition.</returns>
public Transition GetTransition(int index)
{
Debug.Assert(index >= 0 && index < this.transitionCount, nameof(index), "An invalid transition index given.");
return this.transitions[index];
}
/// <summary>
/// Replaces the transition at a given index with a given transition.
/// </summary>
/// <param name="index">The index of the transition to replace.</param>
/// <param name="updatedTransition">The transition to replace with.</param>
public void SetTransition(int index, Transition updatedTransition) =>
this.transitions[index] = updatedTransition;
/// <summary>
/// Removes the transition with a given index.
/// </summary>
/// <param name="index">The index of the transition to remove.</param>
public void RemoveTransition(int index)
{
Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given.");
this.transitions[index] = this.transitions[--this.transitionCount];
}
}
}
}

Просмотреть файл

@ -114,7 +114,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </returns>
public bool HasState(State state)
{
Argument.CheckIfNotNull(state, "state");
Argument.CheckIfValid(!state.IsNull, nameof(state));
return this.GetIndexByState(state) != -1;
}
@ -128,12 +128,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </returns>
public int GetIndexByState(State state)
{
Argument.CheckIfNotNull(state, "state");
Argument.CheckIfValid(!state.IsNull, nameof(state));
Argument.CheckIfValid(ReferenceEquals(state.Owner, this.statesInComponent[0].Owner), "state", "The given state belongs to other automaton.");
if (this.statesInComponent.Count == 1)
{
return ReferenceEquals(this.statesInComponent[0], state) ? 0 : -1;
return this.statesInComponent[0].Index == state.Index ? 0 : -1;
}
if (this.stateIdToIndexInComponent == null)

Просмотреть файл

@ -88,18 +88,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// The collection of states.
/// </summary>
[DataMember]
private List<State> states = new List<State>();
/// <summary>
/// A read-only wrapper around the <see cref="states"/>.
/// </summary>
private ReadOnlyList<State> statesReadOnly;
private List<StateData> statesData = new List<StateData>();
/// <summary>
/// The start state.
/// </summary>
[DataMember]
private State startState;
private int startStateIndex;
/// <summary>
/// Whether the automaton is free of epsilon transition.
@ -127,7 +122,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
protected Automaton()
{
// Zero by default
this.statesReadOnly = new ReadOnlyList<State>(this.states);
this.SetToZero();
}
@ -179,14 +173,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
public static int MaxStateCount
{
get
{
return maxStateCount;
}
get => maxStateCount;
set
{
Argument.CheckIfInRange(value > 0, "value", "The maximum number of states must be positive.");
Argument.CheckIfInRange(value > 0, nameof(value), "The maximum number of states must be positive.");
maxStateCount = value;
}
}
@ -197,14 +188,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
public static int MaxStateCountBeforeSimplification
{
get
{
return maxStateCountBeforeSimplification;
}
get => maxStateCountBeforeSimplification;
set
{
Argument.CheckIfInRange(value > 0, "value", "The maximum number of states before simplification must be positive.");
Argument.CheckIfInRange(value > 0, nameof(value), "The maximum number of states before simplification must be positive.");
maxStateCountBeforeSimplification = value;
}
}
@ -215,14 +203,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
public static int MaxDeadStateCount
{
get
{
return maxDeadStateCount;
}
get => maxDeadStateCount;
set
{
Argument.CheckIfInRange(value >= 0, "value", "The maximum number of dead states should be non-negative.");
Argument.CheckIfInRange(value >= 0, nameof(value), "The maximum number of dead states should be non-negative.");
maxDeadStateCount = value;
}
}
@ -230,13 +215,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <summary>
/// Gets the collection of the states of the automaton.
/// </summary>
public ReadOnlyList<State> States
{
get
{
return this.statesReadOnly;
}
}
public StateCollection States => new StateCollection(this, this.statesData);
/// <summary>
/// Gets or sets the start state of the automaton.
@ -246,16 +225,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </remarks>
public State Start
{
get
{
return this.startState;
}
get => new State(this, this.startStateIndex, this.statesData[this.startStateIndex]);
set
{
Argument.CheckIfNotNull(value, "value");
Argument.CheckIfValid(ReferenceEquals(value.Owner, this), "value", "The given state does not belong to this automaton.");
this.startState = value;
Argument.CheckIfValid(!value.IsNull, nameof(value));
Argument.CheckIfValid(ReferenceEquals(value.Owner, this), nameof(value), "The given state does not belong to this automaton.");
this.startStateIndex = value.Index;
}
}
@ -277,13 +253,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
[Construction("GetStates", "Start")]
public static TThis FromStates(IEnumerable<State> states, State startState)
{
Argument.CheckIfNotNull(states, "states");
Argument.CheckIfNotNull(startState, "startState");
Argument.CheckIfNotNull(states, nameof(states));
Argument.CheckIfValid(!startState.IsNull, nameof(startState));
CheckStateConsistency(states, startState);
var result = new TThis();
result.SetStates(states, startState.Index);
result.SetStates(states.Select(state => state.Data), startState.Index);
return result;
}
@ -380,7 +356,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
allowedElements = Distribution.CreatePartialUniform(allowedElements);
State finish = result.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements)));
finish.EndWeight = Weight.FromLogValue(logValue);
finish.SetEndWeight(Weight.FromLogValue(logValue));
}
return result;
@ -433,7 +409,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
foreach (TSequence sequence in sequences)
{
State sequenceEndState = result.Start.AddTransitionsForSequence(sequence);
sequenceEndState.EndWeight = Weight.FromLogValue(logValue);
sequenceEndState.SetEndWeight(Weight.FromLogValue(logValue));
}
}
@ -503,7 +479,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
TThis result = Zero();
result.startState.EndWeight = Weight.FromLogValue(Math.Log(value));
result.Start.SetEndWeight(Weight.FromLogValue(Math.Log(value)));
return result;
}
@ -607,9 +583,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
continue;
}
int index = result.states.Count;
result.AddStates(automaton.states);
result.Start.AddEpsilonTransition(Weight.One, result.states[index + automaton.Start.Index]);
int index = result.statesData.Count;
result.AddStates(automaton.statesData);
result.Start.AddEpsilonTransition(Weight.One, result.States[index + automaton.Start.Index]);
}
return result;
@ -699,20 +675,20 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
TThis result = ConstantOn(1.0, SequenceManipulator.ToSequence(new TElement[0]));
// Build a list of all intermediate end states with their target ending weights while adding repetitions
var endStatesWithTargetWeights = new List<Pair<State, Weight>>();
var endStatesWithTargetWeights = new List<(State, Weight)>();
int prevStateCount = 0;
for (int i = 0; i <= maxTimes; ++i)
{
// Remember added ending states
if (repetitionNumberWeights[i] > 0)
{
for (int j = prevStateCount; j < result.states.Count; ++j)
for (int j = prevStateCount; j < result.statesData.Count; ++j)
{
if (result.states[j].CanEnd)
if (result.statesData[j].CanEnd)
{
endStatesWithTargetWeights.Add(Pair.Create(
result.states[j],
Weight.Product(Weight.FromValue(repetitionNumberWeights[i]), result.states[j].EndWeight)));
endStatesWithTargetWeights.Add(ValueTuple.Create(
result.States[j],
Weight.Product(Weight.FromValue(repetitionNumberWeights[i]), result.statesData[j].EndWeight)));
}
}
}
@ -720,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Add one more repetition
if (i != maxTimes)
{
prevStateCount = result.States.Count;
prevStateCount = result.statesData.Count;
result.AppendInPlaceNoOptimizations(automaton);
}
}
@ -728,7 +704,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Set target ending weights
for (int i = 0; i < endStatesWithTargetWeights.Count; ++i)
{
endStatesWithTargetWeights[i].First.EndWeight = endStatesWithTargetWeights[i].Second;
var (state, weight) = endStatesWithTargetWeights[i];
state.SetEndWeight(weight);
}
return result;
@ -774,11 +751,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
TThis optionalPart = automaton.Clone();
for (int i = 0; i < optionalPart.states.Count; ++i)
for (int i = 0; i < optionalPart.statesData.Count; ++i)
{
if (optionalPart.states[i].CanEnd)
if (optionalPart.statesData[i].CanEnd)
{
optionalPart.states[i].AddEpsilonTransition(optionalPart.states[i].EndWeight, optionalPart.Start);
optionalPart.States[i].AddEpsilonTransition(optionalPart.States[i].EndWeight, optionalPart.Start);
}
}
@ -829,7 +806,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
else
{
StringBuilder builder = new StringBuilder();
this.AppendString(builder, new HashSet<int>(), this.startState.Index, appendElement);
this.AppendString(builder, new HashSet<int>(), this.Start.Index, appendElement);
return builder.ToString();
}
}
@ -857,9 +834,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>True if it the automaton has this group, false otherwise.</returns>
public bool HasGroup(int group)
{
for (int stateIndex = 0; stateIndex < this.states.Count; stateIndex++)
for (int stateIndex = 0; stateIndex < this.statesData.Count; stateIndex++)
{
State state = this.states[stateIndex];
var state = this.statesData[stateIndex];
for (int transitionIndex = 0; transitionIndex < state.TransitionCount; transitionIndex++)
{
Transition transition = state.GetTransition(transitionIndex);
@ -879,9 +856,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>True if it the automaton has groups, false otherwise.</returns>
public bool UsesGroups()
{
for (int stateIndex = 0; stateIndex < this.states.Count; stateIndex++)
for (int stateIndex = 0; stateIndex < this.statesData.Count; stateIndex++)
{
State state = this.states[stateIndex];
var state = this.statesData[stateIndex];
for (int transitionIndex = 0; transitionIndex < state.TransitionCount; transitionIndex++)
{
Transition transition = state.GetTransition(transitionIndex);
@ -911,9 +888,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="group">The specified group.</param>
public void SetGroup(int group)
{
for (int stateIndex = 0; stateIndex < this.states.Count; stateIndex++)
for (int stateIndex = 0; stateIndex < this.statesData.Count; stateIndex++)
{
State state = this.states[stateIndex];
var state = this.statesData[stateIndex];
for (int transitionIndex = 0; transitionIndex < state.TransitionCount; transitionIndex++)
{
Transition transition = state.GetTransition(transitionIndex);
@ -1027,7 +1004,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </remarks>
public bool IsCanonicConstant()
{
if (this.States.Count != 1)
if (this.statesData.Count != 1)
{
return false;
}
@ -1098,13 +1075,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
throw new NotImplementedException("Not yet supported for non-determinizable automata.");
}
for (int stateId = 0; stateId < result.states.Count; ++stateId)
for (int stateId = 0; stateId < result.States.Count; ++stateId)
{
var state = result.states[stateId];
var state = result.States[stateId];
if (state.CanEnd)
{
// Make all accepting states contibute the desired value to the result
state.EndWeight = value;
state.SetEndWeight(value);
}
for (int transitionIndex = 0; transitionIndex < state.TransitionCount; ++transitionIndex)
@ -1152,29 +1129,29 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var result = Zero();
// Result already has 1 state, we add the remaining Count-1 states
result.AddStates(this.states.Count - 1);
result.AddStates(this.statesData.Count - 1);
// And the new start state
result.Start = result.AddState();
// The start state in the original automaton is going to be the one and only end state in result
result.states[this.Start.Index].EndWeight = Weight.One;
result.States[this.Start.Index].SetEndWeight(Weight.One);
for (int i = 0; i < this.states.Count; ++i)
for (int i = 0; i < this.statesData.Count; ++i)
{
var oldState = this.states[i];
for (int j = 0; j < this.states[i].TransitionCount; ++j)
var oldState = this.statesData[i];
for (int j = 0; j < this.statesData[i].TransitionCount; ++j)
{
// Result has original transitions reversed
var oldTransition = oldState.GetTransition(j);
result.states[oldTransition.DestinationStateIndex].AddTransition(
oldTransition.ElementDistribution, oldTransition.Weight, result.states[i]);
result.States[oldTransition.DestinationStateIndex].AddTransition(
oldTransition.ElementDistribution, oldTransition.Weight, result.States[i]);
}
// End states of the original automaton are the new start states
if (oldState.CanEnd)
{
result.Start.AddEpsilonTransition(oldState.EndWeight, result.states[i]);
result.Start.AddEpsilonTransition(oldState.EndWeight, result.States[i]);
}
}
@ -1247,11 +1224,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
// Append the states of the second automaton
List<State> endStates = this.states.Where(nd => nd.CanEnd).ToList();
int stateCount = this.states.Count;
var endStates = this.States.Where(nd => nd.CanEnd).ToList();
int stateCount = this.statesData.Count;
this.AddStates(automaton.states, group);
State secondStartState = this.states[stateCount + automaton.Start.Index];
this.AddStates(automaton.statesData, group);
var secondStartState = this.States[stateCount + automaton.Start.Index];
// todo: make efficient
bool startIncoming = automaton.Start.HasIncomingTransitions;
@ -1277,10 +1254,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
transition.Weight = Weight.Product(transition.Weight, endState.EndWeight);
}
endState.AddTransition(transition);
endState.Data.AddTransition(transition);
}
endState.EndWeight = Weight.Product(endState.EndWeight, secondStartState.EndWeight);
endState.SetEndWeight(Weight.Product(endState.EndWeight, secondStartState.EndWeight));
}
this.RemoveState(secondStartState.Index);
@ -1291,7 +1268,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
State state = endStates[i];
state.AddEpsilonTransition(state.EndWeight, secondStartState, group);
state.EndWeight = Weight.Zero;
state.SetEndWeight(Weight.Zero);
}
}
@ -1482,15 +1459,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
if (hasFirstTerm)
{
result.AddStates(automaton1.states);
result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight1), result.states[1 + automaton1.Start.Index]);
result.AddStates(automaton1.statesData);
result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight1), result.States[1 + automaton1.Start.Index]);
}
if (hasSecondTerm)
{
int cnt = result.states.Count;
result.AddStates(automaton2.states);
result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight2), result.states[cnt + automaton2.Start.Index]);
int cnt = result.statesData.Count;
result.AddStates(automaton2.statesData);
result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight2), result.States[cnt + automaton2.Start.Index]);
}
}
@ -1556,8 +1533,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
public void SetToZero()
{
this.states.Clear();
this.startState = this.AddState();
this.statesData.Clear();
this.startStateIndex = this.AddState().Index;
this.isEpsilonFree = true;
}
@ -1611,7 +1588,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
this.SetToZero();
if (!double.IsNegativeInfinity(logValue))
{
this.Start.EndWeight = Weight.FromLogValue(logValue);
this.Start.SetEndWeight(Weight.FromLogValue(logValue));
this.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements)), this.Start);
}
}
@ -1626,7 +1603,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (!ReferenceEquals(this, automaton))
{
this.SetStates(automaton.states, automaton.Start.Index);
this.SetStates(automaton.statesData, automaton.Start.Index);
this.isEpsilonFree = automaton.isEpsilonFree;
this.LogValueOverride = automaton.LogValueOverride;
this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan;
@ -1657,15 +1634,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
// Add states
this.SetToZero();
this.AddStates(sourceAutomaton.states.Count - 1);
this.AddStates(sourceAutomaton.statesData.Count - 1);
// Copy state parameters and transitions
for (int stateIndex = 0; stateIndex < sourceAutomaton.states.Count; stateIndex++)
for (int stateIndex = 0; stateIndex < sourceAutomaton.statesData.Count; stateIndex++)
{
State thisState = this.states[stateIndex];
var otherState = sourceAutomaton.states[stateIndex];
var thisState = this.States[stateIndex];
var otherState = sourceAutomaton.States[stateIndex];
thisState.EndWeight = otherState.EndWeight;
thisState.SetEndWeight(otherState.EndWeight);
if (otherState == sourceAutomaton.Start)
{
this.Start = thisState;
@ -1675,10 +1652,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
var otherTransition = otherState.GetTransition(transitionIndex);
var transformedTransition = transitionTransform(otherTransition.ElementDistribution, otherTransition.Weight, otherTransition.Group);
this.states[stateIndex].AddTransition(
this.States[stateIndex].AddTransition(
transformedTransition.Item1,
transformedTransition.Item2,
this.states[otherTransition.DestinationStateIndex],
this.States[otherTransition.DestinationStateIndex],
otherTransition.Group);
}
}
@ -1728,7 +1705,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
var point = new List<TElement>();
int? pointLength = null;
var stateDepth = new ArrayDictionary<int>(this.states.Count);
var stateDepth = new ArrayDictionary<int>(this.statesData.Count);
bool isPoint = this.TryComputePointDfs(this.Start, 0, stateDepth, endNodeReachability, point, ref pointLength);
return isPoint && pointLength.HasValue ? SequenceManipulator.ToSequence(point) : null;
}
@ -1848,7 +1825,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <returns>The created state collection copy.</returns>
public State[] GetStates()
{
return this.states.Select(s => new State(s.Index, s.GetTransitions(), s.EndWeight)).ToArray();
// FIXME: discuss what it is supposed to do
// if needed - implement real full automaton deep-copy
return this.States.ToArray();
}
/// <summary>
@ -1858,19 +1837,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <remarks>Indices of the added states are guaranteed to be increasing consecutive.</remarks>
public State AddState()
{
if (this.states.Count >= maxStateCount)
if (this.statesData.Count >= maxStateCount)
{
throw new AutomatonTooLargeException(MaxStateCount);
}
var state = new State
{
Owner = (TThis)this,
Index = this.states.Count
};
this.states.Add(state);
var index = this.statesData.Count;
var stateImpl = new StateData();
this.statesData.Add(stateImpl);
return state;
return new State(this, index, stateImpl);
}
/// <summary>
@ -1897,7 +1873,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
if (this.isEpsilonFree == null)
{
this.isEpsilonFree = true;
foreach (var state in this.states)
foreach (var state in this.statesData)
{
for (int i = 0; i < state.TransitionCount; i++)
{
@ -2020,7 +1996,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
TThis automaton = automata[automatonIndex];
for (int stateIndex = 0; stateIndex < automaton.states.Count; ++stateIndex)
for (int stateIndex = 0; stateIndex < automaton.statesData.Count; ++stateIndex)
{
State state = automaton.States[stateIndex];
Weight transitionWeightSum = Weight.Zero;
@ -2049,7 +2025,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
Weight.FromLogValue(-maxLogTransitionWeightSum),
Weight.FromValue(0.99));
theConverger.Start.AddSelfTransition(uniformDist, transitionWeight);
theConverger.Start.EndWeight = Weight.One;
theConverger.Start.SetEndWeight(Weight.One);
return theConverger;
}
@ -2125,17 +2101,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
automaton = automaton.Clone();
}
int stateCount = this.states.Count;
List<State> endStates = this.states.Where(nd => nd.CanEnd).ToList();
int stateCount = this.statesData.Count;
var endStates = this.States.Where(nd => nd.CanEnd).ToList();
this.AddStates(automaton.states);
State secondStartState = this.states[stateCount + automaton.Start.Index];
this.AddStates(automaton.statesData);
var secondStartState = this.States[stateCount + automaton.Start.Index];
for (int i = 0; i < endStates.Count; i++)
{
State state = endStates[i];
var state = endStates[i];
state.AddEpsilonTransition(state.EndWeight, secondStartState);
state.EndWeight = Weight.Zero;
state.SetEndWeight(Weight.Zero);
}
}
@ -2147,10 +2123,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
//// First, build a reversed graph
int[] edgePlacementIndices = new int[this.states.Count + 1];
for (int i = 0; i < this.states.Count; ++i)
int[] edgePlacementIndices = new int[this.statesData.Count + 1];
for (int i = 0; i < this.statesData.Count; ++i)
{
State state = this.states[i];
var state = this.statesData[i];
for (int j = 0; j < state.TransitionCount; ++j)
{
var transition = state.GetTransition(j);
@ -2170,11 +2146,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
int[] edgeArrayStarts = (int[])edgePlacementIndices.Clone();
int totalEdgeCount = edgePlacementIndices[this.states.Count];
int totalEdgeCount = edgePlacementIndices[this.statesData.Count];
int[] edgeDestinationIndices = new int[totalEdgeCount];
for (int i = 0; i < this.states.Count; ++i)
for (int i = 0; i < this.statesData.Count; ++i)
{
State state = this.states[i];
var state = this.statesData[i];
for (int j = 0; j < state.TransitionCount; ++j)
{
var transition = state.GetTransition(j);
@ -2191,10 +2167,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
//// Now run a depth-first search to label all reachable nodes
bool[] visitedNodes = new bool[this.states.Count];
for (int i = 0; i < this.states.Count; ++i)
bool[] visitedNodes = new bool[this.statesData.Count];
for (int i = 0; i < this.statesData.Count; ++i)
{
if (!visitedNodes[i] && this.states[i].CanEnd)
if (!visitedNodes[i] && this.statesData[i].CanEnd)
{
LabelReachableNodesDfs(i, visitedNodes, edgeDestinationIndices, edgeArrayStarts);
}
@ -2211,10 +2187,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
//// First, build a reversed graph
int[] edgePlacementIndices = new int[this.states.Count + 1];
for (int i = 0; i < this.states.Count; ++i)
int[] edgePlacementIndices = new int[this.statesData.Count + 1];
for (int i = 0; i < this.statesData.Count; ++i)
{
State state = this.states[i];
var state = this.statesData[i];
for (int j = 0; j < state.TransitionCount; ++j)
{
var transition = state.GetTransition(j);
@ -2234,11 +2210,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
int[] edgeArrayStarts = (int[])edgePlacementIndices.Clone();
int totalEdgeCount = edgePlacementIndices[this.states.Count];
int totalEdgeCount = edgePlacementIndices[this.statesData.Count];
int[] edgeDestinationIndices = new int[totalEdgeCount];
for (int i = 0; i < this.states.Count; ++i)
for (int i = 0; i < this.statesData.Count; ++i)
{
State state = this.states[i];
var state = this.statesData[i];
for (int j = 0; j < state.TransitionCount; ++j)
{
var transition = state.GetTransition(j);
@ -2254,8 +2230,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
//// Now run a depth-first search to label all reachable nodes
bool[] visitedNodes = new bool[this.states.Count];
LabelReachableNodesDfs(this.startState.Index, visitedNodes, edgeDestinationIndices, edgeArrayStarts);
bool[] visitedNodes = new bool[this.statesData.Count];
LabelReachableNodesDfs(this.Start.Index, visitedNodes, edgeDestinationIndices, edgeArrayStarts);
return visitedNodes;
}
@ -2292,9 +2268,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// <param name="condensation">A condensation of the automaton.</param>
private void PushWeights(Condensation condensation)
{
for (int i = 0; i < this.states.Count; ++i)
for (int i = 0; i < this.statesData.Count; ++i)
{
State state = this.states[i];
var state = this.States[i];
Weight weightToEnd = condensation.GetWeightToEnd(state);
if (weightToEnd.IsZero)
{
@ -2307,12 +2283,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
Transition transition = state.GetTransition(j);
transition.Weight = Weight.Product(
transition.Weight,
condensation.GetWeightToEnd(this.states[transition.DestinationStateIndex]),
condensation.GetWeightToEnd(this.States[transition.DestinationStateIndex]),
weightToEndInv);
state.SetTransition(j, transition);
}
state.EndWeight = Weight.Product(state.EndWeight, weightToEndInv);
state.SetEndWeight(Weight.Product(state.EndWeight, weightToEndInv));
}
}
@ -2370,7 +2346,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
Transition transition = currentState.GetTransition(i);
State destState = this.states[transition.DestinationStateIndex];
State destState = this.States[transition.DestinationStateIndex];
if (!isEndNodeReachable[destState.Index])
{
continue; // Only walk through the accepting part of the automaton
@ -2448,7 +2424,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
for (int transition1Index = 0; transition1Index < state1.TransitionCount; transition1Index++)
{
Transition transition1 = state1.GetTransition(transition1Index);
State destState1 = state1.Owner.states[transition1.DestinationStateIndex];
State destState1 = state1.Owner.States[transition1.DestinationStateIndex];
if (transition1.IsEpsilon)
{
@ -2463,7 +2439,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
Transition transition2 = state2.GetTransition(transition2Index);
Debug.Assert(!transition2.IsEpsilon, "The second argument of the product operation must be epsilon-free.");
State destState2 = state2.Owner.states[transition2.DestinationStateIndex];
State destState2 = state2.Owner.States[transition2.DestinationStateIndex];
TElementDistribution product;
double productLogNormalizer = Distribution<TElement>.GetLogAverageOf(
@ -2483,7 +2459,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
productState.EndWeight = Weight.Product(state1.EndWeight, state2.EndWeight);
productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight));
return productState;
}
@ -2510,7 +2486,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
oldToNewState.Add(state.Index, resultState);
EpsilonClosure closure = state.GetEpsilonClosure();
resultState.EndWeight = closure.EndWeight;
resultState.SetEndWeight(closure.EndWeight);
for (int stateIndex = 0; stateIndex < closure.Size; ++stateIndex)
{
State closureState = closure.GetStateByIndex(stateIndex);
@ -2520,7 +2496,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
Transition transition = closureState.GetTransition(transitionIndex);
if (!transition.IsEpsilon)
{
State destState = state.Owner.states[transition.DestinationStateIndex];
State destState = state.Owner.States[transition.DestinationStateIndex];
State closureDestState = this.BuildEpsilonClosure(destState, oldToNewState);
resultState.AddTransition(
transition.ElementDistribution, Weight.Product(transition.Weight, closureStateWeight), closureDestState, transition.Group);
@ -2540,9 +2516,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
Debug.Assert(automaton != null, "A valid automaton must be provided.");
// Swap contents
Util.Swap(ref this.states, ref automaton.states);
Util.Swap(ref this.startState, ref automaton.startState);
Util.Swap(ref this.statesReadOnly, ref automaton.statesReadOnly);
Util.Swap(ref this.statesData, ref automaton.statesData);
Util.Swap(ref this.startStateIndex, ref automaton.startStateIndex);
Util.Swap(ref this.isEpsilonFree, ref automaton.isEpsilonFree);
var dummy = this.LogValueOverride;
this.LogValueOverride = automaton.LogValueOverride;
@ -2550,17 +2525,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
dummy = this.PruneTransitionsWithLogWeightLessThan;
this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan;
automaton.PruneTransitionsWithLogWeightLessThan = dummy;
// Update backward references
for (int i = 0; i < this.states.Count; ++i)
{
this.states[i].Owner = (TThis)this;
}
for (int i = 0; i < automaton.states.Count; ++i)
{
automaton.states[i].Owner = automaton;
}
}
/// <summary>
@ -2568,11 +2532,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
/// <param name="newStates">The states to replace the existing states with.</param>
/// <param name="newStartStateIndex">The index of the new start state.</param>
private void SetStates(IEnumerable<State> newStates, int newStartStateIndex)
private void SetStates(IEnumerable<StateData> newStates, int newStartStateIndex)
{
this.states.Clear();
this.statesData.Clear();
this.AddStates(newStates);
this.Start = this.states[newStartStateIndex];
this.Start = this.States[newStartStateIndex];
}
/// <summary>
@ -2581,34 +2545,34 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </summary>
/// <param name="statesToAdd">The states to add.</param>
/// <param name="group">The group for the transitions of the states being added.</param>
private void AddStates(IEnumerable<State> statesToAdd, int group = 0)
private void AddStates(IEnumerable<StateData> statesToAdd, int group = 0)
{
Debug.Assert(statesToAdd != null, "A valid state collection must be provided.");
int startIndex = this.states.Count;
var statesToAddList = statesToAdd as IList<State> ?? statesToAdd.ToList();
int startIndex = this.statesData.Count;
var statesToAddList = statesToAdd as IList<StateData> ?? statesToAdd.ToList();
// Add states
for (int i = 0; i < statesToAddList.Count; ++i)
{
State newState = this.AddState();
newState.EndWeight = statesToAddList[i].EndWeight;
newState.SetEndWeight(statesToAddList[i].EndWeight);
Debug.Assert(statesToAddList[i].Index == i && newState.Index == i + startIndex, "State indices must always be consequent.");
Debug.Assert(newState.Index == i + startIndex, "State indices must always be consequent.");
}
// Add transitions
for (int i = 0; i < statesToAddList.Count; ++i)
{
State stateToAdd = statesToAddList[i];
var stateToAdd = statesToAddList[i];
for (int transitionIndex = 0; transitionIndex < stateToAdd.TransitionCount; transitionIndex++)
{
Transition transitionToAdd = stateToAdd.GetTransition(transitionIndex);
Debug.Assert(transitionToAdd.DestinationStateIndex < statesToAddList.Count, "Self-inconsistent collection of states provided.");
this.states[i + startIndex].AddTransition(
this.States[i + startIndex].AddTransition(
transitionToAdd.ElementDistribution,
transitionToAdd.Weight,
this.states[transitionToAdd.DestinationStateIndex + startIndex],
this.States[transitionToAdd.DestinationStateIndex + startIndex],
group != 0 ? group : transitionToAdd.Group);
}
}
@ -2630,19 +2594,18 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
{
//// TODO: see remarks
Debug.Assert(index >= 0 && index < this.states.Count, "An invalid state index provided.");
Debug.Assert(index >= 0 && index < this.statesData.Count, "An invalid state index provided.");
Debug.Assert(index != this.Start.Index, "Cannot remove the start state.");
Debug.Assert(
!replaceIndex.HasValue || (replaceIndex.Value >= 0 && replaceIndex.Value < this.states.Count),
!replaceIndex.HasValue || (replaceIndex.Value >= 0 && replaceIndex.Value < this.statesData.Count),
"An invalid replace index provided.");
Debug.Assert(!replaceIndex.HasValue || replaceIndex.Value != index, "Replace index must point to a different state.");
this.states.RemoveAt(index);
int stateCount = this.states.Count;
this.statesData.RemoveAt(index);
int stateCount = this.statesData.Count;
for (int i = 0; i < stateCount; i++)
{
State state = this.states[i];
state.Index = i;
StateData state = this.statesData[i];
for (int j = state.TransitionCount - 1; j >= 0; j--)
{
Transition transition = state.GetTransition(j);
@ -2931,20 +2894,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
/// </remarks>
protected Automaton(SerializationInfo info, StreamingContext context)
{
this.states = (List<State>)info.GetValue(nameof(this.states), typeof(List<State>));
foreach (var state in this.states)
{
state.Owner = (TThis)this;
}
this.statesReadOnly = new ReadOnlyList<State>(this.states);
this.startState = (State)info.GetValue(nameof(this.startState), typeof(State));
this.statesData = (List<StateData>)info.GetValue(nameof(this.statesData), typeof(List<StateData>));
this.startStateIndex = (int)info.GetValue(nameof(this.startStateIndex), typeof(int));
this.isEpsilonFree = (bool?)info.GetValue(nameof(this.isEpsilonFree), typeof(bool?));
}
void ISerializable.GetObjectData(SerializationInfo info, StreamingContext context)
{
info.AddValue(nameof(this.states), this.states);
info.AddValue(nameof(this.startState), this.startState);
info.AddValue(nameof(this.statesData), this.statesData);
info.AddValue(nameof(this.startStateIndex), this.startStateIndex);
info.AddValue(nameof(this.isEpsilonFree), this.isEpsilonFree);
}
@ -2959,7 +2917,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
propertyMask[1 << idx++] = this.isEpsilonFree.HasValue && this.isEpsilonFree.Value;
propertyMask[1 << idx++] = this.LogValueOverride.HasValue;
propertyMask[1 << idx++] = this.PruneTransitionsWithLogWeightLessThan.HasValue;
propertyMask[1 << idx++] = this.startState != null;
propertyMask[1 << idx++] = !this.Start.IsNull;
writeInt32(propertyMask.Data);
@ -2973,23 +2931,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
writeDouble(this.PruneTransitionsWithLogWeightLessThan.Value);
}
if (startState != null)
if (!this.Start.IsNull)
{
if (startState.Owner != this)
{
throw new InvalidOperationException("the state owner is not set to the current automaton");
}
startState.Write(writeInt32, writeDouble, writeElementDistribution);
this.Start.Write(writeInt32, writeDouble, writeElementDistribution);
}
writeInt32(states.Count);
foreach (var state in states)
writeInt32(this.statesData.Count);
foreach (var state in this.States)
{
if (state.Owner != this)
{
throw new InvalidOperationException("the state owner is not set to the current automaton");
}
state.Write(writeInt32, writeDouble, writeElementDistribution);
}
}
@ -3020,20 +2969,24 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
res.PruneTransitionsWithLogWeightLessThan = readDouble();
}
if (hasStartState)
{
res.startState = State.Read(readInt32, readDouble, readElementDistribution);
res.startState.Owner = res;
}
var startState =
hasStartState
? State.Read(readInt32, readDouble, readElementDistribution)
: default(State);
var numStates = readInt32();
res.states.Clear();
res.statesData.Clear();
res.AddStates(numStates);
for (var i = 0; i < numStates; i++)
{
var state = State.Read(readInt32, readDouble, readElementDistribution);
state.Owner = res;
res.states.Add(state);
res.statesData[i] = State.Read(readInt32, readDouble, readElementDistribution).Data;
}
if (hasStartState)
{
res.startStateIndex = startState.Index;
}
return res;
}
#endregion

Просмотреть файл

@ -468,7 +468,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
destState.EndWeight = srcSequenceIndex == srcSequenceLength ? mappingState.EndWeight : Weight.Zero;
destState.SetEndWeight(srcSequenceIndex == srcSequenceLength ? mappingState.EndWeight : Weight.Zero);
return destState;
}
@ -541,7 +541,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
}
}
destState.EndWeight = Weight.Product(mappingState.EndWeight, srcState.EndWeight);
destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight));
return destState;
}

Просмотреть файл

@ -226,7 +226,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
var func = Automaton<TSequence, TElement, TElementDistribution, TSequenceManipulator, TWeightFunction>.Zero();
var end = func.Start.AddTransition(elementDistribution, Weight.One);
end.EndWeight = Weight.One;
end.SetEndWeight(Weight.One);
return FromWorkspace(func);
}
@ -431,7 +431,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
for (int i = 0; i <= iterationBound; i++)
{
bool isLengthAllowed = i >= minTimes;
state.EndWeight = isLengthAllowed ? Weight.One : Weight.Zero;
state.SetEndWeight(isLengthAllowed ? Weight.One : Weight.Zero);
if (i < iterationBound)
{
state = state.AddTransition(allowedElements, weight); // todo: clone set?

Просмотреть файл

@ -31,7 +31,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Clone()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).SetEndWeight(Weight.One);
StringAutomaton clone = automaton.Clone();
Assert.Equal(automaton, clone);
@ -390,8 +390,8 @@ namespace Microsoft.ML.Probabilistic.Tests
public void ProductNoDeadBranches()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).AddTransition('c', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition('a', Weight.One).AddTransition('c', Weight.One).SetEndWeight(Weight.One);
StringAutomaton automatonSqr = automaton.Product(automaton);
Assert.Equal(4, automatonSqr.States.Count);
}
@ -462,8 +462,8 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
var otherState = automaton.Start.AddSelfTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.FromValue(0.7));
automaton.Start.EndWeight = Weight.FromValue(0.3);
otherState.EndWeight = Weight.FromValue(0.8);
automaton.Start.SetEndWeight(Weight.FromValue(0.3));
otherState.SetEndWeight(Weight.FromValue(0.8));
StringAutomaton reverse = automaton.Reverse();
StringInferenceTestUtilities.TestValue(reverse, 0.7 * 0.8, "b");
@ -616,7 +616,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransitionsForSequence("abc");
automaton.Start.AddTransitionsForSequence("def").EndWeight = Weight.FromValue(4.0);
automaton.Start.AddTransitionsForSequence("def").SetEndWeight(Weight.FromValue(4.0));
double logNormalizer;
Assert.Equal(Math.Log(4.0), automaton.GetLogNormalizer(), 1e-8);
Assert.True(automaton.TryNormalizeValues(out logNormalizer));
@ -637,7 +637,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(TransitionProbability), automaton.Start);
automaton.Start.EndWeight = Weight.FromValue(EndWeight);
automaton.Start.SetEndWeight(Weight.FromValue(EndWeight));
double logNormalizer = automaton.GetLogNormalizer();
Assert.Equal(Math.Log(EndWeight / (1 - TransitionProbability)), logNormalizer, 1e-8);
Assert.Equal(logNormalizer, automaton.NormalizeValues());
@ -672,7 +672,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(1.01), automaton.Start);
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
TestNonNormalizable(automaton, false);
}
@ -722,7 +722,7 @@ namespace Microsoft.ML.Probabilistic.Tests
state = nextState;
}
state.EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
var closure = automaton.Start.GetEpsilonClosure();
@ -772,7 +772,7 @@ namespace Microsoft.ML.Probabilistic.Tests
state = state.AddTransition('a', Weight.One);
}
state.EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
string point = new string('a', stateCount - 1);
@ -803,10 +803,10 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.True(automaton2.IsZero());
// Null states collection
Assert.Throws<ArgumentNullException>(() => StringAutomaton.FromStates(null, null));
Assert.Throws<ArgumentNullException>(() => StringAutomaton.FromStates(null, default(StringAutomaton.State)));
// Null start state
Assert.Throws<ArgumentNullException>(() => StringAutomaton.FromStates(new[] { theOnlyState }, null));
Assert.Throws<ArgumentException>(() => StringAutomaton.FromStates(new[] { theOnlyState }, default(StringAutomaton.State)));
// Duplicate state indices
Assert.Throws<ArgumentException>(
@ -874,7 +874,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void ConvertToString1()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('b', Weight.One).AddTransition('c', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('b', Weight.One).AddTransition('c', Weight.One).SetEndWeight(Weight.One);
Assert.Equal("ab*c", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("ab*c", automaton.ToString(AutomatonFormats.Regexp));
}
@ -890,7 +890,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var middleState = automaton.Start.AddTransition('b', Weight.One).AddTransition('c', Weight.One);
automaton.Start.AddTransition('a', Weight.One, middleState);
middleState.AddTransition('d', Weight.One).EndWeight = Weight.One;
middleState.AddTransition('d', Weight.One).SetEndWeight(Weight.One);
Assert.Equal("(bc|a)d", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("(bc|a)d", automaton.ToString(AutomatonFormats.Regexp));
}
@ -903,7 +903,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void ConvertToString3()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransitionsForSequence("hello").EndWeight = Weight.One;
automaton.Start.AddTransitionsForSequence("hello").SetEndWeight(Weight.One);
Assert.Equal("hello", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("hello", automaton.ToString(AutomatonFormats.Regexp));
}
@ -916,7 +916,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void ConvertToString4()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransitionsForSequence("hello").EndWeight = Weight.Zero;
automaton.Start.AddTransitionsForSequence("hello").SetEndWeight(Weight.Zero);
Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Regexp));
}
@ -929,9 +929,9 @@ namespace Microsoft.ML.Probabilistic.Tests
public void ConvertToString5()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransitionsForSequence("hello").EndWeight = Weight.One;
automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hi").EndWeight = Weight.One;
automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hey").EndWeight = Weight.One;
automaton.Start.AddTransitionsForSequence("hello").SetEndWeight(Weight.One);
automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hi").SetEndWeight(Weight.One);
automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hey").SetEndWeight(Weight.One);
Assert.Equal("hey|hi|hello", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("hey|hi|hello", automaton.ToString(AutomatonFormats.Regexp));
}
@ -1048,7 +1048,7 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[5].AddTransition('l', Weight.FromValue(1), automaton.States[6]);
automaton.States[6].AddTransition('m', Weight.FromValue(1), automaton.States[3]);
automaton.States[6].AddTransition('n', Weight.FromValue(1), automaton.States[1]);
automaton.States[7].EndWeight = Weight.FromValue(1);
automaton.States[7].SetEndWeight(Weight.FromValue(1));
var distribution = StringDistribution.FromWorkspace(automaton);
var regexPattern = distribution.ToRegex();
@ -1095,7 +1095,7 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[0].AddTransition('o', Weight.FromValue(1), automaton.States[6]);
automaton.States[1].AddTransition('p', Weight.FromValue(1), automaton.States[7]);
automaton.States[6].AddTransition('q', Weight.FromValue(1), automaton.States[7]);
automaton.States[7].EndWeight = Weight.FromValue(1);
automaton.States[7].SetEndWeight(Weight.FromValue(1));
var distribution = StringDistribution.FromWorkspace(automaton);
var regexPattern = distribution.ToRegex();
@ -1176,7 +1176,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Equality3()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
automaton.Start.AddTransition(DiscreteChar.Lower(), Weight.FromLogValue(26 + 1e-3), automaton.Start);
AssertEquals(automaton, automaton);
@ -1191,7 +1191,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Equality4()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
automaton.Start.AddTransition('a', Weight.FromLogValue(1.0 - 1e-3), automaton.Start);
AssertEquals(automaton, automaton);
@ -1222,7 +1222,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton func1 = StringAutomaton.Constant(1.0, DiscreteChar.OneOf('a', 'b'));
StringAutomaton func2 = StringAutomaton.Zero();
func2.Start.AddSelfTransition('a', Weight.One).AddSelfTransition('b', Weight.One).EndWeight = Weight.One;
func2.Start.AddSelfTransition('a', Weight.One).AddSelfTransition('b', Weight.One).SetEndWeight(Weight.One);
AssertEquals(func1, func2);
}
@ -1235,13 +1235,13 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Equality7()
{
StringAutomaton func1 = StringAutomaton.Zero();
func1.Start.AddSelfTransition(DiscreteChar.PointMass('a'), Weight.One).EndWeight = Weight.One;
func1.Start.AddSelfTransition(DiscreteChar.PointMass('a'), Weight.One).SetEndWeight(Weight.One);
StringAutomaton func2 = StringAutomaton.Zero();
func2.Start
.AddEpsilonTransition(Weight.One)
.AddTransition(DiscreteChar.PointMass('a'), Weight.One, func2.Start)
.EndWeight = Weight.One;
.SetEndWeight(Weight.One);
AssertEquals(func1, func2);
}
@ -1258,10 +1258,10 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Simplify1()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).EndWeight =
Weight.FromValue(2.0);
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight =
Weight.FromValue(3.0);
automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).SetEndWeight(
Weight.FromValue(2.0));
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight(
Weight.FromValue(3.0));
for (int i = 0; i < 3; ++i)
{
@ -1281,10 +1281,10 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Simplify2()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight =
Weight.FromValue(3.0);
automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).EndWeight =
Weight.FromValue(2.0);
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight(
Weight.FromValue(3.0));
automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).SetEndWeight(
Weight.FromValue(2.0));
for (int i = 0; i < 3; ++i)
{
@ -1304,10 +1304,10 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Simplify3()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight =
Weight.FromValue(2.0);
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight =
Weight.FromValue(3.0);
automaton.Start.AddTransition('a', Weight.One).AddTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight(
Weight.FromValue(2.0));
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight(
Weight.FromValue(3.0));
for (int i = 0; i < 3; ++i)
{
@ -1327,11 +1327,11 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddEpsilonTransition(Weight.One).AddSelfTransition('a', Weight.One)
.AddEpsilonTransition(Weight.One).AddSelfTransition('b', Weight.One).EndWeight = Weight.FromValue(2.0);
.AddEpsilonTransition(Weight.One).AddSelfTransition('b', Weight.One).SetEndWeight(Weight.FromValue(2.0));
automaton.Start.AddEpsilonTransition(Weight.One).AddSelfTransition('a', Weight.One)
.AddEpsilonTransition(Weight.One).AddSelfTransition('c', Weight.One).EndWeight = Weight.FromValue(3.0);
.AddEpsilonTransition(Weight.One).AddSelfTransition('c', Weight.One).SetEndWeight(Weight.FromValue(3.0));
automaton.Start.AddSelfTransition('x', Weight.One);
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
for (int i = 0; i < 3; ++i)
{
@ -1361,12 +1361,12 @@ namespace Microsoft.ML.Probabilistic.Tests
state = nextState;
}
state.EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
const int AdditionalSequenceCount = 5;
for (int i = 0; i < AdditionalSequenceCount; ++i)
{
automaton.Start.AddTransitionsForSequence(AcceptedSequence).EndWeight = Weight.One;
automaton.Start.AddTransitionsForSequence(AcceptedSequence).SetEndWeight(Weight.One);
}
for (int i = 0; i < 3; ++i)
@ -1393,8 +1393,8 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[2].AddTransition('a', Weight.FromValue(6.0), automaton.States[5]);
automaton.States[4].AddTransition('a', Weight.FromValue(5.0), automaton.States[2]);
automaton.States[3].EndWeight = Weight.FromValue(2.0);
automaton.States[5].EndWeight = Weight.FromValue(3.0);
automaton.States[3].SetEndWeight(Weight.FromValue(2.0));
automaton.States[5].SetEndWeight(Weight.FromValue(3.0));
for (int i = 0; i < 3; ++i)
{
@ -1417,13 +1417,13 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var branch1 = automaton.Start.AddEpsilonTransition(Weight.FromValue(0.5)).AddTransition('a', Weight.FromValue(1.0 / 3.0)).AddTransition('B', Weight.FromValue(1.0 / 4.0));
branch1.EndWeight = Weight.FromValue(3.0);
branch1.AddTransition('X', Weight.FromValue(1.0 / 6.0)).EndWeight = Weight.FromValue(5.0);
branch1.AddEpsilonTransition(Weight.FromValue(1.0 / 8.0)).EndWeight = Weight.FromValue(7.0);
branch1.SetEndWeight(Weight.FromValue(3.0));
branch1.AddTransition('X', Weight.FromValue(1.0 / 6.0)).SetEndWeight(Weight.FromValue(5.0));
branch1.AddEpsilonTransition(Weight.FromValue(1.0 / 8.0)).SetEndWeight(Weight.FromValue(7.0));
var branch2 = automaton.Start.AddTransition(lowerEnglish, Weight.FromValue(2.0));
branch2.EndWeight = Weight.FromValue(4.0);
branch2.SetEndWeight(Weight.FromValue(4.0));
branch2.AddTransition(upperEnglish, Weight.FromValue(3.0), branch2);
branch2.AddTransition('X', Weight.FromValue(4.0)).EndWeight = Weight.FromValue(5.0);
branch2.AddTransition('X', Weight.FromValue(4.0)).SetEndWeight(Weight.FromValue(5.0));
for (int i = 0; i < 3; ++i)
{
@ -1464,10 +1464,10 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[8].AddTransition('b', Weight.One, automaton.States[8]);
automaton.States[8].AddTransition('a', Weight.One, automaton.States[9]);
automaton.States[3].EndWeight = Weight.FromValue(0.1);
automaton.States[6].EndWeight = Weight.FromValue(0.2);
automaton.States[9].EndWeight = Weight.FromValue(0.3);
automaton.States[10].EndWeight = Weight.FromValue(0.4);
automaton.States[3].SetEndWeight(Weight.FromValue(0.1));
automaton.States[6].SetEndWeight(Weight.FromValue(0.2));
automaton.States[9].SetEndWeight(Weight.FromValue(0.3));
automaton.States[10].SetEndWeight(Weight.FromValue(0.4));
for (int i = 0; i < 3; ++i)
{
@ -1518,9 +1518,9 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Determinize1()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).EndWeight = Weight.FromValue(4);
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).EndWeight = Weight.FromValue(7);
automaton.Start.EndWeight = Weight.FromValue(17);
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.FromValue(4));
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7));
automaton.Start.SetEndWeight(Weight.FromValue(17));
Assert.False(automaton.IsDeterministic());
@ -1546,9 +1546,9 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Determinize2()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).EndWeight = Weight.FromValue(4);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).EndWeight = Weight.FromValue(7);
automaton.Start.EndWeight = Weight.FromValue(17);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.FromValue(4));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7));
automaton.Start.SetEndWeight(Weight.FromValue(17));
Assert.False(automaton.IsDeterministic());
@ -1574,10 +1574,10 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Determinize3()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition(DiscreteChar.Uniform(), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).EndWeight = Weight.FromValue(4);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).EndWeight = Weight.FromValue(7);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('x', 'z'), Weight.FromValue(8)).AddTransition('d', Weight.FromValue(9)).EndWeight = Weight.FromValue(10);
automaton.Start.EndWeight = Weight.FromValue(17);
automaton.Start.AddTransition(DiscreteChar.Uniform(), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.FromValue(4));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('x', 'z'), Weight.FromValue(8)).AddTransition('d', Weight.FromValue(9)).SetEndWeight(Weight.FromValue(10));
automaton.Start.SetEndWeight(Weight.FromValue(17));
Assert.False(automaton.IsDeterministic());
@ -1607,8 +1607,8 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Determinize4()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).EndWeight = Weight.FromValue(4);
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('d', Weight.FromValue(6.0)).EndWeight = Weight.FromValue(7);
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).SetEndWeight(Weight.FromValue(4));
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('d', Weight.FromValue(6.0)).SetEndWeight(Weight.FromValue(7));
Assert.False(automaton.IsDeterministic());
@ -1637,8 +1637,8 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Determinize5()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).EndWeight = Weight.FromValue(4);
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(6.0)).EndWeight = Weight.FromValue(7);
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).SetEndWeight(Weight.FromValue(4));
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(6.0)).SetEndWeight(Weight.FromValue(7));
Assert.False(automaton.IsDeterministic());
@ -1664,11 +1664,11 @@ namespace Microsoft.ML.Probabilistic.Tests
public void Determinize6()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'c'), Weight.FromValue(2)).EndWeight = Weight.FromValue(3.0);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'c'), Weight.FromValue(4)).EndWeight = Weight.FromValue(5.0);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'd'), Weight.FromValue(6)).EndWeight = Weight.FromValue(7.0);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'd'), Weight.FromValue(8)).EndWeight = Weight.FromValue(9.0);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'e'), Weight.FromValue(10)).EndWeight = Weight.FromValue(11.0);
automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'c'), Weight.FromValue(2)).SetEndWeight(Weight.FromValue(3.0));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'c'), Weight.FromValue(4)).SetEndWeight(Weight.FromValue(5.0));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'd'), Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7.0));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'd'), Weight.FromValue(8)).SetEndWeight(Weight.FromValue(9.0));
automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'e'), Weight.FromValue(10)).SetEndWeight(Weight.FromValue(11.0));
Assert.False(automaton.IsDeterministic());
@ -1706,9 +1706,9 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[2].AddTransition('b', Weight.FromValue(6), automaton.States[4]);
automaton.States[3].AddTransition('c', Weight.FromValue(7), automaton.States[4]);
automaton.States[2].EndWeight = Weight.FromValue(0.5);
automaton.States[3].EndWeight = Weight.FromValue(1);
automaton.States[4].EndWeight = Weight.FromValue(2);
automaton.States[2].SetEndWeight(Weight.FromValue(0.5));
automaton.States[3].SetEndWeight(Weight.FromValue(1));
automaton.States[4].SetEndWeight(Weight.FromValue(2));
Assert.False(automaton.IsDeterministic());
@ -1747,7 +1747,7 @@ namespace Microsoft.ML.Probabilistic.Tests
state = nextState;
}
state.EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
Assert.False(automaton.IsDeterministic());
@ -1777,11 +1777,11 @@ namespace Microsoft.ML.Probabilistic.Tests
const int TransitionsPerCharacter = 3;
for (int i = 0; i < TransitionsPerCharacter; ++i)
{
automaton.Start.AddTransition('a', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('b', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('d', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('e', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('g', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition('b', Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition('d', Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition('e', Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition('g', Weight.One).SetEndWeight(Weight.One);
}
Assert.False(automaton.IsDeterministic() || TransitionsPerCharacter <= 1);
@ -1817,9 +1817,9 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[2].AddTransition('b', Weight.FromValue(6), automaton.States[4]);
automaton.States[3].AddTransition('c', Weight.FromValue(7), automaton.States[4]);
automaton.States[2].EndWeight = Weight.FromValue(0.5);
automaton.States[3].EndWeight = Weight.FromValue(1);
automaton.States[4].EndWeight = Weight.FromValue(2);
automaton.States[2].SetEndWeight(Weight.FromValue(0.5));
automaton.States[3].SetEndWeight(Weight.FromValue(1));
automaton.States[4].SetEndWeight(Weight.FromValue(2));
Assert.False(automaton.IsDeterministic());
@ -1846,8 +1846,8 @@ namespace Microsoft.ML.Probabilistic.Tests
public void NonDeterminizable1()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).EndWeight = Weight.FromValue(4);
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.1)).AddTransition('c', Weight.FromValue(6.0)).EndWeight = Weight.FromValue(7);
automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).SetEndWeight(Weight.FromValue(4));
automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.1)).AddTransition('c', Weight.FromValue(6.0)).SetEndWeight(Weight.FromValue(7));
Assert.False(automaton.IsDeterministic());
@ -1887,11 +1887,11 @@ namespace Microsoft.ML.Probabilistic.Tests
const int TransitionsPerCharacter = 3;
for (int i = 0; i < TransitionsPerCharacter; ++i)
{
automaton.Start.AddTransition("a", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("b", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("d", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("e", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("g", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("a", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("b", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("d", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("e", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("g", Weight.One).SetEndWeight(Weight.One);
}
Assert.False(automaton.IsDeterministic() || TransitionsPerCharacter <= 1);
@ -1923,11 +1923,11 @@ namespace Microsoft.ML.Probabilistic.Tests
const int TransitionsPerCharacter = 3;
for (int i = 0; i < TransitionsPerCharacter; ++i)
{
automaton.Start.AddTransition("a", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("b", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("d", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("e", Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition(scaledUniform, Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition("a", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("b", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("d", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition("e", Weight.One).SetEndWeight(Weight.One);
automaton.Start.AddTransition(scaledUniform, Weight.One).SetEndWeight(Weight.One);
}
Assert.False(automaton.IsDeterministic() || TransitionsPerCharacter <= 1);
@ -1964,10 +1964,10 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[4].AddTransition(DiscreteChar.UniformOver('i', 'j'), Weight.FromValue(1), automaton.States[5]);
automaton.States[4].AddTransition(DiscreteChar.UniformOver('k', 'l'), Weight.FromValue(1), automaton.States[5]);
automaton.States[1].EndWeight = Weight.FromValue(1);
automaton.States[3].EndWeight = Weight.FromValue(1);
automaton.States[5].EndWeight = Weight.FromValue(1);
automaton.States[6].EndWeight = Weight.FromValue(1);
automaton.States[1].SetEndWeight(Weight.FromValue(1));
automaton.States[3].SetEndWeight(Weight.FromValue(1));
automaton.States[5].SetEndWeight(Weight.FromValue(1));
automaton.States[6].SetEndWeight(Weight.FromValue(1));
var expectedSupport = new HashSet<string>
{
@ -1999,10 +1999,10 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[4].AddTransition(DiscreteChar.UniformOver('i', 'j'), Weight.FromValue(1), automaton.States[5]);
automaton.States[4].AddTransition(DiscreteChar.UniformOver('k', 'l'), Weight.FromValue(1), automaton.States[5]);
automaton.States[1].EndWeight = Weight.FromValue(1);
automaton.States[3].EndWeight = Weight.FromValue(1);
automaton.States[5].EndWeight = Weight.FromValue(1);
automaton.States[6].EndWeight = Weight.FromValue(1);
automaton.States[1].SetEndWeight(Weight.FromValue(1));
automaton.States[3].SetEndWeight(Weight.FromValue(1));
automaton.States[5].SetEndWeight(Weight.FromValue(1));
automaton.States[6].SetEndWeight(Weight.FromValue(1));
int numPasses = 10000;
Stopwatch watch = new Stopwatch();

Просмотреть файл

@ -109,7 +109,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton f = StringAutomaton.Zero();
AddEpsilonLoop(f.Start, 5, 0.5);
f.Start.AddTransitionsForSequence("abc").EndWeight = Weight.One;
f.Start.AddTransitionsForSequence("abc").SetEndWeight(Weight.One);
Assert.Equal("abc", f.TryComputePoint());
}
@ -122,7 +122,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton f = StringAutomaton.Zero();
f.Start.AddTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.Zero, f.Start);
f.Start.AddTransitionsForSequence("abc").EndWeight = Weight.One;
f.Start.AddTransitionsForSequence("abc").SetEndWeight(Weight.One);
Assert.Equal("abc", f.TryComputePoint());
}
@ -135,7 +135,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton f = StringAutomaton.Zero();
f.Start.AddTransition('a', Weight.FromValue(0.5)).AddSelfTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.One);
f.Start.AddTransition('b', Weight.FromValue(0.5)).EndWeight = Weight.One;
f.Start.AddTransition('b', Weight.FromValue(0.5)).SetEndWeight(Weight.One);
Assert.Equal("b", f.TryComputePoint());
}
@ -147,7 +147,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void NoPoint1()
{
StringAutomaton f = StringAutomaton.Zero();
f.Start.AddTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.FromValue(0.5), f.Start).EndWeight = Weight.One;
f.Start.AddTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.FromValue(0.5), f.Start).SetEndWeight(Weight.One);
Assert.Null(f.TryComputePoint());
}
@ -160,7 +160,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton f = StringAutomaton.Zero();
var state = f.Start.AddTransition('a', Weight.FromValue(0.5));
state.EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
state.AddTransition('b', Weight.FromValue(0.5), f.Start);
Assert.Null(f.TryComputePoint());
}
@ -216,7 +216,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton f = StringAutomaton.Zero();
f.Start.AddSelfTransition('x', Weight.Zero);
f.Start.AddTransition('y', Weight.Zero).EndWeight = Weight.One;
f.Start.AddTransition('y', Weight.Zero).SetEndWeight(Weight.One);
Assert.True(f.IsZero());
}
@ -232,10 +232,10 @@ namespace Microsoft.ML.Probabilistic.Tests
public void LoopyArithmetic()
{
StringAutomaton automaton1 = StringAutomaton.Zero();
automaton1.Start.AddTransition('a', Weight.FromValue(4.0)).AddTransition('b', Weight.One, automaton1.Start).EndWeight = Weight.One;
automaton1.Start.AddTransition('a', Weight.FromValue(4.0)).AddTransition('b', Weight.One, automaton1.Start).SetEndWeight(Weight.One);
StringAutomaton automaton2 = StringAutomaton.Zero();
automaton2.Start.AddSelfTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(3)).EndWeight = Weight.One;
automaton2.Start.AddSelfTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.One);
StringAutomaton sum = automaton1.Sum(automaton2);
StringInferenceTestUtilities.TestValue(sum, 2.0, string.Empty);
@ -272,7 +272,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddSelfTransition('a', Weight.FromValue(0.7));
automaton.Start.EndWeight = Weight.FromValue(0.3);
automaton.Start.SetEndWeight(Weight.FromValue(0.3));
AssertStochastic(automaton);
Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
@ -289,11 +289,11 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
var state = automaton.Start;
state.EndWeight = Weight.FromValue(0.1);
state.SetEndWeight(Weight.FromValue(0.1));
state.AddSelfTransition('a', Weight.FromValue(0.7));
state = state.AddTransition('b', Weight.FromValue(0.2));
state.AddSelfTransition('a', Weight.FromValue(0.4));
state.EndWeight = Weight.FromValue(0.6);
state.SetEndWeight(Weight.FromValue(0.6));
AssertStochastic(automaton);
Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
@ -311,14 +311,14 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddSelfTransition('a', Weight.FromValue(0.7));
automaton.Start.EndWeight = Weight.FromValue(0.1);
automaton.Start.SetEndWeight(Weight.FromValue(0.1));
var state1 = automaton.Start.AddTransition('b', Weight.FromValue(0.15));
state1.AddSelfTransition('a', Weight.FromValue(0.4));
state1.EndWeight = Weight.FromValue(0.6);
state1.SetEndWeight(Weight.FromValue(0.6));
var state2 = automaton.Start.AddTransition('c', Weight.FromValue(0.05));
state2.EndWeight = Weight.One;
state2.SetEndWeight(Weight.One);
AssertStochastic(automaton);
Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
@ -336,11 +336,11 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var state = automaton.Start.AddTransition('a', Weight.FromValue(0.9));
state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One;
state.AddTransition('a', Weight.FromValue(0.1)).SetEndWeight(Weight.One);
state = state.AddTransition('a', Weight.FromValue(0.9));
state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One;
state.AddTransition('a', Weight.FromValue(0.1)).SetEndWeight(Weight.One);
state = state.AddTransition('a', Weight.FromValue(0.9), automaton.Start);
state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One;
state.AddTransition('a', Weight.FromValue(0.1)).SetEndWeight(Weight.One);
AssertStochastic(automaton);
Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
@ -358,7 +358,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0));
endState.EndWeight = Weight.FromValue(5.0);
endState.SetEndWeight(Weight.FromValue(5.0));
endState.AddTransition('b', Weight.FromValue(0.25), automaton.Start);
endState.AddTransition('c', Weight.FromValue(0.2), automaton.Start);
@ -377,7 +377,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0));
endState.EndWeight = Weight.FromValue(5.0);
endState.SetEndWeight(Weight.FromValue(5.0));
endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start);
endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start);
endState.AddSelfTransition('!', Weight.FromValue(0.5));
@ -407,13 +407,13 @@ namespace Microsoft.ML.Probabilistic.Tests
AddEpsilonLoop(automaton.Start, 3, 0.2);
AddEpsilonLoop(automaton.Start, 5, 0.3);
automaton.Start.EndWeight = Weight.FromValue(0.1);
automaton.Start.SetEndWeight(Weight.FromValue(0.1));
var nextState = automaton.Start.AddTransition('a', Weight.FromValue(0.4));
nextState.EndWeight = Weight.FromValue(0.6);
nextState.SetEndWeight(Weight.FromValue(0.6));
AddEpsilonLoop(nextState, 0, 0.3);
nextState = nextState.AddTransition('b', Weight.FromValue(0.1));
AddEpsilonLoop(nextState, 1, 0.9);
nextState.EndWeight = Weight.FromValue(0.1);
nextState.SetEndWeight(Weight.FromValue(0.1));
AssertStochastic(automaton);
Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
@ -433,21 +433,21 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.States[0].AddEpsilonTransition(Weight.FromValue(0.2), automaton.States[1]);
automaton.States[0].AddEpsilonTransition(Weight.FromValue(0.5), automaton.States[3]);
automaton.States[0].EndWeight = Weight.FromValue(0.3);
automaton.States[0].SetEndWeight(Weight.FromValue(0.3));
automaton.States[1].AddEpsilonTransition(Weight.FromValue(0.8), automaton.States[0]);
automaton.States[1].AddEpsilonTransition(Weight.FromValue(0.1), automaton.States[2]);
automaton.States[1].EndWeight = Weight.FromValue(0.1);
automaton.States[2].EndWeight = Weight.FromValue(1.0);
automaton.States[1].SetEndWeight(Weight.FromValue(0.1));
automaton.States[2].SetEndWeight(Weight.FromValue(1.0));
automaton.States[3].AddEpsilonTransition(Weight.FromValue(0.2), automaton.States[4]);
automaton.States[3].AddEpsilonTransition(Weight.FromValue(0.1), automaton.States[5]);
automaton.States[3].EndWeight = Weight.FromValue(0.7);
automaton.States[3].SetEndWeight(Weight.FromValue(0.7));
automaton.States[4].AddEpsilonTransition(Weight.FromValue(0.5), automaton.States[2]);
automaton.States[4].AddEpsilonTransition(Weight.FromValue(0.5), automaton.States[6]);
automaton.States[4].EndWeight = Weight.FromValue(0.0);
automaton.States[4].SetEndWeight(Weight.FromValue(0.0));
automaton.States[5].AddEpsilonTransition(Weight.FromValue(0.1), automaton.States[3]);
automaton.States[5].AddEpsilonTransition(Weight.FromValue(0.9), automaton.States[6]);
automaton.States[5].EndWeight = Weight.Zero;
automaton.States[6].EndWeight = Weight.One;
automaton.States[5].SetEndWeight(Weight.Zero);
automaton.States[6].SetEndWeight(Weight.One);
AssertStochastic(automaton);
Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6);
@ -465,7 +465,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var endState = automaton.Start.AddTransition('a', Weight.FromValue(3.5));
endState.EndWeight = Weight.FromValue(5.0);
endState.SetEndWeight(Weight.FromValue(5.0));
endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start);
endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start);
endState.AddSelfTransition('!', Weight.FromValue(0.5));
@ -486,7 +486,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0));
endState.EndWeight = Weight.FromValue(5.0);
endState.SetEndWeight(Weight.FromValue(5.0));
endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start);
endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start);
endState.AddSelfTransition('!', Weight.FromValue(0.75));
@ -506,7 +506,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.FromValue(2.0), automaton.Start);
automaton.Start.EndWeight = Weight.FromValue(5.0);
automaton.Start.SetEndWeight(Weight.FromValue(5.0));
StringAutomaton copyOfAutomaton = automaton.Clone();
Assert.Throws<InvalidOperationException>(() => automaton.NormalizeValues());
@ -525,9 +525,9 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.Start.AddSelfTransition('a', Weight.FromValue(0.1));
var branch1 = automaton.Start.AddTransition('a', Weight.FromValue(2.0));
branch1.AddSelfTransition('a', Weight.FromValue(2.0));
branch1.EndWeight = Weight.One;
branch1.SetEndWeight(Weight.One);
var branch2 = automaton.Start.AddTransition('a', Weight.FromValue(2.0));
branch2.EndWeight = Weight.One;
branch2.SetEndWeight(Weight.One);
StringAutomaton copyOfAutomaton = automaton.Clone();
Assert.Throws<InvalidOperationException>(() => automaton.NormalizeValues());
@ -543,7 +543,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public void NormalizeWithInfiniteEpsilon1()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(3)).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(3)).SetEndWeight(Weight.One);
// The automaton takes an infinite value on "a", and yet the normalization must work
Assert.True(automaton.TryNormalizeValues());
@ -559,8 +559,8 @@ namespace Microsoft.ML.Probabilistic.Tests
public void NormalizeWithInfiniteEpsilon2()
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(2)).EndWeight = Weight.One;
automaton.Start.AddTransition('b', Weight.One).AddSelfTransition(null, Weight.FromValue(1)).EndWeight = Weight.One;
automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(2)).SetEndWeight(Weight.One);
automaton.Start.AddTransition('b', Weight.One).AddSelfTransition(null, Weight.FromValue(1)).SetEndWeight(Weight.One);
// "a" branch infinitely dominates over the "b" branch
Assert.True(automaton.TryNormalizeValues());
@ -583,7 +583,7 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.Start.AddEpsilonTransition(Weight.FromValue(0.5), automaton.Start);
var nextState = automaton.Start.AddEpsilonTransition(Weight.FromValue(0.4));
nextState.AddEpsilonTransition(Weight.One).AddEpsilonTransition(Weight.One, automaton.Start);
automaton.Start.EndWeight = Weight.FromValue(0.1);
automaton.Start.SetEndWeight(Weight.FromValue(0.1));
AssertStochastic(automaton);
@ -614,7 +614,7 @@ namespace Microsoft.ML.Probabilistic.Tests
var middleNode = automaton.Start.AddTransition('a', Weight.One);
middleNode.AddTransitionsForSequence("bbb", automaton.Start);
middleNode.AddTransition('c', Weight.One, automaton.Start);
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
Assert.Equal("(a(c|bbb))*", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("(a(c|bbb))*", automaton.ToString(AutomatonFormats.Regexp));
@ -630,7 +630,7 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddSelfTransition('a', Weight.One);
automaton.Start.AddSelfTransition('b', Weight.One);
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
Assert.Equal("(a|b)*", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("(a|b)*", automaton.ToString(AutomatonFormats.Regexp));
@ -649,8 +649,8 @@ namespace Microsoft.ML.Probabilistic.Tests
automaton.Start.AddTransition('y', Weight.One, state);
state.AddSelfTransition('a', Weight.One);
state.AddSelfTransition('b', Weight.One);
state.EndWeight = Weight.One;
state.AddTransitionsForSequence("zzz").EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
state.AddTransitionsForSequence("zzz").SetEndWeight(Weight.One);
Assert.Equal("(x|y)(a|b)*[zzz]", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("(x|y)(a|b)*(|zzz)", automaton.ToString(AutomatonFormats.Regexp));
@ -665,7 +665,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddTransitionsForSequence("xyz", automaton.Start);
automaton.Start.AddTransition('!', Weight.One).EndWeight = Weight.One;
automaton.Start.AddTransition('!', Weight.One).SetEndWeight(Weight.One);
Assert.Equal("(xyz)*!", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("(xyz)*!", automaton.ToString(AutomatonFormats.Regexp));
}
@ -681,7 +681,7 @@ namespace Microsoft.ML.Probabilistic.Tests
var state = automaton.Start.AddTransition('x', Weight.One);
automaton.Start.AddTransition('y', Weight.Zero, state);
state.AddSelfTransition('a', Weight.One);
state.EndWeight = Weight.One;
state.SetEndWeight(Weight.One);
Assert.Equal("xa*", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("xa*", automaton.ToString(AutomatonFormats.Regexp));
@ -696,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddSelfTransition('x', Weight.Zero);
automaton.Start.AddTransition('y', Weight.Zero).EndWeight = Weight.One;
automaton.Start.AddTransition('y', Weight.Zero).SetEndWeight(Weight.One);
Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Friendly));
Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Regexp));
@ -766,7 +766,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
currentState = currentState.AddEpsilonTransition(
i == 0 ? Weight.FromValue(loopWeight) : Weight.One,
i == loopSize ? state : null);
i == loopSize ? state : default(StringAutomaton.State));
}
}

Просмотреть файл

@ -579,7 +579,7 @@ namespace Microsoft.ML.Probabilistic.Tests
// The length of sequences sampled from this distribution must follow a geometric distribution
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start = automaton.AddState();
automaton.Start.EndWeight = Weight.FromValue(StoppingProbability);
automaton.Start.SetEndWeight(Weight.FromValue(StoppingProbability));
automaton.Start.AddTransition('a', Weight.FromValue(1 - StoppingProbability), automaton.Start);
StringDistribution dist = StringDistribution.FromWeightFunction(automaton);

Просмотреть файл

@ -38,8 +38,8 @@ namespace Microsoft.ML.Probabilistic.Tests
StringAutomaton automaton = StringAutomaton.Zero();
var nextState = automaton.Start.AddTransitionsForSequence("abc");
nextState.AddSelfTransition('d', Weight.FromValue(0.1));
nextState.AddTransitionsForSequence("efg").EndWeight = Weight.One;
nextState.AddTransitionsForSequence("hejfhoenmf").EndWeight = Weight.One;
nextState.AddTransitionsForSequence("efg").SetEndWeight(Weight.One);
nextState.AddTransitionsForSequence("hejfhoenmf").SetEndWeight(Weight.One);
ProfileAction(() => automaton.GetLogNormalizer(), 100000);
}, 10000);
@ -57,13 +57,13 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
var nextState = automaton.Start.AddTransitionsForSequence("abc");
nextState.EndWeight = Weight.One;
nextState.SetEndWeight(Weight.One);
nextState.AddSelfTransition('d', Weight.FromValue(0.1));
nextState = nextState.AddTransitionsForSequence("efg");
nextState.EndWeight = Weight.One;
nextState.SetEndWeight(Weight.One);
nextState.AddSelfTransition('h', Weight.FromValue(0.2));
nextState = nextState.AddTransitionsForSequence("grlkhgn;lk3rng");
nextState.EndWeight = Weight.One;
nextState.SetEndWeight(Weight.One);
nextState.AddSelfTransition('h', Weight.FromValue(0.3));
ProfileAction(() => automaton.GetLogNormalizer(), 100000);
@ -82,10 +82,10 @@ namespace Microsoft.ML.Probabilistic.Tests
{
StringAutomaton automaton = StringAutomaton.Zero();
automaton.Start.AddSelfTransition('a', Weight.FromValue(0.5));
automaton.Start.EndWeight = Weight.One;
automaton.Start.SetEndWeight(Weight.One);
var nextState = automaton.Start.AddTransitionsForSequence("aa");
nextState.AddSelfTransition('a', Weight.FromValue(0.5));
nextState.EndWeight = Weight.One;
nextState.SetEndWeight(Weight.One);
for (int i = 0; i < 3; ++i)
{