DiscreteEnumAreEqualOp supports VMP (#387)

This commit is contained in:
Tom Minka 2022-01-08 18:47:02 +00:00 коммит произвёл GitHub
Родитель 079250ef67
Коммит b6dd9630f5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 221 добавлений и 72 удалений

Просмотреть файл

@ -26,7 +26,7 @@ namespace Microsoft.ML.Probabilistic.Distributions
}
/// <summary>
/// Creates a distribution over the enum values using the specified probs.
/// Creates a distribution over the enum values using the specified probabilities.
/// </summary>
public DiscreteEnum(params double[] probs) :
base(values.Length, Sparsity.Dense)
@ -34,6 +34,15 @@ namespace Microsoft.ML.Probabilistic.Distributions
disc.SetProbs(Vector.FromArray(probs));
}
/// <summary>
/// Creates a distribution over the enum values using the specified probabilities.
/// </summary>
public DiscreteEnum(Vector probs) :
base(values.Length, Sparsity.Dense)
{
disc.SetProbs(probs);
}
/// <summary>
/// Converts from an integer to an enum value
/// </summary>

Просмотреть файл

@ -233,6 +233,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return (int)(object)en;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> B, DiscreteEnum<TEnum> result)
{
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, B.GetInternalDiscrete(), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> AAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum B, DiscreteEnum<TEnum> result)
{
@ -245,6 +251,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageConditional(areEqual, ToInt(B), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> AAverageLogarithm([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> B, DiscreteEnum<TEnum> result)
{
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, B.GetInternalDiscrete(), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AAverageLogarithm(bool, TEnum, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> AAverageLogarithm(bool areEqual, TEnum B, DiscreteEnum<TEnum> result)
{
@ -257,6 +269,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.AAverageLogarithm(areEqual, ToInt(B), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> result)
{
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, A.GetInternalDiscrete(), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageConditional(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> BAverageConditional([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
{
@ -269,6 +287,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageConditional(areEqual, ToInt(A), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> result)
{
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, A.GetInternalDiscrete(), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="BAverageLogarithm(Bernoulli, TEnum, DiscreteEnum{TEnum})"]/*'/>
public static DiscreteEnum<TEnum> BAverageLogarithm([SkipIfUniform] Bernoulli areEqual, TEnum A, DiscreteEnum<TEnum> result)
{
@ -281,6 +305,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return DiscreteEnum<TEnum>.FromDiscrete(DiscreteAreEqualOp.BAverageLogarithm(areEqual, ToInt(A), result.GetInternalDiscrete()));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static Bernoulli AreEqualAverageConditional(DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B)
{
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), B.GetInternalDiscrete());
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageConditional(TEnum, DiscreteEnum{TEnum})"]/*'/>
public static Bernoulli AreEqualAverageConditional(TEnum A, DiscreteEnum<TEnum> B)
{
@ -293,6 +323,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return DiscreteAreEqualOp.AreEqualAverageConditional(A.GetInternalDiscrete(), ToInt(B));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static Bernoulli AreEqualAverageLogarithm(DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B)
{
return DiscreteAreEqualOp.AreEqualAverageLogarithm(A.GetInternalDiscrete(), B.GetInternalDiscrete());
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="AreEqualAverageLogarithm(TEnum, DiscreteEnum{TEnum})"]/*'/>
public static Bernoulli AreEqualAverageLogarithm(TEnum A, DiscreteEnum<TEnum> B)
{
@ -329,6 +365,12 @@ namespace Microsoft.ML.Probabilistic.Factors
return DiscreteAreEqualOp.LogAverageFactor(areEqual, A.GetInternalDiscrete(), ToInt(B));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, DiscreteEnum{TEnum}, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static double LogEvidenceRatio(bool areEqual, DiscreteEnum<TEnum> A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_A)
{
return DiscreteAreEqualOp.LogEvidenceRatio(areEqual, A.GetInternalDiscrete(), B.GetInternalDiscrete(), to_A.GetInternalDiscrete());
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(bool, TEnum, DiscreteEnum{TEnum}, DiscreteEnum{TEnum})"]/*'/>
public static double LogEvidenceRatio(bool areEqual, TEnum A, DiscreteEnum<TEnum> B, [Fresh] DiscreteEnum<TEnum> to_B)
{

Просмотреть файл

@ -101,7 +101,7 @@ namespace TestApp
//}
//TestUtils.CheckTransformNames();
}
bool showFactorManager = true;
bool showFactorManager = false;
if (showFactorManager)
{
InferenceEngine.ShowFactorManager(true);

Просмотреть файл

@ -7,6 +7,7 @@ using Xunit;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Models;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Algorithms;
using Range = Microsoft.ML.Probabilistic.Models.Range;
namespace Microsoft.ML.Probabilistic.Tests
@ -90,7 +91,6 @@ namespace Microsoft.ML.Probabilistic.Tests
Assert.True(expected.MaxDiff(actual) < 1e-10);
}
public enum Outcome
{
Good,
@ -133,5 +133,96 @@ namespace Microsoft.ML.Probabilistic.Tests
Console.WriteLine("Distribution over outcomes if control = "
+ DiscreteEnum<Outcome>.FromVector(ie.Infer<Dirichlet>(probIfControl).GetMean()));
}
[Fact]
public void GatedOutcomeAreEqualTest()
{
foreach (var algorithm in new Models.Attributes.IAlgorithm[] { new ExpectationPropagation(), new VariationalMessagePassing() })
{
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
IfBlock block = Variable.If(evidence);
Vector priorA = Vector.FromArray(0.1, 0.9);
Vector priorB = Vector.FromArray(0.2, 0.8);
Variable<Outcome> a = Variable.EnumDiscrete<Outcome>(priorA).Named("a");
Variable<Outcome> b = Variable.EnumDiscrete<Outcome>(priorB).Named("b");
Variable<bool> c = (a == b).Named("c");
double priorC = 0.3;
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
block.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
double probEqual = priorA.Inner(priorB);
double evPrior = 0;
for (int atrial = 0; atrial < 2; atrial++)
{
if (atrial == 1)
{
a.ObservedValue = Outcome.Bad;
probEqual = priorB[1];
c.ClearObservedValue();
evPrior = System.Math.Log(priorA[1]);
priorA[0] = 0.0;
priorA[1] = 1.0;
}
double evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
if (algorithm is ExpectationPropagation || atrial == 1)
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
Bernoulli cActual = engine.Infer<Bernoulli>(c);
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
if (algorithm is ExpectationPropagation || atrial == 1)
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
Vector postB = Vector.Zero(2);
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
postB.Scale(1.0 / postB.Sum());
DiscreteEnum<Outcome> bExpected = new DiscreteEnum<Outcome>(postB);
DiscreteEnum<Outcome> bActual = engine.Infer<DiscreteEnum<Outcome>>(b);
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
if (algorithm is ExpectationPropagation || atrial == 1)
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
if (atrial == 0 && algorithm is VariationalMessagePassing) continue;
for (int trial = 0; trial < 2; trial++)
{
if (trial == 0)
{
c.ObservedValue = true;
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
}
else
{
c.ObservedValue = false;
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
}
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
if (a.IsObserved)
{
Outcome flip(Outcome x) => (x == Outcome.Good ? Outcome.Bad : Outcome.Good);
bExpected = DiscreteEnum<Outcome>.PointMass(c.ObservedValue ? a.ObservedValue : flip(a.ObservedValue));
}
else
{
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
postB.Scale(1.0 / postB.Sum());
bExpected = new DiscreteEnum<Outcome>(postB);
}
bActual = engine.Infer<DiscreteEnum<Outcome>>(b);
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
}
}
}
}
}
}

Просмотреть файл

@ -3784,82 +3784,89 @@ namespace Microsoft.ML.Probabilistic.Tests
[Fact]
public void GatedIntAreEqualTest()
{
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
IfBlock block = Variable.If(evidence);
Vector priorA = Vector.FromArray(0.1, 0.9);
Vector priorB = Vector.FromArray(0.2, 0.8);
Variable<int> a = Variable.Discrete(priorA).Named("a");
Variable<int> b = Variable.Discrete(priorB).Named("b");
Variable<bool> c = (a == b);
double priorC = 0.3;
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
block.CloseBlock();
InferenceEngine engine = new InferenceEngine();
double evExpected, evActual;
double probEqual = priorA.Inner(priorB);
double evPrior = 0;
for (int atrial = 0; atrial < 2; atrial++)
foreach (var algorithm in new IAlgorithm[] { new ExpectationPropagation(), new VariationalMessagePassing() })
{
if (atrial == 1)
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
IfBlock block = Variable.If(evidence);
Vector priorA = Vector.FromArray(0.1, 0.9);
Vector priorB = Vector.FromArray(0.2, 0.8);
Variable<int> a = Variable.Discrete(priorA).Named("a");
Variable<int> b = Variable.Discrete(priorB).Named("b");
Variable<bool> c = (a == b).Named("c");
double priorC = 0.3;
Variable.ConstrainEqualRandom(c, new Bernoulli(priorC));
block.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
double probEqual = priorA.Inner(priorB);
double evPrior = 0;
for (int atrial = 0; atrial < 2; atrial++)
{
a.ObservedValue = 1;
probEqual = priorB[1];
c.ClearObservedValue();
evPrior = System.Math.Log(priorA[1]);
priorA[0] = 0.0;
priorA[1] = 1.0;
}
evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
Bernoulli cActual = engine.Infer<Bernoulli>(c);
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
Vector postB = Vector.Zero(2);
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
postB.Scale(1.0 / postB.Sum());
Discrete bExpected = new Discrete(postB);
Discrete bActual = engine.Infer<Discrete>(b);
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
for (int trial = 0; trial < 2; trial++)
{
if (trial == 0)
if (atrial == 1)
{
c.ObservedValue = true;
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
a.ObservedValue = 1;
probEqual = priorB[1];
c.ClearObservedValue();
evPrior = System.Math.Log(priorA[1]);
priorA[0] = 0.0;
priorA[1] = 1.0;
}
else
{
c.ObservedValue = false;
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
}
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
double evExpected = System.Math.Log(probEqual * priorC + (1 - probEqual) * (1 - priorC)) + evPrior;
double evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
if (algorithm is ExpectationPropagation || atrial == 1)
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
if (a.IsObserved)
{
bExpected = Discrete.PointMass(c.ObservedValue ? a.ObservedValue : 1 - a.ObservedValue, 2);
}
else
{
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
postB.Scale(1.0 / postB.Sum());
bExpected = new Discrete(postB);
}
bActual = engine.Infer<Discrete>(b);
Bernoulli cExpected = new Bernoulli(probEqual * priorC / (probEqual * priorC + (1 - probEqual) * (1 - priorC)));
Bernoulli cActual = engine.Infer<Bernoulli>(c);
Console.WriteLine("c = {0} should be {1}", cActual, cExpected);
if (algorithm is ExpectationPropagation || atrial == 1)
Assert.True(cExpected.MaxDiff(cActual) < 1e-10);
Vector postB = Vector.Zero(2);
postB[0] = priorB[0] * (priorA[0] * priorC + priorA[1] * (1 - priorC));
postB[1] = priorB[1] * (priorA[1] * priorC + priorA[0] * (1 - priorC));
postB.Scale(1.0 / postB.Sum());
Discrete bExpected = new Discrete(postB);
Discrete bActual = engine.Infer<Discrete>(b);
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
if (algorithm is ExpectationPropagation || atrial == 1)
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
if (atrial == 0 && algorithm is VariationalMessagePassing) continue;
for (int trial = 0; trial < 2; trial++)
{
if (trial == 0)
{
c.ObservedValue = true;
evExpected = System.Math.Log(probEqual * priorC) + evPrior;
}
else
{
c.ObservedValue = false;
evExpected = System.Math.Log((1 - probEqual) * (1 - priorC)) + evPrior;
}
evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
Console.WriteLine("evidence = {0} should be {1}", evActual, evExpected);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-5) < 1e-5);
if (a.IsObserved)
{
bExpected = Discrete.PointMass(c.ObservedValue ? a.ObservedValue : 1 - a.ObservedValue, 2);
}
else
{
postB[0] = priorB[0] * (c.ObservedValue ? priorA[0] : priorA[1]);
postB[1] = priorB[1] * (c.ObservedValue ? priorA[1] : priorA[0]);
postB.Scale(1.0 / postB.Sum());
bExpected = new Discrete(postB);
}
bActual = engine.Infer<Discrete>(b);
Console.WriteLine("b = {0} should be {1}", bActual, bExpected);
Assert.True(bExpected.MaxDiff(bActual) < 1e-10);
}
}
}
}