зеркало из https://github.com/dotnet/infer.git
MMath.WeightedAverage fix (#379)
This commit is contained in:
Родитель
5f394f9b8a
Коммит
791d11a8a1
|
@ -4824,6 +4824,7 @@ rr = mpf('-0.99999824265582826');
|
|||
/// <returns></returns>
|
||||
public static double WeightedAverage(double weight1, double value1, double weight2, double value2)
|
||||
{
|
||||
const double ScaleBM1024 = 5.5626846462680035E-309; // Math.ScaleB(1,-1024)
|
||||
if (weight1 < weight2)
|
||||
{
|
||||
return WeightedAverage(weight2, value2, weight1, value1);
|
||||
|
@ -4851,7 +4852,8 @@ rr = mpf('-0.99999824265582826');
|
|||
// We know that (weight1+weight2) == weight1
|
||||
// weight2 < 1
|
||||
double scale = double.MaxValue / weight1; // scale >= 1 but scale*weight2 < 1
|
||||
//result = (scale * weight1 * value1 + scale * weight2 * value2) / (scale * weight1 + scale * weight2);
|
||||
// Below is equivalent to:
|
||||
// result = (scale * weight1 * value1 + scale * weight2 * value2) / (scale * weight1 + scale * weight2);
|
||||
double scaleWeight2Value2 = scale * weight2 * value2; // cannot overflow
|
||||
result = (double.MaxValue * value1 + scaleWeight2Value2) / double.MaxValue;
|
||||
if (double.IsNaN(result) || double.IsInfinity(result))
|
||||
|
@ -4859,7 +4861,6 @@ rr = mpf('-0.99999824265582826');
|
|||
// Overflow happened. Scale down to avoid overflow.
|
||||
// We scale by a power of 2 to ensure that the result is rounded the same way as above.
|
||||
// Otherwise, the function would not always be monotonic in value1 and value2.
|
||||
const double ScaleBM1024 = 5.5626846462680035E-309; // ScaleB(1,-1024)
|
||||
const double nextBelowOne = double.MaxValue * ScaleBM1024; // nextBelowOne < 1
|
||||
// IsInfinity(result) implies value1 > 1 so the first term cannot underflow.
|
||||
// This cannot overflow, because the second term's magnitude is less than 1.
|
||||
|
@ -4879,27 +4880,11 @@ rr = mpf('-0.99999824265582826');
|
|||
if (double.IsNaN(result) || double.IsInfinity(result))
|
||||
{
|
||||
// Overflow happened. Scale down to avoid overflow.
|
||||
// 0 <= weight2/weight1 <= 1
|
||||
// This case cannot overflow but can underflow.
|
||||
// It is not clear whether this result will be rounded the same way as above,
|
||||
// so it is not clear that this code ensures monotonicity.
|
||||
double ratio2 = weight2 / weight1;
|
||||
result = (0.5 * value1 + 0.5 * ratio2 * value2) / (0.5 + 0.5 * ratio2);
|
||||
if (double.IsNaN(result))
|
||||
{
|
||||
// a/b returns NaN in 4 cases:
|
||||
// 1. a=b=0
|
||||
// 2. abs(a)=abs(b)=infinity
|
||||
// 3. a is NaN
|
||||
// 4. b is NaN
|
||||
// (weight2/weight1) cannot be NaN since none of these cases can happen.
|
||||
// Therefore IsNaN(result) implies IsNaN(numerator).
|
||||
// IsNaN(numerator) happens in 2 ways:
|
||||
// 1. abs(value1)=inf and (weight2 / weight1) * value2 = -value1. This implies abs(value2)=-value1.
|
||||
// 2. (weight2 / weight1) * value2 is NaN. This implies (weight2/weight1)=0 and abs(value2)=inf.
|
||||
// In both cases, the weights are irrelevant since at least one value is infinite.
|
||||
return value1 + value2;
|
||||
}
|
||||
// We scale by a power of 2 to ensure that the result is rounded the same way as above.
|
||||
// ratio * ScaleBM1024 < 1 therefore ratio * ScaleBM1024 * value1 cannot overflow.
|
||||
// value2 * ScaleBM1024 < 1 therefore the sum cannot overflow.
|
||||
// Both terms in the denominator are < 1 therefore the denominator cannot overflow.
|
||||
result = (ratio * ScaleBM1024 * value1 + value2 * ScaleBM1024) / (ratio * ScaleBM1024 + ScaleBM1024);
|
||||
}
|
||||
}
|
||||
result = Math.Min(result, Math.Max(value1, value2));
|
||||
|
|
|
@ -108,6 +108,22 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Assert.True(midpoint2 >= midpoint, $"Failed assertion: {midpoint2} >= {midpoint}, wa={wa:r}, a={a:r}, a2={a2:r}, wb={wb:r}, b={b:r}");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Tests a specific case that previously failed.
|
||||
/// </summary>
|
||||
[Fact]
|
||||
public void WeightedAverage_IsMonotonic3()
|
||||
{
|
||||
double wa = 1E-05;
|
||||
double a = -1.7976931348623157E+308;
|
||||
double a2 = -1.7976931348623155E+308;
|
||||
double wb = 1.0;
|
||||
double b = -1E+287;
|
||||
double midpoint = MMath.WeightedAverage(wa, a, wb, b);
|
||||
double midpoint2 = MMath.WeightedAverage(wa, a2, wb, b);
|
||||
Assert.True(midpoint2 >= midpoint, $"Failed assertion: {midpoint2} >= {midpoint}, wa={wa:r}, a={a:r}, a2={a2:r}, wb={wb:r}, b={b:r}");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MeanAccumulator_Add_ZeroWeight()
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче