diff --git a/src/Runtime/Core/Maths/SpecialFunctions.cs b/src/Runtime/Core/Maths/SpecialFunctions.cs index ca69e050..bea93293 100644 --- a/src/Runtime/Core/Maths/SpecialFunctions.cs +++ b/src/Runtime/Core/Maths/SpecialFunctions.cs @@ -9,6 +9,7 @@ namespace Microsoft.ML.Probabilistic.Math using System.Diagnostics; using System.Linq; using System.Numerics; + using Microsoft.ML.Probabilistic.Collections; using Microsoft.ML.Probabilistic.Distributions; // for Gaussian.GetLogProb using Microsoft.ML.Probabilistic.Utilities; @@ -836,7 +837,18 @@ namespace Microsoft.ML.Probabilistic.Math else return (MMath.GammaLn(x + n) - MMath.GammaLn(x)) / n; } - private static double[] DigammaLookup; + const int DigammaTableLength = 100; + private static readonly Lazy DigammaTable = new Lazy(MakeDigammaTable); + + private static double[] MakeDigammaTable() + { + double[] table = new double[DigammaTableLength]; + table[0] = double.NegativeInfinity; + table[1] = Digamma1; + for (int i = 2; i < table.Length; i++) + table[i] = table[i - 1] + 1.0 / (i - 1); + return table; + } /// /// Evaluates Digamma(x), the derivative of ln(Gamma(x)). @@ -870,22 +882,10 @@ namespace Microsoft.ML.Probabilistic.Math } /* Lookup table for when x is an integer */ - const int tableLength = 100; int xAsInt = (int)x; - if ((xAsInt == x) && (xAsInt < tableLength)) + if ((xAsInt == x) && (xAsInt < DigammaTableLength)) { - if (DigammaLookup == null) - { - double[] table = new double[tableLength]; - table[0] = double.NegativeInfinity; - table[1] = Digamma1; - for (int i = 2; i < table.Length; i++) - table[i] = table[i - 1] + 1.0 / (i - 1); - // This is thread-safe because read/write to a reference type is atomic. See - // http://msdn.microsoft.com/en-us/library/aa691278%28VS.71%29.aspx - DigammaLookup = table; - } - return DigammaLookup[xAsInt]; + return DigammaTable.Value[xAsInt]; } if (x <= 2.5) @@ -1357,7 +1357,7 @@ namespace Microsoft.ML.Probabilistic.Math // ACM Transactions on Mathematical Software (TOMS) // Volume 12 Issue 4, Dec. 1986 // http://dl.acm.org/citation.cfm?id=23109 - private static double[] GammaLowerAsympt_C0 = + private static readonly double[] GammaLowerAsympt_C0 = { -.333333333333333333333333333333E+00, .833333333333333333333333333333E-01, @@ -2064,9 +2064,27 @@ namespace Microsoft.ML.Probabilistic.Math #region NormalCdf functions + const int NormalCdfMomentRatioTableSize = 200; + // [0] contains moments for x=-2 // [1] contains moments for x=-3, etc. - private static readonly double[][] NormalCdfMomentRatioTable = new double[7][]; + private static readonly Lazy NormalCdfMomentRatioTable = new Lazy(MakeNormalCdfMomentRatioTable); + + private static double[][] MakeNormalCdfMomentRatioTable() + { + return Util.ArrayInit(7, index => + { + double[] derivs = new double[NormalCdfMomentRatioTableSize]; + double x0 = -index - 2; + var iter = NormalCdfMomentRatioSequence(0, x0, true); + for (int i = 0; i < NormalCdfMomentRatioTableSize; i++) + { + iter.MoveNext(); + derivs[i] = iter.Current; + } + return derivs; + }); + } /// /// Computes int_0^infinity t^n N(t;x,1) dt / (n! N(x;0,1)) @@ -2076,27 +2094,14 @@ namespace Microsoft.ML.Probabilistic.Math /// public static double NormalCdfMomentRatio(int n, double x) { - const int tableSize = 200; const int maxTerms = 60; if (x >= -0.5) return NormalCdfMomentRatioRecurrence(n, x); - else if (n <= tableSize - maxTerms && x > -8) + else if (n <= NormalCdfMomentRatioTableSize - maxTerms && x > -8) { int index = (int)(-x - 1.5); // index ranges from 0 to 6 double x0 = -index - 2; - if (NormalCdfMomentRatioTable[index] == null) - { - double[] derivs = new double[tableSize]; - // this must not try to use the lookup table, since we are building it - var iter = NormalCdfMomentRatioSequence(0, x0, true); - for (int i = 0; i < derivs.Length; i++) - { - iter.MoveNext(); - derivs[i] = iter.Current; - } - NormalCdfMomentRatioTable[index] = derivs; - } - return NormalCdfMomentRatioTaylor(n, x - x0, NormalCdfMomentRatioTable[index]); + return NormalCdfMomentRatioTaylor(n, x - x0, NormalCdfMomentRatioTable.Value[index]); } else if (x > -2) { diff --git a/test/Tests/Core/FloatTests.cs b/test/Tests/Core/FloatTests.cs index d34e3b13..b4c94b9f 100644 --- a/test/Tests/Core/FloatTests.cs +++ b/test/Tests/Core/FloatTests.cs @@ -16,6 +16,21 @@ namespace Microsoft.ML.Probabilistic.Tests #if NETCOREAPP3_1 public class FloatTests { + /// + /// Tests whether two expressions are equal for all floating-point values. + /// + internal void EqualTest() + { + float x = float.Epsilon; + while(x < float.PositiveInfinity) + { + if ((float)(x + x + x) + x != 4 * x) throw new Exception(); + //if (x + x + x + x + x + x != 6 * x) throw new Exception(); + //if ((x + x) + (x + x) + (x + x) != 6 * x) throw new Exception(); + x = NextFloat(x); + } + } + internal void SincTest() { float x = float.Epsilon; @@ -49,21 +64,23 @@ namespace Microsoft.ML.Probabilistic.Tests public static float NextFloat(float value) { - if (value < 0) return -PreviousFloat(-value); - value = System.Math.Abs(value); // needed to handle -0 - if (float.IsNaN(value)) return value; - if (float.IsPositiveInfinity(value)) return value; - int bits = BitConverter.SingleToInt32Bits(value); - return BitConverter.Int32BitsToSingle(bits + 1); + return MathF.BitIncrement(value); + //if (value < 0) return -PreviousFloat(-value); + //value = System.Math.Abs(value); // needed to handle -0 + //if (float.IsNaN(value)) return value; + //if (float.IsPositiveInfinity(value)) return value; + //int bits = BitConverter.SingleToInt32Bits(value); + //return BitConverter.Int32BitsToSingle(bits + 1); } public static float PreviousFloat(float value) { - if (value <= 0) return -NextFloat(-value); - if (float.IsNaN(value)) return value; - if (float.IsNegativeInfinity(value)) return value; - int bits = BitConverter.SingleToInt32Bits(value); - return BitConverter.Int32BitsToSingle(bits - 1); + return MathF.BitDecrement(value); + //if (value <= 0) return -NextFloat(-value); + //if (float.IsNaN(value)) return value; + //if (float.IsNegativeInfinity(value)) return value; + //int bits = BitConverter.SingleToInt32Bits(value); + //return BitConverter.Int32BitsToSingle(bits - 1); } [Fact]