зеркало из https://github.com/dotnet/infer.git
Added Variable.CutForwardWhen, Range.IsIncreasing (#243)
Renamed Factor.Cut to Cut.Backward Added InferNet.IsIncreasing, IsIncreasingTransform
This commit is contained in:
Родитель
09d05e617a
Коммит
268353c5bb
|
@ -982,6 +982,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
tc.AddTransform(new IterativeProcessTransform(this, algorithm));
|
||||
// LoopMerging is required to support offset indexing (see GateModelTests.CaseLoopIndexTest2)
|
||||
tc.AddTransform(new LoopMergingTransform());
|
||||
tc.AddTransform(new IsIncreasingTransform());
|
||||
// Local is required for DistributedTests
|
||||
tc.AddTransform(new LocalTransform(this));
|
||||
if (OptimiseInferenceCode)
|
||||
|
|
|
@ -39,13 +39,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
// Attributes of the method invocation i.e. the factor or constraint.
|
||||
internal List<ICompilerAttribute> attributes = new List<ICompilerAttribute>();
|
||||
|
||||
// The condition blocks this method is contained in
|
||||
private List<IStatementBlock> containers;
|
||||
|
||||
internal List<IStatementBlock> Containers
|
||||
{
|
||||
get { return containers; }
|
||||
}
|
||||
internal List<IStatementBlock> Containers { get; }
|
||||
|
||||
// Provides global ordering for ModelBuilder
|
||||
internal readonly int timestamp;
|
||||
|
@ -67,7 +61,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
this.timestamp = GetTimestamp();
|
||||
this.method = method;
|
||||
this.args.AddRange(args);
|
||||
this.containers = new List<IStatementBlock>(containers);
|
||||
this.Containers = new List<IStatementBlock>(containers);
|
||||
foreach (IModelExpression arg in args)
|
||||
{
|
||||
if (ReferenceEquals(arg, null)) throw new ArgumentNullException();
|
||||
|
@ -77,7 +71,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
if (v.IsObserved) continue;
|
||||
foreach (ConditionBlock cb in v.GetContainers<ConditionBlock>())
|
||||
{
|
||||
if (!this.containers.Contains(cb))
|
||||
if (!this.Containers.Contains(cb))
|
||||
{
|
||||
throw new InvalidOperationException($"{arg} was created in condition {cb} and cannot be used outside. " +
|
||||
$"To give {arg} a conditional definition, use SetTo inside {cb} rather than assignment (=). " +
|
||||
|
@ -195,7 +189,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
else if (op == Variable.Operator.Equal) return Builder.BinaryExpr(argExprs[0], BinaryOperator.ValueEquality, argExprs[1]);
|
||||
else if (op == Variable.Operator.NotEqual) return Builder.BinaryExpr(argExprs[0], BinaryOperator.ValueInequality, argExprs[1]);
|
||||
}
|
||||
IMethodInvokeExpression imie = null;
|
||||
IMethodInvokeExpression imie;
|
||||
if (method.IsGenericMethod && !method.ContainsGenericParameters)
|
||||
{
|
||||
imie = Builder.StaticGenericMethod(method, argExprs);
|
||||
|
@ -240,7 +234,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
Set<Range> ranges = new Set<Range>();
|
||||
foreach (IModelExpression arg in returnValueAndArgs()) ForEachRange(arg, ranges.Add);
|
||||
foreach (IStatementBlock b in containers)
|
||||
foreach (IStatementBlock b in Containers)
|
||||
{
|
||||
if (b is HasRange)
|
||||
{
|
||||
|
@ -262,7 +256,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
ForEachRange(arg, delegate(Range r) { if (!ranges.Contains(r)) ranges.Add(r); });
|
||||
}
|
||||
foreach (IStatementBlock b in containers)
|
||||
foreach (IStatementBlock b in Containers)
|
||||
{
|
||||
if (b is HasRange)
|
||||
{
|
||||
|
|
|
@ -65,7 +65,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
modelType.Owner = null;
|
||||
modelType.BaseType = null;
|
||||
modelType.Visibility = TypeVisibility.Public;
|
||||
modelMethod = Builder.MethodDecl(MethodVisibility.Public, "Model", typeof (void), modelType);
|
||||
modelMethod = Builder.MethodDecl(MethodVisibility.Public, "Model", typeof(void), modelType);
|
||||
IBlockStatement body = Builder.BlockStmt();
|
||||
modelMethod.Body = body;
|
||||
//blocks = new List<IList<IStatement>>();
|
||||
|
@ -114,15 +114,13 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
List<IModelExpression> exprs = new List<IModelExpression>();
|
||||
foreach (IModelExpression expr in ModelExpressions)
|
||||
{
|
||||
if (expr is Variable)
|
||||
if (expr is Variable var)
|
||||
{
|
||||
Variable var = (Variable) expr;
|
||||
exprs.Add(var);
|
||||
timestamps.Add(var.timestamp);
|
||||
}
|
||||
else if (expr is MethodInvoke)
|
||||
else if (expr is MethodInvoke mi)
|
||||
{
|
||||
MethodInvoke mi = (MethodInvoke) expr;
|
||||
exprs.Add(mi);
|
||||
timestamps.Add(mi.timestamp);
|
||||
}
|
||||
|
@ -183,9 +181,9 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
if (var == null) throw new NullReferenceException("Model expression was null.");
|
||||
// Console.WriteLine("Building expression: "+var+" "+builtVars.ContainsKey(var));
|
||||
if (var is MethodInvoke)
|
||||
if (var is MethodInvoke methodInvoke)
|
||||
{
|
||||
BuildMethodInvoke((MethodInvoke)var, null);
|
||||
BuildMethodInvoke(methodInvoke);
|
||||
return;
|
||||
}
|
||||
MethodInfo mb = new Action<IModelExpression<object>>(this.BuildExpression<object>).Method.GetGenericMethodDefinition();
|
||||
|
@ -220,7 +218,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
Type[] faces = expr.GetType().GetInterfaces();
|
||||
foreach (Type face in faces)
|
||||
{
|
||||
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof (IModelExpression<>))
|
||||
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression<>))
|
||||
{
|
||||
domainType = face.GetGenericArguments()[0];
|
||||
break;
|
||||
|
@ -236,14 +234,14 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
if (expr == null) throw new NullReferenceException("Model expression was null.");
|
||||
// Console.WriteLine("Searching expression: "+var+" "+builtVars.ContainsKey(var));
|
||||
if (searched.Contains(expr)) return;
|
||||
if (expr is MethodInvoke)
|
||||
if (expr is MethodInvoke methodInvoke)
|
||||
{
|
||||
SearchMethodInvoke((MethodInvoke) expr);
|
||||
SearchMethodInvoke(methodInvoke);
|
||||
return;
|
||||
}
|
||||
if (expr is Range)
|
||||
if (expr is Range range)
|
||||
{
|
||||
SearchRange((Range) expr);
|
||||
SearchRange(range);
|
||||
return;
|
||||
}
|
||||
MethodInfo mb = new Action<IModelExpression<object>>(this.SearchExpression<object>).Method.GetGenericMethodDefinition();
|
||||
|
@ -253,7 +251,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
Type[] faces = expr.GetType().GetInterfaces();
|
||||
foreach (Type face in faces)
|
||||
{
|
||||
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof (IModelExpression<>))
|
||||
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression<>))
|
||||
{
|
||||
domainType = face.GetGenericArguments()[0];
|
||||
break;
|
||||
|
@ -284,15 +282,13 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// Add a statement of the form x = f(...) to the MSL.
|
||||
/// </summary>
|
||||
/// <param name="method">Stores the method to call, the argument variables, and target variable.</param>
|
||||
/// <param name="lhs">Stores the name and type of the target variable, if it is not already declared. Otherwise null.</param>
|
||||
///
|
||||
/// <remarks>
|
||||
/// If any variable in the statement is an item variable, then we surround the statement with a loop over its range.
|
||||
/// Since there may be multiple item variables, and each item may depend on multiple ranges, we may end up with multiple loops.
|
||||
/// </remarks>
|
||||
private void BuildMethodInvoke(MethodInvoke method, IExpression lhs)
|
||||
private void BuildMethodInvoke(MethodInvoke method)
|
||||
{
|
||||
if (method.ReturnValue is Variable && ((Variable) method.ReturnValue).Inline) return;
|
||||
if (method.ReturnValue is Variable && ((Variable)method.ReturnValue).Inline) return;
|
||||
// Open containing blocks
|
||||
List<IStatementBlock> stBlocks = method.Containers;
|
||||
List<Range> localRanges = new List<Range>();
|
||||
|
@ -304,7 +300,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
foreach (IModelExpression arg in method.returnValueAndArgs())
|
||||
{
|
||||
MethodInvoke.ForEachRange(arg,
|
||||
delegate(Range r) { if (!localRanges.Contains(r)) localRanges.Add(r); });
|
||||
delegate (Range r) { if (!localRanges.Contains(r)) localRanges.Add(r); });
|
||||
}
|
||||
ParameterInfo[] pis = method.method.GetParameters();
|
||||
for (int i = 0; i < pis.Length; i++)
|
||||
|
@ -320,32 +316,20 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
if (b is HasRange)
|
||||
{
|
||||
HasRange br = (HasRange) b;
|
||||
HasRange br = (HasRange)b;
|
||||
localRanges.Remove(br.Range);
|
||||
}
|
||||
}
|
||||
localRanges.Sort(delegate(Range a, Range b) { return MethodInvoke.CompareRanges(dict, a, b); });
|
||||
localRanges.Sort(delegate (Range a, Range b) { return MethodInvoke.CompareRanges(dict, a, b); });
|
||||
// convert from List<Range> to List<IStatementBlock>
|
||||
List<IStatementBlock> localRangeBlocks = new List<IStatementBlock>(localRanges.Select(r => r));
|
||||
BuildStatementBlocks(stBlocks, true);
|
||||
if (lhs != null)
|
||||
{
|
||||
// If there are inline loops, we need to incorporate the variable declaration outside the loops
|
||||
AddStatement(Builder.ExprStatement(lhs));
|
||||
lhs = null;
|
||||
}
|
||||
BuildStatementBlocks(localRangeBlocks, true);
|
||||
|
||||
// Invoke method
|
||||
IExpression methodExpr = method.GetExpression();
|
||||
if (lhs != null)
|
||||
{
|
||||
// If this is the definition of a declared variable, merge the declaration in to the assignment.
|
||||
IAssignExpression iae = (IAssignExpression) methodExpr;
|
||||
iae.Target = lhs;
|
||||
}
|
||||
IStatement st = Builder.ExprStatement(methodExpr);
|
||||
if (methodExpr is IAssignExpression && method.ReturnValue is HasObservedValue && ((HasObservedValue) method.ReturnValue).IsObserved)
|
||||
if (methodExpr is IAssignExpression && method.ReturnValue is HasObservedValue && ((HasObservedValue)method.ReturnValue).IsObserved)
|
||||
{
|
||||
Attributes.Set(st, new Constraint());
|
||||
}
|
||||
|
@ -388,23 +372,23 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
if (isb is IfBlock)
|
||||
{
|
||||
IfBlock ib = (IfBlock) isb;
|
||||
IfBlock ib = (IfBlock)isb;
|
||||
var condVar = ib.ConditionVariable;
|
||||
if (!negatedConditionVariables.Contains(condVar))
|
||||
{
|
||||
if (condVar.definition != null)
|
||||
{
|
||||
MethodInvoke mi = condVar.definition;
|
||||
if (mi.method.Equals(new Func<int, int, bool>(Microsoft.ML.Probabilistic.Factors.Factor.AreEqual).Method))
|
||||
if (mi.method.Equals(new Func<int, int, bool>(Factor.AreEqual).Method))
|
||||
{
|
||||
if (mi.Arguments[1] is Variable)
|
||||
{
|
||||
Variable arg1 = (Variable) mi.Arguments[1];
|
||||
Variable arg1 = (Variable)mi.Arguments[1];
|
||||
if (arg1.IsObserved || arg1.IsLoopIndex)
|
||||
{
|
||||
// convert 'if(vbool1)' into 'if(x==value)' where value is observed (or a loop index) and vbool1 is never negated.
|
||||
// if vbool1 is negated, then we cannot make this substitution since we need to match the corresponding 'if(!vbool1)' condition.
|
||||
IConditionStatement ics = (IConditionStatement) ist;
|
||||
IConditionStatement ics = (IConditionStatement)ist;
|
||||
ics.Condition = Builder.BinaryExpr(mi.Arguments[0].GetExpression(), BinaryOperator.ValueEquality, arg1.GetExpression());
|
||||
}
|
||||
}
|
||||
|
@ -422,7 +406,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
///
|
||||
public void SearchExpression<T>(IModelExpression<T> var)
|
||||
{
|
||||
if (var is Variable<T>) SearchVariable<T>((Variable<T>) var);
|
||||
if (var is Variable<T> varT) SearchVariable<T>(varT);
|
||||
else throw new InferCompilerException("Unhandled model expression type: " + var.GetType());
|
||||
}
|
||||
|
||||
|
@ -473,9 +457,11 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
if (!variable.Inline)
|
||||
{
|
||||
// Determine if the variable should be inlined
|
||||
bool inline = false;
|
||||
bool inline;
|
||||
if (variable.definition != null)
|
||||
{
|
||||
inline = variable.definition.CanBeInlined();
|
||||
}
|
||||
else
|
||||
{
|
||||
inline = (variable.conditionalDefinitions.Values.Count == 1);
|
||||
|
@ -496,7 +482,6 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
if (variable is IVariableArray)
|
||||
{
|
||||
IVariableArray iva = (IVariableArray)variable;
|
||||
IList<IStatement> sc = Builder.StmtCollection();
|
||||
IList<IVariableDeclaration[]> jaggedIndexVars;
|
||||
IList<IExpression[]> jaggedSizes;
|
||||
GetJaggedArrayIndicesAndSizes(iva, out jaggedIndexVars, out jaggedSizes);
|
||||
|
@ -565,8 +550,53 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
private void FinishVariable<T>(Variable<T> variable, IAlgorithm alg)
|
||||
{
|
||||
if (variable.IsLoopIndex) return; // do nothing
|
||||
if (variable.IsArrayElement) return;
|
||||
if (variable.Inline) return;
|
||||
|
||||
FinishRandVar(variable, alg);
|
||||
object ivd = variable.GetDeclaration();
|
||||
bool doNotInfer = false;
|
||||
// Add attributes
|
||||
foreach (ICompilerAttribute attr in variable.GetAttributes<ICompilerAttribute>())
|
||||
{
|
||||
if (attr is DoNotInfer) doNotInfer = true;
|
||||
else Attributes.Add(ivd, attr);
|
||||
}
|
||||
foreach (IStatementBlock stBlock in variable.Containers)
|
||||
{
|
||||
if (stBlock is HasRange)
|
||||
{
|
||||
doNotInfer = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
List<IStatementBlock> stBlocks = new List<IStatementBlock>();
|
||||
stBlocks.AddRange(variable.Containers);
|
||||
// Add Infer statement
|
||||
bool isConstant = (variable.IsBase && variable.IsReadOnly);
|
||||
if (!doNotInfer && ((!inferOnlySpecifiedVars && !isConstant) || variablesToInfer.Contains(variable)))
|
||||
{
|
||||
// If there has been no explicit indication of query types for inference, set the
|
||||
// default types
|
||||
List<QueryTypeCompilerAttribute> qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
|
||||
if (qtlist.Count == 0)
|
||||
{
|
||||
alg.ForEachDefaultQueryType(qt => Attributes.Add(ivd, new QueryTypeCompilerAttribute(qt)));
|
||||
qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
|
||||
}
|
||||
variablesToInfer.Add(variable);
|
||||
BuildStatementBlocks(stBlocks, true);
|
||||
IExpression varExpr = variable.GetExpression();
|
||||
IExpression varName = Builder.LiteralExpr(variable.NameInGeneratedCode);
|
||||
foreach (QueryTypeCompilerAttribute qt in qtlist)
|
||||
{
|
||||
IExpression queryExpr = Builder.FieldRefExpr(Builder.TypeRefExpr(typeof(QueryTypes)), typeof(QueryTypes), qt.QueryType.Name);
|
||||
// for a constant, we must get the variable reference, not the value
|
||||
if (isConstant) varExpr = Builder.VarRefExpr((IVariableDeclaration)variable.GetDeclaration());
|
||||
AddStatement(Builder.ExprStatement(
|
||||
Builder.StaticMethod(new Action<object>(InferNet.Infer), varExpr, varName, queryExpr)));
|
||||
}
|
||||
BuildStatementBlocks(stBlocks, false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -576,7 +606,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// <param name="expr">The variable expression</param>
|
||||
private void BuildExpression<T>(IModelExpression<T> expr)
|
||||
{
|
||||
if (expr is Variable<T>) BuildVariable<T>((Variable<T>) expr);
|
||||
if (expr is Variable<T> var) BuildVariable<T>(var);
|
||||
else throw new InferCompilerException("Unhandled model expression type: " + expr.GetType());
|
||||
}
|
||||
|
||||
|
@ -606,7 +636,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
|
||||
private bool ShouldInlineConstant<T>(Variable<T> constant)
|
||||
{
|
||||
return (Quoter.ShouldInlineType(typeof (T)) && (!constant.IsDefined) && !variablesToInfer.Contains(constant));
|
||||
return (Quoter.ShouldInlineType(typeof(T)) && (!constant.IsDefined) && !variablesToInfer.Contains(constant));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
@ -637,9 +667,9 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
if (!useExisting)
|
||||
{
|
||||
// create a new declaration
|
||||
ivd = (IVariableDeclaration) constant.GetDeclaration();
|
||||
ivd = (IVariableDeclaration)constant.GetDeclaration();
|
||||
var rhs = Quoter.Quote(constant.ObservedValue);
|
||||
if (ReferenceEquals(constant.ObservedValue, null)) rhs = Builder.CastExpr(rhs, typeof (T));
|
||||
if (ReferenceEquals(constant.ObservedValue, null)) rhs = Builder.CastExpr(rhs, typeof(T));
|
||||
AddStatement(Builder.AssignStmt(Builder.VarDeclExpr(ivd), rhs));
|
||||
constants[key] = ivd;
|
||||
}
|
||||
|
@ -692,7 +722,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
private void SearchRange(Range range)
|
||||
{
|
||||
if (searched.Contains(range)) return;
|
||||
string name = ((IModelExpression) range).Name;
|
||||
string name = ((IModelExpression)range).Name;
|
||||
foreach (IModelExpression expr in searched)
|
||||
{
|
||||
if (name.Equals(expr.Name))
|
||||
|
@ -716,9 +746,9 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
var ds = (DistributedSchedule)attr;
|
||||
toSearch.Push(ds.commExpression);
|
||||
if(ds.scheduleExpression != null)
|
||||
if (ds.scheduleExpression != null)
|
||||
toSearch.Push(ds.scheduleExpression);
|
||||
if(ds.schedulePerThreadExpression != null)
|
||||
if (ds.schedulePerThreadExpression != null)
|
||||
toSearch.Push(ds.schedulePerThreadExpression);
|
||||
var attr2 = new DistributedScheduleExpression(ds.commExpression.GetExpression(), ds.scheduleExpression?.GetExpression(), ds.schedulePerThreadExpression?.GetExpression());
|
||||
Attributes.Set(ivd, attr2);
|
||||
|
@ -740,28 +770,26 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
foreach (IStatementBlock sb in containers)
|
||||
{
|
||||
if (sb is ConditionBlock)
|
||||
if (sb is ConditionBlock cb)
|
||||
{
|
||||
if (sb is SwitchBlock)
|
||||
if (cb is SwitchBlock swb)
|
||||
{
|
||||
SearchRange(((SwitchBlock) sb).Range);
|
||||
SearchRange(swb.Range);
|
||||
}
|
||||
ConditionBlock cb = (ConditionBlock) sb;
|
||||
Variable condVar = cb.ConditionVariableUntyped;
|
||||
if (cb is IfBlock)
|
||||
if (cb is IfBlock ib)
|
||||
{
|
||||
IfBlock ib = (IfBlock) cb;
|
||||
if (ib.ConditionValue == false) negatedConditionVariables.Add(condVar);
|
||||
}
|
||||
toSearch.Push(condVar);
|
||||
}
|
||||
else if (sb is ForEachBlock)
|
||||
else if (sb is ForEachBlock fb)
|
||||
{
|
||||
SearchRange(((ForEachBlock) sb).Range);
|
||||
SearchRange(fb.Range);
|
||||
}
|
||||
else if (sb is RepeatBlock)
|
||||
else if (sb is RepeatBlock rb)
|
||||
{
|
||||
toSearch.Push(((RepeatBlock) sb).Count);
|
||||
toSearch.Push(rb.Count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -814,7 +842,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
}
|
||||
return;
|
||||
}
|
||||
IVariableDeclaration ivd = (IVariableDeclaration) variable.GetDeclaration();
|
||||
IVariableDeclaration ivd = (IVariableDeclaration)variable.GetDeclaration();
|
||||
if (variable.initialiseTo != null)
|
||||
{
|
||||
Attributes.Set(ivd, new InitialiseTo(variable.initialiseTo.GetExpression()));
|
||||
|
@ -837,9 +865,8 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
Set<IVariableDeclaration> loopVars = new Set<IVariableDeclaration>();
|
||||
foreach (IStatementBlock stBlock in stBlocks)
|
||||
{
|
||||
if (stBlock is ForEachBlock)
|
||||
if (stBlock is ForEachBlock fb)
|
||||
{
|
||||
ForEachBlock fb = (ForEachBlock) stBlock;
|
||||
IVariableDeclaration loopVar = fb.Range.GetIndexDeclaration();
|
||||
if (loopVars.Contains(loopVar))
|
||||
throw new InvalidOperationException("Variable '" + ivd.Name + "' uses range '" + loopVar.Name + "' twice. Use a cloned range instead.");
|
||||
|
@ -881,53 +908,6 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
|
||||
protected void FinishRandVar<T>(Variable<T> variable, IAlgorithm alg)
|
||||
{
|
||||
if (variable.IsArrayElement) return;
|
||||
if (variable.Inline) return;
|
||||
|
||||
object ivd = variable.GetDeclaration();
|
||||
bool doNotInfer = false;
|
||||
// Add attributes
|
||||
foreach (ICompilerAttribute attr in variable.GetAttributes<ICompilerAttribute>())
|
||||
{
|
||||
if (attr is DoNotInfer) doNotInfer = true;
|
||||
else Attributes.Add(ivd, attr);
|
||||
}
|
||||
foreach (IStatementBlock stBlock in variable.Containers)
|
||||
{
|
||||
if (stBlock is HasRange)
|
||||
{
|
||||
doNotInfer = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
List<IStatementBlock> stBlocks = new List<IStatementBlock>();
|
||||
stBlocks.AddRange(variable.Containers);
|
||||
// Add Infer statement
|
||||
bool isConstant = (variable.IsBase && variable.IsReadOnly);
|
||||
if (!doNotInfer && ((!inferOnlySpecifiedVars && !isConstant) || variablesToInfer.Contains(variable)))
|
||||
{
|
||||
// If there has been no explicit indication of query types for inference, set the
|
||||
// default types
|
||||
List<QueryTypeCompilerAttribute> qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
|
||||
if (qtlist.Count == 0)
|
||||
{
|
||||
alg.ForEachDefaultQueryType(qt => Attributes.Add(ivd, new QueryTypeCompilerAttribute(qt)));
|
||||
qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
|
||||
}
|
||||
variablesToInfer.Add(variable);
|
||||
BuildStatementBlocks(stBlocks, true);
|
||||
IExpression varExpr = variable.GetExpression();
|
||||
IExpression varName = Builder.LiteralExpr(variable.NameInGeneratedCode);
|
||||
foreach (QueryTypeCompilerAttribute qt in qtlist)
|
||||
{
|
||||
IExpression queryExpr = Builder.FieldRefExpr(Builder.TypeRefExpr(typeof (QueryTypes)), typeof (QueryTypes), qt.QueryType.Name);
|
||||
// for a constant, we must get the variable reference, not the value
|
||||
if (isConstant) varExpr = Builder.VarRefExpr((IVariableDeclaration)variable.GetDeclaration());
|
||||
AddStatement(Builder.ExprStatement(
|
||||
Builder.StaticMethod(new Action<object>(InferNet.Infer), varExpr, varName, queryExpr)));
|
||||
}
|
||||
BuildStatementBlocks(stBlocks, false);
|
||||
}
|
||||
}
|
||||
|
||||
protected void GetJaggedArrayIndicesAndSizes(IVariableArray array, out IList<IVariableDeclaration[]> jaggedIndexVars, out IList<IExpression[]> jaggedSizes)
|
||||
|
@ -955,10 +935,10 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
}
|
||||
jaggedIndexVars.Add(indexVars);
|
||||
jaggedSizes.Add(sizes);
|
||||
if (array is IVariableJaggedArray)
|
||||
if (array is IVariableJaggedArray variableJaggedArray)
|
||||
{
|
||||
IVariable itemPrototype = ((IVariableJaggedArray) array).ItemPrototype;
|
||||
if (itemPrototype is IVariableArray) array = (IVariableArray) itemPrototype;
|
||||
IVariable itemPrototype = variableJaggedArray.ItemPrototype;
|
||||
if (itemPrototype is IVariableArray variableArray) array = variableArray;
|
||||
else break;
|
||||
}
|
||||
else break;
|
||||
|
|
|
@ -106,7 +106,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
foreach (ICompilerAttribute attr in attributes)
|
||||
{
|
||||
if (attr is AttributeType) yield return (AttributeType) attr;
|
||||
if (attr is AttributeType attributeType) yield return attributeType;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -161,11 +161,13 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
{
|
||||
get
|
||||
{
|
||||
if (!(Size is Variable<int>)) throw new InvalidOperationException("The Range does not have constant size. Set IsReadOnly=true on the range size.");
|
||||
Variable<int> sizeVar = (Variable<int>) Size;
|
||||
if (!(sizeVar.IsObserved && sizeVar.IsReadOnly))
|
||||
throw new InvalidOperationException("The Range does not have constant size. To use SizeAsInt, set IsReadOnly=true on the range size.");
|
||||
return sizeVar.ObservedValue;
|
||||
if (Size is Variable<int> sizeVar)
|
||||
{
|
||||
if (!(sizeVar.IsObserved && sizeVar.IsReadOnly))
|
||||
throw new InvalidOperationException("The Range does not have constant size. To use SizeAsInt, set IsReadOnly=true on the range size.");
|
||||
return sizeVar.ObservedValue;
|
||||
}
|
||||
else throw new InvalidOperationException("The Range does not have constant size. Set IsReadOnly=true on the range size.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,7 +202,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
|
||||
internal IVariableDeclaration GetIndexDeclaration()
|
||||
{
|
||||
if (index == null) index = Builder.VarDecl(NameInGeneratedCode, typeof (int));
|
||||
if (index == null) index = Builder.VarDecl(NameInGeneratedCode, typeof(int));
|
||||
return index;
|
||||
}
|
||||
|
||||
|
@ -233,6 +235,18 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
return root;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get an expression that evaluates to true when this loop counter is increasing in the currently executing loop.
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
public Variable<bool> IsIncreasing()
|
||||
{
|
||||
Variable<bool> v = new Variable<bool>();
|
||||
v.SetTo(new Func<int, bool>(Factors.InferNet.IsIncreasing).Method, this);
|
||||
v.Inline = true;
|
||||
return v;
|
||||
}
|
||||
|
||||
private static Range ReplaceExpressions(Range r, Dictionary<IModelExpression, IModelExpression> replacements)
|
||||
{
|
||||
IModelExpression<int> newSize = (IModelExpression<int>)ReplaceExpressions(r.Size, replacements);
|
||||
|
@ -247,17 +261,16 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
private static IModelExpression ReplaceExpressions(IModelExpression expr, Dictionary<IModelExpression, IModelExpression> replacements)
|
||||
{
|
||||
if (replacements.ContainsKey(expr)) return replacements[expr];
|
||||
if (expr is Range)
|
||||
if (expr is Range range)
|
||||
{
|
||||
return ReplaceExpressions((Range)expr, replacements);
|
||||
return ReplaceExpressions(range, replacements);
|
||||
}
|
||||
else if (expr is Variable)
|
||||
else if (expr is Variable v)
|
||||
{
|
||||
Variable v = (Variable) expr;
|
||||
if (v.IsArrayElement)
|
||||
{
|
||||
bool changed = false;
|
||||
IVariableArray newArray = (IVariableArray) ReplaceExpressions(v.ArrayVariable, replacements);
|
||||
IVariableArray newArray = (IVariableArray)ReplaceExpressions(v.ArrayVariable, replacements);
|
||||
if (!ReferenceEquals(newArray, v.ArrayVariable)) changed = true;
|
||||
IModelExpression[] newIndices = new IModelExpression[v.indices.Count];
|
||||
for (int i = 0; i < newIndices.Length; i++)
|
||||
|
@ -294,13 +307,12 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// <exclude/>
|
||||
internal bool IsCompatibleWith(IModelExpression index)
|
||||
{
|
||||
if (index is Range) return (((Range) index).GetRoot() == GetRoot());
|
||||
else if (index is Variable)
|
||||
if (index is Range range) return (range.GetRoot() == GetRoot());
|
||||
else if (index is Variable indexVar)
|
||||
{
|
||||
Variable indexVar = (Variable) index;
|
||||
Range range = indexVar.GetValueRange(false);
|
||||
if (range == null) return true;
|
||||
return IsCompatibleWith(range);
|
||||
Range valueRange = indexVar.GetValueRange(false);
|
||||
if (valueRange == null) return true;
|
||||
return IsCompatibleWith(valueRange);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
@ -3746,6 +3746,22 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
return first;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a copy of the argument where the forward message is uniform when <paramref name="shouldCut"/> is true. Used to control inference.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="x"></param>
|
||||
/// <param name="shouldCut"></param>
|
||||
/// <returns></returns>
|
||||
public static Variable<T> CutForwardWhen<T>(Variable<T> x, Variable<bool> shouldCut)
|
||||
{
|
||||
Variable<T> result = Variable<T>.Factor(Factors.Cut.ForwardWhen, x, shouldCut);
|
||||
Range valueRange = x.GetValueRange(false);
|
||||
if (valueRange != null)
|
||||
result.AddAttribute(new ValueRange(valueRange));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a cut of the argument. Cut is equivalent to random(infer()).
|
||||
/// </summary>
|
||||
|
@ -3755,7 +3771,7 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
/// <remarks>Cut allows forward messages to pass through unchanged, whereas backward messages are cut off.</remarks>
|
||||
public static Variable<T> Cut<T>(Variable<T> x)
|
||||
{
|
||||
Variable<T> result = Variable<T>.Factor(Factor.Cut<T>, x);
|
||||
Variable<T> result = Variable<T>.Factor(Factors.Cut.Backward, x);
|
||||
Range valueRange = x.GetValueRange(false);
|
||||
if (valueRange != null)
|
||||
result.AddAttribute(new ValueRange(valueRange));
|
||||
|
|
|
@ -631,9 +631,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
protected override IExpression ConvertAssign(IAssignExpression iae)
|
||||
{
|
||||
iae = (IAssignExpression)base.ConvertAssign(iae);
|
||||
if (iae.Expression is IMethodInvokeExpression)
|
||||
if (iae.Expression is IMethodInvokeExpression imie)
|
||||
{
|
||||
IMethodInvokeExpression imie = (IMethodInvokeExpression)iae.Expression;
|
||||
IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target);
|
||||
if (ivd != null)
|
||||
{
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
// Licensed to the .NET Foundation under one or more agreements.
|
||||
// The .NET Foundation licenses this file to you under the MIT license.
|
||||
// See the LICENSE file in the project root for more information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.ML.Probabilistic.Compiler.Attributes;
|
||||
using Microsoft.ML.Probabilistic.Collections;
|
||||
using Microsoft.ML.Probabilistic.Compiler;
|
||||
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
|
||||
using Microsoft.ML.Probabilistic.Utilities;
|
||||
using System.Linq;
|
||||
using Microsoft.ML.Probabilistic.Models.Attributes;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
||||
{
|
||||
/// <summary>
|
||||
/// Convert occurrences of <c>IsIncreasing(i)</c> into literal boolean constants.
|
||||
/// </summary>
|
||||
internal class IsIncreasingTransform : ShallowCopyTransform
|
||||
{
|
||||
public override string Name
|
||||
{
|
||||
get
|
||||
{
|
||||
return nameof(IsIncreasingTransform);
|
||||
}
|
||||
}
|
||||
|
||||
HashSet<string> backwardLoops = new HashSet<string>();
|
||||
|
||||
protected override IStatement ConvertFor(IForStatement ifs)
|
||||
{
|
||||
string toRemove = null;
|
||||
if(!Recognizer.IsForwardLoop(ifs))
|
||||
{
|
||||
var loopVar = Recognizer.LoopVariable(ifs);
|
||||
toRemove = loopVar.Name;
|
||||
backwardLoops.Add(toRemove);
|
||||
}
|
||||
var result = base.ConvertFor(ifs);
|
||||
if (toRemove != null) backwardLoops.Remove(toRemove);
|
||||
return result;
|
||||
}
|
||||
|
||||
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
|
||||
{
|
||||
if(CodeRecognizer.IsIsIncreasing(imie))
|
||||
{
|
||||
IExpression arg = imie.Arguments[0];
|
||||
IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(arg);
|
||||
if (backwardLoops.Contains(ivd.Name)) return Builder.LiteralExpr(false);
|
||||
else return Builder.LiteralExpr(true);
|
||||
}
|
||||
return base.ConvertMethodInvoke(imie);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -206,6 +206,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
|
||||
{
|
||||
if (CodeRecognizer.IsInfer(imie)) return ConvertInfer(imie);
|
||||
if (CodeRecognizer.IsIsIncreasing(imie)) return imie;
|
||||
if (context.FindAncestor<IExpressionStatement>() == null) return imie;
|
||||
IExpression expr = imie;
|
||||
IAssignExpression iae = context.FindAncestor<IAssignExpression>();
|
||||
|
|
|
@ -239,8 +239,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (converted is IMethodInvokeExpression)
|
||||
{
|
||||
var mie = (IMethodInvokeExpression)converted;
|
||||
bool isAnd = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Microsoft.ML.Probabilistic.Factors.Factor.And));
|
||||
bool isOr = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Microsoft.ML.Probabilistic.Factors.Factor.Or));
|
||||
bool isAnd = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Factors.Factor.And));
|
||||
bool isOr = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Factors.Factor.Or));
|
||||
bool anyArgumentIsLiteral = mie.Arguments.Any(arg => arg is ILiteralExpression);
|
||||
if (anyArgumentIsLiteral)
|
||||
{
|
||||
|
@ -262,7 +262,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (reducedArguments.Count() == 1) return reducedArguments.First();
|
||||
else return Builder.LiteralExpr(false);
|
||||
}
|
||||
else if (Recognizer.IsStaticMethod(converted, new Func<bool, bool>(Microsoft.ML.Probabilistic.Factors.Factor.Not)))
|
||||
else if (Recognizer.IsStaticMethod(converted, new Func<bool, bool>(Factors.Factor.Not)))
|
||||
{
|
||||
bool allArgumentsAreLiteral = mie.Arguments.All(arg => arg is ILiteralExpression);
|
||||
if (allArgumentsAreLiteral)
|
||||
|
|
|
@ -896,7 +896,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
else if (
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Factor.Copy)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Diode.Copy)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Factor.Cut)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Cut.Backward)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, bool, PlaceHolder>(Cut.ForwardWhen)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, double, PlaceHolder>(PowerPlate.Enter)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, double, PlaceHolder>(Damp.Forward)) ||
|
||||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, double, PlaceHolder>(Damp.Backward)) ||
|
||||
|
|
|
@ -2338,6 +2338,11 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
Instance.IsStaticMethod(expr, new Action<object, string, QueryType>(InferNet.Infer));
|
||||
}
|
||||
|
||||
internal static bool IsIsIncreasing(IExpression expr)
|
||||
{
|
||||
return Instance.IsStaticMethod(expr, new Func<int,bool>(InferNet.IsIncreasing));
|
||||
}
|
||||
|
||||
internal static FactorManager.FactorInfo GetFactorInfo(BasicTransformContext context, IMethodInvokeExpression imie)
|
||||
{
|
||||
if (!context.InputAttributes.Has<FactorManager.FactorInfo>(imie))
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
|
||||
namespace Microsoft.ML.Probabilistic.Factors
|
||||
{
|
||||
using Microsoft.ML.Probabilistic;
|
||||
using System;
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning disable 1591
|
||||
#endif
|
||||
|
||||
/// <summary>
|
||||
/// Class used in MSL only.
|
||||
/// </summary>
|
||||
|
@ -18,37 +11,40 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
public static class InferNet
|
||||
{
|
||||
/// <summary>
|
||||
/// For use in MSL only.
|
||||
/// Used in MSL to indicate that a variable will be inferred.
|
||||
/// </summary>
|
||||
/// <param name="obj"></param>
|
||||
/// <param name="obj">A variable reference expression</param>
|
||||
public static void Infer(object obj)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Used in MSL to indicate that a variable will be inferred under a specific name.
|
||||
/// </summary>
|
||||
/// <param name="obj">A variable reference expression</param>
|
||||
/// <param name="name">The external name of the variable</param>
|
||||
public static void Infer(object obj, string name)
|
||||
{
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Used in MSL to indicate that a variable will be inferred under a specific name and query type.
|
||||
/// </summary>
|
||||
/// <param name="obj">A variable reference expression</param>
|
||||
/// <param name="name">The external name of the variable</param>
|
||||
/// <param name="query">The query type</param>
|
||||
public static void Infer(object obj, string name, QueryType query)
|
||||
{
|
||||
}
|
||||
|
||||
/*public class Power : IDisposable
|
||||
/// <summary>
|
||||
/// Used in MSL to indicate that the loop counter is increasing in the currently executing loop.
|
||||
/// </summary>
|
||||
/// <param name="loopCounter"></param>
|
||||
/// <returns></returns>
|
||||
public static bool IsIncreasing(int loopCounter)
|
||||
{
|
||||
double power;
|
||||
|
||||
public Power(double p)
|
||||
{
|
||||
power = p;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
}
|
||||
}*/
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#if SUPPRESS_XMLDOC_WARNINGS
|
||||
#pragma warning restore 1591
|
||||
#endif
|
||||
}
|
|
@ -6,10 +6,11 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
{
|
||||
using Microsoft.ML.Probabilistic.Distributions;
|
||||
using Microsoft.ML.Probabilistic.Factors.Attributes;
|
||||
using Microsoft.ML.Probabilistic.Math;
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/doc/*'/>
|
||||
// /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/doc/*'/>
|
||||
/// <typeparam name="T">The type of the variable being copied.</typeparam>
|
||||
[FactorMethod(typeof(Factor), "Cut<>")]
|
||||
[FactorMethod(typeof(Cut), "Backward<>")]
|
||||
[Quality(QualityBand.Preview)]
|
||||
public static class CutOp<T>
|
||||
{
|
||||
|
@ -23,23 +24,23 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional{TDist}(TDist)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageConditional{TDist}(TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
public static TDist CutAverageConditional<TDist>([IsReturned] TDist Value)
|
||||
public static TDist BackwardAverageConditional<TDist>([IsReturned] TDist Value)
|
||||
where TDist : IDistribution<T>
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional(T)"]/*'/>
|
||||
public static T CutAverageConditional([IsReturned] T Value)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageConditional(T)"]/*'/>
|
||||
public static T BackwardAverageConditional([IsReturned] T Value)
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
|
||||
// VMP /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="ValueAverageConditional{TDist}(TDist)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="ValueAverageLogarithm{TDist}(TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
[Skip]
|
||||
public static TDist ValueAverageLogarithm<TDist>(TDist result)
|
||||
|
@ -49,18 +50,105 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return result;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional{TDist}(TDist)"]/*'/>
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageLogarithm{TDist}(TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
public static TDist CutAverageLogarithm<TDist>([IsReturned] TDist Value)
|
||||
public static TDist BackwardAverageLogarithm<TDist>([IsReturned] TDist Value)
|
||||
where TDist : IDistribution<T>
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional(T)"]/*'/>
|
||||
public static T CutAverageLogarithm([IsReturned] T Value)
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageLogarithm(T)"]/*'/>
|
||||
public static T BackwardAverageLogarithm([IsReturned] T Value)
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/doc/*'/>
|
||||
/// <typeparam name="T">The type of the variable being copied.</typeparam>
|
||||
[FactorMethod(typeof(Cut), "ForwardWhen<>")]
|
||||
[Quality(QualityBand.Preview)]
|
||||
public static class CutForwardWhenOp<T>
|
||||
{
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ValueAverageConditional{TDist}(TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
public static TDist ValueAverageConditional<TDist>([IsReturned] TDist forwardWhen)
|
||||
where TDist : IDistribution<T>
|
||||
{
|
||||
return forwardWhen;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageConditional{TDist}(TDist,bool,TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
public static TDist ForwardWhenAverageConditional<TDist>(TDist Value, bool shouldCut, TDist result)
|
||||
where TDist : IDistribution<T>, SettableTo<TDist>
|
||||
{
|
||||
if (shouldCut) result.SetToUniform();
|
||||
else result.SetTo(Value);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageConditional(T)"]/*'/>
|
||||
public static T ForwardWhenAverageConditional([IsReturned] T Value)
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
|
||||
// VMP /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ValueAverageLogarithm{TDist}(TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
[Skip]
|
||||
public static TDist ValueAverageLogarithm<TDist>(TDist result)
|
||||
where TDist : IDistribution<T>
|
||||
{
|
||||
result.SetToUniform();
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageLogarithm{TDist}(TDist)"]/*'/>
|
||||
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
|
||||
public static TDist ForwardWhenAverageLogarithm<TDist>([IsReturned] TDist Value)
|
||||
where TDist : IDistribution<T>
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
|
||||
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageLogarithm(T)"]/*'/>
|
||||
public static T ForwardWhenAverageLogarithm([IsReturned] T Value)
|
||||
{
|
||||
return Value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Cut factor methods
|
||||
/// </summary>
|
||||
[Hidden]
|
||||
public static class Cut
|
||||
{
|
||||
/// <summary>
|
||||
/// Copy a value and cut the backward message (it will always be uniform).
|
||||
/// </summary>
|
||||
/// <typeparam name="T">The type the input value.</typeparam>
|
||||
/// <param name="value">The value to return.</param>
|
||||
/// <returns>The supplied value.</returns>
|
||||
public static T Backward<T>([SkipIfUniform] T value)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Copy a value and cut the forward message when <paramref name="shouldCut"/> is true.
|
||||
/// </summary>
|
||||
/// <typeparam name="T"></typeparam>
|
||||
/// <param name="value"></param>
|
||||
/// <param name="shouldCut"></param>
|
||||
/// <returns></returns>
|
||||
public static T ForwardWhen<T>([IsReturned] T value, bool shouldCut)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1302,18 +1302,6 @@ namespace Microsoft.ML.Probabilistic.Factors
|
|||
return value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Passes the input through to the output. Used to set backward messages to uniform.
|
||||
/// </summary>
|
||||
/// <typeparam name="T">The type the input value.</typeparam>
|
||||
/// <param name="value">The value to return.</param>
|
||||
/// <returns>The supplied value.</returns>
|
||||
[Hidden]
|
||||
public static T Cut<T>([SkipIfUniform] T value)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Generate a jagged array of Gaussian random variables.
|
||||
/// </summary>
|
||||
|
|
Загрузка…
Ссылка в новой задаче