From 69a7a1979ff13399e42d568190e7b29da94d9329 Mon Sep 17 00:00:00 2001
From: Tom Minka <8955276+tminka@users.noreply.github.com>
Date: Wed, 18 Nov 2020 23:44:47 +0000
Subject: [PATCH] Improved convergence rate of GammaPower.FromMeanAndMeanPower
(#301)
* Improved accuracy of ExpOp.ExpAverageConditional
* Tutorials project uses Unicode OutputEncoding
* Fixed MethodBodySynthesizer
---
src/Csoft/MethodBodySynthesizer.cs | 5 +--
src/Runtime/Distributions/Gamma.cs | 14 ++++----
src/Runtime/Distributions/GammaPower.cs | 13 ++++---
src/Runtime/Factors/Exp.cs | 32 ++++++++---------
src/Tutorials/RunMe.cs | 1 +
test/Tests/Distributions/DistributionTests.cs | 12 +++++++
test/Tests/InferTests.cs | 36 ++++++++++++++-----
test/Tests/Operators/OperatorTests.cs | 26 ++++++++++++++
8 files changed, 102 insertions(+), 37 deletions(-)
diff --git a/src/Csoft/MethodBodySynthesizer.cs b/src/Csoft/MethodBodySynthesizer.cs
index 8d680217..ce501cd3 100644
--- a/src/Csoft/MethodBodySynthesizer.cs
+++ b/src/Csoft/MethodBodySynthesizer.cs
@@ -43,7 +43,8 @@ namespace Microsoft.ML.Probabilistic.Compiler
{
pdCache[paramDecl.Name] = paramDecl;
}
- methodDecl.Body = ConvertBlock(methodSyntax.Body);
+ if (methodSyntax.Body != null)
+ methodDecl.Body = ConvertBlock(methodSyntax.Body);
pdCache.Clear();
vdCache.Clear();
}
@@ -599,7 +600,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
private IExpression ConvertNamedType(INamedTypeSymbol symbol, MemberAccessExpressionSyntax memberAccess)
{
var typeRef = (ITypeReference)ConvertTypeReference(symbol);
- return new XTypeReferenceExpression
+ return new XTypeReferenceExpression
{
Type = typeRef
};
diff --git a/src/Runtime/Distributions/Gamma.cs b/src/Runtime/Distributions/Gamma.cs
index a4105834..773c53f7 100644
--- a/src/Runtime/Distributions/Gamma.cs
+++ b/src/Runtime/Distributions/Gamma.cs
@@ -140,7 +140,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
else if (variance < 0)
{
- throw new ArgumentException("variance < 0 (" + variance + ")");
+ throw new ArgumentOutOfRangeException(nameof(variance), variance, "variance < 0");
}
else
{
@@ -227,7 +227,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
/// Scale
public void SetShapeAndScale(double shape, double scale)
{
- if (double.IsPositiveInfinity(shape)) throw new ArgumentOutOfRangeException(nameof(shape), "shape is infinite. To create a point mass, set the Point property.");
+ if (double.IsPositiveInfinity(shape)) throw new ArgumentOutOfRangeException(nameof(shape), shape, "shape is infinite. To create a point mass, set the Point property.");
SetShapeAndRate(shape, 1.0 / scale);
}
@@ -296,7 +296,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
if (MMath.AreEqual(g, 0)) break;
shape /= 1 + g / (1 - shape * MMath.Trigamma(shape));
}
- if (double.IsNaN(shape)) throw new Exception("shape is nan");
+ if (double.IsNaN(shape)) throw new InferRuntimeException("shape is nan");
if (shape > double.MaxValue) return Gamma.PointMass(mean);
return Gamma.FromShapeAndRate(shape, shape / mean);
}
@@ -473,8 +473,8 @@ namespace Microsoft.ML.Probabilistic.Distributions
///
public double GetQuantile(double probability)
{
- if (probability < 0) throw new ArgumentOutOfRangeException("probability < 0");
- if (probability > 1) throw new ArgumentOutOfRangeException("probability > 1");
+ if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), probability, "probability < 0");
+ if (probability > 1) throw new ArgumentOutOfRangeException(nameof(probability), probability, "probability > 1");
if (this.IsPointMass)
{
return (probability == 1.0) ? MMath.NextDouble(this.Point) : this.Point;
@@ -881,7 +881,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
else
{
- throw new DivideByZeroException();
+ throw new DivideByZeroException($"numerator {numerator} and denominator {denominator} are different point masses");
}
}
else
@@ -891,7 +891,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
else if (denominator.IsPointMass)
{
- throw new DivideByZeroException();
+ throw new DivideByZeroException($"denominator {denominator} is a point mass but numerator {numerator} is not");
}
else if (forceProper && numerator.Rate == 0)
{
diff --git a/src/Runtime/Distributions/GammaPower.cs b/src/Runtime/Distributions/GammaPower.cs
index 570fe315..e344dc58 100644
--- a/src/Runtime/Distributions/GammaPower.cs
+++ b/src/Runtime/Distributions/GammaPower.cs
@@ -354,14 +354,19 @@ namespace Microsoft.ML.Probabilistic.Distributions
// derivative wrt shape is (digamma(Shape + power) - digamma(Shape))/power
double risingFactorial = MMath.RisingFactorialLnOverN(shape, power);
logRate = risingFactorial - logMeanOverPower;
- if (power == 2 && backtrackCount < 2)
+ if (power > 1 && backtrackCount < 2)
{
// logRate = risingFactorial - logMeanOverPower
// meanLogOverPower = digamma(shape) - logRate
// = digamma(shape) - risingFactorial + logMeanOverPower
- // = digamma(shape)-log(shape)+0.5/shape +log(shape)+0.5/shape - risingFactorial + logMeanOverPower - 1/shape
- // 1/shape = digamma(shape)-log(shape)+0.5/shape +log(shape)+0.5/shape - risingFactorial + delta/power
- shape = 1 / ((MMath.Digamma(shape) - (Math.Log(shape) - 0.5 / shape)) + (0.5 / shape + Math.Log(shape) - risingFactorial) + delta / power);
+ // = digamma(shape)-log(shape)+0.5/shape +log(shape)+(power-1)*0.5/shape - risingFactorial + logMeanOverPower - power*0.5/shape
+ // power*0.5/shape = digamma(shape)-log(shape)+0.5/shape +log(shape)+(power-1)*0.5/shape - risingFactorial + delta/power
+ // log(x+n) <= log(x) + n/x
+ // risingFactorial <= log(x) + (n-1)/2/x if n >= 1
+ // risingFactorial <= log(x) + (1-n)/2/x if n <= -1
+ double logShape = Math.Log(shape);
+ double halfOverShape = 0.5 / shape;
+ shape = power * 0.5 / ((MMath.Digamma(shape) - (logShape - halfOverShape)) + ((power - 1) * halfOverShape + logShape - risingFactorial) + delta / power);
}
else
{
diff --git a/src/Runtime/Factors/Exp.cs b/src/Runtime/Factors/Exp.cs
index d2726a88..a80adb9a 100644
--- a/src/Runtime/Factors/Exp.cs
+++ b/src/Runtime/Factors/Exp.cs
@@ -264,8 +264,7 @@ namespace Microsoft.ML.Probabilistic.Factors
double Z = 0;
double sumY = 0;
- double sumExpY = 0;
- //double sumExpMinusY = 0;
+ double logsumExpY = double.NegativeInfinity;
bool useHermite = true;
//if (vD < 10)
if (useHermite)
@@ -297,7 +296,9 @@ namespace Microsoft.ML.Probabilistic.Factors
for (int i = 0; i < weights.Length; i++)
{
double y = nodes[i] + mD;
- double logf = shapeMinus1 * y - exp.Rate * Math.Exp(y) + weights[i];
+ double logf = weights[i];
+ if (shapeMinus1 != 0) logf += shapeMinus1 * y; // avoid 0*inf
+ if (exp.Rate != 0) logf -= exp.Rate * Math.Exp(y); // avoid 0*inf
if (logf > maxLogF)
{
maxLogF = logf;
@@ -307,13 +308,12 @@ namespace Microsoft.ML.Probabilistic.Factors
for (int i = 0; i < weights.Length; i++)
{
double y = nodes[i];
- double f = Math.Exp(weights[i] - maxLogF);
+ double logf = weights[i] - maxLogF;
+ double f = Math.Exp(logf);
double f_y = f * y;
- double fexpy = f * Math.Exp(y);
Z += f;
sumY += f_y;
- sumExpY += fexpy;
- //sumExpMinusY += f * Math.Exp(-y);
+ logsumExpY = MMath.LogSumExp(logsumExpY, logf + y);
}
}
else
@@ -326,16 +326,15 @@ namespace Microsoft.ML.Probabilistic.Factors
double scale = 1;
Z = Quadrature.AdaptiveClenshawCurtis(z => Math.Exp(p(sc * z + mD) - offset), scale, nodeCount, relTol);
sumY = Quadrature.AdaptiveClenshawCurtis(z => (sc * z) * Math.Exp(p(sc * z + mD) - offset), scale, nodeCount, relTol);
- sumExpY = Quadrature.AdaptiveClenshawCurtis(z => Math.Exp(sc * z + p(sc * z + mD) - offset), scale, nodeCount, relTol);
+ double sumExpY = Quadrature.AdaptiveClenshawCurtis(z => Math.Exp(sc * z + p(sc * z + mD) - offset), scale, nodeCount, relTol);
+ logsumExpY = Math.Log(sumExpY);
}
if (Z == 0)
throw new InferRuntimeException("Z==0");
double meanLog = sumY / Z + mD;
- double expmD = Math.Exp(mD);
- double mean = sumExpY / Z * expmD;
- //double meanInverse = sumExpMinusY / Z / expmD;
- //Trace.WriteLine($"mean = {mean} meanLog = {meanLog} meanInverse = {meanInverse}");
- Gamma result = Gamma.FromMeanAndMeanLog(mean, meanLog);
+ double logMean = logsumExpY - Math.Log(Z) + mD;
+ double mean = Math.Exp(logMean);
+ Gamma result = Gamma.FromMeanAndMeanLog(mean, meanLog, logMean);
result.SetToRatio(result, exp, ForceProper);
if (Double.IsNaN(result.Shape) || Double.IsNaN(result.Rate))
throw new InferRuntimeException($"result is NaN. exp={exp}, d={d}, to_d={to_d}");
@@ -949,9 +948,10 @@ namespace Microsoft.ML.Probabilistic.Factors
//double bpost = b + d.Precision/expx;
//double mpost = expx - d.Precision*(MMath.Digamma(apost) - Math.Log(apost))/bpost;
double v = 1 / (d.Precision + b * expx);
- double mlogpost = x - 0.5 * v * v * b * expx;
- double mpost = expx * (1 + 0.5 * v * v * d.Precision);
- Gamma result = Gamma.FromMeanAndMeanLog(mpost, mlogpost);
+ double meanLog = x - 0.5 * v * v * b * expx;
+ double logMean = x + Math.Log(1 + 0.5 * v * v * d.Precision);
+ double mean = Math.Exp(logMean);
+ Gamma result = Gamma.FromMeanAndMeanLog(mean, meanLog, logMean);
result.SetToRatio(result, exp, true);
return result;
}
diff --git a/src/Tutorials/RunMe.cs b/src/Tutorials/RunMe.cs
index 4fa57871..32843b8a 100644
--- a/src/Tutorials/RunMe.cs
+++ b/src/Tutorials/RunMe.cs
@@ -34,6 +34,7 @@ namespace Microsoft.ML.Probabilistic.Tutorials
//Tutorials
//Uncomment one of these lines to run a particular tutorial in console application
+ Console.OutputEncoding = System.Text.Encoding.Unicode;
new FirstExample().Run();
//new TruncatedGaussian().Run();
diff --git a/test/Tests/Distributions/DistributionTests.cs b/test/Tests/Distributions/DistributionTests.cs
index 6a7d0705..fd2f1d00 100644
--- a/test/Tests/Distributions/DistributionTests.cs
+++ b/test/Tests/Distributions/DistributionTests.cs
@@ -298,6 +298,18 @@ namespace Microsoft.ML.Probabilistic.Tests
//SetMomentTest(g, 1.1, 2.2);
PointMassMomentTest(g, 7.7, 4.4, 5.5);
SamplingTest(g, 7.7);
+
+ var ratio = g / g2;
+ Assert.Throws(() =>
+ {
+ ratio = g / new TruncatedGaussian(4.4, 5.5, lowerBound + 1, upperBound);
+ });
+ ratio = TruncatedGaussian.PointMass(lowerBound) / new TruncatedGaussian(4.4, 5.5, lowerBound, upperBound);
+ Assert.Throws(() =>
+ {
+ ratio = TruncatedGaussian.PointMass(2) / new TruncatedGaussian(4.4, 5.5, lowerBound + 1, upperBound);
+ });
+
g.SetToUniform();
//GetAndSetMomentTest(g, 0.0, Double.PositiveInfinity);
diff --git a/test/Tests/InferTests.cs b/test/Tests/InferTests.cs
index d690164a..8444eb7d 100644
--- a/test/Tests/InferTests.cs
+++ b/test/Tests/InferTests.cs
@@ -1433,16 +1433,36 @@ namespace Microsoft.ML.Probabilistic.Tests
InferenceEngine engine = new InferenceEngine();
// This namespace is deliberately chosen to cause a potential name conflict when referring to Math.Exp.
engine.ModelNamespace = "Microsoft.ML.Probabilistic";
- engine.Compiler.GivePriorityTo(typeof(GammaFromShapeAndRateOp_Laplace));
- //engine.Algorithm = new VariationalMessagePassing();
Gaussian xExpected = new Gaussian(1.395, 0.0568); // VMP on simple Poisson
xExpected = new Gaussian(1.395, 0.05716); // EP on simple Poisson
- Gaussian xActual = engine.Infer(x);
- Console.WriteLine("x = {0} should be {1}", xActual, xExpected);
- Assert.True(xExpected.MaxDiff(xActual) < 1);
- IList p = engine.Infer>(poissonRate);
- Console.WriteLine(p);
- Console.WriteLine(StringUtil.VerboseToString(p.Select(g => g.Shape)));
+ bool useSampling = false;
+ if(useSampling)
+ {
+ GaussianEstimator estimator = new GaussianEstimator();
+ for (int iter = 0; iter < 100_000; iter++)
+ {
+ double xSample = Gaussian.FromMeanAndPrecision(0, 1).Sample();
+ double logWeight = data.Sum(yi => {
+ double rateSample = Gamma.FromShapeAndRate(System.Math.Exp(gammaLnShape), System.Math.Exp(gammaLnShape - xSample)).Sample();
+ Poisson yDist = new Poisson(rateSample);
+ return yDist.GetLogProb(yi);
+ });
+ double weight = System.Math.Exp(logWeight);
+ estimator.Add(xSample, weight);
+ }
+ xExpected = estimator.GetDistribution(new Gaussian());
+ }
+ for (int trial = 0; trial < 2; trial++)
+ {
+ if(trial == 1) engine.Compiler.GivePriorityTo(typeof(GammaFromShapeAndRateOp_Laplace));
+ else if(trial == 2) engine.Algorithm = new VariationalMessagePassing();
+ Gaussian xActual = engine.Infer(x);
+ Console.WriteLine("x = {0} should be {1}", xActual, xExpected);
+ Assert.True(xExpected.MaxDiff(xActual) < 1);
+ IList p = engine.Infer>(poissonRate);
+ //Console.WriteLine(p);
+ //Console.WriteLine(StringUtil.VerboseToString(p.Select(g => g.Shape)));
+ }
}
// model from Thore Graepel
diff --git a/test/Tests/Operators/OperatorTests.cs b/test/Tests/Operators/OperatorTests.cs
index 90520aa3..05bbf9c0 100644
--- a/test/Tests/Operators/OperatorTests.cs
+++ b/test/Tests/Operators/OperatorTests.cs
@@ -654,10 +654,36 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
+ [Fact]
+ public void ExpOpGammaPowerTest()
+ {
+ ExpOp.ExpAverageConditional(GammaPower.FromShapeAndRate(-1, 283.673, -1), Gaussian.FromNatural(0.004859823703146038, 6.6322755562737905E-06), Gaussian.FromNatural(0.00075506803981220758, 8.24487022054953E-07));
+ GammaPower exp = GammaPower.FromShapeAndRate(0, 0, -1);
+ Gaussian[] ds = new[]
+ {
+ Gaussian.FromNatural(-1.6171314269768655E+308, 4.8976001759138024),
+ Gaussian.FromNatural(-0.037020622891705768, 0.00034989765084474117),
+ Gaussian.PointMass(double.NegativeInfinity),
+ };
+ foreach (var d in ds)
+ {
+ Gaussian to_d = ExpOp.DAverageConditional(exp, d, Gaussian.Uniform());
+ Gaussian to_d_slow = ExpOp_Slow.DAverageConditional(exp, d);
+ Trace.WriteLine($"{to_d}");
+ Trace.WriteLine($"{to_d_slow}");
+ Assert.True(to_d_slow.MaxDiff(to_d) < 1e-10);
+ to_d = Gaussian.FromNatural(1, 0);
+ GammaPower to_exp = ExpOp.ExpAverageConditional(exp, d, to_d);
+ Trace.WriteLine($"{to_exp}");
+ }
+ ExpOp.ExpAverageConditional(GammaPower.FromShapeAndRate(-1, 883.22399999999993, -1), Gaussian.FromNatural(0.0072160312702854888, 8.1788482512051846E-06), Gaussian.FromNatural(0.00057861649495666474, 5.6316164560235272E-07));
+ }
+
[Fact]
[Trait("Category", "OpenBug")]
public void ExpOpTest()
{
+ Assert.True(ExpOp.ExpAverageConditional(Gamma.FromShapeAndRate(3.302758272196654, 0.00060601537137241492), Gaussian.FromNatural(55.350150233321628, 6.3510247863590683), Gaussian.FromNatural(27.960892513144643, 3.4099170930572216)).Rate > 0);
Gamma exp = new Gamma(1, 1);
Gaussian[] ds = new[]
{