Observed Beta variables support a Gamma-distributed pseudocount if the other pseudocount is always 1 (#386)

* ShowFactorManager shows more patterns
* Tidy up code
This commit is contained in:
Tom Minka 2022-01-07 18:07:48 +00:00 коммит произвёл GitHub
Родитель fc93e12851
Коммит 079250ef67
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
22 изменённых файлов: 799 добавлений и 477 удалений

Просмотреть файл

@ -257,9 +257,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
for (int depth = 0; depth < arrayDepth; depth++)
{
bool notLast = (depth < arrayDepth - 1);
int rank;
Type arrayType = sourceArray.GetExpressionType();
Util.GetElementType(arrayType, out rank);
Util.GetElementType(arrayType, out int rank);
if (sizes.Count <= depth) sizes.Add(new IExpression[rank]);
IExpression[] indices = new IExpression[rank];
for (int i = 0; i < rank; i++)
@ -822,8 +821,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
if (char.IsDigit(lastChar))
prefix += "_";
}
int count;
counts.TryGetValue(prefix, out count);
counts.TryGetValue(prefix, out int count);
if (count == 0) count = 1;
counts[prefix] = count + 1;
if (count == 1) return prefix;

Просмотреть файл

@ -553,7 +553,6 @@ namespace Microsoft.ML.Probabilistic.Compiler
}
// Model methods with up to ten parameters are supported directly.
#pragma warning disable 1591
/// <exclude/>
public delegate void ModelDefinitionMethod();
@ -716,7 +715,6 @@ namespace Microsoft.ML.Probabilistic.Compiler
{
return CompileWithoutParams(method.Method);
}
#pragma warning restore 1591
/// <summary>
/// Compiles the model defined in MSL by the specified method. The model parameters are set

Просмотреть файл

@ -4648,7 +4648,7 @@ namespace Microsoft.ML.Probabilistic.Models
/// </remarks>
public static Variable<string> StringCapitalized(Variable<int> minLength, Variable<int> maxLength = null)
{
return ReferenceEquals(maxLength, null)
return maxLength is null
? Variable<string>.Factor(Factor.StringCapitalized, minLength)
: Variable<string>.Factor(Factor.StringCapitalized, minLength, maxLength);
}
@ -4711,7 +4711,7 @@ namespace Microsoft.ML.Probabilistic.Models
/// </remarks>
public static Variable<string> String(Variable<int> minLength, Variable<int> maxLength, Variable<DiscreteChar> allowedCharacters)
{
return ReferenceEquals(maxLength, null)
return maxLength is null
? Variable<string>.Factor(Factor.String, minLength, allowedCharacters)
: Variable<string>.Factor(Factor.String, minLength, maxLength, allowedCharacters);
}
@ -5222,7 +5222,6 @@ namespace Microsoft.ML.Probabilistic.Models
/// <summary>
/// Enumeration over supported operators.
/// </summary>
#pragma warning disable 1591
public enum Operator
{
Plus,
@ -5245,7 +5244,6 @@ namespace Microsoft.ML.Probabilistic.Models
GreaterThanOrEqual,
LessThanOrEqual
};
#pragma warning restore 1591
}
/// <summary>
@ -6482,11 +6480,11 @@ namespace Microsoft.ML.Probabilistic.Models
{
diff = null;
}
if ((object)diff == null)
if (diff is null)
{
diff = OperatorFactor<T>(Operator.Minus, a, b);
}
if ((object)diff != null)
if (diff is object)
{
return IsPositive((Variable<double>)(Variable)diff);
}
@ -6497,7 +6495,7 @@ namespace Microsoft.ML.Probabilistic.Models
private static Variable<bool> NotOrNull(Variable<bool> Variable)
{
return ((object)Variable == null) ? null : !Variable;
return (Variable is null) ? null : !Variable;
}
/// <summary>
@ -6683,7 +6681,7 @@ namespace Microsoft.ML.Probabilistic.Models
public static Variable<T> operator -(Variable<T> a)
{
Variable<T> f = OperatorFactor<T>(Operator.Negative, a);
if ((object)f != null) return f;
if (f is object) return f;
else if (a is Variable<double>)
{
Variable<T> zero = (Variable<T>)(object)Constant(0.0);

Просмотреть файл

@ -19,10 +19,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Algorithms;
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
internal class FactorManager
{
/// <summary>
@ -476,8 +472,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (method.IsDefined(typeof(MultiplyAllAttribute), true))
{
// MultiplyAll implies SkipIfAllUniform
var list = new List<object>(attrs);
list.Add(new SkipIfAllUniformAttribute());
var list = new List<object>(attrs)
{
new SkipIfAllUniformAttribute()
};
attrs = list.ToArray();
}
foreach (SkipIfAllUniformAttribute attr in attrs)
@ -617,7 +615,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
/// </summary>
public bool IsVoid
{
get { return (Method.ReturnType == typeof(void)); }
get { return Method.ReturnType == typeof(void); }
}
/// <summary>
@ -1481,9 +1479,11 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
protected IDictionary<string, string> GetNameMapping(FactorMethodAttribute attr)
{
string[] newParameterNames = attr.NewParameterNames;
Dictionary<string, string> map = new Dictionary<string, string>();
// always include the empty string
map[string.Empty] = string.Empty;
Dictionary<string, string> map = new Dictionary<string, string>
{
// always include the empty string
[string.Empty] = string.Empty
};
if (newParameterNames == null)
{
// map the fields to themselves
@ -1892,8 +1892,4 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
return "i";
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -136,9 +136,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
foreach (var container in containers.inputs)
{
if (container is IForStatement)
if (container is IForStatement ifs)
{
IForStatement ifs = (IForStatement)container;
IVariableDeclaration loopVar = Recognizer.LoopVariable(ifs);
if (context.InputAttributes.Has<Partitioned>(loopVar)) return true;
}
@ -345,14 +344,14 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
/// <summary>
/// Holds the ContainerInfo for each 'for' loop on the inputStack as they are being converted.
/// </summary>
private Stack<ContainerInfo> containerInfos = new Stack<ContainerInfo>();
private readonly Stack<ContainerInfo> containerInfos = new Stack<ContainerInfo>();
/// <summary>
/// Maps an IForStatement to its LocalInfos
/// </summary>
internal Dictionary<IStatement, Dictionary<IExpression, LocalInfo>> localInfoOfStmt = new Dictionary<IStatement, Dictionary<IExpression, LocalInfo>>(ReferenceEqualityComparer<IStatement>.Instance);
private Stack<IStatement> openContainers = new Stack<IStatement>();
private readonly Stack<IStatement> openContainers = new Stack<IStatement>();
private bool InPartitionedLoop;
ModelCompiler compiler;
readonly ModelCompiler compiler;
internal LocalAnalysisTransform(ModelCompiler compiler)
{
@ -435,8 +434,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
else
minExpr = GetCommonPrefix(minExpr, expr);
}
LocalInfo minInfo;
localInfos.TryGetValue(minExpr, out minInfo);
localInfos.TryGetValue(minExpr, out LocalInfo minInfo);
foreach (var entry in localInfos)
{
var expr = entry.Key;
@ -518,9 +516,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
var loopSize = Recognizer.LoopSizeExpression(closedContainer);
bool loopMustExecute = false;
if (loopSize is ILiteralExpression)
if (loopSize is ILiteralExpression ile)
{
int loopSizeAsInt = (int)((ILiteralExpression)loopSize).Value;
int loopSizeAsInt = (int)ile.Value;
if (loopSizeAsInt > 0)
{
loopMustExecute = true;
@ -566,9 +564,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
string containerString = StringUtil.ToString(containers.Select(c =>
{
if (c is IForStatement)
if (c is IForStatement ifs)
{
IForStatement ifs = (IForStatement)c;
if (c is IBrokenForStatement)
return ifs.Initializer.ToString() + " // broken";
else
@ -584,9 +581,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
private static IExpression GetContainerExpression(IStatement container)
{
if (container is IForStatement) return Recognizer.LoopSizeExpression((IForStatement)container);
else if (container is IConditionStatement) return ((IConditionStatement)container).Condition;
else if (container is IRepeatStatement) return ((IRepeatStatement)container).Count;
if (container is IForStatement ifs) return Recognizer.LoopSizeExpression(ifs);
else if (container is IConditionStatement ics) return ics.Condition;
else if (container is IRepeatStatement irs) return irs.Count;
else throw new ArgumentException($"unrecognized container type: {container.GetType()}");
}
@ -741,9 +738,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
AddUsage(prefix, false);
}
}
if (prefix.Target is IArrayIndexerExpression)
if (prefix.Target is IArrayIndexerExpression target)
{
prefix = (IArrayIndexerExpression)prefix.Target;
prefix = target;
isPrefix = true;
}
else
@ -763,9 +760,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (expr is IArrayIndexerExpression)
{
IExpression expr2 = expr;
while (expr2 is IArrayIndexerExpression)
while (expr2 is IArrayIndexerExpression iaie)
{
IArrayIndexerExpression iaie = (IArrayIndexerExpression)expr2;
bool hasExcludedLoopVar = iaie.Indices.Any(index => Recognizer.GetVariables(index).Any(excludedLoopVar.Equals));
if (hasExcludedLoopVar)
return GetPrefixInParent(iaie.Target, excludedLoopVar);

Просмотреть файл

@ -39,8 +39,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
loopStart = Recognizer.LoopStartExpression(ifs);
loopSize = Recognizer.LoopSizeExpression(ifs);
}
else if(ifs.Condition is IBinaryExpression) {
IBinaryExpression ibe = (IBinaryExpression)ifs.Condition;
else if (ifs.Condition is IBinaryExpression ibe)
{
if (ibe.Operator == BinaryOperator.GreaterThanOrEqual)
{
// loop is "for(int i = end; i >= start; i--)"

Просмотреть файл

@ -14,7 +14,6 @@ using Microsoft.ML.Probabilistic.Factors.Attributes;
using Microsoft.ML.Probabilistic.Utilities;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Compiler.Transforms;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Compiler.Reflection;
namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
@ -25,7 +24,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
/// </summary>
internal class DefaultFactorManager
{
private FactorManager factorManager = new FactorManager();
private readonly FactorManager factorManager = new FactorManager();
public IAlgorithm[] algs = { new VariationalMessagePassing(), new ExpectationPropagation() };
@ -48,21 +47,21 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
private string GenerateNewFactorTable()
{
List<Tuple<string, Compiler.Transforms.FactorManager.FactorInfo>> factorList = new List<Tuple<string, Compiler.Transforms.FactorManager.FactorInfo>>();
IEnumerable<Compiler.Transforms.FactorManager.FactorInfo> factorInfos = Compiler.Transforms.FactorManager.GetFactorInfos();
List<Tuple<string, FactorManager.FactorInfo>> factorList = new List<Tuple<string, FactorManager.FactorInfo>>();
IEnumerable<FactorManager.FactorInfo> factorInfos = FactorManager.GetFactorInfos();
foreach (Compiler.Transforms.FactorManager.FactorInfo info in factorInfos)
foreach (FactorManager.FactorInfo info in factorInfos)
{
MethodInfo method = info.Method;
// omit obsolete, unsupported, and hidden factors
if (HiddenAttribute.IsDefined(method) ||
HiddenAttribute.IsDefined(method.DeclaringType) ||
Attribute.IsDefined(method, typeof(System.ObsoleteAttribute))) continue;
//if (method.Name != "Logistic") continue;
Attribute.IsDefined(method, typeof(ObsoleteAttribute))) continue;
string itemName = StringUtil.MethodFullNameToString(method);
factorList.Add(new Tuple<string, Compiler.Transforms.FactorManager.FactorInfo>(itemName, info));
////if (itemName != "EnumSupport.AreEqual<TEnum>") continue;
factorList.Add(new Tuple<string, FactorManager.FactorInfo>(itemName, info));
}
factorList.Sort((elem1, elem2) => elem1.Item1.CompareTo(elem2.Item1));
factorList.Sort((elem1, elem2) => elem1.Item2.ToString().CompareTo(elem2.Item2.ToString()));
string html = $@"<!DOCTYPE html>
<html>
@ -88,7 +87,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
return headers.ToString();
}
private string GetFactorHtml(string name, Compiler.Transforms.FactorManager.FactorInfo info)
private string GetFactorHtml(string name, FactorManager.FactorInfo info)
{
Console.WriteLine($"Scanning {info}");
var currentFactor = new StringBuilder();
@ -104,13 +103,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
args.Append("</span>");
string boldString = info.IsDeterministicFactor ? "" : " text-bold";
currentFactor.Append($@"<div class=""box-left{boldString}"">{StringUtil.EscapeXmlCharacters(name)}</div><div class=""box-left"">{args.ToString()}</div>");
currentFactor.Append($@"<div class=""box-left{boldString}"">{StringUtil.EscapeXmlCharacters(name)}</div><div class=""box-left"">{args}</div>");
foreach (IAlgorithm alg in this.algs)
{
QualityBand minQB, modeQB, maxQB;
ICollection<StochasticityPattern> patterns =
GetAlgorithmPatterns(alg, info, ShowMissingEvidences, out minQB, out modeQB, out maxQB);
GetAlgorithmPatterns(alg, info, ShowMissingEvidences, out QualityBand minQB, out QualityBand modeQB, out QualityBand maxQB);
if (patterns.Count == 0)
{
// not implemented
@ -142,7 +140,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
if (sp.Partial) partialCount++;
if (sp.evidenceFound) evidenceCount++;
if (sb.Length > 0) sb.Append(",</span> ");
sb.Append($"<span>{sp.ToString()}");
sb.Append($"<span>{sp}");
}
sb.Append("</span>");
if (notSupportedCount > 0)
@ -216,7 +214,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
.container {{
display: grid;
width: 100%;
grid-template-columns: {gridTemplate.ToString()};
grid-template-columns: {gridTemplate};
max-width: 1700px;
box-sizing: border-box;
border-left: 1px solid #aaa;
@ -310,9 +308,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
{
try
{
var p = new Process();
p.StartInfo = new ProcessStartInfo(filename);
p.StartInfo.UseShellExecute = true;
var p = new Process
{
StartInfo = new ProcessStartInfo(filename)
{
UseShellExecute = true
}
};
p.Start();
}
catch (Exception ex)
@ -331,9 +333,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
/// <param name="modeQB">The most common quality band attached to operators for this factor</param>
/// <param name="maxQB">The maximum quality band attached to operators for this factor</param>
/// <returns></returns>
private ICollection<StochasticityPattern> GetAlgorithmPatterns(IAlgorithm alg, Compiler.Transforms.FactorManager.FactorInfo info, bool ShowMissingEvidences,
private ICollection<StochasticityPattern> GetAlgorithmPatterns(IAlgorithm alg, FactorManager.FactorInfo info, bool ShowMissingEvidences,
out QualityBand minQB, out QualityBand modeQB, out QualityBand maxQB)
{
bool verbose = false;
ICollection<StochasticityPattern> patterns = new Set<StochasticityPattern>();
string suffix = alg.GetOperatorMethodSuffix(new List<ICompilerAttribute>());
string evidenceMethodName = alg.GetEvidenceMethodName(new List<ICompilerAttribute>());
@ -372,13 +375,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
// for stochastic arguments.
foreach (MessageFcnInfo mfi in mfis)
{
//if(info.Method.ContainsGenericParameters) Console.WriteLine();
// Console.WriteLine("mfi=" + mfi.Method);
if (verbose) Trace.WriteLine(mfi.Method);
StochasticityPattern sp = new StochasticityPattern(info);
sp.notSupported = mfi.NotSupportedMessage;
Dictionary<string, Type> parameterTypes = new Dictionary<string, Type>();
foreach (KeyValuePair<string, Type> kvp in mfi.GetParameterTypes()) parameterTypes[kvp.Key] = kvp.Value;
string target = (string)mfi.TargetParameter;
string target = mfi.TargetParameter;
if (!parameterTypes.ContainsKey(target))
{
if (!mfi.PassResultIndex)
@ -403,17 +405,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
if (info.Method.IsGenericMethodDefinition)
{
// fill in the factor's type parameters from the type arguments of the operator class.
IDictionary<string, Type> typeArgs = Compiler.Transforms.FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
IDictionary<string, Type> typeArgs = FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
try
{
MethodInfo newMethod = Compiler.Transforms.FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(newMethod);
MethodInfo newMethod = FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
sp.info = FactorManager.GetFactorInfo(newMethod);
}
catch (Exception)
{
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(info.Method);
//Console.WriteLine("Could not infer generic type parameters of "+StringUtil.MethodFullNameToString(info.Method));
//continue;
sp.info = FactorManager.GetFactorInfo(info.Method);
if (verbose)
Trace.WriteLine("Could not infer generic type parameters of " + StringUtil.MethodFullNameToString(info.Method));
}
// from now on, sp.info != info
}
@ -423,7 +425,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
ArgInfo ai = new ArgInfo();
if (!parameterTypes.ContainsKey(field))
{
//Console.WriteLine("not found: " + field + " in " + mfi.Method);
if (verbose) Trace.WriteLine("not found: " + field + " in " + mfi.Method);
continue;
}
ai.factorType = sp.info.ParameterTypes[field];
@ -431,7 +433,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
sp.argInfos[field] = ai;
}
if (!sp.IsValid()) continue;
if (verbose) Trace.WriteLine(sp);
if (!sp.IsValid())
{
if (verbose) Trace.WriteLine("Invalid stochasticity pattern");
continue;
}
List<StochasticityPattern> toRemove = new List<StochasticityPattern>();
foreach (StochasticityPattern sp2 in patterns)
{
@ -441,7 +448,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
StochasticityPattern sp3 = sp.Intersect(sp2);
if (sp3 == null)
{
StochasticityPattern sp4 = sp.Intersect(sp2);
throw new Exception("intersection is null");
}
sp = sp3;
@ -457,30 +463,35 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
int qbCnt = -1;
modeQB = QualityBand.Unknown;
foreach (KeyValuePair<QualityBand, int> kvp in qbCounts)
{
if (kvp.Value > qbCnt)
{
qbCnt = kvp.Value;
modeQB = kvp.Key;
}
}
modeQB = minQB;
patterns = IntersectPatterns(patterns);
patterns = GetCompletePatterns(patterns, suffix);
patterns = AddDeterministicPatterns(patterns);
VerifyPatterns(patterns, suffix, evidenceMethodName, ShowMissingEvidences);
StochasticityPattern bestp = GetBestPattern(patterns);
patterns = new Set<StochasticityPattern>();
if (bestp != null)
{
patterns.Add(bestp);
((Set<StochasticityPattern>)patterns).AddRange(bestp.deterministicPatterns);
}
return patterns;
return RemoveDuplicatePatterns(patterns);
//return GetBestPatternPlusDeterministic(patterns);
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning disable 162
#endif
private static List<StochasticityPattern> RemoveDuplicatePatterns(IEnumerable<StochasticityPattern> patterns)
{
var result = new List<StochasticityPattern>();
foreach (var pattern in patterns.OrderByDescending(sp => sp.foundCount))
{
if (!result.Any(sp => sp.IsSamePattern(pattern)))
{
result.Add(pattern);
}
}
return result;
}
/// <summary>
/// Create a new set of patterns containing the closure of all pairwise intersections of patterns.
@ -502,10 +513,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
return result;
}
#if SUPPRESS_UNREACHABLE_CODE_WARNINGS
#pragma warning restore 162
#endif
/// <summary>
/// Modify the collection by adding all pairwise intersections of patterns.
/// </summary>
@ -526,7 +533,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
}
}
private static StochasticityPattern GetStochasticityPattern(Compiler.Transforms.FactorManager.FactorInfo info, MessageFcnInfo mfi)
private static StochasticityPattern GetStochasticityPattern(FactorManager.FactorInfo info, MessageFcnInfo mfi)
{
StochasticityPattern sp = new StochasticityPattern(info);
sp.notSupported = mfi.NotSupportedMessage;
@ -548,15 +555,15 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
if (info.Method.IsGenericMethodDefinition)
{
// fill in the factor's type parameters from the type arguments of the operator class.
IDictionary<string, Type> typeArgs = Compiler.Transforms.FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
IDictionary<string, Type> typeArgs = FactorManager.FactorInfo.GetTypeArguments(mfi.Method.DeclaringType);
try
{
MethodInfo newMethod = Compiler.Transforms.FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(newMethod);
MethodInfo newMethod = FactorManager.FactorInfo.MakeGenericMethod(info.Method, typeArgs);
sp.info = FactorManager.GetFactorInfo(newMethod);
}
catch (Exception)
{
sp.info = Compiler.Transforms.FactorManager.GetFactorInfo(info.Method);
sp.info = FactorManager.GetFactorInfo(info.Method);
}
// from now on, sp.info != info
}
@ -576,6 +583,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
return sp;
}
private static ICollection<StochasticityPattern> GetBestPatternPlusDeterministic(ICollection<StochasticityPattern> patterns)
{
StochasticityPattern bestp = GetBestPattern(patterns);
if (bestp == null) return patterns;
var result = new Set<StochasticityPattern>();
result.Add(bestp);
result.AddRange(bestp.deterministicPatterns);
return result;
}
// Returns null if patterns is empty
private static StochasticityPattern GetBestPattern(IEnumerable<StochasticityPattern> patterns)
{
StochasticityPattern bestp = null;
@ -597,7 +615,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
return bestp;
}
private static ICollection<StochasticityPattern> GetCompletePatterns(IEnumerable<StochasticityPattern> patterns, string suffix)
private static Set<StochasticityPattern> GetCompletePatterns(IEnumerable<StochasticityPattern> patterns, string suffix)
{
Set<StochasticityPattern> pendingPatterns;
Set<StochasticityPattern> completePatterns = new Set<StochasticityPattern>();
@ -689,9 +707,11 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
if (entry2.Key == entry.Key)
{
// replace type with deterministic type.
ArgInfo arg2 = new ArgInfo();
arg2.factorType = arg.factorType;
arg2.opType = arg.factorType;
ArgInfo arg2 = new ArgInfo
{
factorType = arg.factorType,
opType = arg.factorType
};
sp2.argInfos[entry2.Key] = arg2;
}
else
@ -854,7 +874,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
/// </summary>
internal class StochasticityPattern
{
internal Compiler.Transforms.FactorManager.FactorInfo info;
internal FactorManager.FactorInfo info;
internal Dictionary<string, ArgInfo> argInfos = new Dictionary<string, ArgInfo>();
public Set<StochasticityPattern> deterministicPatterns = new Set<StochasticityPattern>();
@ -866,7 +886,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
public bool IsComplete
{
get { return (argInfos.Count >= nonConstantCount); }
get { return argInfos.Count >= nonConstantCount; }
}
public bool Partial
@ -874,7 +894,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
get { return foundCount < neededCount; }
}
public StochasticityPattern(Compiler.Transforms.FactorManager.FactorInfo info)
public StochasticityPattern(FactorManager.FactorInfo info)
{
this.info = info;
nonConstantCount = 0;
@ -947,8 +967,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
public override bool Equals(object o)
{
StochasticityPattern sp = o as StochasticityPattern;
if (sp == null) return false;
if (!(o is StochasticityPattern sp)) return false;
if (sp.info != info) return false;
foreach (string field in info.ParameterNames)
{
@ -977,16 +996,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
return hash;
}
internal bool IsSamePattern(StochasticityPattern sp)
public bool IsSamePattern(StochasticityPattern sp)
{
if (sp.info != info) return false;
foreach (string field in info.ParameterNames)
{
if (!argInfos.ContainsKey(field) || !sp.argInfos.ContainsKey(field)) continue;
if (argInfos[field].IsStoch != sp.argInfos[field].IsStoch) return false;
// Must compare types because a pattern involving Gaussians is not equivalent
// to one using Gammas.
//if (argInfos[field].IsStoch != sp.argInfos[field].IsStoch) return false;
if (argInfos[field].opType != sp.argInfos[field].opType) return false;
//if (argInfos[field].opType != sp.argInfos[field].opType) return false;
}
return true;
}
@ -1102,8 +1121,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
if (result.notSupported == null) result.notSupported = that.notSupported;
foreach (string field in info.ParameterNames)
{
ArgInfo arg1, arg2;
that.argInfos.TryGetValue(field, out arg2);
ArgInfo arg1;
that.argInfos.TryGetValue(field, out ArgInfo arg2);
if (!argInfos.TryGetValue(field, out arg1))
{
if (arg2 == null)
@ -1142,13 +1161,15 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
/// <returns></returns>
internal Type IntersectTypes(Type t1, Type t2)
{
Set<Type> types = new Set<Type>();
types.Add(t1);
types.Add(t2);
Set<Type> types = new Set<Type>
{
t1,
t2
};
Type t;
if (!intersectionCache.TryGetValue(types, out t))
{
t = Microsoft.ML.Probabilistic.Compiler.Reflection.Binding.IntersectTypes(t1, t2);
t = Binding.IntersectTypes(t1, t2);
if (t != null) intersectionCache[types] = t;
}
return t;

Просмотреть файл

@ -138,7 +138,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
// http://www.lovettsoftware.com/blogengine.net/post/2010/07/16/DGML-with-Style.aspx
var settings = new XmlWriterSettings();
settings.Indent = true;
Func<Color, string> colorToString = c => c.ToString().Trim('"');
////string colorToString(Color c) => c.ToString().Trim('"');
using (var writer = XmlWriter.Create(path, settings))
{
writer.WriteStartElement("DirectedGraph", "http://schemas.microsoft.com/vs/2009/dgml");
@ -166,7 +166,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Visualizers
writer.WriteAttributeString("NodeRadius", "100");
}
// The Outline attribute seems to be ignored.
//writer.WriteAttributeString("Outline", colorToString(node.Attr.Color));
////writer.WriteAttributeString("Outline", colorToString(node.Attr.Color));
writer.WriteEndElement();
}
writer.WriteEndElement();

Просмотреть файл

@ -12,10 +12,6 @@ using Microsoft.ML.Probabilistic.Collections;
namespace Microsoft.ML.Probabilistic.Utilities
{
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// Helpful methods for converting objects to strings.
/// </summary>
@ -804,6 +800,5 @@ namespace Microsoft.ML.Probabilistic.Utilities
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -260,7 +260,6 @@ namespace Microsoft.ML.Probabilistic.Factors
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteAreEqualOp"]/doc/*'/>
[FactorMethod(typeof(Factor), "AreEqual", typeof(int), typeof(int))]
[FactorMethod(typeof(EnumSupport), "AreEqual<>")]
[Quality(QualityBand.Mature)]
public static class DiscreteAreEqualOp
{

Просмотреть файл

@ -0,0 +1,189 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Factors
{
using System;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Factors.Attributes;
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/doc/*'/>
[FactorMethod(typeof(Factor), "BetaFromMeanAndTotalCount")]
[Quality(QualityBand.Experimental)]
public static class BetaFromMeanAndTotalCountOp
{
/// <summary>
/// How much damping to use to avoid improper messages. A higher value implies more damping.
/// </summary>
public static double damping = 0.0;
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
public static double AverageLogFactor(double prob, double mean, double totalCount)
{
return LogAverageFactor(prob, mean, totalCount);
}
// TODO: VMP evidence messages for stochastic inputs (see DirichletOp)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageLogarithm(Beta, Gamma)"]/*'/>
public static Beta ProbAverageLogarithm(Beta mean, [Proper] Gamma totalCount)
{
double meanMean = mean.GetMean();
double totalCountMean = totalCount.GetMean();
return new Beta(meanMean * totalCountMean, (1 - meanMean) * totalCountMean);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageLogarithm(Beta, double)"]/*'/>
public static Beta ProbAverageLogarithm(Beta mean, double totalCount)
{
double meanMean = mean.GetMean();
return new Beta(meanMean * totalCount, (1 - meanMean) * totalCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageLogarithm(double, Gamma)"]/*'/>
public static Beta ProbAverageLogarithm(double mean, [Proper] Gamma totalCount)
{
double totalCountMean = totalCount.GetMean();
return new Beta(mean * totalCountMean, (1 - mean) * totalCountMean);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, Gamma, Beta)"]/*'/>
public static Beta MeanAverageLogarithm(double prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
{
return MeanAverageLogarithm(Beta.PointMass(prob), mean, totalCount, to_mean);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, double, Beta)"]/*'/>
public static Beta MeanAverageLogarithm(double prob, Beta mean, double totalCount, Beta to_mean)
{
return MeanAverageLogarithm(Beta.PointMass(prob), mean, Gamma.PointMass(totalCount), to_mean);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageLogarithm(Beta, Beta, Gamma, Beta)"]/*'/>
public static Beta MeanAverageLogarithm([Proper] Beta prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
{
// Calculate gradient using method for DirichletOp
prob.GetMeanLogs(out double ELogP, out double ELogOneMinusP);
Vector gradS = DirichletOp.CalculateGradientForMean(
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
totalCount,
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
// Project onto a Beta distribution
Matrix A = new Matrix(2, 2);
double c = MMath.Trigamma(mean.TotalCount);
A[0, 0] = MMath.Trigamma(mean.TrueCount) - c;
A[1, 0] = A[0, 1] = -c;
A[1, 1] = MMath.Trigamma(mean.FalseCount) - c;
Vector theta = GammaFromShapeAndRateOp.twoByTwoInverse(A) * gradS;
Beta approximateFactor = new Beta(theta[0] + 1, theta[1] + 1);
if (damping == 0.0)
return approximateFactor;
else
return (approximateFactor ^ (1 - damping)) * (to_mean ^ damping);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="TotalCountAverageLogarithm(Beta, Gamma, Beta, Gamma)"]/*'/>
public static Gamma TotalCountAverageLogarithm([Proper] Beta mean, Gamma totalCount, [SkipIfUniform] Beta prob, Gamma to_totalCount)
{
prob.GetMeanLogs(out double ELogP, out double ELogOneMinusP);
Gamma approximateFactor = DirichletOp.TotalCountAverageLogarithmHelper(
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
totalCount,
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
if (damping == 0.0)
return approximateFactor;
else
return (approximateFactor ^ (1 - damping)) * (to_totalCount ^ damping);
}
//---------------------------- EP -----------------------------
private const string NotSupportedMessage = "Expectation Propagation does not currently support beta distributions with stochastic arguments.";
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
public static double LogAverageFactor(double prob, double mean, double totalCount)
{
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
return g.GetLogProb(prob);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="LogAverageFactor(Beta, double, double)"]/*'/>
public static double LogAverageFactor(Beta prob, double mean, double totalCount)
{
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
return g.GetLogAverageOf(prob);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(double, double)"]/*'/>
public static Beta ProbAverageConditional(double mean, double totalCount)
{
return new Beta(mean * totalCount, (1 - mean) * totalCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(Beta, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(Beta, double)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, double totalCount)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="ProbAverageConditional(double, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta ProbAverageConditional(double mean, [SkipIfUniform] Gamma totalCount)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, double, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, double, double, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, double prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, double, Beta, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, Beta, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, double, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Gamma result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndTotalCountOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, Beta, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Gamma result)
{
throw new NotSupportedException(NotSupportedMessage);
}
}
}

Просмотреть файл

@ -0,0 +1,65 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Factors
{
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Factors.Attributes;
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/doc/*'/>
[FactorMethod(new string[] { "sample", "mean", "variance" }, typeof(Beta), "SampleFromMeanAndVariance")]
[Quality(QualityBand.Stable)]
public static class BetaFromMeanAndVarianceOp
{
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
public static double LogAverageFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
{
return to_sample.GetLogAverageOf(sample);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
[Skip]
public static double LogEvidenceRatio(Beta sample, double mean, double variance)
{
return 0.0;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
public static double AverageLogFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
{
return to_sample.GetAverageLog(sample);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
public static double LogAverageFactor(double sample, double mean, double variance)
{
return SampleAverageConditional(mean, variance).GetLogProb(sample);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
public static double LogEvidenceRatio(double sample, double mean, double variance)
{
return LogAverageFactor(sample, mean, variance);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
public static double AverageLogFactor(double sample, double mean, double variance)
{
return LogAverageFactor(sample, mean, variance);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
public static Beta SampleAverageLogarithm(double mean, double variance)
{
return Beta.FromMeanAndVariance(mean, variance);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
public static Beta SampleAverageConditional(double mean, double variance)
{
return Beta.FromMeanAndVariance(mean, variance);
}
}
}

Просмотреть файл

@ -1,301 +1,183 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
namespace Microsoft.ML.Probabilistic.Factors
{
using System;
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Math;
using Microsoft.ML.Probabilistic.Factors.Attributes;
using System;
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/doc/*'/>
[FactorMethod(typeof(Factor), "BetaFromMeanAndTotalCount")]
[Quality(QualityBand.Experimental)]
public static class BetaOp
{
/// <summary>
/// How much damping to use to avoid improper messages. A higher value implies more damping.
/// </summary>
public static double damping = 0.0;
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
public static double AverageLogFactor(double prob, double mean, double totalCount)
{
return LogAverageFactor(prob, mean, totalCount);
}
// TODO: VMP evidence messages for stochastic inputs (see DirichletOp)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageLogarithm(Beta, Gamma)"]/*'/>
public static Beta ProbAverageLogarithm(Beta mean, [Proper] Gamma totalCount)
{
double meanMean = mean.GetMean();
double totalCountMean = totalCount.GetMean();
return (new Beta(meanMean * totalCountMean, (1 - meanMean) * totalCountMean));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageLogarithm(Beta, double)"]/*'/>
public static Beta ProbAverageLogarithm(Beta mean, double totalCount)
{
double meanMean = mean.GetMean();
return (new Beta(meanMean * totalCount, (1 - meanMean) * totalCount));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageLogarithm(double, Gamma)"]/*'/>
public static Beta ProbAverageLogarithm(double mean, [Proper] Gamma totalCount)
{
double totalCountMean = totalCount.GetMean();
return (new Beta(mean * totalCountMean, (1 - mean) * totalCountMean));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, Gamma, Beta)"]/*'/>
public static Beta MeanAverageLogarithm(double prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
{
return MeanAverageLogarithm(Beta.PointMass(prob), mean, totalCount, to_mean);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageLogarithm(double, Beta, double, Beta)"]/*'/>
public static Beta MeanAverageLogarithm(double prob, Beta mean, double totalCount, Beta to_mean)
{
return MeanAverageLogarithm(Beta.PointMass(prob), mean, Gamma.PointMass(totalCount), to_mean);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageLogarithm(Beta, Beta, Gamma, Beta)"]/*'/>
public static Beta MeanAverageLogarithm([Proper] Beta prob, Beta mean, [Proper] Gamma totalCount, Beta to_mean)
{
// Calculate gradient using method for DirichletOp
double ELogP, ELogOneMinusP;
prob.GetMeanLogs(out ELogP, out ELogOneMinusP);
Vector gradS = DirichletOp.CalculateGradientForMean(
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
totalCount,
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
// Project onto a Beta distribution
Matrix A = new Matrix(2, 2);
double c = MMath.Trigamma(mean.TotalCount);
A[0, 0] = MMath.Trigamma(mean.TrueCount) - c;
A[1, 0] = A[0, 1] = -c;
A[1, 1] = MMath.Trigamma(mean.FalseCount) - c;
Vector theta = GammaFromShapeAndRateOp.twoByTwoInverse(A) * gradS;
Beta approximateFactor = new Beta(theta[0] + 1, theta[1] + 1);
if (damping == 0.0)
return approximateFactor;
else
return (approximateFactor ^ (1 - damping)) * (to_mean ^ damping);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TotalCountAverageLogarithm(Beta, Gamma, Beta, Gamma)"]/*'/>
public static Gamma TotalCountAverageLogarithm([Proper] Beta mean, Gamma totalCount, [SkipIfUniform] Beta prob, Gamma to_totalCount)
{
double ELogP, ELogOneMinusP;
prob.GetMeanLogs(out ELogP, out ELogOneMinusP);
Gamma approximateFactor = DirichletOp.TotalCountAverageLogarithmHelper(
Vector.FromArray(new double[] { mean.TrueCount, mean.FalseCount }),
totalCount,
Vector.FromArray(new double[] { ELogP, ELogOneMinusP }));
if (damping == 0.0)
return approximateFactor;
else
return (approximateFactor ^ (1 - damping)) * (to_totalCount ^ damping);
}
//---------------------------- EP -----------------------------
private const string NotSupportedMessage = "Expectation Propagation does not currently support beta distributions with stochastic arguments.";
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
public static double LogAverageFactor(double prob, double mean, double totalCount)
{
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
return g.GetLogProb(prob);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(Beta, double, double)"]/*'/>
public static double LogAverageFactor(Beta prob, double mean, double totalCount)
{
var g = new Beta(mean * totalCount, (1 - mean) * totalCount);
return g.GetLogAverageOf(prob);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(double, double)"]/*'/>
public static Beta ProbAverageConditional(double mean, double totalCount)
{
return new Beta(mean * totalCount, (1 - mean) * totalCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(Beta, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(Beta, double)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta ProbAverageConditional([SkipIfUniform] Beta mean, double totalCount)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="ProbAverageConditional(double, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta ProbAverageConditional(double mean, [SkipIfUniform] Gamma totalCount)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, double, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, double, double, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, double prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, double, Beta, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, double totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="MeanAverageConditional(Beta, Gamma, Beta, Beta)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Beta MeanAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Beta result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, double, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, double prob, [SkipIfUniform] Gamma result)
{
throw new NotSupportedException(NotSupportedMessage);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TotalCountAverageConditional(Beta, Gamma, Beta, Gamma)"]/*'/>
[NotSupported(NotSupportedMessage)]
public static Gamma TotalCountAverageConditional([SkipIfUniform] Beta mean, [SkipIfUniform] Gamma totalCount, [SkipIfUniform] Beta prob, [SkipIfUniform] Gamma result)
{
throw new NotSupportedException(NotSupportedMessage);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/doc/*'/>
[FactorMethod(typeof(Beta), "Sample", typeof(double), typeof(double))]
[Quality(QualityBand.Stable)]
public static class BetaFromTrueAndFalseCountsOp
public static class BetaOp
{
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
public static double LogAverageFactor(Beta sample, double trueCount, double falseCount, [Fresh] Beta to_sample)
{
return to_sample.GetLogAverageOf(sample);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
[Skip]
public static double LogEvidenceRatio(Beta sample, double trueCount, double falseCount)
{
return 0.0;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
public static double AverageLogFactor(Beta sample, double trueCount, double falseCount, [Fresh] Beta to_sample)
{
return to_sample.GetAverageLog(sample);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
public static double LogAverageFactor(double sample, double trueCount, double falseCount)
{
return SampleAverageConditional(trueCount, falseCount).GetLogProb(sample);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
public static double LogEvidenceRatio(double sample, double trueCount, double falseCount)
{
return LogAverageFactor(sample, trueCount, falseCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
public static double AverageLogFactor(double sample, double trueCount, double falseCount)
{
return LogAverageFactor(sample, trueCount, falseCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
public static Beta SampleAverageLogarithm(double trueCount, double falseCount)
{
return new Beta(trueCount, falseCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromTrueAndFalseCountsOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
public static Beta SampleAverageConditional(double trueCount, double falseCount)
{
return new Beta(trueCount, falseCount);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/doc/*'/>
[FactorMethod(new string[] { "sample", "mean", "variance" }, typeof(Beta), "SampleFromMeanAndVariance")]
[Quality(QualityBand.Stable)]
public static class BetaFromMeanAndVarianceOp
{
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(Beta, double, double, Beta)"]/*'/>
public static double LogAverageFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
const string TrueCountMustBeOneMessage = "falseCount is Gamma and trueCount is not 1";
const string FalseCountMustBeOneMessage = "trueCount is Gamma and falseCount is not 1";
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, Gamma, double)"]/*'/>
public static double AverageLogFactor(double sample, Gamma trueCount, double falseCount)
{
return to_sample.GetLogAverageOf(sample);
if (trueCount.IsPointMass)
{
return LogEvidenceRatio(sample, trueCount.Point, falseCount);
}
else if (falseCount == 1)
{
// The factor is f(x, a) = a x^(a-1)
// whose logarithm is log(a) + (a-1)*log(x)
return trueCount.GetMeanLog() + (trueCount.GetMean() - 1) * Math.Log(sample);
}
else
{
throw new NotSupportedException(FalseCountMustBeOneMessage);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(Beta, double, double)"]/*'/>
[Skip]
public static double LogEvidenceRatio(Beta sample, double mean, double variance)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(double, Gamma, double)"]/*'/>
public static double LogEvidenceRatio(double sample, Gamma trueCount, double falseCount)
{
return 0.0;
if (trueCount.IsPointMass)
{
return LogEvidenceRatio(sample, trueCount.Point, falseCount);
}
else if (falseCount == 1)
{
// The factor is f(x, a) = a x^(a-1)
// f(x, a) Ga(a; s, r) = a^s exp(-r*a + (a-1)*log(x)) r^s / Gamma(s)
// Z = Gamma(s+1)/Gamma(s) * r^s / (r - log(x))^(s+1) / x
return Math.Log(trueCount.Shape) - Math.Log(sample) + trueCount.Shape * Math.Log(trueCount.Rate) - (trueCount.Shape + 1) * Math.Log(trueCount.Rate - Math.Log(sample));
}
else
{
throw new NotSupportedException(FalseCountMustBeOneMessage);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(Beta, double, double, Beta)"]/*'/>
public static double AverageLogFactor(Beta sample, double mean, double variance, [Fresh] Beta to_sample)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="AverageLogFactor(double, double, Gamma)"]/*'/>
public static double AverageLogFactor(double sample, double trueCount, Gamma falseCount)
{
return to_sample.GetAverageLog(sample);
if (falseCount.IsPointMass)
{
return LogEvidenceRatio(sample, trueCount, falseCount.Point);
}
else if (trueCount == 1)
{
// The factor is f(x, b) = b (1-x)^(b-1)
return falseCount.GetMeanLog() + (falseCount.GetMean() - 1) * Math.Log(1 - sample);
}
else
{
throw new NotSupportedException(TrueCountMustBeOneMessage);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogAverageFactor(double, double, double)"]/*'/>
public static double LogAverageFactor(double sample, double mean, double variance)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="LogEvidenceRatio(double, double, Gamma)"]/*'/>
public static double LogEvidenceRatio(double sample, double trueCount, Gamma falseCount)
{
return SampleAverageConditional(mean, variance).GetLogProb(sample);
if (falseCount.IsPointMass)
{
return LogEvidenceRatio(sample, trueCount, falseCount.Point);
}
else if (trueCount == 1)
{
// The factor is f(x, b) = b (1-x)^(b-1)
return Math.Log(falseCount.Shape) - Math.Log(1 - sample) + falseCount.Shape * Math.Log(falseCount.Rate) - (falseCount.Shape + 1) * Math.Log(falseCount.Rate - Math.Log(1 - sample));
}
else
{
throw new NotSupportedException(TrueCountMustBeOneMessage);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="LogEvidenceRatio(double, double, double)"]/*'/>
public static double LogEvidenceRatio(double sample, double mean, double variance)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TrueCountAverageConditional(double, double)"]/*'/>
public static Gamma TrueCountAverageConditional(double sample, double falseCount)
{
return LogAverageFactor(sample, mean, variance);
return TrueCountAverageLogarithm(sample, falseCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="AverageLogFactor(double, double, double)"]/*'/>
public static double AverageLogFactor(double sample, double mean, double variance)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="TrueCountAverageLogarithm(double, double)"]/*'/>
public static Gamma TrueCountAverageLogarithm(double sample, double falseCount)
{
return LogAverageFactor(sample, mean, variance);
if (falseCount == 1)
{
return Gamma.FromShapeAndRate(2, -Math.Log(sample));
}
else
{
throw new NotSupportedException(FalseCountMustBeOneMessage);
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageLogarithm(double, double)"]/*'/>
public static Beta SampleAverageLogarithm(double mean, double variance)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="FalseCountAverageConditional(double, double)"]/*'/>
public static Gamma FalseCountAverageConditional(double sample, double trueCount)
{
return Beta.FromMeanAndVariance(mean, variance);
return FalseCountAverageLogarithm(sample, trueCount);
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaFromMeanAndVarianceOp"]/message_doc[@name="SampleAverageConditional(double, double)"]/*'/>
public static Beta SampleAverageConditional(double mean, double variance)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="BetaOp"]/message_doc[@name="FalseCountAverageLogarithm(double, double)"]/*'/>
public static Gamma FalseCountAverageLogarithm(double sample, double trueCount)
{
return Beta.FromMeanAndVariance(mean, variance);
if (trueCount == 1)
{
return Gamma.FromShapeAndRate(2, -Math.Log(1 - sample));
}
else
{
throw new NotSupportedException(TrueCountMustBeOneMessage);
}
}
}
}

Просмотреть файл

@ -346,5 +346,25 @@ namespace Microsoft.ML.Probabilistic.Factors
{
return DiscreteAreEqualOp.LogEvidenceRatio(areEqual, ToInt(A), ToInt(B));
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="DiscreteEnumAreEqualOp{TEnum}"]/message_doc[@name="LogEvidenceRatio(Bernoulli)"]/*'/>
[Skip]
public static double LogEvidenceRatio(Bernoulli areEqual)
{
return 0.0;
}
/// <summary>
/// Evidence message for VMP.
/// </summary>
/// <returns>Zero</returns>
/// <remarks><para>
/// In Variational Message Passing, the evidence contribution of a deterministic factor is zero.
/// </para></remarks>
[Skip]
public static double AverageLogFactor()
{
return 0.0;
}
}
}

Просмотреть файл

@ -677,7 +677,7 @@ namespace Microsoft.ML.Probabilistic.Factors
int sum = 0;
for (int i = 0; i < array.Count; i++)
{
sum = sum + array[i];
sum += array[i];
}
return sum;
}
@ -692,7 +692,7 @@ namespace Microsoft.ML.Probabilistic.Factors
double sum = 0;
for (int i = 0; i < array.Count; i++)
{
sum = sum + array[i];
sum += array[i];
}
return sum;
}

Просмотреть файл

@ -3952,7 +3952,7 @@
</remarks>
</message_doc>
</message_op_class>
<message_op_class name="BetaOp">
<message_op_class name="BetaFromMeanAndTotalCountOp">
<doc>
<summary>Provides outgoing messages for <see cref="Factor.BetaFromMeanAndTotalCount(double, double)" />, given random arguments to the function.</summary>
</doc>
@ -4220,7 +4220,92 @@
<paramref name="prob" /> is not a proper distribution.</exception>
</message_doc>
</message_op_class>
<message_op_class name="BetaFromTrueAndFalseCountsOp">
<message_op_class name="BetaFromMeanAndVarianceOp">
<doc>
<summary>Provides outgoing messages for <see cref="Beta.SampleFromMeanAndVariance(double, double)" />, given random arguments to the function.</summary>
</doc>
<message_doc name="LogAverageFactor(Beta, double, double, Beta)">
<summary>Evidence message for EP.</summary>
<param name="sample">Incoming message from <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance))</c>.</para>
</remarks>
</message_doc>
<message_doc name="LogEvidenceRatio(Beta, double, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Incoming message from <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
<remarks>
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance) / sum_sample p(sample) messageTo(sample))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
</remarks>
</message_doc>
<message_doc name="AverageLogFactor(Beta, double, double, Beta)">
<summary>Evidence message for VMP.</summary>
<param name="sample">Incoming message from <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
<returns>Average of the factor's log-value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>sum_(sample) p(sample) log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
</remarks>
</message_doc>
<message_doc name="LogAverageFactor(double, double, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>.</para>
</remarks>
</message_doc>
<message_doc name="LogEvidenceRatio(double, double, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
<remarks>
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
</remarks>
</message_doc>
<message_doc name="AverageLogFactor(double, double, double)">
<summary>Evidence message for VMP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Average of the factor's log-value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
</remarks>
</message_doc>
<message_doc name="SampleAverageLogarithm(double, double)">
<summary>VMP message to <c>sample</c>.</summary>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>The outgoing VMP message to the <c>sample</c> argument.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
</remarks>
</message_doc>
<message_doc name="SampleAverageConditional(double, double)">
<summary>EP message to <c>sample</c>.</summary>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>The outgoing EP message to the <c>sample</c> argument.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
</remarks>
</message_doc>
</message_op_class>
<message_op_class name="BetaOp">
<doc>
<summary>Provides outgoing messages for <see cref="Beta.Sample(double, double)" />, given random arguments to the function.</summary>
</doc>
@ -4304,89 +4389,80 @@
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
</remarks>
</message_doc>
</message_op_class>
<message_op_class name="BetaFromMeanAndVarianceOp">
<doc>
<summary>Provides outgoing messages for <see cref="Beta.SampleFromMeanAndVariance(double, double)" />, given random arguments to the function.</summary>
</doc>
<message_doc name="LogAverageFactor(Beta, double, double, Beta)">
<summary>Evidence message for EP.</summary>
<param name="sample">Incoming message from <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance))</c>.</para>
</remarks>
</message_doc>
<message_doc name="LogEvidenceRatio(Beta, double, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Incoming message from <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
<remarks>
<para>The formula for the result is <c>log(sum_(sample) p(sample) factor(sample,mean,variance) / sum_sample p(sample) messageTo(sample))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
</remarks>
</message_doc>
<message_doc name="AverageLogFactor(Beta, double, double, Beta)">
<summary>Evidence message for VMP.</summary>
<param name="sample">Incoming message from <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<param name="to_sample">Outgoing message to <c>sample</c>.</param>
<returns>Average of the factor's log-value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>sum_(sample) p(sample) log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
</remarks>
</message_doc>
<message_doc name="LogAverageFactor(double, double, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Logarithm of the factor's average value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>.</para>
</remarks>
</message_doc>
<message_doc name="LogEvidenceRatio(double, double, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
<remarks>
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
</remarks>
</message_doc>
<message_doc name="AverageLogFactor(double, double, double)">
<message_doc name="AverageLogFactor(double, Gamma, double)">
<summary>Evidence message for VMP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<param name="trueCount">Incoming message from <c>trueCount</c>.</param>
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
<returns>Average of the factor's log-value across the given argument distributions.</returns>
<remarks>
<para>The formula for the result is <c>log(factor(sample,mean,variance))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
<para>The formula for the result is <c>sum_(trueCount) p(trueCount) log(factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
</remarks>
</message_doc>
<message_doc name="SampleAverageLogarithm(double, double)">
<summary>VMP message to <c>sample</c>.</summary>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>The outgoing VMP message to the <c>sample</c> argument.</returns>
<message_doc name="LogEvidenceRatio(double, Gamma, double)">
<summary>Evidence message for EP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="trueCount">Incoming message from <c>trueCount</c>.</param>
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
<para>The formula for the result is <c>log(sum_(trueCount) p(trueCount) factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
</remarks>
</message_doc>
<message_doc name="SampleAverageConditional(double, double)">
<summary>EP message to <c>sample</c>.</summary>
<param name="mean">Constant value for <c>mean</c>.</param>
<param name="variance">Constant value for <c>variance</c>.</param>
<returns>The outgoing EP message to the <c>sample</c> argument.</returns>
<message_doc name="AverageLogFactor(double, double, Gamma)">
<summary>Evidence message for VMP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
<param name="falseCount">Incoming message from <c>falseCount</c>.</param>
<returns>Average of the factor's log-value across the given argument distributions.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>sample</c> conditioned on the given values.</para>
<para>The formula for the result is <c>sum_(falseCount) p(falseCount) log(factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for VMP.</para>
</remarks>
</message_doc>
<message_doc name="LogEvidenceRatio(double, double, Gamma)">
<summary>Evidence message for EP.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
<param name="falseCount">Incoming message from <c>falseCount</c>.</param>
<returns>Logarithm of the factor's contribution the EP model evidence.</returns>
<remarks>
<para>The formula for the result is <c>log(sum_(falseCount) p(falseCount) factor(sample,trueCount,falseCount))</c>. Adding up these values across all factors and variables gives the log-evidence estimate for EP.</para>
</remarks>
</message_doc>
<message_doc name="TrueCountAverageConditional(double, double)">
<summary>EP message to <c>trueCount</c>.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
<returns>The outgoing EP message to the <c>trueCount</c> argument.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>trueCount</c> conditioned on the given values.</para>
</remarks>
</message_doc>
<message_doc name="TrueCountAverageLogarithm(double, double)">
<summary>VMP message to <c>trueCount</c>.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="falseCount">Constant value for <c>falseCount</c>.</param>
<returns>The outgoing VMP message to the <c>trueCount</c> argument.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>trueCount</c> conditioned on the given values.</para>
</remarks>
</message_doc>
<message_doc name="FalseCountAverageConditional(double, double)">
<summary>EP message to <c>falseCount</c>.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
<returns>The outgoing EP message to the <c>falseCount</c> argument.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>falseCount</c> conditioned on the given values.</para>
</remarks>
</message_doc>
<message_doc name="FalseCountAverageLogarithm(double, double)">
<summary>VMP message to <c>falseCount</c>.</summary>
<param name="sample">Constant value for <c>sample</c>.</param>
<param name="trueCount">Constant value for <c>trueCount</c>.</param>
<returns>The outgoing VMP message to the <c>falseCount</c> argument.</returns>
<remarks>
<para>The outgoing message is the factor viewed as a function of <c>falseCount</c> conditioned on the given values.</para>
</remarks>
</message_doc>
</message_op_class>

Просмотреть файл

@ -59,58 +59,56 @@ namespace Microsoft.ML.Probabilistic.Tools.PrepareSource
private static void ProcessFile(string sourceFileName, string destinationFileName, Dictionary<string, XDocument> loadedDocFiles)
{
using (var reader = new StreamReader(sourceFileName))
using (var writer = new StreamWriter(destinationFileName))
using var reader = new StreamReader(sourceFileName);
using var writer = new StreamWriter(destinationFileName);
string line;
int lineNumber = 0;
while ((line = reader.ReadLine()) != null)
{
string line;
int lineNumber = 0;
while ((line = reader.ReadLine()) != null)
++lineNumber;
string trimmedLine = line.Trim();
if (!trimmedLine.StartsWith("/// <include", StringComparison.InvariantCulture))
{
++lineNumber;
// Not a line with an include directive
writer.WriteLine(line);
continue;
}
string trimmedLine = line.Trim();
if (!trimmedLine.StartsWith("/// <include", StringComparison.InvariantCulture))
{
// Not a line with an include directive
writer.WriteLine(line);
continue;
}
string includeString = trimmedLine.Substring("/// ".Length);
var includeDoc = XDocument.Parse(includeString);
string includeString = trimmedLine.Substring("/// ".Length);
var includeDoc = XDocument.Parse(includeString);
XAttribute fileAttribute = includeDoc.Root.Attribute("file");
XAttribute pathAttribute = includeDoc.Root.Attribute("path");
if (fileAttribute == null || pathAttribute == null)
{
Error("An ill-formed include directive at {0}:{1}", sourceFileName, lineNumber);
}
XAttribute fileAttribute = includeDoc.Root.Attribute("file");
XAttribute pathAttribute = includeDoc.Root.Attribute("path");
if (fileAttribute == null || pathAttribute == null)
{
Error("An ill-formed include directive at {0}:{1}", sourceFileName, lineNumber);
}
string fullDocFileName = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(sourceFileName), fileAttribute.Value));
XDocument docFile;
if (!loadedDocFiles.TryGetValue(fullDocFileName, out docFile))
{
docFile = XDocument.Load(fullDocFileName);
loadedDocFiles.Add(fullDocFileName, docFile);
}
string fullDocFileName = Path.GetFullPath(Path.Combine(Path.GetDirectoryName(sourceFileName), fileAttribute.Value));
XDocument docFile;
if (!loadedDocFiles.TryGetValue(fullDocFileName, out docFile))
XElement[] docElements = ((IEnumerable)docFile.XPathEvaluate(pathAttribute.Value)).Cast<XElement>().ToArray();
if (docElements.Length == 0)
{
Console.WriteLine("WARNING: nothing to include for the include directive at {0}:{1}", sourceFileName, lineNumber);
}
else
{
int indexOfDocStart = line.IndexOf("/// <include", StringComparison.InvariantCulture);
foreach (XElement docElement in docElements)
{
docFile = XDocument.Load(fullDocFileName);
loadedDocFiles.Add(fullDocFileName, docFile);
}
XElement[] docElements = ((IEnumerable)docFile.XPathEvaluate(pathAttribute.Value)).Cast<XElement>().ToArray();
if (docElements.Length == 0)
{
Console.WriteLine("WARNING: nothing to include for the include directive at {0}:{1}", sourceFileName, lineNumber);
}
else
{
int indexOfDocStart = line.IndexOf("/// <include", StringComparison.InvariantCulture);
foreach (XElement docElement in docElements)
string[] docElementStringLines = docElement.ToString().Split(new[] { Environment.NewLine }, StringSplitOptions.None);
string indentation = new string(' ', indexOfDocStart);
foreach (string docElementStringLine in docElementStringLines)
{
string[] docElementStringLines = docElement.ToString().Split(new[] { Environment.NewLine }, StringSplitOptions.None);
string indentation = new string(' ', indexOfDocStart);
foreach (string docElementStringLine in docElementStringLines)
{
writer.WriteLine("{0}/// {1}", indentation, docElementStringLine);
}
}
writer.WriteLine("{0}/// {1}", indentation, docElementStringLine);
}
}
}
}

Просмотреть файл

@ -83,7 +83,8 @@ namespace TestApp
Stopwatch watch = new Stopwatch();
watch.Start();
if (false)
bool runAllTests = false;
if (runAllTests)
{
// Run all tests (need to run in 64-bit else OutOfMemory due to loading many DLLs)
// This is useful when looking for failures due to certain compiler options.
@ -100,7 +101,11 @@ namespace TestApp
//}
//TestUtils.CheckTransformNames();
}
//InferenceEngine.ShowFactorManager(true);
bool showFactorManager = true;
if (showFactorManager)
{
InferenceEngine.ShowFactorManager(true);
}
#if NETFRAMEWORK
logWriter.Dispose();
#endif

Просмотреть файл

@ -465,8 +465,11 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
private VariableArray<VariableArray<T>, T[][]> Transpose<T>(VariableArray<VariableArray<T>, T[][]> array, VariableArray<VariableArray<int>, int[][]> indices, Range r1,
Range r2)
private VariableArray<VariableArray<T>, T[][]> Transpose<T>(
VariableArray<VariableArray<T>, T[][]> array,
VariableArray<VariableArray<int>, int[][]> indices,
Range r1,
Range r2)
{
return array;
}

Просмотреть файл

@ -1518,7 +1518,7 @@ namespace Microsoft.ML.Probabilistic.Tests
public class VectorIsPositiveEP
{
IGeneratedAlgorithm gen;
readonly IGeneratedAlgorithm gen;
public int NumberOfIterations = 100;
public VectorIsPositiveEP(int dim)

Просмотреть файл

@ -27,7 +27,91 @@ namespace Microsoft.ML.Probabilistic.Tests
public class GatedFactorTests
{
private static bool verbose = false;
private static readonly bool verbose = false;
[Fact]
public void BetaTrueCountIsGamma()
{
BetaTrueCountIsGamma(new ExpectationPropagation());
BetaTrueCountIsGamma(new VariationalMessagePassing());
void BetaTrueCountIsGamma(IAlgorithm algorithm)
{
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
var evBlock = Variable.If(evidence);
double shape = 2.5;
double rate = 3.5;
Variable<double> trueCount = Variable.GammaFromShapeAndRate(shape, rate);
trueCount.Name = nameof(trueCount);
Variable<double> x = Variable.Beta(trueCount, 1).Named("x");
x.ObservedValue = 0.25;
evBlock.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
var trueCountActual = engine.Infer<Gamma>(trueCount);
double xRate = -System.Math.Log(x.ObservedValue);
var trueCountExpected = Gamma.FromShapeAndRate(shape + 1, rate + xRate);
Assert.True(trueCountExpected.MaxDiff(trueCountActual) < 1e-10);
var evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
double evExpected = double.NegativeInfinity;
var trueCounts = EpTests.linspace(1e-6, 10, 1000);
foreach (var trueCountValue in trueCounts)
{
trueCount.ObservedValue = trueCountValue;
double ev = engine.Infer<Bernoulli>(evidence).LogOdds;
evExpected = MMath.LogSumExp(evExpected, ev);
}
double increment = trueCounts[1] - trueCounts[0];
evExpected += System.Math.Log(increment);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-8) < 1e-6);
}
}
[Fact]
public void BetaFalseCountIsGamma()
{
BetaFalseCountIsGamma(new ExpectationPropagation());
BetaFalseCountIsGamma(new VariationalMessagePassing());
void BetaFalseCountIsGamma(IAlgorithm algorithm)
{
Variable<bool> evidence = Variable.Bernoulli(0.5).Named("evidence");
var evBlock = Variable.If(evidence);
double shape = 2.5;
double rate = 3.5;
Variable<double> falseCount = Variable.GammaFromShapeAndRate(shape, rate);
falseCount.Name = nameof(falseCount);
Variable<double> x = Variable.Beta(1, falseCount).Named("x");
x.ObservedValue = 0.25;
evBlock.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
var falseCountActual = engine.Infer<Gamma>(falseCount);
double xRate = -System.Math.Log(1 - x.ObservedValue);
var falseCountExpected = Gamma.FromShapeAndRate(shape + 1, rate + xRate);
Assert.True(falseCountExpected.MaxDiff(falseCountActual) < 1e-10);
var evActual = engine.Infer<Bernoulli>(evidence).LogOdds;
double evExpected = double.NegativeInfinity;
var falseCounts = EpTests.linspace(1e-6, 10, 1000);
foreach (var falseCountValue in falseCounts)
{
falseCount.ObservedValue = falseCountValue;
double ev = engine.Infer<Bernoulli>(evidence).LogOdds;
evExpected = MMath.LogSumExp(evExpected, ev);
}
double increment = falseCounts[1] - falseCounts[0];
evExpected += System.Math.Log(increment);
Assert.True(MMath.AbsDiff(evExpected, evActual, 1e-8) < 1e-6);
}
}
[Fact]
public void TruncatedGaussianIsBetweenTest()
@ -2452,7 +2536,7 @@ namespace Microsoft.ML.Probabilistic.Tests
Variable.ConstrainEqualRandom(x[item], xPrior[item]);
block.CloseBlock();
InferenceEngine engine = new InferenceEngine();
InferenceEngine engine = new InferenceEngine(algorithm);
for (int ctrial = 0; ctrial < 3; ctrial++)
{
Vector cMean = Vector.Zero(dc);

Просмотреть файл

@ -76,8 +76,7 @@ namespace Microsoft.ML.Probabilistic.Tests
/// <returns>Beta-distributed random variable</returns>
public static Variable<double> BetaFromMeanAndTotalCount(Variable<double> mean, Variable<double> totalCount)
{
return Variable<double>.Factor(Factor.BetaFromMeanAndTotalCount, mean, totalCount)
.Attrib(new MarginalPrototype(new Beta(0, 0)));
return Variable<double>.Factor(Factor.BetaFromMeanAndTotalCount, mean, totalCount);
}
[Fact]