* Added IrregularQuantiles

* IrregularQuantiles throws ArgumentOutOfRangeException

* Added Region.Equals and CompareTo

* BlogTests.Handedness is a test

* BallCountingTest comment

* More documentation and testing
This commit is contained in:
Tom Minka 2018-12-12 18:18:20 +00:00 коммит произвёл GitHub
Родитель 13def19155
Коммит 9f20867b92
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 194 добавлений и 12 удалений

Просмотреть файл

@ -4,6 +4,7 @@
namespace Microsoft.ML.Probabilistic.Math
{
using Microsoft.ML.Probabilistic.Utilities;
using System;
/// <summary>
@ -11,7 +12,7 @@ namespace Microsoft.ML.Probabilistic.Math
/// </summary>
public class Region
{
public Vector Lower, Upper;
public readonly Vector Lower, Upper;
public int Dimension
{
@ -113,5 +114,35 @@ namespace Microsoft.ML.Probabilistic.Math
{
return string.Format("[{0},{1}]", Lower.ToString(format), Upper.ToString(format));
}
public override bool Equals(object obj)
{
Region that = obj as Region;
if (that == null) return false;
return (that.Lower == this.Lower) && (that.Upper == this.Upper);
}
public override int GetHashCode()
{
return Hash.Combine(Lower.GetHashCode(), Upper.GetHashCode());
}
public int CompareTo(Region other)
{
int result = CompareTo(Lower, other.Lower);
if (result == 0) result = CompareTo(Upper, other.Upper);
return result;
}
public int CompareTo(Vector a, Vector b)
{
int result = 0;
for (int i = 0; i < a.Count; i++)
{
result = a[i].CompareTo(b[i]);
if (result != 0) return result;
}
return result;
}
}
}

Просмотреть файл

@ -33,6 +33,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
if (quantiles == null) throw new ArgumentNullException(nameof(quantiles));
if (quantiles.Length == 0) throw new ArgumentException("quantiles array is empty", nameof(quantiles));
OuterQuantiles.AssertFinite(quantiles, nameof(quantiles));
OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles));
this.quantiles = quantiles;
lowerGaussian = GetLowerGaussian(quantiles);
upperGaussian = GetUpperGaussian(quantiles);

Просмотреть файл

@ -0,0 +1,106 @@
using Microsoft.ML.Probabilistic.Math;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.ML.Probabilistic.Distributions
{
/// <summary>
/// Represents a distribution using the quantiles at arbitrary probabilities.
/// </summary>
public class IrregularQuantiles : CanGetQuantile, CanGetProbLessThan
{
/// <summary>
/// Sorted in ascending order.
/// </summary>
private readonly double[] probabilities, quantiles;
public IrregularQuantiles(double[] probabilities, double[] quantiles)
{
AssertIncreasing(probabilities, nameof(probabilities));
AssertInRange(probabilities, nameof(probabilities));
OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles));
OuterQuantiles.AssertFinite(quantiles, nameof(quantiles));
this.probabilities = probabilities;
this.quantiles = quantiles;
}
private void AssertIncreasing(double[] array, string paramName)
{
for (int i = 1; i < array.Length; i++)
{
if (array[i] <= array[i - 1]) throw new ArgumentException($"Array is not increasing: {paramName}[{i}] {array[i]} <= {paramName}[{i - 1}] {array[i - 1]}", paramName);
}
}
private static void AssertInRange(double[] array, string paramName)
{
for (int i = 0; i < array.Length; i++)
{
if (array[i] < 0) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} < 0");
if (array[i] > 1) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} > 1");
if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
}
}
/// <summary>
/// Returns the quantile rank of x. This is a probability such that GetQuantile(probability) == x, whenever x is inside the support of the distribution. May be discontinuous due to duplicates.
/// </summary>
/// <param name="x"></param>
/// <returns>A real number in [0,1]</returns>
public double GetProbLessThan(double x)
{
int index = Array.BinarySearch(quantiles, x);
if (index >= 0)
{
// In case of duplicates, find the smallest copy.
while (index > 0 && quantiles[index - 1] == x) index--;
return probabilities[index];
}
else
{
// Linear interpolation
int largerIndex = ~index;
if (largerIndex == 0) return 0;
if (largerIndex == quantiles.Length) return 1;
int smallerIndex = largerIndex - 1;
double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]);
return probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope;
}
}
/// <summary>
/// Returns the largest value x such that GetProbLessThan(x) &lt;= probability.
/// </summary>
/// <param name="probability">A real number in [0,1].</param>
/// <returns></returns>
public double GetQuantile(double probability)
{
if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), "probability < 0");
if (probability > 1.0) throw new ArgumentOutOfRangeException(nameof(probability), "probability > 1.0");
// The zero-based index of item in the sorted List<T>, if item is found;
// otherwise, a negative number that is the bitwise complement of the index of the next element that is larger than item
// or, if there is no larger element, the bitwise complement of Count.
int index = Array.BinarySearch(probabilities, probability);
if(index >= 0)
{
return quantiles[index];
}
else
{
// Linear interpolation
int largerIndex = ~index;
if (largerIndex == 0) return quantiles[largerIndex];
int smallerIndex = largerIndex - 1;
if (largerIndex == probabilities.Length) return quantiles[smallerIndex];
double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]);
// Solve for the largest x such that probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope <= probability.
double frac = MMath.LargestDoubleSum(-probabilities[smallerIndex], probability);
double offset = MMath.LargestDoubleProduct(slope, frac);
return MMath.LargestDoubleSum(quantiles[smallerIndex], offset);
}
}
}
}

Просмотреть файл

@ -19,13 +19,32 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <summary>
/// Numbers in increasing order.
/// </summary>
private double[] quantiles;
private readonly double[] quantiles;
public OuterQuantiles(double[] quantiles)
{
AssertNondecreasing(quantiles, nameof(quantiles));
AssertFinite(quantiles, nameof(quantiles));
this.quantiles = quantiles;
}
internal static void AssertFinite(double[] array, string paramName)
{
for (int i = 0; i < array.Length; i++)
{
if (double.IsInfinity(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
}
}
internal static void AssertNondecreasing(double[] array, string paramName)
{
for (int i = 1; i < array.Length; i++)
{
if (array[i] < array[i - 1]) throw new ArgumentException($"Array is not non-decreasing: {paramName}[{i}] {array[i]} < {paramName}[{i - 1}] {array[i - 1]}", paramName);
}
}
public OuterQuantiles(int quantileCount, CanGetQuantile canGetQuantile)
{
this.quantiles = new double[quantileCount];
@ -99,7 +118,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
double pos = MMath.LargestDoubleProduct(n - 1, probability);
int lower = (int)Math.Floor(pos);
if (lower == n - 1) return quantiles[lower];
return GetQuantile(probability, lower, quantiles[lower], quantiles[lower+1], n);
return GetQuantile(probability, lower, quantiles[lower], quantiles[lower + 1], n);
}
/// <summary>

Просмотреть файл

@ -2317,7 +2317,8 @@ namespace Microsoft.ML.Probabilistic.Tests
Console.WriteLine("P(woman's height|isTaller) = {0}", engine.Infer(heightWoman));
}
internal void Handedness()
[Fact]
public void Handedness()
{
bool[] studentData = {false, true, true, true, true, true, true, true, false, false};
bool[] lecturerData = {false, true, true, true, true, true, true, true, true, true};
@ -2339,7 +2340,9 @@ namespace Microsoft.ML.Probabilistic.Tests
// -----------------------------------
InferenceEngine engine = new InferenceEngine();
//Console.WriteLine("isRightHanded = {0}", engine.Infer(isRightHanded));
Console.WriteLine("probRightHanded = {0}", engine.Infer(probRightHanded));
var probRightHandedExpected = new Beta(7.72, 3.08);
var probRightHandedActual = engine.Infer<Beta>(probRightHanded);
Assert.True(probRightHandedExpected.MaxDiff(probRightHandedActual) < 1e-4);
}
[Fact]

Просмотреть файл

@ -85,6 +85,7 @@ namespace Microsoft.ML.Probabilistic.Tests
VariableArray<bool> observedBlue = Variable.Array<bool>(draw).Named("observedBlue");
using (Variable.ForEach(draw))
{
// cannot use Variable.DiscreteUniform(numBalls) here since ballIndex will get the wrong range.
Variable<int> ballIndex = Variable.DiscreteUniform(ball, numBalls).Named("ballIndex");
using (Variable.Switch(ballIndex))
{
@ -100,10 +101,10 @@ namespace Microsoft.ML.Probabilistic.Tests
// 16 iters with good schedule
// 120 iters with bad schedule
engine.NumberOfIterations = 150;
Discrete numUsersActual = engine.Infer<Discrete>(numBalls);
Console.WriteLine("numBalls = {0}", numUsersActual);
Discrete numUsersExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133);
Assert.True(numUsersExpected.MaxDiff(numUsersActual) < 1e-4);
Discrete numBallsActual = engine.Infer<Discrete>(numBalls);
Console.WriteLine("numBalls = {0}", numBallsActual);
Discrete numBallsExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133);
Assert.True(numBallsExpected.MaxDiff(numBallsActual) < 1e-4);
numBalls.ObservedValue = 1;
Console.WriteLine(engine.Infer(isBlue));
}

Просмотреть файл

@ -9,7 +9,7 @@ using System.Runtime.InteropServices;
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("Infer2Tests")]
[assembly: AssemblyTitle("Microsoft.ML.Probabilistic.Tests")]
[assembly: AssemblyDescription("")]
[assembly: AssemblyConfiguration("")]
[assembly: AssemblyCompany("Microsoft Research Limited")]

Просмотреть файл

@ -16,6 +16,26 @@ namespace Microsoft.ML.Probabilistic.Tests
public class QuantileTests
{
[Fact]
public void IrregularQuantilesTest()
{
var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { 3, 4, 5 });
Assert.Equal(3.25, iq.GetQuantile(0.1));
Assert.Equal(0.1, iq.GetProbLessThan(3.25));
CheckGetQuantile(iq, iq);
}
[Fact]
public void IrregularQuantiles_InfinityTest()
{
Assert.Throws<ArgumentOutOfRangeException>(() =>
{
var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { double.NegativeInfinity, 4, double.PositiveInfinity });
Assert.Equal(3.25, iq.GetQuantile(0.1));
Assert.Equal(0.1, iq.GetProbLessThan(3.25));
});
}
[Fact]
public void QuantileEstimator_SinglePointIsMedian()
{
@ -279,11 +299,11 @@ namespace Microsoft.ML.Probabilistic.Tests
foreach(double maximumError in new[] { 0.05, 0.01, 0.005, 0.001 })
{
int n = (int)(2.0 / maximumError);
QuantileEstimator(maximumError, n);
QuantileEstimatorTester(maximumError, n);
}
}
private void QuantileEstimator(double maximumError, int n)
private void QuantileEstimatorTester(double maximumError, int n)
{
// draw many samples from N(m,v)
Rand.Restart(0);