зеркало из https://github.com/dotnet/infer.git
Add operator overloads to Weight class (#121)
This commit is contained in:
Родитель
bb11960fc0
Коммит
dc0b30487e
|
@ -327,13 +327,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
}
|
||||
else
|
||||
{
|
||||
transition.Weight = Weight.Product(transition.Weight, endState.EndWeight);
|
||||
transition.Weight *= endState.EndWeight;
|
||||
}
|
||||
|
||||
endState.AddTransition(transition);
|
||||
}
|
||||
|
||||
endState.SetEndWeight(Weight.Product(endState.EndWeight, secondStartState.EndWeight));
|
||||
endState.SetEndWeight(endState.EndWeight * secondStartState.EndWeight);
|
||||
}
|
||||
|
||||
this.RemoveState(secondStartState.Index);
|
||||
|
|
|
@ -313,9 +313,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
State destState = state.Owner.States[transition.DestinationStateIndex];
|
||||
if (this.transitionFilter(transition) && !currentComponent.HasState(destState))
|
||||
{
|
||||
weightToAdd = Weight.Sum(
|
||||
weightToAdd,
|
||||
Weight.Product(transition.Weight, this.stateIdToInfo[transition.DestinationStateIndex].WeightToEnd));
|
||||
weightToAdd += transition.Weight * this.stateIdToInfo[transition.DestinationStateIndex].WeightToEnd;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -326,9 +324,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
State updatedState = currentComponent.GetStateByIndex(updatedStateIndex);
|
||||
CondensationStateInfo updatedStateInfo = this.stateIdToInfo[updatedState.Index];
|
||||
updatedStateInfo.WeightToEnd = Weight.Sum(
|
||||
updatedStateInfo.WeightToEnd,
|
||||
Weight.Product(currentComponent.GetWeight(updatedStateIndex, stateIndex), weightToAdd));
|
||||
updatedStateInfo.WeightToEnd +=
|
||||
currentComponent.GetWeight(updatedStateIndex, stateIndex) * weightToAdd;
|
||||
this.stateIdToInfo[updatedState.Index] = updatedStateInfo;
|
||||
}
|
||||
}
|
||||
|
@ -368,9 +365,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
State destState = currentComponent.GetStateByIndex(destStateIndex);
|
||||
CondensationStateInfo destStateInfo = this.stateIdToInfo[destState.Index];
|
||||
destStateInfo.WeightFromRoot = Weight.Sum(
|
||||
destStateInfo.WeightFromRoot,
|
||||
Weight.Product(srcStateInfo.UpwardWeightFromRoot, currentComponent.GetWeight(srcStateIndex, destStateIndex)));
|
||||
destStateInfo.WeightFromRoot +=
|
||||
srcStateInfo.UpwardWeightFromRoot * currentComponent.GetWeight(srcStateIndex, destStateIndex);
|
||||
this.stateIdToInfo[destState.Index] = destStateInfo;
|
||||
}
|
||||
}
|
||||
|
@ -392,9 +388,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (this.transitionFilter(transition) && !currentComponent.HasState(destState))
|
||||
{
|
||||
CondensationStateInfo destStateInfo = this.stateIdToInfo[destState.Index];
|
||||
destStateInfo.UpwardWeightFromRoot = Weight.Sum(
|
||||
destStateInfo.UpwardWeightFromRoot,
|
||||
Weight.Product(srcStateInfo.WeightFromRoot, transition.Weight));
|
||||
destStateInfo.UpwardWeightFromRoot += srcStateInfo.WeightFromRoot * transition.Weight;
|
||||
this.stateIdToInfo[transition.DestinationStateIndex] = destStateInfo;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -110,9 +110,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
destinationState.SetEndWeight(Weight.Zero);
|
||||
foreach (KeyValuePair<int, Weight> stateIdWithWeight in destWeightedStateSet)
|
||||
{
|
||||
destinationState.SetEndWeight(Weight.Sum(
|
||||
destinationState.EndWeight,
|
||||
Weight.Product(stateIdWithWeight.Value, this.States[stateIdWithWeight.Key].EndWeight)));
|
||||
var addedWeight = stateIdWithWeight.Value * this.States[stateIdWithWeight.Key].EndWeight;
|
||||
destinationState.SetEndWeight(destinationState.EndWeight + addedWeight);
|
||||
}
|
||||
|
||||
destinationStateIndex = destinationState.Index;
|
||||
|
|
|
@ -49,7 +49,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
break;
|
||||
}
|
||||
|
||||
selfLoopWeight = Weight.Sum(selfLoopWeight, transition.Weight);
|
||||
selfLoopWeight += transition.Weight;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
Weight stateWeight = Weight.ApproximateClosure(selfLoopWeight);
|
||||
this.weightedStates.Add((state, stateWeight));
|
||||
this.EndWeight = Weight.Product(stateWeight, state.EndWeight);
|
||||
this.EndWeight = stateWeight * state.EndWeight;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
@ -74,7 +74,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
newSourceState.SetEndWeight(weightToEnd);
|
||||
}
|
||||
|
||||
correctionFactor = Weight.Sum(correctionFactor, Weight.Product(weightFromRoot, weightToEnd));
|
||||
correctionFactor += weightFromRoot * weightToEnd;
|
||||
}
|
||||
|
||||
if (!correctionFactor.IsZero)
|
||||
|
@ -168,9 +168,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
if (transition.Group == group) continue;
|
||||
|
||||
weightToAdd = Weight.Sum(
|
||||
weightToAdd,
|
||||
Weight.Product(transition.Weight, weights[transition.DestinationStateIndex]));
|
||||
weightToAdd += transition.Weight * weights[transition.DestinationStateIndex];
|
||||
}
|
||||
|
||||
weights[state.Index] = weightToAdd;
|
||||
|
@ -205,10 +203,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
if (transition.Group == group) continue;
|
||||
|
||||
var destWeight = weights[transition.DestinationStateIndex];
|
||||
var weight = Weight.Sum(destWeight, Weight.Product(srcWeight, transition.Weight));
|
||||
|
||||
weights[transition.DestinationStateIndex] = weight;
|
||||
weights[transition.DestinationStateIndex] += srcWeight * transition.Weight;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -187,7 +187,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
if (transition1.IsEpsilon && transition2.IsEpsilon)
|
||||
{
|
||||
transition1.Weight = Weight.Sum(transition1.Weight, transition2.Weight);
|
||||
transition1.Weight += transition2.Weight;
|
||||
return transition1;
|
||||
}
|
||||
|
||||
|
@ -213,7 +213,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
|
||||
return new Transition(
|
||||
newElementDistribution,
|
||||
Weight.Sum(transition1.Weight, transition2.Weight),
|
||||
transition1.Weight + transition2.Weight,
|
||||
transition1.DestinationStateIndex,
|
||||
transition1.Group);
|
||||
}
|
||||
|
@ -266,7 +266,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
MergeStates(
|
||||
transition1.DestinationStateIndex,
|
||||
transition2.DestinationStateIndex,
|
||||
Weight.Product(transition2.Weight, Weight.Inverse(transition1.Weight)));
|
||||
transition2.Weight * Weight.Inverse(transition1.Weight));
|
||||
isRemovedNode[transition2.DestinationStateIndex] = true;
|
||||
iterator2.Remove();
|
||||
}
|
||||
|
@ -370,8 +370,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
// sum end weights
|
||||
if (!state2.EndWeight.IsZero)
|
||||
{
|
||||
var state2EndWeight = Weight.Product(state2WeightMultiplier, state2.EndWeight);
|
||||
state1.SetEndWeight(Weight.Sum(state1.EndWeight, state2EndWeight));
|
||||
var state2EndWeight = state2WeightMultiplier * state2.EndWeight;
|
||||
state1.SetEndWeight(state1.EndWeight + state2EndWeight);
|
||||
}
|
||||
|
||||
// Copy all transitions except self-loop.
|
||||
|
@ -383,7 +383,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var transition = iterator.Value;
|
||||
if (transition.DestinationStateIndex != state2Index)
|
||||
{
|
||||
transition.Weight = Weight.Product(transition.Weight, state2WeightMultiplier);
|
||||
transition.Weight *= state2WeightMultiplier;
|
||||
state1.AddTransition(transition);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -165,8 +165,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
if (this.transitionFilter(transition) && transition.DestinationStateIndex == state.Index)
|
||||
{
|
||||
this.singleStatePairwiseWeight = Weight.Sum(
|
||||
this.singleStatePairwiseWeight.Value, transition.Weight);
|
||||
this.singleStatePairwiseWeight += transition.Weight;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,8 +202,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
int destStateIndexInComponent;
|
||||
if (this.transitionFilter(transition) && (destStateIndexInComponent = this.GetIndexByState(destState)) != -1)
|
||||
{
|
||||
this.pairwiseWeights[srcStateIndexInComponent, destStateIndexInComponent] = Weight.Sum(
|
||||
this.pairwiseWeights[srcStateIndexInComponent, destStateIndexInComponent], transition.Weight);
|
||||
this.pairwiseWeights[srcStateIndexInComponent, destStateIndexInComponent] += transition.Weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -228,16 +226,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
continue;
|
||||
}
|
||||
|
||||
Weight additionalWeight = Weight.Product(
|
||||
this.pairwiseWeights[i, j] += Weight.Product(
|
||||
this.pairwiseWeights[i, k], loopWeight, this.pairwiseWeights[k, j]);
|
||||
this.pairwiseWeights[i, j] = Weight.Sum(this.pairwiseWeights[i, j], additionalWeight);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < this.Size; ++i)
|
||||
{
|
||||
this.pairwiseWeights[i, k] = Weight.Product(this.pairwiseWeights[i, k], loopWeight);
|
||||
this.pairwiseWeights[k, i] = Weight.Product(this.pairwiseWeights[k, i], loopWeight);
|
||||
this.pairwiseWeights[i, k] *= loopWeight;
|
||||
this.pairwiseWeights[k, i] *= loopWeight;
|
||||
}
|
||||
|
||||
this.pairwiseWeights[k, k] = loopWeight;
|
||||
|
|
|
@ -646,7 +646,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
endStatesWithTargetWeights.Add(ValueTuple.Create(
|
||||
state.Index,
|
||||
Weight.Product(Weight.FromValue(repetitionNumberWeights[i]), state.EndWeight)));
|
||||
Weight.FromValue(repetitionNumberWeights[i]) * state.EndWeight));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1343,7 +1343,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (!productStateCache.TryGetValue(destPair, out var productStateIndex))
|
||||
{
|
||||
var productState = builder.AddState();
|
||||
productState.SetEndWeight(Weight.Product(state1.EndWeight, state2.EndWeight));
|
||||
productState.SetEndWeight(state1.EndWeight * state2.EndWeight);
|
||||
stack.Push((state1.Index, state2.Index, productState.Index));
|
||||
productStateCache[destPair] = productState.Index;
|
||||
productStateIndex = productState.Index;
|
||||
|
@ -1768,7 +1768,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
if (valueCache.TryGetValue(statePosPair, out var cachedValue))
|
||||
{
|
||||
valuesStack.Push(Weight.Product(cachedValue, multiplier));
|
||||
valuesStack.Push(cachedValue * multiplier);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -1777,7 +1777,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (sequencePos == sequenceLength)
|
||||
{
|
||||
// We are at the end of sequence. So put an answer on stack
|
||||
valuesStack.Push(Weight.Product(closure.EndWeight, multiplier));
|
||||
valuesStack.Push(closure.EndWeight * multiplier);
|
||||
valueCache[statePosPair] = closure.EndWeight;
|
||||
}
|
||||
else
|
||||
|
@ -1817,11 +1817,10 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var sum = Weight.Zero;
|
||||
while (valuesStack.Count > sumUntil)
|
||||
{
|
||||
var transitionValue = valuesStack.Pop();
|
||||
sum = Weight.Sum(sum, transitionValue);
|
||||
sum += valuesStack.Pop();
|
||||
}
|
||||
|
||||
valuesStack.Push(Weight.Product(multiplier, sum));
|
||||
valuesStack.Push(multiplier * sum);
|
||||
valueCache[statePosPair] = sum;
|
||||
}
|
||||
}
|
||||
|
@ -2189,7 +2188,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var closureDestStateIndex = BuildEpsilonClosure(destState);
|
||||
resultState.AddTransition(
|
||||
transition.ElementDistribution,
|
||||
Weight.Product(transition.Weight, closureStateWeight),
|
||||
transition.Weight * closureStateWeight,
|
||||
closureDestStateIndex,
|
||||
transition.Group);
|
||||
}
|
||||
|
@ -2268,7 +2267,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
Weight transitionWeightSum = Weight.Zero;
|
||||
foreach (var transition in state.Transitions)
|
||||
{
|
||||
transitionWeightSum = Weight.Sum(transitionWeightSum, transition.Weight);
|
||||
transitionWeightSum += transition.Weight;
|
||||
}
|
||||
|
||||
maxLogTransitionWeightSum = Math.Max(maxLogTransitionWeightSum, transitionWeightSum.LogValue);
|
||||
|
@ -2421,7 +2420,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
transitionIterator.Value = transition;
|
||||
}
|
||||
|
||||
state.SetEndWeight(Weight.Product(state.EndWeight, weightToEndInv));
|
||||
state.SetEndWeight(state.EndWeight * weightToEndInv);
|
||||
}
|
||||
|
||||
return builder.GetData();
|
||||
|
@ -2531,7 +2530,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var currentState = this.States[stateIndex];
|
||||
if (currentState.CanEnd)
|
||||
{
|
||||
var newWeight = Weight.Product(weight, currentState.EndWeight);
|
||||
var newWeight = weight * currentState.EndWeight;
|
||||
yield return new Tuple<List<TElementDistribution>, double>(prefix.Reverse().ToList(), newWeight.LogValue);
|
||||
}
|
||||
|
||||
|
@ -2556,7 +2555,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
prefix.Push(transition.ElementDistribution.Value);
|
||||
}
|
||||
|
||||
foreach (var support in this.EnumeratePaths(prefix, visitedStates, Weight.Product(weight, transition.Weight), transition.DestinationStateIndex))
|
||||
foreach (var support in this.EnumeratePaths(prefix, visitedStates, weight * transition.Weight, transition.DestinationStateIndex))
|
||||
{
|
||||
yield return support;
|
||||
}
|
||||
|
|
|
@ -122,14 +122,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var elementStateWeightSum = Weight.Zero;
|
||||
foreach (var element in transitionElements)
|
||||
{
|
||||
Weight prevStateWeight;
|
||||
if (!elementStatesWeights.TryGetWeight(element.destIndex, out prevStateWeight))
|
||||
if (!elementStatesWeights.TryGetWeight(element.destIndex, out var prevStateWeight))
|
||||
{
|
||||
prevStateWeight = Weight.Zero;
|
||||
}
|
||||
|
||||
elementStatesWeights[element.destIndex] = Weight.Sum(prevStateWeight, element.weight);
|
||||
elementStateWeightSum = Weight.Sum(elementStateWeightSum, element.weight);
|
||||
elementStatesWeights[element.destIndex] = prevStateWeight + element.weight;
|
||||
elementStateWeightSum += element.weight;
|
||||
}
|
||||
|
||||
var destinationState = new Determinization.WeightedStateSet();
|
||||
|
@ -137,13 +136,13 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
if (stateIdWithWeight.Value.LogValue > LogEps)
|
||||
{
|
||||
Weight stateWeight = Weight.Product(stateIdWithWeight.Value, Weight.Inverse(elementStateWeightSum));
|
||||
Weight stateWeight = stateIdWithWeight.Value * Weight.Inverse(elementStateWeightSum);
|
||||
destinationState.Add(stateIdWithWeight.Key, stateWeight);
|
||||
}
|
||||
}
|
||||
|
||||
Weight transitionWeight = Weight.Product(Weight.FromValue(1), elementStateWeightSum);
|
||||
results.Add(Tuple.Create(transitionElements[0].distribution,transitionWeight, destinationState));
|
||||
Weight transitionWeight = elementStateWeightSum;
|
||||
results.Add(Tuple.Create(transitionElements[0].distribution, transitionWeight, destinationState));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -159,7 +158,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
Dictionary<TElement, List<TransitionElement>> elements, List<TransitionElement> uniformList)
|
||||
{
|
||||
var dist = transition.ElementDistribution.Value;
|
||||
Weight weightBase = Weight.Product(transition.Weight, sourceStateResidualWeight);
|
||||
Weight weightBase = transition.Weight * sourceStateResidualWeight;
|
||||
if (dist.IsPointMass)
|
||||
{
|
||||
var pt = dist.Point;
|
||||
|
|
|
@ -92,12 +92,12 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
{
|
||||
if (stateIdWithWeight.Value.LogValue > LogEps)
|
||||
{
|
||||
Weight stateWeight = Weight.Product(stateIdWithWeight.Value, Weight.Inverse(currentSegmentStateWeightSum));
|
||||
Weight stateWeight = stateIdWithWeight.Value * Weight.Inverse(currentSegmentStateWeightSum);
|
||||
destinationState.Add(stateIdWithWeight.Key, stateWeight);
|
||||
}
|
||||
}
|
||||
|
||||
Weight transitionWeight = Weight.Product(Weight.FromValue(segmentLength), currentSegmentStateWeightSum);
|
||||
Weight transitionWeight = Weight.FromValue(segmentLength) * currentSegmentStateWeightSum;
|
||||
result.Add((elementDist, transitionWeight, destinationState));
|
||||
}
|
||||
|
||||
|
@ -107,8 +107,8 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (segmentBound.IsStart)
|
||||
{
|
||||
activeSegments.Add(segmentBound);
|
||||
currentSegmentStateWeightSum = Weight.Sum(currentSegmentStateWeightSum, segmentBound.Weight);
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] = Weight.Sum(currentSegmentStateWeights[segmentBound.DestinationStateId], segmentBound.Weight);
|
||||
currentSegmentStateWeightSum += segmentBound.Weight;
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] += segmentBound.Weight;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -117,9 +117,15 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (double.IsInfinity(segmentBound.Weight.Value))
|
||||
{
|
||||
// Cannot subtract because of the infinities involved.
|
||||
currentSegmentStateWeightSum = activeSegments.Select(sb => sb.Weight).Aggregate(Weight.Zero, (acc, w) => Weight.Sum(acc, w));
|
||||
currentSegmentStateWeightSum =
|
||||
activeSegments
|
||||
.Select(sb => sb.Weight)
|
||||
.Aggregate(Weight.Zero, Weight.Sum);
|
||||
currentSegmentStateWeights[segmentBound.DestinationStateId] =
|
||||
activeSegments.Where(sb => sb.DestinationStateId == segmentBound.DestinationStateId).Select(sb => sb.Weight).Aggregate(Weight.Zero, (acc, w) => Weight.Sum(acc, w));
|
||||
activeSegments
|
||||
.Where(sb => sb.DestinationStateId == segmentBound.DestinationStateId)
|
||||
.Select(sb => sb.Weight)
|
||||
.Aggregate(Weight.Zero, Weight.Sum);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -149,7 +155,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
var ranges = distribution.Ranges;
|
||||
int commonValueStart = char.MinValue;
|
||||
Weight commonValue = Weight.FromValue(distribution.ProbabilityOutsideRanges);
|
||||
Weight weightBase = Weight.Product(transition.Weight, sourceStateResidualWeight);
|
||||
Weight weightBase = transition.Weight * sourceStateResidualWeight;
|
||||
TransitionCharSegmentBound newSegmentBound;
|
||||
|
||||
////if (double.IsInfinity(weightBase.Value))
|
||||
|
@ -162,7 +168,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (range.StartInclusive > commonValueStart && !commonValue.IsZero)
|
||||
{
|
||||
// Add endpoints for the common value
|
||||
Weight segmentWeight = Weight.Product(commonValue, weightBase);
|
||||
Weight segmentWeight = commonValue * weightBase;
|
||||
newSegmentBound = new TransitionCharSegmentBound(commonValueStart, transition.DestinationStateIndex, segmentWeight, true);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
newSegmentBound = new TransitionCharSegmentBound(range.StartInclusive, transition.DestinationStateIndex, segmentWeight, false);
|
||||
|
@ -173,7 +179,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
Weight pieceValue = Weight.FromValue(range.Probability);
|
||||
if (!pieceValue.IsZero)
|
||||
{
|
||||
Weight segmentWeight = Weight.Product(pieceValue, weightBase);
|
||||
Weight segmentWeight = pieceValue * weightBase;
|
||||
newSegmentBound = new TransitionCharSegmentBound(range.StartInclusive, transition.DestinationStateIndex, segmentWeight, true);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
newSegmentBound = new TransitionCharSegmentBound(range.EndExclusive, transition.DestinationStateIndex, segmentWeight, false);
|
||||
|
@ -186,7 +192,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (!commonValue.IsZero && (ranges.Count == 0 || ranges[ranges.Count - 1].EndExclusive != DiscreteChar.CharRangeEndExclusive))
|
||||
{
|
||||
// Add endpoints for the last common value segment
|
||||
Weight segmentWeight = Weight.Product(commonValue, weightBase);
|
||||
Weight segmentWeight = commonValue * weightBase;
|
||||
newSegmentBound = new TransitionCharSegmentBound(commonValueStart, transition.DestinationStateIndex, segmentWeight, true);
|
||||
bounds.Add(new ValueTuple<int, TransitionCharSegmentBound>(bounds.Count, newSegmentBound));
|
||||
newSegmentBound = new TransitionCharSegmentBound(char.MaxValue + 1, transition.DestinationStateIndex, segmentWeight, false);
|
||||
|
|
|
@ -349,7 +349,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
if (!destStateCache.TryGetValue(destPair, out var destStateIndex))
|
||||
{
|
||||
var destState = result.AddState();
|
||||
destState.SetEndWeight(Weight.Product(mappingState.EndWeight, srcState.EndWeight));
|
||||
destState.SetEndWeight(mappingState.EndWeight * srcState.EndWeight);
|
||||
stack.Push((mappingState.Index, srcState.Index, destState.Index));
|
||||
destStateCache[destPair] = destState.Index;
|
||||
destStateIndex = destState.Index;
|
||||
|
@ -504,7 +504,7 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
continue;
|
||||
}
|
||||
|
||||
var weight = Weight.Product(mappingTransition.Weight, Weight.FromLogValue(projectionLogScale));
|
||||
var weight = mappingTransition.Weight * Weight.FromLogValue(projectionLogScale);
|
||||
var childDestState = CreateDestState(destMappingState,srcSequenceIndex + 1);
|
||||
destState.AddTransition(destElementDistribution, weight, childDestState, mappingTransition.Group);
|
||||
}
|
||||
|
|
|
@ -43,50 +43,32 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
/// <summary>
|
||||
/// Gets the zero weight.
|
||||
/// </summary>
|
||||
public static Weight Zero
|
||||
{
|
||||
get { return new Weight(double.NegativeInfinity); }
|
||||
}
|
||||
public static Weight Zero => new Weight(double.NegativeInfinity);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the unit weight.
|
||||
/// </summary>
|
||||
public static Weight One
|
||||
{
|
||||
get { return new Weight(0); }
|
||||
}
|
||||
public static Weight One => new Weight(0);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the infinite weight.
|
||||
/// </summary>
|
||||
public static Weight Infinity
|
||||
{
|
||||
get { return new Weight(double.PositiveInfinity); }
|
||||
}
|
||||
public static Weight Infinity => new Weight(double.PositiveInfinity);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the logarithm of the weight value.
|
||||
/// </summary>
|
||||
public double LogValue
|
||||
{
|
||||
get { return this.logValue; }
|
||||
}
|
||||
public double LogValue => this.logValue;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the weight value.
|
||||
/// </summary>
|
||||
public double Value
|
||||
{
|
||||
get { return Math.Exp(this.LogValue); }
|
||||
}
|
||||
public double Value => Math.Exp(this.LogValue);
|
||||
|
||||
/// <summary>
|
||||
/// Gets a value indicating whether the weight is zero.
|
||||
/// </summary>
|
||||
public bool IsZero
|
||||
{
|
||||
get { return double.IsNegativeInfinity(this.LogValue); }
|
||||
}
|
||||
public bool IsZero => double.IsNegativeInfinity(this.LogValue);
|
||||
|
||||
/// <summary>
|
||||
/// Gets value indicating whether weight is infinite.
|
||||
|
@ -282,6 +264,22 @@ namespace Microsoft.ML.Probabilistic.Distributions.Automata
|
|||
return !(weight1 == weight2);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Compute the product of given weights.
|
||||
/// </summary>
|
||||
/// <param name="weight1">The first weight.</param>
|
||||
/// <param name="weight2">The second weight.</param>
|
||||
/// <returns>The computed product.</returns>
|
||||
public static Weight operator *(Weight weight1, Weight weight2) => Product(weight1, weight2);
|
||||
|
||||
/// <summary>
|
||||
/// Compute the sum of given weights.
|
||||
/// </summary>
|
||||
/// <param name="weight1">The first weight.</param>
|
||||
/// <param name="weight2">The second weight.</param>
|
||||
/// <returns>The computed sum.</returns>
|
||||
public static Weight operator +(Weight weight1, Weight weight2) => Sum(weight1, weight2);
|
||||
|
||||
/// <summary>
|
||||
/// Checks if this instance is equal to a given object.
|
||||
/// </summary>
|
||||
|
|
|
@ -1467,7 +1467,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
Weight probSum = Weight.Zero;
|
||||
foreach (var transition in currentState.Transitions)
|
||||
{
|
||||
probSum = Weight.Sum(probSum, transition.Weight);
|
||||
probSum += transition.Weight;
|
||||
if (logSample < probSum.LogValue)
|
||||
{
|
||||
if (!transition.IsEpsilon)
|
||||
|
|
|
@ -86,7 +86,7 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
logValues1.SetToFunction(
|
||||
logValues1,
|
||||
logValues2,
|
||||
(x, y) => Weight.Sum(Weight.FromLogValue(x), Weight.Product(values2Scale, Weight.FromLogValue(y))).LogValue);
|
||||
(x, y) => (Weight.FromLogValue(x) + values2Scale * Weight.FromLogValue(y)).LogValue);
|
||||
return logValues1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -799,7 +799,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Weight weightSum = automatonClone.States[i].EndWeight;
|
||||
for (int j = 0; j < automatonClone.States[i].Transitions.Count; ++j)
|
||||
{
|
||||
weightSum = Weight.Sum(weightSum, automatonClone.States[i].Transitions[j].Weight);
|
||||
weightSum += automatonClone.States[i].Transitions[j].Weight;
|
||||
}
|
||||
|
||||
Assert.Equal(0.0, weightSum.LogValue, 1e-6);
|
||||
|
|
Загрузка…
Ссылка в новой задаче