Fixed an edge case in InnerQuantiles.GetQuantile.

This commit is contained in:
Tom Minka 2018-10-15 15:55:32 +01:00
Родитель fe253eb525
Коммит 79905f740d
3 изменённых файлов: 30 добавлений и 21 удалений

Просмотреть файл

@ -53,14 +53,14 @@ namespace Microsoft.ML.Probabilistic.Distributions
public override string ToString()
{
string quantileString;
if(quantiles.Length <= 5)
if (quantiles.Length <= 5)
{
quantileString = StringUtil.CollectionToString(quantiles, " ");
}
else
{
int n = quantiles.Length;
quantileString = $"{quantiles[0]:g2} {quantiles[1]:g2} ... {quantiles[n-2]:g2} {quantiles[n-1]:g2}";
quantileString = $"{quantiles[0]:g2} {quantiles[1]:g2} ... {quantiles[n - 2]:g2} {quantiles[n - 1]:g2}";
}
return $"InnerQuantiles({quantiles.Length}, {quantileString})";
}
@ -174,7 +174,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
double p0 = (double)n / (n + 1);
double p1 = (i + 1.0) / (n + 1);
double mean, stddev;
GetGaussianFromQuantiles(quantiles[n-1], p0, quantiles[i], p1, out mean, out stddev);
GetGaussianFromQuantiles(quantiles[n - 1], p0, quantiles[i], p1, out mean, out stddev);
return Gaussian.FromMeanAndVariance(mean, stddev * stddev);
}
}
@ -208,25 +208,29 @@ namespace Microsoft.ML.Probabilistic.Distributions
return (index + frac) / (n + 1);
}
/// <summary>
/// Returns the largest value x such that GetProbLessThan(x) &lt;= probability.
/// </summary>
/// <param name="probability">A real number in [0,1].</param>
/// <returns></returns>
public double GetQuantile(double probability)
{
if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), "probability < 0");
if (probability > 1.0) throw new ArgumentOutOfRangeException(nameof(probability), "probability > 1.0");
int n = quantiles.Length;
if(probability < 1.0/(n+1.0))
if (probability < 1.0 / (n + 1.0))
{
return lowerGaussian.GetQuantile(probability);
}
if(probability > n/(n+1.0))
if (probability > n / (n + 1.0))
{
return upperGaussian.GetQuantile(probability);
}
if (n == 1) return quantiles[0]; // probability must be 0.5
double pos = MMath.LargestDoubleProduct(n + 1, probability) - 1;
int lower = (int)Math.Floor(pos);
int upper = (int)Math.Ceiling(pos);
if (upper == lower) upper = lower + 1;
return OuterQuantiles.GetQuantile(probability, lower + 1, quantiles[lower], quantiles[upper], n + 2);
if (lower == n - 1) return quantiles[lower];
return OuterQuantiles.GetQuantile(probability, lower + 1, quantiles[lower], quantiles[lower + 1], n + 2);
}
}
}

Просмотреть файл

@ -98,9 +98,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
if (n == 1) return quantiles[0];
double pos = MMath.LargestDoubleProduct(n - 1, probability);
int lower = (int)Math.Floor(pos);
int upper = (int)Math.Ceiling(pos);
if (upper == lower) upper = lower + 1;
return GetQuantile(probability, lower, quantiles[lower], quantiles[upper], n);
if (lower == n - 1) return quantiles[lower];
return GetQuantile(probability, lower, quantiles[lower], quantiles[lower+1], n);
}
/// <summary>

Просмотреть файл

@ -40,6 +40,7 @@ namespace Microsoft.ML.Probabilistic.Tests
double[] quantiles = { -2.3396737042130806, -2.1060851851919309, -1.8587796492436919, -1.7515040214502977, -1.6631549706936311, -1.5649421094540212, -1.4760970897199182, -1.4120516891795316, -1.3472276831887715, -1.2800915764085863, -1.2315546431485036, -1.1733035015194753, -1.1275506999997809, -1.0868191452824896, -1.0423720676050061, -1.0030087867587449, -0.96427545374863111, -0.917480799606264, -0.88868683894166878, -0.85040868414900372, -0.80942702953353063, -0.78299794937710787, -0.74791530550923879, -0.71057667829968463, -0.6764786230399974, -0.64937712088706545, -0.61647747819758114, -0.585418062478127, -0.55212155586237877, -0.52794712262708809, -0.49602391921870309, -0.4699661621821, -0.44707572988386468, -0.41779003649017038, -0.38751278424822111, -0.3659754249474671, -0.33671101603741, -0.30844051169056652, -0.28736460398884939, -0.26394181175383763, -0.2339421108026867, -0.20421395179821347, -0.17975005820876525, -0.15495505128166037, -0.12881080807789203, -0.10882854018038969, -0.080502768973386082, -0.054592779524389491, -0.030746556623353873, 0.0010699779508669754, 0.018476164506076323, 0.042997842018717161, 0.068170326454891988, 0.098939711480485845, 0.12364243085219064, 0.14897752107634207, 0.17232065117344095, 0.19510514320430472, 0.21967681963331126, 0.25144866739098226, 0.26627058021030359, 0.28976112810281413, 0.325183138022793, 0.34611510490686043, 0.37135045464414679, 0.40484250840269187, 0.423660564514518, 0.45260008550109493, 0.47897070643517381, 0.513466904702678, 0.54074552445523427, 0.56782579073247685, 0.59191546380311844, 0.630594130276651, 0.66170186000470765, 0.69059427870805967, 0.72267836185626344, 0.75388989983592025, 0.78095231060517345, 0.81945443104186122, 0.85806474163877222, 0.88543000730858912, 0.9254742516670329, 0.96663287584250224, 1.0081099518226813, 1.0414716524617549, 1.0873521052324735, 1.138068925150572, 1.1769604530537776, 1.2209510765755074, 1.2805602443304192, 1.3529085306332467, 1.4111760504339896, 1.4822842454846386, 1.5518312997748602, 1.6439254270476189, 1.7357210363862619, 1.9281504259252962, 2.064331420559117, 2.3554568165928291, };
InnerQuantiles inner = new InnerQuantiles(quantiles);
inner.GetQuantile(0.49471653842100138);
inner.GetQuantile((double)quantiles.Length / (quantiles.Length + 1));
}
/// <summary>
@ -76,12 +77,16 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void QuantileEstimator_MedianTest()
{
double left = 1.2;
double middle = 3.4;
double[] x = { 1.2, middle, 5.6 };
double right = 5.6;
double[] x = { left, middle, right };
var outer = new OuterQuantiles(x);
Assert.Equal(outer.GetQuantile(0.5), middle);
Assert.Equal(middle, outer.GetQuantile(0.5));
var inner = new InnerQuantiles(3, outer);
Assert.Equal(inner.GetQuantile(0.5), middle);
Assert.Equal(middle, inner.GetQuantile(0.5));
inner = new InnerQuantiles(x);
CheckGetQuantile(inner, inner, 25, 75);
var est = new QuantileEstimator(0.01);
est.AddRange(x);
Assert.Equal(est.GetQuantile(0.5), middle);
@ -250,22 +255,23 @@ namespace Microsoft.ML.Probabilistic.Tests
double middle = 3.4;
double next = MMath.NextDouble(middle);
double[] x = { middle, middle, middle };
var outer = new OuterQuantiles(x);
Assert.Equal(0.0, outer.GetProbLessThan(middle));
Assert.Equal(middle, outer.GetQuantile(0.0));
Assert.Equal(middle, outer.GetQuantile(0.75));
Assert.Equal(1.0, outer.GetProbLessThan(next));
Assert.Equal(next, outer.GetQuantile(1.0));
foreach (int weight in new[] { 1, 2, 3 })
{
var outer = new OuterQuantiles(x);
Assert.Equal(0.0, outer.GetProbLessThan(middle));
Assert.Equal(outer.GetQuantile(0.0), middle);
Assert.Equal(1.0, outer.GetProbLessThan(next));
Assert.Equal(outer.GetQuantile(1.0), next);
var est = new QuantileEstimator(0.01);
foreach (var item in x)
{
est.Add(item, weight);
}
Assert.Equal(0.0, est.GetProbLessThan(middle));
Assert.Equal(est.GetQuantile(0.0), middle);
Assert.Equal(middle, est.GetQuantile(0.0));
Assert.Equal(1.0, est.GetProbLessThan(next));
Assert.Equal(est.GetQuantile(1.0), next);
Assert.Equal(next, est.GetQuantile(1.0));
}
}