Improved TruncatedGamma support (#222)

Improved accuracy of TruncatedGamma.GetMeanAndVariance, GetMeanPower, GetNormalizer, Sample
Improved accuracy of Factor.TruncatedGammaFromShapeAndRate
GammaFromShapeAndRateOp_Slow.SampleAverageConditional handles sample=0
GammaFromShapeAndRateOp_Slow.RateAverageConditional handles rate=0
Improved accuracy of MMath.ExpMinus1, Gamma, GammaUpper
Gamma.SetShapeAndScale throws if shape is infinite
Added TruncatedGamma.GetMode
Added MMath.ToStringExact, GammaUpperScale
Changed uses of :r to :g17 numeric format string
MMath.IndexOfMinimum takes IEnumerable
PowerOp supports GammaPower = TruncatedGamma ^ y
This commit is contained in:
Tom Minka 2020-03-13 00:28:01 +00:00 коммит произвёл GitHub
Родитель 085a5a4675
Коммит 9b2cbe9619
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
17 изменённых файлов: 1350 добавлений и 339 удалений

Просмотреть файл

@ -1097,8 +1097,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
mpa.prototypeExpression = Builder.StaticMethod(new Func<Microsoft.ML.Probabilistic.Distributions.TruncatedGaussian>(Microsoft.ML.Probabilistic.Distributions.TruncatedGaussian.Uniform));
return mpa;
}
else if (Recognizer.IsStaticMethod(imie, new Func<double, double, double, double, double>(TruncatedGamma.Sample))
|| Recognizer.IsStaticMethod(imie, new Func<double, double, double, double, double>(Factor.TruncatedGammaFromShapeAndRate))
else if (Recognizer.IsStaticMethod(imie, new Func<double, double, double, double, double>(Factor.TruncatedGammaFromShapeAndRate))
)
{
MarginalPrototype mpa = new MarginalPrototype(null);

Просмотреть файл

@ -36,7 +36,7 @@ namespace Microsoft.ML.Probabilistic.Core.Maths
public override string ToString()
{
return $"{Mantissa:r}*exp({Exponent:r})";
return $"{Mantissa:g17}*exp({Exponent:g17})";
}
public static ExtendedDouble Zero()

Просмотреть файл

@ -2,17 +2,18 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Distributions; // for Gaussian.GetLogProb
namespace Microsoft.ML.Probabilistic.Math
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Numerics;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Core.Maths;
using Microsoft.ML.Probabilistic.Distributions; // for Gaussian.GetLogProb
using Microsoft.ML.Probabilistic.Utilities;
/// <summary>
/// This class provides mathematical constants and special functions,
@ -687,7 +688,6 @@ namespace Microsoft.ML.Probabilistic.Math
/// <returns>Gamma(x).</returns>
public static double Gamma(double x)
{
/* Negative values */
if (x < 0)
{
// this test also catches -inf
@ -700,6 +700,11 @@ namespace Microsoft.ML.Probabilistic.Math
return -Math.PI / (x * Math.Sin(Math.PI * x) * Gamma(-x));
}
if (x <= GammaSmallX)
{
return 1 / x + GammaSeries(x);
}
if (x > 180)
{
return Double.PositiveInfinity;
@ -708,6 +713,8 @@ namespace Microsoft.ML.Probabilistic.Math
return Math.Exp(GammaLn(x));
}
private const double GammaSmallX = 1e-3;
/// <summary>
/// Computes the natural logarithm of the Gamma function.
/// </summary>
@ -885,7 +892,7 @@ namespace Microsoft.ML.Probabilistic.Math
/* Shift the argument and use Taylor series near 1 if argument <= S */
if (x <= c_digamma_small)
{
return - 1 / x +
return -1 / x +
// Truncated series 2: Digamma at 1
// Generated automatically by /src/Tools/GenerateSeries/GenerateSeries.py
-0.577215664901532860606512090082 +
@ -1141,7 +1148,7 @@ namespace Microsoft.ML.Probabilistic.Math
/// <returns></returns>
public static double ChooseLn(double n, double k)
{
if (k <= -1 || k >= n+1)
if (k <= -1 || k >= n + 1)
return Double.NegativeInfinity;
return GammaLn(n + 1) - GammaLn(k + 1) - GammaLn(n - k + 1);
}
@ -1179,20 +1186,20 @@ namespace Microsoft.ML.Probabilistic.Math
throw new ArgumentException($"x ({x}) < 0");
if (!regularized)
{
if (a < 1) return GammaUpperConFrac(a, x, regularized);
if (a < 1 && x >= 1) return GammaUpperConFrac(a, x, regularized);
else if (a <= GammaSmallX) return GammaUpperSeries(a, x, regularized);
else return Gamma(a) * GammaUpper(a, x, true);
}
if (a <= 0)
throw new ArgumentException($"a ({a}) <= 0");
if (x == 0) return 1; // avoid 0/0
// Use the criterion from Gautschi (1979) to determine whether GammaLower(a,x) or GammaUpper(a,x) is smaller.
// useLower = true means that GammaLower is smaller.
bool useLower;
bool lowerIsSmaller;
if (x > 0.25)
useLower = (a > x + 0.25);
lowerIsSmaller = (a > x + 0.25);
else
useLower = (a > -MMath.Ln2 / Math.Log(x));
if (useLower)
lowerIsSmaller = (a > -MMath.Ln2 / Math.Log(x));
if (lowerIsSmaller)
{
if (x < 0.5 * a + 67)
return 1 - GammaLowerSeries(a, x);
@ -1203,6 +1210,9 @@ namespace Microsoft.ML.Probabilistic.Math
return GammaAsympt(a, x, true);
else if (x > 1.5)
return GammaUpperConFrac(a, x);
else if (a <= 1e-16)
// Gamma(a) = 1/a for a <= 1e-16
return a * GammaUpperSeries(a, x, false);
else
return GammaUpperSeries(a, x);
}
@ -1222,13 +1232,12 @@ namespace Microsoft.ML.Probabilistic.Math
throw new ArgumentException($"a ({a}) <= 0");
if (x == 0) return 0; // avoid 0/0
// Use the criterion from Gautschi (1979) to determine whether GammaLower(a,x) or GammaUpper(a,x) is smaller.
// useLower = true means that GammaLower is smaller.
bool useLower;
bool lowerIsSmaller;
if (x > 0.25)
useLower = (a > x + 0.25);
lowerIsSmaller = (a > x + 0.25);
else
useLower = (a > -MMath.Ln2 / Math.Log(x));
if (useLower)
lowerIsSmaller = (a > -MMath.Ln2 / Math.Log(x));
if (lowerIsSmaller)
{
if (x < 0.5 * a + 67)
return GammaLowerSeries(a, x);
@ -1243,6 +1252,31 @@ namespace Microsoft.ML.Probabilistic.Math
return 1 - GammaUpperSeries(a, x);
}
/// <summary>
/// Computes <c>x - log(1+x)</c> to high accuracy.
/// </summary>
/// <param name="x">Any real number &gt;= -1</param>
/// <returns>A real number &gt;= 0</returns>
private static double XMinusLog1Plus(double x)
{
if (Math.Abs(x) < 1e-1)
{
return
// Truncated series 12: x - log(1 + x)
// Generated automatically by /src/Tools/GenerateSeries/GenerateSeries.py
x * x * (1.0 / 2.0 +
x * (-1.0 / 3.0 +
x * (1.0 / 4.0 +
x * (-1.0 / 5.0 +
x * (1.0 / 6.0
)))));
}
else
{
return x - MMath.Log1Plus(x);
}
}
// Reference:
// "Computation of the incomplete gamma function ratios and their inverse"
// Armido R DiDonato and Alfred H Morris, Jr.
@ -1256,7 +1290,7 @@ namespace Microsoft.ML.Probabilistic.Math
/// Compute the regularized lower incomplete Gamma function: <c>int_0^x t^(a-1) exp(-t) dt / Gamma(a)</c>
/// </summary>
/// <param name="a">Must be &gt; 20</param>
/// <param name="x"></param>
/// <param name="x">A real number &gt;= 0</param>
/// <param name="upper">If true, compute the upper incomplete Gamma function</param>
/// <returns></returns>
private static double GammaAsympt(double a, double x, bool upper)
@ -1264,38 +1298,28 @@ namespace Microsoft.ML.Probabilistic.Math
if (a <= 20)
throw new Exception("a <= 20");
double xOverAMinus1 = (x - a) / a;
double phi;
if (Math.Abs(xOverAMinus1) < 1e-1)
{
phi =
// Truncated series 12: x - log(1 + x)
// Generated automatically by /src/Tools/GenerateSeries/GenerateSeries.py
xOverAMinus1 * xOverAMinus1 * (1.0 / 2.0 +
xOverAMinus1 * (-1.0 / 3.0 +
xOverAMinus1 * (1.0 / 4.0 +
xOverAMinus1 * (-1.0 / 5.0 +
xOverAMinus1 * 1.0 / 6.0))))
;
}
else
{
phi = xOverAMinus1 - MMath.Log1Plus(xOverAMinus1);
}
double phi = XMinusLog1Plus(xOverAMinus1);
// phi >= 0
double y = a * phi;
double z = Math.Sqrt(2 * phi);
if (x <= a)
z *= -1;
double[] b = new double[GammaLowerAsympt_C0.Length];
b[b.Length - 1] = GammaLowerAsympt_C0[b.Length - 1];
double sum = b[b.Length - 1];
b[b.Length - 2] = GammaLowerAsympt_C0[b.Length - 2];
sum = z * sum + b[b.Length - 2];
for (int i = b.Length - 3; i >= 0; i--)
int length = GammaLowerAsympt_C0.Length;
double bEven = GammaLowerAsympt_C0[length - 1];
double sum = bEven;
double bOdd = GammaLowerAsympt_C0[length - 2];
sum = z * sum + bOdd;
for (int i = length - 3; i >= 0; i -= 2)
{
b[i] = b[i + 2] * (i + 2) / a + GammaLowerAsympt_C0[i];
sum = z * sum + b[i];
bEven = bEven * (i + 2) / a + GammaLowerAsympt_C0[i];
sum = z * sum + bEven;
if (i > 0)
{
bOdd = bOdd * (i + 1) / a + GammaLowerAsympt_C0[i - 1];
sum = z * sum + bOdd;
}
}
sum *= a / (a + b[1]);
sum *= a / (a + bOdd);
if (x <= a)
sum *= -1;
double result = 0.5 * Erfc(Math.Sqrt(y)) + sum * Math.Exp(-y) / (Math.Sqrt(a) * MMath.Sqrt2PI);
@ -1366,47 +1390,90 @@ namespace Microsoft.ML.Probabilistic.Math
if (AreEqual(sum, oldSum))
return sum * scale;
}
throw new Exception(string.Format("GammaLowerSeries not converging for a={0} x={1}", a, x));
throw new Exception($"GammaLowerSeries not converging for a={a:g17} x={x:g17}");
}
/// <summary>
/// Compute the regularized upper incomplete Gamma function by a series expansion
/// Compute the upper incomplete Gamma function by a series expansion
/// </summary>
/// <param name="a">The shape parameter, &gt; 0</param>
/// <param name="x">The lower bound of the integral, &gt;= 0</param>
/// <param name="regularized">If true, result is divided by Gamma(a)</param>
/// <returns></returns>
private static double GammaUpperSeries(double a, double x)
private static double GammaUpperSeries(double a, double x, bool regularized = true)
{
// this series should only be applied when x is small
// the series is: 1 - x^a sum_{k=0}^inf (-x)^k /(k! Gamma(a+k+1))
// = (1 - 1/Gamma(a+1)) + (1 - x^a)/Gamma(a+1) - x^a sum_{k=1}^inf (-x)^k/(k! Gamma(a+k+1))
double xaMinus1 = MMath.ExpMinus1(a * Math.Log(x));
double aReciprocalFactorial, aReciprocalFactorialMinus1;
if (a > 0.3)
// the regularized series is: 1 - x^a/Gamma(a) sum_{k=0}^inf (-x)^k /(k! (a+k))
// = (1 - 1/Gamma(a+1)) + (1 - x^a)/Gamma(a+1) - x^a/Gamma(a) sum_{k=1}^inf (-x)^k/(k! (a+k))
// The unregularized series is:
// = (Gamma(a) - 1/a) + (1 - x^a)/a - x^a sum_{k=1}^inf (-x)^k/(k! (a+k))
double logx = Math.Log(x);
double alogx = a * logx;
double xaMinus1 = MMath.ExpMinus1(alogx);
double offset, scale, term;
if (regularized)
{
aReciprocalFactorial = 1 / MMath.Gamma(a + 1);
aReciprocalFactorialMinus1 = aReciprocalFactorial - 1;
double aReciprocalFactorial, aReciprocalFactorialMinus1;
if (a > 0.3)
{
aReciprocalFactorial = 1 / MMath.Gamma(a + 1);
aReciprocalFactorialMinus1 = aReciprocalFactorial - 1;
}
else
{
aReciprocalFactorialMinus1 = ReciprocalFactorialMinus1(a);
aReciprocalFactorial = 1 + aReciprocalFactorialMinus1;
}
// offset = 1 - x^a/Gamma(a+1)
offset = -xaMinus1 * aReciprocalFactorial - aReciprocalFactorialMinus1;
scale = 1 - offset;
term = x * a;
}
else
{
aReciprocalFactorialMinus1 = ReciprocalFactorialMinus1(a);
aReciprocalFactorial = 1 + aReciprocalFactorialMinus1;
if (Math.Abs(alogx) <= 1e-16) offset = GammaSeries(a) - logx;
else offset = GammaSeries(a) - xaMinus1 / a;
scale = 1 + xaMinus1;
term = x;
}
// offset = 1 - x^a/Gamma(a+1)
double offset = -xaMinus1 * aReciprocalFactorial - aReciprocalFactorialMinus1;
double scale = 1 - offset;
double term = x / (a + 1) * a;
double sum = term;
double sum = term / (a + 1);
for (int i = 1; i < 1000; i++)
{
term *= -(a + i) * x / ((a + i + 1) * (i + 1));
term *= -x / (i + 1);
double sumOld = sum;
sum += term;
sum += term / (a + i + 1);
//Console.WriteLine("{0}: {1}", i, sum);
if (AreEqual(sum, sumOld))
{
return scale * sum + offset;
}
}
throw new Exception(string.Format("GammaUpperSeries not converging for a={0} x={1}", a, x));
throw new Exception($"GammaUpperSeries not converging for a={a:g17} x={x:g17} regularized={regularized}");
}
/// <summary>
/// Compute <c>Gamma(x) - 1/x</c> to high accuracy
/// </summary>
/// <param name="x">A real number &gt;= 0</param>
/// <returns></returns>
private static double GammaSeries(double x)
{
if (x > GammaSmallX)
return MMath.Gamma(x) - 1 / x;
else
/* using http://sagecell.sagemath.org/ (must not be indented)
var('x');
f = gamma(x)-1/x
[c[0].n(100) for c in f.taylor(x,0,16).coefficients()]
*/
return Digamma1
+ x * (0.98905599532797255539539565150
+ x * (-0.90747907608088628901656016736
+ x * (0.98172808683440018733638029402
+ x * (-0.98199506890314520210470141379
+ x * (0.99314911462127619315386725333
+ x * (-0.99600176044243153397007841966
))))));
}
/// <summary>
@ -1447,20 +1514,32 @@ namespace Microsoft.ML.Probabilistic.Math
/// <param name="a">A positive real number</param>
/// <param name="x"></param>
/// <returns></returns>
private static double GammaUpperScale(double a, double x)
public static double GammaUpperScale(double a, double x)
{
return Math.Exp(GammaUpperLogScale(a, x));
}
/// <summary>
/// Computes <c>log(x^a e^(-x)/Gamma(a))</c> to high accuracy.
/// </summary>
/// <param name="a">A positive real number</param>
/// <param name="x"></param>
/// <returns></returns>
public static double GammaUpperLogScale(double a, double x)
{
if (double.IsPositiveInfinity(x) || double.IsPositiveInfinity(a))
return 0;
double scale;
return double.NegativeInfinity;
if (a < 10)
scale = Math.Exp(a * Math.Log(x) - x - GammaLn(a));
{
return a * Math.Log(x) - x - GammaLn(a);
}
else
{
double xia = x / a;
double phi = xia - 1 - Math.Log(xia);
scale = Math.Exp(0.5 * Math.Log(a) - MMath.LnSqrt2PI - GammaLnSeries(a) - a * phi);
// Result is inaccurate for a=100, x=3
double xOverAMinus1 = (x - a) / a;
double phi = XMinusLog1Plus(xOverAMinus1);
return 0.5 * Math.Log(a) - MMath.LnSqrt2PI - GammaLnSeries(a) - a * phi;
}
return scale;
}
// Origin: James McCaffrey, http://msdn.microsoft.com/en-us/magazine/dn520240.aspx
@ -1469,7 +1548,7 @@ namespace Microsoft.ML.Probabilistic.Math
/// </summary>
/// <param name="a">A real number. Must be &gt; 0 if regularized is true.</param>
/// <param name="x">A real number &gt;= 1.1</param>
/// <param name="regularized">If true, result is divded by Gamma(a)</param>
/// <param name="regularized">If true, result is divided by Gamma(a)</param>
/// <returns></returns>
private static double GammaUpperConFrac(double a, double x, bool regularized = true)
{
@ -1481,29 +1560,29 @@ namespace Microsoft.ML.Probabilistic.Math
// a_i = -i*(i-a)
// b_i = x+1-a+2*i
// the confrac is evaluated using Lentz's algorithm
double b = x + 1.0 - a;
double b = x - a + 1.0;
const double tiny = 1e-30;
double c = 1.0 / tiny;
double d = 1.0 / b;
double h = d * scale;
double d = b;
double h = scale / d;
for (int i = 1; i < 1000; ++i)
{
double an = -i * (i - a);
b += 2.0;
d = an * d + b;
d = an / d + b;
if (Math.Abs(d) < tiny)
d = tiny;
c = b + an / c;
if (Math.Abs(c) < tiny)
c = tiny;
d = 1.0 / d;
double del = d * c;
double del = c / d;
double oldH = h;
h *= del;
//Trace.WriteLine($"h = {h} del = {del}");
if (AreEqual(h, oldH))
return h;
}
throw new Exception(string.Format("GammaUpperConFrac not converging for a={0} x={1}", a, x));
throw new Exception($"GammaUpperConFrac not converging for a={a:g17} x={x:g17}");
}
/// <summary>
@ -1532,8 +1611,8 @@ namespace Microsoft.ML.Probabilistic.Math
invX2 * (-1.0 / 1680.0 +
invX2 * (1.0 / 1188.0 +
invX2 * (-691.0 / 360360.0 +
invX2 * 1.0 / 156.0)))))
);
invX2 * (1.0 / 156.0
)))))));
return sum;
}
}
@ -2103,7 +2182,7 @@ namespace Microsoft.ML.Probabilistic.Math
return r;
rOld = r;
}
throw new Exception($"Not converging for n={n},x={x:r}");
throw new Exception($"Not converging for n={n},x={x:g17}");
}
/// <summary>
@ -2504,7 +2583,7 @@ namespace Microsoft.ML.Probabilistic.Math
double sumOld = sum;
for (int n = 2; n <= 100; n++)
{
//Console.WriteLine($"n = {n - 1} sum = {sum:r}");
//Console.WriteLine($"n = {n - 1} sum = {sum:g17}");
double dlogphiOverFactorial;
if (n % 2 == 0) dlogphiOverFactorial = 1.0 / n - Halfx2y2;
else dlogphiOverFactorial = xy;
@ -2519,7 +2598,7 @@ namespace Microsoft.ML.Probabilistic.Math
rPowerN *= r;
sum += QderivOverFactorial * rPowerN;
if ((sum > double.MaxValue) || double.IsNaN(sum) || n >= 100)
throw new Exception($"NormalCdfRatioTaylor not converging for x={x:r}, y={y:r}, r={r:r}");
throw new Exception($"NormalCdfRatioTaylor not converging for x={x:g17}, y={y:g17}, r={r:g17}");
if (AreEqual(sum, sumOld)) break;
sumOld = sum;
}
@ -2738,7 +2817,7 @@ rr = mpf('-0.99999824265582826');
numerPrevPlusC = scale * (c1 * R1xmry - c2 * R2xmry - c3);
//numer3 = scale / 3 * x * (c1 * R1xmry - c2 * R2xmry - c3 + r * omr2 * R2xmry);
//numer4 = scale / 3 * ((3 + x * x) * c1 * R1xmry + x * r * sqrtomr2 * omr2 * R3xmry - (3 * c2 + x * x * (c2 - r * omr2)) * R2xmry - (3 + x * x) * c3);
//Trace.WriteLine($"numerPrevPlusC = {numerPrevPlusC:r} numer4 = {numer4:r}");
//Trace.WriteLine($"numerPrevPlusC = {numerPrevPlusC:g17} numer4 = {numer4:g17}");
//shiftNumer = omr2 * x * x > 100;
shiftNumer = true;
}
@ -2905,9 +2984,9 @@ rr = mpf('-0.99999824265582826');
else numer2 = numer;
double result = numer2 / (denom - 1);
if (TraceConFrac)
Trace.WriteLine($"iter {i}: result={result:r} c={c:r} cOdd={cOdd:r} numer={numer:r} numer2={numer2:r} denom={denom:r} numerPrev={numerPrev:r}");
Trace.WriteLine($"iter {i}: result={result:g17} c={c:g17} cOdd={cOdd:g17} numer={numer:g17} numer2={numer2:g17} denom={denom:g17} numerPrev={numerPrev:g17}");
if ((result > double.MaxValue) || double.IsNaN(result) || result < 0 || i >= iterationCount - 1)
throw new Exception($"NormalCdfRatioConFrac2 not converging for x={x:r} y={y:r} r={r:r} sqrtomr2={sqrtomr2:r} scale={scale:r}");
throw new Exception($"NormalCdfRatioConFrac2 not converging for x={x:g17} y={y:g17} r={r:g17} sqrtomr2={sqrtomr2:g17} scale={scale:g17}");
if (AreEqual(result, resultPrev) || AbsDiff(result, resultPrev, 0) < 1e-13)
break;
resultPrev = result;
@ -2954,9 +3033,9 @@ rr = mpf('-0.99999824265582826');
{
double result = numer / denom;
if (TraceConFrac)
Trace.WriteLine($"iter {i}: result={result:r} c={c:g4} numer={numer:r} denom={denom:r} numerPrev={numerPrev:r}");
Trace.WriteLine($"iter {i}: result={result:g17} c={c:g4} numer={numer:g17} denom={denom:g17} numerPrev={numerPrev:g17}");
if ((result > double.MaxValue) || double.IsNaN(result) || result < 0 || i >= 1000)
throw new Exception($"NormalCdfRatioConFrac2b not converging for x={x:r} y={y:r} r={r:r} scale={scale}");
throw new Exception($"NormalCdfRatioConFrac2b not converging for x={x:g17} y={y:g17} r={r:g17} scale={scale}");
if (AreEqual(result, resultPrev))
break;
resultPrev = result;
@ -3456,7 +3535,7 @@ rr = mpf('-0.99999824265582826');
/// </remarks>
public static double ExpMinus1(double x)
{
if (Math.Abs(x) < 2e-3)
if (Math.Abs(x) < 55e-3)
{
return
// Truncated series 13: exp(x) - 1
@ -3464,8 +3543,12 @@ rr = mpf('-0.99999824265582826');
x * (1.0 +
x * (1.0 / 2.0 +
x * (1.0 / 6.0 +
x * 1.0 / 24.0)))
;
x * (1.0 / 24.0 +
x * (1.0 / 120.0 +
x * (1.0 / 720.0 +
x * (1.0 / 5040.0 +
x * (1.0 / 40320.0
))))))));
}
else
{
@ -4243,22 +4326,24 @@ else if (m < 20.0 - 60.0/11.0 * s) {
/// <typeparam name="T"></typeparam>
/// <param name="list"></param>
/// <returns></returns>
public static int IndexOfMinimum<T>(IList<T> list)
public static int IndexOfMinimum<T>(IEnumerable<T> list)
where T : IComparable<T>
{
if (list.Count == 0)
IEnumerator<T> iter = list.GetEnumerator();
if (!iter.MoveNext())
return -1;
T min = list[0];
int pos = 0;
for (int i = 1; i < list.Count; i++)
T min = iter.Current;
int indexOfMinimum = 0;
for (int i = 1; iter.MoveNext(); i++)
{
if (min.CompareTo(list[i]) > 0)
T item = iter.Current;
if (min.CompareTo(item) > 0)
{
min = list[i];
pos = i;
min = item;
indexOfMinimum = i;
}
}
return pos;
return indexOfMinimum;
}
/// <summary>
@ -4267,22 +4352,24 @@ else if (m < 20.0 - 60.0/11.0 * s) {
/// <typeparam name="T"></typeparam>
/// <param name="list"></param>
/// <returns></returns>
public static int IndexOfMaximum<T>(IList<T> list)
public static int IndexOfMaximum<T>(IEnumerable<T> list)
where T : IComparable<T>
{
if (list.Count == 0)
IEnumerator<T> iter = list.GetEnumerator();
if (!iter.MoveNext())
return -1;
T max = list[0];
int pos = 0;
for (int i = 1; i < list.Count; i++)
T max = iter.Current;
int indexOfMaximum = 0;
for (int i = 1; iter.MoveNext(); i++)
{
if (max.CompareTo(list[i]) < 0)
T item = iter.Current;
if (max.CompareTo(item) < 0)
{
max = list[i];
pos = i;
max = item;
indexOfMaximum = i;
}
}
return pos;
return indexOfMaximum;
}
/// <summary>
@ -4516,7 +4603,7 @@ else if (m < 20.0 - 60.0/11.0 * s) {
{
iterCount++;
double value = (double)Average(lowerBound, upperBound);
if (value < lowerBound || value > upperBound) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, denominator={denominator:r}, ratio={numerator:r}");
if (value < lowerBound || value > upperBound) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, denominator={denominator:g17}, ratio={numerator:g17}");
if ((double)(value * denominator) <= numerator)
{
double value2 = NextDouble(value);
@ -4531,14 +4618,14 @@ else if (m < 20.0 - 60.0/11.0 * s) {
{
// value is too low
lowerBound = value2;
if (lowerBound > upperBound || double.IsNaN(lowerBound)) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, denominator={denominator:r}, ratio={numerator:r}");
if (lowerBound > upperBound || double.IsNaN(lowerBound)) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, denominator={denominator:g17}, ratio={numerator:g17}");
}
}
else
{
// value is too high
upperBound = PreviousDouble(value);
if (lowerBound > upperBound || double.IsNaN(upperBound)) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, denominator={denominator:r}, ratio={numerator:r}");
if (lowerBound > upperBound || double.IsNaN(upperBound)) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, denominator={denominator:g17}, ratio={numerator:g17}");
}
}
}
@ -4598,7 +4685,7 @@ else if (m < 20.0 - 60.0/11.0 * s) {
{
iterCount++;
double value = (double)Average(lowerBound, upperBound);
if (value < lowerBound || value > upperBound) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, denominator={denominator:r}, ratio={ratio:r}");
if (value < lowerBound || value > upperBound) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, denominator={denominator:g17}, ratio={ratio:g17}");
if ((double)(value / denominator) <= ratio)
{
double value2 = NextDouble(value);
@ -4613,14 +4700,14 @@ else if (m < 20.0 - 60.0/11.0 * s) {
{
// value is too low
lowerBound = value2;
if (lowerBound > upperBound || double.IsNaN(lowerBound)) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, denominator={denominator:r}, ratio={ratio:r}");
if (lowerBound > upperBound || double.IsNaN(lowerBound)) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, denominator={denominator:g17}, ratio={ratio:g17}");
}
}
else
{
// value is too high
upperBound = PreviousDouble(value);
if (lowerBound > upperBound || double.IsNaN(upperBound)) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, denominator={denominator:r}, ratio={ratio:r}");
if (lowerBound > upperBound || double.IsNaN(upperBound)) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, denominator={denominator:g17}, ratio={ratio:g17}");
}
}
}
@ -4660,7 +4747,7 @@ else if (m < 20.0 - 60.0/11.0 * s) {
iterCount++;
double value = (double)Average(lowerBound, upperBound);
//double value = RepresentationMidpoint(lowerBound, upperBound);
if (value < lowerBound || value > upperBound) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, b={b:r}, sum={sum:r}");
if (value < lowerBound || value > upperBound) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, b={b:g17}, sum={sum:g17}");
if ((double)(value - b) <= sum)
{
double value2 = NextDouble(value);
@ -4675,14 +4762,14 @@ else if (m < 20.0 - 60.0/11.0 * s) {
{
// value is too low
lowerBound = value2;
if (lowerBound > upperBound || double.IsNaN(lowerBound)) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, b={b:r}, sum={sum:r}");
if (lowerBound > upperBound || double.IsNaN(lowerBound)) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, b={b:g17}, sum={sum:g17}");
}
}
else
{
// value is too high
upperBound = PreviousDouble(value);
if (lowerBound > upperBound || double.IsNaN(upperBound)) throw new Exception($"value={value:r}, lowerBound={lowerBound:r}, upperBound={upperBound:r}, b={b:r}, sum={sum:r}");
if (lowerBound > upperBound || double.IsNaN(upperBound)) throw new Exception($"value={value:g17}, lowerBound={lowerBound:g17}, upperBound={upperBound:g17}, b={b:g17}, sum={sum:g17}");
}
}
}
@ -4738,6 +4825,62 @@ else if (m < 20.0 - 60.0/11.0 * s) {
return BitConverter.Int64BitsToDouble(midpoint);
}
/// <summary>
/// Returns a decimal string that exactly equals a double-precision number, unlike double.ToString which always returns a rounded result.
/// </summary>
/// <param name="x"></param>
/// <returns></returns>
public static string ToStringExact(double x)
{
if (double.IsNaN(x) || double.IsInfinity(x) || x == 0) return x.ToString(System.Globalization.CultureInfo.InvariantCulture);
long bits = BitConverter.DoubleToInt64Bits(x);
ulong fraction = Convert.ToUInt64(bits & 0x000fffffffffffff);
short exponent = Convert.ToInt16((bits & 0x7ff0000000000000) >> 52);
if (exponent == 0)
{
// subnormal number
exponent = -1022 - 52;
}
else
{
// normal number
fraction += 0x0010000000000000;
exponent = Convert.ToInt16(exponent - 1023 - 52);
}
while ((fraction & 1) == 0)
{
fraction >>= 1;
exponent++;
}
string sign = (x >= 0) ? "" : "-";
BigInteger big;
if (exponent >= 0)
{
big = BigInteger.Pow(2, exponent) * fraction;
return $"{sign}{big}";
}
else
{
// Rewrite 2^-4 as 5^4 * 10^-4
big = BigInteger.Pow(5, -exponent) * fraction;
// At this point, we could output the big integer with an "E"{exponent} suffix.
// However, double.Parse does not correctly parse such strings.
// Instead we insert a decimal point and eliminate the "E" suffix if possible.
int digitCount = big.ToString().Length;
if (digitCount < -exponent)
{
return $"{sign}0.{big}e{exponent + digitCount}";
}
else
{
BigInteger pow10 = BigInteger.Pow(10, -exponent);
BigInteger integerPart = big / pow10;
BigInteger fractionalPart = big - integerPart*pow10;
string zeros = new string('0', -exponent);
return $"{sign}{integerPart}.{fractionalPart.ToString(zeros)}";
}
}
}
#region Enumerations and constants

Просмотреть файл

@ -174,15 +174,17 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <param name="rate">rate = 1/scale</param>
public void SetShapeAndRate(double shape, double rate)
{
if (rate > double.MaxValue)
this.Shape = shape;
this.Rate = rate;
CheckForPointMass();
}
private void CheckForPointMass()
{
if (!IsPointMass && Rate > double.MaxValue)
{
Point = 0;
}
else
{
this.Shape = shape;
this.Rate = rate;
}
}
/// <summary>
@ -225,14 +227,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <param name="scale">Scale</param>
public void SetShapeAndScale(double shape, double scale)
{
if (scale == 0)
{
Point = 0;
}
else
{
SetShapeAndRate(shape, 1.0 / scale);
}
if (double.IsPositiveInfinity(shape)) throw new ArgumentOutOfRangeException(nameof(shape), "shape is infinite. To create a point mass, set the Point property.");
SetShapeAndRate(shape, 1.0 / scale);
}
/// <summary>
@ -537,15 +533,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
get { return (Shape == Double.PositiveInfinity); }
}
/// <summary>
/// Sets this instance to a point mass. The location of the
/// point mass is the existing Rate parameter
/// </summary>
private void SetToPointMass()
{
Shape = Double.PositiveInfinity;
}
/// <summary>
/// Sets/gets the instance as a point mass
/// </summary>
@ -560,7 +547,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
set
{
SetToPointMass();
Shape = Double.PositiveInfinity;
Rate = value;
}
}
@ -599,7 +586,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
if (x < 0) return double.NegativeInfinity;
if (x > double.MaxValue) // Avoid subtracting infinities below
{
{
if (rate > 0) return -x;
else if (rate < 0) return x;
// fall through when rate == 0

Просмотреть файл

@ -206,8 +206,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
if (!IsPointMass && Rate > double.MaxValue)
{
Rate = Math.Pow(0, Power);
SetToPointMass();
Point = Math.Pow(0, Power);
}
}
@ -283,7 +282,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
double oldShape = shape;
logRate = MMath.RisingFactorialLnOverN(shape, power) - logMeanOverPower;
shape = Math.Exp(meanLogOverPower + logRate) + 0.5;
//Console.WriteLine($"shape = {shape:r}, logRate = {logRate:r}");
//Console.WriteLine($"shape = {shape:g17}, logRate = {logRate:g17}");
if (MMath.AreEqual(oldLogRate, logRate) && MMath.AreEqual(oldShape, shape)) break;
if (double.IsNaN(shape)) throw new Exception("Failed to converge");
}
@ -450,15 +449,6 @@ namespace Microsoft.ML.Probabilistic.Distributions
get { return (Shape == Double.PositiveInfinity); }
}
/// <summary>
/// Sets this instance to a point mass. The location of the
/// point mass is the existing Rate parameter
/// </summary>
private void SetToPointMass()
{
Shape = Double.PositiveInfinity;
}
/// <summary>
/// Sets/gets the instance as a point mass
/// </summary>
@ -472,7 +462,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
set
{
SetToPointMass();
// Change this instance to a point mass.
Shape = Double.PositiveInfinity;
Rate = value;
}
}

Просмотреть файл

@ -29,7 +29,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
Sampleable<double>, SettableToWeightedSum<TruncatedGamma>,
CanGetMean<double>, CanGetVariance<double>, CanGetMeanAndVarianceOut<double, double>,
CanGetLogNormalizer, CanGetLogAverageOf<TruncatedGamma>, CanGetLogAverageOfPower<TruncatedGamma>,
CanGetAverageLog<TruncatedGamma>
CanGetAverageLog<TruncatedGamma>, CanGetMode<double>
{
/// <summary>
/// Untruncated Gamma
@ -249,7 +249,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
return double.NegativeInfinity;
else
{
return this.Gamma.GetLogProb(value) + this.Gamma.GetLogNormalizer() - GetLogNormalizer();
return this.Gamma.GetLogProb(value) + (this.Gamma.GetLogNormalizer() - GetLogNormalizer());
}
}
@ -259,9 +259,11 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <returns></returns>
public double GetNormalizer()
{
if (IsProper())
if (IsProper() && !IsPointMass)
{
return this.Gamma.GetProbLessThan(UpperBound) - this.Gamma.GetProbLessThan(LowerBound);
// Equivalent but less accurate:
//return this.Gamma.GetProbLessThan(UpperBound) - this.Gamma.GetProbLessThan(LowerBound);
return GammaProbBetween(this.Gamma.Shape, this.Gamma.Rate, LowerBound, UpperBound);
}
else
{
@ -275,8 +277,22 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// <returns></returns>
public double GetLogNormalizer()
{
// TODO: make this more accurate.
return Math.Log(GetNormalizer());
if (IsProper() && !IsPointMass)
{
if (this.Gamma.Shape < 1 && (double)(this.Gamma.Rate * LowerBound) > 0)
{
// When Shape < 1, Gamma(Shape) > 1 so use the unregularized version to avoid underflow.
return Math.Log(GammaProbBetween(this.Gamma.Shape, this.Gamma.Rate, LowerBound, UpperBound, false)) - MMath.GammaLn(this.Gamma.Shape);
}
else
{
return Math.Log(GammaProbBetween(this.Gamma.Shape, this.Gamma.Rate, LowerBound, UpperBound));
}
}
else
{
return 0.0;
}
}
/// <summary>
@ -448,12 +464,12 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// </summary>
/// <returns>The sample value</returns>
[Stochastic]
public static double Sample(double shape, double scale, double lowerBound, double upperBound)
public static double Sample(Gamma gamma, double lowerBound, double upperBound)
{
double sample;
do
{
sample = Gamma.Sample(shape, scale);
sample = gamma.Sample();
} while (sample < lowerBound || sample > upperBound);
return sample;
}
@ -471,7 +487,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
else
{
return Sample(Gamma.Shape, 1 / Gamma.Rate, LowerBound, UpperBound);
return Sample(Gamma, LowerBound, UpperBound);
}
}
@ -486,29 +502,58 @@ namespace Microsoft.ML.Probabilistic.Distributions
return Sample();
}
/// <summary>
/// Get the mode (highest density point) of this distribution
/// </summary>
/// <returns></returns>
public double GetMode()
{
return Math.Min(Math.Max(this.Gamma.GetMode(), this.LowerBound), this.UpperBound);
}
/// <summary>
/// Returns the mean (first moment) of the distribution
/// </summary>
/// <returns></returns>
public double GetMean()
{
if (this.Gamma.IsPointMass)
return this.Gamma.Point;
else if (!IsProper())
throw new ImproperDistributionException(this);
else
double mean, variance;
GetMeanAndVariance(out mean, out variance);
return mean;
}
/// <summary>
/// Get the variance of this distribution
/// </summary>
/// <returns></returns>
public double GetVariance()
{
double mean, variance;
GetMeanAndVariance(out mean, out variance);
return variance;
}
/// <summary>
/// Computes <c>GammaUpper(s,x)/(x^(s-1)*exp(-x)) - 1</c> to high accuracy
/// </summary>
/// <param name="s"></param>
/// <param name="x">A real number gt;= 45 and gt; <paramref name="s"/>/0.99</param>
/// <param name="regularized"></param>
/// <returns></returns>
public static double GammaUpperRatio(double s, double x, bool regularized = true)
{
if (s >= x * 0.99) throw new ArgumentOutOfRangeException(nameof(s), s, "s >= x*0.99");
if (x < 45) throw new ArgumentOutOfRangeException(nameof(x), x, "x < 45");
double term = (s - 1) / x;
double sum = term;
for (int i = 2; i < 1000; i++)
{
double Z = GetNormalizer();
if (Z == 0)
{
double mean = this.Gamma.GetMean();
return Math.Min(UpperBound, Math.Max(LowerBound, mean));
}
// if Z is not zero, then Z1 cannot be zero.
double Z1 = MMath.GammaLower(this.Gamma.Shape + 1, this.Gamma.Rate * UpperBound) - MMath.GammaLower(this.Gamma.Shape + 1, this.Gamma.Rate * LowerBound);
double sum = this.Gamma.Shape / this.Gamma.Rate * Z1;
return sum / Z;
term *= (s - i) / x;
double oldSum = sum;
sum += term;
if (MMath.AreEqual(sum, oldSum)) return regularized ? sum / MMath.Gamma(s) : sum;
}
throw new Exception($"GammaUpperRatio not converging for s={s:g17}, x={x:g17}, regularized={regularized}");
}
/// <summary>
@ -535,34 +580,50 @@ namespace Microsoft.ML.Probabilistic.Distributions
throw new ImproperDistributionException(this);
else
{
double Z = GetNormalizer();
if (Z == 0)
{
mean = Math.Min(UpperBound, Math.Max(LowerBound, this.Gamma.GetMean()));
variance = 0.0;
return;
}
// Apply the recurrence GammaUpper(s+1,x,false) = s*GammaUpper(s,x,false) + x^s*exp(-x)
double rl = this.Gamma.Rate * LowerBound;
double ru = this.Gamma.Rate * UpperBound;
double m = this.Gamma.Shape / this.Gamma.Rate;
// t = x * Rate
// dt = dx * Rate
double Z1 = MMath.GammaLower(this.Gamma.Shape + 1, this.Gamma.Rate * UpperBound) - MMath.GammaLower(this.Gamma.Shape + 1, this.Gamma.Rate * LowerBound);
mean = m * Z1 / Z;
double sum2 = m * (this.Gamma.Shape + 1) / this.Gamma.Rate * (MMath.GammaLower(this.Gamma.Shape + 2, this.Gamma.Rate * UpperBound) - MMath.GammaLower(this.Gamma.Shape + 2, this.Gamma.Rate * LowerBound));
variance = sum2 / Z - mean * mean;
double offset, offset2;
if (ru > double.MaxValue)
{
double logZ = GetLogNormalizer();
if (logZ < double.MinValue)
{
mean = GetMode();
variance = 0.0;
return;
}
offset = Math.Exp(MMath.GammaUpperLogScale(this.Gamma.Shape, rl) - logZ);
offset2 = (rl - this.Gamma.Shape) / this.Gamma.Rate * offset;
}
else
{
// This fails when GammaUpperScale underflows to 0
double Z = GetNormalizer();
if (Z == 0)
{
mean = GetMode();
variance = 0.0;
return;
}
double gammaUpperScaleLower = MMath.GammaUpperScale(this.Gamma.Shape, rl);
double gammaUpperScaleUpper = MMath.GammaUpperScale(this.Gamma.Shape, ru);
offset = (gammaUpperScaleLower - gammaUpperScaleUpper) / Z;
offset2 = ((rl - this.Gamma.Shape) / this.Gamma.Rate * gammaUpperScaleLower - (ru - this.Gamma.Shape) / this.Gamma.Rate * gammaUpperScaleUpper) / Z;
}
if (rl == this.Gamma.Shape) mean = LowerBound + offset / this.Gamma.Rate;
else
{
mean = (this.Gamma.Shape + offset) / this.Gamma.Rate;
if (mean < LowerBound) mean = MMath.NextDouble(mean);
if (mean < LowerBound) mean = MMath.NextDouble(mean);
}
if (mean > double.MaxValue) variance = mean;
else variance = (m + offset2 + (1 - offset) * offset / this.Gamma.Rate) / this.Gamma.Rate;
}
}
/// <summary>
/// Get the variance of this distribution
/// </summary>
/// <returns></returns>
public double GetVariance()
{
double mean, var;
GetMeanAndVariance(out mean, out var);
return var;
}
/// <summary>
/// Computes E[x^power]
/// </summary>
@ -570,6 +631,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
public double GetMeanPower(double power)
{
if (power == 0.0) return 1.0;
else if (power == 1.0) return GetMean();
else if (IsPointMass) return Math.Pow(Point, power);
//else if (Rate == 0.0) return (power > 0) ? Double.PositiveInfinity : 0.0;
else if (!IsProper()) throw new ImproperDistributionException(this);
@ -577,23 +639,94 @@ namespace Microsoft.ML.Probabilistic.Distributions
{
throw new ArgumentException("Cannot compute E[x^" + power + "] for " + this + " (shape <= " + (-power) + ")");
}
else
else if (power != 1)
{
double Z = GetNormalizer();
// Large powers lead to overflow
power = Math.Min(Math.Max(power, -1e300), 1e300);
double logZ = GetLogNormalizer();
if (logZ < double.MinValue)
{
return Math.Pow(GetMode(), power);
}
double shapePlusPower = this.Gamma.Shape + power;
double Z1;
double logZ1;
bool regularized = shapePlusPower >= 1;
if (regularized)
{
Z1 = Math.Exp(MMath.GammaLn(shapePlusPower) - MMath.GammaLn(this.Gamma.Shape)) *
(MMath.GammaLower(shapePlusPower, this.Gamma.Rate * UpperBound) - MMath.GammaLower(shapePlusPower, this.Gamma.Rate * LowerBound));
// This formula cannot be used when shapePlusPower <= 0
logZ1 = (power * MMath.RisingFactorialLnOverN(this.Gamma.Shape, power)) +
Math.Log(GammaProbBetween(shapePlusPower, this.Gamma.Rate, LowerBound, UpperBound, regularized));
}
else
{
Z1 = Math.Exp(- MMath.GammaLn(this.Gamma.Shape)) *
(MMath.GammaUpper(shapePlusPower, this.Gamma.Rate * LowerBound, regularized) - MMath.GammaUpper(shapePlusPower, this.Gamma.Rate * UpperBound, regularized));
logZ1 = -MMath.GammaLn(this.Gamma.Shape) +
Math.Log(GammaProbBetween(shapePlusPower, this.Gamma.Rate, LowerBound, UpperBound, regularized));
}
return Math.Pow(this.Gamma.Rate, -power) * Z1 / Z;
return Math.Exp(-power * Math.Log(this.Gamma.Rate) + logZ1 - logZ);
}
else
{
double Z = GetNormalizer();
if (Z == 0.0)
{
return Math.Pow(GetMode(), power);
}
double shapePlusPower = this.Gamma.Shape + power;
double Z1;
double gammaLnShapePlusPower = MMath.GammaLn(shapePlusPower);
double gammaLnShape = MMath.GammaLn(this.Gamma.Shape);
bool regularized = true; // (gammaLnShapePlusPower - gammaLnShape <= 700);
if (regularized)
{
// If shapePlusPower is large and Gamma.Rate * UpperBound is small, then this can lead to Inf * 0
Z1 = Math.Exp(power * MMath.RisingFactorialLnOverN(this.Gamma.Shape, power)) *
GammaProbBetween(shapePlusPower, this.Gamma.Rate, LowerBound, UpperBound, regularized);
}
else
{
Z1 = Math.Exp(-gammaLnShape) *
GammaProbBetween(shapePlusPower, this.Gamma.Rate, LowerBound, UpperBound, regularized);
}
return Z1 / (Math.Pow(this.Gamma.Rate, power) * Z);
}
}
/// <summary>
/// Computes GammaLower(a, r*u) - GammaLower(a, r*l) to high accuracy.
/// </summary>
/// <param name="shape"></param>
/// <param name="rate"></param>
/// <param name="lowerBound"></param>
/// <param name="upperBound"></param>
/// <param name="regularized"></param>
/// <returns></returns>
public static double GammaProbBetween(double shape, double rate, double lowerBound, double upperBound, bool regularized = true)
{
double rl = rate * lowerBound;
// Use the criterion from Gautschi (1979) to determine whether GammaLower(a,x) or GammaUpper(a,x) is smaller.
bool lowerIsSmaller;
if (rl > 0.25)
lowerIsSmaller = (shape > rl + 0.25);
else
lowerIsSmaller = (shape > -MMath.Ln2 / Math.Log(rl));
if (!lowerIsSmaller)
{
double logl = Math.Log(lowerBound);
if (rate * upperBound < 1e-16 && shape < -1e-16 / (Math.Log(rate) + logl))
{
double logu = Math.Log(upperBound);
return shape * (logu - logl);
}
else
{
// This is inaccurate when lowerBound is close to upperBound. In that case, use a Taylor expansion of lowerBound around upperBound.
return MMath.GammaUpper(shape, rl, regularized) - MMath.GammaUpper(shape, rate * upperBound, regularized);
}
}
else
{
double diff = MMath.GammaLower(shape, rate * upperBound) - MMath.GammaLower(shape, rl);
return regularized ? diff : (MMath.Gamma(shape) * diff);
}
}

Просмотреть файл

@ -81,7 +81,7 @@ namespace Microsoft.ML.Probabilistic.Factors
[ParameterNames("sample", "shape", "rate", "lowerBound", "upperBound")]
public static double TruncatedGammaFromShapeAndRate(double shape, double rate, double lowerBound, double upperBound)
{
return TruncatedGamma.Sample(shape, 1 / rate, lowerBound, upperBound);
return TruncatedGamma.Sample(Gamma.FromShapeAndRate(shape, rate), lowerBound, upperBound);
}
/// <summary>

Просмотреть файл

@ -536,9 +536,27 @@ namespace Microsoft.ML.Probabilistic.Factors
double x = sample.Point;
double shape2 = shape + rate.Shape;
double xrr = x + rate.Rate;
double dlogf = (shape - 1) / x - shape2 / xrr;
double ddlogf = -(shape - 1) / (x * x) + shape2 / (xrr * xrr);
return Gamma.FromDerivatives(x, dlogf, ddlogf, GammaFromShapeAndRateOp.ForceProper);
if (x == 0)
{
if (shape == 1)
{
double dlogf = -shape2 / xrr;
double ddlogf = shape2 / (xrr * xrr);
return Gamma.FromDerivatives(x, dlogf, ddlogf, GammaFromShapeAndRateOp.ForceProper);
}
else
{
// a = -x*x*ddLogP
// b = a / x - dLogP
return Gamma.FromShapeAndRate(shape, shape2 / xrr);
}
}
else
{
double dlogf = (shape - 1) / x - shape2 / xrr;
double ddlogf = -(shape - 1) / (x * x) + shape2 / (xrr * xrr);
return Gamma.FromDerivatives(x, dlogf, ddlogf, GammaFromShapeAndRateOp.ForceProper);
}
}
double sampleMean, sampleVariance;
if (sample.Rate == 0)
@ -681,12 +699,16 @@ namespace Microsoft.ML.Probabilistic.Factors
/// <returns></returns>
internal static double FindMaximum(double shape1, double shape2, double yRate, double rateRate)
{
if (shape2 == 0)
{
return shape1 / rateRate;
}
if (yRate < 0)
throw new ArgumentException("yRate < 0");
// f = shape1*log(rs) - shape2*log(rs+by) - br*rs
// df = shape1/rs - shape2/(rs + by) - br
// df=0 when shape1*(rs+by) - shape2*rs - br*rs*(rs+by) = 0
// -br*rs^2 + (shape1-shape2-br*by)*rs + shape1*by = 0
throw new ArgumentOutOfRangeException(nameof(yRate), yRate, "yRate < 0");
// f = shape1*log(x) - shape2*log(x+yRate) - x*rateRate
// df = shape1/x - shape2/(x + yrate) - rateRate
// df=0 when shape1*(x+yRate) - shape2*x - rateRate*x*(x+yRate) = 0
// -rateRate*x^2 + (shape1-shape2-rateRate*yRate)*x + shape1*yRate = 0
double a = -rateRate;
double b = shape1 - shape2 - yRate * rateRate;
double c = shape1 * yRate;
@ -714,11 +736,11 @@ namespace Microsoft.ML.Probabilistic.Factors
// compute the derivative wrt log(rs)
double sum = r0 + yRate;
double p = r0 / sum;
double df = shape1 - shape2*p - rateRate*r0;
double df = shape1 - shape2 * p - rateRate * r0;
if (Math.Abs(df) > 1)
{
// take a Newton step for extra accuracy
double ddf = shape2*p*(p-1) - rateRate*r0;
double ddf = shape2 * p * (p - 1) - rateRate * r0;
r0 *= Math.Exp(-df / ddf);
}
if (double.IsNaN(r0))
@ -814,7 +836,7 @@ namespace Microsoft.ML.Probabilistic.Factors
{
if (hasInflection)
rmax = r * 1.1; // restart closer to the stationary point
else
else
throw new Exception("rmax < r");
}
if (MMath.AreEqual(rmax, r))
@ -908,12 +930,19 @@ namespace Microsoft.ML.Probabilistic.Factors
// dlogf = s/r - (s+xs-1)/(r+xr)
// ddlogf = -s/r^2 + (s+xs-1)/(r+xr)^2
r = rate.Point;
double v = 1 / r;
double r2 = r + sample.Rate;
double v2 = 1 / r2;
double dlogf = shape * v - shape2 * v2;
double ddlogf = -shape * v * v + shape2 * v2 * v2;
return Gamma.FromDerivatives(r, dlogf, ddlogf, GammaFromShapeAndRateOp.ForceProper);
if (r == 0)
{
// a = -r*r*ddLogP
// b = a / r - dLogP
return Gamma.FromShapeAndRate(shape + 1, shape2 / r2);
}
else
{
double dlogf = shape / r - shape2 / r2;
double ddlogf = -shape / (r * r) + shape2 / (r2 * r2);
return Gamma.FromDerivatives(r, dlogf, ddlogf, GammaFromShapeAndRateOp.ForceProper);
}
}
double shape1 = shape + rate.Shape;
double rateMean, rateVariance;
@ -1043,8 +1072,8 @@ namespace Microsoft.ML.Probabilistic.Factors
double p = r / (r + y.Rate);
double p2 = p * p;
double shape2 = GammaFromShapeAndRateOp_Slow.AddShapesMinus1(y.Shape, shape);
double dlogf = shape - shape2 * p;
double ddlogf = -shape + shape2 * p2;
double dlogf = shape - shape2 * p;
double ddlogf = -shape + shape2 * p2;
double dddlogf = 2 * shape - 2 * shape2 * p * p2;
double d4logf = -6 * shape + 6 * shape2 * p2 * p2;
return new double[] { dlogf, ddlogf, dddlogf, d4logf };

Просмотреть файл

@ -14,6 +14,189 @@ namespace Microsoft.ML.Probabilistic.Factors
[Quality(QualityBand.Experimental)]
public static class PlusGammaOp
{
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="SumAverageConditional(GammaPower, GammaPower)"]/*'/>
public static GammaPower SumAverageConditional([SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
{
a.GetMeanAndVariance(out double aMean, out double aVariance);
b.GetMeanAndVariance(out double bMean, out double bVariance);
double mean = aMean + bMean;
double variance = aVariance + bVariance;
return GammaPower.FromMeanAndVariance(mean, variance, result.Power);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="AAverageConditional(GammaPower, GammaPower)"]/*'/>
public static GammaPower AAverageConditional([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
{
if (sum.IsUniform()) return sum;
sum.GetMeanAndVariance(out double sumMean, out double sumVariance);
b.GetMeanAndVariance(out double bMean, out double bVariance);
double rMean = Math.Max(0, sumMean - bMean);
double rVariance = sumVariance + bVariance;
double aVariance = a.GetVariance();
if (rVariance > aVariance)
{
if (sum.Power == 1)
{
GetGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance);
mean += b.GetMean();
variance += b.GetVariance();
GetGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr);
GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ);
return GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ);
}
else if (sum.Power == -1)
{
GetInverseGammaMomentDerivs(a, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance);
mean += b.GetMean();
variance += b.GetVariance();
if (variance > double.MaxValue) return GammaPower.Uniform(a.Power); //throw new NotSupportedException();
GetInverseGammaDerivs(mean, dmean, ddmean, variance, dvariance, ddvariance, out double ds, out double dds, out double dr, out double ddr);
if (sum.IsPointMass && sum.Point == 0) return GammaPower.PointMass(0, a.Power);
GetDerivLogZ(sum, GammaPower.FromMeanAndVariance(mean, variance, sum.Power), ds, dds, dr, ddr, out double dlogZ, out double ddlogZ);
return GammaPowerFromDerivLogZ(a, dlogZ, ddlogZ);
}
}
return GammaPower.FromMeanAndVariance(rMean, rVariance, result.Power);
}
public static GammaPower GammaPowerFromDerivLogZ(GammaPower a, double dlogZ, double ddlogZ)
{
bool method1 = false;
if (method1)
{
GetPosteriorMeanAndVariance(Gamma.FromShapeAndRate(a.Shape, a.Rate), dlogZ, ddlogZ, out double iaMean, out double iaVariance);
Gamma ia = Gamma.FromMeanAndVariance(iaMean, iaVariance);
return GammaPower.FromShapeAndRate(ia.Shape, ia.Rate, a.Power) / a;
}
else
{
double alpha = -a.Rate * dlogZ;
// dalpha/dr = -dlogZ - r*ddlogZ
// beta = -r * dalpha/dr
double beta = a.Rate * dlogZ + a.Rate * a.Rate * ddlogZ;
Gamma prior = Gamma.FromShapeAndRate(a.Shape, a.Rate);
// ia is the marginal of a^(1/a.Power)
Gamma ia = GaussianOp.GammaFromAlphaBeta(prior, alpha, beta, true) * prior;
return GammaPower.FromShapeAndRate(ia.Shape, ia.Rate, a.Power) / a;
}
}
/// <summary>
/// Gets first and second derivatives of the moments with respect to the rate parameter of the distribution.
/// </summary>
/// <param name="gammaPower"></param>
/// <param name="mean"></param>
/// <param name="dmean"></param>
/// <param name="ddmean"></param>
/// <param name="variance"></param>
/// <param name="dvariance"></param>
/// <param name="ddvariance"></param>
public static void GetInverseGammaMomentDerivs(GammaPower gammaPower, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance)
{
if (gammaPower.Power != -1) throw new ArgumentException();
if (gammaPower.Shape <= 2) throw new ArgumentOutOfRangeException($"gammaPower.Shape <= 2");
mean = gammaPower.Rate / (gammaPower.Shape - 1);
dmean = 1 / (gammaPower.Shape - 1);
ddmean = 0;
variance = mean * mean / (gammaPower.Shape - 2);
dvariance = 2 * mean * dmean / (gammaPower.Shape - 2);
ddvariance = 2 * dmean * dmean / (gammaPower.Shape - 2);
}
public static void GetInverseGammaDerivs(double mean, double dmean, double ddmean, double variance, double dvariance, double ddvariance, out double ds, out double dds, out double dr, out double ddr)
{
double shape = 2 + mean * mean / variance;
double v2 = variance * variance;
ds = 2 * mean * dmean / variance - mean * mean / v2 * dvariance;
dds = 2 * mean * ddmean / variance - mean * mean / v2 * ddvariance + 2 * dmean * dmean / variance - 4 * mean * dmean / v2 * dvariance + 2 * mean * mean / (v2 * variance) * dvariance * dvariance;
dr = dmean * (shape - 1) + mean * ds;
ddr = ddmean * (shape - 1) + 2 * dmean * ds + mean * dds;
}
public static void GetGammaMomentDerivs(GammaPower gammaPower, out double mean, out double dmean, out double ddmean, out double variance, out double dvariance, out double ddvariance)
{
if (gammaPower.Power != 1) throw new ArgumentException();
mean = gammaPower.Shape / gammaPower.Rate;
variance = mean / gammaPower.Rate;
dmean = -variance;
ddmean = 2 * variance / gammaPower.Rate;
dvariance = -ddmean;
ddvariance = -3 * dvariance / gammaPower.Rate;
}
public static void GetGammaDerivs(double mean, double dmean, double ddmean, double variance, double dvariance, double ddvariance, out double ds, out double dds, out double dr, out double ddr)
{
double rate = mean / variance;
//double shape = mean * rate;
double v2 = variance * variance;
dr = dmean / variance - mean / v2 * dvariance;
ddr = ddmean / variance - mean / v2 * ddvariance - 2 * dmean / v2 * dvariance + 2 * mean / (v2 * variance) * dvariance * dvariance;
ds = dmean * rate + mean * dr;
dds = ddmean * rate + 2 * dmean * dr + mean * ddr;
}
public static void GetDerivLogZ(GammaPower sum, GammaPower toSum, double ds, double dds, double dr, double ddr, out double dlogZ, out double ddlogZ)
{
if (sum.Power != toSum.Power) throw new ArgumentException($"sum.Power ({sum.Power}) != toSum.Power ({toSum.Power})");
if (toSum.IsPointMass) throw new NotSupportedException();
if(toSum.IsUniform())
{
dlogZ = 0;
ddlogZ = 0;
return;
}
if (sum.IsPointMass)
{
// Z = toSum.GetLogProb(sum.Point)
// log(Z) = (toSum.Shape/toSum.Power - 1)*log(sum.Point) - toSum.Rate*sum.Point^(1/toSum.Power) + toSum.Shape*log(toSum.Rate) - GammaLn(toSum.Shape)
if (sum.Point == 0) throw new NotSupportedException();
double logSumOverPower = Math.Log(sum.Point) / toSum.Power;
double powSum = Math.Exp(logSumOverPower);
double logRate = Math.Log(toSum.Rate);
double digammaShape = MMath.Digamma(toSum.Shape);
double shapeOverRate = toSum.Shape / toSum.Rate;
dlogZ = ds * logSumOverPower - dr * powSum + ds * logRate + shapeOverRate * dr - digammaShape * ds;
ddlogZ = dds * logSumOverPower - ddr * powSum + dds * logRate + 2 * ds * dr / toSum.Rate + shapeOverRate * ddr - MMath.Trigamma(toSum.Shape) * ds - digammaShape * dds;
}
else
{
GammaPower product = sum * toSum;
double cs = (MMath.Digamma(product.Shape) - Math.Log(product.Shape)) - (MMath.Digamma(toSum.Shape) - Math.Log(toSum.Shape));
double cr = toSum.Shape / toSum.Rate - product.Shape / product.Rate;
double css = MMath.Trigamma(product.Shape) - MMath.Trigamma(toSum.Shape);
double csr = 1 / toSum.Rate - 1 / product.Rate;
double crr = product.Shape / (product.Rate * product.Rate) - toSum.Shape / (toSum.Rate * toSum.Rate);
dlogZ = cs * ds + cr * dr;
ddlogZ = cs * dds + cr * ddr + css * ds * ds + 2 * csr * ds * dr + crr * dr * dr;
}
}
public static void GetPosteriorMeanAndVariance(Gamma prior, double dlogZ, double ddlogZ, out double mean, out double variance)
{
// dlogZ is derivative of log(Z) wrt prior rate parameter
prior.GetMeanAndVariance(out double priorMean, out double priorVariance);
mean = priorMean - dlogZ;
variance = priorVariance + ddlogZ;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="BAverageConditional(GammaPower, GammaPower)"]/*'/>
public static GammaPower BAverageConditional([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b, GammaPower result)
{
return AAverageConditional(sum, b, a, result);
}
public static double LogAverageFactor([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b)
{
GammaPower toSum = SumAverageConditional(a, b, sum);
return toSum.GetLogAverageOf(sum);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="LogEvidenceRatio(GammaPower, GammaPower, GammaPower)"]/*'/>
public static double LogEvidenceRatio([SkipIfUniform] GammaPower sum, [SkipIfUniform] GammaPower a, [SkipIfUniform] GammaPower b)
{
return 0.0;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PlusGammaOp"]/message_doc[@name="SumAverageConditional(GammaPower, double)"]/*'/>
public static GammaPower SumAverageConditional([SkipIfUniform] GammaPower a, double b)
{

Просмотреть файл

@ -15,6 +15,38 @@ namespace Microsoft.ML.Probabilistic.Factors
[Quality(QualityBand.Experimental)]
public static class PowerOp
{
public static Gamma GammaFromMeanAndMeanInverse(double mean, double meanInverse)
{
if (mean < 0) throw new ArgumentOutOfRangeException(nameof(mean), mean, "mean < 0");
if (meanInverse < 0) throw new ArgumentOutOfRangeException(nameof(meanInverse), meanInverse, "meanInverse < 0");
// mean = a/b
// meanInverse = b/(a-1)
// a = mean*meanInverse / (mean*meanInverse - 1)
// b = a/mean
double rate = meanInverse / (mean * meanInverse - 1);
if (rate < 0 || rate > double.MaxValue) return Gamma.PointMass(mean);
double shape = mean * rate;
if (shape > double.MaxValue)
return Gamma.PointMass(mean);
else
return Gamma.FromShapeAndRate(shape, rate);
}
public static Gamma GammaFromGammaPower(GammaPower message)
{
if (message.Power == 1) return Gamma.FromShapeAndRate(message.Shape, message.Rate); // same as below, but faster
if (message.IsUniform()) return Gamma.Uniform();
message.GetMeanAndVariance(out double mean, out double variance);
return Gamma.FromMeanAndVariance(mean, variance);
}
public static Gamma FromMeanPowerAndMeanLog(double meanPower, double meanLog, double power)
{
// We want E[log(x)] = meanLog but this sets E[log(x^power)] = meanLog, so we scale meanLog
var gammaPower = GammaPower.FromMeanAndMeanLog(meanPower, meanLog * power, power);
return Gamma.FromShapeAndRate(gammaPower.Shape, gammaPower.Rate);
}
// Gamma = TruncatedGamma ^ y /////////////////////////////////////////////////////////
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PowerOp"]/message_doc[@name="PowAverageConditional(TruncatedGamma, double)"]/*'/>
@ -33,18 +65,6 @@ namespace Microsoft.ML.Probabilistic.Factors
}
}
public static Gamma GammaFromMeanAndMeanInverse(double mean, double meanInverse)
{
// mean = a/b
// meanInverse = b/(a-1)
// a = mean*meanInverse / (mean*meanInverse - 1)
// b = a/mean
double rate = meanInverse / (mean * meanInverse - 1);
double shape = mean * rate;
return Gamma.FromShapeAndRate(shape, rate);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PowerOp"]/message_doc[@name="XAverageConditional(Gamma, TruncatedGamma, double)"]/*'/>
public static TruncatedGamma XAverageConditional([SkipIfUniform] Gamma pow, TruncatedGamma x, double y)
{
// message computed below should be uniform when pow is uniform, but may not due to roundoff error.
@ -58,7 +78,48 @@ namespace Microsoft.ML.Probabilistic.Factors
var powMarginal = pow * toPow;
// xMarginal2 is the exact distribution of pow^(1/y) where pow has distribution powMarginal
GammaPower xMarginal2 = GammaPower.FromShapeAndRate(powMarginal.Shape, powMarginal.Rate, power);
var xMarginal = new TruncatedGamma(GammaFromGammaPower(xMarginal2));
var xMarginal = new TruncatedGamma(GammaFromGammaPower(xMarginal2), x.LowerBound, x.UpperBound);
var result = xMarginal;
result.SetToRatio(xMarginal, x, GammaProductOp_Laplace.ForceProper);
return result;
}
// GammaPower = TruncatedGamma ^ y /////////////////////////////////////////////////////////
public static GammaPower PowAverageConditional([SkipIfUniform] TruncatedGamma x, double y, GammaPower result)
{
if (result.Power == -1) y = -y;
else if (result.Power != 1) throw new ArgumentException($"result.Power ({result.Power}) is not 1 or -1", nameof(result));
double mean = x.GetMeanPower(y);
if (x.LowerBound > 0)
{
double meanInverse = x.GetMeanPower(-y);
Gamma result1 = GammaFromMeanAndMeanInverse(mean, meanInverse);
return GammaPower.FromShapeAndRate(result1.Shape, result1.Rate, result.Power);
}
else
{
double variance = x.GetMeanPower(2 * y) - mean * mean;
Gamma result1 = Gamma.FromMeanAndVariance(mean, variance);
return GammaPower.FromShapeAndRate(result1.Shape, result1.Rate, result.Power);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PowerOp"]/message_doc[@name="XAverageConditional(GammaPower, TruncatedGamma, double)"]/*'/>
public static TruncatedGamma XAverageConditional([SkipIfUniform] GammaPower pow, TruncatedGamma x, double y)
{
// message computed below should be uniform when pow is uniform, but may not due to roundoff error.
if (pow.IsUniform()) return TruncatedGamma.Uniform();
// Factor is (x^y)^(pow.Shape/pow.Power - 1) * exp(-pow.Rate*(x^y)^1/pow.Power)
// =propto x^(pow.Shape/(pow.Power/y) - y) * exp(-pow.Rate*x^y/pow.Power)
// newShape/(pow.Power/y) - 1 = pow.Shape/(pow.Power/y) - y
// newShape = pow.Shape + (1-y)*(pow.Power/y)
double power = pow.Power / y;
var toPow = PowAverageConditional(x, y, pow);
var powMarginal = pow * toPow;
// xMarginal2 is the exact distribution of pow^(1/y) where pow has distribution powMarginal
GammaPower xMarginal2 = GammaPower.FromShapeAndRate(powMarginal.Shape, powMarginal.Rate, power);
var xMarginal = new TruncatedGamma(GammaFromGammaPower(xMarginal2), x.LowerBound, x.UpperBound);
var result = xMarginal;
result.SetToRatio(xMarginal, x, GammaProductOp_Laplace.ForceProper);
return result;
@ -92,14 +153,6 @@ namespace Microsoft.ML.Probabilistic.Factors
return result;
}
public static Gamma GammaFromGammaPower(GammaPower message)
{
if (message.Power == 1) return Gamma.FromShapeAndRate(message.Shape, message.Rate); // same as below, but faster
if (message.IsUniform()) return Gamma.Uniform();
message.GetMeanAndVariance(out double mean, out double variance);
return Gamma.FromMeanAndVariance(mean, variance);
}
// GammaPower = GammaPower ^ y /////////////////////////////////////////////////////////
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="PowerOp"]/message_doc[@name="LogAverageFactor(GammaPower, GammaPower, double)"]/*'/>
@ -141,13 +194,6 @@ namespace Microsoft.ML.Probabilistic.Factors
return result;
}
public static Gamma FromMeanPowerAndMeanLog(double meanPower, double meanLog, double power)
{
// We want E[log(x)] = meanLog but this sets E[log(x^power)] = meanLog, so we scale meanLog
var gammaPower = GammaPower.FromMeanAndMeanLog(meanPower, meanLog * power, power);
return Gamma.FromShapeAndRate(gammaPower.Shape, gammaPower.Rate);
}
public static GammaPower GammaPowerFromDifferentPower(GammaPower message, double newPower)
{
if (message.Power == newPower) return message; // same as below, but faster

Просмотреть файл

@ -154,7 +154,6 @@ namespace Microsoft.ML.Probabilistic.Factors
if (double.IsNaN(currPoint)) throw new ArgumentException("currPoint is NaN");
if (double.IsInfinity(currPoint)) throw new ArgumentException("currPoint is infinite");
if (double.IsNaN(currDeriv)) throw new ArgumentException("currDeriv is NaN");
if (double.IsInfinity(currDeriv)) throw new ArgumentException("currDeriv is infinite");
if (hasPrevious)
{
double prevStep = currPoint - prevPoint;

Просмотреть файл

@ -53,10 +53,10 @@ log_1_minus_variable_name = "expx"
log_1_minus_indent = " "
x_minus_log_1_plus_series_length = 7
x_minus_log_1_plus_variable_name = "xOverAMinus1"
x_minus_log_1_plus_variable_name = "x"
x_minus_log_1_plus_indent = " "
exp_minus_1_series_length = 5
exp_minus_1_series_length = 9
exp_minus_1_variable_name = "x"
exp_minus_1_indent = " "
@ -126,14 +126,17 @@ def print_polynomial_with_rational_coefficients(varname, coefficients, indent):
idx = 1
parentheses = 0
print(indent, end='')
while idx < last_non_zero_idx:
while idx <= last_non_zero_idx:
print(f"{varname} * ", end='')
if coefficients[idx] != 0:
print(f"({format_rational_coefficient(coefficients[idx])} +")
if idx < last_non_zero_idx:
suffix = ' +'
else:
suffix = ''
print(f"({format_rational_coefficient(coefficients[idx])}{suffix}")
print(indent, end='')
parentheses = parentheses + 1
idx = idx + 1
print(f"{varname} * {format_rational_coefficient(coefficients[last_non_zero_idx])}", end='')
for i in range(0, parentheses):
print(")", end='')
print()

Просмотреть файл

@ -75,13 +75,12 @@ namespace TestApp
//InferenceEngine.DefaultEngine.Compiler.UseLocals = false;
TestUtils.SetDebugOptions();
TestUtils.SetBrowserMode(BrowserMode.OnError);
TestUtils.SetBrowserMode(BrowserMode.Always);
//TestUtils.SetBrowserMode(BrowserMode.Always);
//TestUtils.SetBrowserMode(BrowserMode.WriteFiles);
Stopwatch watch = new Stopwatch();
watch.Start();
if (false)
{
// Run all tests (need to run in 64-bit else OutOfMemory due to loading many DLLs)

Просмотреть файл

@ -33,6 +33,23 @@ namespace Microsoft.ML.Probabilistic.Tests
public double[,] pairs;
};
[Fact]
public void ToStringExactTest()
{
Assert.Equal("0", MMath.ToStringExact(0));
Assert.Equal("NaN", MMath.ToStringExact(double.NaN));
Assert.Equal(double.MaxValue, double.Parse(MMath.ToStringExact(double.MaxValue)));
Assert.Equal(double.MinValue, double.Parse(MMath.ToStringExact(double.MinValue)));
Assert.Equal("10.5", MMath.ToStringExact(10.5));
Assert.Equal(10.05, double.Parse(MMath.ToStringExact(10.05)));
Assert.Equal("0.100000000000000002505909183520875968569614680770370524992534231990046604318405148467630281218195010089496230627027825414891031146499880413081224609160619018271942662793458427551041478278701507022263926060379361392435977509403014386614147912551359088259101734169222292122040491862182202915561954185941852588326204092831631787205015401996986616948980410676557942431921652541808732242554300585073938340203330993157646467433638479065531661724812599598594906293782493759617177861888792970476530542335134710418229637566637950767497147854236589795152044892049176025289756709261767081824924720105632337755616538050643653812583050224659631159300563236507929025398878153811554013986009587978081167432804936359631140419153283449560376539011485874652862548828125e-299", MMath.ToStringExact(1e-300));
Assert.Equal("0.988131291682493088353137585736442744730119605228649528851171365001351014540417503730599672723271984759593129390891435461853313420711879592797549592021563756252601426380622809055691634335697964207377437272113997461446100012774818307129968774624946794546339230280063430770796148252477131182342053317113373536374079120621249863890543182984910658610913088802254960259419999083863978818160833126649049514295738029453560318710477223100269607052986944038758053621421498340666445368950667144166486387218476578691673612021202301233961950615668455463665849580996504946155275185449574931216955640746893939906729403594535543517025132110239826300978220290207572547633450191167477946719798732961988232841140527418055848553508913045817507736501283943653106689453125e-322", MMath.ToStringExact(1e-322));
Assert.Equal("0.4940656458412465441765687928682213723650598026143247644255856825006755072702087518652998363616359923797965646954457177309266567103559397963987747960107818781263007131903114045278458171678489821036887186360569987307230500063874091535649843873124733972731696151400317153853980741262385655911710266585566867681870395603106249319452715914924553293054565444011274801297099995419319894090804165633245247571478690147267801593552386115501348035264934720193790268107107491703332226844753335720832431936092382893458368060106011506169809753078342277318329247904982524730776375927247874656084778203734469699533647017972677717585125660551199131504891101451037862738167250955837389733598993664809941164205702637090279242767544565229087538682506419718265533447265625e-323", MMath.ToStringExact(double.Epsilon));
Assert.Equal(1e-300, double.Parse(MMath.ToStringExact(1e-300)));
Assert.Equal(1e-322, double.Parse(MMath.ToStringExact(1e-322)));
Assert.Equal(double.Epsilon, double.Parse(MMath.ToStringExact(double.Epsilon)));
}
[Fact]
public void GammaSpecialFunctionsTest()
{
@ -48,8 +65,22 @@ namespace Microsoft.ML.Probabilistic.Tests
};
CheckFunctionValues("BesselI", MMath.BesselI, BesselI_pairs);
/* In python mpmath:
from mpmath import *
mp.dps = 500
mp.pretty = True
gamma(mpf('8.5'))
*/
double[,] Gamma_pairs = new double[,]
{
{1e-18, 999999999999999999.422784335 },
{1e-17, 99999999999999999.422784335 },
{1e-16, 9999999999999999.422784335 },
{1e-15, 999999999999999.422784335 },
{1e-14, 99999999999999.422784335 },
{1e-13, 9999999999999.422784335 },
{1e-12, 999999999999.422784335 },
{1e-11, 99999999999.4227843351 },
{System.Math.Pow(0.5,20), 1048575.42278527833494202474 },
{System.Math.Pow(0.5,15), 32767.42281451784694671242432333 },
{System.Math.Pow(0.5,10), 1023.4237493455678303694987182399 },
@ -396,8 +427,28 @@ log(1-exp(mpf('-3')))
double[,] expminus1_pairs = new double[,]
{
{0, 0},
{1e-1, 0.1051709180756476248117 },
{6e-2, 0.061836546545359622224684877 },
{5.6e-2, 0.057597683736611251657434658737 },
{5.5e-2, 0.0565406146754942858469448477 },
{5.4e-2, 0.05548460215508004058489867657 },
{5.3e-2, 0.054429645119355907456004582 },
{5.2e-2, 0.05337574251336476282304 },
{5.1e-2, 0.05232289328320391286964 },
{5e-2, 0.0512710963760240396975 },
{1e-2, 0.0100501670841680575421654569 },
{3e-3, 0.003004504503377026012934 },
{2e-3, 0.00200200133400026675558 },
{1e-3, 0.001000500166708341668 },
{1e-4, 0.100005000166670833416668e-3},
{-1e-4, -0.9999500016666250008333e-4},
{-1e-3, -0.000999500166625008331944642832344 },
{-1e-2, -0.009950166250831946426094 },
{-2e-2, -0.01980132669324469777918589577469 },
{-3e-2, -0.029554466451491823067471648 },
{-4e-2, -0.0392105608476767905607893 },
{-5e-2, -0.04877057549928599090857468 },
{-1e-1, -0.09516258196404042683575 },
{Double.PositiveInfinity, Double.PositiveInfinity},
{Double.NegativeInfinity, -1},
{Double.NaN, Double.NaN}
@ -1014,6 +1065,13 @@ ncdf(-12.2)
[Fact]
public void GammaUpperTest()
{
double[,] gammaUpperScale_pairs = new double[,]
{
{100,3, 2.749402805834002258937858149557e-110},
{1e30,1.0000000000000024E+30, 22798605571598.2221521928234647 },
};
CheckFunctionValues(nameof(MMath.GammaUpperScale), MMath.GammaUpperScale, gammaUpperScale_pairs);
double[,] gammaLower_pairs = new double[,] {
{1e-6,1e-1,0.9999981770769746499},
{0.05,3e-20,0.1085221036950261},
@ -1037,7 +1095,7 @@ ncdf(-12.2)
{double.PositiveInfinity,1,0 },
{double.Epsilon,0,0 },
};
CheckFunctionValues("GammaLower", MMath.GammaLower, gammaLower_pairs);
CheckFunctionValues(nameof(MMath.GammaLower), MMath.GammaLower, gammaLower_pairs);
/* In python mpmath:
from mpmath import *
@ -1085,18 +1143,28 @@ gammainc(mpf('1'),mpf('1'),mpf('inf'),regularized=True)
{double.PositiveInfinity,1,1 },
{double.Epsilon,0,1 },
};
CheckFunctionValues("GammaUpperRegularized", (a,x) => MMath.GammaUpper(a, x, true), gammaUpperRegularized_pairs);
//CheckFunctionValues("GammaUpperRegularized", (a,x) => MMath.GammaUpper(a, x, true), gammaUpperRegularized_pairs);
/* In python mpmath:
from mpmath import *
mp.dps = 500
mp.pretty = True
gammainc(mpf('1'),mpf('1'),mpf('inf'),regularized=True)
gammainc(mpf('1'),mpf('1'),mpf('inf'),regularized=False)
*/
double[,] gammaUpper_pairs = new double[,] {
{1e-20,0.3,0.9056766516758467124267199175638778988963728798249333},
{1e-6,1e-1,1.8229219731321746872065707723366373632},
{2,1,0.73575888234288464319104754 },
{1, 1e-1, 0.9048374180359595731642490594464366211947 },
{0.5, 1e-1, 1.160462484793744246763365832264165338881 },
{1e-1, 1e-1, 1.6405876628018872105125369365484 },
{1e-2, 1e-1, 1.803241356902497279052687858810883 },
{1e-3, 1e-1, 1.8209403811279321641732411796 },
{1e-4, 1e-1, 1.8227254466517034567872146606649738 },
{1e-5, 1e-1, 1.822904105701365262441009216994 },
{1e-6, 1e-1, 1.8229219731321746872065707723366373632},
{1e-20,0.3,0.9056766516758467124267199175638778988963728798249333},
{1e-20, 1, 0.21938393439552027367814220743228835 },
{1e-20, 2, 0.048900510708061119567721436 },
{9.8813129168249309E-324, 4.94065645841247E-323, 741.5602711634856828468858990353816714074565 },
};
CheckFunctionValues("GammaUpper", (a, x) => MMath.GammaUpper(a, x, false), gammaUpper_pairs);
}
@ -1896,7 +1964,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
long ticks2 = watch.ElapsedTicks;
bool overtime = ticks > 10 * ticks2;
if (double.IsNaN(result1) /*|| overtime*/)
Trace.WriteLine($"({x:r},{y:r},{r:r},{x-r*y}): {good} {ticks} {ticks2} {result1} {result2}");
Trace.WriteLine($"({x:g17},{y:g17},{r:g17},{x-r*y}): {good} {ticks} {ticks2} {result1} {result2}");
}
}
}
@ -2310,7 +2378,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
r = 0.1;
Trace.WriteLine($"(x,y,r) = {x:r}, {y:r}, {r:r}");
Trace.WriteLine($"(x,y,r) = {x:g17}, {y:g17}, {r:g17}");
double intZOverZ;
try
@ -2321,7 +2389,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
{
intZOverZ = double.NaN;
}
Trace.WriteLine($"intZOverZ = {intZOverZ:r}");
Trace.WriteLine($"intZOverZ = {intZOverZ:g17}");
double intZ0 = NormalCdfIntegralBasic(x, y, r);
double intZ1 = 0; // NormalCdfIntegralFlip(x, y, r);
@ -2337,7 +2405,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
intZ = ExtendedDouble.NaN();
}
//double intZ = intZ0;
Trace.WriteLine($"intZ = {intZ:r} {intZ.ToDouble():r} {intZ0:r} {intZ1:r} {intZr:r}");
Trace.WriteLine($"intZ = {intZ:g17} {intZ.ToDouble():g17} {intZ0:g17} {intZ1:g17} {intZr:g17}");
if (intZ.Mantissa < 0) throw new Exception();
//double intZ2 = NormalCdfIntegralBasic(y, x, r);
//Trace.WriteLine($"intZ2 = {intZ2} {r*intZ}");
@ -2783,7 +2851,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
if (i % 2 == 1)
{
result = -numer / denom;
Console.WriteLine($"iter {i}: {result:r} {c:g4}");
Console.WriteLine($"iter {i}: {result:g17} {c:g4}");
if (double.IsInfinity(result) || double.IsNaN(result))
throw new Exception($"NormalCdfConFrac5 not converging for x={x} y={y} r={r}");
if (result == rOld)
@ -24016,7 +24084,7 @@ exp(x*x/4)*pcfu(0.5+n,-x)
double result = (double)Util.DynamicInvoke(fcn, args);
if (!double.IsNaN(result) && System.Math.Sign(result) != System.Math.Sign(fx) && fx != 0)
{
string strMsg = $"{name}({x:r})\t has wrong sign (result = {result:r})";
string strMsg = $"{name}({x:g17})\t has wrong sign (result = {result:g17})";
Trace.WriteLine(strMsg);
Assert.True(false, strMsg);
}
@ -24033,11 +24101,11 @@ exp(x*x/4)*pcfu(0.5+n,-x)
}
if (err < TOLERANCE)
{
Trace.WriteLine($"{name}({x:r})\t ok");
Trace.WriteLine($"{name}({x:g17})\t ok");
}
else
{
string strMsg = $"{name}({x:r})\t wrong by {err.ToString("g2")} (result = {result:r})";
string strMsg = $"{name}({x:g17})\t wrong by {err.ToString("g2")} (result = {result:g17})";
Trace.WriteLine(strMsg);
if (err > assertTolerance || double.IsNaN(err))
Assert.True(false, strMsg);

Просмотреть файл

@ -133,7 +133,46 @@ namespace Microsoft.ML.Probabilistic.Tests
(mode <= double.Epsilon && gammaPower.GetLogProb(smallestNormalized) >= max)
);
Interlocked.Add(ref count, 1);
if(count % 100000 == 0)
if (count % 100000 == 0)
Trace.WriteLine($"{count} cases passed");
});
Trace.WriteLine($"{count} cases passed");
}
[Fact]
public void TruncatedGamma_GetMode_MaximizesGetLogProb()
{
long count = 0;
Parallel.ForEach(OperatorTests.TruncatedGammas().Take(100000), dist =>
{
double argmax = double.NaN;
double max = double.NegativeInfinity;
foreach (var x in OperatorTests.DoublesAtLeastZero())
{
double logProb = dist.GetLogProb(x);
Assert.False(double.IsNaN(logProb));
if (logProb > max)
{
max = logProb;
argmax = x;
}
}
double mode = dist.GetMode();
Assert.False(double.IsNaN(mode));
double logProbBelowMode = dist.GetLogProb(MMath.PreviousDouble(mode));
Assert.False(double.IsNaN(logProbBelowMode));
double logProbAboveMode = dist.GetLogProb(MMath.NextDouble(mode));
Assert.False(double.IsNaN(logProbAboveMode));
double logProbAtMode = dist.GetLogProb(mode);
Assert.False(double.IsNaN(logProbAtMode));
logProbAtMode = System.Math.Max(System.Math.Max(logProbAtMode, logProbAboveMode), logProbBelowMode);
const double smallestNormalized = 1e-308;
Assert.True(logProbAtMode >= max ||
MMath.AbsDiff(logProbAtMode, max, 1e-8) < 1e-8 ||
(mode <= double.Epsilon && dist.GetLogProb(smallestNormalized) >= max)
);
Interlocked.Add(ref count, 1);
if (count % 100000 == 0)
Trace.WriteLine($"{count} cases passed");
});
Trace.WriteLine($"{count} cases passed");
@ -188,7 +227,7 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void GammaPowerMeanAndVarianceFuzzTest()
{
foreach(var gammaPower in OperatorTests.GammaPowers().Take(100000))
foreach (var gammaPower in OperatorTests.GammaPowers().Take(100000))
{
gammaPower.GetMeanAndVariance(out double mean, out double variance);
Assert.False(double.IsNaN(mean));
@ -289,6 +328,9 @@ namespace Microsoft.ML.Probabilistic.Tests
g = new TruncatedGamma(2, 1, 3, 3);
Assert.True(g.IsPointMass);
Assert.Equal(3.0, g.Point);
g = new TruncatedGamma(Gamma.FromShapeAndRate(4.94065645841247E-324, 4.94065645841247E-324), 0, 1e14);
Assert.True(g.Sample() >= 0);
}
/// <summary>
@ -303,7 +345,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
TruncatedGamma g = new TruncatedGamma(1, System.Math.Exp(-i), target, double.PositiveInfinity);
var mean = g.GetMean();
Console.WriteLine($"GetNormalizer = {g.GetNormalizer()} GetMean = {g.GetMean()}");
//Trace.WriteLine($"GetNormalizer = {g.GetNormalizer()} GetMean = {g.GetMean()}");
Assert.False(double.IsInfinity(mean));
Assert.False(double.IsNaN(mean));
double diff = System.Math.Abs(mean - target);
@ -318,7 +360,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
TruncatedGamma g = new TruncatedGamma(System.Math.Exp(i), 1, 0, target);
var mean = g.GetMean();
Console.WriteLine($"GetNormalizer = {g.GetNormalizer()} GetMean = {g.GetMean()}");
//Trace.WriteLine($"GetNormalizer = {g.GetNormalizer()} GetMean = {g.GetMean()}");
Assert.False(double.IsInfinity(mean));
Assert.False(double.IsNaN(mean));
double diff = System.Math.Abs(mean - target);
@ -341,13 +383,81 @@ namespace Microsoft.ML.Probabilistic.Tests
for (int i = 0; i < 100; i++)
{
var meanPower = g.GetMeanPower(-i);
Trace.WriteLine($"GetMeanPower({-i}) = {meanPower}");
//Trace.WriteLine($"GetMeanPower({-i}) = {meanPower}");
Assert.False(double.IsNaN(meanPower));
Assert.False(double.IsInfinity(meanPower));
if (i == 1) Assert.Equal(MMath.GammaUpper(shape-1, 1, false)/MMath.GammaUpper(shape, 1, false), meanPower, 1e-8);
if (i == 1) Assert.Equal(MMath.GammaUpper(shape - 1, 1, false) / MMath.GammaUpper(shape, 1, false), meanPower, 1e-8);
}
}
[Fact]
public void TruncatedGamma_GetMeanAndVariance_WithinBounds()
{
long count = 0;
Parallel.ForEach(OperatorTests.LowerTruncatedGammas()
.Take(100000), dist =>
{
dist.GetMeanAndVariance(out double mean, out double variance);
// Compiler.Quoter.Quote(dist)
Assert.True(mean >= dist.LowerBound);
Assert.True(mean <= dist.UpperBound);
Assert.Equal(mean, dist.GetMean());
Assert.True(variance >= 0);
Interlocked.Add(ref count, 1);
if (count % 100000 == 0)
Trace.WriteLine($"{count} cases passed");
});
Trace.WriteLine($"{count} cases passed");
}
[Fact]
[Trait("Category", "OpenBug")]
public void TruncatedGamma_GetMeanPower_WithinBounds()
{
var g = new TruncatedGamma(Gamma.FromShapeAndRate(4.94065645841247E-324, 4.94065645841247E-324), 0, 1e14);
Assert.True(g.GetMean() <= g.UpperBound);
for (int i = 0; i < 308; i++)
{
double power = System.Math.Pow(10, i);
//Trace.WriteLine($"GetMeanPower({power}) = {g.GetMeanPower(power)}");
Assert.True(g.GetMeanPower(power) <= g.UpperBound);
}
Assert.True(g.GetMeanPower(1.7976931348623157E+308) <= g.UpperBound);
Assert.True(new TruncatedGamma(Gamma.FromShapeAndRate(4.94065645841247E-324, 4.94065645841247E-324), 0, 1e9).GetMeanPower(1.7976931348623157E+308) <= 1e9);
Assert.True(new TruncatedGamma(Gamma.FromShapeAndRate(4.94065645841247E-324, 4.94065645841247E-324), 0, 1e6).GetMeanPower(1.7976931348623157E+308) <= 1e6);
Assert.True(new TruncatedGamma(Gamma.FromShapeAndRate(4.94065645841247E-324, 4.94065645841247E-324), 0, 100).GetMeanPower(4.94065645841247E-324) <= 100);
long count = 0;
Parallel.ForEach(OperatorTests.LowerTruncatedGammas()
.Take(100000), dist =>
{
foreach (var power in OperatorTests.Doubles())
{
if (dist.Gamma.Shape <= -power && dist.LowerBound == 0) continue;
double meanPower = dist.GetMeanPower(power);
if (power >= 0)
{
// Compiler.Quoter.Quote(dist)
Assert.True(meanPower >= System.Math.Pow(dist.LowerBound, power));
Assert.True(meanPower <= System.Math.Pow(dist.UpperBound, power));
}
else
{
Assert.True(meanPower <= System.Math.Pow(dist.LowerBound, power));
Assert.True(meanPower >= System.Math.Pow(dist.UpperBound, power));
}
if (power == 1)
{
Assert.Equal(meanPower, dist.GetMean());
}
}
Interlocked.Add(ref count, 1);
if (count % 100000 == 0)
Trace.WriteLine($"{count} cases passed");
});
Trace.WriteLine($"{count} cases passed");
}
[Fact]
public void GaussianTest()
{

Просмотреть файл

@ -39,7 +39,15 @@ namespace Microsoft.ML.Probabilistic.Tests
}
[Fact]
public void GammaPower_ReturnsShapeGreaterThan1()
public void TruncatedGammaPowerTest()
{
Assert.True(PowerOp.PowAverageConditional(new TruncatedGamma(2.333, 0.02547, 1, double.PositiveInfinity), 1.1209480955953663, GammaPower.Uniform(-1)).IsProper());
Assert.True(PowerOp.PowAverageConditional(new TruncatedGamma(5.196e+48, 5.567e-50, 1, double.PositiveInfinity), 0.0016132617913803061, GammaPower.Uniform(-1)).IsProper());
Assert.True(PowerOp.PowAverageConditional(new TruncatedGamma(23.14, 0.06354, 1, double.PositiveInfinity), 1.5543122344752203E-15, GammaPower.Uniform(-1)).IsProper());
}
[Fact]
public void TruncatedGammaPower_ReturnsGammaShapeGreaterThan1()
{
Variable<TruncatedGamma> xPriorVar = Variable.Observed(default(TruncatedGamma)).Named("xPrior");
Variable<double> x = Variable<double>.Random(xPriorVar).Named("x");
@ -58,12 +66,13 @@ namespace Microsoft.ML.Probabilistic.Tests
Gamma yLike = Gamma.Uniform();
yLikeVar.ObservedValue = yLike;
power.ObservedValue = powerValue;
var xActual = engine.Infer<TruncatedGamma>(x);
var yActual = engine.Infer<Gamma>(y);
// Importance sampling
GammaEstimator xEstimator = new GammaEstimator();
GammaEstimator yEstimator = new GammaEstimator();
MeanVarianceAccumulator mva = new MeanVarianceAccumulator();
MeanVarianceAccumulator yExpectedInverse = new MeanVarianceAccumulator();
int nSamples = 1000000;
for (int i = 0; i < nSamples; i++)
{
@ -73,13 +82,74 @@ namespace Microsoft.ML.Probabilistic.Tests
double weight = System.Math.Exp(logWeight);
xEstimator.Add(xSample, weight);
yEstimator.Add(ySample, weight);
mva.Add(1/ySample, weight);
yExpectedInverse.Add(1/ySample, weight);
}
Gamma xExpected = xEstimator.GetDistribution(new Gamma());
Gamma yExpected = yEstimator.GetDistribution(yLike);
double yActualMeanInverse = yActual.GetMeanPower(-1);
double meanInverseError = MMath.AbsDiff(mva.Mean, yActualMeanInverse, 1e-8);
Trace.WriteLine($"power = {powerValue}: y = {yActual}[E^-1={yActual.GetMeanPower(-1)}] should be {yExpected}[E^-1={mva.Mean}], error = {meanInverseError}");
double meanInverseError = MMath.AbsDiff(yExpectedInverse.Mean, yActualMeanInverse, 1e-8);
Trace.WriteLine($"power = {powerValue}:");
Trace.WriteLine($" x = {xActual} should be {xExpected}");
Trace.WriteLine($" y = {yActual}[E^-1={yActual.GetMeanPower(-1)}] should be {yExpected}[E^-1={yExpectedInverse.Mean}], E^-1 error = {meanInverseError}");
Assert.True(yActual.Shape > 1);
Assert.True(MMath.AbsDiff(yExpected.GetMean(), yActual.GetMean(), 1e-8) < 1);
Assert.True(meanInverseError < 1e-2);
}
}
[Fact]
public void TruncatedGammaPower_ReturnsGammaPowerShapeGreaterThan1()
{
var result = PowerOp.PowAverageConditional(new TruncatedGamma(0.4, 0.5, 1, double.PositiveInfinity), 0, GammaPower.PointMass(0, -1));
Assert.True(result.IsPointMass);
Assert.Equal(1.0, result.Point);
Variable<TruncatedGamma> xPriorVar = Variable.Observed(default(TruncatedGamma)).Named("xPrior");
Variable<double> x = Variable<double>.Random(xPriorVar).Named("x");
Variable<double> power = Variable.Observed(0.5).Named("power");
var y = x ^ power;
y.Name = nameof(y);
Variable<GammaPower> yLikeVar = Variable.Observed(default(GammaPower)).Named("yLike");
Variable.ConstrainEqualRandom(y, yLikeVar);
y.SetMarginalPrototype(yLikeVar);
InferenceEngine engine = new InferenceEngine();
foreach (var powerValue in linspace(1, 10, 10))
{
TruncatedGamma xPrior = new TruncatedGamma(Gamma.FromShapeAndRate(3, 3), 1, double.PositiveInfinity);
xPriorVar.ObservedValue = xPrior;
GammaPower yLike = GammaPower.Uniform(-1);
//GammaPower yLike = GammaPower.FromShapeAndRate(1, 0.5, -1);
yLikeVar.ObservedValue = yLike;
power.ObservedValue = powerValue;
var xActual = engine.Infer<TruncatedGamma>(x);
var yActual = engine.Infer<GammaPower>(y);
// Importance sampling
GammaEstimator xEstimator = new GammaEstimator();
GammaPowerEstimator yEstimator = new GammaPowerEstimator(yLike.Power);
MeanVarianceAccumulator yExpectedInverse = new MeanVarianceAccumulator();
MeanVarianceAccumulator yMva = new MeanVarianceAccumulator();
int nSamples = 1000000;
for (int i = 0; i < nSamples; i++)
{
double xSample = xPrior.Sample();
double ySample = System.Math.Pow(xSample, power.ObservedValue);
double logWeight = yLike.GetLogProb(ySample);
double weight = System.Math.Exp(logWeight);
xEstimator.Add(xSample, weight);
yEstimator.Add(ySample, weight);
yExpectedInverse.Add(1 / ySample, weight);
yMva.Add(ySample, weight);
}
Gamma xExpected = xEstimator.GetDistribution(new Gamma());
GammaPower yExpected = yEstimator.GetDistribution(yLike);
yExpected = GammaPower.FromMeanAndVariance(yMva.Mean, yMva.Variance, yLike.Power);
double yActualMeanInverse = yActual.GetMeanPower(-1);
double meanInverseError = MMath.AbsDiff(yExpectedInverse.Mean, yActualMeanInverse, 1e-8);
Trace.WriteLine($"power = {powerValue}:");
Trace.WriteLine($" x = {xActual} should be {xExpected}");
Trace.WriteLine($" y = {yActual}[E^-1={yActual.GetMeanPower(-1)}] should be {yExpected}[E^-1={yExpectedInverse.Mean}], error = {meanInverseError}");
Assert.True(yActual.Shape > 1);
Assert.True(MMath.AbsDiff(yExpected.GetMean(), yActual.GetMean(), 1e-8) < 1);
Assert.True(meanInverseError < 1e-2);
@ -106,6 +176,7 @@ namespace Microsoft.ML.Probabilistic.Tests
GammaPower yLike = GammaPower.Uniform(-1);
yLikeVar.ObservedValue = yLike;
power.ObservedValue = powerValue;
var xActual = engine.Infer<GammaPower>(x);
var yActual = engine.Infer<GammaPower>(y);
// Importance sampling
@ -127,7 +198,9 @@ namespace Microsoft.ML.Probabilistic.Tests
Gamma yExpected = yEstimator.GetDistribution(new Gamma());
double yActualMeanInverse = yActual.GetMeanPower(-1);
double meanInverseError = MMath.AbsDiff(mva.Mean, yActualMeanInverse, 1e-8);
Trace.WriteLine($"power = {powerValue}: y = {yActual}[E^-1={yActualMeanInverse}] should be {yExpected}[E^-1={mva.Mean}], error = {meanInverseError}");
Trace.WriteLine($"power = {powerValue}:");
Trace.WriteLine($" x = {xActual} should be {xExpected}");
Trace.WriteLine($" y = {yActual}[E^-1={yActualMeanInverse}] should be {yExpected}[E^-1={mva.Mean}], error = {meanInverseError}");
Assert.True(yActual.Shape > 2);
Assert.True(MMath.AbsDiff(yExpected.GetMean(), yActual.GetMean(), 1e-8) < 1);
//Assert.True(meanInverseError < 10);
@ -1900,7 +1973,7 @@ namespace Microsoft.ML.Probabilistic.Tests
double aError = aExpected.MaxDiff(aActual);
double productError = productExpected.MaxDiff(productActual);
double evError = MMath.AbsDiff(evExpected, evActual, 1e-6);
bool trace = false;
bool trace = true;
if (trace)
{
Trace.WriteLine($"b = {bActual} should be {bExpected}, error = {bError}");
@ -1916,6 +1989,212 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
internal static void TestLogEvidence()
{
LogEvidenceScale(new GammaPower(100, 5.0 / 100, -1), new GammaPower(100, 2.0 / 100, -1), new GammaPower(100, 3.0 / 100, -1), 0.2);
}
internal static void LogEvidenceShift(GammaPower sum, GammaPower a, GammaPower b)
{
double logz100 = PlusGammaOp.LogAverageFactor(GammaPower.FromShapeAndRate(sum.Shape-1, sum.Rate, sum.Power), GammaPower.FromShapeAndRate(a.Shape, a.Rate, a.Power), GammaPower.FromShapeAndRate(b.Shape, b.Rate, b.Power));
double logz010 = PlusGammaOp.LogAverageFactor(GammaPower.FromShapeAndRate(sum.Shape, sum.Rate, sum.Power), GammaPower.FromShapeAndRate(a.Shape-1, a.Rate, a.Power), GammaPower.FromShapeAndRate(b.Shape, b.Rate, b.Power));
double logz001 = PlusGammaOp.LogAverageFactor(GammaPower.FromShapeAndRate(sum.Shape, sum.Rate, sum.Power), GammaPower.FromShapeAndRate(a.Shape, a.Rate, a.Power), GammaPower.FromShapeAndRate(b.Shape-1, b.Rate, b.Power));
double lhs = logz100 + System.Math.Log(sum.Rate / (sum.Shape - 1));
double rhs1 = logz010 + System.Math.Log(a.Rate / (a.Shape - 1));
double rhs2 = logz001 + System.Math.Log(b.Rate / (b.Shape - 1));
Trace.WriteLine($"lhs = {lhs} rhs = {MMath.LogSumExp(rhs1, rhs2)}");
}
internal static void LogEvidenceScale(GammaPower sum, GammaPower a, GammaPower b, double scale)
{
double logZ = LogEvidenceBrute(sum, a, b);
double logZ2 = System.Math.Log(scale) + LogEvidenceBrute(GammaPower.FromShapeAndRate(sum.Shape, scale * sum.Rate, sum.Power), GammaPower.FromShapeAndRate(a.Shape, scale * a.Rate, a.Power), GammaPower.FromShapeAndRate(b.Shape, scale * b.Rate, b.Power));
Trace.WriteLine($"logZ = {logZ} {logZ2}");
}
internal static double LogEvidenceBrute(GammaPower sumPrior, GammaPower aPrior, GammaPower bPrior)
{
double totalWeight = 0;
int numIter = 1000000;
for (int iter = 0; iter < numIter; iter++)
{
if (iter % 1000000 == 0) Trace.WriteLine($"iter = {iter}");
double bSample = bPrior.Sample();
double aSample = aPrior.Sample();
if (sumPrior.Rate > 1e100)
{
bSample = 0;
aSample = 0;
}
double sumSample = aSample + bSample;
double logWeight = sumPrior.GetLogProb(sumSample);
double weight = System.Math.Exp(logWeight);
totalWeight += weight;
}
Trace.WriteLine($"totalWeight = {totalWeight}");
return System.Math.Log(totalWeight / numIter);
}
internal static double LogEvidenceIncrementBShape(GammaPower sum, GammaPower a, GammaPower b)
{
const double threshold = 0;
if (b.Shape > threshold)
{
//return PlusGammaOp.LogAverageFactor(sum, a, b);
return LogEvidenceBrute(sum, a, b);
}
double logz100 = LogEvidenceIncrementBShape(GammaPower.FromShapeAndRate(sum.Shape - 1, sum.Rate, sum.Power), GammaPower.FromShapeAndRate(a.Shape, a.Rate, a.Power), GammaPower.FromShapeAndRate(b.Shape + 1, b.Rate, b.Power));
double logz010 = LogEvidenceIncrementBShape(GammaPower.FromShapeAndRate(sum.Shape, sum.Rate, sum.Power), GammaPower.FromShapeAndRate(a.Shape - 1, a.Rate, a.Power), GammaPower.FromShapeAndRate(b.Shape + 1, b.Rate, b.Power));
double lhs = logz100 + System.Math.Log(sum.Rate / (sum.Shape - 1));
double rhs1 = logz010 + System.Math.Log(a.Rate / (a.Shape - 1));
double rhs2 = System.Math.Log(b.Rate / b.Shape);
return MMath.LogDifferenceOfExp(lhs, rhs1) - rhs2;
}
[Fact]
[Trait("Category", "OpenBug")]
public void GammaPowerSumRRRTest()
{
//Assert.True(PlusGammaOp.AAverageConditional(GammaPower.FromShapeAndRate(299, 2135, -1), GammaPower.FromShapeAndRate(2.01, 10, -1), GammaPower.FromShapeAndRate(12, 22, -1), GammaPower.Uniform(-1)).Shape > 2);
//Assert.True(PlusGammaOp.AAverageConditional(GammaPower.Uniform(-1), GammaPower.FromShapeAndRate(2.0095439611576689, 43.241375394505766, -1), GammaPower.FromShapeAndRate(12, 11, -1), GammaPower.Uniform(-1)).IsUniform());
//Assert.False(double.IsNaN(PlusGammaOp.BAverageConditional(new GammaPower(287, 0.002132, -1), new GammaPower(1.943, 1.714, -1), new GammaPower(12, 0.09091, -1), GammaPower.Uniform(-1)).Shape));
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
IfBlock block = Variable.If(evidence);
Variable<GammaPower> bPriorVar = Variable.Observed(default(GammaPower)).Named("bPrior");
Variable<double> b = Variable<double>.Random(bPriorVar).Named("b");
Variable<GammaPower> aPriorVar = Variable.Observed(default(GammaPower)).Named("aPrior");
Variable<double> a = Variable<double>.Random(aPriorVar).Named("a");
Variable<double> sum = (a + b).Named("sum");
Variable<GammaPower> sumPriorVar = Variable.Observed(default(GammaPower)).Named("sumPrior");
Variable.ConstrainEqualRandom(sum, sumPriorVar);
block.CloseBlock();
InferenceEngine engine = new InferenceEngine();
var groundTruthArray = new[]
{
//((new GammaPower(12, 0.09091, -1), new GammaPower(1.943, 1.714, -1), new GammaPower(287, 0.002132, -1)),
// (GammaPower.FromShapeAndRate(23.445316648707465, 25.094880573396285, -1.0), GammaPower.FromShapeAndRate(6.291922598211336, 2.6711637040924909, -1.0), GammaPower.FromShapeAndRate(297.59289156399706, 481.31323394825631, -1.0), -0.517002984399292)),
//((GammaPower.FromShapeAndRate(12, 22, -1), GammaPower.FromShapeAndRate(2.01, 10, -1), GammaPower.FromShapeAndRate(299, 2135, -1)),
// (GammaPower.FromShapeAndRate(12.4019151884055, 23.487535138993064, -1.0), GammaPower.FromShapeAndRate(47.605465737960976, 236.41203334327037, -1.0), GammaPower.FromShapeAndRate(303.94717779788243, 2160.7976040127091, -1.0), -2.26178042225837)),
//((GammaPower.FromShapeAndRate(1, 2, 1), GammaPower.FromShapeAndRate(10, 10, 1), GammaPower.FromShapeAndRate(101, double.MaxValue, 1)),
// (GammaPower.PointMass(0, 1.0), GammaPower.FromShapeAndScale(9, 0.1, 1), GammaPower.PointMass(5.6183114927306835E-307, 1), 0.79850769622135)),
//((GammaPower.FromShapeAndRate(1, 2, 1), GammaPower.FromShapeAndRate(10, 10, 1), GammaPower.FromShapeAndRate(101, double.PositiveInfinity, 1)),
// (GammaPower.PointMass(0, 1.0), GammaPower.PointMass(0, 1.0), GammaPower.FromShapeAndRate(101, double.PositiveInfinity, 1), double.NegativeInfinity)),
//((GammaPower.FromShapeAndRate(2.25, 0.625, -1), GammaPower.FromShapeAndRate(100000002, 100000001, -1), GammaPower.PointMass(5, -1)),
// (GammaPower.FromShapeAndRate(1599999864.8654146, 6399999443.0866585, -1.0), GammaPower.FromShapeAndRate(488689405.117356, 488689405.88170129, -1.0), GammaPower.FromShapeAndRate(double.PositiveInfinity, 5.0, -1.0), -4.80649551611576)),
//((GammaPower.FromShapeAndRate(2.25, 0.625, -1), GammaPower.FromShapeAndRate(100000002, 100000001, -1), GammaPower.PointMass(0, -1)),
// (GammaPower.FromShapeAndRate(5.25, 0.625, -1.0), GammaPower.PointMass(0, -1.0), GammaPower.PointMass(0, -1), double.NegativeInfinity)),
((GammaPower.FromShapeAndRate(0.83228652924877289, 0.31928405884349487, -1), GammaPower.FromShapeAndRate(1.7184321234630087, 0.709692740551586, -1), GammaPower.FromShapeAndRate(491, 1583.0722891566263, -1)),
(GammaPower.FromShapeAndRate(5.6062357530254419, 8.7330355320375, -1.0), GammaPower.FromShapeAndRate(3.7704064465114597, 3.6618414405426956, -1.0), GammaPower.FromShapeAndRate(493.79911104976264, 1585.67297686381, -1.0), -2.62514943790608)),
//((GammaPower.FromShapeAndRate(1, 1, 1), GammaPower.FromShapeAndRate(1, 1, 1), GammaPower.Uniform(1)),
// (GammaPower.FromShapeAndRate(1, 1, 1), GammaPower.FromShapeAndRate(1, 1, 1), new GammaPower(2, 1, 1), 0)),
//((GammaPower.FromShapeAndRate(1, 1, 1), GammaPower.FromShapeAndRate(1, 1, 1), GammaPower.FromShapeAndRate(10, 1, 1)),
// (GammaPower.FromShapeAndRate(2.2, 0.8, 1), GammaPower.FromShapeAndRate(2.2, 0.8, 1), GammaPower.FromShapeAndRate(11, 2, 1), -5.32133409609914)),
//((GammaPower.FromShapeAndRate(3, 1, -1), GammaPower.FromShapeAndRate(4, 1, -1), GammaPower.Uniform(-1)),
// (GammaPower.FromShapeAndRate(3, 1, -1), GammaPower.FromShapeAndRate(4, 1, -1), GammaPower.FromShapeAndRate(4.311275674659143, 2.7596322350392035, -1.0), 0)),
//((GammaPower.FromShapeAndRate(3, 1, -1), GammaPower.FromShapeAndRate(4, 1, -1), GammaPower.FromShapeAndRate(10, 1, -1)),
// (new GammaPower(10.17, 0.6812, -1), new GammaPower(10.7, 0.7072, -1), new GammaPower(17.04, 0.2038, -1), -5.80097480415528)),
//((GammaPower.FromShapeAndRate(2, 1, -1), GammaPower.FromShapeAndRate(2, 1, -1), GammaPower.Uniform(-1)),
// (GammaPower.FromShapeAndRate(2, 1, -1), GammaPower.FromShapeAndRate(2, 1, -1), new GammaPower(2, 2, -1), 0)),
//((GammaPower.FromShapeAndRate(1, 1, -1), GammaPower.FromShapeAndRate(1, 1, -1), GammaPower.FromShapeAndRate(30, 1, -1)),
// (GammaPower.FromShapeAndRate(11.057594449558747, 2.0731054100295871, -1.0), GammaPower.FromShapeAndRate(11.213079710986863, 2.1031756133562678, -1.0), GammaPower.FromShapeAndRate(28.815751741667615, 1.0848182432207041, -1.0), -4.22210057295786)),
//((GammaPower.FromShapeAndRate(1, 1, 2), GammaPower.FromShapeAndRate(1, 1, 2), GammaPower.Uniform(2)),
// (GammaPower.FromShapeAndRate(1, 1, 2), GammaPower.FromShapeAndRate(1, 1, 2), GammaPower.FromShapeAndRate(0.16538410345846666, 0.219449497990138, 2.0), 0)),
//((GammaPower.FromShapeAndRate(1, 1, 2), GammaPower.FromShapeAndRate(1, 1, 2), GammaPower.FromShapeAndRate(30, 1, 2)),
// (GammaPower.FromShapeAndRate(8.72865708291647, 1.71734403810018, 2.0), GammaPower.FromShapeAndRate(8.5298603954575931, 1.6767026737490067, 2.0), GammaPower.FromShapeAndRate(25.831187278202215, 1.0852321896648485, 2.0), -14.5369973268808)),
};
//using (TestUtils.TemporarilyAllowGammaImproperProducts)
{
foreach (var groundTruth in groundTruthArray)
{
var (bPrior, aPrior, sumPrior) = groundTruth.Item1;
var (bExpected, aExpected, sumExpected, evExpected) = groundTruth.Item2;
bPriorVar.ObservedValue = bPrior;
aPriorVar.ObservedValue = aPrior;
sumPriorVar.ObservedValue = sumPrior;
GammaPower bActual = engine.Infer<GammaPower>(b);
GammaPower aActual = engine.Infer<GammaPower>(a);
GammaPower sumActual = engine.Infer<GammaPower>(sum);
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
double logZ = LogEvidenceIncrementBShape(sumPrior, aPrior, bPrior);
Trace.WriteLine($"LogZ = {logZ}");
if (true)
{
// importance sampling
Rand.Restart(0);
double totalWeight = 0;
GammaPowerEstimator bEstimator = new GammaPowerEstimator(bPrior.Power);
GammaPowerEstimator aEstimator = new GammaPowerEstimator(aPrior.Power);
GammaPowerEstimator sumEstimator = new GammaPowerEstimator(sumPrior.Power);
MeanVarianceAccumulator bMva = new MeanVarianceAccumulator();
MeanVarianceAccumulator aMva = new MeanVarianceAccumulator();
MeanVarianceAccumulator sumMva = new MeanVarianceAccumulator();
int numIter = 10000000;
for (int iter = 0; iter < numIter; iter++)
{
if (iter % 1000000 == 0) Trace.WriteLine($"iter = {iter}");
double bSample = bPrior.Sample();
double aSample = aPrior.Sample();
if (sumPrior.Rate > 1e100)
{
bSample = 0;
aSample = 0;
}
double sumSample = aSample + bSample;
double logWeight = sumPrior.GetLogProb(sumSample);
double weight = System.Math.Exp(logWeight);
totalWeight += weight;
bEstimator.Add(bSample, weight);
aEstimator.Add(aSample, weight);
sumEstimator.Add(sumSample, weight);
bMva.Add(bSample, weight);
aMva.Add(aSample, weight);
sumMva.Add(sumSample, weight);
}
Trace.WriteLine($"totalWeight = {totalWeight}");
evExpected = System.Math.Log(totalWeight / numIter);
bExpected = bEstimator.GetDistribution(bPrior);
aExpected = aEstimator.GetDistribution(aPrior);
sumExpected = sumEstimator.GetDistribution(sumPrior);
bExpected = GammaPower.FromMeanAndVariance(bMva.Mean, bMva.Variance, bPrior.Power);
aExpected = GammaPower.FromMeanAndVariance(aMva.Mean, aMva.Variance, aPrior.Power);
sumExpected = GammaPower.FromMeanAndVariance(sumMva.Mean, sumMva.Variance, sumPrior.Power);
Trace.WriteLine($"{Quoter.Quote(bExpected)}, {Quoter.Quote(aExpected)}, {Quoter.Quote(sumExpected)}, {evExpected}");
}
else Trace.WriteLine($"{Quoter.Quote(bActual)}, {Quoter.Quote(aActual)}, {Quoter.Quote(sumActual)}, {evActual}");
double bError = MomentDiff(bExpected, bActual);
double aError = MomentDiff(aExpected, aActual);
double productError = MomentDiff(sumExpected, sumActual);
double evError = MMath.AbsDiff(evExpected, evActual, 1e-6);
bool trace = true;
if (trace)
{
Trace.WriteLine($"b = {bActual} should be {bExpected}, error = {bError}");
Trace.WriteLine($"a = {aActual}[variance={aActual.GetVariance()}] should be {aExpected}[variance={aExpected.GetVariance()}], error = {aError}");
Trace.WriteLine($"product = {sumActual} should be {sumExpected}, error = {productError}");
Trace.WriteLine($"evidence = {evActual} should be {evExpected}, error = {evError}");
}
Assert.True(bError < 3);
Assert.True(aError < 1);
Assert.True(productError < 1);
Assert.True(evError < 0.3);
}
}
}
public static double MomentDiff(GammaPower expected, GammaPower actual)
{
expected.GetMeanAndVariance(out double meanExpected, out double varianceExpected);
actual.GetMeanAndVariance(out double meanActual, out double varianceActual);
const double rel = 1e-8;
return System.Math.Max(MMath.AbsDiff(meanExpected, meanActual, rel), MMath.AbsDiff(varianceExpected, varianceActual, rel));
}
[Fact]
public void GammaCCRTest()
{

Просмотреть файл

@ -384,6 +384,9 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void GammaFromShapeAndRateOpTest()
{
Assert.False(double.IsNaN(GammaFromShapeAndRateOp_Slow.SampleAverageConditional(Gamma.PointMass(0), 2.0, new Gamma(1, 1)).Rate));
Assert.False(double.IsNaN(GammaFromShapeAndRateOp_Slow.RateAverageConditional(new Gamma(1, 1), 2.0, Gamma.PointMass(0)).Rate));
Gamma sample, rate, result;
double prevDiff;
double shape = 3;
@ -856,7 +859,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
double point = 3;
Gaussian toPoint = MaxGaussianOp.AAverageConditional(max, Gaussian.PointMass(point), b);
//Console.WriteLine($"{point} {toPoint} {toPoint.MeanTimesPrecision:r} {toPoint.Precision:r}");
//Console.WriteLine($"{point} {toPoint} {toPoint.MeanTimesPrecision:g17} {toPoint.Precision:g17}");
if (max.IsPointMass && b.IsPointMass)
{
Gaussian toUniform = MaxGaussianOp.AAverageConditional(max, Gaussian.Uniform(), b);
@ -874,7 +877,7 @@ namespace Microsoft.ML.Probabilistic.Tests
{
Gaussian a = Gaussian.FromMeanAndPrecision(point, System.Math.Pow(10, i));
Gaussian to_a = MaxGaussianOp.AAverageConditional(max, a, b);
//Console.WriteLine($"{a} {to_a} {to_a.MeanTimesPrecision:r} {to_a.Precision:r}");
//Console.WriteLine($"{a} {to_a} {to_a.MeanTimesPrecision:g17} {to_a.Precision:g17}");
double diff = toPoint.MaxDiff(to_a);
if (diff < 1e-14) diff = 0;
Assert.True(diff <= oldDiff);
@ -1362,16 +1365,6 @@ namespace Microsoft.ML.Probabilistic.Tests
});
}
[Fact]
[Trait("Category", "OpenBug")]
public void GammaUpper_IsDecreasingInX()
{
foreach (double a in DoublesGreaterThanZero())
{
IsIncreasingForAtLeastZero(x => -MMath.GammaUpper(a, x));
}
}
[Fact]
public void GammaLower_IsDecreasingInA()
{
@ -1382,13 +1375,21 @@ namespace Microsoft.ML.Probabilistic.Tests
}
[Fact]
[Trait("Category", "OpenBug")]
public void GammaUpper_IsDecreasingInX()
{
Parallel.ForEach(DoublesGreaterThanZero(), a =>
{
IsIncreasingForAtLeastZero(x => -MMath.GammaUpper(a, x));
});
}
[Fact]
public void GammaUpper_IsIncreasingInA()
{
foreach (double x in DoublesAtLeastZero())
Parallel.ForEach(DoublesAtLeastZero(), x =>
{
IsIncreasingForAtLeastZero(a => MMath.GammaUpper(a + double.Epsilon, x));
}
IsIncreasingForAtLeastZero(a => MMath.GammaUpper(a + double.Epsilon, x), 2);
});
}
[Fact]
@ -1454,18 +1455,22 @@ namespace Microsoft.ML.Probabilistic.Tests
/// </summary>
/// <param name="func"></param>
/// <returns></returns>
public bool IsIncreasingForAtLeastZero(Func<double, double> func)
public bool IsIncreasingForAtLeastZero(Func<double, double> func, double ulpError = 0)
{
double scale = 1 + 2e-16 * ulpError;
foreach (var x in DoublesAtLeastZero())
{
double fx = func(x);
double smallFx;
if (fx >= 0) smallFx = fx / scale;
else smallFx = fx * scale;
foreach (var delta in DoublesGreaterThanZero())
{
double x2 = x + delta;
if (double.IsPositiveInfinity(delta)) x2 = delta;
double fx2 = func(x2);
// The cast here is important when running in 32-bit, Release mode.
Assert.True((double)fx2 >= fx);
Assert.True((double)fx2 >= smallFx);
}
}
return true;
@ -1552,7 +1557,7 @@ zL = (L - mx)*sqrt(prec)
//X = Gaussian.FromMeanAndPrecision(mx, X.Precision + 1.0000000000000011E-19);
Gaussian toX2 = DoubleIsBetweenOp.XAverageConditional(Bernoulli.PointMass(true), X, lowerBound, upperBound);
Gaussian xPost = X * toX2;
Console.WriteLine($"mx = {X.GetMean():r} mp = {xPost.GetMean():r} vp = {xPost.GetVariance():r} toX = {toX2}");
Console.WriteLine($"mx = {X.GetMean():g17} mp = {xPost.GetMean():g17} vp = {xPost.GetVariance():g17} toX = {toX2}");
//X.Precision *= 100;
//X.MeanTimesPrecision *= 0.999999;
//X.SetMeanAndPrecision(mx, X.Precision * 2);
@ -1671,7 +1676,7 @@ zL = (L - mx)*sqrt(prec)
// all we need is a good approx for (ZR/diff - 1)
double ZR5 = (1.0 / 6 * diffs * diffs * diffs * (-1 + zL * zL) + 0.5 * diffs * diffs * (-zL) + diffs) / sqrtPrec;
double ZR6 = (1.0 / 24 * diffs * diffs * diffs * diffs * (zL - zL * zL * zL + 2 * zL) + 1.0 / 6 * diffs * diffs * diffs * (-1 + zL * zL) + 0.5 * diffs * diffs * (-zL) + diffs) / sqrtPrec;
//Console.WriteLine($"zL = {zL:r} delta = {delta:r} (-zL-zU)/2*diffs={(-zL - zU) / 2 * diffs:r} diffs = {diffs:r} diffs*zL = {diffs * zL}");
//Console.WriteLine($"zL = {zL:g17} delta = {delta:g17} (-zL-zU)/2*diffs={(-zL - zU) / 2 * diffs:g17} diffs = {diffs:g17} diffs*zL = {diffs * zL}");
//Console.WriteLine($"Z/N = {ZR} {ZR2} {ZR2b} {ZR2c} asympt:{ZRasympt} {ZR4} {ZR5} {ZR6}");
// want to compute Z/X.Prob(L)/diffs + (exp(delta)-1)/delta
double expMinus1RatioMinus1RatioMinusHalf = MMath.ExpMinus1RatioMinus1RatioMinusHalf(delta);
@ -1741,7 +1746,7 @@ zL = (L - mx)*sqrt(prec)
// delta = 0.0002 diffs = 0.00014142135623731: bad
// delta = 2E-08 diffs = 1.4142135623731E-08: good
double numer5 = delta * diffs * diffs / 6 + diffs * diffs * diffs * diffs / 24 - 1.0 / 24 * delta * delta * delta - 1.0 / 120 * delta * delta * delta * delta + diffs * diffs / 12;
//Console.WriteLine($"numer = {numer} smallzL:{numer1SmallzL} largezL:{numerLargezL} {numerLargezL2} {numerLargezL3} {numerLargezL4:r} {numerLargezL5:r} {numerLargezL6:r} {numerLargezL7:r} {numerLargezL8:r} {numer1e} asympt:{numerAsympt} {numerAsympt2} {numer2} {numer3} {numer4} {numer5}");
//Console.WriteLine($"numer = {numer} smallzL:{numer1SmallzL} largezL:{numerLargezL} {numerLargezL2} {numerLargezL3} {numerLargezL4:g17} {numerLargezL5:g17} {numerLargezL6:g17} {numerLargezL7:g17} {numerLargezL8:g17} {numer1e} asympt:{numerAsympt} {numerAsympt2} {numer2} {numer3} {numer4} {numer5}");
double mp = mx - System.Math.Exp(logPhiL - logZ) * expMinus1 / X.Precision;
double mp2 = center + (delta / diff - System.Math.Exp(logPhiL - logZ) * expMinus1) / X.Precision;
double mp3 = center + (delta / diff * ZR2b - expMinus1) * System.Math.Exp(logPhiL - logZ) / X.Precision;
@ -1776,7 +1781,7 @@ zL = (L - mx)*sqrt(prec)
//WriteLast(mpSmallzUs);
double mp5 = center + numer5 * delta / diffs * alphaXcLprecDiffs;
//double mpBrute = Util.ArrayInit(10000000, i => X.Sample()).Where(sample => (sample > lowerBound) && (sample < upperBound)).Average();
//Console.WriteLine($"mp = {mp} {mp2} {mp3} {mpLargezL4:r} {mpLargezL5:r} {mpLargezL6:r} {mpLargezL7:r} {mpLargezL8:r} asympt:{mpAsympt} {mpAsympt2} {mp5}");
//Console.WriteLine($"mp = {mp} {mp2} {mp3} {mpLargezL4:g17} {mpLargezL5:g17} {mpLargezL6:g17} {mpLargezL7:g17} {mpLargezL8:g17} asympt:{mpAsympt} {mpAsympt2} {mp5}");
double cL = -1 / expMinus1;
// rU*diffs = rU*zU - rU*zL = r1U - 1 - rU*zL + rL*zL - rL*zL = r1U - 1 - drU*zL - (r1L-1) = dr1U - drU*zL
// zL = -diffs/2 - delta/diffs
@ -2226,6 +2231,43 @@ zL = (L - mx)*sqrt(prec)
}
}
/// <summary>
/// Generates a representative set of proper TruncatedGamma distributions with infinite upper bound.
/// </summary>
/// <returns></returns>
public static IEnumerable<TruncatedGamma> LowerTruncatedGammas()
{
foreach (var gamma in Gammas())
{
foreach (var lowerBound in DoublesAtLeastZero())
{
if (gamma.IsPointMass && gamma.Point < lowerBound) continue;
yield return new TruncatedGamma(gamma, lowerBound, double.PositiveInfinity);
}
}
}
/// <summary>
/// Generates a representative set of proper TruncatedGamma distributions.
/// </summary>
/// <returns></returns>
public static IEnumerable<TruncatedGamma> TruncatedGammas()
{
foreach (var gamma in Gammas())
{
foreach (var lowerBound in DoublesAtLeastZero())
{
foreach (var gap in DoublesGreaterThanZero())
{
double upperBound = lowerBound + gap;
if (upperBound == lowerBound) continue;
if (gamma.IsPointMass && (gamma.Point < lowerBound || gamma.Point > upperBound)) continue;
yield return new TruncatedGamma(gamma, lowerBound, upperBound);
}
}
}
}
/// <summary>
/// Generates a representative set of proper Gaussian distributions.
/// </summary>
@ -2272,10 +2314,10 @@ zL = (L - mx)*sqrt(prec)
{
Parallel.ForEach(DoublesLessThanZero(), lowerBound =>
{
//Console.WriteLine($"isBetween = {isBetween}, lowerBound = {lowerBound:r}");
//Console.WriteLine($"isBetween = {isBetween}, lowerBound = {lowerBound:g17}");
foreach (var upperBound in new[] { -lowerBound }.Concat(UpperBounds(lowerBound)).Take(1))
{
//Console.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
//Console.WriteLine($"lowerBound = {lowerBound:g17}, upperBound = {upperBound:g17}");
double center = MMath.Average(lowerBound, upperBound);
if (double.IsNegativeInfinity(lowerBound) && double.IsPositiveInfinity(upperBound))
center = 0;
@ -2314,8 +2356,8 @@ zL = (L - mx)*sqrt(prec)
}
});
}
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}, isBetween = {meanMaxUlpErrorIsBetween}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}, isBetween = {precMaxUlpErrorIsBetween}");
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:g17}, upperBound = {meanMaxUlpErrorUpperBound:g17}, isBetween = {meanMaxUlpErrorIsBetween}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:g17}, upperBound = {precMaxUlpErrorUpperBound:g17}, isBetween = {precMaxUlpErrorIsBetween}");
Assert.True(meanMaxUlpError == 0);
Assert.True(precMaxUlpError == 0);
}
@ -2331,7 +2373,7 @@ zL = (L - mx)*sqrt(prec)
for (int i = 0; i < 1000; i++)
{
double logProb = DoubleIsBetweenOp.LogProbBetween(x, lowerBound, upperBound);
Console.WriteLine($"{x.Precision:r} {logProb:r}");
Console.WriteLine($"{x.Precision:g17} {logProb:g17}");
x = Gaussian.FromMeanAndPrecision(x.GetMean(), x.Precision + 1000000000000 * MMath.Ulp(x.Precision));
}
}
@ -2348,7 +2390,7 @@ zL = (L - mx)*sqrt(prec)
{
foreach (double upperBound in new double[] { 1 }.Concat(UpperBounds(lowerBound)).Take(1))
{
if (trace) Trace.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
if (trace) Trace.WriteLine($"lowerBound = {lowerBound:g17}, upperBound = {upperBound:g17}");
foreach (var x in Gaussians())
{
if (x.IsPointMass) continue;
@ -2383,7 +2425,7 @@ zL = (L - mx)*sqrt(prec)
}
}
}
if (trace) Trace.WriteLine($"maxUlpError = {maxUlpError}, lowerBound = {maxUlpErrorLowerBound:r}, upperBound = {maxUlpErrorUpperBound:r}");
if (trace) Trace.WriteLine($"maxUlpError = {maxUlpError}, lowerBound = {maxUlpErrorLowerBound:g17}, upperBound = {maxUlpErrorUpperBound:g17}");
}
});
Assert.True(maxUlpError < 1e3);
@ -2404,7 +2446,7 @@ zL = (L - mx)*sqrt(prec)
{
foreach (double upperBound in new[] { -9999.9999999999982 }.Concat(UpperBounds(lowerBound)).Take(1))
{
if (trace) Trace.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
if (trace) Trace.WriteLine($"lowerBound = {lowerBound:g17}, upperBound = {upperBound:g17}");
Parallel.ForEach(Gaussians().Where(g => !g.IsPointMass), x =>
{
double mx = x.GetMean();
@ -2488,8 +2530,8 @@ zL = (L - mx)*sqrt(prec)
}
if (trace)
{
Trace.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}");
Trace.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}");
Trace.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:g17}, upperBound = {meanMaxUlpErrorUpperBound:g17}");
Trace.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:g17}, upperBound = {precMaxUlpErrorUpperBound:g17}");
}
}
// meanMaxUlpError = 4271.53318407361, lowerBound = -1.0000000000000006E-12, upperBound = inf
@ -2513,7 +2555,7 @@ zL = (L - mx)*sqrt(prec)
{
foreach (double upperBound in new[] { 0.0 }.Concat(UpperBounds(lowerBound)).Take(1))
{
if (trace) Console.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
if (trace) Console.WriteLine($"lowerBound = {lowerBound:g17}, upperBound = {upperBound:g17}");
double center = (lowerBound + upperBound) / 2;
if (double.IsNegativeInfinity(lowerBound) && double.IsPositiveInfinity(upperBound))
center = 0;
@ -2586,8 +2628,8 @@ zL = (L - mx)*sqrt(prec)
}
if (trace)
{
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}");
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:g17}, upperBound = {meanMaxUlpErrorUpperBound:g17}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:g17}, upperBound = {precMaxUlpErrorUpperBound:g17}");
}
}
// meanMaxUlpError = 104.001435643838, lowerBound = -1.0000000000000022E-37, upperBound = 9.9000000000000191E-36
@ -2612,7 +2654,7 @@ zL = (L - mx)*sqrt(prec)
{
foreach (double upperBound in new[] { 1.0 }.Concat(UpperBounds(lowerBound)).Take(1))
{
if (trace) Console.WriteLine($"lowerBound = {lowerBound:r}, upperBound = {upperBound:r}");
if (trace) Console.WriteLine($"lowerBound = {lowerBound:g17}, upperBound = {upperBound:g17}");
Parallel.ForEach(Gaussians(), x =>
{
Gaussian toX = DoubleIsBetweenOp.XAverageConditional(true, x, lowerBound, upperBound);
@ -2678,8 +2720,8 @@ zL = (L - mx)*sqrt(prec)
}
if (trace)
{
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:r}, upperBound = {meanMaxUlpErrorUpperBound:r}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:r}, upperBound = {precMaxUlpErrorUpperBound:r}");
Console.WriteLine($"meanMaxUlpError = {meanMaxUlpError}, lowerBound = {meanMaxUlpErrorLowerBound:g17}, upperBound = {meanMaxUlpErrorUpperBound:g17}");
Console.WriteLine($"precMaxUlpError = {precMaxUlpError}, lowerBound = {precMaxUlpErrorLowerBound:g17}, upperBound = {precMaxUlpErrorUpperBound:g17}");
}
}
// meanMaxUlpError = 33584, lowerBound = -1E+30, upperBound = 9.9E+31
@ -3115,7 +3157,7 @@ weight * (tau + alphaX) + alphaX
Gaussian X = Gaussian.FromMeanAndPrecision(mean, System.Math.Pow(2, -i * 1 - 20));
Gaussian toX = DoubleIsBetweenOp.XAverageConditional_Slow(Bernoulli.PointMass(true), X, lowerBound, upperBound);
Gaussian toLowerBound = toLowerBoundPrev;// DoubleIsBetweenOp.LowerBoundAverageConditional_Slow(Bernoulli.PointMass(true), X, lowerBound, upperBound);
Trace.WriteLine($"{i} {X}: {toX.MeanTimesPrecision:r} {toX.Precision:r} {toLowerBound.MeanTimesPrecision:r} {toLowerBound.Precision:r}");
Trace.WriteLine($"{i} {X}: {toX.MeanTimesPrecision:g17} {toX.Precision:g17} {toLowerBound.MeanTimesPrecision:g17} {toLowerBound.Precision:g17}");
Assert.False(toLowerBound.IsPointMass);
if ((mean > 0 && toLowerBound.MeanTimesPrecision > toLowerBoundPrev.MeanTimesPrecision) ||
(mean < 0 && toLowerBound.MeanTimesPrecision < toLowerBoundPrev.MeanTimesPrecision))