Added Variable.CutForwardWhen, Range.IsIncreasing (#243)

Renamed Factor.Cut to Cut.Backward
Added InferNet.IsIncreasing, IsIncreasingTransform
This commit is contained in:
Tom Minka 2020-05-01 11:24:59 +01:00 коммит произвёл GitHub
Родитель 09d05e617a
Коммит 268353c5bb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 335 добавлений и 196 удалений

Просмотреть файл

@ -982,6 +982,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
tc.AddTransform(new IterativeProcessTransform(this, algorithm));
// LoopMerging is required to support offset indexing (see GateModelTests.CaseLoopIndexTest2)
tc.AddTransform(new LoopMergingTransform());
tc.AddTransform(new IsIncreasingTransform());
// Local is required for DistributedTests
tc.AddTransform(new LocalTransform(this));
if (OptimiseInferenceCode)

Просмотреть файл

@ -39,13 +39,7 @@ namespace Microsoft.ML.Probabilistic.Models
// Attributes of the method invocation i.e. the factor or constraint.
internal List<ICompilerAttribute> attributes = new List<ICompilerAttribute>();
// The condition blocks this method is contained in
private List<IStatementBlock> containers;
internal List<IStatementBlock> Containers
{
get { return containers; }
}
internal List<IStatementBlock> Containers { get; }
// Provides global ordering for ModelBuilder
internal readonly int timestamp;
@ -67,7 +61,7 @@ namespace Microsoft.ML.Probabilistic.Models
this.timestamp = GetTimestamp();
this.method = method;
this.args.AddRange(args);
this.containers = new List<IStatementBlock>(containers);
this.Containers = new List<IStatementBlock>(containers);
foreach (IModelExpression arg in args)
{
if (ReferenceEquals(arg, null)) throw new ArgumentNullException();
@ -77,7 +71,7 @@ namespace Microsoft.ML.Probabilistic.Models
if (v.IsObserved) continue;
foreach (ConditionBlock cb in v.GetContainers<ConditionBlock>())
{
if (!this.containers.Contains(cb))
if (!this.Containers.Contains(cb))
{
throw new InvalidOperationException($"{arg} was created in condition {cb} and cannot be used outside. " +
$"To give {arg} a conditional definition, use SetTo inside {cb} rather than assignment (=). " +
@ -195,7 +189,7 @@ namespace Microsoft.ML.Probabilistic.Models
else if (op == Variable.Operator.Equal) return Builder.BinaryExpr(argExprs[0], BinaryOperator.ValueEquality, argExprs[1]);
else if (op == Variable.Operator.NotEqual) return Builder.BinaryExpr(argExprs[0], BinaryOperator.ValueInequality, argExprs[1]);
}
IMethodInvokeExpression imie = null;
IMethodInvokeExpression imie;
if (method.IsGenericMethod && !method.ContainsGenericParameters)
{
imie = Builder.StaticGenericMethod(method, argExprs);
@ -240,7 +234,7 @@ namespace Microsoft.ML.Probabilistic.Models
{
Set<Range> ranges = new Set<Range>();
foreach (IModelExpression arg in returnValueAndArgs()) ForEachRange(arg, ranges.Add);
foreach (IStatementBlock b in containers)
foreach (IStatementBlock b in Containers)
{
if (b is HasRange)
{
@ -262,7 +256,7 @@ namespace Microsoft.ML.Probabilistic.Models
{
ForEachRange(arg, delegate(Range r) { if (!ranges.Contains(r)) ranges.Add(r); });
}
foreach (IStatementBlock b in containers)
foreach (IStatementBlock b in Containers)
{
if (b is HasRange)
{

Просмотреть файл

@ -65,7 +65,7 @@ namespace Microsoft.ML.Probabilistic.Models
modelType.Owner = null;
modelType.BaseType = null;
modelType.Visibility = TypeVisibility.Public;
modelMethod = Builder.MethodDecl(MethodVisibility.Public, "Model", typeof (void), modelType);
modelMethod = Builder.MethodDecl(MethodVisibility.Public, "Model", typeof(void), modelType);
IBlockStatement body = Builder.BlockStmt();
modelMethod.Body = body;
//blocks = new List<IList<IStatement>>();
@ -114,15 +114,13 @@ namespace Microsoft.ML.Probabilistic.Models
List<IModelExpression> exprs = new List<IModelExpression>();
foreach (IModelExpression expr in ModelExpressions)
{
if (expr is Variable)
if (expr is Variable var)
{
Variable var = (Variable) expr;
exprs.Add(var);
timestamps.Add(var.timestamp);
}
else if (expr is MethodInvoke)
else if (expr is MethodInvoke mi)
{
MethodInvoke mi = (MethodInvoke) expr;
exprs.Add(mi);
timestamps.Add(mi.timestamp);
}
@ -183,9 +181,9 @@ namespace Microsoft.ML.Probabilistic.Models
{
if (var == null) throw new NullReferenceException("Model expression was null.");
// Console.WriteLine("Building expression: "+var+" "+builtVars.ContainsKey(var));
if (var is MethodInvoke)
if (var is MethodInvoke methodInvoke)
{
BuildMethodInvoke((MethodInvoke)var, null);
BuildMethodInvoke(methodInvoke);
return;
}
MethodInfo mb = new Action<IModelExpression<object>>(this.BuildExpression<object>).Method.GetGenericMethodDefinition();
@ -220,7 +218,7 @@ namespace Microsoft.ML.Probabilistic.Models
Type[] faces = expr.GetType().GetInterfaces();
foreach (Type face in faces)
{
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof (IModelExpression<>))
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression<>))
{
domainType = face.GetGenericArguments()[0];
break;
@ -236,14 +234,14 @@ namespace Microsoft.ML.Probabilistic.Models
if (expr == null) throw new NullReferenceException("Model expression was null.");
// Console.WriteLine("Searching expression: "+var+" "+builtVars.ContainsKey(var));
if (searched.Contains(expr)) return;
if (expr is MethodInvoke)
if (expr is MethodInvoke methodInvoke)
{
SearchMethodInvoke((MethodInvoke) expr);
SearchMethodInvoke(methodInvoke);
return;
}
if (expr is Range)
if (expr is Range range)
{
SearchRange((Range) expr);
SearchRange(range);
return;
}
MethodInfo mb = new Action<IModelExpression<object>>(this.SearchExpression<object>).Method.GetGenericMethodDefinition();
@ -253,7 +251,7 @@ namespace Microsoft.ML.Probabilistic.Models
Type[] faces = expr.GetType().GetInterfaces();
foreach (Type face in faces)
{
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof (IModelExpression<>))
if (face.IsGenericType && face.GetGenericTypeDefinition() == typeof(IModelExpression<>))
{
domainType = face.GetGenericArguments()[0];
break;
@ -284,15 +282,13 @@ namespace Microsoft.ML.Probabilistic.Models
/// Add a statement of the form x = f(...) to the MSL.
/// </summary>
/// <param name="method">Stores the method to call, the argument variables, and target variable.</param>
/// <param name="lhs">Stores the name and type of the target variable, if it is not already declared. Otherwise null.</param>
///
/// <remarks>
/// If any variable in the statement is an item variable, then we surround the statement with a loop over its range.
/// Since there may be multiple item variables, and each item may depend on multiple ranges, we may end up with multiple loops.
/// </remarks>
private void BuildMethodInvoke(MethodInvoke method, IExpression lhs)
private void BuildMethodInvoke(MethodInvoke method)
{
if (method.ReturnValue is Variable && ((Variable) method.ReturnValue).Inline) return;
if (method.ReturnValue is Variable && ((Variable)method.ReturnValue).Inline) return;
// Open containing blocks
List<IStatementBlock> stBlocks = method.Containers;
List<Range> localRanges = new List<Range>();
@ -304,7 +300,7 @@ namespace Microsoft.ML.Probabilistic.Models
foreach (IModelExpression arg in method.returnValueAndArgs())
{
MethodInvoke.ForEachRange(arg,
delegate(Range r) { if (!localRanges.Contains(r)) localRanges.Add(r); });
delegate (Range r) { if (!localRanges.Contains(r)) localRanges.Add(r); });
}
ParameterInfo[] pis = method.method.GetParameters();
for (int i = 0; i < pis.Length; i++)
@ -320,32 +316,20 @@ namespace Microsoft.ML.Probabilistic.Models
{
if (b is HasRange)
{
HasRange br = (HasRange) b;
HasRange br = (HasRange)b;
localRanges.Remove(br.Range);
}
}
localRanges.Sort(delegate(Range a, Range b) { return MethodInvoke.CompareRanges(dict, a, b); });
localRanges.Sort(delegate (Range a, Range b) { return MethodInvoke.CompareRanges(dict, a, b); });
// convert from List<Range> to List<IStatementBlock>
List<IStatementBlock> localRangeBlocks = new List<IStatementBlock>(localRanges.Select(r => r));
BuildStatementBlocks(stBlocks, true);
if (lhs != null)
{
// If there are inline loops, we need to incorporate the variable declaration outside the loops
AddStatement(Builder.ExprStatement(lhs));
lhs = null;
}
BuildStatementBlocks(localRangeBlocks, true);
// Invoke method
IExpression methodExpr = method.GetExpression();
if (lhs != null)
{
// If this is the definition of a declared variable, merge the declaration in to the assignment.
IAssignExpression iae = (IAssignExpression) methodExpr;
iae.Target = lhs;
}
IStatement st = Builder.ExprStatement(methodExpr);
if (methodExpr is IAssignExpression && method.ReturnValue is HasObservedValue && ((HasObservedValue) method.ReturnValue).IsObserved)
if (methodExpr is IAssignExpression && method.ReturnValue is HasObservedValue && ((HasObservedValue)method.ReturnValue).IsObserved)
{
Attributes.Set(st, new Constraint());
}
@ -388,23 +372,23 @@ namespace Microsoft.ML.Probabilistic.Models
{
if (isb is IfBlock)
{
IfBlock ib = (IfBlock) isb;
IfBlock ib = (IfBlock)isb;
var condVar = ib.ConditionVariable;
if (!negatedConditionVariables.Contains(condVar))
{
if (condVar.definition != null)
{
MethodInvoke mi = condVar.definition;
if (mi.method.Equals(new Func<int, int, bool>(Microsoft.ML.Probabilistic.Factors.Factor.AreEqual).Method))
if (mi.method.Equals(new Func<int, int, bool>(Factor.AreEqual).Method))
{
if (mi.Arguments[1] is Variable)
{
Variable arg1 = (Variable) mi.Arguments[1];
Variable arg1 = (Variable)mi.Arguments[1];
if (arg1.IsObserved || arg1.IsLoopIndex)
{
// convert 'if(vbool1)' into 'if(x==value)' where value is observed (or a loop index) and vbool1 is never negated.
// if vbool1 is negated, then we cannot make this substitution since we need to match the corresponding 'if(!vbool1)' condition.
IConditionStatement ics = (IConditionStatement) ist;
IConditionStatement ics = (IConditionStatement)ist;
ics.Condition = Builder.BinaryExpr(mi.Arguments[0].GetExpression(), BinaryOperator.ValueEquality, arg1.GetExpression());
}
}
@ -422,7 +406,7 @@ namespace Microsoft.ML.Probabilistic.Models
///
public void SearchExpression<T>(IModelExpression<T> var)
{
if (var is Variable<T>) SearchVariable<T>((Variable<T>) var);
if (var is Variable<T> varT) SearchVariable<T>(varT);
else throw new InferCompilerException("Unhandled model expression type: " + var.GetType());
}
@ -473,9 +457,11 @@ namespace Microsoft.ML.Probabilistic.Models
if (!variable.Inline)
{
// Determine if the variable should be inlined
bool inline = false;
bool inline;
if (variable.definition != null)
{
inline = variable.definition.CanBeInlined();
}
else
{
inline = (variable.conditionalDefinitions.Values.Count == 1);
@ -496,7 +482,6 @@ namespace Microsoft.ML.Probabilistic.Models
if (variable is IVariableArray)
{
IVariableArray iva = (IVariableArray)variable;
IList<IStatement> sc = Builder.StmtCollection();
IList<IVariableDeclaration[]> jaggedIndexVars;
IList<IExpression[]> jaggedSizes;
GetJaggedArrayIndicesAndSizes(iva, out jaggedIndexVars, out jaggedSizes);
@ -565,8 +550,53 @@ namespace Microsoft.ML.Probabilistic.Models
private void FinishVariable<T>(Variable<T> variable, IAlgorithm alg)
{
if (variable.IsLoopIndex) return; // do nothing
if (variable.IsArrayElement) return;
if (variable.Inline) return;
FinishRandVar(variable, alg);
object ivd = variable.GetDeclaration();
bool doNotInfer = false;
// Add attributes
foreach (ICompilerAttribute attr in variable.GetAttributes<ICompilerAttribute>())
{
if (attr is DoNotInfer) doNotInfer = true;
else Attributes.Add(ivd, attr);
}
foreach (IStatementBlock stBlock in variable.Containers)
{
if (stBlock is HasRange)
{
doNotInfer = true;
break;
}
}
List<IStatementBlock> stBlocks = new List<IStatementBlock>();
stBlocks.AddRange(variable.Containers);
// Add Infer statement
bool isConstant = (variable.IsBase && variable.IsReadOnly);
if (!doNotInfer && ((!inferOnlySpecifiedVars && !isConstant) || variablesToInfer.Contains(variable)))
{
// If there has been no explicit indication of query types for inference, set the
// default types
List<QueryTypeCompilerAttribute> qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
if (qtlist.Count == 0)
{
alg.ForEachDefaultQueryType(qt => Attributes.Add(ivd, new QueryTypeCompilerAttribute(qt)));
qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
}
variablesToInfer.Add(variable);
BuildStatementBlocks(stBlocks, true);
IExpression varExpr = variable.GetExpression();
IExpression varName = Builder.LiteralExpr(variable.NameInGeneratedCode);
foreach (QueryTypeCompilerAttribute qt in qtlist)
{
IExpression queryExpr = Builder.FieldRefExpr(Builder.TypeRefExpr(typeof(QueryTypes)), typeof(QueryTypes), qt.QueryType.Name);
// for a constant, we must get the variable reference, not the value
if (isConstant) varExpr = Builder.VarRefExpr((IVariableDeclaration)variable.GetDeclaration());
AddStatement(Builder.ExprStatement(
Builder.StaticMethod(new Action<object>(InferNet.Infer), varExpr, varName, queryExpr)));
}
BuildStatementBlocks(stBlocks, false);
}
}
/// <summary>
@ -576,7 +606,7 @@ namespace Microsoft.ML.Probabilistic.Models
/// <param name="expr">The variable expression</param>
private void BuildExpression<T>(IModelExpression<T> expr)
{
if (expr is Variable<T>) BuildVariable<T>((Variable<T>) expr);
if (expr is Variable<T> var) BuildVariable<T>(var);
else throw new InferCompilerException("Unhandled model expression type: " + expr.GetType());
}
@ -606,7 +636,7 @@ namespace Microsoft.ML.Probabilistic.Models
private bool ShouldInlineConstant<T>(Variable<T> constant)
{
return (Quoter.ShouldInlineType(typeof (T)) && (!constant.IsDefined) && !variablesToInfer.Contains(constant));
return (Quoter.ShouldInlineType(typeof(T)) && (!constant.IsDefined) && !variablesToInfer.Contains(constant));
}
/// <summary>
@ -637,9 +667,9 @@ namespace Microsoft.ML.Probabilistic.Models
if (!useExisting)
{
// create a new declaration
ivd = (IVariableDeclaration) constant.GetDeclaration();
ivd = (IVariableDeclaration)constant.GetDeclaration();
var rhs = Quoter.Quote(constant.ObservedValue);
if (ReferenceEquals(constant.ObservedValue, null)) rhs = Builder.CastExpr(rhs, typeof (T));
if (ReferenceEquals(constant.ObservedValue, null)) rhs = Builder.CastExpr(rhs, typeof(T));
AddStatement(Builder.AssignStmt(Builder.VarDeclExpr(ivd), rhs));
constants[key] = ivd;
}
@ -692,7 +722,7 @@ namespace Microsoft.ML.Probabilistic.Models
private void SearchRange(Range range)
{
if (searched.Contains(range)) return;
string name = ((IModelExpression) range).Name;
string name = ((IModelExpression)range).Name;
foreach (IModelExpression expr in searched)
{
if (name.Equals(expr.Name))
@ -716,9 +746,9 @@ namespace Microsoft.ML.Probabilistic.Models
{
var ds = (DistributedSchedule)attr;
toSearch.Push(ds.commExpression);
if(ds.scheduleExpression != null)
if (ds.scheduleExpression != null)
toSearch.Push(ds.scheduleExpression);
if(ds.schedulePerThreadExpression != null)
if (ds.schedulePerThreadExpression != null)
toSearch.Push(ds.schedulePerThreadExpression);
var attr2 = new DistributedScheduleExpression(ds.commExpression.GetExpression(), ds.scheduleExpression?.GetExpression(), ds.schedulePerThreadExpression?.GetExpression());
Attributes.Set(ivd, attr2);
@ -740,28 +770,26 @@ namespace Microsoft.ML.Probabilistic.Models
{
foreach (IStatementBlock sb in containers)
{
if (sb is ConditionBlock)
if (sb is ConditionBlock cb)
{
if (sb is SwitchBlock)
if (cb is SwitchBlock swb)
{
SearchRange(((SwitchBlock) sb).Range);
SearchRange(swb.Range);
}
ConditionBlock cb = (ConditionBlock) sb;
Variable condVar = cb.ConditionVariableUntyped;
if (cb is IfBlock)
if (cb is IfBlock ib)
{
IfBlock ib = (IfBlock) cb;
if (ib.ConditionValue == false) negatedConditionVariables.Add(condVar);
}
toSearch.Push(condVar);
}
else if (sb is ForEachBlock)
else if (sb is ForEachBlock fb)
{
SearchRange(((ForEachBlock) sb).Range);
SearchRange(fb.Range);
}
else if (sb is RepeatBlock)
else if (sb is RepeatBlock rb)
{
toSearch.Push(((RepeatBlock) sb).Count);
toSearch.Push(rb.Count);
}
}
}
@ -814,7 +842,7 @@ namespace Microsoft.ML.Probabilistic.Models
}
return;
}
IVariableDeclaration ivd = (IVariableDeclaration) variable.GetDeclaration();
IVariableDeclaration ivd = (IVariableDeclaration)variable.GetDeclaration();
if (variable.initialiseTo != null)
{
Attributes.Set(ivd, new InitialiseTo(variable.initialiseTo.GetExpression()));
@ -837,9 +865,8 @@ namespace Microsoft.ML.Probabilistic.Models
Set<IVariableDeclaration> loopVars = new Set<IVariableDeclaration>();
foreach (IStatementBlock stBlock in stBlocks)
{
if (stBlock is ForEachBlock)
if (stBlock is ForEachBlock fb)
{
ForEachBlock fb = (ForEachBlock) stBlock;
IVariableDeclaration loopVar = fb.Range.GetIndexDeclaration();
if (loopVars.Contains(loopVar))
throw new InvalidOperationException("Variable '" + ivd.Name + "' uses range '" + loopVar.Name + "' twice. Use a cloned range instead.");
@ -881,53 +908,6 @@ namespace Microsoft.ML.Probabilistic.Models
protected void FinishRandVar<T>(Variable<T> variable, IAlgorithm alg)
{
if (variable.IsArrayElement) return;
if (variable.Inline) return;
object ivd = variable.GetDeclaration();
bool doNotInfer = false;
// Add attributes
foreach (ICompilerAttribute attr in variable.GetAttributes<ICompilerAttribute>())
{
if (attr is DoNotInfer) doNotInfer = true;
else Attributes.Add(ivd, attr);
}
foreach (IStatementBlock stBlock in variable.Containers)
{
if (stBlock is HasRange)
{
doNotInfer = true;
break;
}
}
List<IStatementBlock> stBlocks = new List<IStatementBlock>();
stBlocks.AddRange(variable.Containers);
// Add Infer statement
bool isConstant = (variable.IsBase && variable.IsReadOnly);
if (!doNotInfer && ((!inferOnlySpecifiedVars && !isConstant) || variablesToInfer.Contains(variable)))
{
// If there has been no explicit indication of query types for inference, set the
// default types
List<QueryTypeCompilerAttribute> qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
if (qtlist.Count == 0)
{
alg.ForEachDefaultQueryType(qt => Attributes.Add(ivd, new QueryTypeCompilerAttribute(qt)));
qtlist = Attributes.GetAll<QueryTypeCompilerAttribute>(ivd);
}
variablesToInfer.Add(variable);
BuildStatementBlocks(stBlocks, true);
IExpression varExpr = variable.GetExpression();
IExpression varName = Builder.LiteralExpr(variable.NameInGeneratedCode);
foreach (QueryTypeCompilerAttribute qt in qtlist)
{
IExpression queryExpr = Builder.FieldRefExpr(Builder.TypeRefExpr(typeof (QueryTypes)), typeof (QueryTypes), qt.QueryType.Name);
// for a constant, we must get the variable reference, not the value
if (isConstant) varExpr = Builder.VarRefExpr((IVariableDeclaration)variable.GetDeclaration());
AddStatement(Builder.ExprStatement(
Builder.StaticMethod(new Action<object>(InferNet.Infer), varExpr, varName, queryExpr)));
}
BuildStatementBlocks(stBlocks, false);
}
}
protected void GetJaggedArrayIndicesAndSizes(IVariableArray array, out IList<IVariableDeclaration[]> jaggedIndexVars, out IList<IExpression[]> jaggedSizes)
@ -955,10 +935,10 @@ namespace Microsoft.ML.Probabilistic.Models
}
jaggedIndexVars.Add(indexVars);
jaggedSizes.Add(sizes);
if (array is IVariableJaggedArray)
if (array is IVariableJaggedArray variableJaggedArray)
{
IVariable itemPrototype = ((IVariableJaggedArray) array).ItemPrototype;
if (itemPrototype is IVariableArray) array = (IVariableArray) itemPrototype;
IVariable itemPrototype = variableJaggedArray.ItemPrototype;
if (itemPrototype is IVariableArray variableArray) array = variableArray;
else break;
}
else break;

Просмотреть файл

@ -106,7 +106,7 @@ namespace Microsoft.ML.Probabilistic.Models
{
foreach (ICompilerAttribute attr in attributes)
{
if (attr is AttributeType) yield return (AttributeType) attr;
if (attr is AttributeType attributeType) yield return attributeType;
}
}
@ -161,11 +161,13 @@ namespace Microsoft.ML.Probabilistic.Models
{
get
{
if (!(Size is Variable<int>)) throw new InvalidOperationException("The Range does not have constant size. Set IsReadOnly=true on the range size.");
Variable<int> sizeVar = (Variable<int>) Size;
if (!(sizeVar.IsObserved && sizeVar.IsReadOnly))
throw new InvalidOperationException("The Range does not have constant size. To use SizeAsInt, set IsReadOnly=true on the range size.");
return sizeVar.ObservedValue;
if (Size is Variable<int> sizeVar)
{
if (!(sizeVar.IsObserved && sizeVar.IsReadOnly))
throw new InvalidOperationException("The Range does not have constant size. To use SizeAsInt, set IsReadOnly=true on the range size.");
return sizeVar.ObservedValue;
}
else throw new InvalidOperationException("The Range does not have constant size. Set IsReadOnly=true on the range size.");
}
}
@ -200,7 +202,7 @@ namespace Microsoft.ML.Probabilistic.Models
internal IVariableDeclaration GetIndexDeclaration()
{
if (index == null) index = Builder.VarDecl(NameInGeneratedCode, typeof (int));
if (index == null) index = Builder.VarDecl(NameInGeneratedCode, typeof(int));
return index;
}
@ -233,6 +235,18 @@ namespace Microsoft.ML.Probabilistic.Models
return root;
}
/// <summary>
/// Get an expression that evaluates to true when this loop counter is increasing in the currently executing loop.
/// </summary>
/// <returns></returns>
public Variable<bool> IsIncreasing()
{
Variable<bool> v = new Variable<bool>();
v.SetTo(new Func<int, bool>(Factors.InferNet.IsIncreasing).Method, this);
v.Inline = true;
return v;
}
private static Range ReplaceExpressions(Range r, Dictionary<IModelExpression, IModelExpression> replacements)
{
IModelExpression<int> newSize = (IModelExpression<int>)ReplaceExpressions(r.Size, replacements);
@ -247,17 +261,16 @@ namespace Microsoft.ML.Probabilistic.Models
private static IModelExpression ReplaceExpressions(IModelExpression expr, Dictionary<IModelExpression, IModelExpression> replacements)
{
if (replacements.ContainsKey(expr)) return replacements[expr];
if (expr is Range)
if (expr is Range range)
{
return ReplaceExpressions((Range)expr, replacements);
return ReplaceExpressions(range, replacements);
}
else if (expr is Variable)
else if (expr is Variable v)
{
Variable v = (Variable) expr;
if (v.IsArrayElement)
{
bool changed = false;
IVariableArray newArray = (IVariableArray) ReplaceExpressions(v.ArrayVariable, replacements);
IVariableArray newArray = (IVariableArray)ReplaceExpressions(v.ArrayVariable, replacements);
if (!ReferenceEquals(newArray, v.ArrayVariable)) changed = true;
IModelExpression[] newIndices = new IModelExpression[v.indices.Count];
for (int i = 0; i < newIndices.Length; i++)
@ -294,13 +307,12 @@ namespace Microsoft.ML.Probabilistic.Models
/// <exclude/>
internal bool IsCompatibleWith(IModelExpression index)
{
if (index is Range) return (((Range) index).GetRoot() == GetRoot());
else if (index is Variable)
if (index is Range range) return (range.GetRoot() == GetRoot());
else if (index is Variable indexVar)
{
Variable indexVar = (Variable) index;
Range range = indexVar.GetValueRange(false);
if (range == null) return true;
return IsCompatibleWith(range);
Range valueRange = indexVar.GetValueRange(false);
if (valueRange == null) return true;
return IsCompatibleWith(valueRange);
}
else
{

Просмотреть файл

@ -3746,6 +3746,22 @@ namespace Microsoft.ML.Probabilistic.Models
return first;
}
/// <summary>
/// Creates a copy of the argument where the forward message is uniform when <paramref name="shouldCut"/> is true. Used to control inference.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="x"></param>
/// <param name="shouldCut"></param>
/// <returns></returns>
public static Variable<T> CutForwardWhen<T>(Variable<T> x, Variable<bool> shouldCut)
{
Variable<T> result = Variable<T>.Factor(Factors.Cut.ForwardWhen, x, shouldCut);
Range valueRange = x.GetValueRange(false);
if (valueRange != null)
result.AddAttribute(new ValueRange(valueRange));
return result;
}
/// <summary>
/// Returns a cut of the argument. Cut is equivalent to random(infer()).
/// </summary>
@ -3755,7 +3771,7 @@ namespace Microsoft.ML.Probabilistic.Models
/// <remarks>Cut allows forward messages to pass through unchanged, whereas backward messages are cut off.</remarks>
public static Variable<T> Cut<T>(Variable<T> x)
{
Variable<T> result = Variable<T>.Factor(Factor.Cut<T>, x);
Variable<T> result = Variable<T>.Factor(Factors.Cut.Backward, x);
Range valueRange = x.GetValueRange(false);
if (valueRange != null)
result.AddAttribute(new ValueRange(valueRange));

Просмотреть файл

@ -631,9 +631,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
protected override IExpression ConvertAssign(IAssignExpression iae)
{
iae = (IAssignExpression)base.ConvertAssign(iae);
if (iae.Expression is IMethodInvokeExpression)
if (iae.Expression is IMethodInvokeExpression imie)
{
IMethodInvokeExpression imie = (IMethodInvokeExpression)iae.Expression;
IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(iae.Target);
if (ivd != null)
{

Просмотреть файл

@ -0,0 +1,58 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using Microsoft.ML.Probabilistic.Compiler.Attributes;
using Microsoft.ML.Probabilistic.Collections;
using Microsoft.ML.Probabilistic.Compiler;
using Microsoft.ML.Probabilistic.Compiler.CodeModel;
using Microsoft.ML.Probabilistic.Utilities;
using System.Linq;
using Microsoft.ML.Probabilistic.Models.Attributes;
namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
/// <summary>
/// Convert occurrences of <c>IsIncreasing(i)</c> into literal boolean constants.
/// </summary>
internal class IsIncreasingTransform : ShallowCopyTransform
{
public override string Name
{
get
{
return nameof(IsIncreasingTransform);
}
}
HashSet<string> backwardLoops = new HashSet<string>();
protected override IStatement ConvertFor(IForStatement ifs)
{
string toRemove = null;
if(!Recognizer.IsForwardLoop(ifs))
{
var loopVar = Recognizer.LoopVariable(ifs);
toRemove = loopVar.Name;
backwardLoops.Add(toRemove);
}
var result = base.ConvertFor(ifs);
if (toRemove != null) backwardLoops.Remove(toRemove);
return result;
}
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
{
if(CodeRecognizer.IsIsIncreasing(imie))
{
IExpression arg = imie.Arguments[0];
IVariableDeclaration ivd = Recognizer.GetVariableDeclaration(arg);
if (backwardLoops.Contains(ivd.Name)) return Builder.LiteralExpr(false);
else return Builder.LiteralExpr(true);
}
return base.ConvertMethodInvoke(imie);
}
}
}

Просмотреть файл

@ -206,6 +206,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
protected override IExpression ConvertMethodInvoke(IMethodInvokeExpression imie)
{
if (CodeRecognizer.IsInfer(imie)) return ConvertInfer(imie);
if (CodeRecognizer.IsIsIncreasing(imie)) return imie;
if (context.FindAncestor<IExpressionStatement>() == null) return imie;
IExpression expr = imie;
IAssignExpression iae = context.FindAncestor<IAssignExpression>();

Просмотреть файл

@ -239,8 +239,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (converted is IMethodInvokeExpression)
{
var mie = (IMethodInvokeExpression)converted;
bool isAnd = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Microsoft.ML.Probabilistic.Factors.Factor.And));
bool isOr = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Microsoft.ML.Probabilistic.Factors.Factor.Or));
bool isAnd = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Factors.Factor.And));
bool isOr = Recognizer.IsStaticMethod(converted, new Func<bool, bool, bool>(Factors.Factor.Or));
bool anyArgumentIsLiteral = mie.Arguments.Any(arg => arg is ILiteralExpression);
if (anyArgumentIsLiteral)
{
@ -262,7 +262,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (reducedArguments.Count() == 1) return reducedArguments.First();
else return Builder.LiteralExpr(false);
}
else if (Recognizer.IsStaticMethod(converted, new Func<bool, bool>(Microsoft.ML.Probabilistic.Factors.Factor.Not)))
else if (Recognizer.IsStaticMethod(converted, new Func<bool, bool>(Factors.Factor.Not)))
{
bool allArgumentsAreLiteral = mie.Arguments.All(arg => arg is ILiteralExpression);
if (allArgumentsAreLiteral)

Просмотреть файл

@ -896,7 +896,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
else if (
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Factor.Copy)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Diode.Copy)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Factor.Cut)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, PlaceHolder>(Cut.Backward)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, bool, PlaceHolder>(Cut.ForwardWhen)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, double, PlaceHolder>(PowerPlate.Enter)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, double, PlaceHolder>(Damp.Forward)) ||
Recognizer.IsStaticGenericMethod(imie, new Func<PlaceHolder, double, PlaceHolder>(Damp.Backward)) ||

Просмотреть файл

@ -2338,6 +2338,11 @@ namespace Microsoft.ML.Probabilistic.Compiler
Instance.IsStaticMethod(expr, new Action<object, string, QueryType>(InferNet.Infer));
}
internal static bool IsIsIncreasing(IExpression expr)
{
return Instance.IsStaticMethod(expr, new Func<int,bool>(InferNet.IsIncreasing));
}
internal static FactorManager.FactorInfo GetFactorInfo(BasicTransformContext context, IMethodInvokeExpression imie)
{
if (!context.InputAttributes.Has<FactorManager.FactorInfo>(imie))

Просмотреть файл

@ -4,13 +4,6 @@
namespace Microsoft.ML.Probabilistic.Factors
{
using Microsoft.ML.Probabilistic;
using System;
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning disable 1591
#endif
/// <summary>
/// Class used in MSL only.
/// </summary>
@ -18,37 +11,40 @@ namespace Microsoft.ML.Probabilistic.Factors
public static class InferNet
{
/// <summary>
/// For use in MSL only.
/// Used in MSL to indicate that a variable will be inferred.
/// </summary>
/// <param name="obj"></param>
/// <param name="obj">A variable reference expression</param>
public static void Infer(object obj)
{
}
/// <summary>
/// Used in MSL to indicate that a variable will be inferred under a specific name.
/// </summary>
/// <param name="obj">A variable reference expression</param>
/// <param name="name">The external name of the variable</param>
public static void Infer(object obj, string name)
{
}
/// <summary>
/// Used in MSL to indicate that a variable will be inferred under a specific name and query type.
/// </summary>
/// <param name="obj">A variable reference expression</param>
/// <param name="name">The external name of the variable</param>
/// <param name="query">The query type</param>
public static void Infer(object obj, string name, QueryType query)
{
}
/*public class Power : IDisposable
/// <summary>
/// Used in MSL to indicate that the loop counter is increasing in the currently executing loop.
/// </summary>
/// <param name="loopCounter"></param>
/// <returns></returns>
public static bool IsIncreasing(int loopCounter)
{
double power;
public Power(double p)
{
power = p;
}
public void Dispose()
{
}
}*/
return true;
}
}
#if SUPPRESS_XMLDOC_WARNINGS
#pragma warning restore 1591
#endif
}

Просмотреть файл

@ -6,10 +6,11 @@ namespace Microsoft.ML.Probabilistic.Factors
{
using Microsoft.ML.Probabilistic.Distributions;
using Microsoft.ML.Probabilistic.Factors.Attributes;
using Microsoft.ML.Probabilistic.Math;
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/doc/*'/>
// /// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/doc/*'/>
/// <typeparam name="T">The type of the variable being copied.</typeparam>
[FactorMethod(typeof(Factor), "Cut<>")]
[FactorMethod(typeof(Cut), "Backward<>")]
[Quality(QualityBand.Preview)]
public static class CutOp<T>
{
@ -23,23 +24,23 @@ namespace Microsoft.ML.Probabilistic.Factors
return result;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional{TDist}(TDist)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageConditional{TDist}(TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
public static TDist CutAverageConditional<TDist>([IsReturned] TDist Value)
public static TDist BackwardAverageConditional<TDist>([IsReturned] TDist Value)
where TDist : IDistribution<T>
{
return Value;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional(T)"]/*'/>
public static T CutAverageConditional([IsReturned] T Value)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageConditional(T)"]/*'/>
public static T BackwardAverageConditional([IsReturned] T Value)
{
return Value;
}
// VMP /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="ValueAverageConditional{TDist}(TDist)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="ValueAverageLogarithm{TDist}(TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
[Skip]
public static TDist ValueAverageLogarithm<TDist>(TDist result)
@ -49,18 +50,105 @@ namespace Microsoft.ML.Probabilistic.Factors
return result;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional{TDist}(TDist)"]/*'/>
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageLogarithm{TDist}(TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
public static TDist CutAverageLogarithm<TDist>([IsReturned] TDist Value)
public static TDist BackwardAverageLogarithm<TDist>([IsReturned] TDist Value)
where TDist : IDistribution<T>
{
return Value;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="CutAverageConditional(T)"]/*'/>
public static T CutAverageLogarithm([IsReturned] T Value)
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutOp{T}"]/message_doc[@name="BackwardAverageLogarithm(T)"]/*'/>
public static T BackwardAverageLogarithm([IsReturned] T Value)
{
return Value;
}
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/doc/*'/>
/// <typeparam name="T">The type of the variable being copied.</typeparam>
[FactorMethod(typeof(Cut), "ForwardWhen<>")]
[Quality(QualityBand.Preview)]
public static class CutForwardWhenOp<T>
{
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ValueAverageConditional{TDist}(TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
public static TDist ValueAverageConditional<TDist>([IsReturned] TDist forwardWhen)
where TDist : IDistribution<T>
{
return forwardWhen;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageConditional{TDist}(TDist,bool,TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
public static TDist ForwardWhenAverageConditional<TDist>(TDist Value, bool shouldCut, TDist result)
where TDist : IDistribution<T>, SettableTo<TDist>
{
if (shouldCut) result.SetToUniform();
else result.SetTo(Value);
return result;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageConditional(T)"]/*'/>
public static T ForwardWhenAverageConditional([IsReturned] T Value)
{
return Value;
}
// VMP /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ValueAverageLogarithm{TDist}(TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
[Skip]
public static TDist ValueAverageLogarithm<TDist>(TDist result)
where TDist : IDistribution<T>
{
result.SetToUniform();
return result;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageLogarithm{TDist}(TDist)"]/*'/>
/// <typeparam name="TDist">The type of the distribution over the variable being copied.</typeparam>
public static TDist ForwardWhenAverageLogarithm<TDist>([IsReturned] TDist Value)
where TDist : IDistribution<T>
{
return Value;
}
/// <include file='FactorDocs.xml' path='factor_docs/message_op_class[@name="CutForwardWhenOp{T}"]/message_doc[@name="ForwardWhenAverageLogarithm(T)"]/*'/>
public static T ForwardWhenAverageLogarithm([IsReturned] T Value)
{
return Value;
}
}
/// <summary>
/// Cut factor methods
/// </summary>
[Hidden]
public static class Cut
{
/// <summary>
/// Copy a value and cut the backward message (it will always be uniform).
/// </summary>
/// <typeparam name="T">The type the input value.</typeparam>
/// <param name="value">The value to return.</param>
/// <returns>The supplied value.</returns>
public static T Backward<T>([SkipIfUniform] T value)
{
return value;
}
/// <summary>
/// Copy a value and cut the forward message when <paramref name="shouldCut"/> is true.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="value"></param>
/// <param name="shouldCut"></param>
/// <returns></returns>
public static T ForwardWhen<T>([IsReturned] T value, bool shouldCut)
{
return value;
}
}
}

Просмотреть файл

@ -1302,18 +1302,6 @@ namespace Microsoft.ML.Probabilistic.Factors
return value;
}
/// <summary>
/// Passes the input through to the output. Used to set backward messages to uniform.
/// </summary>
/// <typeparam name="T">The type the input value.</typeparam>
/// <param name="value">The value to return.</param>
/// <returns>The supplied value.</returns>
[Hidden]
public static T Cut<T>([SkipIfUniform] T value)
{
return value;
}
/// <summary>
/// Generate a jagged array of Gaussian random variables.
/// </summary>