MMath.DigammaLookup and MMath.NormalCdfMomentRatioTable are Lazy

This commit is contained in:
Tom Minka 2020-07-08 15:00:59 +01:00
Родитель 659b224ce7
Коммит dbda600379
2 изменённых файлов: 65 добавлений и 43 удалений

Просмотреть файл

@ -9,6 +9,7 @@ namespace Microsoft.ML.Probabilistic.Math
using System.Diagnostics;
using System.Linq;
using System.Numerics;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Distributions; // for Gaussian.GetLogProb
using Microsoft.ML.Probabilistic.Utilities;
@ -836,7 +837,18 @@ namespace Microsoft.ML.Probabilistic.Math
else return (MMath.GammaLn(x + n) - MMath.GammaLn(x)) / n;
}
private static double[] DigammaLookup;
const int DigammaTableLength = 100;
private static readonly Lazy<double[]> DigammaTable = new Lazy<double[]>(MakeDigammaTable);
private static double[] MakeDigammaTable()
{
double[] table = new double[DigammaTableLength];
table[0] = double.NegativeInfinity;
table[1] = Digamma1;
for (int i = 2; i < table.Length; i++)
table[i] = table[i - 1] + 1.0 / (i - 1);
return table;
}
/// <summary>
/// Evaluates Digamma(x), the derivative of ln(Gamma(x)).
@ -870,22 +882,10 @@ namespace Microsoft.ML.Probabilistic.Math
}
/* Lookup table for when x is an integer */
const int tableLength = 100;
int xAsInt = (int)x;
if ((xAsInt == x) && (xAsInt < tableLength))
if ((xAsInt == x) && (xAsInt < DigammaTableLength))
{
if (DigammaLookup == null)
{
double[] table = new double[tableLength];
table[0] = double.NegativeInfinity;
table[1] = Digamma1;
for (int i = 2; i < table.Length; i++)
table[i] = table[i - 1] + 1.0 / (i - 1);
// This is thread-safe because read/write to a reference type is atomic. See
// http://msdn.microsoft.com/en-us/library/aa691278%28VS.71%29.aspx
DigammaLookup = table;
}
return DigammaLookup[xAsInt];
return DigammaTable.Value[xAsInt];
}
if (x <= 2.5)
@ -1357,7 +1357,7 @@ namespace Microsoft.ML.Probabilistic.Math
// ACM Transactions on Mathematical Software (TOMS)
// Volume 12 Issue 4, Dec. 1986
// http://dl.acm.org/citation.cfm?id=23109
private static double[] GammaLowerAsympt_C0 =
private static readonly double[] GammaLowerAsympt_C0 =
{
-.333333333333333333333333333333E+00,
.833333333333333333333333333333E-01,
@ -2064,9 +2064,27 @@ namespace Microsoft.ML.Probabilistic.Math
#region NormalCdf functions
const int NormalCdfMomentRatioTableSize = 200;
// [0] contains moments for x=-2
// [1] contains moments for x=-3, etc.
private static readonly double[][] NormalCdfMomentRatioTable = new double[7][];
private static readonly Lazy<double[][]> NormalCdfMomentRatioTable = new Lazy<double[][]>(MakeNormalCdfMomentRatioTable);
private static double[][] MakeNormalCdfMomentRatioTable()
{
return Util.ArrayInit(7, index =>
{
double[] derivs = new double[NormalCdfMomentRatioTableSize];
double x0 = -index - 2;
var iter = NormalCdfMomentRatioSequence(0, x0, true);
for (int i = 0; i < NormalCdfMomentRatioTableSize; i++)
{
iter.MoveNext();
derivs[i] = iter.Current;
}
return derivs;
});
}
/// <summary>
/// Computes int_0^infinity t^n N(t;x,1) dt / (n! N(x;0,1))
@ -2076,27 +2094,14 @@ namespace Microsoft.ML.Probabilistic.Math
/// <returns></returns>
public static double NormalCdfMomentRatio(int n, double x)
{
const int tableSize = 200;
const int maxTerms = 60;
if (x >= -0.5)
return NormalCdfMomentRatioRecurrence(n, x);
else if (n <= tableSize - maxTerms && x > -8)
else if (n <= NormalCdfMomentRatioTableSize - maxTerms && x > -8)
{
int index = (int)(-x - 1.5); // index ranges from 0 to 6
double x0 = -index - 2;
if (NormalCdfMomentRatioTable[index] == null)
{
double[] derivs = new double[tableSize];
// this must not try to use the lookup table, since we are building it
var iter = NormalCdfMomentRatioSequence(0, x0, true);
for (int i = 0; i < derivs.Length; i++)
{
iter.MoveNext();
derivs[i] = iter.Current;
}
NormalCdfMomentRatioTable[index] = derivs;
}
return NormalCdfMomentRatioTaylor(n, x - x0, NormalCdfMomentRatioTable[index]);
return NormalCdfMomentRatioTaylor(n, x - x0, NormalCdfMomentRatioTable.Value[index]);
}
else if (x > -2)
{

Просмотреть файл

@ -16,6 +16,21 @@ namespace Microsoft.ML.Probabilistic.Tests
#if NETCOREAPP3_1
public class FloatTests
{
/// <summary>
/// Tests whether two expressions are equal for all floating-point values.
/// </summary>
internal void EqualTest()
{
float x = float.Epsilon;
while(x < float.PositiveInfinity)
{
if ((float)(x + x + x) + x != 4 * x) throw new Exception();
//if (x + x + x + x + x + x != 6 * x) throw new Exception();
//if ((x + x) + (x + x) + (x + x) != 6 * x) throw new Exception();
x = NextFloat(x);
}
}
internal void SincTest()
{
float x = float.Epsilon;
@ -49,21 +64,23 @@ namespace Microsoft.ML.Probabilistic.Tests
public static float NextFloat(float value)
{
if (value < 0) return -PreviousFloat(-value);
value = System.Math.Abs(value); // needed to handle -0
if (float.IsNaN(value)) return value;
if (float.IsPositiveInfinity(value)) return value;
int bits = BitConverter.SingleToInt32Bits(value);
return BitConverter.Int32BitsToSingle(bits + 1);
return MathF.BitIncrement(value);
//if (value < 0) return -PreviousFloat(-value);
//value = System.Math.Abs(value); // needed to handle -0
//if (float.IsNaN(value)) return value;
//if (float.IsPositiveInfinity(value)) return value;
//int bits = BitConverter.SingleToInt32Bits(value);
//return BitConverter.Int32BitsToSingle(bits + 1);
}
public static float PreviousFloat(float value)
{
if (value <= 0) return -NextFloat(-value);
if (float.IsNaN(value)) return value;
if (float.IsNegativeInfinity(value)) return value;
int bits = BitConverter.SingleToInt32Bits(value);
return BitConverter.Int32BitsToSingle(bits - 1);
return MathF.BitDecrement(value);
//if (value <= 0) return -NextFloat(-value);
//if (float.IsNaN(value)) return value;
//if (float.IsNegativeInfinity(value)) return value;
//int bits = BitConverter.SingleToInt32Bits(value);
//return BitConverter.Int32BitsToSingle(bits - 1);
}
[Fact]