зеркало из https://github.com/dotnet/infer.git
BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier is public. (#376)
* Added Variable.Max(int,int). * Compiler warns about excess memory consumption in more cases when it should, and fewer cases when it shouldn't. * TransformBrowser shows attributes by default. * Updated FactorDocs
This commit is contained in:
Родитель
cc132f77f8
Коммит
3c825f2d29
|
@ -3465,6 +3465,17 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
return Variable<double>.Factor<double, double>(System.Math.Max, a, b);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns an int variable which is the maximum of two int variables
|
||||
/// </summary>
|
||||
/// <param name="a">The first variable</param>
|
||||
/// <param name="b">The second variable</param>
|
||||
/// <returns>A new variable</returns>
|
||||
public static Variable<int> Max(Variable<int> a, Variable<int> b)
|
||||
{
|
||||
return Variable<int>.Factor(System.Math.Max, a, b);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns a double variable which is the minimum of two double variables
|
||||
/// </summary>
|
||||
|
|
|
@ -591,72 +591,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
return expr;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns an expression equal to expr1 and expr2 under their respective bindings, or null if the expressions are not equal.
|
||||
/// </summary>
|
||||
/// <param name="expr1"></param>
|
||||
/// <param name="bindings1"></param>
|
||||
/// <param name="expr2"></param>
|
||||
/// <param name="bindings2"></param>
|
||||
/// <returns></returns>
|
||||
private static IExpression Unify(
|
||||
IExpression expr1,
|
||||
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings1,
|
||||
IExpression expr2,
|
||||
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings2)
|
||||
{
|
||||
if (expr1.Equals(expr2))
|
||||
{
|
||||
return expr1;
|
||||
}
|
||||
foreach (IReadOnlyCollection<ConditionBinding> binding2 in bindings2)
|
||||
{
|
||||
IExpression expr1b = ReplaceExpression(binding2, expr1);
|
||||
if (expr1b.Equals(expr2))
|
||||
{
|
||||
return expr1;
|
||||
}
|
||||
}
|
||||
foreach (IReadOnlyCollection<ConditionBinding> binding1 in bindings1)
|
||||
{
|
||||
IExpression expr2b = ReplaceExpression(binding1, expr2);
|
||||
if (expr2b.Equals(expr1))
|
||||
{
|
||||
return expr2;
|
||||
}
|
||||
}
|
||||
bool lift = false;
|
||||
if (lift)
|
||||
{
|
||||
IExpression lifted1 = GetLiftedExpression(expr1, bindings1);
|
||||
IExpression lifted2 = GetLiftedExpression(expr2, bindings2);
|
||||
if (lifted1 != null && lifted1.Equals(lifted2)) return lifted1;
|
||||
}
|
||||
return Builder.StaticMethod(new Func<int>(GateAnalysisTransform.AnyIndex));
|
||||
}
|
||||
|
||||
private static IExpression GetLiftedExpression(IExpression expr, IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings)
|
||||
{
|
||||
IExpression lifted = null;
|
||||
foreach (IReadOnlyCollection<ConditionBinding> binding in bindings)
|
||||
{
|
||||
IExpression lhs = null;
|
||||
foreach (ConditionBinding b in binding)
|
||||
{
|
||||
if (b.rhs.Equals(expr)) lhs = b.lhs;
|
||||
}
|
||||
if (lifted == null)
|
||||
{
|
||||
lifted = lhs;
|
||||
}
|
||||
else if (lhs == null || !lifted.Equals(lhs))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return lifted;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Used only in ArrayIndexerExpressions, to represent a wildcard.
|
||||
/// </summary>
|
||||
|
@ -747,6 +681,65 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
result.Bindings.AddRange(eb2.Bindings);
|
||||
}
|
||||
return result;
|
||||
|
||||
// Returns an expression equal to expr1 and expr2 under their respective bindings, or null if the expressions are not equal.
|
||||
IExpression Unify(
|
||||
IExpression expr1,
|
||||
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings1,
|
||||
IExpression expr2,
|
||||
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings2)
|
||||
{
|
||||
if (expr1.Equals(expr2))
|
||||
{
|
||||
return expr1;
|
||||
}
|
||||
foreach (IReadOnlyCollection<ConditionBinding> binding2 in bindings2)
|
||||
{
|
||||
IExpression expr1b = ReplaceExpression(binding2, expr1);
|
||||
if (expr1b.Equals(expr2))
|
||||
{
|
||||
return expr1;
|
||||
}
|
||||
}
|
||||
foreach (IReadOnlyCollection<ConditionBinding> binding1 in bindings1)
|
||||
{
|
||||
IExpression expr2b = ReplaceExpression(binding1, expr2);
|
||||
if (expr2b.Equals(expr1))
|
||||
{
|
||||
return expr2;
|
||||
}
|
||||
}
|
||||
bool lift = false;
|
||||
if (lift)
|
||||
{
|
||||
IExpression lifted1 = GetLiftedExpression(expr1, bindings1);
|
||||
IExpression lifted2 = GetLiftedExpression(expr2, bindings2);
|
||||
if (lifted1 != null && lifted1.Equals(lifted2)) return lifted1;
|
||||
}
|
||||
return Builder.StaticMethod(new Func<int>(GateAnalysisTransform.AnyIndex));
|
||||
|
||||
IExpression GetLiftedExpression(IExpression expr, IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings)
|
||||
{
|
||||
IExpression lifted = null;
|
||||
foreach (IReadOnlyCollection<ConditionBinding> binding in bindings)
|
||||
{
|
||||
IExpression lhs = null;
|
||||
foreach (ConditionBinding b in binding)
|
||||
{
|
||||
if (b.rhs.Equals(expr)) lhs = b.lhs;
|
||||
}
|
||||
if (lifted == null)
|
||||
{
|
||||
lifted = lhs;
|
||||
}
|
||||
else if (lhs == null || !lifted.Equals(lhs))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return lifted;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ using Microsoft.ML.Probabilistic.Utilities;
|
|||
using System.Linq;
|
||||
using Microsoft.ML.Probabilistic.Algorithms;
|
||||
using Microsoft.ML.Probabilistic.Models.Attributes;
|
||||
using Microsoft.ML.Probabilistic.Models;
|
||||
|
||||
namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
||||
{
|
||||
|
@ -82,11 +83,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
|
||||
string warningText =
|
||||
"This model will consume a lot of memory due to the following mix of indexing expressions inside of a conditional: {0}";
|
||||
"This model will consume excess memory due to the following mix of indexing expressions inside of a conditional: {0}";
|
||||
foreach (var entry in inefficientReplacements)
|
||||
{
|
||||
if (entry.Value != null)
|
||||
{
|
||||
Warning(string.Format(warningText, StringUtil.CollectionToString(entry.Value, ", ")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -554,28 +557,28 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (toReplace != null)
|
||||
{
|
||||
var currentBindings = bindings.Take(conditionContextIndex + 1).ToList();
|
||||
IExpression clone = ci.GetClone(context, toClone, currentBindings, this, start, conditionContextIndex,
|
||||
isDef);
|
||||
IExpression clone = ci.GetClone(context, toClone, currentBindings, this, start, conditionContextIndex, isDef);
|
||||
clone = Builder.JaggedArrayIndex(clone, indices);
|
||||
if (true)
|
||||
// check if indices contains an expression that is not a top-level loop variable or loop local
|
||||
// check for dependence on variables tagged with GateBlock that do not depend on a inner lop
|
||||
bool isSubset = indices.Any(bracket => bracket.Any(index =>
|
||||
{
|
||||
// check if indices contains an expression that is not a top-level loop variable or loop local
|
||||
bool isSubset = false;
|
||||
foreach (var bracket in indices)
|
||||
// Based on GateAnalysisTransform.ContainsLocalVars
|
||||
bool containsLocalVars = Recognizer.GetVariables(index).Any(indexVar =>
|
||||
context.InputAttributes.Get<GateBlock>(indexVar) == ci.gateBlock);
|
||||
if (containsLocalVars)
|
||||
{
|
||||
foreach (var index in bracket)
|
||||
{
|
||||
var indexVar = Recognizer.GetVariableDeclaration(index);
|
||||
if (indexVar == null)
|
||||
{
|
||||
isSubset = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Recognizer.GetVariables(index).Any(indexVar =>
|
||||
context.InputAttributes.Get<GateBlock>(indexVar) == ci.gateBlock &&
|
||||
Recognizer.GetLoopForVariable(context, indexVar) == null);
|
||||
}
|
||||
|
||||
RecordReplacement(expr, toClone, !isSubset);
|
||||
}
|
||||
else
|
||||
{
|
||||
var indexVarDecl = Recognizer.GetVariableDeclaration(index);
|
||||
return indexVarDecl == null;
|
||||
}
|
||||
}));
|
||||
RecordReplacement(expr, toClone, !isSubset);
|
||||
|
||||
int replaceCount = 0;
|
||||
expr = Builder.ReplaceExpression(expr, toReplace, clone, ref replaceCount);
|
||||
|
@ -614,7 +617,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
/// </summary>
|
||||
/// <param name="expr1"></param>
|
||||
/// <param name="expr2"></param>
|
||||
/// <param name="indices"></param>
|
||||
/// <param name="indices">Contains indices of <paramref name="expr1"/> that were replaced by wildcards in <paramref name="expr2"/></param>
|
||||
/// <returns></returns>
|
||||
private static IExpression GetMatchingPrefix(IExpression expr1, IExpression expr2,
|
||||
out List<IEnumerable<IExpression>> indices)
|
||||
|
@ -891,7 +894,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (Builder.ContainsExpression(expr, conditionLhs))
|
||||
context.Error($"Internal: expr ({expr}) contains the conditionLhs ({conditionLhs})");
|
||||
ClonedVarInfo cvi = GetClonedVarInfo(context, eb, isDef);
|
||||
if (cvi == null) return expr;
|
||||
if (cvi == null)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
if (cvi.arrayDecl == null)
|
||||
{
|
||||
var extraBindings = IndexingTransform.FilterBindingSet(eb.Bindings,
|
||||
|
@ -929,9 +935,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
/// <returns></returns>
|
||||
public IExpression ReplaceAnyItem(BasicTransformContext context, IExpression expr, List<IList<IExpression>> indices)
|
||||
{
|
||||
if (expr is IArrayIndexerExpression)
|
||||
if (expr is IArrayIndexerExpression iaie)
|
||||
{
|
||||
IArrayIndexerExpression iaie = (IArrayIndexerExpression) expr;
|
||||
IExpression result = ReplaceAnyItem(context, iaie.Target, indices);
|
||||
IList<IExpression> newIndices = Builder.ExprCollection();
|
||||
IList<IExpression> allIndices = Builder.ExprCollection();
|
||||
|
@ -1090,8 +1095,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
// This is done by attaching the attribute GateExitingVariable to the clones array.
|
||||
// ChannelTransform later reads this attribute.
|
||||
bool useExitRandom =
|
||||
((algorithm is VariationalMessagePassing) &&
|
||||
((VariationalMessagePassing) algorithm).UseGateExitRandom);
|
||||
(algorithm is VariationalMessagePassing vmp) &&
|
||||
vmp.UseGateExitRandom;
|
||||
// if using Gate.ExitRandom, the clones should be marked as GateExiting variables
|
||||
if (useExitRandom)
|
||||
context.OutputAttributes.Set(cvi.arrayDecl, new VariationalMessagePassing.GateExitRandomVariable());
|
||||
|
@ -1245,13 +1250,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
|
||||
internal string ToString(IExpression expr)
|
||||
{
|
||||
if (expr is IVariableReferenceExpression)
|
||||
if (expr is IVariableReferenceExpression ivre)
|
||||
{
|
||||
return ((IVariableReferenceExpression) expr).Variable.Resolve().Name;
|
||||
return ivre.Variable.Resolve().Name;
|
||||
}
|
||||
else if (expr is IArrayIndexerExpression)
|
||||
else if (expr is IArrayIndexerExpression iaie)
|
||||
{
|
||||
IArrayIndexerExpression iaie = (IArrayIndexerExpression) expr;
|
||||
StringBuilder sb = new StringBuilder(ToString(iaie.Target));
|
||||
foreach (IExpression indExpr in iaie.Indices)
|
||||
sb.Append("_" + ToString(indExpr));
|
||||
|
@ -1330,10 +1334,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
for (int i = 0; i < statements.Count; i++)
|
||||
{
|
||||
IStatement s = statements[i];
|
||||
if (s is IForStatement)
|
||||
if (s is IForStatement ifs)
|
||||
{
|
||||
// Recursively wrap the body statements.
|
||||
IForStatement ifs = (IForStatement) s;
|
||||
IForStatement fs = Builder.ForStmt();
|
||||
fs.Condition = ifs.Condition;
|
||||
fs.Increment = ifs.Increment;
|
||||
|
@ -1344,10 +1347,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
currentConditionStatement = null;
|
||||
continue;
|
||||
}
|
||||
if (s is IRepeatStatement)
|
||||
if (s is IRepeatStatement irs)
|
||||
{
|
||||
// Recursively wrap the body statements.
|
||||
IRepeatStatement irs = (IRepeatStatement) s;
|
||||
IRepeatStatement rs = Builder.RepeatStmt();
|
||||
rs.Count = irs.Count;
|
||||
rs.Body = WrapBlockWithConditionals(context, irs.Body);
|
||||
|
@ -1356,10 +1358,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
currentConditionStatement = null;
|
||||
continue;
|
||||
}
|
||||
if (s is IConditionStatement && !CodeRecognizer.IsStochastic(context, ((IConditionStatement) s).Condition))
|
||||
if (s is IConditionStatement ics && !CodeRecognizer.IsStochastic(context, ics.Condition))
|
||||
{
|
||||
// Recursively wrap the body statements.
|
||||
IConditionStatement ics = (IConditionStatement) s;
|
||||
IConditionStatement cs = Builder.CondStmt();
|
||||
cs.Condition = ics.Condition;
|
||||
cs.Then = WrapBlockWithConditionals(context, ics.Then);
|
||||
|
@ -1611,8 +1612,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (caseNumbers == null) return true;
|
||||
if (definitionCount == 0) return true;
|
||||
IExpression dimension = caseNumbers.Dimensions[0];
|
||||
if (!(dimension is ILiteralExpression)) return true;
|
||||
return (definitionCount == (int) ((ILiteralExpression) dimension).Value);
|
||||
if (dimension is ILiteralExpression ile) return definitionCount == (int)ile.Value;
|
||||
return true;
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
|
|
|
@ -92,14 +92,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
bool extraLiteralsAreZero = true;
|
||||
int parentIndex = context.InputStack.Count - 2;
|
||||
object parent = context.GetAncestor(parentIndex);
|
||||
while (parent is IArrayIndexerExpression)
|
||||
while (parent is IArrayIndexerExpression parent_iaie)
|
||||
{
|
||||
IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent;
|
||||
foreach (IExpression index in parent_iaie.Indices)
|
||||
{
|
||||
if (index is ILiteralExpression)
|
||||
if (index is ILiteralExpression ile)
|
||||
{
|
||||
int value = (int)((ILiteralExpression)index).Value;
|
||||
int value = (int)ile.Value;
|
||||
if (value != 0)
|
||||
{
|
||||
extraLiteralsAreZero = false;
|
||||
|
@ -154,73 +153,74 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
var stmtsAfter = Builder.StmtCollection();
|
||||
|
||||
// does the expression have the form array[indices[k]][indices2[k]][indices3[k]]?
|
||||
if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
|
||||
iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
|
||||
if (newvd == null && UseGetItems && iaie.Indices.Count == 1)
|
||||
{
|
||||
IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0];
|
||||
IArrayIndexerExpression iaie2 = (IArrayIndexerExpression)iaie.Target;
|
||||
if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression &&
|
||||
iaie2.Target is IArrayIndexerExpression &&
|
||||
iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression)
|
||||
if (iaie.Target is IArrayIndexerExpression iaie2 &&
|
||||
iaie.Indices[0] is IArrayIndexerExpression index3 &&
|
||||
index3.Indices.Count == 1 &&
|
||||
index3.Indices[0] is IVariableReferenceExpression innerIndex3 &&
|
||||
iaie2.Target is IArrayIndexerExpression iaie3 &&
|
||||
iaie2.Indices.Count == 1 &&
|
||||
iaie2.Indices[0] is IArrayIndexerExpression index2 &&
|
||||
index2.Indices.Count == 1 &&
|
||||
index2.Indices[0] is IVariableReferenceExpression innerIndex2 &&
|
||||
innerIndex2.Equals(innerIndex3) &&
|
||||
iaie3.Indices.Count == 1 &&
|
||||
iaie3.Indices[0] is IArrayIndexerExpression index &&
|
||||
index.Indices.Count == 1 &&
|
||||
index.Indices[0] is IVariableReferenceExpression innerIndex &&
|
||||
innerIndex.Equals(innerIndex2))
|
||||
{
|
||||
IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0];
|
||||
IArrayIndexerExpression iaie3 = (IArrayIndexerExpression)iaie2.Target;
|
||||
if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
|
||||
iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression)
|
||||
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
|
||||
if (innerLoop != null &&
|
||||
AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
|
||||
{
|
||||
IArrayIndexerExpression index = (IArrayIndexerExpression)iaie3.Indices[0];
|
||||
IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
|
||||
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
|
||||
if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex)
|
||||
&& index3.Indices[0].Equals(innerIndex) &&
|
||||
innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
|
||||
// expression has the form array[indices[k]][indices2[k]][indices3[k]]
|
||||
if (isDef)
|
||||
{
|
||||
// expression has the form array[indices[k]][indices2[k]][indices3[k]]
|
||||
if (isDef)
|
||||
{
|
||||
Error("fancy indexing not allowed on left hand side");
|
||||
return iaie;
|
||||
}
|
||||
WarnIfLocal(index.Target, iaie3.Target, iaie);
|
||||
WarnIfLocal(index2.Target, iaie3.Target, iaie);
|
||||
WarnIfLocal(index3.Target, iaie3.Target, iaie);
|
||||
containers = RemoveReferencesTo(containers, innerIndex);
|
||||
IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
|
||||
var indices = Recognizer.GetIndices(iaie);
|
||||
// Build name of replacement variable from index values
|
||||
StringBuilder sb = new StringBuilder("_item");
|
||||
AppendIndexString(sb, iaie3);
|
||||
AppendIndexString(sb, iaie2);
|
||||
AppendIndexString(sb, iaie);
|
||||
string name = ToString(iaie3.Target) + sb.ToString();
|
||||
VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
|
||||
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
|
||||
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
|
||||
context.InputAttributes.Set(newvd, new DerivedVariable());
|
||||
IExpression getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<IReadOnlyList<IReadOnlyList<PlaceHolder>>>, IReadOnlyList<int>, IReadOnlyList<int>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
|
||||
new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
|
||||
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
|
||||
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
|
||||
newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
|
||||
rhsExpr = getItems;
|
||||
Error("fancy indexing not allowed on left hand side");
|
||||
return iaie;
|
||||
}
|
||||
WarnIfLocal(index.Target, iaie3.Target, iaie);
|
||||
WarnIfLocal(index2.Target, iaie3.Target, iaie);
|
||||
WarnIfLocal(index3.Target, iaie3.Target, iaie);
|
||||
containers = RemoveReferencesTo(containers, innerIndex);
|
||||
IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
|
||||
var indices = Recognizer.GetIndices(iaie);
|
||||
// Build name of replacement variable from index values
|
||||
StringBuilder sb = new StringBuilder("_item");
|
||||
AppendIndexString(sb, iaie3);
|
||||
AppendIndexString(sb, iaie2);
|
||||
AppendIndexString(sb, iaie);
|
||||
string name = ToString(iaie3.Target) + sb.ToString();
|
||||
VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
|
||||
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
|
||||
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
|
||||
context.InputAttributes.Set(newvd, new DerivedVariable());
|
||||
IExpression getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<IReadOnlyList<IReadOnlyList<PlaceHolder>>>, IReadOnlyList<int>, IReadOnlyList<int>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
|
||||
new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
|
||||
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
|
||||
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
|
||||
newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
|
||||
rhsExpr = getItems;
|
||||
}
|
||||
}
|
||||
}
|
||||
// does the expression have the form array[indices[k]][indices2[k]]?
|
||||
if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
|
||||
iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
|
||||
if (newvd == null && UseGetItems && iaie.Indices.Count == 1)
|
||||
{
|
||||
IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0];
|
||||
IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target;
|
||||
if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
|
||||
target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression)
|
||||
if (iaie.Target is IArrayIndexerExpression target &&
|
||||
iaie.Indices[0] is IArrayIndexerExpression index2 &&
|
||||
index2.Indices.Count == 1 &&
|
||||
index2.Indices[0] is IVariableReferenceExpression innerIndex &&
|
||||
target.Indices.Count == 1 &&
|
||||
target.Indices[0] is IArrayIndexerExpression index)
|
||||
{
|
||||
IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0];
|
||||
IArrayIndexerExpression index = (IArrayIndexerExpression)target.Indices[0];
|
||||
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
|
||||
if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) &&
|
||||
innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target))
|
||||
if (index.Indices.Count == 1 &&
|
||||
index.Indices[0].Equals(innerIndex) &&
|
||||
innerLoop != null &&
|
||||
AreLoopsDisjoint(innerLoop, target.Target, index.Target))
|
||||
{
|
||||
// expression has the form array[indices[k]][indices2[k]]
|
||||
if (isDef)
|
||||
|
@ -233,17 +233,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
var indexTarget = index.Target;
|
||||
var index2Target = index2.Target;
|
||||
// check if the index array is jagged, i.e. array[indices[k][j]]
|
||||
while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression)
|
||||
while (indexTarget is IArrayIndexerExpression indexTargetExpr &&
|
||||
index2Target is IArrayIndexerExpression index2TargetExpr)
|
||||
{
|
||||
IArrayIndexerExpression indexTargetExpr = (IArrayIndexerExpression)indexTarget;
|
||||
IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target;
|
||||
if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression &&
|
||||
index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression)
|
||||
if (indexTargetExpr.Indices.Count == 1 &&
|
||||
indexTargetExpr.Indices[0] is IVariableReferenceExpression innerIndexTarget &&
|
||||
index2TargetExpr.Indices.Count == 1 &&
|
||||
index2TargetExpr.Indices[0] is IVariableReferenceExpression innerIndex2Target)
|
||||
{
|
||||
IVariableReferenceExpression innerIndexTarget = (IVariableReferenceExpression)indexTargetExpr.Indices[0];
|
||||
IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0];
|
||||
IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget);
|
||||
if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
|
||||
if (indexTargetLoop != null &&
|
||||
AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
|
||||
innerIndexTarget.Equals(innerIndex2Target))
|
||||
{
|
||||
innerLoops.Add(indexTargetLoop);
|
||||
|
@ -319,76 +319,76 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
string name = ToString(iaie.Target) + sb.ToString();
|
||||
|
||||
// does the expression have the form array[indices[k]]?
|
||||
if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
|
||||
if (UseGetItems &&
|
||||
iaie.Indices.Count == 1 &&
|
||||
iaie.Indices[0] is IArrayIndexerExpression index &&
|
||||
index.Indices.Count == 1 &&
|
||||
index.Indices[0] is IVariableReferenceExpression innerIndex)
|
||||
{
|
||||
IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0];
|
||||
if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression)
|
||||
// expression has the form array[indices[k]]
|
||||
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
|
||||
if (innerLoop != null &&
|
||||
AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
|
||||
{
|
||||
// expression has the form array[indices[k]]
|
||||
IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
|
||||
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
|
||||
if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
|
||||
if (isDef)
|
||||
{
|
||||
if (isDef)
|
||||
Error("fancy indexing not allowed on left hand side");
|
||||
return iaie;
|
||||
}
|
||||
var innerLoops = new List<IForStatement>();
|
||||
innerLoops.Add(innerLoop);
|
||||
var indexTarget = index.Target;
|
||||
// check if the index array is jagged, i.e. array[indices[k][j]]
|
||||
while (indexTarget is IArrayIndexerExpression index2)
|
||||
{
|
||||
if (index2.Indices.Count == 1 &&
|
||||
index2.Indices[0] is IVariableReferenceExpression innerIndex2)
|
||||
{
|
||||
Error("fancy indexing not allowed on left hand side");
|
||||
return iaie;
|
||||
}
|
||||
var innerLoops = new List<IForStatement>();
|
||||
innerLoops.Add(innerLoop);
|
||||
var indexTarget = index.Target;
|
||||
// check if the index array is jagged, i.e. array[indices[k][j]]
|
||||
while (indexTarget is IArrayIndexerExpression)
|
||||
{
|
||||
IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget;
|
||||
if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression)
|
||||
IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
|
||||
if (innerLoop2 != null &&
|
||||
AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
|
||||
{
|
||||
IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0];
|
||||
IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
|
||||
if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
|
||||
{
|
||||
innerLoops.Add(innerLoop2);
|
||||
indexTarget = index2.Target;
|
||||
// This limit must match the number of handled cases below.
|
||||
if (innerLoops.Count == 3) break;
|
||||
}
|
||||
else
|
||||
break;
|
||||
innerLoops.Add(innerLoop2);
|
||||
indexTarget = index2.Target;
|
||||
// This limit must match the number of handled cases below.
|
||||
if (innerLoops.Count == 3) break;
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
WarnIfLocal(indexTarget, iaie.Target, originalExpr);
|
||||
innerLoops.Reverse();
|
||||
var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
|
||||
var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
|
||||
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
|
||||
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
|
||||
context.InputAttributes.Set(newvd, new DerivedVariable());
|
||||
IExpression getItems;
|
||||
if (innerLoops.Count == 1)
|
||||
{
|
||||
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItems),
|
||||
new Type[] { tp }, iaie.Target, indexTarget);
|
||||
}
|
||||
else if (innerLoops.Count == 2)
|
||||
{
|
||||
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<int>>, PlaceHolder[][]>(Collection.GetJaggedItems),
|
||||
new Type[] { tp }, iaie.Target, indexTarget);
|
||||
}
|
||||
else if (innerLoops.Count == 3)
|
||||
{
|
||||
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<IReadOnlyList<int>>>, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
|
||||
new Type[] { tp }, iaie.Target, indexTarget);
|
||||
}
|
||||
else
|
||||
throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
|
||||
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
|
||||
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
|
||||
var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
|
||||
newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
|
||||
rhsExpr = getItems;
|
||||
break;
|
||||
}
|
||||
WarnIfLocal(indexTarget, iaie.Target, originalExpr);
|
||||
innerLoops.Reverse();
|
||||
var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
|
||||
var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
|
||||
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
|
||||
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
|
||||
context.InputAttributes.Set(newvd, new DerivedVariable());
|
||||
IExpression getItems;
|
||||
if (innerLoops.Count == 1)
|
||||
{
|
||||
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItems),
|
||||
new Type[] { tp }, iaie.Target, indexTarget);
|
||||
}
|
||||
else if (innerLoops.Count == 2)
|
||||
{
|
||||
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<int>>, PlaceHolder[][]>(Collection.GetJaggedItems),
|
||||
new Type[] { tp }, iaie.Target, indexTarget);
|
||||
}
|
||||
else if (innerLoops.Count == 3)
|
||||
{
|
||||
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<IReadOnlyList<int>>>, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
|
||||
new Type[] { tp }, iaie.Target, indexTarget);
|
||||
}
|
||||
else
|
||||
throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
|
||||
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
|
||||
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
|
||||
var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
|
||||
newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
|
||||
rhsExpr = getItems;
|
||||
}
|
||||
}
|
||||
if (newvd == null)
|
||||
|
|
|
@ -337,7 +337,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
if (embeddedLoopIndices.Contains(currentLoop))
|
||||
{
|
||||
string warningText = "This model will consume a lot of memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
|
||||
string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
|
||||
Warning(string.Format(warningText, originalExpr, loopVar.Name));
|
||||
}
|
||||
// split expr into a target and extra indices, where target will be replicated and extra indices will be added later
|
||||
|
|
|
@ -511,7 +511,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
|
||||
/// <returns>The binary Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
|
||||
CreateGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource>(
|
||||
IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, bool> mapping)
|
||||
{
|
||||
|
@ -532,7 +532,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
|
||||
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
|
||||
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
|
||||
CreateGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource>(
|
||||
IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, int> mapping)
|
||||
{
|
||||
|
@ -554,7 +554,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
|
||||
/// <returns>The binary Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
CreateGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
|
||||
IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
|
||||
{
|
||||
|
@ -576,7 +576,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <typeparam name="TLabel">The type of a label.</typeparam>
|
||||
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
|
||||
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
CreateGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
|
||||
IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
|
||||
{
|
||||
|
@ -603,7 +603,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
|
@ -621,7 +621,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="stream">The stream.</param>
|
||||
/// <param name="formatter">The formatter.</param>
|
||||
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(Stream stream, IFormatter formatter)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
|
||||
|
@ -638,7 +638,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
|
||||
/// <param name="fileName">The file name.</param>
|
||||
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
|
||||
|
@ -656,7 +656,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="stream">The stream.</param>
|
||||
/// <param name="formatter">The formatter.</param>
|
||||
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(Stream stream, IFormatter formatter)
|
||||
{
|
||||
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
|
||||
|
@ -676,7 +676,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="reader">The reader to a stream of a serialized binary Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
|
||||
/// <returns>The binary Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
|
||||
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource>(
|
||||
IReader reader, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, bool> mapping)
|
||||
{
|
||||
|
@ -703,7 +703,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="fileName">The name of the file of a serialized binary Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
|
||||
/// <returns>The binary Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
|
||||
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource>(
|
||||
string fileName, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, bool> mapping)
|
||||
{
|
||||
|
@ -728,7 +728,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="reader">The reader to a stream of a serialized multi-class Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
|
||||
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource>(
|
||||
IReader reader, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, int> mapping)
|
||||
{
|
||||
|
@ -755,7 +755,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="fileName">The name of the file of a serialized multi-class Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
|
||||
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource>(
|
||||
string fileName, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, int> mapping)
|
||||
{
|
||||
|
@ -781,7 +781,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="reader">The reader to a stream of a serialized binary Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
|
||||
/// <returns>The binary Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
|
||||
IReader reader, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
|
||||
{
|
||||
|
@ -809,7 +809,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="fileName">The name of the file of a serialized binary Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
|
||||
/// <returns>The binary Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
|
||||
string fileName, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
|
||||
{
|
||||
|
@ -835,7 +835,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="reader">The reader to a stream of a serialized multi-class Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
|
||||
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
|
||||
IReader reader, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
|
||||
{
|
||||
|
@ -863,7 +863,7 @@ namespace Microsoft.ML.Probabilistic.Learners
|
|||
/// <param name="fileName">The name of the file of a serialized multi-class Bayes point machine classifier.</param>
|
||||
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
|
||||
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
|
||||
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
|
||||
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
|
||||
string fileName, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
|
||||
{
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace Microsoft.ML.Probabilistic.Learners.BayesPointMachineClassifierInterna
|
|||
/// These settings cannot be modified after training.
|
||||
/// </remarks>
|
||||
[Serializable]
|
||||
internal class GaussianBayesPointMachineClassifierAdvancedTrainingSettings : ICustomSerializable
|
||||
public class GaussianBayesPointMachineClassifierAdvancedTrainingSettings : ICustomSerializable
|
||||
{
|
||||
/// <summary>
|
||||
/// The current custom binary serialization version of the <see cref="GaussianBayesPointMachineClassifierAdvancedTrainingSettings"/> class.
|
||||
|
|
|
@ -14,7 +14,7 @@ namespace Microsoft.ML.Probabilistic.Learners.BayesPointMachineClassifierInterna
|
|||
/// Settings which affect training of the Bayes point machine classifier with <see cref="Gaussian"/> prior distributions over weights.
|
||||
/// </summary>
|
||||
[Serializable]
|
||||
internal class GaussianBayesPointMachineClassifierTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
public class GaussianBayesPointMachineClassifierTrainingSettings : BayesPointMachineClassifierTrainingSettings
|
||||
{
|
||||
/// <summary>
|
||||
/// The current serialization version of <see cref="GaussianBayesPointMachineClassifierTrainingSettings"/>.
|
||||
|
|
|
@ -18805,7 +18805,7 @@
|
|||
</remarks>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
<message_op_class name="DoubleIsBetweenOp">
|
||||
<message_op_class name="IsBetweenGaussianOp">
|
||||
<doc>
|
||||
<summary>Provides outgoing messages for <see cref="Factor.IsBetween(double, double, double)" />, given random arguments to the function.</summary>
|
||||
</doc>
|
||||
|
@ -19212,7 +19212,7 @@
|
|||
</remarks>
|
||||
</message_doc>
|
||||
</message_op_class>
|
||||
<message_op_class name="TruncatedGaussianIsBetweenOp">
|
||||
<message_op_class name="IsBetweenTruncatedGaussianOp">
|
||||
<doc>
|
||||
<summary>Provides outgoing messages for <see cref="Factor.IsBetween(double, double, double)" />, given random arguments to the function.</summary>
|
||||
</doc>
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
<Grid.ColumnDefinitions>
|
||||
<ColumnDefinition Width="3*"/>
|
||||
<ColumnDefinition Width="Auto"/>
|
||||
<ColumnDefinition Name="AttributesColumn" Width="0"/>
|
||||
<ColumnDefinition Name="AttributesColumn" Width="250"/>
|
||||
</Grid.ColumnDefinitions>
|
||||
<Grid Margin="5" Grid.ColumnSpan="3">
|
||||
<Grid.ColumnDefinitions>
|
||||
|
|
|
@ -109,8 +109,10 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
using (Variable.IfNot(playersFirstGame))
|
||||
{
|
||||
var prevGame = PrevGameIndices[game][gamePlayer];//, 0).Named("prevGame");
|
||||
var prevGamePlayerIndex = PrevGamePlayerIndices[game][gamePlayer];//, 0).Named("prevGamePlayerIndex");
|
||||
////var prevGame = Variable.Min(PrevGameIndices[game][gamePlayer], 0).Named("prevGame");
|
||||
////var prevGamePlayerIndex = Variable.Min(PrevGamePlayerIndices[game][gamePlayer], 0).Named("prevGamePlayerIndex");
|
||||
var prevGame = PrevGameIndices[game][gamePlayer];
|
||||
var prevGamePlayerIndex = PrevGamePlayerIndices[game][gamePlayer];
|
||||
Skills[game][gamePlayer] = Variable.GaussianFromMeanAndVariance(
|
||||
Skills[prevGame][prevGamePlayerIndex],
|
||||
dynamicsVariance
|
||||
|
|
|
@ -4337,6 +4337,11 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
Variable.ConstrainEqualRandom(b[indexObs], likelihood2[index]);
|
||||
}
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
for (int trial = 0; trial < 2; trial++)
|
||||
{
|
||||
indexObs.ObservedValue = trial;
|
||||
|
@ -7173,7 +7178,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
DistributionArray<Bernoulli> bDist = engine.Infer<DistributionArray<Bernoulli>>(b);
|
||||
for (int i = 0; i < bDist.Count; i++)
|
||||
|
@ -7215,7 +7220,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
DistributionArray<Bernoulli> bDist = engine.Infer<DistributionArray<Bernoulli>>(b);
|
||||
for (int i = 0; i < bDist.Count; i++)
|
||||
|
@ -7231,6 +7236,47 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void EnterArrayElementsTest3()
|
||||
{
|
||||
Range item = new Range(2).Named("item");
|
||||
Range xitem = new Range(2).Named("xitem");
|
||||
VariableArray<bool> x = Variable.Array<bool>(xitem).Named("x");
|
||||
double xPrior = 0.3;
|
||||
x[xitem] = Variable.Bernoulli(xPrior).ForEach(xitem);
|
||||
VariableArray<int> indices = Variable.Array<int>(item).Named("indices");
|
||||
indices.ObservedValue = new int[] { 0, 1 };
|
||||
VariableArray<int> indices2 = Variable.Array<int>(item).Named("indices2");
|
||||
indices2.ObservedValue = new int[] { 1, 0 };
|
||||
VariableArray<bool> b = Variable.Array<bool>(item).Named("b");
|
||||
double bPrior = 0.1;
|
||||
double xLike = 0.4;
|
||||
using (Variable.ForEach(item))
|
||||
{
|
||||
b[item] = Variable.Bernoulli(bPrior);
|
||||
using (Variable.If(b[item]))
|
||||
{
|
||||
var index = Variable.Max(0, indices[item]);
|
||||
Variable.ConstrainEqualRandom(x[index], new Bernoulli(xLike));
|
||||
}
|
||||
}
|
||||
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
engine.ShowProgress = false;
|
||||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
DistributionArray<Bernoulli> bDist = engine.Infer<DistributionArray<Bernoulli>>(b);
|
||||
for (int i = 0; i < bDist.Count; i++)
|
||||
{
|
||||
double sumCondT = xPrior * xLike + (1 - xPrior) * (1 - xLike);
|
||||
double bPost = bPrior * sumCondT / (bPrior * sumCondT + (1 - bPrior));
|
||||
Assert.True(MMath.AbsDiff(bDist[i].GetProbTrue(), bPost, 1e-10) < 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
internal void FairCoinTest()
|
||||
{
|
||||
VariableArray<bool> tosses = Variable.Constant(new bool[] { true, true, true, true, true });
|
||||
|
|
|
@ -1333,7 +1333,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 0);
|
||||
Assert.Equal(0, e.Warnings.Count);
|
||||
};
|
||||
engine.Infer(userThresholds);
|
||||
}
|
||||
|
@ -1359,7 +1359,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
|
||||
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
|
||||
|
@ -1402,7 +1402,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
|
||||
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
|
||||
|
@ -1438,7 +1438,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
|
||||
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
|
||||
|
@ -1477,7 +1477,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
//Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
|
||||
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
|
||||
|
@ -1729,7 +1729,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
IDistribution<bool[]> post = engine.Infer<IDistribution<bool[]>>(bools);
|
||||
}
|
||||
|
@ -1866,7 +1866,7 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.True(e.Warnings.Count == 1);
|
||||
Assert.Equal(1, e.Warnings.Count);
|
||||
};
|
||||
IDistribution<bool[][]> weightsActual = engine.Infer<IDistribution<bool[][]>>(bools);
|
||||
}
|
||||
|
@ -2433,6 +2433,11 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
if (!(algorithm is GibbsSampling)) block.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.Equal(0, e.Warnings.Count);
|
||||
};
|
||||
indices.ObservedValue = new int[] {0};
|
||||
indicesLength.ObservedValue = indices.ObservedValue.Length;
|
||||
Bernoulli[][] arrayExpectedArray = new Bernoulli[item.SizeAsInt][];
|
||||
|
@ -2524,12 +2529,17 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
double xLike = 0.6;
|
||||
using (Variable.Switch(switchVar))
|
||||
{
|
||||
Variable<bool> x = array[indices[switchVar][xitem]].Named("x");
|
||||
Variable<bool> x = array[indices[switchVar][xitem]];
|
||||
Variable.ConstrainEqualRandom(x, new Bernoulli(xLike));
|
||||
}
|
||||
if (!(algorithm is GibbsSampling)) block.CloseBlock();
|
||||
|
||||
InferenceEngine engine = new InferenceEngine(algorithm);
|
||||
engine.Compiler.Compiled += (sender, e) =>
|
||||
{
|
||||
// check for the inefficient replication warning
|
||||
Assert.Equal(0, e.Warnings.Count);
|
||||
};
|
||||
indices.ObservedValue = new int[][] {new int[] {0}, new int[] {3}};
|
||||
indicesLength.ObservedValue = new int[] {indices.ObservedValue[0].Length, indices.ObservedValue[1].Length};
|
||||
Bernoulli[] arrayExpectedArray = new Bernoulli[item.SizeAsInt];
|
||||
|
|
|
@ -407,7 +407,6 @@ namespace Microsoft.ML.Probabilistic.Tests
|
|||
|
||||
InferenceEngine engine = new InferenceEngine();
|
||||
//engine.Algorithm = new VariationalMessagePassing();
|
||||
engine.Compiler.CatchExceptions = true;
|
||||
//engine.Compiler.UnrollLoops = true;
|
||||
//engine.Compiler.UseSerialSchedules = false;
|
||||
//engine.ResetOnObservedValueChanged = false;
|
||||
|
|
Загрузка…
Ссылка в новой задаче