зеркало из https://github.com/dotnet/infer.git
Added IrregularQuantiles. (#93)
* Added IrregularQuantiles * IrregularQuantiles throws ArgumentOutOfRangeException * Added Region.Equals and CompareTo * BlogTests.Handedness is a test * BallCountingTest comment * More documentation and testing
This commit is contained in:
Родитель
13def19155
Коммит
9f20867b92
|
@ -4,6 +4,7 @@
|
|||
|
||||
namespace Microsoft.ML.Probabilistic.Math
|
||||
{
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
using System;
|
||||
|
||||
/// <summary>
|
||||
|
@ -11,7 +12,7 @@ namespace Microsoft.ML.Probabilistic.Math
|
|||
/// </summary>
|
||||
public class Region
|
||||
{
|
||||
public Vector Lower, Upper;
|
||||
public readonly Vector Lower, Upper;
|
||||
|
||||
public int Dimension
|
||||
{
|
||||
|
@ -113,5 +114,35 @@ namespace Microsoft.ML.Probabilistic.Math
|
|||
{
|
||||
return string.Format("[{0},{1}]", Lower.ToString(format), Upper.ToString(format));
|
||||
}
|
||||
|
||||
public override bool Equals(object obj)
|
||||
{
|
||||
Region that = obj as Region;
|
||||
if (that == null) return false;
|
||||
return (that.Lower == this.Lower) && (that.Upper == this.Upper);
|
||||
}
|
||||
|
||||
public override int GetHashCode()
|
||||
{
|
||||
return Hash.Combine(Lower.GetHashCode(), Upper.GetHashCode());
|
||||
}
|
||||
|
||||
public int CompareTo(Region other)
|
||||
{
|
||||
int result = CompareTo(Lower, other.Lower);
|
||||
if (result == 0) result = CompareTo(Upper, other.Upper);
|
||||
return result;
|
||||
}
|
||||
|
||||
public int CompareTo(Vector a, Vector b)
|
||||
{
|
||||
int result = 0;
|
||||
for (int i = 0; i < a.Count; i++)
|
||||
{
|
||||
result = a[i].CompareTo(b[i]);
|
||||
if (result != 0) return result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
{
|
||||
if (quantiles == null) throw new ArgumentNullException(nameof(quantiles));
|
||||
if (quantiles.Length == 0) throw new ArgumentException("quantiles array is empty", nameof(quantiles));
|
||||
OuterQuantiles.AssertFinite(quantiles, nameof(quantiles));
|
||||
OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles));
|
||||
this.quantiles = quantiles;
|
||||
lowerGaussian = GetLowerGaussian(quantiles);
|
||||
upperGaussian = GetUpperGaussian(quantiles);
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
using Microsoft.ML.Probabilistic.Math;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Distributions
|
||||
{
|
||||
/// <summary>
|
||||
/// Represents a distribution using the quantiles at arbitrary probabilities.
|
||||
/// </summary>
|
||||
public class IrregularQuantiles : CanGetQuantile, CanGetProbLessThan
|
||||
{
|
||||
/// <summary>
|
||||
/// Sorted in ascending order.
|
||||
/// </summary>
|
||||
private readonly double[] probabilities, quantiles;
|
||||
|
||||
public IrregularQuantiles(double[] probabilities, double[] quantiles)
|
||||
{
|
||||
AssertIncreasing(probabilities, nameof(probabilities));
|
||||
AssertInRange(probabilities, nameof(probabilities));
|
||||
OuterQuantiles.AssertNondecreasing(quantiles, nameof(quantiles));
|
||||
OuterQuantiles.AssertFinite(quantiles, nameof(quantiles));
|
||||
this.probabilities = probabilities;
|
||||
this.quantiles = quantiles;
|
||||
}
|
||||
|
||||
private void AssertIncreasing(double[] array, string paramName)
|
||||
{
|
||||
for (int i = 1; i < array.Length; i++)
|
||||
{
|
||||
if (array[i] <= array[i - 1]) throw new ArgumentException($"Array is not increasing: {paramName}[{i}] {array[i]} <= {paramName}[{i - 1}] {array[i - 1]}", paramName);
|
||||
}
|
||||
}
|
||||
|
||||
private static void AssertInRange(double[] array, string paramName)
|
||||
{
|
||||
for (int i = 0; i < array.Length; i++)
|
||||
{
|
||||
if (array[i] < 0) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} < 0");
|
||||
if (array[i] > 1) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]} > 1");
|
||||
if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the quantile rank of x. This is a probability such that GetQuantile(probability) == x, whenever x is inside the support of the distribution. May be discontinuous due to duplicates.
|
||||
/// </summary>
|
||||
/// <param name="x"></param>
|
||||
/// <returns>A real number in [0,1]</returns>
|
||||
public double GetProbLessThan(double x)
|
||||
{
|
||||
int index = Array.BinarySearch(quantiles, x);
|
||||
if (index >= 0)
|
||||
{
|
||||
// In case of duplicates, find the smallest copy.
|
||||
while (index > 0 && quantiles[index - 1] == x) index--;
|
||||
return probabilities[index];
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear interpolation
|
||||
int largerIndex = ~index;
|
||||
if (largerIndex == 0) return 0;
|
||||
if (largerIndex == quantiles.Length) return 1;
|
||||
int smallerIndex = largerIndex - 1;
|
||||
double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]);
|
||||
return probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns the largest value x such that GetProbLessThan(x) <= probability.
|
||||
/// </summary>
|
||||
/// <param name="probability">A real number in [0,1].</param>
|
||||
/// <returns></returns>
|
||||
public double GetQuantile(double probability)
|
||||
{
|
||||
if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), "probability < 0");
|
||||
if (probability > 1.0) throw new ArgumentOutOfRangeException(nameof(probability), "probability > 1.0");
|
||||
// The zero-based index of item in the sorted List<T>, if item is found;
|
||||
// otherwise, a negative number that is the bitwise complement of the index of the next element that is larger than item
|
||||
// or, if there is no larger element, the bitwise complement of Count.
|
||||
int index = Array.BinarySearch(probabilities, probability);
|
||||
if(index >= 0)
|
||||
{
|
||||
return quantiles[index];
|
||||
}
|
||||
else
|
||||
{
|
||||
// Linear interpolation
|
||||
int largerIndex = ~index;
|
||||
if (largerIndex == 0) return quantiles[largerIndex];
|
||||
int smallerIndex = largerIndex - 1;
|
||||
if (largerIndex == probabilities.Length) return quantiles[smallerIndex];
|
||||
double slope = (quantiles[largerIndex] - quantiles[smallerIndex]) / (probabilities[largerIndex] - probabilities[smallerIndex]);
|
||||
// Solve for the largest x such that probabilities[smallerIndex] + (x - quantiles[smallerIndex]) / slope <= probability.
|
||||
double frac = MMath.LargestDoubleSum(-probabilities[smallerIndex], probability);
|
||||
double offset = MMath.LargestDoubleProduct(slope, frac);
|
||||
return MMath.LargestDoubleSum(quantiles[smallerIndex], offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,13 +19,32 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
/// <summary>
|
||||
/// Numbers in increasing order.
|
||||
/// </summary>
|
||||
private double[] quantiles;
|
||||
private readonly double[] quantiles;
|
||||
|
||||
public OuterQuantiles(double[] quantiles)
|
||||
{
|
||||
AssertNondecreasing(quantiles, nameof(quantiles));
|
||||
AssertFinite(quantiles, nameof(quantiles));
|
||||
this.quantiles = quantiles;
|
||||
}
|
||||
|
||||
internal static void AssertFinite(double[] array, string paramName)
|
||||
{
|
||||
for (int i = 0; i < array.Length; i++)
|
||||
{
|
||||
if (double.IsInfinity(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
|
||||
if (double.IsNaN(array[i])) throw new ArgumentOutOfRangeException(paramName, $"{paramName}[{i}] {array[i]}");
|
||||
}
|
||||
}
|
||||
|
||||
internal static void AssertNondecreasing(double[] array, string paramName)
|
||||
{
|
||||
for (int i = 1; i < array.Length; i++)
|
||||
{
|
||||
if (array[i] < array[i - 1]) throw new ArgumentException($"Array is not non-decreasing: {paramName}[{i}] {array[i]} < {paramName}[{i - 1}] {array[i - 1]}", paramName);
|
||||
}
|
||||
}
|
||||
|
||||
public OuterQuantiles(int quantileCount, CanGetQuantile canGetQuantile)
|
||||
{
|
||||
this.quantiles = new double[quantileCount];
|
||||
|
@ -99,7 +118,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
double pos = MMath.LargestDoubleProduct(n - 1, probability);
|
||||
int lower = (int)Math.Floor(pos);
|
||||
if (lower == n - 1) return quantiles[lower];
|
||||
return GetQuantile(probability, lower, quantiles[lower], quantiles[lower+1], n);
|
||||
return GetQuantile(probability, lower, quantiles[lower], quantiles[lower + 1], n);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -2317,7 +2317,8 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Console.WriteLine("P(woman's height|isTaller) = {0}", engine.Infer(heightWoman));
|
||||
}
|
||||
|
||||
internal void Handedness()
|
||||
[Fact]
|
||||
public void Handedness()
|
||||
{
|
||||
bool[] studentData = {false, true, true, true, true, true, true, true, false, false};
|
||||
bool[] lecturerData = {false, true, true, true, true, true, true, true, true, true};
|
||||
|
@ -2339,7 +2340,9 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
// -----------------------------------
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
//Console.WriteLine("isRightHanded = {0}", engine.Infer(isRightHanded));
|
||||
Console.WriteLine("probRightHanded = {0}", engine.Infer(probRightHanded));
|
||||
var probRightHandedExpected = new Beta(7.72, 3.08);
|
||||
var probRightHandedActual = engine.Infer<Beta>(probRightHanded);
|
||||
Assert.True(probRightHandedExpected.MaxDiff(probRightHandedActual) < 1e-4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
|
@ -85,6 +85,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
VariableArray<bool> observedBlue = Variable.Array<bool>(draw).Named("observedBlue");
|
||||
using (Variable.ForEach(draw))
|
||||
{
|
||||
// cannot use Variable.DiscreteUniform(numBalls) here since ballIndex will get the wrong range.
|
||||
Variable<int> ballIndex = Variable.DiscreteUniform(ball, numBalls).Named("ballIndex");
|
||||
using (Variable.Switch(ballIndex))
|
||||
{
|
||||
|
@ -100,10 +101,10 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
// 16 iters with good schedule
|
||||
// 120 iters with bad schedule
|
||||
engine.NumberOfIterations = 150;
|
||||
Discrete numUsersActual = engine.Infer<Discrete>(numBalls);
|
||||
Console.WriteLine("numBalls = {0}", numUsersActual);
|
||||
Discrete numUsersExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133);
|
||||
Assert.True(numUsersExpected.MaxDiff(numUsersActual) < 1e-4);
|
||||
Discrete numBallsActual = engine.Infer<Discrete>(numBalls);
|
||||
Console.WriteLine("numBalls = {0}", numBallsActual);
|
||||
Discrete numBallsExpected = new Discrete(0, 0.5079, 0.3097, 0.09646, 0.03907, 0.02015, 0.01225, 0.008336, 0.006133);
|
||||
Assert.True(numBallsExpected.MaxDiff(numBallsActual) < 1e-4);
|
||||
numBalls.ObservedValue = 1;
|
||||
Console.WriteLine(engine.Infer(isBlue));
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ using System.Runtime.InteropServices;
|
|||
// set of attributes. Change these attribute values to modify the information
|
||||
// associated with an assembly.
|
||||
|
||||
[assembly: AssemblyTitle("Infer2Tests")]
|
||||
[assembly: AssemblyTitle("Microsoft.ML.Probabilistic.Tests")]
|
||||
[assembly: AssemblyDescription("")]
|
||||
[assembly: AssemblyConfiguration("")]
|
||||
[assembly: AssemblyCompany("Microsoft Research Limited")]
|
||||
|
|
|
@ -16,6 +16,26 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
public class QuantileTests
|
||||
{
|
||||
[Fact]
|
||||
public void IrregularQuantilesTest()
|
||||
{
|
||||
var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { 3, 4, 5 });
|
||||
Assert.Equal(3.25, iq.GetQuantile(0.1));
|
||||
Assert.Equal(0.1, iq.GetProbLessThan(3.25));
|
||||
CheckGetQuantile(iq, iq);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IrregularQuantiles_InfinityTest()
|
||||
{
|
||||
Assert.Throws<ArgumentOutOfRangeException>(() =>
|
||||
{
|
||||
var iq = new IrregularQuantiles(new double[] { 0, 0.4, 1 }, new double[] { double.NegativeInfinity, 4, double.PositiveInfinity });
|
||||
Assert.Equal(3.25, iq.GetQuantile(0.1));
|
||||
Assert.Equal(0.1, iq.GetProbLessThan(3.25));
|
||||
});
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void QuantileEstimator_SinglePointIsMedian()
|
||||
{
|
||||
|
@ -279,11 +299,11 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
foreach(double maximumError in new[] { 0.05, 0.01, 0.005, 0.001 })
|
||||
{
|
||||
int n = (int)(2.0 / maximumError);
|
||||
QuantileEstimator(maximumError, n);
|
||||
QuantileEstimatorTester(maximumError, n);
|
||||
}
|
||||
}
|
||||
|
||||
private void QuantileEstimator(double maximumError, int n)
|
||||
private void QuantileEstimatorTester(double maximumError, int n)
|
||||
{
|
||||
// draw many samples from N(m,v)
|
||||
Rand.Restart(0);
|
||||
|
|
Загрузка…
Ссылка в новой задаче