From 69a7a1979ff13399e42d568190e7b29da94d9329 Mon Sep 17 00:00:00 2001 From: Tom Minka <8955276+tminka@users.noreply.github.com> Date: Wed, 18 Nov 2020 23:44:47 +0000 Subject: [PATCH] Improved convergence rate of GammaPower.FromMeanAndMeanPower (#301) * Improved accuracy of ExpOp.ExpAverageConditional * Tutorials project uses Unicode OutputEncoding * Fixed MethodBodySynthesizer --- src/Csoft/MethodBodySynthesizer.cs | 5 +-- src/Runtime/Distributions/Gamma.cs | 14 ++++---- src/Runtime/Distributions/GammaPower.cs | 13 ++++--- src/Runtime/Factors/Exp.cs | 32 ++++++++--------- src/Tutorials/RunMe.cs | 1 + test/Tests/Distributions/DistributionTests.cs | 12 +++++++ test/Tests/InferTests.cs | 36 ++++++++++++++----- test/Tests/Operators/OperatorTests.cs | 26 ++++++++++++++ 8 files changed, 102 insertions(+), 37 deletions(-) diff --git a/src/Csoft/MethodBodySynthesizer.cs b/src/Csoft/MethodBodySynthesizer.cs index 8d680217..ce501cd3 100644 --- a/src/Csoft/MethodBodySynthesizer.cs +++ b/src/Csoft/MethodBodySynthesizer.cs @@ -43,7 +43,8 @@ namespace Microsoft.ML.Probabilistic.Compiler { pdCache[paramDecl.Name] = paramDecl; } - methodDecl.Body = ConvertBlock(methodSyntax.Body); + if (methodSyntax.Body != null) + methodDecl.Body = ConvertBlock(methodSyntax.Body); pdCache.Clear(); vdCache.Clear(); } @@ -599,7 +600,7 @@ namespace Microsoft.ML.Probabilistic.Compiler private IExpression ConvertNamedType(INamedTypeSymbol symbol, MemberAccessExpressionSyntax memberAccess) { var typeRef = (ITypeReference)ConvertTypeReference(symbol); - return new XTypeReferenceExpression + return new XTypeReferenceExpression { Type = typeRef }; diff --git a/src/Runtime/Distributions/Gamma.cs b/src/Runtime/Distributions/Gamma.cs index a4105834..773c53f7 100644 --- a/src/Runtime/Distributions/Gamma.cs +++ b/src/Runtime/Distributions/Gamma.cs @@ -140,7 +140,7 @@ namespace Microsoft.ML.Probabilistic.Distributions } else if (variance < 0) { - throw new ArgumentException("variance < 0 (" + variance + ")"); + throw new ArgumentOutOfRangeException(nameof(variance), variance, "variance < 0"); } else { @@ -227,7 +227,7 @@ namespace Microsoft.ML.Probabilistic.Distributions /// Scale public void SetShapeAndScale(double shape, double scale) { - if (double.IsPositiveInfinity(shape)) throw new ArgumentOutOfRangeException(nameof(shape), "shape is infinite. To create a point mass, set the Point property."); + if (double.IsPositiveInfinity(shape)) throw new ArgumentOutOfRangeException(nameof(shape), shape, "shape is infinite. To create a point mass, set the Point property."); SetShapeAndRate(shape, 1.0 / scale); } @@ -296,7 +296,7 @@ namespace Microsoft.ML.Probabilistic.Distributions if (MMath.AreEqual(g, 0)) break; shape /= 1 + g / (1 - shape * MMath.Trigamma(shape)); } - if (double.IsNaN(shape)) throw new Exception("shape is nan"); + if (double.IsNaN(shape)) throw new InferRuntimeException("shape is nan"); if (shape > double.MaxValue) return Gamma.PointMass(mean); return Gamma.FromShapeAndRate(shape, shape / mean); } @@ -473,8 +473,8 @@ namespace Microsoft.ML.Probabilistic.Distributions /// public double GetQuantile(double probability) { - if (probability < 0) throw new ArgumentOutOfRangeException("probability < 0"); - if (probability > 1) throw new ArgumentOutOfRangeException("probability > 1"); + if (probability < 0) throw new ArgumentOutOfRangeException(nameof(probability), probability, "probability < 0"); + if (probability > 1) throw new ArgumentOutOfRangeException(nameof(probability), probability, "probability > 1"); if (this.IsPointMass) { return (probability == 1.0) ? MMath.NextDouble(this.Point) : this.Point; @@ -881,7 +881,7 @@ namespace Microsoft.ML.Probabilistic.Distributions } else { - throw new DivideByZeroException(); + throw new DivideByZeroException($"numerator {numerator} and denominator {denominator} are different point masses"); } } else @@ -891,7 +891,7 @@ namespace Microsoft.ML.Probabilistic.Distributions } else if (denominator.IsPointMass) { - throw new DivideByZeroException(); + throw new DivideByZeroException($"denominator {denominator} is a point mass but numerator {numerator} is not"); } else if (forceProper && numerator.Rate == 0) { diff --git a/src/Runtime/Distributions/GammaPower.cs b/src/Runtime/Distributions/GammaPower.cs index 570fe315..e344dc58 100644 --- a/src/Runtime/Distributions/GammaPower.cs +++ b/src/Runtime/Distributions/GammaPower.cs @@ -354,14 +354,19 @@ namespace Microsoft.ML.Probabilistic.Distributions // derivative wrt shape is (digamma(Shape + power) - digamma(Shape))/power double risingFactorial = MMath.RisingFactorialLnOverN(shape, power); logRate = risingFactorial - logMeanOverPower; - if (power == 2 && backtrackCount < 2) + if (power > 1 && backtrackCount < 2) { // logRate = risingFactorial - logMeanOverPower // meanLogOverPower = digamma(shape) - logRate // = digamma(shape) - risingFactorial + logMeanOverPower - // = digamma(shape)-log(shape)+0.5/shape +log(shape)+0.5/shape - risingFactorial + logMeanOverPower - 1/shape - // 1/shape = digamma(shape)-log(shape)+0.5/shape +log(shape)+0.5/shape - risingFactorial + delta/power - shape = 1 / ((MMath.Digamma(shape) - (Math.Log(shape) - 0.5 / shape)) + (0.5 / shape + Math.Log(shape) - risingFactorial) + delta / power); + // = digamma(shape)-log(shape)+0.5/shape +log(shape)+(power-1)*0.5/shape - risingFactorial + logMeanOverPower - power*0.5/shape + // power*0.5/shape = digamma(shape)-log(shape)+0.5/shape +log(shape)+(power-1)*0.5/shape - risingFactorial + delta/power + // log(x+n) <= log(x) + n/x + // risingFactorial <= log(x) + (n-1)/2/x if n >= 1 + // risingFactorial <= log(x) + (1-n)/2/x if n <= -1 + double logShape = Math.Log(shape); + double halfOverShape = 0.5 / shape; + shape = power * 0.5 / ((MMath.Digamma(shape) - (logShape - halfOverShape)) + ((power - 1) * halfOverShape + logShape - risingFactorial) + delta / power); } else { diff --git a/src/Runtime/Factors/Exp.cs b/src/Runtime/Factors/Exp.cs index d2726a88..a80adb9a 100644 --- a/src/Runtime/Factors/Exp.cs +++ b/src/Runtime/Factors/Exp.cs @@ -264,8 +264,7 @@ namespace Microsoft.ML.Probabilistic.Factors double Z = 0; double sumY = 0; - double sumExpY = 0; - //double sumExpMinusY = 0; + double logsumExpY = double.NegativeInfinity; bool useHermite = true; //if (vD < 10) if (useHermite) @@ -297,7 +296,9 @@ namespace Microsoft.ML.Probabilistic.Factors for (int i = 0; i < weights.Length; i++) { double y = nodes[i] + mD; - double logf = shapeMinus1 * y - exp.Rate * Math.Exp(y) + weights[i]; + double logf = weights[i]; + if (shapeMinus1 != 0) logf += shapeMinus1 * y; // avoid 0*inf + if (exp.Rate != 0) logf -= exp.Rate * Math.Exp(y); // avoid 0*inf if (logf > maxLogF) { maxLogF = logf; @@ -307,13 +308,12 @@ namespace Microsoft.ML.Probabilistic.Factors for (int i = 0; i < weights.Length; i++) { double y = nodes[i]; - double f = Math.Exp(weights[i] - maxLogF); + double logf = weights[i] - maxLogF; + double f = Math.Exp(logf); double f_y = f * y; - double fexpy = f * Math.Exp(y); Z += f; sumY += f_y; - sumExpY += fexpy; - //sumExpMinusY += f * Math.Exp(-y); + logsumExpY = MMath.LogSumExp(logsumExpY, logf + y); } } else @@ -326,16 +326,15 @@ namespace Microsoft.ML.Probabilistic.Factors double scale = 1; Z = Quadrature.AdaptiveClenshawCurtis(z => Math.Exp(p(sc * z + mD) - offset), scale, nodeCount, relTol); sumY = Quadrature.AdaptiveClenshawCurtis(z => (sc * z) * Math.Exp(p(sc * z + mD) - offset), scale, nodeCount, relTol); - sumExpY = Quadrature.AdaptiveClenshawCurtis(z => Math.Exp(sc * z + p(sc * z + mD) - offset), scale, nodeCount, relTol); + double sumExpY = Quadrature.AdaptiveClenshawCurtis(z => Math.Exp(sc * z + p(sc * z + mD) - offset), scale, nodeCount, relTol); + logsumExpY = Math.Log(sumExpY); } if (Z == 0) throw new InferRuntimeException("Z==0"); double meanLog = sumY / Z + mD; - double expmD = Math.Exp(mD); - double mean = sumExpY / Z * expmD; - //double meanInverse = sumExpMinusY / Z / expmD; - //Trace.WriteLine($"mean = {mean} meanLog = {meanLog} meanInverse = {meanInverse}"); - Gamma result = Gamma.FromMeanAndMeanLog(mean, meanLog); + double logMean = logsumExpY - Math.Log(Z) + mD; + double mean = Math.Exp(logMean); + Gamma result = Gamma.FromMeanAndMeanLog(mean, meanLog, logMean); result.SetToRatio(result, exp, ForceProper); if (Double.IsNaN(result.Shape) || Double.IsNaN(result.Rate)) throw new InferRuntimeException($"result is NaN. exp={exp}, d={d}, to_d={to_d}"); @@ -949,9 +948,10 @@ namespace Microsoft.ML.Probabilistic.Factors //double bpost = b + d.Precision/expx; //double mpost = expx - d.Precision*(MMath.Digamma(apost) - Math.Log(apost))/bpost; double v = 1 / (d.Precision + b * expx); - double mlogpost = x - 0.5 * v * v * b * expx; - double mpost = expx * (1 + 0.5 * v * v * d.Precision); - Gamma result = Gamma.FromMeanAndMeanLog(mpost, mlogpost); + double meanLog = x - 0.5 * v * v * b * expx; + double logMean = x + Math.Log(1 + 0.5 * v * v * d.Precision); + double mean = Math.Exp(logMean); + Gamma result = Gamma.FromMeanAndMeanLog(mean, meanLog, logMean); result.SetToRatio(result, exp, true); return result; } diff --git a/src/Tutorials/RunMe.cs b/src/Tutorials/RunMe.cs index 4fa57871..32843b8a 100644 --- a/src/Tutorials/RunMe.cs +++ b/src/Tutorials/RunMe.cs @@ -34,6 +34,7 @@ namespace Microsoft.ML.Probabilistic.Tutorials //Tutorials //Uncomment one of these lines to run a particular tutorial in console application + Console.OutputEncoding = System.Text.Encoding.Unicode; new FirstExample().Run(); //new TruncatedGaussian().Run(); diff --git a/test/Tests/Distributions/DistributionTests.cs b/test/Tests/Distributions/DistributionTests.cs index 6a7d0705..fd2f1d00 100644 --- a/test/Tests/Distributions/DistributionTests.cs +++ b/test/Tests/Distributions/DistributionTests.cs @@ -298,6 +298,18 @@ namespace Microsoft.ML.Probabilistic.Tests //SetMomentTest(g, 1.1, 2.2); PointMassMomentTest(g, 7.7, 4.4, 5.5); SamplingTest(g, 7.7); + + var ratio = g / g2; + Assert.Throws(() => + { + ratio = g / new TruncatedGaussian(4.4, 5.5, lowerBound + 1, upperBound); + }); + ratio = TruncatedGaussian.PointMass(lowerBound) / new TruncatedGaussian(4.4, 5.5, lowerBound, upperBound); + Assert.Throws(() => + { + ratio = TruncatedGaussian.PointMass(2) / new TruncatedGaussian(4.4, 5.5, lowerBound + 1, upperBound); + }); + g.SetToUniform(); //GetAndSetMomentTest(g, 0.0, Double.PositiveInfinity); diff --git a/test/Tests/InferTests.cs b/test/Tests/InferTests.cs index d690164a..8444eb7d 100644 --- a/test/Tests/InferTests.cs +++ b/test/Tests/InferTests.cs @@ -1433,16 +1433,36 @@ namespace Microsoft.ML.Probabilistic.Tests InferenceEngine engine = new InferenceEngine(); // This namespace is deliberately chosen to cause a potential name conflict when referring to Math.Exp. engine.ModelNamespace = "Microsoft.ML.Probabilistic"; - engine.Compiler.GivePriorityTo(typeof(GammaFromShapeAndRateOp_Laplace)); - //engine.Algorithm = new VariationalMessagePassing(); Gaussian xExpected = new Gaussian(1.395, 0.0568); // VMP on simple Poisson xExpected = new Gaussian(1.395, 0.05716); // EP on simple Poisson - Gaussian xActual = engine.Infer(x); - Console.WriteLine("x = {0} should be {1}", xActual, xExpected); - Assert.True(xExpected.MaxDiff(xActual) < 1); - IList p = engine.Infer>(poissonRate); - Console.WriteLine(p); - Console.WriteLine(StringUtil.VerboseToString(p.Select(g => g.Shape))); + bool useSampling = false; + if(useSampling) + { + GaussianEstimator estimator = new GaussianEstimator(); + for (int iter = 0; iter < 100_000; iter++) + { + double xSample = Gaussian.FromMeanAndPrecision(0, 1).Sample(); + double logWeight = data.Sum(yi => { + double rateSample = Gamma.FromShapeAndRate(System.Math.Exp(gammaLnShape), System.Math.Exp(gammaLnShape - xSample)).Sample(); + Poisson yDist = new Poisson(rateSample); + return yDist.GetLogProb(yi); + }); + double weight = System.Math.Exp(logWeight); + estimator.Add(xSample, weight); + } + xExpected = estimator.GetDistribution(new Gaussian()); + } + for (int trial = 0; trial < 2; trial++) + { + if(trial == 1) engine.Compiler.GivePriorityTo(typeof(GammaFromShapeAndRateOp_Laplace)); + else if(trial == 2) engine.Algorithm = new VariationalMessagePassing(); + Gaussian xActual = engine.Infer(x); + Console.WriteLine("x = {0} should be {1}", xActual, xExpected); + Assert.True(xExpected.MaxDiff(xActual) < 1); + IList p = engine.Infer>(poissonRate); + //Console.WriteLine(p); + //Console.WriteLine(StringUtil.VerboseToString(p.Select(g => g.Shape))); + } } // model from Thore Graepel diff --git a/test/Tests/Operators/OperatorTests.cs b/test/Tests/Operators/OperatorTests.cs index 90520aa3..05bbf9c0 100644 --- a/test/Tests/Operators/OperatorTests.cs +++ b/test/Tests/Operators/OperatorTests.cs @@ -654,10 +654,36 @@ namespace Microsoft.ML.Probabilistic.Tests } } + [Fact] + public void ExpOpGammaPowerTest() + { + ExpOp.ExpAverageConditional(GammaPower.FromShapeAndRate(-1, 283.673, -1), Gaussian.FromNatural(0.004859823703146038, 6.6322755562737905E-06), Gaussian.FromNatural(0.00075506803981220758, 8.24487022054953E-07)); + GammaPower exp = GammaPower.FromShapeAndRate(0, 0, -1); + Gaussian[] ds = new[] + { + Gaussian.FromNatural(-1.6171314269768655E+308, 4.8976001759138024), + Gaussian.FromNatural(-0.037020622891705768, 0.00034989765084474117), + Gaussian.PointMass(double.NegativeInfinity), + }; + foreach (var d in ds) + { + Gaussian to_d = ExpOp.DAverageConditional(exp, d, Gaussian.Uniform()); + Gaussian to_d_slow = ExpOp_Slow.DAverageConditional(exp, d); + Trace.WriteLine($"{to_d}"); + Trace.WriteLine($"{to_d_slow}"); + Assert.True(to_d_slow.MaxDiff(to_d) < 1e-10); + to_d = Gaussian.FromNatural(1, 0); + GammaPower to_exp = ExpOp.ExpAverageConditional(exp, d, to_d); + Trace.WriteLine($"{to_exp}"); + } + ExpOp.ExpAverageConditional(GammaPower.FromShapeAndRate(-1, 883.22399999999993, -1), Gaussian.FromNatural(0.0072160312702854888, 8.1788482512051846E-06), Gaussian.FromNatural(0.00057861649495666474, 5.6316164560235272E-07)); + } + [Fact] [Trait("Category", "OpenBug")] public void ExpOpTest() { + Assert.True(ExpOp.ExpAverageConditional(Gamma.FromShapeAndRate(3.302758272196654, 0.00060601537137241492), Gaussian.FromNatural(55.350150233321628, 6.3510247863590683), Gaussian.FromNatural(27.960892513144643, 3.4099170930572216)).Rate > 0); Gamma exp = new Gamma(1, 1); Gaussian[] ds = new[] {