зеркало из https://github.com/dotnet/infer.git
DiscreteEnumAreEqualOp supports VMP (#387)
This commit is contained in:
Родитель
079250ef67
Коммит
b6dd9630f5
|
@ -26,7 +26,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a distribution over the enum values using the specified probs.
|
||||
/// Creates a distribution over the enum values using the specified probabilities.
|
||||
/// </summary>
|
||||
public DiscreteEnum(params double[] probs) :
|
||||
base(values.Length, Sparsity.Dense)
|
||||
|
@ -34,6 +34,15 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
|||
disc.SetProbs(Vector.FromArray(probs));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a distribution over the enum values using the specified probabilities.
|
||||
/// </summary>
|
||||
public DiscreteEnum(Vector probs) :
|
||||
base(values.Length, Sparsity.Dense)
|
||||
{
|
||||
disc.SetProbs(probs);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts from an integer to an enum value
|
||||
/// </summary>
|
||||
|
|
|
@ -233,6 +233,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return (int)(object)en;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> B, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, B.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum B, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
|
@ -245,6 +251,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, ToInt(B), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> AAverageLogarithm([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> B, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, B.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(bool, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> AAverageLogarithm(bool areEqual, TEnum B, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
|
@ -257,6 +269,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, ToInt(B), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, A.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
|
@ -269,6 +287,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, ToInt(A), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, A.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
|
||||
{
|
||||
|
@ -281,6 +305,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, ToInt(A), result.GetInternalDiscrete()));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static Bernoulli AreEqualAverageConditional(DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B)
|
||||
{
|
||||
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), B.GetInternalDiscrete());
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static Bernoulli AreEqualAverageConditional(TEnum A, DiscreteEnum<TEnum> B)
|
||||
{
|
||||
|
@ -293,6 +323,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), ToInt(B));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static Bernoulli AreEqualAverageLogarithm(DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B)
|
||||
{
|
||||
return DiscreteAreEqualOp.AreEqualAverageLogarithm(A.GetInternalDiscrete(), B.GetInternalDiscrete());
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static Bernoulli AreEqualAverageLogarithm(TEnum A, DiscreteEnum<TEnum> B)
|
||||
{
|
||||
|
@ -329,6 +365,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return DiscreteAreEqualOp.LogAverageFactor(areEqual, A.GetInternalDiscrete(), ToInt(B));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, DiscreteEnum{TEnum}, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static double LogEvidenceRatio(bool areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_A)
|
||||
{
|
||||
return DiscreteAreEqualOp.LogEvidenceRatio(areEqual, A.GetInternalDiscrete(), B.GetInternalDiscrete(), to_A.GetInternalDiscrete());
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, TEnum, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||
public static double LogEvidenceRatio(bool areEqual, TEnum A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_B)
|
||||
{
|
||||
|
|
|
@ -101,7 +101,7 @@ namespace TestApp
|
|||
//}
|
||||
//TestUtils.CheckTransformNames();
|
||||
}
|
||||
bool showFactorManager = true;
|
||||
bool showFactorManager = false;
|
||||
if (showFactorManager)
|
||||
{
|
||||
InferenceEngine.ShowFactorManager(true);
|
||||
|
|
|
@ -7,6 +7,7 @@ using Xunit;
|
|||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Models;
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Algorithms;
|
||||
using Range = Microsoft.ML.Probabilistic.Models.Range;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Tests
|
||||
|
@ -90,7 +91,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Assert.True(expected.MaxDiff(actual) < 1e-10);
|
||||
}
|
||||
|
||||
|
||||
public enum Outcome
|
||||
{
|
||||
Good,
|
||||
|
@ -133,5 +133,96 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Console.WriteLine("Distribution over outcomes if control = "
|
||||
+ DiscreteEnum<Outcome>.FromVector(ie.Infer<Dirichlet>(probIfControl).GetMean()));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GatedOutcomeAreEqualTest()
|
||||
{
|
||||
foreach (var algorithm in new Models.Attributes.IAlgorithm[] { new ExpectationPropagation(), new VariationalMessagePassing() })
|
||||
{
|
||||
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||
IfBlock block = Variable.If(evidence);
|
||||
Vector priorA = Vector.FromArray(0.1, 0.9);
|
||||
Vector priorB = Vector.FromArray(0.2, 0.8);
|
||||
Variable<Outcome> a = Variable.EnumDiscrete<Outcome>(priorA).Named("a");
|
||||
Variable<Outcome> b = Variable.EnumDiscrete<Outcome>(priorB).Named("b");
|
||||
Variable<bool> c = (a == b).Named("c");
|
||||
double priorC = 0.3;
|
||||
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
|
||||
block.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
|
||||
double probEqual = priorA.Inner(priorB);
|
||||
double evPrior = 0;
|
||||
for (int atrial = 0; atrial < 2; atrial++)
|
||||
{
|
||||
if (atrial == 1)
|
||||
{
|
||||
a.ObservedValue = Outcome.Bad;
|
||||
probEqual = priorB[1];
|
||||
c.ClearObservedValue();
|
||||
evPrior = System.Math.Log(priorA[1]);
|
||||
priorA[0] = 0.0;
|
||||
priorA[1] = 1.0;
|
||||
}
|
||||
double evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
|
||||
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||
|
||||
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
|
||||
Bernoulli cActual = engine.Infer<Bernoulli>(c);
|
||||
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
|
||||
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
|
||||
|
||||
Vector postB = Vector.Zero(2);
|
||||
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
|
||||
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
|
||||
postB.Scale(1.0 / postB.Sum());
|
||||
DiscreteEnum<Outcome> bExpected = new DiscreteEnum<Outcome>(postB);
|
||||
DiscreteEnum<Outcome> bActual = engine.Infer<DiscreteEnum<Outcome>>(b);
|
||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||
|
||||
if (atrial == 0 && algorithm is VariationalMessagePassing) continue;
|
||||
|
||||
for (int trial = 0; trial < 2; trial++)
|
||||
{
|
||||
if (trial == 0)
|
||||
{
|
||||
c.ObservedValue = true;
|
||||
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
|
||||
}
|
||||
else
|
||||
{
|
||||
c.ObservedValue = false;
|
||||
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
|
||||
}
|
||||
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||
|
||||
if (a.IsObserved)
|
||||
{
|
||||
Outcome flip(Outcome x) => (x == Outcome.Good ? Outcome.Bad : Outcome.Good);
|
||||
bExpected = DiscreteEnum<Outcome>.PointMass(c.ObservedValue ? a.ObservedValue : flip(a.ObservedValue));
|
||||
}
|
||||
else
|
||||
{
|
||||
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
|
||||
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
|
||||
postB.Scale(1.0 / postB.Sum());
|
||||
bExpected = new DiscreteEnum<Outcome>(postB);
|
||||
}
|
||||
bActual = engine.Infer<DiscreteEnum<Outcome>>(b);
|
||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3784,82 +3784,89 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
[Fact]
|
||||
public void GatedIntAreEqualTest()
|
||||
{
|
||||
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||
IfBlock block = Variable.If(evidence);
|
||||
Vector priorA = Vector.FromArray(0.1, 0.9);
|
||||
Vector priorB = Vector.FromArray(0.2, 0.8);
|
||||
Variable<int> a = Variable.Discrete(priorA).Named("a");
|
||||
Variable<int> b = Variable.Discrete(priorB).Named("b");
|
||||
Variable<bool> c = (a == b);
|
||||
double priorC = 0.3;
|
||||
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
|
||||
block.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
double evExpected, evActual;
|
||||
|
||||
double probEqual = priorA.Inner(priorB);
|
||||
double evPrior = 0;
|
||||
for (int atrial = 0; atrial < 2; atrial++)
|
||||
foreach (var algorithm in new IAlgorithm[] { new ExpectationPropagation(), new VariationalMessagePassing() })
|
||||
{
|
||||
if (atrial == 1)
|
||||
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||
IfBlock block = Variable.If(evidence);
|
||||
Vector priorA = Vector.FromArray(0.1, 0.9);
|
||||
Vector priorB = Vector.FromArray(0.2, 0.8);
|
||||
Variable<int> a = Variable.Discrete(priorA).Named("a");
|
||||
Variable<int> b = Variable.Discrete(priorB).Named("b");
|
||||
Variable<bool> c = (a == b).Named("c");
|
||||
double priorC = 0.3;
|
||||
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
|
||||
block.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
|
||||
double probEqual = priorA.Inner(priorB);
|
||||
double evPrior = 0;
|
||||
for (int atrial = 0; atrial < 2; atrial++)
|
||||
{
|
||||
a.ObservedValue = 1;
|
||||
probEqual = priorB[1];
|
||||
c.ClearObservedValue();
|
||||
evPrior = System.Math.Log(priorA[1]);
|
||||
priorA[0] = 0.0;
|
||||
priorA[1] = 1.0;
|
||||
}
|
||||
evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
|
||||
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||
|
||||
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
|
||||
Bernoulli cActual = engine.Infer<Bernoulli>(c);
|
||||
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
|
||||
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
|
||||
|
||||
Vector postB = Vector.Zero(2);
|
||||
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
|
||||
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
|
||||
postB.Scale(1.0 / postB.Sum());
|
||||
Discrete bExpected = new Discrete(postB);
|
||||
Discrete bActual = engine.Infer<Discrete>(b);
|
||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||
|
||||
for (int trial = 0; trial < 2; trial++)
|
||||
{
|
||||
if (trial == 0)
|
||||
if (atrial == 1)
|
||||
{
|
||||
c.ObservedValue = true;
|
||||
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
|
||||
a.ObservedValue = 1;
|
||||
probEqual = priorB[1];
|
||||
c.ClearObservedValue();
|
||||
evPrior = System.Math.Log(priorA[1]);
|
||||
priorA[0] = 0.0;
|
||||
priorA[1] = 1.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
c.ObservedValue = false;
|
||||
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
|
||||
}
|
||||
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
double evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
|
||||
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||
|
||||
if (a.IsObserved)
|
||||
{
|
||||
bExpected = Discrete.PointMass(c.ObservedValue ? a.ObservedValue : 1 - a.ObservedValue, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
|
||||
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
|
||||
postB.Scale(1.0 / postB.Sum());
|
||||
bExpected = new Discrete(postB);
|
||||
}
|
||||
bActual = engine.Infer<Discrete>(b);
|
||||
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
|
||||
Bernoulli cActual = engine.Infer<Bernoulli>(c);
|
||||
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
|
||||
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
|
||||
|
||||
Vector postB = Vector.Zero(2);
|
||||
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
|
||||
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
|
||||
postB.Scale(1.0 / postB.Sum());
|
||||
Discrete bExpected = new Discrete(postB);
|
||||
Discrete bActual = engine.Infer<Discrete>(b);
|
||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||
|
||||
if (atrial == 0 && algorithm is VariationalMessagePassing) continue;
|
||||
|
||||
for (int trial = 0; trial < 2; trial++)
|
||||
{
|
||||
if (trial == 0)
|
||||
{
|
||||
c.ObservedValue = true;
|
||||
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
|
||||
}
|
||||
else
|
||||
{
|
||||
c.ObservedValue = false;
|
||||
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
|
||||
}
|
||||
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||
|
||||
if (a.IsObserved)
|
||||
{
|
||||
bExpected = Discrete.PointMass(c.ObservedValue ? a.ObservedValue : 1 - a.ObservedValue, 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
|
||||
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
|
||||
postB.Scale(1.0 / postB.Sum());
|
||||
bExpected = new Discrete(postB);
|
||||
}
|
||||
bActual = engine.Infer<Discrete>(b);
|
||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче