BayesPointMachineClassifier.CreateGaussianPriorBinaryClassifier is public. (#376)

* Added Variable.Max(int,int).
* Compiler warns about excess memory consumption in more cases when it should, and fewer cases when it shouldn't.
* TransformBrowser shows attributes by default.
* Updated FactorDocs
This commit is contained in:
Tom Minka 2021-12-10 23:37:00 +00:00 коммит произвёл GitHub
Родитель cc132f77f8
Коммит 3c825f2d29
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 326 добавлений и 264 удалений

Просмотреть файл

@ -3465,6 +3465,17 @@ namespace Microsoft.ML.Probabilistic.Models
return Variable<double>.Factor<double, double>(System.Math.Max, a, b);
}
/// <summary>
/// Returns an int variable which is the maximum of two int variables
/// </summary>
/// <param name="a">The first variable</param>
/// <param name="b">The second variable</param>
/// <returns>A new variable</returns>
public static Variable<int> Max(Variable<int> a, Variable<int> b)
{
return Variable<int>.Factor(System.Math.Max, a, b);
}
/// <summary>
/// Returns a double variable which is the minimum of two double variables
/// </summary>

Просмотреть файл

@ -591,72 +591,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
return expr;
}
/// <summary>
/// Returns an expression equal to expr1 and expr2 under their respective bindings, or null if the expressions are not equal.
/// </summary>
/// <param name="expr1"></param>
/// <param name="bindings1"></param>
/// <param name="expr2"></param>
/// <param name="bindings2"></param>
/// <returns></returns>
private static IExpression Unify(
IExpression expr1,
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings1,
IExpression expr2,
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings2)
{
if (expr1.Equals(expr2))
{
return expr1;
}
foreach (IReadOnlyCollection<ConditionBinding> binding2 in bindings2)
{
IExpression expr1b = ReplaceExpression(binding2, expr1);
if (expr1b.Equals(expr2))
{
return expr1;
}
}
foreach (IReadOnlyCollection<ConditionBinding> binding1 in bindings1)
{
IExpression expr2b = ReplaceExpression(binding1, expr2);
if (expr2b.Equals(expr1))
{
return expr2;
}
}
bool lift = false;
if (lift)
{
IExpression lifted1 = GetLiftedExpression(expr1, bindings1);
IExpression lifted2 = GetLiftedExpression(expr2, bindings2);
if (lifted1 != null && lifted1.Equals(lifted2)) return lifted1;
}
return Builder.StaticMethod(new Func<int>(GateAnalysisTransform.AnyIndex));
}
private static IExpression GetLiftedExpression(IExpression expr, IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings)
{
IExpression lifted = null;
foreach (IReadOnlyCollection<ConditionBinding> binding in bindings)
{
IExpression lhs = null;
foreach (ConditionBinding b in binding)
{
if (b.rhs.Equals(expr)) lhs = b.lhs;
}
if (lifted == null)
{
lifted = lhs;
}
else if (lhs == null || !lifted.Equals(lhs))
{
return null;
}
}
return lifted;
}
/// <summary>
/// Used only in ArrayIndexerExpressions, to represent a wildcard.
/// </summary>
@ -747,6 +681,65 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
result.Bindings.AddRange(eb2.Bindings);
}
return result;
// Returns an expression equal to expr1 and expr2 under their respective bindings, or null if the expressions are not equal.
IExpression Unify(
IExpression expr1,
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings1,
IExpression expr2,
IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings2)
{
if (expr1.Equals(expr2))
{
return expr1;
}
foreach (IReadOnlyCollection<ConditionBinding> binding2 in bindings2)
{
IExpression expr1b = ReplaceExpression(binding2, expr1);
if (expr1b.Equals(expr2))
{
return expr1;
}
}
foreach (IReadOnlyCollection<ConditionBinding> binding1 in bindings1)
{
IExpression expr2b = ReplaceExpression(binding1, expr2);
if (expr2b.Equals(expr1))
{
return expr2;
}
}
bool lift = false;
if (lift)
{
IExpression lifted1 = GetLiftedExpression(expr1, bindings1);
IExpression lifted2 = GetLiftedExpression(expr2, bindings2);
if (lifted1 != null && lifted1.Equals(lifted2)) return lifted1;
}
return Builder.StaticMethod(new Func<int>(GateAnalysisTransform.AnyIndex));
IExpression GetLiftedExpression(IExpression expr, IEnumerable<IReadOnlyCollection<ConditionBinding>> bindings)
{
IExpression lifted = null;
foreach (IReadOnlyCollection<ConditionBinding> binding in bindings)
{
IExpression lhs = null;
foreach (ConditionBinding b in binding)
{
if (b.rhs.Equals(expr)) lhs = b.lhs;
}
if (lifted == null)
{
lifted = lhs;
}
else if (lhs == null || !lifted.Equals(lhs))
{
return null;
}
}
return lifted;
}
}
}
}

Просмотреть файл

@ -14,6 +14,7 @@ using Microsoft.ML.Probabilistic.Utilities;
using System.Linq;
using Microsoft.ML.Probabilistic.Algorithms;
using Microsoft.ML.Probabilistic.Models.Attributes;
using Microsoft.ML.Probabilistic.Models;
namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
@ -82,11 +83,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
string warningText =
"This model will consume a lot of memory due to the following mix of indexing expressions inside of a conditional: {0}";
"This model will consume excess memory due to the following mix of indexing expressions inside of a conditional: {0}";
foreach (var entry in inefficientReplacements)
{
if (entry.Value != null)
{
Warning(string.Format(warningText, StringUtil.CollectionToString(entry.Value, ", ")));
}
}
}
@ -554,28 +557,28 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (toReplace != null)
{
var currentBindings = bindings.Take(conditionContextIndex + 1).ToList();
IExpression clone = ci.GetClone(context, toClone, currentBindings, this, start, conditionContextIndex,
isDef);
IExpression clone = ci.GetClone(context, toClone, currentBindings, this, start, conditionContextIndex, isDef);
clone = Builder.JaggedArrayIndex(clone, indices);
if (true)
// check if indices contains an expression that is not a top-level loop variable or loop local
// check for dependence on variables tagged with GateBlock that do not depend on a inner lop
bool isSubset = indices.Any(bracket => bracket.Any(index =>
{
// check if indices contains an expression that is not a top-level loop variable or loop local
bool isSubset = false;
foreach (var bracket in indices)
// Based on GateAnalysisTransform.ContainsLocalVars
bool containsLocalVars = Recognizer.GetVariables(index).Any(indexVar =>
context.InputAttributes.Get<GateBlock>(indexVar) == ci.gateBlock);
if (containsLocalVars)
{
foreach (var index in bracket)
{
var indexVar = Recognizer.GetVariableDeclaration(index);
if (indexVar == null)
{
isSubset = true;
break;
}
}
return Recognizer.GetVariables(index).Any(indexVar =>
context.InputAttributes.Get<GateBlock>(indexVar) == ci.gateBlock &&
Recognizer.GetLoopForVariable(context, indexVar) == null);
}
RecordReplacement(expr, toClone, !isSubset);
}
else
{
var indexVarDecl = Recognizer.GetVariableDeclaration(index);
return indexVarDecl == null;
}
}));
RecordReplacement(expr, toClone, !isSubset);
int replaceCount = 0;
expr = Builder.ReplaceExpression(expr, toReplace, clone, ref replaceCount);
@ -614,7 +617,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
/// </summary>
/// <param name="expr1"></param>
/// <param name="expr2"></param>
/// <param name="indices"></param>
/// <param name="indices">Contains indices of <paramref name="expr1"/> that were replaced by wildcards in <paramref name="expr2"/></param>
/// <returns></returns>
private static IExpression GetMatchingPrefix(IExpression expr1, IExpression expr2,
out List<IEnumerable<IExpression>> indices)
@ -891,7 +894,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (Builder.ContainsExpression(expr, conditionLhs))
context.Error($"Internal: expr ({expr}) contains the conditionLhs ({conditionLhs})");
ClonedVarInfo cvi = GetClonedVarInfo(context, eb, isDef);
if (cvi == null) return expr;
if (cvi == null)
{
return expr;
}
if (cvi.arrayDecl == null)
{
var extraBindings = IndexingTransform.FilterBindingSet(eb.Bindings,
@ -929,9 +935,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
/// <returns></returns>
public IExpression ReplaceAnyItem(BasicTransformContext context, IExpression expr, List<IList<IExpression>> indices)
{
if (expr is IArrayIndexerExpression)
if (expr is IArrayIndexerExpression iaie)
{
IArrayIndexerExpression iaie = (IArrayIndexerExpression) expr;
IExpression result = ReplaceAnyItem(context, iaie.Target, indices);
IList<IExpression> newIndices = Builder.ExprCollection();
IList<IExpression> allIndices = Builder.ExprCollection();
@ -1090,8 +1095,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
// This is done by attaching the attribute GateExitingVariable to the clones array.
// ChannelTransform later reads this attribute.
bool useExitRandom =
((algorithm is VariationalMessagePassing) &&
((VariationalMessagePassing) algorithm).UseGateExitRandom);
(algorithm is VariationalMessagePassing vmp) &&
vmp.UseGateExitRandom;
// if using Gate.ExitRandom, the clones should be marked as GateExiting variables
if (useExitRandom)
context.OutputAttributes.Set(cvi.arrayDecl, new VariationalMessagePassing.GateExitRandomVariable());
@ -1245,13 +1250,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
internal string ToString(IExpression expr)
{
if (expr is IVariableReferenceExpression)
if (expr is IVariableReferenceExpression ivre)
{
return ((IVariableReferenceExpression) expr).Variable.Resolve().Name;
return ivre.Variable.Resolve().Name;
}
else if (expr is IArrayIndexerExpression)
else if (expr is IArrayIndexerExpression iaie)
{
IArrayIndexerExpression iaie = (IArrayIndexerExpression) expr;
StringBuilder sb = new StringBuilder(ToString(iaie.Target));
foreach (IExpression indExpr in iaie.Indices)
sb.Append("_" + ToString(indExpr));
@ -1330,10 +1334,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
for (int i = 0; i < statements.Count; i++)
{
IStatement s = statements[i];
if (s is IForStatement)
if (s is IForStatement ifs)
{
// Recursively wrap the body statements.
IForStatement ifs = (IForStatement) s;
IForStatement fs = Builder.ForStmt();
fs.Condition = ifs.Condition;
fs.Increment = ifs.Increment;
@ -1344,10 +1347,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
currentConditionStatement = null;
continue;
}
if (s is IRepeatStatement)
if (s is IRepeatStatement irs)
{
// Recursively wrap the body statements.
IRepeatStatement irs = (IRepeatStatement) s;
IRepeatStatement rs = Builder.RepeatStmt();
rs.Count = irs.Count;
rs.Body = WrapBlockWithConditionals(context, irs.Body);
@ -1356,10 +1358,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
currentConditionStatement = null;
continue;
}
if (s is IConditionStatement && !CodeRecognizer.IsStochastic(context, ((IConditionStatement) s).Condition))
if (s is IConditionStatement ics && !CodeRecognizer.IsStochastic(context, ics.Condition))
{
// Recursively wrap the body statements.
IConditionStatement ics = (IConditionStatement) s;
IConditionStatement cs = Builder.CondStmt();
cs.Condition = ics.Condition;
cs.Then = WrapBlockWithConditionals(context, ics.Then);
@ -1611,8 +1612,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (caseNumbers == null) return true;
if (definitionCount == 0) return true;
IExpression dimension = caseNumbers.Dimensions[0];
if (!(dimension is ILiteralExpression)) return true;
return (definitionCount == (int) ((ILiteralExpression) dimension).Value);
if (dimension is ILiteralExpression ile) return definitionCount == (int)ile.Value;
return true;
}
public override string ToString()

Просмотреть файл

@ -92,14 +92,13 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
bool extraLiteralsAreZero = true;
int parentIndex = context.InputStack.Count - 2;
object parent = context.GetAncestor(parentIndex);
while (parent is IArrayIndexerExpression)
while (parent is IArrayIndexerExpression parent_iaie)
{
IArrayIndexerExpression parent_iaie = (IArrayIndexerExpression)parent;
foreach (IExpression index in parent_iaie.Indices)
{
if (index is ILiteralExpression)
if (index is ILiteralExpression ile)
{
int value = (int)((ILiteralExpression)index).Value;
int value = (int)ile.Value;
if (value != 0)
{
extraLiteralsAreZero = false;
@ -154,73 +153,74 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
var stmtsAfter = Builder.StmtCollection();
// does the expression have the form array[indices[k]][indices2[k]][indices3[k]]?
if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
if (newvd == null && UseGetItems && iaie.Indices.Count == 1)
{
IArrayIndexerExpression index3 = (IArrayIndexerExpression)iaie.Indices[0];
IArrayIndexerExpression iaie2 = (IArrayIndexerExpression)iaie.Target;
if (index3.Indices.Count == 1 && index3.Indices[0] is IVariableReferenceExpression &&
iaie2.Target is IArrayIndexerExpression &&
iaie2.Indices.Count == 1 && iaie2.Indices[0] is IArrayIndexerExpression)
if (iaie.Target is IArrayIndexerExpression iaie2 &&
iaie.Indices[0] is IArrayIndexerExpression index3 &&
index3.Indices.Count == 1 &&
index3.Indices[0] is IVariableReferenceExpression innerIndex3 &&
iaie2.Target is IArrayIndexerExpression iaie3 &&
iaie2.Indices.Count == 1 &&
iaie2.Indices[0] is IArrayIndexerExpression index2 &&
index2.Indices.Count == 1 &&
index2.Indices[0] is IVariableReferenceExpression innerIndex2 &&
innerIndex2.Equals(innerIndex3) &&
iaie3.Indices.Count == 1 &&
iaie3.Indices[0] is IArrayIndexerExpression index &&
index.Indices.Count == 1 &&
index.Indices[0] is IVariableReferenceExpression innerIndex &&
innerIndex.Equals(innerIndex2))
{
IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie2.Indices[0];
IArrayIndexerExpression iaie3 = (IArrayIndexerExpression)iaie2.Target;
if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
iaie3.Indices.Count == 1 && iaie3.Indices[0] is IArrayIndexerExpression)
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
if (innerLoop != null &&
AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
{
IArrayIndexerExpression index = (IArrayIndexerExpression)iaie3.Indices[0];
IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
if (index.Indices.Count == 1 && index2.Indices[0].Equals(innerIndex)
&& index3.Indices[0].Equals(innerIndex) &&
innerLoop != null && AreLoopsDisjoint(innerLoop, iaie3.Target, index.Target))
// expression has the form array[indices[k]][indices2[k]][indices3[k]]
if (isDef)
{
// expression has the form array[indices[k]][indices2[k]][indices3[k]]
if (isDef)
{
Error("fancy indexing not allowed on left hand side");
return iaie;
}
WarnIfLocal(index.Target, iaie3.Target, iaie);
WarnIfLocal(index2.Target, iaie3.Target, iaie);
WarnIfLocal(index3.Target, iaie3.Target, iaie);
containers = RemoveReferencesTo(containers, innerIndex);
IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
var indices = Recognizer.GetIndices(iaie);
// Build name of replacement variable from index values
StringBuilder sb = new StringBuilder("_item");
AppendIndexString(sb, iaie3);
AppendIndexString(sb, iaie2);
AppendIndexString(sb, iaie);
string name = ToString(iaie3.Target) + sb.ToString();
VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
context.InputAttributes.Set(newvd, new DerivedVariable());
IExpression getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<IReadOnlyList<IReadOnlyList<PlaceHolder>>>, IReadOnlyList<int>, IReadOnlyList<int>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
rhsExpr = getItems;
Error("fancy indexing not allowed on left hand side");
return iaie;
}
WarnIfLocal(index.Target, iaie3.Target, iaie);
WarnIfLocal(index2.Target, iaie3.Target, iaie);
WarnIfLocal(index3.Target, iaie3.Target, iaie);
containers = RemoveReferencesTo(containers, innerIndex);
IExpression loopSize = Recognizer.LoopSizeExpression(innerLoop);
var indices = Recognizer.GetIndices(iaie);
// Build name of replacement variable from index values
StringBuilder sb = new StringBuilder("_item");
AppendIndexString(sb, iaie3);
AppendIndexString(sb, iaie2);
AppendIndexString(sb, iaie);
string name = ToString(iaie3.Target) + sb.ToString();
VariableInformation varInfo = VariableInformation.GetVariableInformation(context, baseVar);
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSize, Recognizer.GetVariableDeclaration(innerIndex), indices);
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
context.InputAttributes.Set(newvd, new DerivedVariable());
IExpression getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<IReadOnlyList<IReadOnlyList<PlaceHolder>>>, IReadOnlyList<int>, IReadOnlyList<int>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItemsFromDeepJagged),
new Type[] { tp }, iaie3.Target, index.Target, index2.Target, index3.Target);
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
newExpr = Builder.ArrayIndex(Builder.VarRefExpr(newvd), innerIndex);
rhsExpr = getItems;
}
}
}
// does the expression have the form array[indices[k]][indices2[k]]?
if (newvd == null && UseGetItems && iaie.Target is IArrayIndexerExpression &&
iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
if (newvd == null && UseGetItems && iaie.Indices.Count == 1)
{
IArrayIndexerExpression index2 = (IArrayIndexerExpression)iaie.Indices[0];
IArrayIndexerExpression target = (IArrayIndexerExpression)iaie.Target;
if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression &&
target.Indices.Count == 1 && target.Indices[0] is IArrayIndexerExpression)
if (iaie.Target is IArrayIndexerExpression target &&
iaie.Indices[0] is IArrayIndexerExpression index2 &&
index2.Indices.Count == 1 &&
index2.Indices[0] is IVariableReferenceExpression innerIndex &&
target.Indices.Count == 1 &&
target.Indices[0] is IArrayIndexerExpression index)
{
IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index2.Indices[0];
IArrayIndexerExpression index = (IArrayIndexerExpression)target.Indices[0];
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
if (index.Indices.Count == 1 && index.Indices[0].Equals(innerIndex) &&
innerLoop != null && AreLoopsDisjoint(innerLoop, target.Target, index.Target))
if (index.Indices.Count == 1 &&
index.Indices[0].Equals(innerIndex) &&
innerLoop != null &&
AreLoopsDisjoint(innerLoop, target.Target, index.Target))
{
// expression has the form array[indices[k]][indices2[k]]
if (isDef)
@ -233,17 +233,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
var indexTarget = index.Target;
var index2Target = index2.Target;
// check if the index array is jagged, i.e. array[indices[k][j]]
while (indexTarget is IArrayIndexerExpression && index2Target is IArrayIndexerExpression)
while (indexTarget is IArrayIndexerExpression indexTargetExpr &&
index2Target is IArrayIndexerExpression index2TargetExpr)
{
IArrayIndexerExpression indexTargetExpr = (IArrayIndexerExpression)indexTarget;
IArrayIndexerExpression index2TargetExpr = (IArrayIndexerExpression)index2Target;
if (indexTargetExpr.Indices.Count == 1 && indexTargetExpr.Indices[0] is IVariableReferenceExpression &&
index2TargetExpr.Indices.Count == 1 && index2TargetExpr.Indices[0] is IVariableReferenceExpression)
if (indexTargetExpr.Indices.Count == 1 &&
indexTargetExpr.Indices[0] is IVariableReferenceExpression innerIndexTarget &&
index2TargetExpr.Indices.Count == 1 &&
index2TargetExpr.Indices[0] is IVariableReferenceExpression innerIndex2Target)
{
IVariableReferenceExpression innerIndexTarget = (IVariableReferenceExpression)indexTargetExpr.Indices[0];
IVariableReferenceExpression innerIndex2Target = (IVariableReferenceExpression)index2TargetExpr.Indices[0];
IForStatement indexTargetLoop = Recognizer.GetLoopForVariable(context, innerIndexTarget);
if (indexTargetLoop != null && AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
if (indexTargetLoop != null &&
AreLoopsDisjoint(indexTargetLoop, target.Target, indexTargetExpr.Target) &&
innerIndexTarget.Equals(innerIndex2Target))
{
innerLoops.Add(indexTargetLoop);
@ -319,76 +319,76 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
string name = ToString(iaie.Target) + sb.ToString();
// does the expression have the form array[indices[k]]?
if (UseGetItems && iaie.Indices.Count == 1 && iaie.Indices[0] is IArrayIndexerExpression)
if (UseGetItems &&
iaie.Indices.Count == 1 &&
iaie.Indices[0] is IArrayIndexerExpression index &&
index.Indices.Count == 1 &&
index.Indices[0] is IVariableReferenceExpression innerIndex)
{
IArrayIndexerExpression index = (IArrayIndexerExpression)iaie.Indices[0];
if (index.Indices.Count == 1 && index.Indices[0] is IVariableReferenceExpression)
// expression has the form array[indices[k]]
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
if (innerLoop != null &&
AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
{
// expression has the form array[indices[k]]
IVariableReferenceExpression innerIndex = (IVariableReferenceExpression)index.Indices[0];
IForStatement innerLoop = Recognizer.GetLoopForVariable(context, innerIndex);
if (innerLoop != null && AreLoopsDisjoint(innerLoop, iaie.Target, index.Target))
if (isDef)
{
if (isDef)
Error("fancy indexing not allowed on left hand side");
return iaie;
}
var innerLoops = new List<IForStatement>();
innerLoops.Add(innerLoop);
var indexTarget = index.Target;
// check if the index array is jagged, i.e. array[indices[k][j]]
while (indexTarget is IArrayIndexerExpression index2)
{
if (index2.Indices.Count == 1 &&
index2.Indices[0] is IVariableReferenceExpression innerIndex2)
{
Error("fancy indexing not allowed on left hand side");
return iaie;
}
var innerLoops = new List<IForStatement>();
innerLoops.Add(innerLoop);
var indexTarget = index.Target;
// check if the index array is jagged, i.e. array[indices[k][j]]
while (indexTarget is IArrayIndexerExpression)
{
IArrayIndexerExpression index2 = (IArrayIndexerExpression)indexTarget;
if (index2.Indices.Count == 1 && index2.Indices[0] is IVariableReferenceExpression)
IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
if (innerLoop2 != null &&
AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
{
IVariableReferenceExpression innerIndex2 = (IVariableReferenceExpression)index2.Indices[0];
IForStatement innerLoop2 = Recognizer.GetLoopForVariable(context, innerIndex2);
if (innerLoop2 != null && AreLoopsDisjoint(innerLoop2, iaie.Target, index2.Target))
{
innerLoops.Add(innerLoop2);
indexTarget = index2.Target;
// This limit must match the number of handled cases below.
if (innerLoops.Count == 3) break;
}
else
break;
innerLoops.Add(innerLoop2);
indexTarget = index2.Target;
// This limit must match the number of handled cases below.
if (innerLoops.Count == 3) break;
}
else
break;
}
WarnIfLocal(indexTarget, iaie.Target, originalExpr);
innerLoops.Reverse();
var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
context.InputAttributes.Set(newvd, new DerivedVariable());
IExpression getItems;
if (innerLoops.Count == 1)
{
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItems),
new Type[] { tp }, iaie.Target, indexTarget);
}
else if (innerLoops.Count == 2)
{
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<int>>, PlaceHolder[][]>(Collection.GetJaggedItems),
new Type[] { tp }, iaie.Target, indexTarget);
}
else if (innerLoops.Count == 3)
{
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<IReadOnlyList<int>>>, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
new Type[] { tp }, iaie.Target, indexTarget);
}
else
throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
rhsExpr = getItems;
break;
}
WarnIfLocal(indexTarget, iaie.Target, originalExpr);
innerLoops.Reverse();
var loopSizes = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopSizeExpression(ifs) });
var newIndexVars = innerLoops.ListSelect(ifs => new[] { Recognizer.LoopVariable(ifs) });
newvd = varInfo.DeriveArrayVariable(stmts, context, name, loopSizes, newIndexVars, indices);
if (!context.InputAttributes.Has<DerivedVariable>(newvd))
context.InputAttributes.Set(newvd, new DerivedVariable());
IExpression getItems;
if (innerLoops.Count == 1)
{
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<int>, PlaceHolder[]>(Collection.GetItems),
new Type[] { tp }, iaie.Target, indexTarget);
}
else if (innerLoops.Count == 2)
{
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<int>>, PlaceHolder[][]>(Collection.GetJaggedItems),
new Type[] { tp }, iaie.Target, indexTarget);
}
else if (innerLoops.Count == 3)
{
getItems = Builder.StaticGenericMethod(new Func<IReadOnlyList<PlaceHolder>, IReadOnlyList<IReadOnlyList<IReadOnlyList<int>>>, PlaceHolder[][][]>(Collection.GetDeepJaggedItems),
new Type[] { tp }, iaie.Target, indexTarget);
}
else
throw new NotImplementedException($"innerLoops.Count = {innerLoops.Count}");
context.InputAttributes.CopyObjectAttributesTo<Algorithm>(baseVar, context.OutputAttributes, getItems);
stmts.Add(Builder.AssignStmt(Builder.VarRefExpr(newvd), getItems));
var newIndices = newIndexVars.ListSelect(ivds => Util.ArrayInit(ivds.Length, i => Builder.VarRefExpr(ivds[i])));
newExpr = Builder.JaggedArrayIndex(Builder.VarRefExpr(newvd), newIndices);
rhsExpr = getItems;
}
}
if (newvd == null)

Просмотреть файл

@ -337,7 +337,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
if (embeddedLoopIndices.Contains(currentLoop))
{
string warningText = "This model will consume a lot of memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
string warningText = "This model will consume excess memory due to the indexing expression {0} inside of a loop over {1}. Try simplifying this expression in your model, perhaps by creating auxiliary index arrays.";
Warning(string.Format(warningText, originalExpr, loopVar.Name));
}
// split expr into a target and extra indices, where target will be replicated and extra indices will be added later

Просмотреть файл

@ -511,7 +511,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
/// <returns>The binary Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
CreateGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource>(
IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, bool> mapping)
{
@ -532,7 +532,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <typeparam name="TLabelSource">The type of a source of labels.</typeparam>
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
CreateGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource>(
IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, int> mapping)
{
@ -554,7 +554,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
/// <returns>The binary Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
CreateGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
{
@ -576,7 +576,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <typeparam name="TLabel">The type of a label.</typeparam>
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
CreateGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
{
@ -603,7 +603,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
@ -621,7 +621,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="stream">The stream.</param>
/// <param name="formatter">The formatter.</param>
/// <returns>The deserialized binary Bayes point machine classifier object.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(Stream stream, IFormatter formatter)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
@ -638,7 +638,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <typeparam name="TLabelDistribution">The type of a distribution over labels.</typeparam>
/// <param name="fileName">The file name.</param>
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(string fileName)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(fileName);
@ -656,7 +656,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="stream">The stream.</param>
/// <param name="formatter">The formatter.</param>
/// <returns>The deserialized multi-class Bayes point machine classifier object.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution>(Stream stream, IFormatter formatter)
{
return Utilities.Load<IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, TLabelDistribution, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>>(stream, formatter);
@ -676,7 +676,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="reader">The reader to a stream of a serialized binary Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
/// <returns>The binary Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource>(
IReader reader, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, bool> mapping)
{
@ -703,7 +703,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="fileName">The name of the file of a serialized binary Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
/// <returns>The binary Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, bool, Bernoulli, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<bool>>
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource>(
string fileName, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, bool> mapping)
{
@ -728,7 +728,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="reader">The reader to a stream of a serialized multi-class Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource>(
IReader reader, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, int> mapping)
{
@ -755,7 +755,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="fileName">The name of the file of a serialized multi-class Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the native format.</param>
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, int, Discrete, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<int>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource>(
string fileName, IBayesPointMachineClassifierMapping<TInstanceSource, TInstance, TLabelSource, int> mapping)
{
@ -781,7 +781,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="reader">The reader to a stream of a serialized binary Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
/// <returns>The binary Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
IReader reader, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
{
@ -809,7 +809,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="fileName">The name of the file of a serialized binary Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
/// <returns>The binary Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, BinaryBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorBinaryClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
string fileName, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
{
@ -835,7 +835,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="reader">The reader to a stream of a serialized multi-class Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
IReader reader, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
{
@ -863,7 +863,7 @@ namespace Microsoft.ML.Probabilistic.Learners
/// <param name="fileName">The name of the file of a serialized multi-class Bayes point machine classifier.</param>
/// <param name="mapping">The mapping used for accessing data in the standard format.</param>
/// <returns>The multi-class Bayes point machine classifier instance.</returns>
internal static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
public static IBayesPointMachineClassifier<TInstanceSource, TInstance, TLabelSource, TLabel, IDictionary<TLabel, double>, GaussianBayesPointMachineClassifierTrainingSettings, MulticlassBayesPointMachineClassifierPredictionSettings<TLabel>>
LoadBackwardCompatibleGaussianPriorMulticlassClassifier<TInstanceSource, TInstance, TLabelSource, TLabel>(
string fileName, IClassifierMapping<TInstanceSource, TInstance, TLabelSource, TLabel, Vector> mapping)
{

Просмотреть файл

@ -17,7 +17,7 @@ namespace Microsoft.ML.Probabilistic.Learners.BayesPointMachineClassifierInterna
/// These settings cannot be modified after training.
/// </remarks>
[Serializable]
internal class GaussianBayesPointMachineClassifierAdvancedTrainingSettings : ICustomSerializable
public class GaussianBayesPointMachineClassifierAdvancedTrainingSettings : ICustomSerializable
{
/// <summary>
/// The current custom binary serialization version of the <see cref="GaussianBayesPointMachineClassifierAdvancedTrainingSettings"/> class.

Просмотреть файл

@ -14,7 +14,7 @@ namespace Microsoft.ML.Probabilistic.Learners.BayesPointMachineClassifierInterna
/// Settings which affect training of the Bayes point machine classifier with <see cref="Gaussian"/> prior distributions over weights.
/// </summary>
[Serializable]
internal class GaussianBayesPointMachineClassifierTrainingSettings : BayesPointMachineClassifierTrainingSettings
public class GaussianBayesPointMachineClassifierTrainingSettings : BayesPointMachineClassifierTrainingSettings
{
/// <summary>
/// The current serialization version of <see cref="GaussianBayesPointMachineClassifierTrainingSettings"/>.

Просмотреть файл

@ -18805,7 +18805,7 @@
</remarks>
</message_doc>
</message_op_class>
<message_op_class name="DoubleIsBetweenOp">
<message_op_class name="IsBetweenGaussianOp">
<doc>
<summary>Provides outgoing messages for <see cref="Factor.IsBetween(double, double, double)" />, given random arguments to the function.</summary>
</doc>
@ -19212,7 +19212,7 @@
</remarks>
</message_doc>
</message_op_class>
<message_op_class name="TruncatedGaussianIsBetweenOp">
<message_op_class name="IsBetweenTruncatedGaussianOp">
<doc>
<summary>Provides outgoing messages for <see cref="Factor.IsBetween(double, double, double)" />, given random arguments to the function.</summary>
</doc>

Просмотреть файл

@ -27,7 +27,7 @@
<Grid.ColumnDefinitions>
<ColumnDefinition Width="3*"/>
<ColumnDefinition Width="Auto"/>
<ColumnDefinition Name="AttributesColumn" Width="0"/>
<ColumnDefinition Name="AttributesColumn" Width="250"/>
</Grid.ColumnDefinitions>
<Grid Margin="5" Grid.ColumnSpan="3">
<Grid.ColumnDefinitions>

Просмотреть файл

@ -109,8 +109,10 @@ namespace Microsoft.ML.Probabilistic.Tests
using (Variable.IfNot(playersFirstGame))
{
var prevGame = PrevGameIndices[game][gamePlayer];//, 0).Named("prevGame");
var prevGamePlayerIndex = PrevGamePlayerIndices[game][gamePlayer];//, 0).Named("prevGamePlayerIndex");
////var prevGame = Variable.Min(PrevGameIndices[game][gamePlayer], 0).Named("prevGame");
////var prevGamePlayerIndex = Variable.Min(PrevGamePlayerIndices[game][gamePlayer], 0).Named("prevGamePlayerIndex");
var prevGame = PrevGameIndices[game][gamePlayer];
var prevGamePlayerIndex = PrevGamePlayerIndices[game][gamePlayer];
Skills[game][gamePlayer] = Variable.GaussianFromMeanAndVariance(
Skills[prevGame][prevGamePlayerIndex],
dynamicsVariance

Просмотреть файл

@ -4337,6 +4337,11 @@ namespace Microsoft.ML.Probabilistic.Tests
Variable.ConstrainEqualRandom(b[indexObs], likelihood2[index]);
}
InferenceEngine engine = new InferenceEngine();
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.Equal(1, e.Warnings.Count);
};
for (int trial = 0; trial < 2; trial++)
{
indexObs.ObservedValue = trial;
@ -7173,7 +7178,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
DistributionArray<Bernoulli> bDist = engine.Infer<DistributionArray<Bernoulli>>(b);
for (int i = 0; i < bDist.Count; i++)
@ -7215,7 +7220,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
DistributionArray<Bernoulli> bDist = engine.Infer<DistributionArray<Bernoulli>>(b);
for (int i = 0; i < bDist.Count; i++)
@ -7231,6 +7236,47 @@ namespace Microsoft.ML.Probabilistic.Tests
}
}
[Fact]
public void EnterArrayElementsTest3()
{
Range item = new Range(2).Named("item");
Range xitem = new Range(2).Named("xitem");
VariableArray<bool> x = Variable.Array<bool>(xitem).Named("x");
double xPrior = 0.3;
x[xitem] = Variable.Bernoulli(xPrior).ForEach(xitem);
VariableArray<int> indices = Variable.Array<int>(item).Named("indices");
indices.ObservedValue = new int[] { 0, 1 };
VariableArray<int> indices2 = Variable.Array<int>(item).Named("indices2");
indices2.ObservedValue = new int[] { 1, 0 };
VariableArray<bool> b = Variable.Array<bool>(item).Named("b");
double bPrior = 0.1;
double xLike = 0.4;
using (Variable.ForEach(item))
{
b[item] = Variable.Bernoulli(bPrior);
using (Variable.If(b[item]))
{
var index = Variable.Max(0, indices[item]);
Variable.ConstrainEqualRandom(x[index], new Bernoulli(xLike));
}
}
InferenceEngine engine = new InferenceEngine();
engine.ShowProgress = false;
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.Equal(1, e.Warnings.Count);
};
DistributionArray<Bernoulli> bDist = engine.Infer<DistributionArray<Bernoulli>>(b);
for (int i = 0; i < bDist.Count; i++)
{
double sumCondT = xPrior * xLike + (1 - xPrior) * (1 - xLike);
double bPost = bPrior * sumCondT / (bPrior * sumCondT + (1 - bPrior));
Assert.True(MMath.AbsDiff(bDist[i].GetProbTrue(), bPost, 1e-10) < 1e-10);
}
}
internal void FairCoinTest()
{
VariableArray<bool> tosses = Variable.Constant(new bool[] { true, true, true, true, true });

Просмотреть файл

@ -1333,7 +1333,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 0);
Assert.Equal(0, e.Warnings.Count);
};
engine.Infer(userThresholds);
}
@ -1359,7 +1359,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
@ -1402,7 +1402,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
@ -1438,7 +1438,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
@ -1477,7 +1477,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
//Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
var boolsActual = engine.Infer<IList<Bernoulli>>(bools);
var boolsExpected = new BernoulliArray(C.SizeAsInt, i =>
@ -1729,7 +1729,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
IDistribution<bool[]> post = engine.Infer<IDistribution<bool[]>>(bools);
}
@ -1866,7 +1866,7 @@ namespace Microsoft.ML.Probabilistic.Tests
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.True(e.Warnings.Count == 1);
Assert.Equal(1, e.Warnings.Count);
};
IDistribution<bool[][]> weightsActual = engine.Infer<IDistribution<bool[][]>>(bools);
}
@ -2433,6 +2433,11 @@ namespace Microsoft.ML.Probabilistic.Tests
if (!(algorithm is GibbsSampling)) block.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.Equal(0, e.Warnings.Count);
};
indices.ObservedValue = new int[] {0};
indicesLength.ObservedValue = indices.ObservedValue.Length;
Bernoulli[][] arrayExpectedArray = new Bernoulli[item.SizeAsInt][];
@ -2524,12 +2529,17 @@ namespace Microsoft.ML.Probabilistic.Tests
double xLike = 0.6;
using (Variable.Switch(switchVar))
{
Variable<bool> x = array[indices[switchVar][xitem]].Named("x");
Variable<bool> x = array[indices[switchVar][xitem]];
Variable.ConstrainEqualRandom(x, new Bernoulli(xLike));
}
if (!(algorithm is GibbsSampling)) block.CloseBlock();
InferenceEngine engine = new InferenceEngine(algorithm);
engine.Compiler.Compiled += (sender, e) =>
{
// check for the inefficient replication warning
Assert.Equal(0, e.Warnings.Count);
};
indices.ObservedValue = new int[][] {new int[] {0}, new int[] {3}};
indicesLength.ObservedValue = new int[] {indices.ObservedValue[0].Length, indices.ObservedValue[1].Length};
Bernoulli[] arrayExpectedArray = new Bernoulli[item.SizeAsInt];

Просмотреть файл

@ -407,7 +407,6 @@ namespace Microsoft.ML.Probabilistic.Tests
InferenceEngine engine = new InferenceEngine();
//engine.Algorithm = new VariationalMessagePassing();
engine.Compiler.CatchExceptions = true;
//engine.Compiler.UnrollLoops = true;
//engine.Compiler.UseSerialSchedules = false;
//engine.ResetOnObservedValueChanged = false;