зеркало из https://github.com/dotnet/infer.git
DiscreteEnumAreEqualOp supports VMP (#387)
This commit is contained in:
Родитель
079250ef67
Коммит
b6dd9630f5
|
@ -26,7 +26,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates a distribution over the enum values using the specified probs.
|
/// Creates a distribution over the enum values using the specified probabilities.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public DiscreteEnum(params double[] probs) :
|
public DiscreteEnum(params double[] probs) :
|
||||||
base(values.Length, Sparsity.Dense)
|
base(values.Length, Sparsity.Dense)
|
||||||
|
@ -34,6 +34,15 @@ namespace Microsoft.ML.Probabilistic.Distributions
|
||||||
disc.SetProbs(Vector.FromArray(probs));
|
disc.SetProbs(Vector.FromArray(probs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Creates a distribution over the enum values using the specified probabilities.
|
||||||
|
/// </summary>
|
||||||
|
public DiscreteEnum(Vector probs) :
|
||||||
|
base(values.Length, Sparsity.Dense)
|
||||||
|
{
|
||||||
|
disc.SetProbs(probs);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Converts from an integer to an enum value
|
/// Converts from an integer to an enum value
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -233,6 +233,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return (int)(object)en;
|
return (int)(object)en;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> B, DiscreteEnum<TEnum> result)
|
||||||
|
{
|
||||||
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, B.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum B, DiscreteEnum<TEnum> result)
|
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum B, DiscreteEnum<TEnum> result)
|
||||||
{
|
{
|
||||||
|
@ -245,6 +251,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, ToInt(B), result.GetInternalDiscrete()));
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, ToInt(B), result.GetInternalDiscrete()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static DiscreteEnum<TEnum> AAverageLogarithm([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> B, DiscreteEnum<TEnum> result)
|
||||||
|
{
|
||||||
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, B.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(bool, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(bool, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static DiscreteEnum<TEnum> AAverageLogarithm(bool areEqual, TEnum B, DiscreteEnum<TEnum> result)
|
public static DiscreteEnum<TEnum> AAverageLogarithm(bool areEqual, TEnum B, DiscreteEnum<TEnum> result)
|
||||||
{
|
{
|
||||||
|
@ -257,6 +269,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, ToInt(B), result.GetInternalDiscrete()));
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, ToInt(B), result.GetInternalDiscrete()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> result)
|
||||||
|
{
|
||||||
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, A.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
|
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
|
||||||
{
|
{
|
||||||
|
@ -269,6 +287,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, ToInt(A), result.GetInternalDiscrete()));
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, ToInt(A), result.GetInternalDiscrete()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> result)
|
||||||
|
{
|
||||||
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, A.GetInternalDiscrete(), result.GetInternalDiscrete()));
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
|
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
|
||||||
{
|
{
|
||||||
|
@ -281,6 +305,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, ToInt(A), result.GetInternalDiscrete()));
|
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, ToInt(A), result.GetInternalDiscrete()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static Bernoulli AreEqualAverageConditional(DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B)
|
||||||
|
{
|
||||||
|
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), B.GetInternalDiscrete());
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(TEnum, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static Bernoulli AreEqualAverageConditional(TEnum A, DiscreteEnum<TEnum> B)
|
public static Bernoulli AreEqualAverageConditional(TEnum A, DiscreteEnum<TEnum> B)
|
||||||
{
|
{
|
||||||
|
@ -293,6 +323,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), ToInt(B));
|
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), ToInt(B));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static Bernoulli AreEqualAverageLogarithm(DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B)
|
||||||
|
{
|
||||||
|
return DiscreteAreEqualOp.AreEqualAverageLogarithm(A.GetInternalDiscrete(), B.GetInternalDiscrete());
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(TEnum, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(TEnum, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static Bernoulli AreEqualAverageLogarithm(TEnum A, DiscreteEnum<TEnum> B)
|
public static Bernoulli AreEqualAverageLogarithm(TEnum A, DiscreteEnum<TEnum> B)
|
||||||
{
|
{
|
||||||
|
@ -329,6 +365,12 @@ namespace Microsoft.ML.Probabilistic.Factors
|
||||||
return DiscreteAreEqualOp.LogAverageFactor(areEqual, A.GetInternalDiscrete(), ToInt(B));
|
return DiscreteAreEqualOp.LogAverageFactor(areEqual, A.GetInternalDiscrete(), ToInt(B));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, DiscreteEnum{TEnum}, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
|
public static double LogEvidenceRatio(bool areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_A)
|
||||||
|
{
|
||||||
|
return DiscreteAreEqualOp.LogEvidenceRatio(areEqual, A.GetInternalDiscrete(), B.GetInternalDiscrete(), to_A.GetInternalDiscrete());
|
||||||
|
}
|
||||||
|
|
||||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, TEnum, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, TEnum, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
|
||||||
public static double LogEvidenceRatio(bool areEqual, TEnum A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_B)
|
public static double LogEvidenceRatio(bool areEqual, TEnum A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_B)
|
||||||
{
|
{
|
||||||
|
|
|
@ -101,7 +101,7 @@ namespace TestApp
|
||||||
//}
|
//}
|
||||||
//TestUtils.CheckTransformNames();
|
//TestUtils.CheckTransformNames();
|
||||||
}
|
}
|
||||||
bool showFactorManager = true;
|
bool showFactorManager = false;
|
||||||
if (showFactorManager)
|
if (showFactorManager)
|
||||||
{
|
{
|
||||||
InferenceEngine.ShowFactorManager(true);
|
InferenceEngine.ShowFactorManager(true);
|
||||||
|
|
|
@ -7,6 +7,7 @@ using Xunit;
|
||||||
using Microsoft.ML.Probabilistic.Math;
|
using Microsoft.ML.Probabilistic.Math;
|
||||||
using Microsoft.ML.Probabilistic.Models;
|
using Microsoft.ML.Probabilistic.Models;
|
||||||
using Microsoft.ML.Probabilistic.Distributions;
|
using Microsoft.ML.Probabilistic.Distributions;
|
||||||
|
using Microsoft.ML.Probabilistic.Algorithms;
|
||||||
using Range = Microsoft.ML.Probabilistic.Models.Range;
|
using Range = Microsoft.ML.Probabilistic.Models.Range;
|
||||||
|
|
||||||
namespace Microsoft.ML.Probabilistic.Tests
|
namespace Microsoft.ML.Probabilistic.Tests
|
||||||
|
@ -90,7 +91,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
Assert.True(expected.MaxDiff(actual) < 1e-10);
|
Assert.True(expected.MaxDiff(actual) < 1e-10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public enum Outcome
|
public enum Outcome
|
||||||
{
|
{
|
||||||
Good,
|
Good,
|
||||||
|
@ -133,5 +133,96 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
Console.WriteLine("Distribution over outcomes if control = "
|
Console.WriteLine("Distribution over outcomes if control = "
|
||||||
+ DiscreteEnum<Outcome>.FromVector(ie.Infer<Dirichlet>(probIfControl).GetMean()));
|
+ DiscreteEnum<Outcome>.FromVector(ie.Infer<Dirichlet>(probIfControl).GetMean()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void GatedOutcomeAreEqualTest()
|
||||||
|
{
|
||||||
|
foreach (var algorithm in new Models.Attributes.IAlgorithm[] { new ExpectationPropagation(), new VariationalMessagePassing() })
|
||||||
|
{
|
||||||
|
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||||
|
IfBlock block = Variable.If(evidence);
|
||||||
|
Vector priorA = Vector.FromArray(0.1, 0.9);
|
||||||
|
Vector priorB = Vector.FromArray(0.2, 0.8);
|
||||||
|
Variable<Outcome> a = Variable.EnumDiscrete<Outcome>(priorA).Named("a");
|
||||||
|
Variable<Outcome> b = Variable.EnumDiscrete<Outcome>(priorB).Named("b");
|
||||||
|
Variable<bool> c = (a == b).Named("c");
|
||||||
|
double priorC = 0.3;
|
||||||
|
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
|
||||||
|
block.CloseBlock();
|
||||||
|
|
||||||
|
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||||
|
|
||||||
|
double probEqual = priorA.Inner(priorB);
|
||||||
|
double evPrior = 0;
|
||||||
|
for (int atrial = 0; atrial < 2; atrial++)
|
||||||
|
{
|
||||||
|
if (atrial == 1)
|
||||||
|
{
|
||||||
|
a.ObservedValue = Outcome.Bad;
|
||||||
|
probEqual = priorB[1];
|
||||||
|
c.ClearObservedValue();
|
||||||
|
evPrior = System.Math.Log(priorA[1]);
|
||||||
|
priorA[0] = 0.0;
|
||||||
|
priorA[1] = 1.0;
|
||||||
|
}
|
||||||
|
double evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
|
||||||
|
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||||
|
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||||
|
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||||
|
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||||
|
|
||||||
|
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
|
||||||
|
Bernoulli cActual = engine.Infer<Bernoulli>(c);
|
||||||
|
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
|
||||||
|
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||||
|
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
|
||||||
|
|
||||||
|
Vector postB = Vector.Zero(2);
|
||||||
|
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
|
||||||
|
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
|
||||||
|
postB.Scale(1.0 / postB.Sum());
|
||||||
|
DiscreteEnum<Outcome> bExpected = new DiscreteEnum<Outcome>(postB);
|
||||||
|
DiscreteEnum<Outcome> bActual = engine.Infer<DiscreteEnum<Outcome>>(b);
|
||||||
|
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||||
|
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||||
|
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||||
|
|
||||||
|
if (atrial == 0 && algorithm is VariationalMessagePassing) continue;
|
||||||
|
|
||||||
|
for (int trial = 0; trial < 2; trial++)
|
||||||
|
{
|
||||||
|
if (trial == 0)
|
||||||
|
{
|
||||||
|
c.ObservedValue = true;
|
||||||
|
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
c.ObservedValue = false;
|
||||||
|
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
|
||||||
|
}
|
||||||
|
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||||
|
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||||
|
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||||
|
|
||||||
|
if (a.IsObserved)
|
||||||
|
{
|
||||||
|
Outcome flip(Outcome x) => (x == Outcome.Good ? Outcome.Bad : Outcome.Good);
|
||||||
|
bExpected = DiscreteEnum<Outcome>.PointMass(c.ObservedValue ? a.ObservedValue : flip(a.ObservedValue));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
|
||||||
|
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
|
||||||
|
postB.Scale(1.0 / postB.Sum());
|
||||||
|
bExpected = new DiscreteEnum<Outcome>(postB);
|
||||||
|
}
|
||||||
|
bActual = engine.Infer<DiscreteEnum<Outcome>>(b);
|
||||||
|
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||||
|
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -3784,82 +3784,89 @@ namespace Microsoft.ML.Probabilistic.Tests
|
||||||
[Fact]
|
[Fact]
|
||||||
public void GatedIntAreEqualTest()
|
public void GatedIntAreEqualTest()
|
||||||
{
|
{
|
||||||
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
foreach (var algorithm in new IAlgorithm[] { new ExpectationPropagation(), new VariationalMessagePassing() })
|
||||||
IfBlock block = Variable.If(evidence);
|
|
||||||
Vector priorA = Vector.FromArray(0.1, 0.9);
|
|
||||||
Vector priorB = Vector.FromArray(0.2, 0.8);
|
|
||||||
Variable<int> a = Variable.Discrete(priorA).Named("a");
|
|
||||||
Variable<int> b = Variable.Discrete(priorB).Named("b");
|
|
||||||
Variable<bool> c = (a == b);
|
|
||||||
double priorC = 0.3;
|
|
||||||
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
|
|
||||||
block.CloseBlock();
|
|
||||||
|
|
||||||
InferenceEngine engine = new InferenceEngine();
|
|
||||||
double evExpected, evActual;
|
|
||||||
|
|
||||||
double probEqual = priorA.Inner(priorB);
|
|
||||||
double evPrior = 0;
|
|
||||||
for (int atrial = 0; atrial < 2; atrial++)
|
|
||||||
{
|
{
|
||||||
if (atrial == 1)
|
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||||
|
IfBlock block = Variable.If(evidence);
|
||||||
|
Vector priorA = Vector.FromArray(0.1, 0.9);
|
||||||
|
Vector priorB = Vector.FromArray(0.2, 0.8);
|
||||||
|
Variable<int> a = Variable.Discrete(priorA).Named("a");
|
||||||
|
Variable<int> b = Variable.Discrete(priorB).Named("b");
|
||||||
|
Variable<bool> c = (a == b).Named("c");
|
||||||
|
double priorC = 0.3;
|
||||||
|
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
|
||||||
|
block.CloseBlock();
|
||||||
|
|
||||||
|
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||||
|
|
||||||
|
double probEqual = priorA.Inner(priorB);
|
||||||
|
double evPrior = 0;
|
||||||
|
for (int atrial = 0; atrial < 2; atrial++)
|
||||||
{
|
{
|
||||||
a.ObservedValue = 1;
|
if (atrial == 1)
|
||||||
probEqual = priorB[1];
|
|
||||||
c.ClearObservedValue();
|
|
||||||
evPrior = System.Math.Log(priorA[1]);
|
|
||||||
priorA[0] = 0.0;
|
|
||||||
priorA[1] = 1.0;
|
|
||||||
}
|
|
||||||
evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
|
|
||||||
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
|
||||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
|
||||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
|
||||||
|
|
||||||
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
|
|
||||||
Bernoulli cActual = engine.Infer<Bernoulli>(c);
|
|
||||||
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
|
|
||||||
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
|
|
||||||
|
|
||||||
Vector postB = Vector.Zero(2);
|
|
||||||
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
|
|
||||||
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
|
|
||||||
postB.Scale(1.0 / postB.Sum());
|
|
||||||
Discrete bExpected = new Discrete(postB);
|
|
||||||
Discrete bActual = engine.Infer<Discrete>(b);
|
|
||||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
|
||||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
|
||||||
|
|
||||||
for (int trial = 0; trial < 2; trial++)
|
|
||||||
{
|
|
||||||
if (trial == 0)
|
|
||||||
{
|
{
|
||||||
c.ObservedValue = true;
|
a.ObservedValue = 1;
|
||||||
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
|
probEqual = priorB[1];
|
||||||
|
c.ClearObservedValue();
|
||||||
|
evPrior = System.Math.Log(priorA[1]);
|
||||||
|
priorA[0] = 0.0;
|
||||||
|
priorA[1] = 1.0;
|
||||||
}
|
}
|
||||||
else
|
double evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
|
||||||
{
|
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||||
c.ObservedValue = false;
|
|
||||||
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
|
|
||||||
}
|
|
||||||
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
|
||||||
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||||
|
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||||
|
|
||||||
if (a.IsObserved)
|
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
|
||||||
{
|
Bernoulli cActual = engine.Infer<Bernoulli>(c);
|
||||||
bExpected = Discrete.PointMass(c.ObservedValue ? a.ObservedValue : 1 - a.ObservedValue, 2);
|
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
|
||||||
}
|
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||||
else
|
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
|
||||||
{
|
|
||||||
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
|
Vector postB = Vector.Zero(2);
|
||||||
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
|
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
|
||||||
postB.Scale(1.0 / postB.Sum());
|
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
|
||||||
bExpected = new Discrete(postB);
|
postB.Scale(1.0 / postB.Sum());
|
||||||
}
|
Discrete bExpected = new Discrete(postB);
|
||||||
bActual = engine.Infer<Discrete>(b);
|
Discrete bActual = engine.Infer<Discrete>(b);
|
||||||
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||||
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
if (algorithm is ExpectationPropagation || atrial == 1)
|
||||||
|
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||||
|
|
||||||
|
if (atrial == 0 && algorithm is VariationalMessagePassing) continue;
|
||||||
|
|
||||||
|
for (int trial = 0; trial < 2; trial++)
|
||||||
|
{
|
||||||
|
if (trial == 0)
|
||||||
|
{
|
||||||
|
c.ObservedValue = true;
|
||||||
|
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
c.ObservedValue = false;
|
||||||
|
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
|
||||||
|
}
|
||||||
|
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||||
|
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
|
||||||
|
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
|
||||||
|
|
||||||
|
if (a.IsObserved)
|
||||||
|
{
|
||||||
|
bExpected = Discrete.PointMass(c.ObservedValue ? a.ObservedValue : 1 - a.ObservedValue, 2);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
|
||||||
|
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
|
||||||
|
postB.Scale(1.0 / postB.Sum());
|
||||||
|
bExpected = new Discrete(postB);
|
||||||
|
}
|
||||||
|
bActual = engine.Infer<Discrete>(b);
|
||||||
|
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
|
||||||
|
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче