diff --git a/src/Runtime/Core/Maths/Region.cs b/src/Runtime/Core/Maths/Region.cs
index 4868256b..f9398a1c 100644
--- a/src/Runtime/Core/Maths/Region.cs
+++ b/src/Runtime/Core/Maths/Region.cs
@@ -4,6 +4,7 @@
namespace Microsoft.ML.Probabilistic.Math
{
+ using Microsoft.ML.Probabilistic.Utilities;
using System;
///
@@ -11,7 +12,7 @@ namespace Microsoft.ML.Probabilistic.Math
///
public class Region
{
- public Vector Lower, Upper;
+ public readonly Vector Lower, Upper;
public int Dimension
{
@@ -113,5 +114,35 @@ namespace Microsoft.ML.Probabilistic.Math
{
return string.Format("[{0},{1}]", Lower.ToString(format), Upper.ToString(format));
}
+
+ public override bool Equals(object obj)
+ {
+ Region that = obj as Region;
+ if (that == null) return false;
+ return (that.Lower == this.Lower) && (that.Upper == this.Upper);
+ }
+
+ public override int GetHashCode()
+ {
+ return Hash.Combine(Lower.GetHashCode(), Upper.GetHashCode());
+ }
+
+ public int CompareTo(Region other)
+ {
+ int result = CompareTo(Lower, other.Lower);
+ if (result == 0) result = CompareTo(Upper, other.Upper);
+ return result;
+ }
+
+ public int CompareTo(Vector a, Vector b)
+ {
+ int result = 0;
+ for (int i = 0; i < a.Count; i++)
+ {
+ result = a[i].CompareTo(b[i]);
+ if (result != 0) return result;
+ }
+ return result;
+ }
}
}
diff --git a/src/Runtime/Distributions/InnerQuantiles.cs b/src/Runtime/Distributions/InnerQuantiles.cs
index 96f9f9c6..2dd31d3d 100644
--- a/src/Runtime/Distributions/InnerQuantiles.cs
+++ b/src/Runtime/Distributions/InnerQuantiles.cs
@@ -33,6 +33,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
if (quantiles == null) throw new ArgumentNullException(nameof(quantiles));
if (quantiles.Length == 0) throw new ArgumentException("quantiles array is empty", nameof(quantiles));
+ OuterQuantiles.AssertFinite(quantiles, nameof(quantiles));
+ OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles));
this.quantiles = quantiles;
lowerGaussian = GetLowerGaussian(quantiles);
upperGaussian = GetUpperGaussian(quantiles);
diff --git a/src/Runtime/Distributions/IrregularQuantiles.cs b/src/Runtime/Distributions/IrregularQuantiles.cs
new file mode 100644
index 00000000..52a5bc0a
--- /dev/null
+++ b/src/Runtime/Distributions/IrregularQuantiles.cs
@@ -0,0 +1,106 @@
+using Microsoft.ML.Probabilistic.Math;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Microsoft.ML.Probabilistic.Distributions
+{
+ ///
+ /// Represents a distribution using the quantiles at arbitrary probabilities.
+ ///
+ public class IrregularQuantiles : CanGetQuantile, CanGetProbLessThan
+ {
+ ///
+ /// Sorted in ascending order.
+ ///
+ private readonly double[] probabilities, quantiles;
+
+ public IrregularQuantiles(double[] probabilities, double[] quantiles)
+ {
+ AssertIncreasing(probabilities, nameof(probabilities));
+ AssertInRange(probabilities, nameof(probabilities));
+ OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles));
+ OuterQuantiles.AssertFinite(quantiles, nameof(quantiles));
+ this.probabilities = probabilities;
+ this.quantiles = quantiles;
+ }
+
+ private void AssertIncreasing(double[] array, string paramName)
+ {
+ for (int i = 1; i < array.Length; i++)
+ {
+ if (array[i] <= array[i - 1]) throw new ArgumentException($"Array is not increasing: {paramName}[{i}] {array[i]} <= {paramName}[{i - 1}] {array[i - 1]}", paramName);
+ }
+ }
+
+ private static void AssertInRange(double[] array, string paramName)
+ {
+ for (int i = 0; i < array.Length; i++)
+ {
+ if (array[i] < 0) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} < 0");
+ if (array[i] > 1) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} > 1");
+ if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
+ }
+ }
+
+ ///
+ /// Returns the quantile rank of x. This is a probability such that GetQuantile(probability) == x, whenever x is inside the support of the distribution. May be discontinuous due to duplicates.
+ ///
+ ///
+ /// A real number in [0,1]
+ public double GetProbLessThan(double x)
+ {
+ int index = Array.BinarySearch(quantiles, x);
+ if (index >= 0)
+ {
+ // In case of duplicates, find the smallest copy.
+ while (index > 0 && quantiles[index - 1] == x) index--;
+ return probabilities[index];
+ }
+ else
+ {
+ // Linear interpolation
+ int largerIndex = ~index;
+ if (largerIndex == 0) return 0;
+ if (largerIndex == quantiles.Length) return 1;
+ int smallerIndex = largerIndex - 1;
+ double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]);
+ return probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope;
+ }
+ }
+
+ ///
+ /// Returns the largest value x such that GetProbLessThan(x) <= probability.
+ ///
+ /// A real number in [0,1].
+ ///
+ public double GetQuantile(double probability)
+ {
+ if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), "probability < 0");
+ if (probability > 1.0) throw new ArgumentOutOfRangeException(nameof(probability), "probability > 1.0");
+ // The zero-based index of item in the sorted List, if item is found;
+ // otherwise, a negative number that is the bitwise complement of the index of the next element that is larger than item
+ // or, if there is no larger element, the bitwise complement of Count.
+ int index = Array.BinarySearch(probabilities, probability);
+ if(index >= 0)
+ {
+ return quantiles[index];
+ }
+ else
+ {
+ // Linear interpolation
+ int largerIndex = ~index;
+ if (largerIndex == 0) return quantiles[largerIndex];
+ int smallerIndex = largerIndex - 1;
+ if (largerIndex == probabilities.Length) return quantiles[smallerIndex];
+ double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]);
+ // Solve for the largest x such that probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope <= probability.
+ double frac = MMath.LargestDoubleSum(-probabilities[smallerIndex], probability);
+ double offset = MMath.LargestDoubleProduct(slope, frac);
+ return MMath.LargestDoubleSum(quantiles[smallerIndex], offset);
+ }
+ }
+ }
+}
diff --git a/src/Runtime/Distributions/OuterQuantiles.cs b/src/Runtime/Distributions/OuterQuantiles.cs
index c8d914be..732ca3f9 100644
--- a/src/Runtime/Distributions/OuterQuantiles.cs
+++ b/src/Runtime/Distributions/OuterQuantiles.cs
@@ -19,13 +19,32 @@ namespace Microsoft.ML.Probabilistic.Distributions
///
/// Numbers in increasing order.
///
- private double[] quantiles;
+ private readonly double[] quantiles;
public OuterQuantiles(double[] quantiles)
{
+ AssertNondecreasing(quantiles, nameof(quantiles));
+ AssertFinite(quantiles, nameof(quantiles));
this.quantiles = quantiles;
}
+ internal static void AssertFinite(double[] array, string paramName)
+ {
+ for (int i = 0; i < array.Length; i++)
+ {
+ if (double.IsInfinity(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
+ if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
+ }
+ }
+
+ internal static void AssertNondecreasing(double[] array, string paramName)
+ {
+ for (int i = 1; i < array.Length; i++)
+ {
+ if (array[i] < array[i - 1]) throw new ArgumentException($"Array is not non-decreasing: {paramName}[{i}] {array[i]} < {paramName}[{i - 1}] {array[i - 1]}", paramName);
+ }
+ }
+
public OuterQuantiles(int quantileCount, CanGetQuantile canGetQuantile)
{
this.quantiles = new double[quantileCount];
@@ -99,7 +118,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
double pos = MMath.LargestDoubleProduct(n - 1, probability);
int lower = (int)Math.Floor(pos);
if (lower == n - 1) return quantiles[lower];
- return GetQuantile(probability, lower, quantiles[lower], quantiles[lower+1], n);
+ return GetQuantile(probability, lower, quantiles[lower], quantiles[lower + 1], n);
}
///
diff --git a/test/Tests/BlogTests.cs b/test/Tests/BlogTests.cs
index fae4ee76..47536bfb 100644
--- a/test/Tests/BlogTests.cs
+++ b/test/Tests/BlogTests.cs
@@ -2317,7 +2317,8 @@ namespace Microsoft.ML.Probabilistic.Tests
Console.WriteLine("P(woman's height|isTaller) = {0}", engine.Infer(heightWoman));
}
- internal void Handedness()
+ [Fact]
+ public void Handedness()
{
bool[] studentData = {false, true, true, true, true, true, true, true, false, false};
bool[] lecturerData = {false, true, true, true, true, true, true, true, true, true};
@@ -2339,7 +2340,9 @@ namespace Microsoft.ML.Probabilistic.Tests
// -----------------------------------
InferenceEngine engine = new InferenceEngine();
//Console.WriteLine("isRightHanded = {0}", engine.Infer(isRightHanded));
- Console.WriteLine("probRightHanded = {0}", engine.Infer(probRightHanded));
+ var probRightHandedExpected = new Beta(7.72, 3.08);
+ var probRightHandedActual = engine.Infer(probRightHanded);
+ Assert.True(probRightHandedExpected.MaxDiff(probRightHandedActual) < 1e-4);
}
[Fact]
diff --git a/test/Tests/DiscreteTests.cs b/test/Tests/DiscreteTests.cs
index 46ddc051..7733cf36 100644
--- a/test/Tests/DiscreteTests.cs
+++ b/test/Tests/DiscreteTests.cs
@@ -85,6 +85,7 @@ namespace Microsoft.ML.Probabilistic.Tests
VariableArray observedBlue = Variable.Array(draw).Named("observedBlue");
using (Variable.ForEach(draw))
{
+ // cannot use Variable.DiscreteUniform(numBalls) here since ballIndex will get the wrong range.
Variable ballIndex = Variable.DiscreteUniform(ball, numBalls).Named("ballIndex");
using (Variable.Switch(ballIndex))
{
@@ -100,10 +101,10 @@ namespace Microsoft.ML.Probabilistic.Tests
// 16 iters with good schedule
// 120 iters with bad schedule
engine.NumberOfIterations = 150;
- Discrete numUsersActual = engine.Infer(numBalls);
- Console.WriteLine("numBalls = {0}", numUsersActual);
- Discrete numUsersExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133);
- Assert.True(numUsersExpected.MaxDiff(numUsersActual) < 1e-4);
+ Discrete numBallsActual = engine.Infer(numBalls);
+ Console.WriteLine("numBalls = {0}", numBallsActual);
+ Discrete numBallsExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133);
+ Assert.True(numBallsExpected.MaxDiff(numBallsActual) < 1e-4);
numBalls.ObservedValue = 1;
Console.WriteLine(engine.Infer(isBlue));
}
diff --git a/test/Tests/Properties/AssemblyInfo.cs b/test/Tests/Properties/AssemblyInfo.cs
index ea12f586..067fe854 100644
--- a/test/Tests/Properties/AssemblyInfo.cs
+++ b/test/Tests/Properties/AssemblyInfo.cs
@@ -9,7 +9,7 @@ using System.Runtime.InteropServices;
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
-[assembly: AssemblyTitle("Infer2Tests")]
+[assembly: AssemblyTitle("Microsoft.ML.Probabilistic.Tests")]
[assembly: AssemblyDescription("")]
[assembly: AssemblyConfiguration("")]
[assembly: AssemblyCompany("Microsoft Research Limited")]
diff --git a/test/Tests/QuantileTests.cs b/test/Tests/QuantileTests.cs
index 1462839e..10e275e4 100644
--- a/test/Tests/QuantileTests.cs
+++ b/test/Tests/QuantileTests.cs
@@ -16,6 +16,26 @@ namespace Microsoft.ML.Probabilistic.Tests
public class QuantileTests
{
+ [Fact]
+ public void IrregularQuantilesTest()
+ {
+ var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { 3, 4, 5 });
+ Assert.Equal(3.25, iq.GetQuantile(0.1));
+ Assert.Equal(0.1, iq.GetProbLessThan(3.25));
+ CheckGetQuantile(iq, iq);
+ }
+
+ [Fact]
+ public void IrregularQuantiles_InfinityTest()
+ {
+ Assert.Throws(() =>
+ {
+ var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { double.NegativeInfinity, 4, double.PositiveInfinity });
+ Assert.Equal(3.25, iq.GetQuantile(0.1));
+ Assert.Equal(0.1, iq.GetProbLessThan(3.25));
+ });
+ }
+
[Fact]
public void QuantileEstimator_SinglePointIsMedian()
{
@@ -279,11 +299,11 @@ namespace Microsoft.ML.Probabilistic.Tests
foreach(double maximumError in new[] { 0.05, 0.01, 0.005, 0.001 })
{
int n = (int)(2.0 / maximumError);
- QuantileEstimator(maximumError, n);
+ QuantileEstimatorTester(maximumError, n);
}
}
- private void QuantileEstimator(double maximumError, int n)
+ private void QuantileEstimatorTester(double maximumError, int n)
{
// draw many samples from N(m,v)
Rand.Restart(0);