зеркало из https://github.com/dotnet/infer.git
Observed Beta variables support a Gamma-distributed pseudocount if the other pseudocount is always 1 (#386)
* ShowFactorManager shows more patterns * Tidy up code
This commit is contained in:
Родитель
fc93e12851
Коммит
079250ef67
|
@ -257,9 +257,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
for (int depth = 0; depth < arrayDepth; depth++)
|
||||
{
|
||||
bool notLast = (depth < arrayDepth - 1);
|
||||
int rank;
|
||||
Type arrayType = sourceArray.GetExpressionType();
|
||||
Util.GetElementType(arrayType, out rank);
|
||||
Util.GetElementType(arrayType, out int rank);
|
||||
if (sizes.Count <= depth) sizes.Add(new IExpression[rank]);
|
||||
IExpression[] indices = new IExpression[rank];
|
||||
for (int i = 0; i < rank; i++)
|
||||
|
@ -822,8 +821,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
if (char.IsDigit(lastChar))
|
||||
prefix += "_";
|
||||
}
|
||||
int count;
|
||||
counts.TryGetValue(prefix, out count);
|
||||
counts.TryGetValue(prefix, out int count);
|
||||
if (count == 0) count = 1;
|
||||
counts[prefix] = count + 1;
|
||||
if (count == 1) return prefix;
|
||||
|
|
|
@ -553,7 +553,6 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
}
|
||||
|
||||
// Model methods with up to ten parameters are supported directly.
|
||||
#pragma warning disable 1591
|
||||
/// <exclude/>
|
||||
public delegate void ModelDefinitionMethod();
|
||||
|
||||
|
@ -716,7 +715,6 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
{
|
||||
return CompileWithoutParams(method.Method);
|
||||
}
|
||||
#pragma warning restore 1591
|
||||
|
||||
/// <summary>
|
||||
/// Compiles the model defined in MSL by the specified method. The model parameters are set
|
||||
|
|
|
@ -4648,7 +4648,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// </remarks>
|
||||
public static Variable<string> StringCapitalized(Variable<int> minLength, Variable<int> maxLength = null)
|
||||
{
|
||||
return ReferenceEquals(maxLength, null)
|
||||
return maxLength is null
|
||||
? Variable<string>.Factor(Factor.StringCapitalized, minLength)
|
||||
: Variable<string>.Factor(Factor.StringCapitalized, minLength, maxLength);
|
||||
}
|
||||
|
@ -4711,7 +4711,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// </remarks>
|
||||
public static Variable<string> String(Variable<int> minLength, Variable<int> maxLength, Variable<DiscreteChar> allowedCharacters)
|
||||
{
|
||||
return ReferenceEquals(maxLength, null)
|
||||
return maxLength is null
|
||||
? Variable<string>.Factor(Factor.String, minLength, allowedCharacters)
|
||||
: Variable<string>.Factor(Factor.String, minLength, maxLength, allowedCharacters);
|
||||
}
|
||||
|
@ -5222,7 +5222,6 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// <summary>
|
||||
/// Enumeration over supported operators.
|
||||
/// </summary>
|
||||
#pragma warning disable 1591
|
||||
public enum Operator
|
||||
{
|
||||
Plus,
|
||||
|
@ -5245,7 +5244,6 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
GreaterThanOrEqual,
|
||||
LessThanOrEqual
|
||||
};
|
||||
#pragma warning restore 1591
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -6482,11 +6480,11 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
diff = null;
|
||||
}
|
||||
if ((object)diff == null)
|
||||
if (diff is null)
|
||||
{
|
||||
diff = OperatorFactor<T>(Operator.Minus, a, b);
|
||||
}
|
||||
if ((object)diff != null)
|
||||
if (diff is object)
|
||||
{
|
||||
return IsPositive((Variable<double>)(Variable)diff);
|
||||
}
|
||||
|
@ -6497,7 +6495,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
|
||||
private static Variable<bool> NotOrNull(Variable<bool> Variable)
|
||||
{
|
||||
return ((object)Variable == null) ? null : !Variable;
|
||||
return (Variable is null) ? null : !Variable;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -6683,7 +6681,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
public static Variable<T> operator -(Variable<T> a)
|
||||
{
|
||||
Variable<T> f = OperatorFactor<T>(Operator.Negative, a);
|
||||
if ((object)f != null) return f;
|
||||
if (f is object) return f;
|
||||
else if (a is Variable<double>)
|
||||
{
|
||||
Variable<T> zero = (Variable<T>)(object)Constant(0.0);
|
||||
|
|
|
@ -19,10 +19,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
|
||||
using Microsoft.ML.Probabilistic.Algorithms;
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning disable 1591
|
||||
#endif
|
||||
|
||||
internal class FactorManager
|
||||
{
|
||||
/// <summary>
|
||||
|
@ -476,8 +472,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (method.IsDefined(typeof(MultiplyAllAttribute), true))
|
||||
{
|
||||
// MultiplyAll implies SkipIfAllUniform
|
||||
var list = new List<object>(attrs);
|
||||
list.Add(new SkipIfAllUniformAttribute());
|
||||
var list = new List<object>(attrs)
|
||||
{
|
||||
new SkipIfAllUniformAttribute()
|
||||
};
|
||||
attrs = list.ToArray();
|
||||
}
|
||||
foreach (SkipIfAllUniformAttribute attr in attrs)
|
||||
|
@ -617,7 +615,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
/// </summary>
|
||||
public bool IsVoid
|
||||
{
|
||||
get { return (Method.ReturnType == typeof(void)); }
|
||||
get { return Method.ReturnType == typeof(void); }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -1481,9 +1479,11 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
protected IDictionary<string, string> GetNameMapping(FactorMethodAttribute attr)
|
||||
{
|
||||
string[] newParameterNames = attr.NewParameterNames;
|
||||
Dictionary<string, string> map = new Dictionary<string, string>();
|
||||
// always include the empty string
|
||||
map[string.Empty] = string.Empty;
|
||||
Dictionary<string, string> map = new Dictionary<string, string>
|
||||
{
|
||||
// always include the empty string
|
||||
[string.Empty] = string.Empty
|
||||
};
|
||||
if (newParameterNames == null)
|
||||
{
|
||||
// map the fields to themselves
|
||||
|
@ -1892,8 +1892,4 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
return "i";
|
||||
}
|
||||
}
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning restore 1591
|
||||
#endif
|
||||
}
|
|
@ -136,9 +136,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
foreach (var container in containers.inputs)
|
||||
{
|
||||
if (container is IForStatement)
|
||||
if (container is IForStatement ifs)
|
||||
{
|
||||
IForStatement ifs = (IForStatement)container;
|
||||
IVariableDeclaration loopVar = Recognizer.LoopVariable(ifs);
|
||||
if (context.InputAttributes.Has<Partitioned>(loopVar)) return true;
|
||||
}
|
||||
|
@ -345,14 +344,14 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
/// <summary>
|
||||
/// Holds the ContainerInfo for each 'for' loop on the inputStack as they are being converted.
|
||||
/// </summary>
|
||||
private Stack<ContainerInfo> containerInfos = new Stack<ContainerInfo>();
|
||||
private readonly Stack<ContainerInfo> containerInfos = new Stack<ContainerInfo>();
|
||||
/// <summary>
|
||||
/// Maps an IForStatement to its LocalInfos
|
||||
/// </summary>
|
||||
internal Dictionary<IStatement, Dictionary<IExpression, LocalInfo>> localInfoOfStmt = new Dictionary<IStatement, Dictionary<IExpression, LocalInfo>>(ReferenceEqualityComparer<IStatement>.Instance);
|
||||
private Stack<IStatement> openContainers = new Stack<IStatement>();
|
||||
private readonly Stack<IStatement> openContainers = new Stack<IStatement>();
|
||||
private bool InPartitionedLoop;
|
||||
ModelCompiler compiler;
|
||||
readonly ModelCompiler compiler;
|
||||
|
||||
internal LocalAnalysisTransform(ModelCompiler compiler)
|
||||
{
|
||||
|
@ -435,8 +434,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
else
|
||||
minExpr = GetCommonPrefix(minExpr, expr);
|
||||
}
|
||||
LocalInfo minInfo;
|
||||
localInfos.TryGetValue(minExpr, out minInfo);
|
||||
localInfos.TryGetValue(minExpr, out LocalInfo minInfo);
|
||||
foreach (var entry in localInfos)
|
||||
{
|
||||
var expr = entry.Key;
|
||||
|
@ -518,9 +516,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
var loopSize = Recognizer.LoopSizeExpression(closedContainer);
|
||||
bool loopMustExecute = false;
|
||||
if (loopSize is ILiteralExpression)
|
||||
if (loopSize is ILiteralExpression ile)
|
||||
{
|
||||
int loopSizeAsInt = (int)((ILiteralExpression)loopSize).Value;
|
||||
int loopSizeAsInt = (int)ile.Value;
|
||||
if (loopSizeAsInt > 0)
|
||||
{
|
||||
loopMustExecute = true;
|
||||
|
@ -566,9 +564,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
string containerString = StringUtil.ToString(containers.Select(c =>
|
||||
{
|
||||
if (c is IForStatement)
|
||||
if (c is IForStatement ifs)
|
||||
{
|
||||
IForStatement ifs = (IForStatement)c;
|
||||
if (c is IBrokenForStatement)
|
||||
return ifs.Initializer.ToString() + " // broken";
|
||||
else
|
||||
|
@ -584,9 +581,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
|
||||
private static IExpression GetContainerExpression(IStatement container)
|
||||
{
|
||||
if (container is IForStatement) return Recognizer.LoopSizeExpression((IForStatement)container);
|
||||
else if (container is IConditionStatement) return ((IConditionStatement)container).Condition;
|
||||
else if (container is IRepeatStatement) return ((IRepeatStatement)container).Count;
|
||||
if (container is IForStatement ifs) return Recognizer.LoopSizeExpression(ifs);
|
||||
else if (container is IConditionStatement ics) return ics.Condition;
|
||||
else if (container is IRepeatStatement irs) return irs.Count;
|
||||
else throw new ArgumentException($"unrecognized container type: {container.GetType()}");
|
||||
}
|
||||
|
||||
|
@ -741,9 +738,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
AddUsage(prefix, false);
|
||||
}
|
||||
}
|
||||
if (prefix.Target is IArrayIndexerExpression)
|
||||
if (prefix.Target is IArrayIndexerExpression target)
|
||||
{
|
||||
prefix = (IArrayIndexerExpression)prefix.Target;
|
||||
prefix = target;
|
||||
isPrefix = true;
|
||||
}
|
||||
else
|
||||
|
@ -763,9 +760,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (expr is IArrayIndexerExpression)
|
||||
{
|
||||
IExpression expr2 = expr;
|
||||
while (expr2 is IArrayIndexerExpression)
|
||||
while (expr2 is IArrayIndexerExpression iaie)
|
||||
{
|
||||
IArrayIndexerExpression iaie = (IArrayIndexerExpression)expr2;
|
||||
bool hasExcludedLoopVar = iaie.Indices.Any(index => Recognizer.GetVariables(index).Any(excludedLoopVar.Equals));
|
||||
if (hasExcludedLoopVar)
|
||||
return GetPrefixInParent(iaie.Target, excludedLoopVar);
|
||||
|
|
|
@ -39,8 +39,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
loopStart = Recognizer.LoopStartExpression(ifs);
|
||||
loopSize = Recognizer.LoopSizeExpression(ifs);
|
||||
}
|
||||
else if(ifs.Condition is IBinaryExpression) {
|
||||
IBinaryExpression ibe = (IBinaryExpression)ifs.Condition;
|
||||
else if (ifs.Condition is IBinaryExpression ibe)
|
||||
{
|
||||
if (ibe.Operator == BinaryOperator.GreaterThanOrEqual)
|
||||
{
|
||||
// loop is "for(int i = end; i >= start; i--)"
|
||||
|
|
|
@ -14,7 +14,6 @@ using Microsoft.ML.Probabilistic.Factors.Attributes;
|
|||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
using Microsoft.ML.Probabilistic.Compiler.Transforms;
|
||||
using Microsoft.ML.Probabilistic.Compiler;
|
||||
using Microsoft.ML.Probabilistic.Compiler.Reflection;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
||||
|
@ -25,7 +24,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
/// </summary>
|
||||
internal class DefaultFactorManager
|
||||
{
|
||||
private FactorManager factorManager = new FactorManager();
|
||||
private readonly FactorManager factorManager = new FactorManager();
|
||||
|
||||
public IAlgorithm[] algs = { new VariationalMessagePassing(), new ExpectationPropagation() };
|
||||
|
||||
|
@ -48,21 +47,21 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
|
||||
private string GenerateNewFactorTable()
|
||||
{
|
||||
List<Tuple<string, Compiler.Transforms.FactorManager.FactorInfo>> factorList = new List<Tuple<string, Compiler.Transforms.FactorManager.FactorInfo>>();
|
||||
IEnumerable<Compiler.Transforms.FactorManager.FactorInfo> factorInfos = Compiler.Transforms.FactorManager.GetFactorInfos();
|
||||
List<Tuple<string, FactorManager.FactorInfo>> factorList = new List<Tuple<string, FactorManager.FactorInfo>>();
|
||||
IEnumerable<FactorManager.FactorInfo> factorInfos = FactorManager.GetFactorInfos();
|
||||
|
||||
foreach (Compiler.Transforms.FactorManager.FactorInfo info in factorInfos)
|
||||
foreach (FactorManager.FactorInfo info in factorInfos)
|
||||
{
|
||||
MethodInfo method = info.Method;
|
||||
// omit obsolete, unsupported, and hidden factors
|
||||
if (HiddenAttribute.IsDefined(method) ||
|
||||
HiddenAttribute.IsDefined(method.DeclaringType) ||
|
||||
Attribute.IsDefined(method, typeof(System.ObsoleteAttribute))) continue;
|
||||
//if (method.Name != "Logistic") continue;
|
||||
Attribute.IsDefined(method, typeof(ObsoleteAttribute))) continue;
|
||||
string itemName = StringUtil.MethodFullNameToString(method);
|
||||
factorList.Add(new Tuple<string, Compiler.Transforms.FactorManager.FactorInfo>(itemName, info));
|
||||
////if (itemName != "EnumSupport.AreEqual<TEnum>") continue;
|
||||
factorList.Add(new Tuple<string, FactorManager.FactorInfo>(itemName, info));
|
||||
}
|
||||
factorList.Sort((elem1, elem2) => elem1.Item1.CompareTo(elem2.Item1));
|
||||
factorList.Sort((elem1, elem2) => elem1.Item2.ToString().CompareTo(elem2.Item2.ToString()));
|
||||
|
||||
string html = $@"<!DOCTYPE html>
|
||||
<html>
|
||||
|
@ -88,7 +87,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
return headers.ToString();
|
||||
}
|
||||
|
||||
private string GetFactorHtml(string name, Compiler.Transforms.FactorManager.FactorInfo info)
|
||||
private string GetFactorHtml(string name, FactorManager.FactorInfo info)
|
||||
{
|
||||
Console.WriteLine($"Scanning {info}");
|
||||
var currentFactor = new StringBuilder();
|
||||
|
@ -104,13 +103,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
args.Append("</span>");
|
||||
|
||||
string boldString = info.IsDeterministicFactor ? "" : " text-bold";
|
||||
currentFactor.Append($@"<div class=""box-left{boldString}"">{StringUtil.EscapeXmlCharacters(name)}</div><div class=""box-left"">{args.ToString()}</div>");
|
||||
currentFactor.Append($@"<div class=""box-left{boldString}"">{StringUtil.EscapeXmlCharacters(name)}</div><div class=""box-left"">{args}</div>");
|
||||
|
||||
foreach (IAlgorithm alg in this.algs)
|
||||
{
|
||||
QualityBand minQB, modeQB, maxQB;
|
||||
ICollection<StochasticityPattern> patterns =
|
||||
GetAlgorithmPatterns(alg, info, ShowMissingEvidences, out minQB, out modeQB, out maxQB);
|
||||
GetAlgorithmPatterns(alg, info, ShowMissingEvidences, out QualityBand minQB, out QualityBand modeQB, out QualityBand maxQB);
|
||||
if (patterns.Count == 0)
|
||||
{
|
||||
// not implemented
|
||||
|
@ -142,7 +140,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
if (sp.Partial) partialCount++;
|
||||
if (sp.evidenceFound) evidenceCount++;
|
||||
if (sb.Length > 0) sb.Append(",</span> ");
|
||||
sb.Append($"<span>{sp.ToString()}");
|
||||
sb.Append($"<span>{sp}");
|
||||
}
|
||||
sb.Append("</span>");
|
||||
if (notSupportedCount > 0)
|
||||
|
@ -216,7 +214,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
.container {{
|
||||
display: grid;
|
||||
width: 100%;
|
||||
grid-template-columns: {gridTemplate.ToString()};
|
||||
grid-template-columns: {gridTemplate};
|
||||
max-width: 1700px;
|
||||
box-sizing: border-box;
|
||||
border-left: 1px solid #aaa;
|
||||
|
@ -310,9 +308,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
{
|
||||
try
|
||||
{
|
||||
var p = new Process();
|
||||
p.StartInfo = new ProcessStartInfo(filename);
|
||||
p.StartInfo.UseShellExecute = true;
|
||||
var p = new Process
|
||||
{
|
||||
StartInfo = new ProcessStartInfo(filename)
|
||||
{
|
||||
UseShellExecute = true
|
||||
}
|
||||
};
|
||||
p.Start();
|
||||
}
|
||||
catch (Exception ex)
|
||||
|
@ -331,9 +333,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
/// <param name="modeQB">The most common quality band attached to operators for this factor</param>
|
||||
/// <param name="maxQB">The maximum quality band attached to operators for this factor</param>
|
||||
/// <returns></returns>
|
||||
private ICollection<StochasticityPattern> GetAlgorithmPatterns(IAlgorithm alg, Compiler.Transforms.FactorManager.FactorInfo info, bool ShowMissingEvidences,
|
||||
private ICollection<StochasticityPattern> GetAlgorithmPatterns(IAlgorithm alg, FactorManager.FactorInfo info, bool ShowMissingEvidences,
|
||||
out QualityBand minQB, out QualityBand modeQB, out QualityBand maxQB)
|
||||
{
|
||||
bool verbose = false;
|
||||
ICollection<StochasticityPattern> patterns = new Set<StochasticityPattern>();
|
||||
string suffix = alg.GetOperatorMethodSuffix(new List<ICompilerAttribute>());
|
||||
string evidenceMethodName = alg.GetEvidenceMethodName(new List<ICompilerAttribute>());
|
||||
|
@ -372,13 +375,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
// for stochastic arguments.
|
||||
foreach (MessageFcnInfo mfi in mfis)
|
||||
{
|
||||
//if(info.Method.ContainsGenericParameters) Console.WriteLine();
|
||||
// Console.WriteLine("mfi=" + mfi.Method);
|
||||
if (verbose) Trace.WriteLine(mfi.Method);
|
||||
StochasticityPattern sp = new StochasticityPattern(info);
|
||||
sp.notSupported = mfi.NotSupportedMessage;
|
||||
Dictionary<string, Type> parameterTypes = new Dictionary<string, Type>();
|
||||
foreach (KeyValuePair<string, Type> kvp in mfi.GetParameterTypes()) parameterTypes[kvp.Key] = kvp.Value;
|
||||
string target = (string)mfi.TargetParameter;
|
||||
string target = mfi.TargetParameter;
|
||||
if (!parameterTypes.ContainsKey(target))
|
||||
{
|
||||
if (!mfi.PassResultIndex)
|
||||
|
@ -403,17 +405,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
if (info.Method.IsGenericMethodDefinition)
|
||||
{
|
||||
// fill in the factor's type parameters from the type arguments of the operator class.
|
||||
IDictionary<string, Type> typeArgs = Compiler.Transforms.FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
|
||||
IDictionary<string, Type> typeArgs = FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
|
||||
try
|
||||
{
|
||||
MethodInfo newMethod = Compiler.Transforms.FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
|
||||
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(newMethod);
|
||||
MethodInfo newMethod = FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
|
||||
sp.info = FactorManager.GetFactorInfo(newMethod);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(info.Method);
|
||||
//Console.WriteLine("Could not infer generic type parameters of "+StringUtil.MethodFullNameToString(info.Method));
|
||||
//continue;
|
||||
sp.info = FactorManager.GetFactorInfo(info.Method);
|
||||
if (verbose)
|
||||
Trace.WriteLine("Could not infer generic type parameters of " + StringUtil.MethodFullNameToString(info.Method));
|
||||
}
|
||||
// from now on, sp.info != info
|
||||
}
|
||||
|
@ -423,7 +425,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
ArgInfo ai = new ArgInfo();
|
||||
if (!parameterTypes.ContainsKey(field))
|
||||
{
|
||||
//Console.WriteLine("not found: " + field + " in " + mfi.Method);
|
||||
if (verbose) Trace.WriteLine("not found: " + field + " in " + mfi.Method);
|
||||
continue;
|
||||
}
|
||||
ai.factorType = sp.info.ParameterTypes[field];
|
||||
|
@ -431,7 +433,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
sp.argInfos[field] = ai;
|
||||
}
|
||||
|
||||
if (!sp.IsValid()) continue;
|
||||
if (verbose) Trace.WriteLine(sp);
|
||||
if (!sp.IsValid())
|
||||
{
|
||||
if (verbose) Trace.WriteLine("Invalid stochasticity pattern");
|
||||
continue;
|
||||
}
|
||||
List<StochasticityPattern> toRemove = new List<StochasticityPattern>();
|
||||
foreach (StochasticityPattern sp2 in patterns)
|
||||
{
|
||||
|
@ -441,7 +448,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
StochasticityPattern sp3 = sp.Intersect(sp2);
|
||||
if (sp3 == null)
|
||||
{
|
||||
StochasticityPattern sp4 = sp.Intersect(sp2);
|
||||
throw new Exception("intersection is null");
|
||||
}
|
||||
sp = sp3;
|
||||
|
@ -457,30 +463,35 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
int qbCnt = -1;
|
||||
modeQB = QualityBand.Unknown;
|
||||
foreach (KeyValuePair<QualityBand, int> kvp in qbCounts)
|
||||
{
|
||||
if (kvp.Value > qbCnt)
|
||||
{
|
||||
qbCnt = kvp.Value;
|
||||
modeQB = kvp.Key;
|
||||
}
|
||||
}
|
||||
|
||||
modeQB = minQB;
|
||||
patterns = IntersectPatterns(patterns);
|
||||
patterns = GetCompletePatterns(patterns, suffix);
|
||||
patterns = AddDeterministicPatterns(patterns);
|
||||
VerifyPatterns(patterns, suffix, evidenceMethodName, ShowMissingEvidences);
|
||||
StochasticityPattern bestp = GetBestPattern(patterns);
|
||||
patterns = new Set<StochasticityPattern>();
|
||||
if (bestp != null)
|
||||
{
|
||||
patterns.Add(bestp);
|
||||
((Set<StochasticityPattern>)patterns).AddRange(bestp.deterministicPatterns);
|
||||
}
|
||||
return patterns;
|
||||
return RemoveDuplicatePatterns(patterns);
|
||||
//return GetBestPatternPlusDeterministic(patterns);
|
||||
}
|
||||
|
||||
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
|
||||
#pragma warning disable 162
|
||||
#endif
|
||||
private static List<StochasticityPattern> RemoveDuplicatePatterns(IEnumerable<StochasticityPattern> patterns)
|
||||
{
|
||||
var result = new List<StochasticityPattern>();
|
||||
foreach (var pattern in patterns.OrderByDescending(sp => sp.foundCount))
|
||||
{
|
||||
if (!result.Any(sp => sp.IsSamePattern(pattern)))
|
||||
{
|
||||
result.Add(pattern);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create a new set of patterns containing the closure of all pairwise intersections of patterns.
|
||||
|
@ -502,10 +513,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
return result;
|
||||
}
|
||||
|
||||
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
|
||||
#pragma warning restore 162
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Modify the collection by adding all pairwise intersections of patterns.
|
||||
/// </summary>
|
||||
|
@ -526,7 +533,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
}
|
||||
}
|
||||
|
||||
private static StochasticityPattern GetStochasticityPattern(Compiler.Transforms.FactorManager.FactorInfo info, MessageFcnInfo mfi)
|
||||
private static StochasticityPattern GetStochasticityPattern(FactorManager.FactorInfo info, MessageFcnInfo mfi)
|
||||
{
|
||||
StochasticityPattern sp = new StochasticityPattern(info);
|
||||
sp.notSupported = mfi.NotSupportedMessage;
|
||||
|
@ -548,15 +555,15 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
if (info.Method.IsGenericMethodDefinition)
|
||||
{
|
||||
// fill in the factor's type parameters from the type arguments of the operator class.
|
||||
IDictionary<string, Type> typeArgs = Compiler.Transforms.FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
|
||||
IDictionary<string, Type> typeArgs = FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
|
||||
try
|
||||
{
|
||||
MethodInfo newMethod = Compiler.Transforms.FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
|
||||
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(newMethod);
|
||||
MethodInfo newMethod = FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
|
||||
sp.info = FactorManager.GetFactorInfo(newMethod);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(info.Method);
|
||||
sp.info = FactorManager.GetFactorInfo(info.Method);
|
||||
}
|
||||
// from now on, sp.info != info
|
||||
}
|
||||
|
@ -576,6 +583,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
return sp;
|
||||
}
|
||||
|
||||
private static ICollection<StochasticityPattern> GetBestPatternPlusDeterministic(ICollection<StochasticityPattern> patterns)
|
||||
{
|
||||
StochasticityPattern bestp = GetBestPattern(patterns);
|
||||
if (bestp == null) return patterns;
|
||||
var result = new Set<StochasticityPattern>();
|
||||
result.Add(bestp);
|
||||
result.AddRange(bestp.deterministicPatterns);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns null if patterns is empty
|
||||
private static StochasticityPattern GetBestPattern(IEnumerable<StochasticityPattern> patterns)
|
||||
{
|
||||
StochasticityPattern bestp = null;
|
||||
|
@ -597,7 +615,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
return bestp;
|
||||
}
|
||||
|
||||
private static ICollection<StochasticityPattern> GetCompletePatterns(IEnumerable<StochasticityPattern> patterns, string suffix)
|
||||
private static Set<StochasticityPattern> GetCompletePatterns(IEnumerable<StochasticityPattern> patterns, string suffix)
|
||||
{
|
||||
Set<StochasticityPattern> pendingPatterns;
|
||||
Set<StochasticityPattern> completePatterns = new Set<StochasticityPattern>();
|
||||
|
@ -689,9 +707,11 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
if (entry2.Key == entry.Key)
|
||||
{
|
||||
// replace type with deterministic type.
|
||||
ArgInfo arg2 = new ArgInfo();
|
||||
arg2.factorType = arg.factorType;
|
||||
arg2.opType = arg.factorType;
|
||||
ArgInfo arg2 = new ArgInfo
|
||||
{
|
||||
factorType = arg.factorType,
|
||||
opType = arg.factorType
|
||||
};
|
||||
sp2.argInfos[entry2.Key] = arg2;
|
||||
}
|
||||
else
|
||||
|
@ -854,7 +874,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
/// </summary>
|
||||
internal class StochasticityPattern
|
||||
{
|
||||
internal Compiler.Transforms.FactorManager.FactorInfo info;
|
||||
internal FactorManager.FactorInfo info;
|
||||
internal Dictionary<string, ArgInfo> argInfos = new Dictionary<string, ArgInfo>();
|
||||
public Set<StochasticityPattern> deterministicPatterns = new Set<StochasticityPattern>();
|
||||
|
||||
|
@ -866,7 +886,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
|
||||
public bool IsComplete
|
||||
{
|
||||
get { return (argInfos.Count >= nonConstantCount); }
|
||||
get { return argInfos.Count >= nonConstantCount; }
|
||||
}
|
||||
|
||||
public bool Partial
|
||||
|
@ -874,7 +894,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
get { return foundCount < neededCount; }
|
||||
}
|
||||
|
||||
public StochasticityPattern(Compiler.Transforms.FactorManager.FactorInfo info)
|
||||
public StochasticityPattern(FactorManager.FactorInfo info)
|
||||
{
|
||||
this.info = info;
|
||||
nonConstantCount = 0;
|
||||
|
@ -947,8 +967,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
|
||||
public override bool Equals(object o)
|
||||
{
|
||||
StochasticityPattern sp = o as StochasticityPattern;
|
||||
if (sp == null) return false;
|
||||
if (!(o is StochasticityPattern sp)) return false;
|
||||
if (sp.info != info) return false;
|
||||
foreach (string field in info.ParameterNames)
|
||||
{
|
||||
|
@ -977,16 +996,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
return hash;
|
||||
}
|
||||
|
||||
internal bool IsSamePattern(StochasticityPattern sp)
|
||||
public bool IsSamePattern(StochasticityPattern sp)
|
||||
{
|
||||
if (sp.info != info) return false;
|
||||
foreach (string field in info.ParameterNames)
|
||||
{
|
||||
if (!argInfos.ContainsKey(field) || !sp.argInfos.ContainsKey(field)) continue;
|
||||
if (argInfos[field].IsStoch != sp.argInfos[field].IsStoch) return false;
|
||||
// Must compare types because a pattern involving Gaussians is not equivalent
|
||||
// to one using Gammas.
|
||||
//if (argInfos[field].IsStoch != sp.argInfos[field].IsStoch) return false;
|
||||
if (argInfos[field].opType != sp.argInfos[field].opType) return false;
|
||||
//if (argInfos[field].opType != sp.argInfos[field].opType) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -1102,8 +1121,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
if (result.notSupported == null) result.notSupported = that.notSupported;
|
||||
foreach (string field in info.ParameterNames)
|
||||
{
|
||||
ArgInfo arg1, arg2;
|
||||
that.argInfos.TryGetValue(field, out arg2);
|
||||
ArgInfo arg1;
|
||||
that.argInfos.TryGetValue(field, out ArgInfo arg2);
|
||||
if (!argInfos.TryGetValue(field, out arg1))
|
||||
{
|
||||
if (arg2 == null)
|
||||
|
@ -1142,13 +1161,15 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
/// <returns></returns>
|
||||
internal Type IntersectTypes(Type t1, Type t2)
|
||||
{
|
||||
Set<Type> types = new Set<Type>();
|
||||
types.Add(t1);
|
||||
types.Add(t2);
|
||||
Set<Type> types = new Set<Type>
|
||||
{
|
||||
t1,
|
||||
t2
|
||||
};
|
||||
Type t;
|
||||
if (!intersectionCache.TryGetValue(types, out t))
|
||||
{
|
||||
t = Microsoft.ML.Probabilistic.Compiler.Reflection.Binding.IntersectTypes(t1, t2);
|
||||
t = Binding.IntersectTypes(t1, t2);
|
||||
if (t != null) intersectionCache[types] = t;
|
||||
}
|
||||
return t;
|
||||
|
|
|
@ -138,7 +138,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
// http://www.lovettsoftware.com/blogengine.net/post/2010/07/16/DGML-with-Style.aspx
|
||||
var settings = new XmlWriterSettings();
|
||||
settings.Indent = true;
|
||||
Func<Color, string> colorToString = c => c.ToString().Trim('"');
|
||||
////string colorToString(Color c) => c.ToString().Trim('"');
|
||||
using (var writer = XmlWriter.Create(path, settings))
|
||||
{
|
||||
writer.WriteStartElement("DirectedGraph", "http://schemas.microsoft.com/vs/2009/dgml");
|
||||
|
@ -166,7 +166,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
|
|||
writer.WriteAttributeString("NodeRadius", "100");
|
||||
}
|
||||
// The Outline attribute seems to be ignored.
|
||||
//writer.WriteAttributeString("Outline", colorToString(node.Attr.Color));
|
||||
////writer.WriteAttributeString("Outline", colorToString(node.Attr.Color));
|
||||
writer.WriteEndElement();
|
||||
}
|
||||
writer.WriteEndElement();
|
||||
|
|
|
@ -12,10 +12,6 @@ using Microsoft.ML.Probabilistic.Collections;
|
|||
|
||||
namespace Microsoft.ML.Probabilistic.Utilities
|
||||
{
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning disable 1591
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Helpful methods for converting objects to strings.
|
||||
/// </summary>
|
||||
|
@ -804,6 +800,5 @@ namespace Microsoft.ML.Probabilistic.Utilities
|
|||
}
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning restore 1591
|
||||
#endif
|
||||
}
|
|
@ -260,7 +260,6 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteAreEqualOp"]/doc/*'/>
|
||||
[FactorMethod(typeof(Factor), "AreEqual", typeof(int), typeof(int))]
|
||||
[FactorMethod(typeof(EnumSupport), "AreEqual<>")]
|
||||
[Quality(QualityBand.Mature)]
|
||||
public static class DiscreteAreEqualOp
|
||||
{
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Factors
|
||||
{
|
||||
using System;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Factors.Attributes;
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/doc/*'/>
|
||||
[FactorMethod(typeof(Factor), "BetaFromMeanAndTotalCount")]
|
||||
[Quality(QualityBand.Experimental)]
|
||||
public static class BetaFromMeanAndTotalCountOp
|
||||
{
|
||||
/// <summary>
|
||||
/// How much damping to use to avoid improper messages. A higher value implies more damping.
|
||||
/// </summary>
|
||||
public static double damping = 0.0;
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
|
||||
public static double AverageLogFactor(double prob, double mean, double totalCount)
|
||||
{
|
||||
return LogAverageFactor(prob, mean, totalCount);
|
||||
}
|
||||
|
||||
// TODO: VMP evidence messages for stochastic inputs (see DirichletOp)
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageLogarithm(Beta, Gamma)"]/*'/>
|
||||
public static Beta ProbAverageLogarithm(Beta mean, [Proper] Gamma totalCount)
|
||||
{
|
||||
double meanMean = mean.GetMean();
|
||||
double totalCountMean = totalCount.GetMean();
|
||||
return new Beta(meanMean * totalCountMean, (1 - meanMean) * totalCountMean);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageLogarithm(Beta, double)"]/*'/>
|
||||
public static Beta ProbAverageLogarithm(Beta mean, double totalCount)
|
||||
{
|
||||
double meanMean = mean.GetMean();
|
||||
return new Beta(meanMean * totalCount, (1 - meanMean) * totalCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageLogarithm(double, Gamma)"]/*'/>
|
||||
public static Beta ProbAverageLogarithm(double mean, [Proper] Gamma totalCount)
|
||||
{
|
||||
double totalCountMean = totalCount.GetMean();
|
||||
return new Beta(mean * totalCountMean, (1 - mean) * totalCountMean);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, Gamma, Beta)"]/*'/>
|
||||
public static Beta MeanAverageLogarithm(double prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
|
||||
{
|
||||
return MeanAverageLogarithm(Beta.PointMass(prob), mean, totalCount, to_mean);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, double, Beta)"]/*'/>
|
||||
public static Beta MeanAverageLogarithm(double prob, Beta mean, double totalCount, Beta to_mean)
|
||||
{
|
||||
return MeanAverageLogarithm(Beta.PointMass(prob), mean, Gamma.PointMass(totalCount), to_mean);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageLogarithm(Beta, Beta, Gamma, Beta)"]/*'/>
|
||||
public static Beta MeanAverageLogarithm([Proper] Beta prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
|
||||
{
|
||||
// Calculate gradient using method for DirichletOp
|
||||
prob.GetMeanLogs(out double ELogP, out double ELogOneMinusP);
|
||||
Vector gradS = DirichletOp.CalculateGradientForMean(
|
||||
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
|
||||
totalCount,
|
||||
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
|
||||
// Project onto a Beta distribution
|
||||
Matrix A = new Matrix(2, 2);
|
||||
double c = MMath.Trigamma(mean.TotalCount);
|
||||
A[0, 0] = MMath.Trigamma(mean.TrueCount) - c;
|
||||
A[1, 0] = A[0, 1] = -c;
|
||||
A[1, 1] = MMath.Trigamma(mean.FalseCount) - c;
|
||||
Vector theta = GammaFromShapeAndRateOp.twoByTwoInverse(A) * gradS;
|
||||
Beta approximateFactor = new Beta(theta[0] + 1, theta[1] + 1);
|
||||
if (damping == 0.0)
|
||||
return approximateFactor;
|
||||
else
|
||||
return (approximateFactor ^ (1 - damping)) * (to_mean ^ damping);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="TotalCountAverageLogarithm(Beta, Gamma, Beta, Gamma)"]/*'/>
|
||||
public static Gamma TotalCountAverageLogarithm([Proper] Beta mean, Gamma totalCount, [SkipIfUniform] Beta prob, Gamma to_totalCount)
|
||||
{
|
||||
prob.GetMeanLogs(out double ELogP, out double ELogOneMinusP);
|
||||
Gamma approximateFactor = DirichletOp.TotalCountAverageLogarithmHelper(
|
||||
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
|
||||
totalCount,
|
||||
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
|
||||
if (damping == 0.0)
|
||||
return approximateFactor;
|
||||
else
|
||||
return (approximateFactor ^ (1 - damping)) * (to_totalCount ^ damping);
|
||||
}
|
||||
|
||||
//---------------------------- EP -----------------------------
|
||||
|
||||
private const string NotSupportedMessage = "Expectation Propagation does not currently support beta distributions with stochastic arguments.";
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(double prob, double mean, double totalCount)
|
||||
{
|
||||
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
|
||||
return g.GetLogProb(prob);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="LogAverageFactor(Beta, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(Beta prob, double mean, double totalCount)
|
||||
{
|
||||
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
|
||||
return g.GetLogAverageOf(prob);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(double, double)"]/*'/>
|
||||
public static Beta ProbAverageConditional(double mean, double totalCount)
|
||||
{
|
||||
return new Beta(mean * totalCount, (1 - mean) * totalCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(Beta, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(Beta, double)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, double totalCount)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(double, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta ProbAverageConditional(double mean, [SkipIfUniform] Gamma totalCount)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, double, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, double, double, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, double prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, double, Beta, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, Beta, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, double, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Gamma result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, Beta, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Gamma result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Factors
|
||||
{
|
||||
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Factors.Attributes;
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/doc/*'/>
|
||||
[FactorMethod(new string[] { "sample", "mean", "variance" }, typeof(Beta), "SampleFromMeanAndVariance")]
|
||||
[Quality(QualityBand.Stable)]
|
||||
public static class BetaFromMeanAndVarianceOp
|
||||
{
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
|
||||
public static double LogAverageFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
|
||||
{
|
||||
return to_sample.GetLogAverageOf(sample);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
|
||||
[Skip]
|
||||
public static double LogEvidenceRatio(Beta sample, double mean, double variance)
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
|
||||
public static double AverageLogFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
|
||||
{
|
||||
return to_sample.GetAverageLog(sample);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(double sample, double mean, double variance)
|
||||
{
|
||||
return SampleAverageConditional(mean, variance).GetLogProb(sample);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
|
||||
public static double LogEvidenceRatio(double sample, double mean, double variance)
|
||||
{
|
||||
return LogAverageFactor(sample, mean, variance);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
|
||||
public static double AverageLogFactor(double sample, double mean, double variance)
|
||||
{
|
||||
return LogAverageFactor(sample, mean, variance);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
|
||||
public static Beta SampleAverageLogarithm(double mean, double variance)
|
||||
{
|
||||
return Beta.FromMeanAndVariance(mean, variance);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
|
||||
public static Beta SampleAverageConditional(double mean, double variance)
|
||||
{
|
||||
return Beta.FromMeanAndVariance(mean, variance);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,301 +1,183 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Factors
|
||||
{
|
||||
using System;
|
||||
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
using Microsoft.ML.Probabilistic.Factors.Attributes;
|
||||
|
||||
using System;
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/doc/*'/>
|
||||
[FactorMethod(typeof(Factor), "BetaFromMeanAndTotalCount")]
|
||||
[Quality(QualityBand.Experimental)]
|
||||
public static class BetaOp
|
||||
{
|
||||
/// <summary>
|
||||
/// How much damping to use to avoid improper messages. A higher value implies more damping.
|
||||
/// </summary>
|
||||
public static double damping = 0.0;
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
|
||||
public static double AverageLogFactor(double prob, double mean, double totalCount)
|
||||
{
|
||||
return LogAverageFactor(prob, mean, totalCount);
|
||||
}
|
||||
|
||||
// TODO: VMP evidence messages for stochastic inputs (see DirichletOp)
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageLogarithm(Beta, Gamma)"]/*'/>
|
||||
public static Beta ProbAverageLogarithm(Beta mean, [Proper] Gamma totalCount)
|
||||
{
|
||||
double meanMean = mean.GetMean();
|
||||
double totalCountMean = totalCount.GetMean();
|
||||
return (new Beta(meanMean * totalCountMean, (1 - meanMean) * totalCountMean));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageLogarithm(Beta, double)"]/*'/>
|
||||
public static Beta ProbAverageLogarithm(Beta mean, double totalCount)
|
||||
{
|
||||
double meanMean = mean.GetMean();
|
||||
return (new Beta(meanMean * totalCount, (1 - meanMean) * totalCount));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageLogarithm(double, Gamma)"]/*'/>
|
||||
public static Beta ProbAverageLogarithm(double mean, [Proper] Gamma totalCount)
|
||||
{
|
||||
double totalCountMean = totalCount.GetMean();
|
||||
return (new Beta(mean * totalCountMean, (1 - mean) * totalCountMean));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, Gamma, Beta)"]/*'/>
|
||||
public static Beta MeanAverageLogarithm(double prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
|
||||
{
|
||||
return MeanAverageLogarithm(Beta.PointMass(prob), mean, totalCount, to_mean);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, double, Beta)"]/*'/>
|
||||
public static Beta MeanAverageLogarithm(double prob, Beta mean, double totalCount, Beta to_mean)
|
||||
{
|
||||
return MeanAverageLogarithm(Beta.PointMass(prob), mean, Gamma.PointMass(totalCount), to_mean);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageLogarithm(Beta, Beta, Gamma, Beta)"]/*'/>
|
||||
public static Beta MeanAverageLogarithm([Proper] Beta prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
|
||||
{
|
||||
// Calculate gradient using method for DirichletOp
|
||||
double ELogP, ELogOneMinusP;
|
||||
prob.GetMeanLogs(out ELogP, out ELogOneMinusP);
|
||||
Vector gradS = DirichletOp.CalculateGradientForMean(
|
||||
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
|
||||
totalCount,
|
||||
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
|
||||
// Project onto a Beta distribution
|
||||
Matrix A = new Matrix(2, 2);
|
||||
double c = MMath.Trigamma(mean.TotalCount);
|
||||
A[0, 0] = MMath.Trigamma(mean.TrueCount) - c;
|
||||
A[1, 0] = A[0, 1] = -c;
|
||||
A[1, 1] = MMath.Trigamma(mean.FalseCount) - c;
|
||||
Vector theta = GammaFromShapeAndRateOp.twoByTwoInverse(A) * gradS;
|
||||
Beta approximateFactor = new Beta(theta[0] + 1, theta[1] + 1);
|
||||
if (damping == 0.0)
|
||||
return approximateFactor;
|
||||
else
|
||||
return (approximateFactor ^ (1 - damping)) * (to_mean ^ damping);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TotalCountAverageLogarithm(Beta, Gamma, Beta, Gamma)"]/*'/>
|
||||
public static Gamma TotalCountAverageLogarithm([Proper] Beta mean, Gamma totalCount, [SkipIfUniform] Beta prob, Gamma to_totalCount)
|
||||
{
|
||||
double ELogP, ELogOneMinusP;
|
||||
prob.GetMeanLogs(out ELogP, out ELogOneMinusP);
|
||||
Gamma approximateFactor = DirichletOp.TotalCountAverageLogarithmHelper(
|
||||
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
|
||||
totalCount,
|
||||
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
|
||||
if (damping == 0.0)
|
||||
return approximateFactor;
|
||||
else
|
||||
return (approximateFactor ^ (1 - damping)) * (to_totalCount ^ damping);
|
||||
}
|
||||
|
||||
//---------------------------- EP -----------------------------
|
||||
|
||||
private const string NotSupportedMessage = "Expectation Propagation does not currently support beta distributions with stochastic arguments.";
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(double prob, double mean, double totalCount)
|
||||
{
|
||||
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
|
||||
return g.GetLogProb(prob);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(Beta, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(Beta prob, double mean, double totalCount)
|
||||
{
|
||||
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
|
||||
return g.GetLogAverageOf(prob);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(double, double)"]/*'/>
|
||||
public static Beta ProbAverageConditional(double mean, double totalCount)
|
||||
{
|
||||
return new Beta(mean * totalCount, (1 - mean) * totalCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(Beta, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(Beta, double)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, double totalCount)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(double, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta ProbAverageConditional(double mean, [SkipIfUniform] Gamma totalCount)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, double, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, double, double, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, double prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, double, Beta, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, Beta, Beta)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, double, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Gamma result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, Beta, Gamma)"]/*'/>
|
||||
[NotSupported(NotSupportedMessage)]
|
||||
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Gamma result)
|
||||
{
|
||||
throw new NotSupportedException(NotSupportedMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/doc/*'/>
|
||||
[FactorMethod(typeof(Beta), "Sample", typeof(double), typeof(double))]
|
||||
[Quality(QualityBand.Stable)]
|
||||
public static class BetaFromTrueAndFalseCountsOp
|
||||
public static class BetaOp
|
||||
{
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
|
||||
public static double LogAverageFactor(Beta sample, double trueCount, double falseCount, [Fresh] Beta to_sample)
|
||||
{
|
||||
return to_sample.GetLogAverageOf(sample);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
|
||||
[Skip]
|
||||
public static double LogEvidenceRatio(Beta sample, double trueCount, double falseCount)
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
|
||||
public static double AverageLogFactor(Beta sample, double trueCount, double falseCount, [Fresh] Beta to_sample)
|
||||
{
|
||||
return to_sample.GetAverageLog(sample);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(double sample, double trueCount, double falseCount)
|
||||
{
|
||||
return SampleAverageConditional(trueCount, falseCount).GetLogProb(sample);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
|
||||
public static double LogEvidenceRatio(double sample, double trueCount, double falseCount)
|
||||
{
|
||||
return LogAverageFactor(sample, trueCount, falseCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
|
||||
public static double AverageLogFactor(double sample, double trueCount, double falseCount)
|
||||
{
|
||||
return LogAverageFactor(sample, trueCount, falseCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
|
||||
public static Beta SampleAverageLogarithm(double trueCount, double falseCount)
|
||||
{
|
||||
return new Beta(trueCount, falseCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
|
||||
public static Beta SampleAverageConditional(double trueCount, double falseCount)
|
||||
{
|
||||
return new Beta(trueCount, falseCount);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/doc/*'/>
|
||||
[FactorMethod(new string[] { "sample", "mean", "variance" }, typeof(Beta), "SampleFromMeanAndVariance")]
|
||||
[Quality(QualityBand.Stable)]
|
||||
public static class BetaFromMeanAndVarianceOp
|
||||
{
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
|
||||
public static double LogAverageFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
|
||||
const string TrueCountMustBeOneMessage = "falseCount is Gamma and trueCount is not 1";
|
||||
const string FalseCountMustBeOneMessage = "trueCount is Gamma and falseCount is not 1";
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, Gamma, double)"]/*'/>
|
||||
public static double AverageLogFactor(double sample, Gamma trueCount, double falseCount)
|
||||
{
|
||||
return to_sample.GetLogAverageOf(sample);
|
||||
if (trueCount.IsPointMass)
|
||||
{
|
||||
return LogEvidenceRatio(sample, trueCount.Point, falseCount);
|
||||
}
|
||||
else if (falseCount == 1)
|
||||
{
|
||||
// The factor is f(x, a) = a x^(a-1)
|
||||
// whose logarithm is log(a) + (a-1)*log(x)
|
||||
return trueCount.GetMeanLog() + (trueCount.GetMean() - 1) * Math.Log(sample);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(FalseCountMustBeOneMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
|
||||
[Skip]
|
||||
public static double LogEvidenceRatio(Beta sample, double mean, double variance)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(double, Gamma, double)"]/*'/>
|
||||
public static double LogEvidenceRatio(double sample, Gamma trueCount, double falseCount)
|
||||
{
|
||||
return 0.0;
|
||||
if (trueCount.IsPointMass)
|
||||
{
|
||||
return LogEvidenceRatio(sample, trueCount.Point, falseCount);
|
||||
}
|
||||
else if (falseCount == 1)
|
||||
{
|
||||
// The factor is f(x, a) = a x^(a-1)
|
||||
// f(x, a) Ga(a; s, r) = a^s exp(-r*a + (a-1)*log(x)) r^s / Gamma(s)
|
||||
// Z = Gamma(s+1)/Gamma(s) * r^s / (r - log(x))^(s+1) / x
|
||||
return Math.Log(trueCount.Shape) - Math.Log(sample) + trueCount.Shape * Math.Log(trueCount.Rate) - (trueCount.Shape + 1) * Math.Log(trueCount.Rate - Math.Log(sample));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(FalseCountMustBeOneMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
|
||||
public static double AverageLogFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, double, Gamma)"]/*'/>
|
||||
public static double AverageLogFactor(double sample, double trueCount, Gamma falseCount)
|
||||
{
|
||||
return to_sample.GetAverageLog(sample);
|
||||
if (falseCount.IsPointMass)
|
||||
{
|
||||
return LogEvidenceRatio(sample, trueCount, falseCount.Point);
|
||||
}
|
||||
else if (trueCount == 1)
|
||||
{
|
||||
// The factor is f(x, b) = b (1-x)^(b-1)
|
||||
return falseCount.GetMeanLog() + (falseCount.GetMean() - 1) * Math.Log(1 - sample);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(TrueCountMustBeOneMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
|
||||
public static double LogAverageFactor(double sample, double mean, double variance)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(double, double, Gamma)"]/*'/>
|
||||
public static double LogEvidenceRatio(double sample, double trueCount, Gamma falseCount)
|
||||
{
|
||||
return SampleAverageConditional(mean, variance).GetLogProb(sample);
|
||||
if (falseCount.IsPointMass)
|
||||
{
|
||||
return LogEvidenceRatio(sample, trueCount, falseCount.Point);
|
||||
}
|
||||
else if (trueCount == 1)
|
||||
{
|
||||
// The factor is f(x, b) = b (1-x)^(b-1)
|
||||
return Math.Log(falseCount.Shape) - Math.Log(1 - sample) + falseCount.Shape * Math.Log(falseCount.Rate) - (falseCount.Shape + 1) * Math.Log(falseCount.Rate - Math.Log(1 - sample));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(TrueCountMustBeOneMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
|
||||
public static double LogEvidenceRatio(double sample, double mean, double variance)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TrueCountAverageConditional(double, double)"]/*'/>
|
||||
public static Gamma TrueCountAverageConditional(double sample, double falseCount)
|
||||
{
|
||||
return LogAverageFactor(sample, mean, variance);
|
||||
return TrueCountAverageLogarithm(sample, falseCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
|
||||
public static double AverageLogFactor(double sample, double mean, double variance)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TrueCountAverageLogarithm(double, double)"]/*'/>
|
||||
public static Gamma TrueCountAverageLogarithm(double sample, double falseCount)
|
||||
{
|
||||
return LogAverageFactor(sample, mean, variance);
|
||||
if (falseCount == 1)
|
||||
{
|
||||
return Gamma.FromShapeAndRate(2, -Math.Log(sample));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(FalseCountMustBeOneMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
|
||||
public static Beta SampleAverageLogarithm(double mean, double variance)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="FalseCountAverageConditional(double, double)"]/*'/>
|
||||
public static Gamma FalseCountAverageConditional(double sample, double trueCount)
|
||||
{
|
||||
return Beta.FromMeanAndVariance(mean, variance);
|
||||
return FalseCountAverageLogarithm(sample, trueCount);
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
|
||||
public static Beta SampleAverageConditional(double mean, double variance)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="FalseCountAverageLogarithm(double, double)"]/*'/>
|
||||
public static Gamma FalseCountAverageLogarithm(double sample, double trueCount)
|
||||
{
|
||||
return Beta.FromMeanAndVariance(mean, variance);
|
||||
if (trueCount == 1)
|
||||
{
|
||||
return Gamma.FromShapeAndRate(2, -Math.Log(1 - sample));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new NotSupportedException(TrueCountMustBeOneMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -346,5 +346,25 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
{
|
||||
return DiscreteAreEqualOp.LogEvidenceRatio(areEqual, ToInt(A), ToInt(B));
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(Bernoulli)"]/*'/>
|
||||
[Skip]
|
||||
public static double LogEvidenceRatio(Bernoulli areEqual)
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Evidence message for VMP.
|
||||
/// </summary>
|
||||
/// <returns>Zero</returns>
|
||||
/// <remarks><para>
|
||||
/// In Variational Message Passing, the evidence contribution of a deterministic factor is zero.
|
||||
/// </para></remarks>
|
||||
[Skip]
|
||||
public static double AverageLogFactor()
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -677,7 +677,7 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
int sum = 0;
|
||||
for (int i = 0; i < array.Count; i++)
|
||||
{
|
||||
sum = sum + array[i];
|
||||
sum += array[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
@ -692,7 +692,7 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
double sum = 0;
|
||||
for (int i = 0; i < array.Count; i++)
|
||||
{
|
||||
sum = sum + array[i];
|
||||
sum += array[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
|
|
@ -3952,7 +3952,7 @@
|
|||
</remarks>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
<message_op_class name="BetaOp">
|
||||
<message_op_class name="BetaFromMeanAndTotalCountOp">
|
||||
<doc>
|
||||
<summary>Provides outgoing messages for <see cref="Factor.BetaFromMeanAndTotalCount(double, double)" />, given random arguments to the function.</summary>
|
||||
</doc>
|
||||
|
@ -4220,7 +4220,92 @@
|
|||
<paramref name="prob" /> is not a proper distribution.</exception>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
<message_op_class name="BetaFromTrueAndFalseCountsOp">
|
||||
<message_op_class name="BetaFromMeanAndVarianceOp">
|
||||
<doc>
|
||||
<summary>Provides outgoing messages for <see cref="Beta.SampleFromMeanAndVariance(double, double)" />, given random arguments to the function.</summary>
|
||||
</doc>
|
||||
<message_doc name="LogAverageFactor(Beta, double, double, Beta)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Incoming message from <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
|
||||
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance))</c>.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogEvidenceRatio(Beta, double, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Incoming message from <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance) / sum_sample p(sample) messageTo(sample))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="AverageLogFactor(Beta, double, double, Beta)">
|
||||
<summary>Evidence message for VMP.</summary>
|
||||
<param name="sample">Incoming message from <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
|
||||
<returns>Average of the factor's log-value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>sum_(sample) p(sample) log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogAverageFactor(double, double, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogEvidenceRatio(double, double, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="AverageLogFactor(double, double, double)">
|
||||
<summary>Evidence message for VMP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Average of the factor's log-value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="SampleAverageLogarithm(double, double)">
|
||||
<summary>VMP message to <c>sample</c>.</summary>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>The outgoing VMP message to the <c>sample</c> argument.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="SampleAverageConditional(double, double)">
|
||||
<summary>EP message to <c>sample</c>.</summary>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>The outgoing EP message to the <c>sample</c> argument.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
<message_op_class name="BetaOp">
|
||||
<doc>
|
||||
<summary>Provides outgoing messages for <see cref="Beta.Sample(double, double)" />, given random arguments to the function.</summary>
|
||||
</doc>
|
||||
|
@ -4304,89 +4389,80 @@
|
|||
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
<message_op_class name="BetaFromMeanAndVarianceOp">
|
||||
<doc>
|
||||
<summary>Provides outgoing messages for <see cref="Beta.SampleFromMeanAndVariance(double, double)" />, given random arguments to the function.</summary>
|
||||
</doc>
|
||||
<message_doc name="LogAverageFactor(Beta, double, double, Beta)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Incoming message from <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
|
||||
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance))</c>.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogEvidenceRatio(Beta, double, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Incoming message from <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance) / sum_sample p(sample) messageTo(sample))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="AverageLogFactor(Beta, double, double, Beta)">
|
||||
<summary>Evidence message for VMP.</summary>
|
||||
<param name="sample">Incoming message from <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
|
||||
<returns>Average of the factor's log-value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>sum_(sample) p(sample) log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogAverageFactor(double, double, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogEvidenceRatio(double, double, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="AverageLogFactor(double, double, double)">
|
||||
<message_doc name="AverageLogFactor(double, Gamma, double)">
|
||||
<summary>Evidence message for VMP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<param name="trueCount">Incoming message from <c>trueCount</c>.</param>
|
||||
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
|
||||
<returns>Average of the factor's log-value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
|
||||
<para>The formula for the result is <c>sum_(trueCount) p(trueCount) log(factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="SampleAverageLogarithm(double, double)">
|
||||
<summary>VMP message to <c>sample</c>.</summary>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>The outgoing VMP message to the <c>sample</c> argument.</returns>
|
||||
<message_doc name="LogEvidenceRatio(double, Gamma, double)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="trueCount">Incoming message from <c>trueCount</c>.</param>
|
||||
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
|
||||
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
|
||||
<para>The formula for the result is <c>log(sum_(trueCount) p(trueCount) factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="SampleAverageConditional(double, double)">
|
||||
<summary>EP message to <c>sample</c>.</summary>
|
||||
<param name="mean">Constant value for <c>mean</c>.</param>
|
||||
<param name="variance">Constant value for <c>variance</c>.</param>
|
||||
<returns>The outgoing EP message to the <c>sample</c> argument.</returns>
|
||||
<message_doc name="AverageLogFactor(double, double, Gamma)">
|
||||
<summary>Evidence message for VMP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
|
||||
<param name="falseCount">Incoming message from <c>falseCount</c>.</param>
|
||||
<returns>Average of the factor's log-value across the given argument distributions.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
|
||||
<para>The formula for the result is <c>sum_(falseCount) p(falseCount) log(factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="LogEvidenceRatio(double, double, Gamma)">
|
||||
<summary>Evidence message for EP.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
|
||||
<param name="falseCount">Incoming message from <c>falseCount</c>.</param>
|
||||
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
|
||||
<remarks>
|
||||
<para>The formula for the result is <c>log(sum_(falseCount) p(falseCount) factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="TrueCountAverageConditional(double, double)">
|
||||
<summary>EP message to <c>trueCount</c>.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
|
||||
<returns>The outgoing EP message to the <c>trueCount</c> argument.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>trueCount</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="TrueCountAverageLogarithm(double, double)">
|
||||
<summary>VMP message to <c>trueCount</c>.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
|
||||
<returns>The outgoing VMP message to the <c>trueCount</c> argument.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>trueCount</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="FalseCountAverageConditional(double, double)">
|
||||
<summary>EP message to <c>falseCount</c>.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
|
||||
<returns>The outgoing EP message to the <c>falseCount</c> argument.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>falseCount</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
<message_doc name="FalseCountAverageLogarithm(double, double)">
|
||||
<summary>VMP message to <c>falseCount</c>.</summary>
|
||||
<param name="sample">Constant value for <c>sample</c>.</param>
|
||||
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
|
||||
<returns>The outgoing VMP message to the <c>falseCount</c> argument.</returns>
|
||||
<remarks>
|
||||
<para>The outgoing message is the factor viewed as a function of <c>falseCount</c> conditioned on the given values.</para>
|
||||
</remarks>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
|
|
|
@ -59,58 +59,56 @@ namespace Microsoft.ML.Probabilistic.Tools.PrepareSource
|
|||
|
||||
private static void ProcessFile(string sourceFileName, string destinationFileName, Dictionary<string, XDocument> loadedDocFiles)
|
||||
{
|
||||
using (var reader = new StreamReader(sourceFileName))
|
||||
using (var writer = new StreamWriter(destinationFileName))
|
||||
using var reader = new StreamReader(sourceFileName);
|
||||
using var writer = new StreamWriter(destinationFileName);
|
||||
string line;
|
||||
int lineNumber = 0;
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
{
|
||||
string line;
|
||||
int lineNumber = 0;
|
||||
while ((line = reader.ReadLine()) != null)
|
||||
++lineNumber;
|
||||
|
||||
string trimmedLine = line.Trim();
|
||||
if (!trimmedLine.StartsWith("/// <include", StringComparison.InvariantCulture))
|
||||
{
|
||||
++lineNumber;
|
||||
// Not a line with an include directive
|
||||
writer.WriteLine(line);
|
||||
continue;
|
||||
}
|
||||
|
||||
string trimmedLine = line.Trim();
|
||||
if (!trimmedLine.StartsWith("/// <include", StringComparison.InvariantCulture))
|
||||
{
|
||||
// Not a line with an include directive
|
||||
writer.WriteLine(line);
|
||||
continue;
|
||||
}
|
||||
string includeString = trimmedLine.Substring("/// ".Length);
|
||||
var includeDoc = XDocument.Parse(includeString);
|
||||
|
||||
string includeString = trimmedLine.Substring("/// ".Length);
|
||||
var includeDoc = XDocument.Parse(includeString);
|
||||
XAttribute fileAttribute = includeDoc.Root.Attribute("file");
|
||||
XAttribute pathAttribute = includeDoc.Root.Attribute("path");
|
||||
if (fileAttribute == null || pathAttribute == null)
|
||||
{
|
||||
Error("An ill-formed include directive at {0}:{1}", sourceFileName, lineNumber);
|
||||
}
|
||||
|
||||
XAttribute fileAttribute = includeDoc.Root.Attribute("file");
|
||||
XAttribute pathAttribute = includeDoc.Root.Attribute("path");
|
||||
if (fileAttribute == null || pathAttribute == null)
|
||||
{
|
||||
Error("An ill-formed include directive at {0}:{1}", sourceFileName, lineNumber);
|
||||
}
|
||||
string fullDocFileName = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(sourceFileName), fileAttribute.Value));
|
||||
XDocument docFile;
|
||||
if (!loadedDocFiles.TryGetValue(fullDocFileName, out docFile))
|
||||
{
|
||||
docFile = XDocument.Load(fullDocFileName);
|
||||
loadedDocFiles.Add(fullDocFileName, docFile);
|
||||
}
|
||||
|
||||
string fullDocFileName = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(sourceFileName), fileAttribute.Value));
|
||||
XDocument docFile;
|
||||
if (!loadedDocFiles.TryGetValue(fullDocFileName, out docFile))
|
||||
XElement[] docElements = ((IEnumerable)docFile.XPathEvaluate(pathAttribute.Value)).Cast<XElement>().ToArray();
|
||||
if (docElements.Length == 0)
|
||||
{
|
||||
Console.WriteLine("WARNING: nothing to include for the include directive at {0}:{1}", sourceFileName, lineNumber);
|
||||
}
|
||||
else
|
||||
{
|
||||
int indexOfDocStart = line.IndexOf("/// <include", StringComparison.InvariantCulture);
|
||||
foreach (XElement docElement in docElements)
|
||||
{
|
||||
docFile = XDocument.Load(fullDocFileName);
|
||||
loadedDocFiles.Add(fullDocFileName, docFile);
|
||||
}
|
||||
|
||||
XElement[] docElements = ((IEnumerable)docFile.XPathEvaluate(pathAttribute.Value)).Cast<XElement>().ToArray();
|
||||
if (docElements.Length == 0)
|
||||
{
|
||||
Console.WriteLine("WARNING: nothing to include for the include directive at {0}:{1}", sourceFileName, lineNumber);
|
||||
}
|
||||
else
|
||||
{
|
||||
int indexOfDocStart = line.IndexOf("/// <include", StringComparison.InvariantCulture);
|
||||
foreach (XElement docElement in docElements)
|
||||
string[] docElementStringLines = docElement.ToString().Split(new[] { Environment.NewLine }, StringSplitOptions.None);
|
||||
string indentation = new string(' ', indexOfDocStart);
|
||||
foreach (string docElementStringLine in docElementStringLines)
|
||||
{
|
||||
string[] docElementStringLines = docElement.ToString().Split(new[] { Environment.NewLine }, StringSplitOptions.None);
|
||||
string indentation = new string(' ', indexOfDocStart);
|
||||
foreach (string docElementStringLine in docElementStringLines)
|
||||
{
|
||||
writer.WriteLine("{0}/// {1}", indentation, docElementStringLine);
|
||||
}
|
||||
}
|
||||
writer.WriteLine("{0}/// {1}", indentation, docElementStringLine);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,7 +83,8 @@ namespace TestApp
|
|||
Stopwatch watch = new Stopwatch();
|
||||
watch.Start();
|
||||
|
||||
if (false)
|
||||
bool runAllTests = false;
|
||||
if (runAllTests)
|
||||
{
|
||||
// Run all tests (need to run in 64-bit else OutOfMemory due to loading many DLLs)
|
||||
// This is useful when looking for failures due to certain compiler options.
|
||||
|
@ -100,7 +101,11 @@ namespace TestApp
|
|||
//}
|
||||
//TestUtils.CheckTransformNames();
|
||||
}
|
||||
//InferenceEngine.ShowFactorManager(true);
|
||||
bool showFactorManager = true;
|
||||
if (showFactorManager)
|
||||
{
|
||||
InferenceEngine.ShowFactorManager(true);
|
||||
}
|
||||
#if NETFRAMEWORK
|
||||
logWriter.Dispose();
|
||||
#endif
|
||||
|
|
|
@ -465,8 +465,11 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
private VariableArray<VariableArray<T>, T[][]> Transpose<T>(VariableArray<VariableArray<T>, T[][]> array, VariableArray<VariableArray<int>, int[][]> indices, Range r1,
|
||||
Range r2)
|
||||
private VariableArray<VariableArray<T>, T[][]> Transpose<T>(
|
||||
VariableArray<VariableArray<T>, T[][]> array,
|
||||
VariableArray<VariableArray<int>, int[][]> indices,
|
||||
Range r1,
|
||||
Range r2)
|
||||
{
|
||||
return array;
|
||||
}
|
||||
|
|
|
@ -1518,7 +1518,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
public class VectorIsPositiveEP
|
||||
{
|
||||
IGeneratedAlgorithm gen;
|
||||
readonly IGeneratedAlgorithm gen;
|
||||
public int NumberOfIterations = 100;
|
||||
|
||||
public VectorIsPositiveEP(int dim)
|
||||
|
|
|
@ -27,7 +27,91 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
public class GatedFactorTests
|
||||
{
|
||||
private static bool verbose = false;
|
||||
private static readonly bool verbose = false;
|
||||
|
||||
[Fact]
|
||||
public void BetaTrueCountIsGamma()
|
||||
{
|
||||
BetaTrueCountIsGamma(new ExpectationPropagation());
|
||||
BetaTrueCountIsGamma(new VariationalMessagePassing());
|
||||
|
||||
void BetaTrueCountIsGamma(IAlgorithm algorithm)
|
||||
{
|
||||
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||
var evBlock = Variable.If(evidence);
|
||||
|
||||
double shape = 2.5;
|
||||
double rate = 3.5;
|
||||
Variable<double> trueCount = Variable.GammaFromShapeAndRate(shape, rate);
|
||||
trueCount.Name = nameof(trueCount);
|
||||
Variable<double> x = Variable.Beta(trueCount, 1).Named("x");
|
||||
x.ObservedValue = 0.25;
|
||||
|
||||
evBlock.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
var trueCountActual = engine.Infer<Gamma>(trueCount);
|
||||
double xRate = -System.Math.Log(x.ObservedValue);
|
||||
var trueCountExpected = Gamma.FromShapeAndRate(shape + 1, rate + xRate);
|
||||
Assert.True(trueCountExpected.MaxDiff(trueCountActual) < 1e-10);
|
||||
|
||||
var evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
|
||||
double evExpected = double.NegativeInfinity;
|
||||
var trueCounts = EpTests.linspace(1e-6, 10, 1000);
|
||||
foreach (var trueCountValue in trueCounts)
|
||||
{
|
||||
trueCount.ObservedValue = trueCountValue;
|
||||
double ev = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
evExpected = MMath.LogSumExp(evExpected, ev);
|
||||
}
|
||||
double increment = trueCounts[1] - trueCounts[0];
|
||||
evExpected += System.Math.Log(increment);
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-8) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BetaFalseCountIsGamma()
|
||||
{
|
||||
BetaFalseCountIsGamma(new ExpectationPropagation());
|
||||
BetaFalseCountIsGamma(new VariationalMessagePassing());
|
||||
|
||||
void BetaFalseCountIsGamma(IAlgorithm algorithm)
|
||||
{
|
||||
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
|
||||
var evBlock = Variable.If(evidence);
|
||||
|
||||
double shape = 2.5;
|
||||
double rate = 3.5;
|
||||
Variable<double> falseCount = Variable.GammaFromShapeAndRate(shape, rate);
|
||||
falseCount.Name = nameof(falseCount);
|
||||
Variable<double> x = Variable.Beta(1, falseCount).Named("x");
|
||||
x.ObservedValue = 0.25;
|
||||
|
||||
evBlock.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
var falseCountActual = engine.Infer<Gamma>(falseCount);
|
||||
double xRate = -System.Math.Log(1 - x.ObservedValue);
|
||||
var falseCountExpected = Gamma.FromShapeAndRate(shape + 1, rate + xRate);
|
||||
Assert.True(falseCountExpected.MaxDiff(falseCountActual) < 1e-10);
|
||||
|
||||
var evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
|
||||
double evExpected = double.NegativeInfinity;
|
||||
var falseCounts = EpTests.linspace(1e-6, 10, 1000);
|
||||
foreach (var falseCountValue in falseCounts)
|
||||
{
|
||||
falseCount.ObservedValue = falseCountValue;
|
||||
double ev = engine.Infer<Bernoulli>(evidence).LogOdds;
|
||||
evExpected = MMath.LogSumExp(evExpected, ev);
|
||||
}
|
||||
double increment = falseCounts[1] - falseCounts[0];
|
||||
evExpected += System.Math.Log(increment);
|
||||
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-8) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TruncatedGaussianIsBetweenTest()
|
||||
|
@ -2452,7 +2536,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Variable.ConstrainEqualRandom(x[item], xPrior[item]);
|
||||
block.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
for (int ctrial = 0; ctrial < 3; ctrial++)
|
||||
{
|
||||
Vector cMean = Vector.Zero(dc);
|
||||
|
|
|
@ -76,8 +76,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
/// <returns>Beta-distributed random variable</returns>
|
||||
public static Variable<double> BetaFromMeanAndTotalCount(Variable<double> mean, Variable<double> totalCount)
|
||||
{
|
||||
return Variable<double>.Factor(Factor.BetaFromMeanAndTotalCount, mean, totalCount)
|
||||
.Attrib(new MarginalPrototype(new Beta(0, 0)));
|
||||
return Variable<double>.Factor(Factor.BetaFromMeanAndTotalCount, mean, totalCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
Загрузка…
Ссылка в новой задаче