diff --git a/src/Runtime/Core/Maths/Region.cs b/src/Runtime/Core/Maths/Region.cs index 4868256b..f9398a1c 100644 --- a/src/Runtime/Core/Maths/Region.cs +++ b/src/Runtime/Core/Maths/Region.cs @@ -4,6 +4,7 @@ namespace Microsoft.ML.Probabilistic.Math { + using Microsoft.ML.Probabilistic.Utilities; using System; /// @@ -11,7 +12,7 @@ namespace Microsoft.ML.Probabilistic.Math /// public class Region { - public Vector Lower, Upper; + public readonly Vector Lower, Upper; public int Dimension { @@ -113,5 +114,35 @@ namespace Microsoft.ML.Probabilistic.Math { return string.Format("[{0},{1}]", Lower.ToString(format), Upper.ToString(format)); } + + public override bool Equals(object obj) + { + Region that = obj as Region; + if (that == null) return false; + return (that.Lower == this.Lower) && (that.Upper == this.Upper); + } + + public override int GetHashCode() + { + return Hash.Combine(Lower.GetHashCode(), Upper.GetHashCode()); + } + + public int CompareTo(Region other) + { + int result = CompareTo(Lower, other.Lower); + if (result == 0) result = CompareTo(Upper, other.Upper); + return result; + } + + public int CompareTo(Vector a, Vector b) + { + int result = 0; + for (int i = 0; i < a.Count; i++) + { + result = a[i].CompareTo(b[i]); + if (result != 0) return result; + } + return result; + } } } diff --git a/src/Runtime/Distributions/InnerQuantiles.cs b/src/Runtime/Distributions/InnerQuantiles.cs index 96f9f9c6..2dd31d3d 100644 --- a/src/Runtime/Distributions/InnerQuantiles.cs +++ b/src/Runtime/Distributions/InnerQuantiles.cs @@ -33,6 +33,8 @@ namespace Microsoft.ML.Probabilistic.Distributions { if (quantiles == null) throw new ArgumentNullException(nameof(quantiles)); if (quantiles.Length == 0) throw new ArgumentException("quantiles array is empty", nameof(quantiles)); + OuterQuantiles.AssertFinite(quantiles, nameof(quantiles)); + OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles)); this.quantiles = quantiles; lowerGaussian = GetLowerGaussian(quantiles); upperGaussian = GetUpperGaussian(quantiles); diff --git a/src/Runtime/Distributions/IrregularQuantiles.cs b/src/Runtime/Distributions/IrregularQuantiles.cs new file mode 100644 index 00000000..52a5bc0a --- /dev/null +++ b/src/Runtime/Distributions/IrregularQuantiles.cs @@ -0,0 +1,106 @@ +using Microsoft.ML.Probabilistic.Math; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.ML.Probabilistic.Distributions +{ + /// + /// Represents a distribution using the quantiles at arbitrary probabilities. + /// + public class IrregularQuantiles : CanGetQuantile, CanGetProbLessThan + { + /// + /// Sorted in ascending order. + /// + private readonly double[] probabilities, quantiles; + + public IrregularQuantiles(double[] probabilities, double[] quantiles) + { + AssertIncreasing(probabilities, nameof(probabilities)); + AssertInRange(probabilities, nameof(probabilities)); + OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles)); + OuterQuantiles.AssertFinite(quantiles, nameof(quantiles)); + this.probabilities = probabilities; + this.quantiles = quantiles; + } + + private void AssertIncreasing(double[] array, string paramName) + { + for (int i = 1; i < array.Length; i++) + { + if (array[i] <= array[i - 1]) throw new ArgumentException($"Array is not increasing: {paramName}[{i}] {array[i]} <= {paramName}[{i - 1}] {array[i - 1]}", paramName); + } + } + + private static void AssertInRange(double[] array, string paramName) + { + for (int i = 0; i < array.Length; i++) + { + if (array[i] < 0) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} < 0"); + if (array[i] > 1) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} > 1"); + if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}"); + } + } + + /// + /// Returns the quantile rank of x. This is a probability such that GetQuantile(probability) == x, whenever x is inside the support of the distribution. May be discontinuous due to duplicates. + /// + /// + /// A real number in [0,1] + public double GetProbLessThan(double x) + { + int index = Array.BinarySearch(quantiles, x); + if (index >= 0) + { + // In case of duplicates, find the smallest copy. + while (index > 0 && quantiles[index - 1] == x) index--; + return probabilities[index]; + } + else + { + // Linear interpolation + int largerIndex = ~index; + if (largerIndex == 0) return 0; + if (largerIndex == quantiles.Length) return 1; + int smallerIndex = largerIndex - 1; + double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]); + return probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope; + } + } + + /// + /// Returns the largest value x such that GetProbLessThan(x) <= probability. + /// + /// A real number in [0,1]. + /// + public double GetQuantile(double probability) + { + if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), "probability < 0"); + if (probability > 1.0) throw new ArgumentOutOfRangeException(nameof(probability), "probability > 1.0"); + // The zero-based index of item in the sorted List, if item is found; + // otherwise, a negative number that is the bitwise complement of the index of the next element that is larger than item + // or, if there is no larger element, the bitwise complement of Count. + int index = Array.BinarySearch(probabilities, probability); + if(index >= 0) + { + return quantiles[index]; + } + else + { + // Linear interpolation + int largerIndex = ~index; + if (largerIndex == 0) return quantiles[largerIndex]; + int smallerIndex = largerIndex - 1; + if (largerIndex == probabilities.Length) return quantiles[smallerIndex]; + double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]); + // Solve for the largest x such that probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope <= probability. + double frac = MMath.LargestDoubleSum(-probabilities[smallerIndex], probability); + double offset = MMath.LargestDoubleProduct(slope, frac); + return MMath.LargestDoubleSum(quantiles[smallerIndex], offset); + } + } + } +} diff --git a/src/Runtime/Distributions/OuterQuantiles.cs b/src/Runtime/Distributions/OuterQuantiles.cs index c8d914be..732ca3f9 100644 --- a/src/Runtime/Distributions/OuterQuantiles.cs +++ b/src/Runtime/Distributions/OuterQuantiles.cs @@ -19,13 +19,32 @@ namespace Microsoft.ML.Probabilistic.Distributions /// /// Numbers in increasing order. /// - private double[] quantiles; + private readonly double[] quantiles; public OuterQuantiles(double[] quantiles) { + AssertNondecreasing(quantiles, nameof(quantiles)); + AssertFinite(quantiles, nameof(quantiles)); this.quantiles = quantiles; } + internal static void AssertFinite(double[] array, string paramName) + { + for (int i = 0; i < array.Length; i++) + { + if (double.IsInfinity(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}"); + if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}"); + } + } + + internal static void AssertNondecreasing(double[] array, string paramName) + { + for (int i = 1; i < array.Length; i++) + { + if (array[i] < array[i - 1]) throw new ArgumentException($"Array is not non-decreasing: {paramName}[{i}] {array[i]} < {paramName}[{i - 1}] {array[i - 1]}", paramName); + } + } + public OuterQuantiles(int quantileCount, CanGetQuantile canGetQuantile) { this.quantiles = new double[quantileCount]; @@ -99,7 +118,7 @@ namespace Microsoft.ML.Probabilistic.Distributions double pos = MMath.LargestDoubleProduct(n - 1, probability); int lower = (int)Math.Floor(pos); if (lower == n - 1) return quantiles[lower]; - return GetQuantile(probability, lower, quantiles[lower], quantiles[lower+1], n); + return GetQuantile(probability, lower, quantiles[lower], quantiles[lower + 1], n); } /// diff --git a/test/Tests/BlogTests.cs b/test/Tests/BlogTests.cs index fae4ee76..47536bfb 100644 --- a/test/Tests/BlogTests.cs +++ b/test/Tests/BlogTests.cs @@ -2317,7 +2317,8 @@ namespace Microsoft.ML.Probabilistic.Tests Console.WriteLine("P(woman's height|isTaller) = {0}", engine.Infer(heightWoman)); } - internal void Handedness() + [Fact] + public void Handedness() { bool[] studentData = {false, true, true, true, true, true, true, true, false, false}; bool[] lecturerData = {false, true, true, true, true, true, true, true, true, true}; @@ -2339,7 +2340,9 @@ namespace Microsoft.ML.Probabilistic.Tests // ----------------------------------- InferenceEngine engine = new InferenceEngine(); //Console.WriteLine("isRightHanded = {0}", engine.Infer(isRightHanded)); - Console.WriteLine("probRightHanded = {0}", engine.Infer(probRightHanded)); + var probRightHandedExpected = new Beta(7.72, 3.08); + var probRightHandedActual = engine.Infer(probRightHanded); + Assert.True(probRightHandedExpected.MaxDiff(probRightHandedActual) < 1e-4); } [Fact] diff --git a/test/Tests/DiscreteTests.cs b/test/Tests/DiscreteTests.cs index 46ddc051..7733cf36 100644 --- a/test/Tests/DiscreteTests.cs +++ b/test/Tests/DiscreteTests.cs @@ -85,6 +85,7 @@ namespace Microsoft.ML.Probabilistic.Tests VariableArray observedBlue = Variable.Array(draw).Named("observedBlue"); using (Variable.ForEach(draw)) { + // cannot use Variable.DiscreteUniform(numBalls) here since ballIndex will get the wrong range. Variable ballIndex = Variable.DiscreteUniform(ball, numBalls).Named("ballIndex"); using (Variable.Switch(ballIndex)) { @@ -100,10 +101,10 @@ namespace Microsoft.ML.Probabilistic.Tests // 16 iters with good schedule // 120 iters with bad schedule engine.NumberOfIterations = 150; - Discrete numUsersActual = engine.Infer(numBalls); - Console.WriteLine("numBalls = {0}", numUsersActual); - Discrete numUsersExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133); - Assert.True(numUsersExpected.MaxDiff(numUsersActual) < 1e-4); + Discrete numBallsActual = engine.Infer(numBalls); + Console.WriteLine("numBalls = {0}", numBallsActual); + Discrete numBallsExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133); + Assert.True(numBallsExpected.MaxDiff(numBallsActual) < 1e-4); numBalls.ObservedValue = 1; Console.WriteLine(engine.Infer(isBlue)); } diff --git a/test/Tests/Properties/AssemblyInfo.cs b/test/Tests/Properties/AssemblyInfo.cs index ea12f586..067fe854 100644 --- a/test/Tests/Properties/AssemblyInfo.cs +++ b/test/Tests/Properties/AssemblyInfo.cs @@ -9,7 +9,7 @@ using System.Runtime.InteropServices; // set of attributes. Change these attribute values to modify the information // associated with an assembly. -[assembly: AssemblyTitle("Infer2Tests")] +[assembly: AssemblyTitle("Microsoft.ML.Probabilistic.Tests")] [assembly: AssemblyDescription("")] [assembly: AssemblyConfiguration("")] [assembly: AssemblyCompany("Microsoft Research Limited")] diff --git a/test/Tests/QuantileTests.cs b/test/Tests/QuantileTests.cs index 1462839e..10e275e4 100644 --- a/test/Tests/QuantileTests.cs +++ b/test/Tests/QuantileTests.cs @@ -16,6 +16,26 @@ namespace Microsoft.ML.Probabilistic.Tests public class QuantileTests { + [Fact] + public void IrregularQuantilesTest() + { + var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { 3, 4, 5 }); + Assert.Equal(3.25, iq.GetQuantile(0.1)); + Assert.Equal(0.1, iq.GetProbLessThan(3.25)); + CheckGetQuantile(iq, iq); + } + + [Fact] + public void IrregularQuantiles_InfinityTest() + { + Assert.Throws(() => + { + var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { double.NegativeInfinity, 4, double.PositiveInfinity }); + Assert.Equal(3.25, iq.GetQuantile(0.1)); + Assert.Equal(0.1, iq.GetProbLessThan(3.25)); + }); + } + [Fact] public void QuantileEstimator_SinglePointIsMedian() { @@ -279,11 +299,11 @@ namespace Microsoft.ML.Probabilistic.Tests foreach(double maximumError in new[] { 0.05, 0.01, 0.005, 0.001 }) { int n = (int)(2.0 / maximumError); - QuantileEstimator(maximumError, n); + QuantileEstimatorTester(maximumError, n); } } - private void QuantileEstimator(double maximumError, int n) + private void QuantileEstimatorTester(double maximumError, int n) { // draw many samples from N(m,v) Rand.Restart(0);