diff --git a/src/Runtime/Distributions/Automata/Automaton.Condensation.cs b/src/Runtime/Distributions/Automata/Automaton.Condensation.cs index e2e3fd87..1a806e65 100644 --- a/src/Runtime/Distributions/Automata/Automaton.Condensation.cs +++ b/src/Runtime/Distributions/Automata/Automaton.Condensation.cs @@ -46,9 +46,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The computed condensation. public Condensation ComputeCondensation(State root, Func transitionFilter, bool useApproximateClosure) { - Argument.CheckIfNotNull(root, "root"); - Argument.CheckIfNotNull(transitionFilter, "transitionFilter"); - Argument.CheckIfValid(ReferenceEquals(root.Owner, this), "root", "The given node belongs to a different automaton."); + Argument.CheckIfValid(!root.IsNull, nameof(root)); + Argument.CheckIfNotNull(transitionFilter, nameof(transitionFilter)); + Argument.CheckIfValid(ReferenceEquals(root.Owner, this), nameof(root), "The given node belongs to a different automaton."); return new Condensation(root, transitionFilter, useApproximateClosure); } @@ -108,7 +108,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// internal Condensation(State root, Func transitionFilter, bool useApproximateClosure) { - Debug.Assert(root != null, "A valid root node must be provided."); + Debug.Assert(!root.IsNull, "A valid root node must be provided."); Debug.Assert(transitionFilter != null, "A valid transition filter must be provided."); this.Root = root; @@ -164,7 +164,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The computed total weight. public Weight GetWeightToEnd(State state) { - Argument.CheckIfNotNull(state, "state"); + Argument.CheckIfValid(!state.IsNull, nameof(state)); Argument.CheckIfValid(ReferenceEquals(state.Owner, this.Root.Owner), "state", "The given state belongs to a different automaton."); if (!this.weightsToEndComputed) @@ -189,7 +189,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The computed total weight. public Weight GetWeightFromRoot(State state) { - Argument.CheckIfNotNull(state, "state"); + Argument.CheckIfValid(!state.IsNull, nameof(state)); Argument.CheckIfValid(ReferenceEquals(state.Owner, this.Root.Owner), "state", "The given state belongs to a different automaton."); if (!this.weightsFromRootComputed) @@ -241,7 +241,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata if (!stateIdToStateInfo.TryGetValue(transition.DestinationStateIndex, out destinationStateInfo)) { this.FindStronglyConnectedComponents( - this.Root.Owner.states[transition.DestinationStateIndex], ref traversalIndex, stateIdToStateInfo, stateIdStack); + this.Root.Owner.States[transition.DestinationStateIndex], ref traversalIndex, stateIdToStateInfo, stateIdStack); stateInfo.Lowlink = Math.Min(stateInfo.Lowlink, stateIdToStateInfo[transition.DestinationStateIndex].Lowlink); } else if (destinationStateInfo.InStack) @@ -288,7 +288,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata for (int transitionIndex = 0; transitionIndex < state.TransitionCount; ++transitionIndex) { Transition transition = state.GetTransition(transitionIndex); - State destState = state.Owner.states[transition.DestinationStateIndex]; + State destState = state.Owner.States[transition.DestinationStateIndex]; if (this.transitionFilter(transition) && !currentComponent.HasState(destState)) { weightToAdd = Weight.Sum( @@ -367,7 +367,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata for (int transitionIndex = 0; transitionIndex < srcState.TransitionCount; ++transitionIndex) { Transition transition = srcState.GetTransition(transitionIndex); - State destState = srcState.Owner.states[transition.DestinationStateIndex]; + State destState = srcState.Owner.States[transition.DestinationStateIndex]; if (this.transitionFilter(transition) && !currentComponent.HasState(destState)) { CondensationStateInfo destStateInfo = this.stateIdToInfo[destState.Index]; diff --git a/src/Runtime/Distributions/Automata/Automaton.EpsilonClosure.cs b/src/Runtime/Distributions/Automata/Automaton.EpsilonClosure.cs index 03fe3c76..bd4934ed 100644 --- a/src/Runtime/Distributions/Automata/Automaton.EpsilonClosure.cs +++ b/src/Runtime/Distributions/Automata/Automaton.EpsilonClosure.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; + namespace Microsoft.ML.Probabilistic.Distributions.Automata { using System.Collections.Generic; @@ -40,7 +42,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The state, which epsilon closure this instance will represent. internal EpsilonClosure(State state) { - Argument.CheckIfNotNull(state, "state"); + Argument.CheckIfValid(!state.IsNull, nameof(state)); // Optimize for a very common case: a single-node closure bool singleNodeClosure = true; diff --git a/src/Runtime/Distributions/Automata/Automaton.GroupExtractor.cs b/src/Runtime/Distributions/Automata/Automaton.GroupExtractor.cs index db097e0f..bc918860 100644 --- a/src/Runtime/Distributions/Automata/Automaton.GroupExtractor.cs +++ b/src/Runtime/Distributions/Automata/Automaton.GroupExtractor.cs @@ -23,17 +23,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { internal static Dictionary ExtractGroups(Automaton automaton) { - Dictionary> subGraphs; - var order = ComputeTopologicalOrderAndGroupSubgraphs(automaton, out subGraphs); - return BuildSubautomata(automaton.states, order, subGraphs); + var order = ComputeTopologicalOrderAndGroupSubgraphs(automaton, out var subGraphs); + return BuildSubautomata(automaton.States, order, subGraphs); } private static Dictionary BuildSubautomata( - List states, - List topologicalOrder, + IReadOnlyList states, + IReadOnlyList topologicalOrder, Dictionary> groupSubGraphs) => groupSubGraphs.ToDictionary(g => g.Key, g => BuildSubautomaton(states, topologicalOrder, g.Key, g.Value)); - private static TThis BuildSubautomaton(List states, List topologicalOrder, int group, HashSet subgraph) + private static TThis BuildSubautomaton(IReadOnlyList states, IReadOnlyList topologicalOrder, int group, HashSet subgraph) { var weightsFromRoot = ComputeWeightsFromRoot(states.Count, topologicalOrder, group); var weightsToEnd = ComputeWeightsToEnd(states.Count, topologicalOrder, group); @@ -69,14 +68,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata var weightFromRoot = newSourceState.TransitionCount > 0 ? weightsFromRoot[stateIndex] : Weight.Zero; if (!weightFromRoot.IsZero) { - subautomaton.startState.AddEpsilonTransition(weightFromRoot, newSourceState); + subautomaton.Start.AddEpsilonTransition(weightFromRoot, newSourceState); } // consider end states var weightToEnd = !hasNoIncomingTransitions.Contains(stateIndex) ? weightsToEnd[stateIndex] : Weight.Zero; if (!weightToEnd.IsZero) { - newSourceState.EndWeight = weightToEnd; + newSourceState.SetEndWeight(weightToEnd); } correctionFactor = Weight.Sum(correctionFactor, Weight.Product(weightFromRoot, weightToEnd)); @@ -84,7 +83,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata if (!correctionFactor.IsZero) throw new Exception("Write a unit test for this case. Code should be fine."); var epsilonWeight = Weight.AbsoluteDifference(weightsToEnd[topologicalOrder[0].Index], correctionFactor); - subautomaton.startState.EndWeight = epsilonWeight; + subautomaton.Start.SetEndWeight(epsilonWeight); return subautomaton; } @@ -106,16 +105,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata private static List ComputeTopologicalOrderAndGroupSubgraphs(Automaton automaton, out Dictionary> groupSubGraphs) { var topologicalOrder = new Stack(); - var states = automaton.states; - var temporary = new BitArray(states.Count); - var permanent = new BitArray(states.Count); + var temporary = new BitArray(automaton.States.Count); + var permanent = new BitArray(automaton.States.Count); groupSubGraphs = new Dictionary>(); - VisitNode(states, automaton.startState.Index, temporary, permanent, groupSubGraphs, topologicalOrder); - return topologicalOrder.Select(idx => states[idx]).ToList(); + VisitNode(automaton.States, automaton.Start.Index, temporary, permanent, groupSubGraphs, topologicalOrder); + return topologicalOrder.Select(idx => automaton.States[idx]).ToList(); } - private static void VisitNode(List states, int stateIdx, BitArray temporary, BitArray permanent, Dictionary> groupSubGraphs, Stack topologicalOrder) + private static void VisitNode(IReadOnlyList states, int stateIdx, BitArray temporary, BitArray permanent, Dictionary> groupSubGraphs, Stack topologicalOrder) { if (temporary[stateIdx]) { @@ -158,7 +156,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// Ending weights are taken into account. /// /// The weights are computed using dynamic programming, going up from leafs to the root. - private static Weight[] ComputeWeightsToEnd(int nStates, List topologicalOrder, int group) + private static Weight[] ComputeWeightsToEnd(int nStates, IReadOnlyList topologicalOrder, int group) { var weights = CreateZeroWeights(nStates); // Iterate in the reverse topological order @@ -190,7 +188,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// and ending at that state. Ending weights are not taken into account. /// /// The weights are computed using dynamic programming, going down from the root to leafs. - private static Weight[] ComputeWeightsFromRoot(int nStates, List topologicalOrder, int group) + private static Weight[] ComputeWeightsFromRoot(int nStates, IReadOnlyList topologicalOrder, int group) { var weights = CreateZeroWeights(nStates); weights[topologicalOrder[0].Index] = Weight.One; diff --git a/src/Runtime/Distributions/Automata/Automaton.Simplification.cs b/src/Runtime/Distributions/Automata/Automaton.Simplification.cs index bc38df7e..e9ecac0b 100644 --- a/src/Runtime/Distributions/Automata/Automaton.Simplification.cs +++ b/src/Runtime/Distributions/Automata/Automaton.Simplification.cs @@ -43,9 +43,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata return false; } - for (int stateId = 0; stateId < this.states.Count; ++stateId) + for (int stateId = 0; stateId < this.States.Count; ++stateId) { - var state = this.states[stateId]; + var state = this.States[stateId]; // There should be no epsilon transitions for (int transitionIndex = 0; transitionIndex < state.TransitionCount; ++transitionIndex) @@ -135,7 +135,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata var startWeightedStateSet = new Determinization.WeightedStateSet { { this.Start.Index, Weight.One } }; weightedStateSetQueue.Enqueue(startWeightedStateSet); weightedStateSetToNewState.Add(startWeightedStateSet, result.Start); - result.Start.EndWeight = this.Start.EndWeight; + result.Start.SetEndWeight(this.Start.EndWeight); while (weightedStateSetQueue.Count > 0) { @@ -169,12 +169,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata weightedStateSetQueue.Enqueue(destWeightedStateSet); // Compute its ending weight - destinationState.EndWeight = Weight.Zero; + destinationState.SetEndWeight(Weight.Zero); foreach (KeyValuePair stateIdWithWeight in destWeightedStateSet) { - destinationState.EndWeight = Weight.Sum( + destinationState.SetEndWeight(Weight.Sum( destinationState.EndWeight, - Weight.Product(stateIdWithWeight.Value, this.States[stateIdWithWeight.Key].EndWeight)); + Weight.Product(stateIdWithWeight.Value, this.States[stateIdWithWeight.Key].EndWeight))); } } @@ -202,7 +202,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public void SimplifyIfNeeded() { - if (this.states.Count > MaxStateCountBeforeSimplification || this.PruneTransitionsWithLogWeightLessThan != null) + if (this.States.Count > MaxStateCountBeforeSimplification || this.PruneTransitionsWithLogWeightLessThan != null) { ////Console.WriteLine(this.ToString(AutomatonFormats.GraphViz)); this.Simplify(); @@ -306,10 +306,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// If the number of stats to remove is less than this value, the removal will not be done. private void RemoveStates(bool[] statesToKeep, int minStatesToActuallyRemove) { - int[] oldToNewStateIdMapping = new int[this.states.Count]; + int[] oldToNewStateIdMapping = new int[this.States.Count]; int newStateId = 0; int deadStateCount = 0; - for (int stateId = 0; stateId < this.states.Count; ++stateId) + for (int stateId = 0; stateId < this.States.Count; ++stateId) { if (statesToKeep[stateId]) { @@ -340,16 +340,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata funcWithoutStates.Start = funcWithoutStates.States[oldToNewStateIdMapping[this.Start.Index]]; funcWithoutStates.LogValueOverride = this.LogValueOverride; funcWithoutStates.PruneTransitionsWithLogWeightLessThan = this.PruneTransitionsWithLogWeightLessThan; - for (int i = 0; i < this.states.Count; ++i) + for (int i = 0; i < this.States.Count; ++i) { if (oldToNewStateIdMapping[i] == -1) { continue; } - State oldState = this.states[i]; + State oldState = this.States[i]; State newState = funcWithoutStates.States[oldToNewStateIdMapping[i]]; - newState.EndWeight = oldState.EndWeight; + newState.SetEndWeight(oldState.EndWeight); for (int transitionIndex = 0; transitionIndex < oldState.TransitionCount; ++transitionIndex) { Transition transition = oldState.GetTransition(transitionIndex); @@ -374,7 +374,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The smallest log weight that a transition can have and not be removed. public void RemoveTransitionsWithSmallWeights(double logWeightThreshold) { - foreach (var state in this.states) + foreach (var state in this.States) { for (int i = state.TransitionCount-1; i >=0; i--) { @@ -441,7 +441,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata // Self-loops are allowed if (transition.DestinationStateIndex != currentState.Index) { - isGeneralizedTree &= this.DoLabelStatesForSimplification(this.states[transition.DestinationStateIndex], stateLabels); + isGeneralizedTree &= this.DoLabelStatesForSimplification(this.States[transition.DestinationStateIndex], stateLabels); } } @@ -455,9 +455,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// private void MergeParallelTransitions() { - for (int stateIndex = 0; stateIndex < this.states.Count; ++stateIndex) + for (int stateIndex = 0; stateIndex < this.States.Count; ++stateIndex) { - State state = this.states[stateIndex]; + State state = this.States[stateIndex]; for (int transitionIndex1 = 0; transitionIndex1 < state.TransitionCount; ++transitionIndex1) { Transition transition1 = state.GetTransition(transitionIndex1); @@ -541,13 +541,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } copiedState = this.AddState(); - copiedState.EndWeight = stateToCopy.EndWeight; + copiedState.SetEndWeight(stateToCopy.EndWeight); copiedStateCache.Add(stateToCopy.Index, copiedState); for (int i = 0; i < stateToCopy.TransitionCount; ++i) { Transition transitionToCopy = stateToCopy.GetTransition(i); - State destStateToCopy = stateToCopy.Owner.states[transitionToCopy.DestinationStateIndex]; + State destStateToCopy = stateToCopy.Owner.States[transitionToCopy.DestinationStateIndex]; if (!lookAtLabels || !stateLabels[destStateToCopy.Index]) { State copiedDestState = this.DoCopyNonSimplifiable(destStateToCopy, stateLabels, false, copiedStateCache); @@ -696,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata stack.Push( new StateWeight( - states[transition.DestinationStateIndex], + States[transition.DestinationStateIndex], Weight.Product(currentWeight, transition.Weight))); if (!transition.IsEpsilon) @@ -784,7 +784,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } } - state.EndWeight = Weight.Sum(state.EndWeight, weight); + state.SetEndWeight(Weight.Sum(state.EndWeight, weight)); return true; } @@ -802,7 +802,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata if (transition.DestinationStateIndex != state.Index && transition.IsEpsilon && transition.DestinationStateIndex >= firstAllowedStateIndex) { if (this.DoAddGeneralizedSequence( - this.states[transition.DestinationStateIndex], + this.States[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, @@ -831,7 +831,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { // Try this epsilon transition if (this.DoAddGeneralizedSequence( - this.states[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight)) + this.States[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight)) { return true; } @@ -878,7 +878,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { // Try this epsilon transition if (this.DoAddGeneralizedSequence( - this.states[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight)) + this.States[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, currentSequencePos, sequence, weight)) { return true; } @@ -908,7 +908,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata // Weight of the existing transition must be taken into account // This case can fail if the next element is a self-loop and the destination state already has a different one if (this.DoAddGeneralizedSequence( - this.states[transition.DestinationStateIndex], + this.States[transition.DestinationStateIndex], false, false, firstAllowedStateIndex, @@ -921,7 +921,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } // Add a new transition - State newChild = state.AddTransition(element.ElementDistribution, Weight.One, null, element.Group); + State newChild = state.AddTransition(element.ElementDistribution, Weight.One, default(State), element.Group); success = this.DoAddGeneralizedSequence(newChild, true, false, firstAllowedStateIndex, currentSequencePos + 1, sequence, weight); Debug.Assert(success, "This call must always succeed."); return true; diff --git a/src/Runtime/Distributions/Automata/Automaton.State.cs b/src/Runtime/Distributions/Automata/Automaton.State.cs index 7ae64760..5b43d70e 100644 --- a/src/Runtime/Distributions/Automata/Automaton.State.cs +++ b/src/Runtime/Distributions/Automata/Automaton.State.cs @@ -7,6 +7,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata using System; using System.Collections; using System.Collections.Generic; + using System.Linq; using System.Runtime.Serialization; using System.Text; @@ -16,9 +17,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata using Microsoft.ML.Probabilistic.Serialization; using Microsoft.ML.Probabilistic.Utilities; - /// - /// Contains the class used to represent a state of an automaton. - /// public abstract partial class Automaton where TSequence : class, IEnumerable where TElementDistribution : class, IDistribution, SettableToProduct, SettableToWeightedSumExact, CanGetLogAverageOf, SettableToPartialUniform, new() @@ -26,44 +24,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata where TThis : Automaton, new() { /// - /// Represents a state of an automaton. + /// Represents a reference to a state of automaton for exposure in public API. /// - [Serializable] - [DataContract(IsReference = true)] - public class State + /// + /// Acts as a "fat reference" to state in automaton. In addition to reference to actual StateData it carries + /// 2 additional properties for convinience: automaton and of the state. + /// We don't store them in to save some memoty. C# compiler and .NET jitter are good + /// at optimizing wrapping where it is not needed. + /// + public struct State : IEquatable { - //// This class has been made inner so that the user doesn't have to deal with a lot of generic parameters on it. + internal readonly StateData Data; /// - /// The default capacity of the . + /// Initializes a new instance of class. Used internally by automaton implementation + /// to wrap StateData for use in public Automaton APIs. /// - private const int DefaultTransitionArrayCapacity = 1; - - /// - /// The array of outgoing transitions. - /// - /// - /// We don't use here for performance reasons. - /// - [DataMember] - private Transition[] transitions = new Transition[DefaultTransitionArrayCapacity]; - - /// - /// The number of outgoing transitions from the state. - /// - [DataMember] - private int transitionCount; - - /// - /// Initializes a new instance of the class. - /// - public State() + internal State(Automaton owner, int index, StateData data) { - this.EndWeight = Weight.Zero; + this.Owner = owner; + this.Index = index; + this.Data = data; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. Created state does not belong + /// to any automaton and has to be added to some automaton explicitly via Automaton.AddStates. /// /// The index of the state. /// The outgoing transitions. @@ -73,73 +59,90 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata : this() { Argument.CheckIfInRange(index >= 0, "index", "State index must be non-negative."); - Argument.CheckIfNotNull(transitions, "transitions"); - this.Index = index; - this.EndWeight = endWeight; - - foreach (var transition in transitions) - { - this.DoAddTransition(transition); - } + this.Data = new StateData(transitions, endWeight); } /// - /// Gets the automaton which owns the state. + /// Returns where this State represents some valid state in automaton. /// - /// - /// Owner is not serialized to avoid circular references. It has to be restored manually upon deserialization. - /// BinaryFormatter and Newtonsoft.Json handle circular references differently by default. - /// At the same time [DataMember] does the right thing, because IsReference=true property on State DataContract - /// makes DataContractSerializer handle circular references just fine. - /// - [DataMember] - [NonSerializedProperty] - public TThis Owner { get; internal set; } + public bool IsNull => this.Data == null; /// - /// Helper method for Newtonsoft.Json to skip serialization of property. + /// Automaton to which this state belongs. /// - public bool ShouldSerializeOwner() => false; + public Automaton Owner { get; } /// /// Gets the index of the state. /// - [DataMember] - public int Index { get; internal set; } // TODO: setter of this property is needed only for the state removal procedure + public int Index { get; } /// /// Gets or sets the ending weight of the state. /// - [DataMember] - public Weight EndWeight { get; set; } - + /// + /// C# compiler disallows to use property setter if it sees that instance is a temporary. + /// It is not smart enough to understand that property setter actually changes something behind a reference. + /// To overcome this issue special method is added calling which is equivalent + /// to calling property setter but is not rejected by compiler. + /// + public Weight EndWeight => this.Data.EndWeight; + + /// + /// Sets the property of State. + /// + /// Because is a struct, trying to set on it + /// (if property setter was provided) would result in compilation error. Compiler isn't + /// smart enough to see that setting property just updates the value in referenced . + /// Having a method call doesn't create this problem. + /// + /// New end weight. + public void SetEndWeight(Weight weight) + { + this.Data.EndWeight = weight; + } + /// /// Gets a value indicating whether the ending weight of this state is greater than zero. /// - public bool CanEnd - { - get { return !this.EndWeight.IsZero; } - } + public bool CanEnd => this.Data.CanEnd; /// /// Gets the number of outgoing transitions. /// - public int TransitionCount - { - get { return this.transitionCount; } - } + public int TransitionCount => this.Data.TransitionCount; /// /// Creates the copy of the array of outgoing transitions. Used by quoting. /// /// The copy of the array of outgoing transitions. - public Transition[] GetTransitions() - { - var result = new Transition[this.transitionCount]; - Array.Copy(this.transitions, result, this.transitionCount); - return result; - } + public Transition[] GetTransitions() => this.Data.GetTransitions(); + + /// + /// Compares 2 states for equality. + /// + public static bool operator ==(State a, State b) => a.Data == b.Data; + + /// + /// Compares 2 states for inequality. + /// + public static bool operator !=(State a, State b) => !(a == b); + + /// + /// Compares 2 states for equality. + /// + public bool Equals(State that) => this == that; + + /// + /// Compares 2 states for equality. + /// + public override bool Equals(object obj) => obj is State that && this.Equals(that); + + /// + /// Returns HashCode of this state. + /// + public override int GetHashCode() => this.Data?.GetHashCode() ?? 0; /// /// Adds a series of transitions labeled with the elements of a given sequence to the current state, @@ -152,16 +155,19 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// The group of the added transitions. /// The last state in the added transition series. - public State AddTransitionsForSequence(TSequence sequence, State destinationState = null, int group = 0) + public State AddTransitionsForSequence(TSequence sequence, State destinationState = default(State), int group = 0) { - State currentState = this; - IEnumerator enumerator = sequence.GetEnumerator(); - bool moveNext = enumerator.MoveNext(); - while (moveNext) + var currentState = this; + using (var enumerator = sequence.GetEnumerator()) { - TElement element = enumerator.Current; - moveNext = enumerator.MoveNext(); - currentState = currentState.AddTransition(element, Weight.One, moveNext ? null : destinationState, group); + var moveNext = enumerator.MoveNext(); + while (moveNext) + { + var element = enumerator.Current; + moveNext = enumerator.MoveNext(); + currentState = currentState.AddTransition( + element, Weight.One, moveNext ? default(State) : destinationState, group); + } } return currentState; @@ -177,7 +183,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// If the value of this parameter is , a new state will be created. /// The group of the added transition. /// The destination state of the added transition. - public State AddTransition(TElement element, Weight weight, State destinationState = null, int group = 0) + public State AddTransition(TElement element, Weight weight, State destinationState = default(State), int group = 0) { return this.AddTransition(new TElementDistribution { Point = element }, weight, destinationState, group); } @@ -191,7 +197,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// If the value of this parameter is , a new state will be created. /// The group of the added transition. /// The destination state of the added transition. - public State AddEpsilonTransition(Weight weight, State destinationState = null, int group = 0) + public State AddEpsilonTransition(Weight weight, State destinationState = default(State), int group = 0) { return this.AddTransition(null, weight, destinationState, group); } @@ -209,9 +215,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// If the value of this parameter is , a new state will be created. /// The group of the added transition. /// The destination state of the added transition. - public State AddTransition(TElementDistribution elementDistribution, Weight weight, State destinationState = null, int group = 0) + public State AddTransition(TElementDistribution elementDistribution, Weight weight, State destinationState = default(State), int group = 0) { - if (destinationState == null) + if (destinationState.IsNull) { destinationState = this.Owner.AddState(); } @@ -231,11 +237,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The destination state of the added transition. public State AddTransition(Transition transition) { - Argument.CheckIfValid(this.Owner == null || transition.DestinationStateIndex < this.Owner.states.Count, "transition", "The destination state index is not valid."); + Argument.CheckIfValid(this.Owner == null || transition.DestinationStateIndex < this.Owner.statesData.Count, "transition", "The destination state index is not valid."); + + this.Data.AddTransition(transition); + if (this.Owner.isEpsilonFree == true) + { + this.Owner.isEpsilonFree = !transition.IsEpsilon; + } - this.DoAddTransition(transition); - if (this.Owner.isEpsilonFree==true) this.Owner.isEpsilonFree = !transition.IsEpsilon; - return this.Owner.states[transition.DestinationStateIndex]; + return this.Owner.States[transition.DestinationStateIndex]; } /// @@ -270,11 +280,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// The index of the transition. /// The transition. - public Transition GetTransition(int index) - { - //Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given."); - return this.transitions[index]; - } + public Transition GetTransition(int index) => this.Data.GetTransition(index); /// /// Replaces the transition at a given index with a given transition. @@ -283,28 +289,26 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The transition to replace with. public void SetTransition(int index, Transition updatedTransition) { - Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given."); - Argument.CheckIfValid(updatedTransition.DestinationStateIndex < this.Owner.states.Count, "updatedTransition", "The destination state index is not valid."); + Argument.CheckIfInRange(index >= 0 && index < this.TransitionCount, "index", "An invalid transition index given."); + Argument.CheckIfValid(updatedTransition.DestinationStateIndex < this.Owner.statesData.Count, "updatedTransition", "The destination state index is not valid."); - if (updatedTransition.IsEpsilon) { + if (updatedTransition.IsEpsilon) + { this.Owner.isEpsilonFree = false; } else { this.Owner.isEpsilonFree = null; } - this.transitions[index] = updatedTransition; + + this.Data.SetTransition(index, updatedTransition); } /// /// Removes the transition with a given index. /// /// The index of the transition to remove. - public void RemoveTransition(int index) - { - Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given."); - this.transitions[index] = this.transitions[--this.transitionCount]; - } + public void RemoveTransition(int index) => this.Data.RemoveTransition(index); /// /// Returns a string that represents the state. @@ -314,7 +318,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { const string StartStateMarker = "START ->"; const string TransitionSeparator = ","; - + var sb = new StringBuilder(); bool isStartState = this.Owner != null && this.Owner.Start == this; @@ -341,7 +345,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata if (CanEnd) { if (!firstTransition) sb.Append(TransitionSeparator); - sb.Append(this.EndWeight.Value+" -> END"); + sb.Append(this.EndWeight.Value + " -> END"); } return sb.ToString(); @@ -389,22 +393,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata #region Helpers - /// - /// Adds a given transition to the transition array, increasing the size of the array if necessary. - /// - /// The transition to add. - private void DoAddTransition(Transition transition) - { - if (this.transitionCount == this.transitions.Length) - { - var newTransitions = new Transition[this.transitionCount * 2]; - Array.Copy(this.transitions, newTransitions, this.transitionCount); - this.transitions = newTransitions; - } - - this.transitions[this.transitionCount++] = transition; - } - /// /// Recursively checks if the automaton has non-trivial loops /// (i.e. loops consisting of more than one transition). @@ -426,11 +414,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata stateInStack.Add(this.Index, true); - for (int i = 0; i < this.transitionCount; ++i) + for (int i = 0; i < this.TransitionCount; ++i) { - if (this.transitions[i].DestinationStateIndex != this.Index) + var transition = this.GetTransition(i); + if (transition.DestinationStateIndex != this.Index) { - State destState = this.Owner.States[this.transitions[i].DestinationStateIndex]; + var destState = this.Owner.States[transition.DestinationStateIndex]; if (destState.DoHasNonTrivialLoops(stateInStack)) { return true; @@ -459,13 +448,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata visitedStates[this.Index] = true; - bool isZero = !this.CanEnd; - int transitionIndex = 0; - while (isZero && transitionIndex < this.transitionCount) + var isZero = !this.CanEnd; + var transitionIndex = 0; + while (isZero && transitionIndex < this.TransitionCount) { - if (!this.transitions[transitionIndex].Weight.IsZero) + var transition = this.GetTransition(transitionIndex); + if (!transition.Weight.IsZero) { - State destState = this.Owner.States[this.transitions[transitionIndex].DestinationStateIndex]; + var destState = this.Owner.States[transition.DestinationStateIndex]; isZero = destState.DoIsZero(visitedStates); } @@ -486,43 +476,42 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata TSequence sequence, int sequencePosition, Dictionary<(int, int), Weight> valueCache) { var stateIndexPair = (this.Index, sequencePosition); - Weight cachedValue; - if (valueCache.TryGetValue(stateIndexPair, out cachedValue)) + if (valueCache.TryGetValue(stateIndexPair, out var cachedValue)) { return cachedValue; } - EpsilonClosure closure = this.GetEpsilonClosure(); + var closure = this.GetEpsilonClosure(); - Weight value = Weight.Zero; - int count = Automaton.SequenceManipulator.GetLength(sequence); - bool isCurrent = sequencePosition < count; + var value = Weight.Zero; + var count = SequenceManipulator.GetLength(sequence); + var isCurrent = sequencePosition < count; if (isCurrent) { - TElement element = Automaton.SequenceManipulator.GetElement(sequence, sequencePosition); - for (int closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex) + var element = SequenceManipulator.GetElement(sequence, sequencePosition); + for (var closureStateIndex = 0; closureStateIndex < closure.Size; ++closureStateIndex) { - State closureState = closure.GetStateByIndex(closureStateIndex); - Weight closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex); + var closureState = closure.GetStateByIndex(closureStateIndex); + var closureStateWeight = closure.GetStateWeightByIndex(closureStateIndex); - for (int transitionIndex = 0; transitionIndex < closureState.transitionCount; transitionIndex++) + for (int transitionIndex = 0; transitionIndex < closureState.TransitionCount; transitionIndex++) { - Transition transition = closureState.transitions[transitionIndex]; + var transition = closureState.GetTransition(transitionIndex); if (transition.IsEpsilon) { continue; // The destination is a part of the closure anyway } - State destState = this.Owner.states[transition.DestinationStateIndex]; - Weight distWeight = Weight.FromLogValue(transition.ElementDistribution.GetLogProb(element)); + var destState = this.Owner.States[transition.DestinationStateIndex]; + var distWeight = Weight.FromLogValue(transition.ElementDistribution.GetLogProb(element)); if (!distWeight.IsZero && !transition.Weight.IsZero) { - Weight destValue = destState.DoGetValue(sequence, sequencePosition + 1, valueCache); + var destValue = destState.DoGetValue(sequence, sequencePosition + 1, valueCache); if (!destValue.IsZero) { value = Weight.Sum( value, - Weight.Product(closureStateWeight, transition.Weight, distWeight, destValue)); + Weight.Product(closureStateWeight, transition.Weight, distWeight, destValue)); } } } @@ -546,15 +535,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { get { - foreach (var state in Owner.States) - { - var trans = state.GetTransitions(); - for (int i = 0; i < trans.Length; i++) - { - if (trans[i].DestinationStateIndex == Index) return true; - } - } - return false; + var this_ = this; + return this.Owner.States.Any( + state => state.GetTransitions().Any( + transition => transition.DestinationStateIndex == this_.Index)); } } @@ -566,10 +550,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { this.EndWeight.Write(writeDouble); writeInt32(this.Index); - writeInt32(this.transitionCount); - for (var i = 0; i < transitionCount; i++) + writeInt32(this.TransitionCount); + for (var i = 0; i < TransitionCount; i++) { - transitions[i].Write(writeInt32, writeDouble, writeElementDistribution); + GetTransition(i).Write(writeInt32, writeDouble, writeElementDistribution); } } @@ -578,27 +562,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public static State Read(Func readInt32, Func readDouble, Func readElementDistribution) { - var res = new State(); - res.EndWeight = Weight.Read(readDouble); - res.Index = readInt32(); + var endWeight = Weight.Read(readDouble); + // Note: index is serialized for compatibility with old binary serializations + var index = readInt32(); var transitionCount = readInt32(); - - var transitionLength = res.transitions.Length; - while (transitionLength < transitionCount) - { - transitionLength <<= 1; - } - - var transitions = transitionLength == res.transitions.Length ? res.transitions : new Transition[transitionLength]; + var transitions = new Transition[transitionCount]; for (var i = 0; i < transitionCount; i++) { transitions[i] = Transition.Read(readInt32, readDouble, readElementDistribution); } - res.transitionCount = transitionCount; - res.transitions = transitions; - - return res; + return new State(index, transitions, endWeight); } } } diff --git a/src/Runtime/Distributions/Automata/Automaton.StateCollection.cs b/src/Runtime/Distributions/Automata/Automaton.StateCollection.cs new file mode 100644 index 00000000..f22623d1 --- /dev/null +++ b/src/Runtime/Distributions/Automata/Automaton.StateCollection.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Probabilistic.Distributions.Automata +{ + using System.Collections; + using System.Collections.Generic; + using System.Linq; + + using Microsoft.ML.Probabilistic.Distributions; + using Microsoft.ML.Probabilistic.Math; + + public abstract partial class Automaton + where TSequence : class, IEnumerable + where TElementDistribution : class, IDistribution, SettableToProduct, SettableToWeightedSumExact, CanGetLogAverageOf, SettableToPartialUniform, new() + where TSequenceManipulator : ISequenceManipulator, new() + where TThis : Automaton, new() + { + /// + /// Represents a collection of automaton states for use in public APIs + /// + /// + /// Is a thin wrapper around Automaton.stateData. Wraps each into on demand. + /// + public struct StateCollection : IReadOnlyList + { + /// + /// Owner automaton of all states in collection. + /// + private readonly Automaton owner; + + /// + /// Cached value of owner.statesData. Cached for performance reasons. + /// + private readonly List statesData; + + /// + /// Initializes instance of . + /// + internal StateCollection(Automaton owner, List states) + { + this.owner = owner; + this.statesData = owner.statesData; + } + + /// + /// Gets state by its index. + /// + public State this[int index] => new State(this.owner, index, this.statesData[index]); + + /// + /// Gets number of states in collection. + /// + public int Count => this.statesData.Count; + + /// + /// Returns enumerator over all states in collection. + /// + public IEnumerator GetEnumerator() + { + var owner = this.owner; + return this.statesData.Select((data, index) => new State(owner, index, data)).GetEnumerator(); + } + + /// + /// Returns enumerator over all states in collection. + /// + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + } + } +} diff --git a/src/Runtime/Distributions/Automata/Automaton.StateData.cs b/src/Runtime/Distributions/Automata/Automaton.StateData.cs new file mode 100644 index 00000000..c4f23c80 --- /dev/null +++ b/src/Runtime/Distributions/Automata/Automaton.StateData.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Probabilistic.Distributions.Automata +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Runtime.Serialization; + + using Microsoft.ML.Probabilistic.Distributions; + using Microsoft.ML.Probabilistic.Math; + using Microsoft.ML.Probabilistic.Serialization; + using Microsoft.ML.Probabilistic.Utilities; + + public abstract partial class Automaton + where TSequence : class, IEnumerable + where TElementDistribution : class, IDistribution, SettableToProduct, SettableToWeightedSumExact, CanGetLogAverageOf, SettableToPartialUniform, new() + where TSequenceManipulator : ISequenceManipulator, new() + where TThis : Automaton, new() + { + /// + /// Represents a state of an automaton that is stored in the Automaton.statesData. This is an internal representation + /// of the state. struct should be used in public APIs. + /// + [Serializable] + [DataContract] + internal class StateData + { + /// + /// The default capacity of the . + /// + private const int DefaultTransitionArrayCapacity = 1; + + /// + /// The array of outgoing transitions. + /// + /// + /// We don't use here for performance reasons. + /// + [DataMember] + private Transition[] transitions = new Transition[DefaultTransitionArrayCapacity]; + + /// + /// The number of outgoing transitions from the state. + /// + [DataMember] + private int transitionCount; + + /// + /// Initializes a new instance of the class. + /// + public StateData() => this.EndWeight = Weight.Zero; + + /// + /// Initializes a new instance of the class. + /// + /// The outgoing transitions. + /// The ending weight of the state. + [Construction("GetTransitions", "EndWeight")] + public StateData(IEnumerable transitions, Weight endWeight) + : this() + { + Argument.CheckIfNotNull(transitions, "transitions"); + + this.EndWeight = endWeight; + + foreach (var transition in transitions) + { + this.AddTransition(transition); + } + } + + /// + /// Gets or sets the ending weight of the state. + /// + [DataMember] + public Weight EndWeight { get; set; } + + /// + /// Gets a value indicating whether the ending weight of this state is greater than zero. + /// + public bool CanEnd => !this.EndWeight.IsZero; + + /// + /// Gets the number of outgoing transitions. + /// + public int TransitionCount => this.transitionCount; + + /// + /// Creates the copy of the array of outgoing transitions. Used by quoting. + /// + /// The copy of the array of outgoing transitions. + public Transition[] GetTransitions() + { + var result = new Transition[this.transitionCount]; + Array.Copy(this.transitions, result, this.transitionCount); + return result; + } + + /// + /// Adds a transition to the current state. + /// + /// The transition to add. + /// The destination state of the added transition. + public void AddTransition(Transition transition) + { + if (this.transitionCount == this.transitions.Length) + { + var newTransitions = new Transition[this.transitionCount * 2]; + Array.Copy(this.transitions, newTransitions, this.transitionCount); + this.transitions = newTransitions; + } + + this.transitions[this.transitionCount++] = transition; + } + + /// + /// Gets the transition at a specified index. + /// + /// The index of the transition. + /// The transition. + public Transition GetTransition(int index) + { + Debug.Assert(index >= 0 && index < this.transitionCount, nameof(index), "An invalid transition index given."); + return this.transitions[index]; + } + + /// + /// Replaces the transition at a given index with a given transition. + /// + /// The index of the transition to replace. + /// The transition to replace with. + public void SetTransition(int index, Transition updatedTransition) => + this.transitions[index] = updatedTransition; + + /// + /// Removes the transition with a given index. + /// + /// The index of the transition to remove. + public void RemoveTransition(int index) + { + Argument.CheckIfInRange(index >= 0 && index < this.transitionCount, "index", "An invalid transition index given."); + this.transitions[index] = this.transitions[--this.transitionCount]; + } + } + } +} diff --git a/src/Runtime/Distributions/Automata/Automaton.StronglyConnectedComponent.cs b/src/Runtime/Distributions/Automata/Automaton.StronglyConnectedComponent.cs index 0b425821..2034af43 100644 --- a/src/Runtime/Distributions/Automata/Automaton.StronglyConnectedComponent.cs +++ b/src/Runtime/Distributions/Automata/Automaton.StronglyConnectedComponent.cs @@ -114,7 +114,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public bool HasState(State state) { - Argument.CheckIfNotNull(state, "state"); + Argument.CheckIfValid(!state.IsNull, nameof(state)); return this.GetIndexByState(state) != -1; } @@ -128,12 +128,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public int GetIndexByState(State state) { - Argument.CheckIfNotNull(state, "state"); + Argument.CheckIfValid(!state.IsNull, nameof(state)); Argument.CheckIfValid(ReferenceEquals(state.Owner, this.statesInComponent[0].Owner), "state", "The given state belongs to other automaton."); if (this.statesInComponent.Count == 1) { - return ReferenceEquals(this.statesInComponent[0], state) ? 0 : -1; + return this.statesInComponent[0].Index == state.Index ? 0 : -1; } if (this.stateIdToIndexInComponent == null) diff --git a/src/Runtime/Distributions/Automata/Automaton.cs b/src/Runtime/Distributions/Automata/Automaton.cs index a77070a5..a587b75b 100644 --- a/src/Runtime/Distributions/Automata/Automaton.cs +++ b/src/Runtime/Distributions/Automata/Automaton.cs @@ -88,18 +88,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The collection of states. /// [DataMember] - private List states = new List(); - - /// - /// A read-only wrapper around the . - /// - private ReadOnlyList statesReadOnly; + private List statesData = new List(); /// /// The start state. /// [DataMember] - private State startState; + private int startStateIndex; /// /// Whether the automaton is free of epsilon transition. @@ -127,7 +122,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata protected Automaton() { // Zero by default - this.statesReadOnly = new ReadOnlyList(this.states); this.SetToZero(); } @@ -179,14 +173,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public static int MaxStateCount { - get - { - return maxStateCount; - } + get => maxStateCount; set { - Argument.CheckIfInRange(value > 0, "value", "The maximum number of states must be positive."); + Argument.CheckIfInRange(value > 0, nameof(value), "The maximum number of states must be positive."); maxStateCount = value; } } @@ -197,14 +188,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public static int MaxStateCountBeforeSimplification { - get - { - return maxStateCountBeforeSimplification; - } + get => maxStateCountBeforeSimplification; set { - Argument.CheckIfInRange(value > 0, "value", "The maximum number of states before simplification must be positive."); + Argument.CheckIfInRange(value > 0, nameof(value), "The maximum number of states before simplification must be positive."); maxStateCountBeforeSimplification = value; } } @@ -215,14 +203,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public static int MaxDeadStateCount { - get - { - return maxDeadStateCount; - } + get => maxDeadStateCount; set { - Argument.CheckIfInRange(value >= 0, "value", "The maximum number of dead states should be non-negative."); + Argument.CheckIfInRange(value >= 0, nameof(value), "The maximum number of dead states should be non-negative."); maxDeadStateCount = value; } } @@ -230,13 +215,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// Gets the collection of the states of the automaton. /// - public ReadOnlyList States - { - get - { - return this.statesReadOnly; - } - } + public StateCollection States => new StateCollection(this, this.statesData); /// /// Gets or sets the start state of the automaton. @@ -246,16 +225,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public State Start { - get - { - return this.startState; - } + get => new State(this, this.startStateIndex, this.statesData[this.startStateIndex]); set { - Argument.CheckIfNotNull(value, "value"); - Argument.CheckIfValid(ReferenceEquals(value.Owner, this), "value", "The given state does not belong to this automaton."); - this.startState = value; + Argument.CheckIfValid(!value.IsNull, nameof(value)); + Argument.CheckIfValid(ReferenceEquals(value.Owner, this), nameof(value), "The given state does not belong to this automaton."); + this.startStateIndex = value.Index; } } @@ -277,13 +253,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata [Construction("GetStates", "Start")] public static TThis FromStates(IEnumerable states, State startState) { - Argument.CheckIfNotNull(states, "states"); - Argument.CheckIfNotNull(startState, "startState"); + Argument.CheckIfNotNull(states, nameof(states)); + Argument.CheckIfValid(!startState.IsNull, nameof(startState)); CheckStateConsistency(states, startState); var result = new TThis(); - result.SetStates(states, startState.Index); + result.SetStates(states.Select(state => state.Data), startState.Index); return result; } @@ -380,7 +356,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { allowedElements = Distribution.CreatePartialUniform(allowedElements); State finish = result.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements))); - finish.EndWeight = Weight.FromLogValue(logValue); + finish.SetEndWeight(Weight.FromLogValue(logValue)); } return result; @@ -433,7 +409,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata foreach (TSequence sequence in sequences) { State sequenceEndState = result.Start.AddTransitionsForSequence(sequence); - sequenceEndState.EndWeight = Weight.FromLogValue(logValue); + sequenceEndState.SetEndWeight(Weight.FromLogValue(logValue)); } } @@ -503,7 +479,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } TThis result = Zero(); - result.startState.EndWeight = Weight.FromLogValue(Math.Log(value)); + result.Start.SetEndWeight(Weight.FromLogValue(Math.Log(value))); return result; } @@ -607,9 +583,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata continue; } - int index = result.states.Count; - result.AddStates(automaton.states); - result.Start.AddEpsilonTransition(Weight.One, result.states[index + automaton.Start.Index]); + int index = result.statesData.Count; + result.AddStates(automaton.statesData); + result.Start.AddEpsilonTransition(Weight.One, result.States[index + automaton.Start.Index]); } return result; @@ -699,20 +675,20 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata TThis result = ConstantOn(1.0, SequenceManipulator.ToSequence(new TElement[0])); // Build a list of all intermediate end states with their target ending weights while adding repetitions - var endStatesWithTargetWeights = new List>(); + var endStatesWithTargetWeights = new List<(State, Weight)>(); int prevStateCount = 0; for (int i = 0; i <= maxTimes; ++i) { // Remember added ending states if (repetitionNumberWeights[i] > 0) { - for (int j = prevStateCount; j < result.states.Count; ++j) + for (int j = prevStateCount; j < result.statesData.Count; ++j) { - if (result.states[j].CanEnd) + if (result.statesData[j].CanEnd) { - endStatesWithTargetWeights.Add(Pair.Create( - result.states[j], - Weight.Product(Weight.FromValue(repetitionNumberWeights[i]), result.states[j].EndWeight))); + endStatesWithTargetWeights.Add(ValueTuple.Create( + result.States[j], + Weight.Product(Weight.FromValue(repetitionNumberWeights[i]), result.statesData[j].EndWeight))); } } } @@ -720,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata // Add one more repetition if (i != maxTimes) { - prevStateCount = result.States.Count; + prevStateCount = result.statesData.Count; result.AppendInPlaceNoOptimizations(automaton); } } @@ -728,7 +704,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata // Set target ending weights for (int i = 0; i < endStatesWithTargetWeights.Count; ++i) { - endStatesWithTargetWeights[i].First.EndWeight = endStatesWithTargetWeights[i].Second; + var (state, weight) = endStatesWithTargetWeights[i]; + state.SetEndWeight(weight); } return result; @@ -774,11 +751,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } TThis optionalPart = automaton.Clone(); - for (int i = 0; i < optionalPart.states.Count; ++i) + for (int i = 0; i < optionalPart.statesData.Count; ++i) { - if (optionalPart.states[i].CanEnd) + if (optionalPart.statesData[i].CanEnd) { - optionalPart.states[i].AddEpsilonTransition(optionalPart.states[i].EndWeight, optionalPart.Start); + optionalPart.States[i].AddEpsilonTransition(optionalPart.States[i].EndWeight, optionalPart.Start); } } @@ -829,7 +806,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata else { StringBuilder builder = new StringBuilder(); - this.AppendString(builder, new HashSet(), this.startState.Index, appendElement); + this.AppendString(builder, new HashSet(), this.Start.Index, appendElement); return builder.ToString(); } } @@ -857,9 +834,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// True if it the automaton has this group, false otherwise. public bool HasGroup(int group) { - for (int stateIndex = 0; stateIndex < this.states.Count; stateIndex++) + for (int stateIndex = 0; stateIndex < this.statesData.Count; stateIndex++) { - State state = this.states[stateIndex]; + var state = this.statesData[stateIndex]; for (int transitionIndex = 0; transitionIndex < state.TransitionCount; transitionIndex++) { Transition transition = state.GetTransition(transitionIndex); @@ -879,9 +856,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// True if it the automaton has groups, false otherwise. public bool UsesGroups() { - for (int stateIndex = 0; stateIndex < this.states.Count; stateIndex++) + for (int stateIndex = 0; stateIndex < this.statesData.Count; stateIndex++) { - State state = this.states[stateIndex]; + var state = this.statesData[stateIndex]; for (int transitionIndex = 0; transitionIndex < state.TransitionCount; transitionIndex++) { Transition transition = state.GetTransition(transitionIndex); @@ -911,9 +888,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The specified group. public void SetGroup(int group) { - for (int stateIndex = 0; stateIndex < this.states.Count; stateIndex++) + for (int stateIndex = 0; stateIndex < this.statesData.Count; stateIndex++) { - State state = this.states[stateIndex]; + var state = this.statesData[stateIndex]; for (int transitionIndex = 0; transitionIndex < state.TransitionCount; transitionIndex++) { Transition transition = state.GetTransition(transitionIndex); @@ -1027,7 +1004,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public bool IsCanonicConstant() { - if (this.States.Count != 1) + if (this.statesData.Count != 1) { return false; } @@ -1098,13 +1075,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata throw new NotImplementedException("Not yet supported for non-determinizable automata."); } - for (int stateId = 0; stateId < result.states.Count; ++stateId) + for (int stateId = 0; stateId < result.States.Count; ++stateId) { - var state = result.states[stateId]; + var state = result.States[stateId]; if (state.CanEnd) { // Make all accepting states contibute the desired value to the result - state.EndWeight = value; + state.SetEndWeight(value); } for (int transitionIndex = 0; transitionIndex < state.TransitionCount; ++transitionIndex) @@ -1152,29 +1129,29 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata var result = Zero(); // Result already has 1 state, we add the remaining Count-1 states - result.AddStates(this.states.Count - 1); + result.AddStates(this.statesData.Count - 1); // And the new start state result.Start = result.AddState(); // The start state in the original automaton is going to be the one and only end state in result - result.states[this.Start.Index].EndWeight = Weight.One; + result.States[this.Start.Index].SetEndWeight(Weight.One); - for (int i = 0; i < this.states.Count; ++i) + for (int i = 0; i < this.statesData.Count; ++i) { - var oldState = this.states[i]; - for (int j = 0; j < this.states[i].TransitionCount; ++j) + var oldState = this.statesData[i]; + for (int j = 0; j < this.statesData[i].TransitionCount; ++j) { // Result has original transitions reversed var oldTransition = oldState.GetTransition(j); - result.states[oldTransition.DestinationStateIndex].AddTransition( - oldTransition.ElementDistribution, oldTransition.Weight, result.states[i]); + result.States[oldTransition.DestinationStateIndex].AddTransition( + oldTransition.ElementDistribution, oldTransition.Weight, result.States[i]); } // End states of the original automaton are the new start states if (oldState.CanEnd) { - result.Start.AddEpsilonTransition(oldState.EndWeight, result.states[i]); + result.Start.AddEpsilonTransition(oldState.EndWeight, result.States[i]); } } @@ -1247,11 +1224,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } // Append the states of the second automaton - List endStates = this.states.Where(nd => nd.CanEnd).ToList(); - int stateCount = this.states.Count; + var endStates = this.States.Where(nd => nd.CanEnd).ToList(); + int stateCount = this.statesData.Count; - this.AddStates(automaton.states, group); - State secondStartState = this.states[stateCount + automaton.Start.Index]; + this.AddStates(automaton.statesData, group); + var secondStartState = this.States[stateCount + automaton.Start.Index]; // todo: make efficient bool startIncoming = automaton.Start.HasIncomingTransitions; @@ -1277,10 +1254,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata transition.Weight = Weight.Product(transition.Weight, endState.EndWeight); } - endState.AddTransition(transition); + endState.Data.AddTransition(transition); } - endState.EndWeight = Weight.Product(endState.EndWeight, secondStartState.EndWeight); + endState.SetEndWeight(Weight.Product(endState.EndWeight, secondStartState.EndWeight)); } this.RemoveState(secondStartState.Index); @@ -1291,7 +1268,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { State state = endStates[i]; state.AddEpsilonTransition(state.EndWeight, secondStartState, group); - state.EndWeight = Weight.Zero; + state.SetEndWeight(Weight.Zero); } } @@ -1482,15 +1459,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { if (hasFirstTerm) { - result.AddStates(automaton1.states); - result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight1), result.states[1 + automaton1.Start.Index]); + result.AddStates(automaton1.statesData); + result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight1), result.States[1 + automaton1.Start.Index]); } if (hasSecondTerm) { - int cnt = result.states.Count; - result.AddStates(automaton2.states); - result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight2), result.states[cnt + automaton2.Start.Index]); + int cnt = result.statesData.Count; + result.AddStates(automaton2.statesData); + result.Start.AddEpsilonTransition(Weight.FromLogValue(logWeight2), result.States[cnt + automaton2.Start.Index]); } } @@ -1556,8 +1533,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// public void SetToZero() { - this.states.Clear(); - this.startState = this.AddState(); + this.statesData.Clear(); + this.startStateIndex = this.AddState().Index; this.isEpsilonFree = true; } @@ -1611,7 +1588,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata this.SetToZero(); if (!double.IsNegativeInfinity(logValue)) { - this.Start.EndWeight = Weight.FromLogValue(logValue); + this.Start.SetEndWeight(Weight.FromLogValue(logValue)); this.Start.AddTransition(allowedElements, Weight.FromLogValue(-allowedElements.GetLogAverageOf(allowedElements)), this.Start); } } @@ -1626,7 +1603,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata if (!ReferenceEquals(this, automaton)) { - this.SetStates(automaton.states, automaton.Start.Index); + this.SetStates(automaton.statesData, automaton.Start.Index); this.isEpsilonFree = automaton.isEpsilonFree; this.LogValueOverride = automaton.LogValueOverride; this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan; @@ -1657,15 +1634,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata // Add states this.SetToZero(); - this.AddStates(sourceAutomaton.states.Count - 1); + this.AddStates(sourceAutomaton.statesData.Count - 1); // Copy state parameters and transitions - for (int stateIndex = 0; stateIndex < sourceAutomaton.states.Count; stateIndex++) + for (int stateIndex = 0; stateIndex < sourceAutomaton.statesData.Count; stateIndex++) { - State thisState = this.states[stateIndex]; - var otherState = sourceAutomaton.states[stateIndex]; + var thisState = this.States[stateIndex]; + var otherState = sourceAutomaton.States[stateIndex]; - thisState.EndWeight = otherState.EndWeight; + thisState.SetEndWeight(otherState.EndWeight); if (otherState == sourceAutomaton.Start) { this.Start = thisState; @@ -1675,10 +1652,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { var otherTransition = otherState.GetTransition(transitionIndex); var transformedTransition = transitionTransform(otherTransition.ElementDistribution, otherTransition.Weight, otherTransition.Group); - this.states[stateIndex].AddTransition( + this.States[stateIndex].AddTransition( transformedTransition.Item1, transformedTransition.Item2, - this.states[otherTransition.DestinationStateIndex], + this.States[otherTransition.DestinationStateIndex], otherTransition.Group); } } @@ -1728,7 +1705,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata var point = new List(); int? pointLength = null; - var stateDepth = new ArrayDictionary(this.states.Count); + var stateDepth = new ArrayDictionary(this.statesData.Count); bool isPoint = this.TryComputePointDfs(this.Start, 0, stateDepth, endNodeReachability, point, ref pointLength); return isPoint && pointLength.HasValue ? SequenceManipulator.ToSequence(point) : null; } @@ -1848,7 +1825,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// The created state collection copy. public State[] GetStates() { - return this.states.Select(s => new State(s.Index, s.GetTransitions(), s.EndWeight)).ToArray(); + // FIXME: discuss what it is supposed to do + // if needed - implement real full automaton deep-copy + return this.States.ToArray(); } /// @@ -1858,19 +1837,16 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// Indices of the added states are guaranteed to be increasing consecutive. public State AddState() { - if (this.states.Count >= maxStateCount) + if (this.statesData.Count >= maxStateCount) { throw new AutomatonTooLargeException(MaxStateCount); } - var state = new State - { - Owner = (TThis)this, - Index = this.states.Count - }; - this.states.Add(state); + var index = this.statesData.Count; + var stateImpl = new StateData(); + this.statesData.Add(stateImpl); - return state; + return new State(this, index, stateImpl); } /// @@ -1897,7 +1873,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata if (this.isEpsilonFree == null) { this.isEpsilonFree = true; - foreach (var state in this.states) + foreach (var state in this.statesData) { for (int i = 0; i < state.TransitionCount; i++) { @@ -2020,7 +1996,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { TThis automaton = automata[automatonIndex]; - for (int stateIndex = 0; stateIndex < automaton.states.Count; ++stateIndex) + for (int stateIndex = 0; stateIndex < automaton.statesData.Count; ++stateIndex) { State state = automaton.States[stateIndex]; Weight transitionWeightSum = Weight.Zero; @@ -2049,7 +2025,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata Weight.FromLogValue(-maxLogTransitionWeightSum), Weight.FromValue(0.99)); theConverger.Start.AddSelfTransition(uniformDist, transitionWeight); - theConverger.Start.EndWeight = Weight.One; + theConverger.Start.SetEndWeight(Weight.One); return theConverger; } @@ -2125,17 +2101,17 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata automaton = automaton.Clone(); } - int stateCount = this.states.Count; - List endStates = this.states.Where(nd => nd.CanEnd).ToList(); + int stateCount = this.statesData.Count; + var endStates = this.States.Where(nd => nd.CanEnd).ToList(); - this.AddStates(automaton.states); - State secondStartState = this.states[stateCount + automaton.Start.Index]; + this.AddStates(automaton.statesData); + var secondStartState = this.States[stateCount + automaton.Start.Index]; for (int i = 0; i < endStates.Count; i++) { - State state = endStates[i]; + var state = endStates[i]; state.AddEpsilonTransition(state.EndWeight, secondStartState); - state.EndWeight = Weight.Zero; + state.SetEndWeight(Weight.Zero); } } @@ -2147,10 +2123,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { //// First, build a reversed graph - int[] edgePlacementIndices = new int[this.states.Count + 1]; - for (int i = 0; i < this.states.Count; ++i) + int[] edgePlacementIndices = new int[this.statesData.Count + 1]; + for (int i = 0; i < this.statesData.Count; ++i) { - State state = this.states[i]; + var state = this.statesData[i]; for (int j = 0; j < state.TransitionCount; ++j) { var transition = state.GetTransition(j); @@ -2170,11 +2146,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } int[] edgeArrayStarts = (int[])edgePlacementIndices.Clone(); - int totalEdgeCount = edgePlacementIndices[this.states.Count]; + int totalEdgeCount = edgePlacementIndices[this.statesData.Count]; int[] edgeDestinationIndices = new int[totalEdgeCount]; - for (int i = 0; i < this.states.Count; ++i) + for (int i = 0; i < this.statesData.Count; ++i) { - State state = this.states[i]; + var state = this.statesData[i]; for (int j = 0; j < state.TransitionCount; ++j) { var transition = state.GetTransition(j); @@ -2191,10 +2167,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata //// Now run a depth-first search to label all reachable nodes - bool[] visitedNodes = new bool[this.states.Count]; - for (int i = 0; i < this.states.Count; ++i) + bool[] visitedNodes = new bool[this.statesData.Count]; + for (int i = 0; i < this.statesData.Count; ++i) { - if (!visitedNodes[i] && this.states[i].CanEnd) + if (!visitedNodes[i] && this.statesData[i].CanEnd) { LabelReachableNodesDfs(i, visitedNodes, edgeDestinationIndices, edgeArrayStarts); } @@ -2211,10 +2187,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { //// First, build a reversed graph - int[] edgePlacementIndices = new int[this.states.Count + 1]; - for (int i = 0; i < this.states.Count; ++i) + int[] edgePlacementIndices = new int[this.statesData.Count + 1]; + for (int i = 0; i < this.statesData.Count; ++i) { - State state = this.states[i]; + var state = this.statesData[i]; for (int j = 0; j < state.TransitionCount; ++j) { var transition = state.GetTransition(j); @@ -2234,11 +2210,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } int[] edgeArrayStarts = (int[])edgePlacementIndices.Clone(); - int totalEdgeCount = edgePlacementIndices[this.states.Count]; + int totalEdgeCount = edgePlacementIndices[this.statesData.Count]; int[] edgeDestinationIndices = new int[totalEdgeCount]; - for (int i = 0; i < this.states.Count; ++i) + for (int i = 0; i < this.statesData.Count; ++i) { - State state = this.states[i]; + var state = this.statesData[i]; for (int j = 0; j < state.TransitionCount; ++j) { var transition = state.GetTransition(j); @@ -2254,8 +2230,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } //// Now run a depth-first search to label all reachable nodes - bool[] visitedNodes = new bool[this.states.Count]; - LabelReachableNodesDfs(this.startState.Index, visitedNodes, edgeDestinationIndices, edgeArrayStarts); + bool[] visitedNodes = new bool[this.statesData.Count]; + LabelReachableNodesDfs(this.Start.Index, visitedNodes, edgeDestinationIndices, edgeArrayStarts); return visitedNodes; } @@ -2292,9 +2268,9 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// A condensation of the automaton. private void PushWeights(Condensation condensation) { - for (int i = 0; i < this.states.Count; ++i) + for (int i = 0; i < this.statesData.Count; ++i) { - State state = this.states[i]; + var state = this.States[i]; Weight weightToEnd = condensation.GetWeightToEnd(state); if (weightToEnd.IsZero) { @@ -2307,12 +2283,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata Transition transition = state.GetTransition(j); transition.Weight = Weight.Product( transition.Weight, - condensation.GetWeightToEnd(this.states[transition.DestinationStateIndex]), + condensation.GetWeightToEnd(this.States[transition.DestinationStateIndex]), weightToEndInv); state.SetTransition(j, transition); } - state.EndWeight = Weight.Product(state.EndWeight, weightToEndInv); + state.SetEndWeight(Weight.Product(state.EndWeight, weightToEndInv)); } } @@ -2370,7 +2346,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { Transition transition = currentState.GetTransition(i); - State destState = this.states[transition.DestinationStateIndex]; + State destState = this.States[transition.DestinationStateIndex]; if (!isEndNodeReachable[destState.Index]) { continue; // Only walk through the accepting part of the automaton @@ -2448,7 +2424,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata for (int transition1Index = 0; transition1Index < state1.TransitionCount; transition1Index++) { Transition transition1 = state1.GetTransition(transition1Index); - State destState1 = state1.Owner.states[transition1.DestinationStateIndex]; + State destState1 = state1.Owner.States[transition1.DestinationStateIndex]; if (transition1.IsEpsilon) { @@ -2463,7 +2439,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { Transition transition2 = state2.GetTransition(transition2Index); Debug.Assert(!transition2.IsEpsilon, "The second argument of the product operation must be epsilon-free."); - State destState2 = state2.Owner.states[transition2.DestinationStateIndex]; + State destState2 = state2.Owner.States[transition2.DestinationStateIndex]; TElementDistribution product; double productLogNormalizer = Distribution.GetLogAverageOf( @@ -2483,7 +2459,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } } - productState.EndWeight = Weight.Product(state1.EndWeight, state2.EndWeight); + productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight)); return productState; } @@ -2510,7 +2486,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata oldToNewState.Add(state.Index, resultState); EpsilonClosure closure = state.GetEpsilonClosure(); - resultState.EndWeight = closure.EndWeight; + resultState.SetEndWeight(closure.EndWeight); for (int stateIndex = 0; stateIndex < closure.Size; ++stateIndex) { State closureState = closure.GetStateByIndex(stateIndex); @@ -2520,7 +2496,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata Transition transition = closureState.GetTransition(transitionIndex); if (!transition.IsEpsilon) { - State destState = state.Owner.states[transition.DestinationStateIndex]; + State destState = state.Owner.States[transition.DestinationStateIndex]; State closureDestState = this.BuildEpsilonClosure(destState, oldToNewState); resultState.AddTransition( transition.ElementDistribution, Weight.Product(transition.Weight, closureStateWeight), closureDestState, transition.Group); @@ -2540,9 +2516,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata Debug.Assert(automaton != null, "A valid automaton must be provided."); // Swap contents - Util.Swap(ref this.states, ref automaton.states); - Util.Swap(ref this.startState, ref automaton.startState); - Util.Swap(ref this.statesReadOnly, ref automaton.statesReadOnly); + Util.Swap(ref this.statesData, ref automaton.statesData); + Util.Swap(ref this.startStateIndex, ref automaton.startStateIndex); Util.Swap(ref this.isEpsilonFree, ref automaton.isEpsilonFree); var dummy = this.LogValueOverride; this.LogValueOverride = automaton.LogValueOverride; @@ -2550,17 +2525,6 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata dummy = this.PruneTransitionsWithLogWeightLessThan; this.PruneTransitionsWithLogWeightLessThan = automaton.PruneTransitionsWithLogWeightLessThan; automaton.PruneTransitionsWithLogWeightLessThan = dummy; - - // Update backward references - for (int i = 0; i < this.states.Count; ++i) - { - this.states[i].Owner = (TThis)this; - } - - for (int i = 0; i < automaton.states.Count; ++i) - { - automaton.states[i].Owner = automaton; - } } /// @@ -2568,11 +2532,11 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// The states to replace the existing states with. /// The index of the new start state. - private void SetStates(IEnumerable newStates, int newStartStateIndex) + private void SetStates(IEnumerable newStates, int newStartStateIndex) { - this.states.Clear(); + this.statesData.Clear(); this.AddStates(newStates); - this.Start = this.states[newStartStateIndex]; + this.Start = this.States[newStartStateIndex]; } /// @@ -2581,34 +2545,34 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// /// The states to add. /// The group for the transitions of the states being added. - private void AddStates(IEnumerable statesToAdd, int group = 0) + private void AddStates(IEnumerable statesToAdd, int group = 0) { Debug.Assert(statesToAdd != null, "A valid state collection must be provided."); - int startIndex = this.states.Count; - var statesToAddList = statesToAdd as IList ?? statesToAdd.ToList(); + int startIndex = this.statesData.Count; + var statesToAddList = statesToAdd as IList ?? statesToAdd.ToList(); // Add states for (int i = 0; i < statesToAddList.Count; ++i) { State newState = this.AddState(); - newState.EndWeight = statesToAddList[i].EndWeight; + newState.SetEndWeight(statesToAddList[i].EndWeight); - Debug.Assert(statesToAddList[i].Index == i && newState.Index == i + startIndex, "State indices must always be consequent."); + Debug.Assert(newState.Index == i + startIndex, "State indices must always be consequent."); } // Add transitions for (int i = 0; i < statesToAddList.Count; ++i) { - State stateToAdd = statesToAddList[i]; + var stateToAdd = statesToAddList[i]; for (int transitionIndex = 0; transitionIndex < stateToAdd.TransitionCount; transitionIndex++) { Transition transitionToAdd = stateToAdd.GetTransition(transitionIndex); Debug.Assert(transitionToAdd.DestinationStateIndex < statesToAddList.Count, "Self-inconsistent collection of states provided."); - this.states[i + startIndex].AddTransition( + this.States[i + startIndex].AddTransition( transitionToAdd.ElementDistribution, transitionToAdd.Weight, - this.states[transitionToAdd.DestinationStateIndex + startIndex], + this.States[transitionToAdd.DestinationStateIndex + startIndex], group != 0 ? group : transitionToAdd.Group); } } @@ -2630,19 +2594,18 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata { //// TODO: see remarks - Debug.Assert(index >= 0 && index < this.states.Count, "An invalid state index provided."); + Debug.Assert(index >= 0 && index < this.statesData.Count, "An invalid state index provided."); Debug.Assert(index != this.Start.Index, "Cannot remove the start state."); Debug.Assert( - !replaceIndex.HasValue || (replaceIndex.Value >= 0 && replaceIndex.Value < this.states.Count), + !replaceIndex.HasValue || (replaceIndex.Value >= 0 && replaceIndex.Value < this.statesData.Count), "An invalid replace index provided."); Debug.Assert(!replaceIndex.HasValue || replaceIndex.Value != index, "Replace index must point to a different state."); - this.states.RemoveAt(index); - int stateCount = this.states.Count; + this.statesData.RemoveAt(index); + int stateCount = this.statesData.Count; for (int i = 0; i < stateCount; i++) { - State state = this.states[i]; - state.Index = i; + StateData state = this.statesData[i]; for (int j = state.TransitionCount - 1; j >= 0; j--) { Transition transition = state.GetTransition(j); @@ -2931,20 +2894,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata /// protected Automaton(SerializationInfo info, StreamingContext context) { - this.states = (List)info.GetValue(nameof(this.states), typeof(List)); - foreach (var state in this.states) - { - state.Owner = (TThis)this; - } - this.statesReadOnly = new ReadOnlyList(this.states); - this.startState = (State)info.GetValue(nameof(this.startState), typeof(State)); + this.statesData = (List)info.GetValue(nameof(this.statesData), typeof(List)); + this.startStateIndex = (int)info.GetValue(nameof(this.startStateIndex), typeof(int)); this.isEpsilonFree = (bool?)info.GetValue(nameof(this.isEpsilonFree), typeof(bool?)); } void ISerializable.GetObjectData(SerializationInfo info, StreamingContext context) { - info.AddValue(nameof(this.states), this.states); - info.AddValue(nameof(this.startState), this.startState); + info.AddValue(nameof(this.statesData), this.statesData); + info.AddValue(nameof(this.startStateIndex), this.startStateIndex); info.AddValue(nameof(this.isEpsilonFree), this.isEpsilonFree); } @@ -2959,7 +2917,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata propertyMask[1 << idx++] = this.isEpsilonFree.HasValue && this.isEpsilonFree.Value; propertyMask[1 << idx++] = this.LogValueOverride.HasValue; propertyMask[1 << idx++] = this.PruneTransitionsWithLogWeightLessThan.HasValue; - propertyMask[1 << idx++] = this.startState != null; + propertyMask[1 << idx++] = !this.Start.IsNull; writeInt32(propertyMask.Data); @@ -2973,23 +2931,14 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata writeDouble(this.PruneTransitionsWithLogWeightLessThan.Value); } - if (startState != null) + if (!this.Start.IsNull) { - if (startState.Owner != this) - { - throw new InvalidOperationException("the state owner is not set to the current automaton"); - } - startState.Write(writeInt32, writeDouble, writeElementDistribution); + this.Start.Write(writeInt32, writeDouble, writeElementDistribution); } - writeInt32(states.Count); - foreach (var state in states) + writeInt32(this.statesData.Count); + foreach (var state in this.States) { - if (state.Owner != this) - { - throw new InvalidOperationException("the state owner is not set to the current automaton"); - } - state.Write(writeInt32, writeDouble, writeElementDistribution); } } @@ -3020,20 +2969,24 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata res.PruneTransitionsWithLogWeightLessThan = readDouble(); } - if (hasStartState) - { - res.startState = State.Read(readInt32, readDouble, readElementDistribution); - res.startState.Owner = res; - } + var startState = + hasStartState + ? State.Read(readInt32, readDouble, readElementDistribution) + : default(State); var numStates = readInt32(); - res.states.Clear(); + res.statesData.Clear(); + res.AddStates(numStates); for (var i = 0; i < numStates; i++) { - var state = State.Read(readInt32, readDouble, readElementDistribution); - state.Owner = res; - res.states.Add(state); + res.statesData[i] = State.Read(readInt32, readDouble, readElementDistribution).Data; } + + if (hasStartState) + { + res.startStateIndex = startState.Index; + } + return res; } #endregion diff --git a/src/Runtime/Distributions/Automata/TransducerBase.cs b/src/Runtime/Distributions/Automata/TransducerBase.cs index bcb9039d..3140dfe9 100644 --- a/src/Runtime/Distributions/Automata/TransducerBase.cs +++ b/src/Runtime/Distributions/Automata/TransducerBase.cs @@ -468,7 +468,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } } - destState.EndWeight = srcSequenceIndex == srcSequenceLength ? mappingState.EndWeight : Weight.Zero; + destState.SetEndWeight(srcSequenceIndex == srcSequenceLength ? mappingState.EndWeight : Weight.Zero); return destState; } @@ -541,7 +541,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata } } - destState.EndWeight = Weight.Product(mappingState.EndWeight, srcState.EndWeight); + destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight)); return destState; } diff --git a/src/Runtime/Distributions/SequenceDistribution.cs b/src/Runtime/Distributions/SequenceDistribution.cs index bfe92a3c..886538aa 100644 --- a/src/Runtime/Distributions/SequenceDistribution.cs +++ b/src/Runtime/Distributions/SequenceDistribution.cs @@ -226,7 +226,7 @@ namespace Microsoft.ML.Probabilistic.Distributions { var func = Automaton.Zero(); var end = func.Start.AddTransition(elementDistribution, Weight.One); - end.EndWeight = Weight.One; + end.SetEndWeight(Weight.One); return FromWorkspace(func); } @@ -431,7 +431,7 @@ namespace Microsoft.ML.Probabilistic.Distributions for (int i = 0; i <= iterationBound; i++) { bool isLengthAllowed = i >= minTimes; - state.EndWeight = isLengthAllowed ? Weight.One : Weight.Zero; + state.SetEndWeight(isLengthAllowed ? Weight.One : Weight.Zero); if (i < iterationBound) { state = state.AddTransition(allowedElements, weight); // todo: clone set? diff --git a/test/Tests/Strings/AutomatonTests.cs b/test/Tests/Strings/AutomatonTests.cs index e0d5a44a..ba5fd4b5 100644 --- a/test/Tests/Strings/AutomatonTests.cs +++ b/test/Tests/Strings/AutomatonTests.cs @@ -31,7 +31,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void Clone() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition('a', Weight.One).SetEndWeight(Weight.One); StringAutomaton clone = automaton.Clone(); Assert.Equal(automaton, clone); @@ -390,8 +390,8 @@ namespace Microsoft.ML.Probabilistic.Tests public void ProductNoDeadBranches() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition('a', Weight.One).AddTransition('c', Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition('a', Weight.One).AddTransition('c', Weight.One).SetEndWeight(Weight.One); StringAutomaton automatonSqr = automaton.Product(automaton); Assert.Equal(4, automatonSqr.States.Count); } @@ -462,8 +462,8 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); var otherState = automaton.Start.AddSelfTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.FromValue(0.7)); - automaton.Start.EndWeight = Weight.FromValue(0.3); - otherState.EndWeight = Weight.FromValue(0.8); + automaton.Start.SetEndWeight(Weight.FromValue(0.3)); + otherState.SetEndWeight(Weight.FromValue(0.8)); StringAutomaton reverse = automaton.Reverse(); StringInferenceTestUtilities.TestValue(reverse, 0.7 * 0.8, "b"); @@ -616,7 +616,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddTransitionsForSequence("abc"); - automaton.Start.AddTransitionsForSequence("def").EndWeight = Weight.FromValue(4.0); + automaton.Start.AddTransitionsForSequence("def").SetEndWeight(Weight.FromValue(4.0)); double logNormalizer; Assert.Equal(Math.Log(4.0), automaton.GetLogNormalizer(), 1e-8); Assert.True(automaton.TryNormalizeValues(out logNormalizer)); @@ -637,7 +637,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddTransition('a', Weight.FromValue(TransitionProbability), automaton.Start); - automaton.Start.EndWeight = Weight.FromValue(EndWeight); + automaton.Start.SetEndWeight(Weight.FromValue(EndWeight)); double logNormalizer = automaton.GetLogNormalizer(); Assert.Equal(Math.Log(EndWeight / (1 - TransitionProbability)), logNormalizer, 1e-8); Assert.Equal(logNormalizer, automaton.NormalizeValues()); @@ -672,7 +672,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddTransition('a', Weight.FromValue(1.01), automaton.Start); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); TestNonNormalizable(automaton, false); } @@ -722,7 +722,7 @@ namespace Microsoft.ML.Probabilistic.Tests state = nextState; } - state.EndWeight = Weight.One; + state.SetEndWeight(Weight.One); var closure = automaton.Start.GetEpsilonClosure(); @@ -772,7 +772,7 @@ namespace Microsoft.ML.Probabilistic.Tests state = state.AddTransition('a', Weight.One); } - state.EndWeight = Weight.One; + state.SetEndWeight(Weight.One); string point = new string('a', stateCount - 1); @@ -803,10 +803,10 @@ namespace Microsoft.ML.Probabilistic.Tests Assert.True(automaton2.IsZero()); // Null states collection - Assert.Throws(() => StringAutomaton.FromStates(null, null)); + Assert.Throws(() => StringAutomaton.FromStates(null, default(StringAutomaton.State))); // Null start state - Assert.Throws(() => StringAutomaton.FromStates(new[] { theOnlyState }, null)); + Assert.Throws(() => StringAutomaton.FromStates(new[] { theOnlyState }, default(StringAutomaton.State))); // Duplicate state indices Assert.Throws( @@ -874,7 +874,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void ConvertToString1() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('b', Weight.One).AddTransition('c', Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('b', Weight.One).AddTransition('c', Weight.One).SetEndWeight(Weight.One); Assert.Equal("ab*c", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("ab*c", automaton.ToString(AutomatonFormats.Regexp)); } @@ -890,7 +890,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var middleState = automaton.Start.AddTransition('b', Weight.One).AddTransition('c', Weight.One); automaton.Start.AddTransition('a', Weight.One, middleState); - middleState.AddTransition('d', Weight.One).EndWeight = Weight.One; + middleState.AddTransition('d', Weight.One).SetEndWeight(Weight.One); Assert.Equal("(bc|a)d", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("(bc|a)d", automaton.ToString(AutomatonFormats.Regexp)); } @@ -903,7 +903,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void ConvertToString3() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransitionsForSequence("hello").EndWeight = Weight.One; + automaton.Start.AddTransitionsForSequence("hello").SetEndWeight(Weight.One); Assert.Equal("hello", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("hello", automaton.ToString(AutomatonFormats.Regexp)); } @@ -916,7 +916,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void ConvertToString4() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransitionsForSequence("hello").EndWeight = Weight.Zero; + automaton.Start.AddTransitionsForSequence("hello").SetEndWeight(Weight.Zero); Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Regexp)); } @@ -929,9 +929,9 @@ namespace Microsoft.ML.Probabilistic.Tests public void ConvertToString5() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransitionsForSequence("hello").EndWeight = Weight.One; - automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hi").EndWeight = Weight.One; - automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hey").EndWeight = Weight.One; + automaton.Start.AddTransitionsForSequence("hello").SetEndWeight(Weight.One); + automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hi").SetEndWeight(Weight.One); + automaton.Start.AddEpsilonTransition(Weight.One).AddTransitionsForSequence("hey").SetEndWeight(Weight.One); Assert.Equal("hey|hi|hello", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("hey|hi|hello", automaton.ToString(AutomatonFormats.Regexp)); } @@ -1048,7 +1048,7 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[5].AddTransition('l', Weight.FromValue(1), automaton.States[6]); automaton.States[6].AddTransition('m', Weight.FromValue(1), automaton.States[3]); automaton.States[6].AddTransition('n', Weight.FromValue(1), automaton.States[1]); - automaton.States[7].EndWeight = Weight.FromValue(1); + automaton.States[7].SetEndWeight(Weight.FromValue(1)); var distribution = StringDistribution.FromWorkspace(automaton); var regexPattern = distribution.ToRegex(); @@ -1095,7 +1095,7 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[0].AddTransition('o', Weight.FromValue(1), automaton.States[6]); automaton.States[1].AddTransition('p', Weight.FromValue(1), automaton.States[7]); automaton.States[6].AddTransition('q', Weight.FromValue(1), automaton.States[7]); - automaton.States[7].EndWeight = Weight.FromValue(1); + automaton.States[7].SetEndWeight(Weight.FromValue(1)); var distribution = StringDistribution.FromWorkspace(automaton); var regexPattern = distribution.ToRegex(); @@ -1176,7 +1176,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void Equality3() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); automaton.Start.AddTransition(DiscreteChar.Lower(), Weight.FromLogValue(26 + 1e-3), automaton.Start); AssertEquals(automaton, automaton); @@ -1191,7 +1191,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void Equality4() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); automaton.Start.AddTransition('a', Weight.FromLogValue(1.0 - 1e-3), automaton.Start); AssertEquals(automaton, automaton); @@ -1222,7 +1222,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton func1 = StringAutomaton.Constant(1.0, DiscreteChar.OneOf('a', 'b')); StringAutomaton func2 = StringAutomaton.Zero(); - func2.Start.AddSelfTransition('a', Weight.One).AddSelfTransition('b', Weight.One).EndWeight = Weight.One; + func2.Start.AddSelfTransition('a', Weight.One).AddSelfTransition('b', Weight.One).SetEndWeight(Weight.One); AssertEquals(func1, func2); } @@ -1235,13 +1235,13 @@ namespace Microsoft.ML.Probabilistic.Tests public void Equality7() { StringAutomaton func1 = StringAutomaton.Zero(); - func1.Start.AddSelfTransition(DiscreteChar.PointMass('a'), Weight.One).EndWeight = Weight.One; + func1.Start.AddSelfTransition(DiscreteChar.PointMass('a'), Weight.One).SetEndWeight(Weight.One); StringAutomaton func2 = StringAutomaton.Zero(); func2.Start .AddEpsilonTransition(Weight.One) .AddTransition(DiscreteChar.PointMass('a'), Weight.One, func2.Start) - .EndWeight = Weight.One; + .SetEndWeight(Weight.One); AssertEquals(func1, func2); } @@ -1258,10 +1258,10 @@ namespace Microsoft.ML.Probabilistic.Tests public void Simplify1() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).EndWeight = - Weight.FromValue(2.0); - automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight = - Weight.FromValue(3.0); + automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).SetEndWeight( + Weight.FromValue(2.0)); + automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight( + Weight.FromValue(3.0)); for (int i = 0; i < 3; ++i) { @@ -1281,10 +1281,10 @@ namespace Microsoft.ML.Probabilistic.Tests public void Simplify2() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight = - Weight.FromValue(3.0); - automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).EndWeight = - Weight.FromValue(2.0); + automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight( + Weight.FromValue(3.0)); + automaton.Start.AddTransition('a', Weight.One).AddTransition('b', Weight.One).SetEndWeight( + Weight.FromValue(2.0)); for (int i = 0; i < 3; ++i) { @@ -1304,10 +1304,10 @@ namespace Microsoft.ML.Probabilistic.Tests public void Simplify3() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight = - Weight.FromValue(2.0); - automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).EndWeight = - Weight.FromValue(3.0); + automaton.Start.AddTransition('a', Weight.One).AddTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight( + Weight.FromValue(2.0)); + automaton.Start.AddTransition('a', Weight.One).AddSelfTransition('d', Weight.One).AddTransition('c', Weight.One).SetEndWeight( + Weight.FromValue(3.0)); for (int i = 0; i < 3; ++i) { @@ -1327,11 +1327,11 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddEpsilonTransition(Weight.One).AddSelfTransition('a', Weight.One) - .AddEpsilonTransition(Weight.One).AddSelfTransition('b', Weight.One).EndWeight = Weight.FromValue(2.0); + .AddEpsilonTransition(Weight.One).AddSelfTransition('b', Weight.One).SetEndWeight(Weight.FromValue(2.0)); automaton.Start.AddEpsilonTransition(Weight.One).AddSelfTransition('a', Weight.One) - .AddEpsilonTransition(Weight.One).AddSelfTransition('c', Weight.One).EndWeight = Weight.FromValue(3.0); + .AddEpsilonTransition(Weight.One).AddSelfTransition('c', Weight.One).SetEndWeight(Weight.FromValue(3.0)); automaton.Start.AddSelfTransition('x', Weight.One); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); for (int i = 0; i < 3; ++i) { @@ -1361,12 +1361,12 @@ namespace Microsoft.ML.Probabilistic.Tests state = nextState; } - state.EndWeight = Weight.One; + state.SetEndWeight(Weight.One); const int AdditionalSequenceCount = 5; for (int i = 0; i < AdditionalSequenceCount; ++i) { - automaton.Start.AddTransitionsForSequence(AcceptedSequence).EndWeight = Weight.One; + automaton.Start.AddTransitionsForSequence(AcceptedSequence).SetEndWeight(Weight.One); } for (int i = 0; i < 3; ++i) @@ -1393,8 +1393,8 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[2].AddTransition('a', Weight.FromValue(6.0), automaton.States[5]); automaton.States[4].AddTransition('a', Weight.FromValue(5.0), automaton.States[2]); - automaton.States[3].EndWeight = Weight.FromValue(2.0); - automaton.States[5].EndWeight = Weight.FromValue(3.0); + automaton.States[3].SetEndWeight(Weight.FromValue(2.0)); + automaton.States[5].SetEndWeight(Weight.FromValue(3.0)); for (int i = 0; i < 3; ++i) { @@ -1417,13 +1417,13 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var branch1 = automaton.Start.AddEpsilonTransition(Weight.FromValue(0.5)).AddTransition('a', Weight.FromValue(1.0 / 3.0)).AddTransition('B', Weight.FromValue(1.0 / 4.0)); - branch1.EndWeight = Weight.FromValue(3.0); - branch1.AddTransition('X', Weight.FromValue(1.0 / 6.0)).EndWeight = Weight.FromValue(5.0); - branch1.AddEpsilonTransition(Weight.FromValue(1.0 / 8.0)).EndWeight = Weight.FromValue(7.0); + branch1.SetEndWeight(Weight.FromValue(3.0)); + branch1.AddTransition('X', Weight.FromValue(1.0 / 6.0)).SetEndWeight(Weight.FromValue(5.0)); + branch1.AddEpsilonTransition(Weight.FromValue(1.0 / 8.0)).SetEndWeight(Weight.FromValue(7.0)); var branch2 = automaton.Start.AddTransition(lowerEnglish, Weight.FromValue(2.0)); - branch2.EndWeight = Weight.FromValue(4.0); + branch2.SetEndWeight(Weight.FromValue(4.0)); branch2.AddTransition(upperEnglish, Weight.FromValue(3.0), branch2); - branch2.AddTransition('X', Weight.FromValue(4.0)).EndWeight = Weight.FromValue(5.0); + branch2.AddTransition('X', Weight.FromValue(4.0)).SetEndWeight(Weight.FromValue(5.0)); for (int i = 0; i < 3; ++i) { @@ -1464,10 +1464,10 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[8].AddTransition('b', Weight.One, automaton.States[8]); automaton.States[8].AddTransition('a', Weight.One, automaton.States[9]); - automaton.States[3].EndWeight = Weight.FromValue(0.1); - automaton.States[6].EndWeight = Weight.FromValue(0.2); - automaton.States[9].EndWeight = Weight.FromValue(0.3); - automaton.States[10].EndWeight = Weight.FromValue(0.4); + automaton.States[3].SetEndWeight(Weight.FromValue(0.1)); + automaton.States[6].SetEndWeight(Weight.FromValue(0.2)); + automaton.States[9].SetEndWeight(Weight.FromValue(0.3)); + automaton.States[10].SetEndWeight(Weight.FromValue(0.4)); for (int i = 0; i < 3; ++i) { @@ -1518,9 +1518,9 @@ namespace Microsoft.ML.Probabilistic.Tests public void Determinize1() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).EndWeight = Weight.FromValue(4); - automaton.Start.AddTransition('a', Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).EndWeight = Weight.FromValue(7); - automaton.Start.EndWeight = Weight.FromValue(17); + automaton.Start.AddTransition('a', Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.FromValue(4)); + automaton.Start.AddTransition('a', Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7)); + automaton.Start.SetEndWeight(Weight.FromValue(17)); Assert.False(automaton.IsDeterministic()); @@ -1546,9 +1546,9 @@ namespace Microsoft.ML.Probabilistic.Tests public void Determinize2() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).EndWeight = Weight.FromValue(4); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).EndWeight = Weight.FromValue(7); - automaton.Start.EndWeight = Weight.FromValue(17); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.FromValue(4)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7)); + automaton.Start.SetEndWeight(Weight.FromValue(17)); Assert.False(automaton.IsDeterministic()); @@ -1574,10 +1574,10 @@ namespace Microsoft.ML.Probabilistic.Tests public void Determinize3() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition(DiscreteChar.Uniform(), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).EndWeight = Weight.FromValue(4); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).EndWeight = Weight.FromValue(7); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('x', 'z'), Weight.FromValue(8)).AddTransition('d', Weight.FromValue(9)).EndWeight = Weight.FromValue(10); - automaton.Start.EndWeight = Weight.FromValue(17); + automaton.Start.AddTransition(DiscreteChar.Uniform(), Weight.FromValue(2)).AddTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.FromValue(4)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'z'), Weight.FromValue(5)).AddTransition('c', Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('x', 'z'), Weight.FromValue(8)).AddTransition('d', Weight.FromValue(9)).SetEndWeight(Weight.FromValue(10)); + automaton.Start.SetEndWeight(Weight.FromValue(17)); Assert.False(automaton.IsDeterministic()); @@ -1607,8 +1607,8 @@ namespace Microsoft.ML.Probabilistic.Tests public void Determinize4() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).EndWeight = Weight.FromValue(4); - automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('d', Weight.FromValue(6.0)).EndWeight = Weight.FromValue(7); + automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).SetEndWeight(Weight.FromValue(4)); + automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('d', Weight.FromValue(6.0)).SetEndWeight(Weight.FromValue(7)); Assert.False(automaton.IsDeterministic()); @@ -1637,8 +1637,8 @@ namespace Microsoft.ML.Probabilistic.Tests public void Determinize5() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).EndWeight = Weight.FromValue(4); - automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(6.0)).EndWeight = Weight.FromValue(7); + automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).SetEndWeight(Weight.FromValue(4)); + automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(6.0)).SetEndWeight(Weight.FromValue(7)); Assert.False(automaton.IsDeterministic()); @@ -1664,11 +1664,11 @@ namespace Microsoft.ML.Probabilistic.Tests public void Determinize6() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'c'), Weight.FromValue(2)).EndWeight = Weight.FromValue(3.0); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'c'), Weight.FromValue(4)).EndWeight = Weight.FromValue(5.0); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'd'), Weight.FromValue(6)).EndWeight = Weight.FromValue(7.0); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'd'), Weight.FromValue(8)).EndWeight = Weight.FromValue(9.0); - automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'e'), Weight.FromValue(10)).EndWeight = Weight.FromValue(11.0); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('a', 'c'), Weight.FromValue(2)).SetEndWeight(Weight.FromValue(3.0)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'c'), Weight.FromValue(4)).SetEndWeight(Weight.FromValue(5.0)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('b', 'd'), Weight.FromValue(6)).SetEndWeight(Weight.FromValue(7.0)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'd'), Weight.FromValue(8)).SetEndWeight(Weight.FromValue(9.0)); + automaton.Start.AddTransition(DiscreteChar.UniformInRange('d', 'e'), Weight.FromValue(10)).SetEndWeight(Weight.FromValue(11.0)); Assert.False(automaton.IsDeterministic()); @@ -1706,9 +1706,9 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[2].AddTransition('b', Weight.FromValue(6), automaton.States[4]); automaton.States[3].AddTransition('c', Weight.FromValue(7), automaton.States[4]); - automaton.States[2].EndWeight = Weight.FromValue(0.5); - automaton.States[3].EndWeight = Weight.FromValue(1); - automaton.States[4].EndWeight = Weight.FromValue(2); + automaton.States[2].SetEndWeight(Weight.FromValue(0.5)); + automaton.States[3].SetEndWeight(Weight.FromValue(1)); + automaton.States[4].SetEndWeight(Weight.FromValue(2)); Assert.False(automaton.IsDeterministic()); @@ -1747,7 +1747,7 @@ namespace Microsoft.ML.Probabilistic.Tests state = nextState; } - state.EndWeight = Weight.One; + state.SetEndWeight(Weight.One); Assert.False(automaton.IsDeterministic()); @@ -1777,11 +1777,11 @@ namespace Microsoft.ML.Probabilistic.Tests const int TransitionsPerCharacter = 3; for (int i = 0; i < TransitionsPerCharacter; ++i) { - automaton.Start.AddTransition('a', Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition('b', Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition('d', Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition('e', Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition('g', Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition('a', Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition('b', Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition('d', Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition('e', Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition('g', Weight.One).SetEndWeight(Weight.One); } Assert.False(automaton.IsDeterministic() || TransitionsPerCharacter <= 1); @@ -1817,9 +1817,9 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[2].AddTransition('b', Weight.FromValue(6), automaton.States[4]); automaton.States[3].AddTransition('c', Weight.FromValue(7), automaton.States[4]); - automaton.States[2].EndWeight = Weight.FromValue(0.5); - automaton.States[3].EndWeight = Weight.FromValue(1); - automaton.States[4].EndWeight = Weight.FromValue(2); + automaton.States[2].SetEndWeight(Weight.FromValue(0.5)); + automaton.States[3].SetEndWeight(Weight.FromValue(1)); + automaton.States[4].SetEndWeight(Weight.FromValue(2)); Assert.False(automaton.IsDeterministic()); @@ -1846,8 +1846,8 @@ namespace Microsoft.ML.Probabilistic.Tests public void NonDeterminizable1() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).EndWeight = Weight.FromValue(4); - automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.1)).AddTransition('c', Weight.FromValue(6.0)).EndWeight = Weight.FromValue(7); + automaton.Start.AddTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(0.5)).AddTransition('c', Weight.FromValue(3.0)).SetEndWeight(Weight.FromValue(4)); + automaton.Start.AddTransition('a', Weight.FromValue(5)).AddSelfTransition('b', Weight.FromValue(0.1)).AddTransition('c', Weight.FromValue(6.0)).SetEndWeight(Weight.FromValue(7)); Assert.False(automaton.IsDeterministic()); @@ -1887,11 +1887,11 @@ namespace Microsoft.ML.Probabilistic.Tests const int TransitionsPerCharacter = 3; for (int i = 0; i < TransitionsPerCharacter; ++i) { - automaton.Start.AddTransition("a", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("b", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("d", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("e", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("g", Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition("a", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("b", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("d", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("e", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("g", Weight.One).SetEndWeight(Weight.One); } Assert.False(automaton.IsDeterministic() || TransitionsPerCharacter <= 1); @@ -1923,11 +1923,11 @@ namespace Microsoft.ML.Probabilistic.Tests const int TransitionsPerCharacter = 3; for (int i = 0; i < TransitionsPerCharacter; ++i) { - automaton.Start.AddTransition("a", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("b", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("d", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition("e", Weight.One).EndWeight = Weight.One; - automaton.Start.AddTransition(scaledUniform, Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition("a", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("b", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("d", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition("e", Weight.One).SetEndWeight(Weight.One); + automaton.Start.AddTransition(scaledUniform, Weight.One).SetEndWeight(Weight.One); } Assert.False(automaton.IsDeterministic() || TransitionsPerCharacter <= 1); @@ -1964,10 +1964,10 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[4].AddTransition(DiscreteChar.UniformOver('i', 'j'), Weight.FromValue(1), automaton.States[5]); automaton.States[4].AddTransition(DiscreteChar.UniformOver('k', 'l'), Weight.FromValue(1), automaton.States[5]); - automaton.States[1].EndWeight = Weight.FromValue(1); - automaton.States[3].EndWeight = Weight.FromValue(1); - automaton.States[5].EndWeight = Weight.FromValue(1); - automaton.States[6].EndWeight = Weight.FromValue(1); + automaton.States[1].SetEndWeight(Weight.FromValue(1)); + automaton.States[3].SetEndWeight(Weight.FromValue(1)); + automaton.States[5].SetEndWeight(Weight.FromValue(1)); + automaton.States[6].SetEndWeight(Weight.FromValue(1)); var expectedSupport = new HashSet { @@ -1999,10 +1999,10 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[4].AddTransition(DiscreteChar.UniformOver('i', 'j'), Weight.FromValue(1), automaton.States[5]); automaton.States[4].AddTransition(DiscreteChar.UniformOver('k', 'l'), Weight.FromValue(1), automaton.States[5]); - automaton.States[1].EndWeight = Weight.FromValue(1); - automaton.States[3].EndWeight = Weight.FromValue(1); - automaton.States[5].EndWeight = Weight.FromValue(1); - automaton.States[6].EndWeight = Weight.FromValue(1); + automaton.States[1].SetEndWeight(Weight.FromValue(1)); + automaton.States[3].SetEndWeight(Weight.FromValue(1)); + automaton.States[5].SetEndWeight(Weight.FromValue(1)); + automaton.States[6].SetEndWeight(Weight.FromValue(1)); int numPasses = 10000; Stopwatch watch = new Stopwatch(); diff --git a/test/Tests/Strings/LoopyAutomatonTests.cs b/test/Tests/Strings/LoopyAutomatonTests.cs index 4a24cf21..5c947818 100644 --- a/test/Tests/Strings/LoopyAutomatonTests.cs +++ b/test/Tests/Strings/LoopyAutomatonTests.cs @@ -109,7 +109,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton f = StringAutomaton.Zero(); AddEpsilonLoop(f.Start, 5, 0.5); - f.Start.AddTransitionsForSequence("abc").EndWeight = Weight.One; + f.Start.AddTransitionsForSequence("abc").SetEndWeight(Weight.One); Assert.Equal("abc", f.TryComputePoint()); } @@ -122,7 +122,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton f = StringAutomaton.Zero(); f.Start.AddTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.Zero, f.Start); - f.Start.AddTransitionsForSequence("abc").EndWeight = Weight.One; + f.Start.AddTransitionsForSequence("abc").SetEndWeight(Weight.One); Assert.Equal("abc", f.TryComputePoint()); } @@ -135,7 +135,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton f = StringAutomaton.Zero(); f.Start.AddTransition('a', Weight.FromValue(0.5)).AddSelfTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.One); - f.Start.AddTransition('b', Weight.FromValue(0.5)).EndWeight = Weight.One; + f.Start.AddTransition('b', Weight.FromValue(0.5)).SetEndWeight(Weight.One); Assert.Equal("b", f.TryComputePoint()); } @@ -147,7 +147,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void NoPoint1() { StringAutomaton f = StringAutomaton.Zero(); - f.Start.AddTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.FromValue(0.5), f.Start).EndWeight = Weight.One; + f.Start.AddTransition('a', Weight.FromValue(0.5)).AddTransition('b', Weight.FromValue(0.5), f.Start).SetEndWeight(Weight.One); Assert.Null(f.TryComputePoint()); } @@ -160,7 +160,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton f = StringAutomaton.Zero(); var state = f.Start.AddTransition('a', Weight.FromValue(0.5)); - state.EndWeight = Weight.One; + state.SetEndWeight(Weight.One); state.AddTransition('b', Weight.FromValue(0.5), f.Start); Assert.Null(f.TryComputePoint()); } @@ -216,7 +216,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton f = StringAutomaton.Zero(); f.Start.AddSelfTransition('x', Weight.Zero); - f.Start.AddTransition('y', Weight.Zero).EndWeight = Weight.One; + f.Start.AddTransition('y', Weight.Zero).SetEndWeight(Weight.One); Assert.True(f.IsZero()); } @@ -232,10 +232,10 @@ namespace Microsoft.ML.Probabilistic.Tests public void LoopyArithmetic() { StringAutomaton automaton1 = StringAutomaton.Zero(); - automaton1.Start.AddTransition('a', Weight.FromValue(4.0)).AddTransition('b', Weight.One, automaton1.Start).EndWeight = Weight.One; + automaton1.Start.AddTransition('a', Weight.FromValue(4.0)).AddTransition('b', Weight.One, automaton1.Start).SetEndWeight(Weight.One); StringAutomaton automaton2 = StringAutomaton.Zero(); - automaton2.Start.AddSelfTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(3)).EndWeight = Weight.One; + automaton2.Start.AddSelfTransition('a', Weight.FromValue(2)).AddSelfTransition('b', Weight.FromValue(3)).SetEndWeight(Weight.One); StringAutomaton sum = automaton1.Sum(automaton2); StringInferenceTestUtilities.TestValue(sum, 2.0, string.Empty); @@ -272,7 +272,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.FromValue(0.7)); - automaton.Start.EndWeight = Weight.FromValue(0.3); + automaton.Start.SetEndWeight(Weight.FromValue(0.3)); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); @@ -289,11 +289,11 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); var state = automaton.Start; - state.EndWeight = Weight.FromValue(0.1); + state.SetEndWeight(Weight.FromValue(0.1)); state.AddSelfTransition('a', Weight.FromValue(0.7)); state = state.AddTransition('b', Weight.FromValue(0.2)); state.AddSelfTransition('a', Weight.FromValue(0.4)); - state.EndWeight = Weight.FromValue(0.6); + state.SetEndWeight(Weight.FromValue(0.6)); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); @@ -311,14 +311,14 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.FromValue(0.7)); - automaton.Start.EndWeight = Weight.FromValue(0.1); + automaton.Start.SetEndWeight(Weight.FromValue(0.1)); var state1 = automaton.Start.AddTransition('b', Weight.FromValue(0.15)); state1.AddSelfTransition('a', Weight.FromValue(0.4)); - state1.EndWeight = Weight.FromValue(0.6); + state1.SetEndWeight(Weight.FromValue(0.6)); var state2 = automaton.Start.AddTransition('c', Weight.FromValue(0.05)); - state2.EndWeight = Weight.One; + state2.SetEndWeight(Weight.One); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); @@ -336,11 +336,11 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var state = automaton.Start.AddTransition('a', Weight.FromValue(0.9)); - state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One; + state.AddTransition('a', Weight.FromValue(0.1)).SetEndWeight(Weight.One); state = state.AddTransition('a', Weight.FromValue(0.9)); - state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One; + state.AddTransition('a', Weight.FromValue(0.1)).SetEndWeight(Weight.One); state = state.AddTransition('a', Weight.FromValue(0.9), automaton.Start); - state.AddTransition('a', Weight.FromValue(0.1)).EndWeight = Weight.One; + state.AddTransition('a', Weight.FromValue(0.1)).SetEndWeight(Weight.One); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); @@ -358,7 +358,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); - endState.EndWeight = Weight.FromValue(5.0); + endState.SetEndWeight(Weight.FromValue(5.0)); endState.AddTransition('b', Weight.FromValue(0.25), automaton.Start); endState.AddTransition('c', Weight.FromValue(0.2), automaton.Start); @@ -377,7 +377,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); - endState.EndWeight = Weight.FromValue(5.0); + endState.SetEndWeight(Weight.FromValue(5.0)); endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start); endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start); endState.AddSelfTransition('!', Weight.FromValue(0.5)); @@ -407,13 +407,13 @@ namespace Microsoft.ML.Probabilistic.Tests AddEpsilonLoop(automaton.Start, 3, 0.2); AddEpsilonLoop(automaton.Start, 5, 0.3); - automaton.Start.EndWeight = Weight.FromValue(0.1); + automaton.Start.SetEndWeight(Weight.FromValue(0.1)); var nextState = automaton.Start.AddTransition('a', Weight.FromValue(0.4)); - nextState.EndWeight = Weight.FromValue(0.6); + nextState.SetEndWeight(Weight.FromValue(0.6)); AddEpsilonLoop(nextState, 0, 0.3); nextState = nextState.AddTransition('b', Weight.FromValue(0.1)); AddEpsilonLoop(nextState, 1, 0.9); - nextState.EndWeight = Weight.FromValue(0.1); + nextState.SetEndWeight(Weight.FromValue(0.1)); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); @@ -433,21 +433,21 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.States[0].AddEpsilonTransition(Weight.FromValue(0.2), automaton.States[1]); automaton.States[0].AddEpsilonTransition(Weight.FromValue(0.5), automaton.States[3]); - automaton.States[0].EndWeight = Weight.FromValue(0.3); + automaton.States[0].SetEndWeight(Weight.FromValue(0.3)); automaton.States[1].AddEpsilonTransition(Weight.FromValue(0.8), automaton.States[0]); automaton.States[1].AddEpsilonTransition(Weight.FromValue(0.1), automaton.States[2]); - automaton.States[1].EndWeight = Weight.FromValue(0.1); - automaton.States[2].EndWeight = Weight.FromValue(1.0); + automaton.States[1].SetEndWeight(Weight.FromValue(0.1)); + automaton.States[2].SetEndWeight(Weight.FromValue(1.0)); automaton.States[3].AddEpsilonTransition(Weight.FromValue(0.2), automaton.States[4]); automaton.States[3].AddEpsilonTransition(Weight.FromValue(0.1), automaton.States[5]); - automaton.States[3].EndWeight = Weight.FromValue(0.7); + automaton.States[3].SetEndWeight(Weight.FromValue(0.7)); automaton.States[4].AddEpsilonTransition(Weight.FromValue(0.5), automaton.States[2]); automaton.States[4].AddEpsilonTransition(Weight.FromValue(0.5), automaton.States[6]); - automaton.States[4].EndWeight = Weight.FromValue(0.0); + automaton.States[4].SetEndWeight(Weight.FromValue(0.0)); automaton.States[5].AddEpsilonTransition(Weight.FromValue(0.1), automaton.States[3]); automaton.States[5].AddEpsilonTransition(Weight.FromValue(0.9), automaton.States[6]); - automaton.States[5].EndWeight = Weight.Zero; - automaton.States[6].EndWeight = Weight.One; + automaton.States[5].SetEndWeight(Weight.Zero); + automaton.States[6].SetEndWeight(Weight.One); AssertStochastic(automaton); Assert.Equal(0.0, automaton.GetLogNormalizer(), 1e-6); @@ -465,7 +465,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var endState = automaton.Start.AddTransition('a', Weight.FromValue(3.5)); - endState.EndWeight = Weight.FromValue(5.0); + endState.SetEndWeight(Weight.FromValue(5.0)); endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start); endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start); endState.AddSelfTransition('!', Weight.FromValue(0.5)); @@ -486,7 +486,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var endState = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); - endState.EndWeight = Weight.FromValue(5.0); + endState.SetEndWeight(Weight.FromValue(5.0)); endState.AddTransition('b', Weight.FromValue(0.1), automaton.Start); endState.AddTransition('c', Weight.FromValue(0.05), automaton.Start); endState.AddSelfTransition('!', Weight.FromValue(0.75)); @@ -506,7 +506,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddTransition('a', Weight.FromValue(2.0), automaton.Start); - automaton.Start.EndWeight = Weight.FromValue(5.0); + automaton.Start.SetEndWeight(Weight.FromValue(5.0)); StringAutomaton copyOfAutomaton = automaton.Clone(); Assert.Throws(() => automaton.NormalizeValues()); @@ -525,9 +525,9 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.Start.AddSelfTransition('a', Weight.FromValue(0.1)); var branch1 = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); branch1.AddSelfTransition('a', Weight.FromValue(2.0)); - branch1.EndWeight = Weight.One; + branch1.SetEndWeight(Weight.One); var branch2 = automaton.Start.AddTransition('a', Weight.FromValue(2.0)); - branch2.EndWeight = Weight.One; + branch2.SetEndWeight(Weight.One); StringAutomaton copyOfAutomaton = automaton.Clone(); Assert.Throws(() => automaton.NormalizeValues()); @@ -543,7 +543,7 @@ namespace Microsoft.ML.Probabilistic.Tests public void NormalizeWithInfiniteEpsilon1() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(3)).EndWeight = Weight.One; + automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(3)).SetEndWeight(Weight.One); // The automaton takes an infinite value on "a", and yet the normalization must work Assert.True(automaton.TryNormalizeValues()); @@ -559,8 +559,8 @@ namespace Microsoft.ML.Probabilistic.Tests public void NormalizeWithInfiniteEpsilon2() { StringAutomaton automaton = StringAutomaton.Zero(); - automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(2)).EndWeight = Weight.One; - automaton.Start.AddTransition('b', Weight.One).AddSelfTransition(null, Weight.FromValue(1)).EndWeight = Weight.One; + automaton.Start.AddTransition('a', Weight.One).AddSelfTransition(null, Weight.FromValue(2)).SetEndWeight(Weight.One); + automaton.Start.AddTransition('b', Weight.One).AddSelfTransition(null, Weight.FromValue(1)).SetEndWeight(Weight.One); // "a" branch infinitely dominates over the "b" branch Assert.True(automaton.TryNormalizeValues()); @@ -583,7 +583,7 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.Start.AddEpsilonTransition(Weight.FromValue(0.5), automaton.Start); var nextState = automaton.Start.AddEpsilonTransition(Weight.FromValue(0.4)); nextState.AddEpsilonTransition(Weight.One).AddEpsilonTransition(Weight.One, automaton.Start); - automaton.Start.EndWeight = Weight.FromValue(0.1); + automaton.Start.SetEndWeight(Weight.FromValue(0.1)); AssertStochastic(automaton); @@ -614,7 +614,7 @@ namespace Microsoft.ML.Probabilistic.Tests var middleNode = automaton.Start.AddTransition('a', Weight.One); middleNode.AddTransitionsForSequence("bbb", automaton.Start); middleNode.AddTransition('c', Weight.One, automaton.Start); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); Assert.Equal("(a(c|bbb))*", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("(a(c|bbb))*", automaton.ToString(AutomatonFormats.Regexp)); @@ -630,7 +630,7 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.One); automaton.Start.AddSelfTransition('b', Weight.One); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); Assert.Equal("(a|b)*", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("(a|b)*", automaton.ToString(AutomatonFormats.Regexp)); @@ -649,8 +649,8 @@ namespace Microsoft.ML.Probabilistic.Tests automaton.Start.AddTransition('y', Weight.One, state); state.AddSelfTransition('a', Weight.One); state.AddSelfTransition('b', Weight.One); - state.EndWeight = Weight.One; - state.AddTransitionsForSequence("zzz").EndWeight = Weight.One; + state.SetEndWeight(Weight.One); + state.AddTransitionsForSequence("zzz").SetEndWeight(Weight.One); Assert.Equal("(x|y)(a|b)*[zzz]", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("(x|y)(a|b)*(|zzz)", automaton.ToString(AutomatonFormats.Regexp)); @@ -665,7 +665,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddTransitionsForSequence("xyz", automaton.Start); - automaton.Start.AddTransition('!', Weight.One).EndWeight = Weight.One; + automaton.Start.AddTransition('!', Weight.One).SetEndWeight(Weight.One); Assert.Equal("(xyz)*!", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("(xyz)*!", automaton.ToString(AutomatonFormats.Regexp)); } @@ -681,7 +681,7 @@ namespace Microsoft.ML.Probabilistic.Tests var state = automaton.Start.AddTransition('x', Weight.One); automaton.Start.AddTransition('y', Weight.Zero, state); state.AddSelfTransition('a', Weight.One); - state.EndWeight = Weight.One; + state.SetEndWeight(Weight.One); Assert.Equal("xa*", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("xa*", automaton.ToString(AutomatonFormats.Regexp)); @@ -696,7 +696,7 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('x', Weight.Zero); - automaton.Start.AddTransition('y', Weight.Zero).EndWeight = Weight.One; + automaton.Start.AddTransition('y', Weight.Zero).SetEndWeight(Weight.One); Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Friendly)); Assert.Equal("Ø", automaton.ToString(AutomatonFormats.Regexp)); @@ -766,7 +766,7 @@ namespace Microsoft.ML.Probabilistic.Tests { currentState = currentState.AddEpsilonTransition( i == 0 ? Weight.FromValue(loopWeight) : Weight.One, - i == loopSize ? state : null); + i == loopSize ? state : default(StringAutomaton.State)); } } diff --git a/test/Tests/Strings/SequenceDistributionTests.cs b/test/Tests/Strings/SequenceDistributionTests.cs index e13cb1b9..53980a9d 100644 --- a/test/Tests/Strings/SequenceDistributionTests.cs +++ b/test/Tests/Strings/SequenceDistributionTests.cs @@ -579,7 +579,7 @@ namespace Microsoft.ML.Probabilistic.Tests // The length of sequences sampled from this distribution must follow a geometric distribution StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start = automaton.AddState(); - automaton.Start.EndWeight = Weight.FromValue(StoppingProbability); + automaton.Start.SetEndWeight(Weight.FromValue(StoppingProbability)); automaton.Start.AddTransition('a', Weight.FromValue(1 - StoppingProbability), automaton.Start); StringDistribution dist = StringDistribution.FromWeightFunction(automaton); diff --git a/test/Tests/Strings/StringInferencePerformanceTests.cs b/test/Tests/Strings/StringInferencePerformanceTests.cs index 4c88ea4f..5cd0ebd3 100644 --- a/test/Tests/Strings/StringInferencePerformanceTests.cs +++ b/test/Tests/Strings/StringInferencePerformanceTests.cs @@ -38,8 +38,8 @@ namespace Microsoft.ML.Probabilistic.Tests StringAutomaton automaton = StringAutomaton.Zero(); var nextState = automaton.Start.AddTransitionsForSequence("abc"); nextState.AddSelfTransition('d', Weight.FromValue(0.1)); - nextState.AddTransitionsForSequence("efg").EndWeight = Weight.One; - nextState.AddTransitionsForSequence("hejfhoenmf").EndWeight = Weight.One; + nextState.AddTransitionsForSequence("efg").SetEndWeight(Weight.One); + nextState.AddTransitionsForSequence("hejfhoenmf").SetEndWeight(Weight.One); ProfileAction(() => automaton.GetLogNormalizer(), 100000); }, 10000); @@ -57,13 +57,13 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); var nextState = automaton.Start.AddTransitionsForSequence("abc"); - nextState.EndWeight = Weight.One; + nextState.SetEndWeight(Weight.One); nextState.AddSelfTransition('d', Weight.FromValue(0.1)); nextState = nextState.AddTransitionsForSequence("efg"); - nextState.EndWeight = Weight.One; + nextState.SetEndWeight(Weight.One); nextState.AddSelfTransition('h', Weight.FromValue(0.2)); nextState = nextState.AddTransitionsForSequence("grlkhgn;lk3rng"); - nextState.EndWeight = Weight.One; + nextState.SetEndWeight(Weight.One); nextState.AddSelfTransition('h', Weight.FromValue(0.3)); ProfileAction(() => automaton.GetLogNormalizer(), 100000); @@ -82,10 +82,10 @@ namespace Microsoft.ML.Probabilistic.Tests { StringAutomaton automaton = StringAutomaton.Zero(); automaton.Start.AddSelfTransition('a', Weight.FromValue(0.5)); - automaton.Start.EndWeight = Weight.One; + automaton.Start.SetEndWeight(Weight.One); var nextState = automaton.Start.AddTransitionsForSequence("aa"); nextState.AddSelfTransition('a', Weight.FromValue(0.5)); - nextState.EndWeight = Weight.One; + nextState.SetEndWeight(Weight.One); for (int i = 0; i < 3; ++i) {