diff --git a/src/Runtime/Distributions/Automata/Automaton.Builder.cs b/src/Runtime/Distributions/Automata/Automaton.Builder.cs index b71bff4a..ae7e0e94 100644 --- a/src/Runtime/Distributions/Automata/Automaton.Builder.cs +++ b/src/Runtime/Distributions/Automata/Automaton.Builder.cs @@ -20,7 +20,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// States created so far. /// - private readonly List states; + private readonly List states; /// /// Transitions created so far. @@ -29,8 +29,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// Unlike in , transitions /// for single entity are not represented by contiguous segment of array, but rather as a linked /// list. It is done this way, because transitions can be added at any moment and inserting - /// transition into a middle of array is not feasible. - /// references head of list and references last element of list. + /// transition into a middle of array is not feasible. /// private readonly List transitions; @@ -45,7 +44,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public Builder(int startStateCount = 1) { - this.states = new List(); + this.states = new List(); this.transitions = new List(); this.AddStates(startStateCount); } @@ -120,10 +119,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata var index = this.states.Count; this.states.Add( - new StateData + new LinkedStateData { - FirstTransition = -1, - LastTransition = -1, + FirstTransitionIndex = -1, + LastTransitionIndex = -1, EndWeight = Weight.Zero, }); return new StateBuilder(this, index); @@ -357,7 +356,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata for (var i = 0; i < oldStateCount; ++i) { var state = this.states[i]; - if (state.CanEnd && state.FirstTransition != -1) + if (!state.EndWeight.IsZero && state.FirstTransitionIndex != -1) { return false; } @@ -397,9 +396,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata for (var i = 0; i < resultStates.Length; ++i) { - var state = this.states[i]; - var transitionIndex = state.FirstTransition; - state.FirstTransition = nextResultTransitionIndex; + var firstResultTransitionIndex = nextResultTransitionIndex; + var transitionIndex = this.states[i].FirstTransitionIndex; while (transitionIndex != -1) { var node = this.transitions[transitionIndex]; @@ -415,8 +413,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata transitionIndex = node.Next; } - state.LastTransition = nextResultTransitionIndex; - resultStates[i] = state; + resultStates[i] = new StateData( + firstResultTransitionIndex, + nextResultTransitionIndex - firstResultTransitionIndex, + this.states[i].EndWeight); } Debug.Assert( @@ -470,7 +470,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// Gets a value indicating whether the ending weight of this state is greater than zero. /// - public bool CanEnd => this.builder.states[this.Index].CanEnd; + public bool CanEnd => !this.builder.states[this.Index].EndWeight.IsZero; /// /// Gets or sets the ending weight of the state. @@ -486,7 +486,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// Gets over transitions of this state. /// public TransitionIterator TransitionIterator => - new TransitionIterator(this.builder, this.Index, this.builder.states[this.Index].FirstTransition); + new TransitionIterator(this.builder, this.Index, this.builder.states[this.Index].FirstTransitionIndex); /// /// Initializes a new instance of struct. @@ -525,22 +525,22 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { Transition = transition, Next = -1, - Prev = state.LastTransition, + Prev = state.LastTransitionIndex, }); - if (state.LastTransition != -1) + if (state.LastTransitionIndex != -1) { // update "next" field in old tail - var oldTail = this.builder.transitions[state.LastTransition]; + var oldTail = this.builder.transitions[state.LastTransitionIndex]; oldTail.Next = transitionIndex; - this.builder.transitions[state.LastTransition] = oldTail; + this.builder.transitions[state.LastTransitionIndex] = oldTail; } else { - state.FirstTransition = transitionIndex; + state.FirstTransitionIndex = transitionIndex; } - state.LastTransition = transitionIndex; + state.LastTransitionIndex = transitionIndex; this.builder.states[this.Index] = state; @@ -759,16 +759,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } // update references in state - if (state.FirstTransition == this.index || state.LastTransition == this.index) + if (state.FirstTransitionIndex == this.index || state.LastTransitionIndex == this.index) { - if (state.FirstTransition == this.index) + if (state.FirstTransitionIndex == this.index) { - state.FirstTransition = node.Next; + state.FirstTransitionIndex = node.Next; } - if (state.LastTransition == this.index) + if (state.LastTransitionIndex == this.index) { - state.LastTransition = node.Prev; + state.LastTransitionIndex = node.Prev; } this.builder.states[this.stateIndex] = state; @@ -793,6 +793,29 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } + /// + /// Version of that is used during automaton constructions. Unlike + /// regular transitions are stored as a linked list over + /// array. + /// + private struct LinkedStateData + { + /// + /// Index of the head of transitions list in . + /// + public int FirstTransitionIndex { get; internal set; } + + /// + /// Index of the tail of transitions list in . + /// + public int LastTransitionIndex { get; internal set; } + + /// + /// Ending weight of the state. + /// + public Weight EndWeight { get; internal set; } + } + /// /// Linked list node for representing transitions for state. /// diff --git a/src/Runtime/Distributions/Automata/Automaton.DataContainer.cs b/src/Runtime/Distributions/Automata/Automaton.DataContainer.cs index 15d9b561..705d19ed 100644 --- a/src/Runtime/Distributions/Automata/Automaton.DataContainer.cs +++ b/src/Runtime/Distributions/Automata/Automaton.DataContainer.cs @@ -108,12 +108,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata foreach (var state in this.States) { - if (state.FirstTransition < 0 || state.LastTransition > this.Transitions.Count) + var lastTransitionIndex = state.FirstTransitionIndex + state.TransitionsCount; + if (state.FirstTransitionIndex < 0 || lastTransitionIndex > this.Transitions.Count) { return false; } - for (var i = state.FirstTransition; i < state.LastTransition; ++i) + for (var i = state.FirstTransitionIndex; i < lastTransitionIndex; ++i) { var transition = this.Transitions[i]; if (transition.DestinationStateIndex < 0 || diff --git a/src/Runtime/Distributions/Automata/Automaton.State.cs b/src/Runtime/Distributions/Automata/Automaton.State.cs index 4b5e40da..159a0d8b 100644 --- a/src/Runtime/Distributions/Automata/Automaton.State.cs +++ b/src/Runtime/Distributions/Automata/Automaton.State.cs @@ -67,8 +67,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata public ReadOnlyArraySegment Transitions => new ReadOnlyArraySegment( this.transitions, - this.Data.FirstTransition, - this.Data.LastTransition - this.Data.FirstTransition); + this.Data.FirstTransitionIndex, + this.Data.TransitionsCount); internal StateData Data => this.states[this.Index]; diff --git a/src/Runtime/Distributions/Automata/Automaton.StateData.cs b/src/Runtime/Distributions/Automata/Automaton.StateData.cs index a13ea643..c76b8994 100644 --- a/src/Runtime/Distributions/Automata/Automaton.StateData.cs +++ b/src/Runtime/Distributions/Automata/Automaton.StateData.cs @@ -29,20 +29,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// this property contains index of the head of the linked-list of transitions. /// [DataMember] - public int FirstTransition { get; internal set; } + public int FirstTransitionIndex { get; internal set; } /// - /// Gets or sets index of the first transition in after - /// which does not belong to this state. All transitions for + /// Gets or sets count of transition in after + /// which belong to this state. All transitions for /// the same state are stored as a contiguous block. /// - /// - /// During automaton construction - /// stores transitions as linked-list instead of contiguous block. So, during construction - /// this property contains index of the tail of the linked-list of transitions. - /// [DataMember] - public int LastTransition { get; internal set; } + public int TransitionsCount { get; internal set; } /// /// Gets or sets ending weight of the state. @@ -53,11 +48,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// Initializes a new instance of the struct. /// - [Construction("FirstTransition", "LastTransition", "EndWeight")] - public StateData(int firstTransition, int lastTransition, Weight endWeight) + [Construction("FirstTransitionIndex", "TransitionsCount", "EndWeight")] + public StateData(int firstTransitionIndex, int transitionsCount, Weight endWeight) { - this.FirstTransition = firstTransition; - this.LastTransition = lastTransition; + this.FirstTransitionIndex = firstTransitionIndex; + this.TransitionsCount = transitionsCount; this.EndWeight = endWeight; } diff --git a/test/Tests/Strings/AutomatonTests.cs b/test/Tests/Strings/AutomatonTests.cs index d81d4503..d72225d6 100644 --- a/test/Tests/Strings/AutomatonTests.cs +++ b/test/Tests/Strings/AutomatonTests.cs @@ -864,7 +864,7 @@ namespace Microsoft.ML.Probabilistic.Tests new[] { new StringAutomaton.StateData(0, 1, Weight.One), - new StringAutomaton.StateData(1, 1, Weight.One), + new StringAutomaton.StateData(1, 0, Weight.One), }, new[] { new StringAutomaton.Transition(DiscreteChar.PointMass('a'), Weight.One, 1) }));