Added GatedInnerProductVectorTest and missing overloads of InnerProductOp

This commit is contained in:
Tom Minka 2019-08-23 16:03:05 +01:00
Родитель 98e07ae335
Коммит a280a6272f
2 изменённых файлов: 89 добавлений и 16 удалений

Просмотреть файл

@ -87,8 +87,7 @@ namespace Microsoft.ML.Probabilistic.Factors
[NotSupported(InnerProductOp.NotSupportedMessage)]
public static VectorGaussian BAverageLogarithm(double innerProduct, [SkipIfUniform] VectorGaussian A, VectorGaussian result)
{
throw new NotSupportedException(InnerProductOp.NotSupportedMessage);
//return AAverageLogarithm(innerProduct, A, result);
return AAverageLogarithm(innerProduct, A, result);
}
private const string LowRankNotSupportedMessage = "A InnerProduct factor with fixed output is not yet implemented.";
@ -120,29 +119,21 @@ namespace Microsoft.ML.Probabilistic.Factors
[NotSupported(InnerProductOp.LowRankNotSupportedMessage)]
public static VectorGaussian BAverageConditional(double innerProduct, Vector A, VectorGaussian result)
{
throw new NotImplementedException(LowRankNotSupportedMessage);
//return AAverageConditional(innerProduct, A, result);
return AAverageConditional(innerProduct, A, result);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="InnerProductOp"]/message_doc[@name="AAverageLogarithm(double, Vector, VectorGaussian)"]/*'/>
[NotSupported(InnerProductOp.LowRankNotSupportedMessage)]
public static VectorGaussian AAverageLogarithm(double innerProduct, Vector B, VectorGaussian result)
{
// This case could be supported if we had low-rank VectorGaussian distributions.
throw new NotSupportedException(LowRankNotSupportedMessage);
//if (result == default(VectorGaussian))
// result = new VectorGaussian(B.Count);
//result.Point = result.Point;
//result.Point.SetToProduct(B, innerProduct / B.Inner(B));
//return result;
return AAverageConditional(innerProduct, B, result);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="InnerProductOp"]/message_doc[@name="BAverageLogarithm(double, Vector, VectorGaussian)"]/*'/>
[NotSupported(InnerProductOp.LowRankNotSupportedMessage)]
public static VectorGaussian BAverageLogarithm(double innerProduct, Vector A, VectorGaussian result)
{
throw new NotSupportedException(LowRankNotSupportedMessage);
//return AAverageLogarithm(innerProduct, A, result);
return AAverageLogarithm(innerProduct, A, result);
}
//-- VMP ---------------------------------------------------------------------------------------------

Просмотреть файл

@ -4346,6 +4346,14 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void GatedInnerProductArrayTest()
{
foreach (var flip in new[] { false, true })
{
GatedInnerProductArray(flip);
}
}
internal void GatedInnerProductArray(bool flip)
{
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
var evBlock = Variable.If(evidence);
@ -4354,7 +4362,7 @@ namespace Microsoft.ML.Probabilistic.Tests
Gaussian aPrior = Gaussian.FromMeanAndVariance(1.2, 3.4);
a[item] = Variable<double>.Random(aPrior).ForEach(item);
VariableArray<double> b = Variable.Observed(default(double[]), item).Named("b");
Variable<double> c = Variable.InnerProduct(a, b);
Variable<double> c = flip ? Variable.InnerProduct(b, a) : Variable.InnerProduct(a, b);
Variable<Gaussian> cLike = Variable.Observed(default(Gaussian));
Variable.ConstrainEqualRandom(c, cLike);
evBlock.CloseBlock();
@ -4367,7 +4375,7 @@ namespace Microsoft.ML.Probabilistic.Tests
b.ObservedValue = new[] { 2.0, 3.0 };
cLike.ObservedValue = Gaussian.FromMeanAndVariance(4, 5);
}
else if(trial == 1)
else if (trial == 1)
{
b.ObservedValue = new[] { 0.0, 0.0 };
cLike.ObservedValue = Gaussian.PointMass(0);
@ -4403,6 +4411,80 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
[Fact]
public void GatedInnerProductVectorTest()
{
foreach (var algorithm in new IAlgorithm[] { new Algorithms.ExpectationPropagation(), new Algorithms.VariationalMessagePassing() })
{
foreach (var flip in new[] { false, true })
{
GatedInnerProductVector(flip, algorithm);
}
}
}
internal void GatedInnerProductVector(bool flip, IAlgorithm algorithm)
{
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
var evBlock = Variable.If(evidence);
Range item = new Range(2).Named("item");
DenseVector aMean = DenseVector.FromArray(1.2, 3.4);
PositiveDefiniteMatrix aVariance = new PositiveDefiniteMatrix(new double[,] {
{ 4.5, 2.3 }, { 2.3, 6.7 }
});
VectorGaussian aPrior = VectorGaussian.FromMeanAndVariance(aMean, aVariance);
Variable<Vector> a = Variable.Random(aPrior).Named("a");
a.SetValueRange(item);
Variable<Vector> b = Variable.Observed(default(Vector)).Named("b");
b.SetValueRange(item);
Variable<double> c = flip ? Variable.InnerProduct(b, a) : Variable.InnerProduct(a, b);
Variable<Gaussian> cLike = Variable.Observed(default(Gaussian));
Variable.ConstrainEqualRandom(c, cLike);
evBlock.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
for (int trial = 0; trial <= 2; trial++)
{
if (trial == 0)
{
b.ObservedValue = Vector.FromArray(2.0, 3.0);
cLike.ObservedValue = Gaussian.FromMeanAndVariance(4, 5);
}
else if (trial == 1)
{
b.ObservedValue = Vector.FromArray(0.0, 0.0);
cLike.ObservedValue = Gaussian.PointMass(0);
}
else
{
b.ObservedValue = Vector.FromArray(0.0, 2.0);
cLike.ObservedValue = Gaussian.PointMass(2.3);
}
var aActual = engine.Infer<VectorGaussian>(a);
VectorGaussian vectorGaussianExpected = new VectorGaussian(item.SizeAsInt);
var bVector = b.ObservedValue;
InnerProductOp.AAverageConditional(cLike.ObservedValue, bVector, vectorGaussianExpected);
vectorGaussianExpected.SetToProduct(vectorGaussianExpected, aPrior);
Assert.True(vectorGaussianExpected.MaxDiff(aActual) < 1e-10);
var cActual = engine.Infer<Gaussian>(c);
var toC = InnerProductOp.InnerProductAverageConditional(aMean, aVariance, bVector);
var cExpected = toC * cLike.ObservedValue;
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
double evExpected;
if(engine.Algorithm is Algorithms.VariationalMessagePassing)
{
// The only stochastic variable is A, and the only stochastic factors are A's prior and cLike.
evExpected = aActual.GetAverageLog(aPrior) + cExpected.GetAverageLog(cLike.ObservedValue)
-aActual.GetAverageLog(aActual);
}
else evExpected = cLike.ObservedValue.GetLogAverageOf(toC);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-8) < 1e-10);
}
}
[Fact]
[Trait("Category", "CsoftModel")]
public void GatedInnerProductRRCTest()
@ -5177,7 +5259,7 @@ namespace Microsoft.ML.Probabilistic.Tests
double evExpected = 0;
if (!double.IsPositiveInfinity(xObserved))
{
evExpected = Gamma.FromShapeAndRate(shape, rate/xObserved).GetLogProb(xObserved);
evExpected = Gamma.FromShapeAndRate(shape, rate / xObserved).GetLogProb(xObserved);
yExpected = Gamma.PointMass(1);
}