OuterQuantiles and InnerQuantiles can be serialized and compared.

Their second constructor is now a factory method.
This commit is contained in:
Tom Minka 2019-06-13 22:35:24 +01:00
Родитель 7e79cb8b4b
Коммит 63fe682bed
4 изменённых файлов: 70 добавлений и 24 удалений

Просмотреть файл

@ -7,19 +7,22 @@ namespace Microsoft.ML.Probabilistic.Distributions
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Serialization;
using System.Text;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Utilities;
/// <summary>
/// Represents a distribution using the quantiles at probabilities (1,...,n)/(n+1)
/// </summary>
[Serializable, DataContract]
public class InnerQuantiles : CanGetQuantile, CanGetProbLessThan
{
/// <summary>
/// Numbers in increasing order.
/// </summary>
private readonly double[] quantiles;
[DataMember] private readonly double[] quantiles;
/// <summary>
/// Gaussian approximation of the lower tail.
/// </summary>
@ -40,18 +43,11 @@ namespace Microsoft.ML.Probabilistic.Distributions
upperGaussian = GetUpperGaussian(quantiles);
}
public InnerQuantiles(int quantileCount, CanGetQuantile canGetQuantile)
public static InnerQuantiles FromDistribution(int quantileCount, CanGetQuantile canGetQuantile)
{
if (quantileCount == 0) throw new ArgumentException("quantileCount == 0", nameof(quantiles));
this.quantiles = new double[quantileCount];
for (int i = 0; i < quantileCount; i++)
{
this.quantiles[i] = canGetQuantile.GetQuantile((i + 1.0) / (quantileCount + 1.0));
}
OuterQuantiles.AssertFinite(quantiles, nameof(canGetQuantile));
OuterQuantiles.AssertNondecreasing(quantiles, nameof(canGetQuantile));
lowerGaussian = GetLowerGaussian(quantiles);
upperGaussian = GetUpperGaussian(quantiles);
if (quantileCount == 0) throw new ArgumentOutOfRangeException(nameof(quantileCount), quantileCount, "quantileCount == 0");
var quantiles = Util.ArrayInit(quantileCount, i => canGetQuantile.GetQuantile((i + 1.0) / (quantileCount + 1.0)));
return new InnerQuantiles(quantiles);
}
public override string ToString()
@ -74,6 +70,17 @@ namespace Microsoft.ML.Probabilistic.Distributions
return quantiles;
}
public override bool Equals(object obj)
{
if (!(obj is InnerQuantiles that)) return false;
return quantiles.ValueEquals(that.quantiles);
}
public override int GetHashCode()
{
return Hash.GetHashCodeAsSequence(quantiles);
}
/// <inheritdoc/>
public double GetProbLessThan(double x)
{

Просмотреть файл

@ -7,18 +7,23 @@ namespace Microsoft.ML.Probabilistic.Distributions
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Serialization;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Utilities;
/// <summary>
/// Represents a distribution using the quantiles at probabilities (0,...,n-1)/(n-1)
/// </summary>
[Serializable, DataContract]
public class OuterQuantiles : CanGetQuantile, CanGetProbLessThan
{
/// <summary>
/// Numbers in increasing order.
/// </summary>
[DataMember]
private readonly double[] quantiles;
public OuterQuantiles(double[] quantiles)
@ -28,6 +33,37 @@ namespace Microsoft.ML.Probabilistic.Distributions
this.quantiles = quantiles;
}
public override string ToString()
{
string quantileString;
if (quantiles.Length <= 5)
{
quantileString = StringUtil.CollectionToString(quantiles, " ");
}
else
{
int n = quantiles.Length;
quantileString = $"{quantiles[0]:g2} {quantiles[1]:g2} ... {quantiles[n - 2]:g2} {quantiles[n - 1]:g2}";
}
return $"OuterQuantiles({quantiles.Length}, {quantileString})";
}
public double[] ToArray()
{
return quantiles;
}
public override bool Equals(object obj)
{
if (!(obj is OuterQuantiles that)) return false;
return quantiles.ValueEquals(that.quantiles);
}
public override int GetHashCode()
{
return Hash.GetHashCodeAsSequence(quantiles);
}
internal static void AssertFinite(double[] array, string paramName)
{
for (int i = 0; i < array.Length; i++)
@ -45,13 +81,10 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
}
public OuterQuantiles(int quantileCount, CanGetQuantile canGetQuantile)
public static OuterQuantiles FromDistribution(int quantileCount, CanGetQuantile canGetQuantile)
{
this.quantiles = new double[quantileCount];
for (int i = 0; i < quantileCount; i++)
{
this.quantiles[i] = canGetQuantile.GetQuantile(i / (quantileCount - 1.0));
}
var quantiles = Util.ArrayInit(quantileCount, i => canGetQuantile.GetQuantile(i / (quantileCount - 1.0)));
return new OuterQuantiles(quantiles);
}
public double GetProbLessThan(double x)

Просмотреть файл

@ -180,6 +180,8 @@ namespace Microsoft.ML.Probabilistic.Tests
[DataMember] private IDistribution<Vector[][]> vgaJ;
[DataMember] private SparseGP sparseGp;
[DataMember] private QuantileEstimator quantileEstimator;
[DataMember] private OuterQuantiles outerQuantiles;
[DataMember] private InnerQuantiles innerQuantiles;
[DataMember] private StringDistribution stringDistribution1;
[DataMember] private StringDistribution stringDistribution2;
@ -235,6 +237,8 @@ namespace Microsoft.ML.Probabilistic.Tests
this.quantileEstimator = new QuantileEstimator(0.01);
this.quantileEstimator.Add(5);
this.outerQuantiles = OuterQuantiles.FromDistribution(3, this.quantileEstimator);
this.innerQuantiles = InnerQuantiles.FromDistribution(3, this.outerQuantiles);
this.stringDistribution1 = StringDistribution.String("aa").Append(StringDistribution.OneOf("b", "ccc")).Append("dddd");
this.stringDistribution2 = new StringDistribution();
@ -274,6 +278,8 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(0, vgaJ.MaxDiff(that.vgaJ));
Assert.Equal(0, this.sparseGp.MaxDiff(that.sparseGp));
Assert.True(this.quantileEstimator.ValueEquals(that.quantileEstimator));
Assert.True(this.innerQuantiles.Equals(that.innerQuantiles));
Assert.True(this.outerQuantiles.Equals(that.outerQuantiles));
Assert.Equal(0, this.stringDistribution1.MaxDiff(that.stringDistribution1));
Assert.Equal(0, this.stringDistribution2.MaxDiff(that.stringDistribution2));
}

Просмотреть файл

@ -48,7 +48,7 @@ namespace Microsoft.ML.Probabilistic.Tests
var est = new QuantileEstimator(0.1);
est.Add(double.PositiveInfinity);
//est.Add(double.NegativeInfinity);
var inner = new InnerQuantiles(10, est);
var inner = InnerQuantiles.FromDistribution(10, est);
});
}
@ -159,7 +159,7 @@ namespace Microsoft.ML.Probabilistic.Tests
double[] x = { left, middle, right };
var outer = new OuterQuantiles(x);
Assert.Equal(middle, outer.GetQuantile(0.5));
var inner = new InnerQuantiles(3, outer);
var inner = InnerQuantiles.FromDistribution(3, outer);
Assert.Equal(middle, inner.GetQuantile(0.5));
inner = new InnerQuantiles(x);
CheckGetQuantile(inner, inner, 25, 75);
@ -260,7 +260,7 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(outer.GetQuantile(0.5), middle);
Assert.Equal(outer.GetQuantile(0.7), middle);
CheckGetQuantile(outer, outer);
var inner = new InnerQuantiles(7, outer);
var inner = InnerQuantiles.FromDistribution(7, outer);
Assert.Equal(0.25, inner.GetProbLessThan(middle));
Assert.Equal(outer.GetQuantile(0.3), middle);
Assert.Equal(outer.GetQuantile(0.5), middle);
@ -294,7 +294,7 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(1.0, outer.GetProbLessThan(next));
Assert.Equal(next, outer.GetQuantile(1.0));
CheckGetQuantile(outer, outer);
var inner = new InnerQuantiles(5, outer);
var inner = InnerQuantiles.FromDistribution(5, outer);
CheckGetQuantile(inner, inner, (int)Math.Ceiling(100.0 / 6), (int)Math.Floor(100.0 * 5 / 6));
var est = new QuantileEstimator(0.01);
est.Add(first, 2);
@ -331,7 +331,7 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.Equal(data[4], outer.GetQuantile(0.76));
Assert.Equal(data[2], outer.GetQuantile(0.3));
CheckGetQuantile(outer, outer);
var inner = new InnerQuantiles(7, outer);
var inner = InnerQuantiles.FromDistribution(7, outer);
CheckGetQuantile(inner, inner, (int)Math.Ceiling(100.0 / 8), (int)Math.Floor(100.0 * 7 / 8));
}
@ -444,7 +444,7 @@ namespace Microsoft.ML.Probabilistic.Tests
var sortedData = new OuterQuantiles(x.ToArray());
// compute quantiles
var quantiles = new InnerQuantiles(100, sortedData);
var quantiles = InnerQuantiles.FromDistribution(100, sortedData);
// loop over x's and compare true quantile rank
var testPoints = EpTests.linspace(MMath.Min(x) - stddev, MMath.Max(x) + stddev, 100);