From 63fe682bede8fca7d33653e00f1fd991428e254b Mon Sep 17 00:00:00 2001 From: Tom Minka <8955276+tminka@users.noreply.github.com> Date: Thu, 13 Jun 2019 22:35:24 +0100 Subject: [PATCH] OuterQuantiles and InnerQuantiles can be serialized and compared. Their second constructor is now a factory method. --- src/Runtime/Distributions/InnerQuantiles.cs | 31 ++++++++------ src/Runtime/Distributions/OuterQuantiles.cs | 45 +++++++++++++++++--- test/Tests/Distributions/SerializableTest.cs | 6 +++ test/Tests/QuantileTests.cs | 12 +++--- 4 files changed, 70 insertions(+), 24 deletions(-) diff --git a/src/Runtime/Distributions/InnerQuantiles.cs b/src/Runtime/Distributions/InnerQuantiles.cs index 3e71f61d..f0147790 100644 --- a/src/Runtime/Distributions/InnerQuantiles.cs +++ b/src/Runtime/Distributions/InnerQuantiles.cs @@ -7,19 +7,22 @@ namespace Microsoft.ML.Probabilistic.Distributions using System; using System.Collections.Generic; using System.Linq; + using System.Runtime.Serialization; using System.Text; + using Microsoft.ML.Probabilistic.Collections; using Microsoft.ML.Probabilistic.Math; using Microsoft.ML.Probabilistic.Utilities; /// /// Represents a distribution using the quantiles at probabilities (1,...,n)/(n+1) /// + [Serializable, DataContract] public class InnerQuantiles : CanGetQuantile, CanGetProbLessThan { /// /// Numbers in increasing order. /// - private readonly double[] quantiles; + [DataMember] private readonly double[] quantiles; /// /// Gaussian approximation of the lower tail. /// @@ -40,18 +43,11 @@ namespace Microsoft.ML.Probabilistic.Distributions upperGaussian = GetUpperGaussian(quantiles); } - public InnerQuantiles(int quantileCount, CanGetQuantile canGetQuantile) + public static InnerQuantiles FromDistribution(int quantileCount, CanGetQuantile canGetQuantile) { - if (quantileCount == 0) throw new ArgumentException("quantileCount == 0", nameof(quantiles)); - this.quantiles = new double[quantileCount]; - for (int i = 0; i < quantileCount; i++) - { - this.quantiles[i] = canGetQuantile.GetQuantile((i + 1.0) / (quantileCount + 1.0)); - } - OuterQuantiles.AssertFinite(quantiles, nameof(canGetQuantile)); - OuterQuantiles.AssertNondecreasing(quantiles, nameof(canGetQuantile)); - lowerGaussian = GetLowerGaussian(quantiles); - upperGaussian = GetUpperGaussian(quantiles); + if (quantileCount == 0) throw new ArgumentOutOfRangeException(nameof(quantileCount), quantileCount, "quantileCount == 0"); + var quantiles = Util.ArrayInit(quantileCount, i => canGetQuantile.GetQuantile((i + 1.0) / (quantileCount + 1.0))); + return new InnerQuantiles(quantiles); } public override string ToString() @@ -74,6 +70,17 @@ namespace Microsoft.ML.Probabilistic.Distributions return quantiles; } + public override bool Equals(object obj) + { + if (!(obj is InnerQuantiles that)) return false; + return quantiles.ValueEquals(that.quantiles); + } + + public override int GetHashCode() + { + return Hash.GetHashCodeAsSequence(quantiles); + } + /// public double GetProbLessThan(double x) { diff --git a/src/Runtime/Distributions/OuterQuantiles.cs b/src/Runtime/Distributions/OuterQuantiles.cs index d1083a9b..1dce8e65 100644 --- a/src/Runtime/Distributions/OuterQuantiles.cs +++ b/src/Runtime/Distributions/OuterQuantiles.cs @@ -7,18 +7,23 @@ namespace Microsoft.ML.Probabilistic.Distributions using System; using System.Collections.Generic; using System.Linq; + using System.Runtime.Serialization; using System.Text; using System.Threading.Tasks; + using Microsoft.ML.Probabilistic.Collections; using Microsoft.ML.Probabilistic.Math; + using Microsoft.ML.Probabilistic.Utilities; /// /// Represents a distribution using the quantiles at probabilities (0,...,n-1)/(n-1) /// + [Serializable, DataContract] public class OuterQuantiles : CanGetQuantile, CanGetProbLessThan { /// /// Numbers in increasing order. /// + [DataMember] private readonly double[] quantiles; public OuterQuantiles(double[] quantiles) @@ -28,6 +33,37 @@ namespace Microsoft.ML.Probabilistic.Distributions this.quantiles = quantiles; } + public override string ToString() + { + string quantileString; + if (quantiles.Length <= 5) + { + quantileString = StringUtil.CollectionToString(quantiles, " "); + } + else + { + int n = quantiles.Length; + quantileString = $"{quantiles[0]:g2} {quantiles[1]:g2} ... {quantiles[n - 2]:g2} {quantiles[n - 1]:g2}"; + } + return $"OuterQuantiles({quantiles.Length}, {quantileString})"; + } + + public double[] ToArray() + { + return quantiles; + } + + public override bool Equals(object obj) + { + if (!(obj is OuterQuantiles that)) return false; + return quantiles.ValueEquals(that.quantiles); + } + + public override int GetHashCode() + { + return Hash.GetHashCodeAsSequence(quantiles); + } + internal static void AssertFinite(double[] array, string paramName) { for (int i = 0; i < array.Length; i++) @@ -45,13 +81,10 @@ namespace Microsoft.ML.Probabilistic.Distributions } } - public OuterQuantiles(int quantileCount, CanGetQuantile canGetQuantile) + public static OuterQuantiles FromDistribution(int quantileCount, CanGetQuantile canGetQuantile) { - this.quantiles = new double[quantileCount]; - for (int i = 0; i < quantileCount; i++) - { - this.quantiles[i] = canGetQuantile.GetQuantile(i / (quantileCount - 1.0)); - } + var quantiles = Util.ArrayInit(quantileCount, i => canGetQuantile.GetQuantile(i / (quantileCount - 1.0))); + return new OuterQuantiles(quantiles); } public double GetProbLessThan(double x) diff --git a/test/Tests/Distributions/SerializableTest.cs b/test/Tests/Distributions/SerializableTest.cs index b8d16ce7..1991315d 100644 --- a/test/Tests/Distributions/SerializableTest.cs +++ b/test/Tests/Distributions/SerializableTest.cs @@ -180,6 +180,8 @@ namespace Microsoft.ML.Probabilistic.Tests [DataMember] private IDistribution vgaJ; [DataMember] private SparseGP sparseGp; [DataMember] private QuantileEstimator quantileEstimator; + [DataMember] private OuterQuantiles outerQuantiles; + [DataMember] private InnerQuantiles innerQuantiles; [DataMember] private StringDistribution stringDistribution1; [DataMember] private StringDistribution stringDistribution2; @@ -235,6 +237,8 @@ namespace Microsoft.ML.Probabilistic.Tests this.quantileEstimator = new QuantileEstimator(0.01); this.quantileEstimator.Add(5); + this.outerQuantiles = OuterQuantiles.FromDistribution(3, this.quantileEstimator); + this.innerQuantiles = InnerQuantiles.FromDistribution(3, this.outerQuantiles); this.stringDistribution1 = StringDistribution.String("aa").Append(StringDistribution.OneOf("b", "ccc")).Append("dddd"); this.stringDistribution2 = new StringDistribution(); @@ -274,6 +278,8 @@ namespace Microsoft.ML.Probabilistic.Tests Assert.Equal(0, vgaJ.MaxDiff(that.vgaJ)); Assert.Equal(0, this.sparseGp.MaxDiff(that.sparseGp)); Assert.True(this.quantileEstimator.ValueEquals(that.quantileEstimator)); + Assert.True(this.innerQuantiles.Equals(that.innerQuantiles)); + Assert.True(this.outerQuantiles.Equals(that.outerQuantiles)); Assert.Equal(0, this.stringDistribution1.MaxDiff(that.stringDistribution1)); Assert.Equal(0, this.stringDistribution2.MaxDiff(that.stringDistribution2)); } diff --git a/test/Tests/QuantileTests.cs b/test/Tests/QuantileTests.cs index 5e300d48..4728ba31 100644 --- a/test/Tests/QuantileTests.cs +++ b/test/Tests/QuantileTests.cs @@ -48,7 +48,7 @@ namespace Microsoft.ML.Probabilistic.Tests var est = new QuantileEstimator(0.1); est.Add(double.PositiveInfinity); //est.Add(double.NegativeInfinity); - var inner = new InnerQuantiles(10, est); + var inner = InnerQuantiles.FromDistribution(10, est); }); } @@ -159,7 +159,7 @@ namespace Microsoft.ML.Probabilistic.Tests double[] x = { left, middle, right }; var outer = new OuterQuantiles(x); Assert.Equal(middle, outer.GetQuantile(0.5)); - var inner = new InnerQuantiles(3, outer); + var inner = InnerQuantiles.FromDistribution(3, outer); Assert.Equal(middle, inner.GetQuantile(0.5)); inner = new InnerQuantiles(x); CheckGetQuantile(inner, inner, 25, 75); @@ -260,7 +260,7 @@ namespace Microsoft.ML.Probabilistic.Tests Assert.Equal(outer.GetQuantile(0.5), middle); Assert.Equal(outer.GetQuantile(0.7), middle); CheckGetQuantile(outer, outer); - var inner = new InnerQuantiles(7, outer); + var inner = InnerQuantiles.FromDistribution(7, outer); Assert.Equal(0.25, inner.GetProbLessThan(middle)); Assert.Equal(outer.GetQuantile(0.3), middle); Assert.Equal(outer.GetQuantile(0.5), middle); @@ -294,7 +294,7 @@ namespace Microsoft.ML.Probabilistic.Tests Assert.Equal(1.0, outer.GetProbLessThan(next)); Assert.Equal(next, outer.GetQuantile(1.0)); CheckGetQuantile(outer, outer); - var inner = new InnerQuantiles(5, outer); + var inner = InnerQuantiles.FromDistribution(5, outer); CheckGetQuantile(inner, inner, (int)Math.Ceiling(100.0 / 6), (int)Math.Floor(100.0 * 5 / 6)); var est = new QuantileEstimator(0.01); est.Add(first, 2); @@ -331,7 +331,7 @@ namespace Microsoft.ML.Probabilistic.Tests Assert.Equal(data[4], outer.GetQuantile(0.76)); Assert.Equal(data[2], outer.GetQuantile(0.3)); CheckGetQuantile(outer, outer); - var inner = new InnerQuantiles(7, outer); + var inner = InnerQuantiles.FromDistribution(7, outer); CheckGetQuantile(inner, inner, (int)Math.Ceiling(100.0 / 8), (int)Math.Floor(100.0 * 7 / 8)); } @@ -444,7 +444,7 @@ namespace Microsoft.ML.Probabilistic.Tests var sortedData = new OuterQuantiles(x.ToArray()); // compute quantiles - var quantiles = new InnerQuantiles(100, sortedData); + var quantiles = InnerQuantiles.FromDistribution(100, sortedData); // loop over x's and compare true quantile rank var testPoints = EpTests.linspace(MMath.Min(x) - stddev, MMath.Max(x) + stddev, 100);