Fixed DependencyAnalysisTransform and SchedulingTransform (#405)

This commit is contained in:
Tom Minka 2022-05-29 18:40:37 +01:00 коммит произвёл GitHub
Родитель 55d26a7138
Коммит 5156776f4c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 135 добавлений и 116 удалений

Просмотреть файл

@ -411,8 +411,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
IList<IList<IExpression>> indices, IList<IList<IExpression>> wildcardVars = null)
{
IExpression original = prototypeExpression;
int replaceCount = 0;
prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, ref replaceCount);
prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, out int replaceCount);
int mpDepth = Util.GetArrayDepth(varType, Distribution.GetDomainType(prototypeExpression.GetExpressionType()));
int indexingDepth = indices.Count;
int wildcardBracket = 0;
@ -504,8 +503,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
IExpression index = indices[i][j];
if (Recognizer.IsStaticMethod(index, new Func<int>(GateAnalysisTransform.AnyIndex)))
{
int replaceCount = 0;
sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, ref replaceCount));
sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, out int replaceCount));
IVariableDeclaration v = indexVars[i][j];
if (wildcardVars != null) v = Recognizer.GetVariableDeclaration(wildcardVars[newIndexVars.Count][indexVarsBracket.Count]);
else if (Recognizer.GetLoopForVariable(context, v) != null)
@ -556,8 +554,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
IList<IExpression> replacementBracket = Builder.ExprCollection();
for (int j = 0; j < sizeBracket.Length; j++)
{
int replaceCount = 0;
sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, ref replaceCount);
sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, out int replaceCount);
if (replaceCount > 0) indexVarBracket[j] = GenerateLoopVar(context, "_a");
else if (indexVars.Count > i) indexVarBracket[j] = indexVars[i][j];
if (indexVarBracket[j] != null) replacementBracket.Add(Builder.VarRefExpr(indexVarBracket[j]));
@ -598,9 +595,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
// substitute indices in the marginal prototype expression
vi.marginalPrototypeExpression = GetMarginalPrototypeExpression(context, marginalPrototypeExpression, replacements, wildcardVars);
}
InitialiseTo it = context.InputAttributes.Get<InitialiseTo>(declaration);
if (it != null && copyInitializer)
if (copyInitializer) CopyInitialiser();
ChannelTransform.setAllGroupRoots(context, arrayvd, false);
return arrayvd;
void CopyInitialiser()
{
InitialiseTo it = context.InputAttributes.Get<InitialiseTo>(declaration);
if (it == null)
{
return;
}
// if original array is indexed [i,j][k,l][m,n] and indices = [*,*][3,*] then
// initExpr2 = new PlaceHolder[wildcard0,wildcard1] { new PlaceHolder[wildcard2] { new PlaceHolder[newIndexVar] { initExpr[wildcard0,wildcard1][3,wildcard2] } } }
IExpression initExpr = it.initialMessagesExpression;
@ -643,8 +648,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
}
context.OutputAttributes.Set(arrayvd, new InitialiseTo(initExpr));
}
ChannelTransform.setAllGroupRoots(context, arrayvd, false);
return arrayvd;
}
internal static IExpression MakePlaceHolderArrayCreate(IExpression expr, IList<IVariableDeclaration[]> indexVars)
@ -693,11 +696,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
/// <param name="expr">Any expression</param>
/// <param name="indices">A list of lists of index expressions (one list for each indexing bracket).</param>
/// <param name="wildcardIndices">Expressions used to replace wildcards. May be null if there are no wildcards.</param>
/// <param name="replaceCount">Incremented for each replacement.</param>
/// <param name="replaceCount">The number of replacements.</param>
/// <returns>A new expression.</returns>
internal IExpression ReplaceIndexVars(BasicTransformContext context, IExpression expr, IList<IList<IExpression>> indices,
IList<IList<IExpression>> wildcardIndices, ref int replaceCount)
IList<IList<IExpression>> wildcardIndices, out int replaceCount)
{
replaceCount = 0;
Dictionary<IVariableDeclaration, IExpression> replacedIndexVars = new Dictionary<IVariableDeclaration, IExpression>();
int wildcardBracket = 0;
for (int depth = 0; depth < indices.Count; depth++)

Просмотреть файл

@ -946,7 +946,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
var lct2 = new LoopCuttingTransform(true);
tc.AddTransform(lct2);
tc.AddTransform(lct2); // run again to catch uses before declaration
if(OptimiseInferenceCode)
if (OptimiseInferenceCode)
{
// must run after HoistingTransform
tc.AddTransform(new LoopRemovalTransform());

Просмотреть файл

@ -836,9 +836,10 @@ namespace Microsoft.ML.Probabilistic.Models
for (int i = 0; i < indexVars.Length; i++)
{
IModelExpression expr = parent.indices[i];
if (!(expr is Range))
if (expr is Range range)
indexVars[i] = range.GetIndexDeclaration();
else
throw new Exception(parent + ".InitializeTo is not allowed since the indices are not ranges");
indexVars[i] = ((Range)expr).GetIndexDeclaration();
}
initExpr = VariableInformation.MakePlaceHolderArrayCreate(initExpr, indexVars);
parent = (Variable)parent.ArrayVariable;

Просмотреть файл

@ -321,7 +321,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
foreach (ConditionBinding ci in conditionContext)
{
// each lhs has already been replaced, so we only need to compare for equality
// each subexpression has already been replaced, so we only need to compare for equality here
if (expr.Equals(ci.lhs)) return ci.rhs;
}
}

Просмотреть файл

@ -810,10 +810,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
if (parameter.Name == "length")
{
IExpression arg = ioce.Arguments[argIndex];
if (arg is ILiteralExpression)
if (arg is ILiteralExpression ile)
{
object argValue = ((ILiteralExpression)arg).Value;
if (argValue is int && (int)argValue == 0)
object argValue = ile.Value;
if (argValue is int i && i == 0)
dependencyInformation.IsUniform = true;
}
break;
@ -1063,18 +1063,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
// arguments in all dependencies with the correct model expressions.
IExpression MapDependency(IStatement ist)
{
return ReplaceArgs(((IExpressionStatement)ist).Expression);
// Replace the parameter expressions in an expression with the corresponding model expressions
IExpression ReplaceArgs(IExpression iExpression)
{
foreach (KeyValuePair<IExpression, IExpression> kvp in parameterToExpressionMap)
{
int repCount = 0;
iExpression = Builder.ReplaceExpression(iExpression, kvp.Key, kvp.Value, ref repCount);
}
return iExpression;
}
return Builder.ReplaceSubexpressions(((IExpressionStatement)ist).Expression, parameterToExpressionMap);
}
}

Просмотреть файл

@ -148,7 +148,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
private static void AddCopyStatements(ICollection<IStatement> stmts, VariableInformation varInfo, int indexingDepth, IExpression lhs, IExpression rhs,
int bracket = 0, Dictionary<IExpression,IExpression> replacements = null)
int bracket = 0, Dictionary<IExpression, IExpression> replacements = null)
{
if (indexingDepth == bracket)
{
@ -164,7 +164,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
int[] sizes = Util.ArrayInit(varInfo.sizes[bracket].Length, i => (int)((ILiteralExpression)varInfo.sizes[bracket][i]).Value);
ForEachLiteralIndex(sizes, index =>
{
IExpression[] bracketIndices = Util.ArrayInit(index.Length, i => Builder.LiteralExpr(index[i]));
ILiteralExpression[] bracketIndices = Util.ArrayInit(index.Length, i => Builder.LiteralExpr(index[i]));
var newLhs = Builder.ArrayIndex(lhs, bracketIndices);
var newRhs = Builder.ArrayIndex(rhs, bracketIndices);
for (int dim = 0; dim < index.Length; dim++)
@ -180,7 +180,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
IReadOnlyList<IExpression> replacedSizes = varInfo.sizes[bracket];
if(replacements != null)
{
replacedSizes = Util.ArrayInit(replacedSizes.Count, i => Replace(replacedSizes[i], replacements));
replacedSizes = Util.ArrayInit(replacedSizes.Count, i => Builder.ReplaceSubexpressions(replacedSizes[i], replacements));
}
IForStatement innerForStatement;
var fs = Builder.NestedForStmt(varInfo.indexVars[bracket], replacedSizes, out innerForStatement);
@ -194,15 +194,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
}
private static IExpression Replace(IExpression expr, IReadOnlyDictionary<IExpression,IExpression> replacements)
{
foreach(var entry in replacements)
{
expr = Builder.ReplaceExpression(expr, entry.Key, entry.Value);
}
return expr;
}
private static void ForEachLiteralIndex(int[] sizes, Action<int[]> action)
{
int[] strides = StringUtil.ArrayStrides(sizes);

Просмотреть файл

@ -166,10 +166,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
foreach (var stmt in outputBlock)
{
if (stmt is IWhileStatement)
// recursively collect the child block
if (stmt is IWhileStatement iws)
{
// recursively collect the child block
IWhileStatement iws = (IWhileStatement)stmt;
CollectTransformedStmts(iws.Body.Statements, replacementsInContext);
// merge the child replacements into this block's replacements
var childReplacements = replacementsInContext[iws.Body.Statements];
@ -208,9 +207,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
var replacements = replacementsInContext[outputBlock];
foreach (var stmt in outputBlock)
{
if (stmt is IWhileStatement)
if (stmt is IWhileStatement iws)
{
IWhileStatement iws = (IWhileStatement)stmt;
// merge this block's replacements into the child's replacements
var childReplacements = replacementsInContext[iws.Body.Statements];
foreach (var entry in replacements)
@ -314,7 +312,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
OffsetInfo newOffsetInfo = new OffsetInfo();
foreach (var offset in offsetInfo)
{
if (CanKeepOffsetDependency(ssinfo, reversedLoopVars, reversedLoopVarsOther, offset))
if (CanKeepOffsetDependency(reversedLoopVarsOther, offset))
newOffsetInfo.Add(offset);
else
changed = true;
@ -334,19 +332,19 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
di.offsetIndexOf.Add(pair.Key, pair.Value);
}
}
private bool CanKeepOffsetDependency(SerialSchedulingInfo ssinfo, Set<IVariableDeclaration> reversedLoopVars, Set<IVariableDeclaration> reversedLoopVarsOther, Offset offset)
{
foreach (var loopVar in ssinfo.loopInfos.Select(info => info.loopVar))
bool CanKeepOffsetDependency(Set<IVariableDeclaration> reversedLoopVarsOther, Offset offset)
{
bool compatible = (reversedLoopVars.Contains(loopVar) == reversedLoopVarsOther.Contains(loopVar));
if (!compatible)
return false;
if (offset.loopVar == loopVar)
break;
foreach (var loopVar in ssinfo.loopInfos.Select(info => info.loopVar))
{
bool compatible = (reversedLoopVars.Contains(loopVar) == reversedLoopVarsOther.Contains(loopVar));
if (!compatible)
return false;
if (offset.loopVar == loopVar)
break;
}
return true;
}
return true;
}
/// <summary>

Просмотреть файл

@ -574,17 +574,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}
/// <summary>
/// Apply all bindings to expr, in order.
/// Apply all bindings to expr.
/// </summary>
/// <param name="bindings"></param>
/// <param name="bindings">Must not chain, i.e. no lhs appears in any rhs.</param>
/// <param name="expr"></param>
/// <returns></returns>
internal static IExpression ReplaceExpression(IEnumerable<ConditionBinding> bindings, IExpression expr)
{
foreach (ConditionBinding binding in bindings)
{
int replaceCount = 0;
expr = Builder.ReplaceExpression(expr, binding.lhs, binding.rhs, ref replaceCount);
expr = Builder.ReplaceExpression(expr, binding.lhs, binding.rhs);
}
return expr;
}

Просмотреть файл

@ -516,6 +516,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
out replaced);
if (isDef && !replaced)
{
// This call is valid because bindings is formed from condition expressions and these have been fully replaced by ConstantFoldingTransform.
IExpression boundDef =
GateAnalysisTransform.ReplaceExpression(bindings, definedExpression.Expression);
Error($"{expr} doesn\'t match bound GateBlock def: {boundDef}");
@ -586,8 +587,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
}));
RecordReplacement(expr, toClone, !isSubset);
int replaceCount = 0;
expr = Builder.ReplaceExpression(expr, toReplace, clone, ref replaceCount);
expr = Builder.ReplaceExpression(expr, toReplace, clone);
replaced = true;
}

Просмотреть файл

@ -565,8 +565,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
int depth = Recognizer.GetIndexingDepth(index);
IExpression resultSize = indexInfo.sizes[depth][0];
var indices = Recognizer.GetIndices(index);
int replaceCount = 0;
resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, ref replaceCount);
resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, out int replaceCount);
indexInfo.DefineIndexVarsUpToDepth(context, depth + 1);
IVariableDeclaration resultIndex = indexInfo.indexVars[depth][0];
Type arrayType = arrayExpr.GetExpressionType();

Просмотреть файл

@ -1864,8 +1864,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
List<IList<IExpression>> indices = Recognizer.GetIndices(outputLhs);
if (indices.Count > 0)
{
int replaceCount = 0;
mpe = channelVarInfo.ReplaceIndexVars(context, mpe, indices, null, ref replaceCount);
mpe = channelVarInfo.ReplaceIndexVars(context, mpe, indices, null, out int replaceCount);
}
if (mai.isDistribution)
{
@ -2122,13 +2121,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
{
if (iace.Type.DotNetType.Equals(typeof(PlaceHolder)) && iace.Initializer != null && iace.Initializer.Expressions.Count == 1)
{
// VariableInformation.DeriveArrayVariable has created an initializer expression of the form:
// new PlaceHolder[wildcard0,wildcard1] { initExpr[wildcard0,wildcard1] }
IExpression initExpr = iace.Initializer.Expressions[0];
// replace index variables with the given indices
Dictionary<IExpression, IExpression> replacements = new Dictionary<IExpression, IExpression>();
for (int dim = 0; dim < iace.Dimensions.Count; dim++)
{
initExpr = Builder.ReplaceExpression(initExpr, iace.Dimensions[dim], iaie.Indices[dim]);
replacements.Add(iace.Dimensions[dim], iaie.Indices[dim]);
}
return initExpr;
return Builder.ReplaceSubexpressions(initExpr, replacements);
}
}
return Builder.ArrayIndex(target, iaie.Indices);
@ -2275,8 +2277,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
IVariableDeclaration indexVar = indexVars[i];
IParameterDeclaration param = Builder.Param(indexVar.Name, typeof(int));
iame.Parameters.Add(param);
int replaceCount = 0;
elementInit = Builder.ReplaceExpression(elementInit, Builder.VarRefExpr(indexVar), Builder.ParamRef(param), ref replaceCount);
elementInit = Builder.ReplaceExpression(elementInit, Builder.VarRefExpr(indexVar), Builder.ParamRef(param));
}
iame.Body.Statements.Add(Builder.Return(elementInit));
return iame;

Просмотреть файл

@ -3848,8 +3848,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
private void LabelEdgesWithOffsets(NodeIndex node)
{
var obc = descendantOffset[node];
var keys = this.loopVarsOfNode[node];
foreach (EdgeIndex edge in g.EdgesInto(node))
{
if (direction[edge] == Direction.Unknown && !deletedEdges.Contains(edge))

Просмотреть файл

@ -238,10 +238,14 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
// this will fill in groupOf and loopInfoOfGroup
BuildGroups(inputStmts, -1);
g = new DependencyGraph(context, flatStmts, ignoreMissingNodes: true, ignoreRequirements: false, deleteCancels: true);
g.getTargetIndex = delegate (NodeIndex node)
bool replaceTargetIndex = false;
if (replaceTargetIndex)
{
return new DependencyGraph.TargetIndex(loopMergingInfo.GetIndexOf(flatStmts[node]));
};
g.getTargetIndex = delegate (NodeIndex node)
{
return new DependencyGraph.TargetIndex(loopMergingInfo.GetIndexOf(flatStmts[node]));
};
}
if (compiler.UseSerialSchedules && !compiler.UseExperimentalSerialSchedules)
{
bool anyDeleted;

Просмотреть файл

@ -1589,7 +1589,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
}
/// <summary>
/// Finds and replaces one expression with another expression in a given expression
/// Finds and replaces one expression with another expression, everywhere it occurs
/// </summary>
/// <param name="expr">The expression</param>
/// <param name="exprFind">The expression to be found</param>
@ -1598,58 +1598,97 @@ namespace Microsoft.ML.Probabilistic.Compiler
/// <returns>The resulting expression</returns>
public IExpression ReplaceExpression(IExpression expr, IExpression exprFind, IExpression exprReplace, ref int replaceCount)
{
if (expr == null) return expr;
else if (expr.Equals(exprFind))
int localReplaceCount = 0;
IExpression result = ReplaceSubexpressions(expr, e =>
{
replaceCount++;
return exprReplace;
if (e.Equals(exprFind))
{
localReplaceCount++;
return exprReplace;
}
else
{
return null;
}
});
replaceCount = localReplaceCount;
return result;
}
/// <summary>
/// Replaces all subexpressions of an expression
/// </summary>
/// <param name="expr">The expression</param>
/// <param name="replacements">Subexpressions to find and replace</param>
/// <returns>The replaced expression</returns>
public IExpression ReplaceSubexpressions(IExpression expr, IReadOnlyDictionary<IExpression, IExpression> replacements)
{
return ReplaceSubexpressions(expr, e =>
{
replacements.TryGetValue(e, out IExpression replacement);
return replacement;
});
}
/// <summary>
/// Replaces all subexpressions of an expression
/// </summary>
/// <param name="expr">The expression</param>
/// <param name="replace">Returns a new expression or null for no replacement</param>
/// <returns>The replaced expression</returns>
public IExpression ReplaceSubexpressions(IExpression expr, Func<IExpression, IExpression> replace)
{
if (expr == null) return expr;
var replaced = replace(expr);
if (replaced != null)
{
return replaced;
}
else if ((expr is IVariableDeclarationExpression) ||
(expr is IVariableReferenceExpression) ||
(expr is ILiteralExpression) ||
(expr is IDefaultExpression) ||
(expr is IArgumentReferenceExpression) ||
(expr is IThisReferenceExpression)) return expr;
else if (expr is IArrayIndexerExpression iaie)
{
IArrayIndexerExpression aie = ArrayIndxrExpr();
foreach (IExpression ind in iaie.Indices) aie.Indices.Add(ReplaceExpression(ind, exprFind, exprReplace, ref replaceCount));
aie.Target = ReplaceExpression(iaie.Target, exprFind, exprReplace, ref replaceCount);
foreach (IExpression ind in iaie.Indices) aie.Indices.Add(ReplaceSubexpressions(ind, replace));
aie.Target = ReplaceSubexpressions(iaie.Target, replace);
return aie;
}
else if (expr is IPropertyIndexerExpression ipie)
{
IPropertyIndexerExpression pie = PropIndxrExpr();
foreach (IExpression ind in ipie.Indices) pie.Indices.Add(ReplaceExpression(ind, exprFind, exprReplace, ref replaceCount));
pie.Target = (IPropertyReferenceExpression)ReplaceExpression(ipie.Target, exprFind, exprReplace, ref replaceCount);
foreach (IExpression ind in ipie.Indices) pie.Indices.Add(ReplaceSubexpressions(ind, replace));
pie.Target = (IPropertyReferenceExpression)ReplaceSubexpressions(ipie.Target, replace);
return pie;
}
else if (expr is ICastExpression ice)
{
return CastExpr(ReplaceExpression(ice.Expression, exprFind, exprReplace, ref replaceCount), ice.TargetType);
return CastExpr(ReplaceSubexpressions(ice.Expression, replace), ice.TargetType);
}
else if (expr is ICheckedExpression iche)
{
return CheckedExpr(ReplaceExpression(iche.Expression, exprFind, exprReplace, ref replaceCount));
return CheckedExpr(ReplaceSubexpressions(iche.Expression, replace));
}
else if (
(expr is IVariableDeclarationExpression) ||
(expr is IVariableReferenceExpression) ||
(expr is ILiteralExpression) ||
(expr is IDefaultExpression) ||
(expr is IArgumentReferenceExpression)) return expr;
else if (expr is IPropertyReferenceExpression ipre)
{
IExpression target = ReplaceExpression(ipre.Target, exprFind, exprReplace, ref replaceCount);
IExpression target = ReplaceSubexpressions(ipre.Target, replace);
if (target == ipre.Target) return ipre;
IPropertyReferenceExpression pre = PropRefExpr();
pre.Property = ipre.Property;
pre.Target = target;
return pre;
}
else if (expr is IArrayCreateExpression)
else if (expr is IArrayCreateExpression iace)
{
IArrayCreateExpression iace = expr as IArrayCreateExpression;
var ace = ArrayCreateExpr();
ace.Type = iace.Type;
ace.Initializer = ReplaceExpression(iace.Initializer, exprFind, exprReplace, ref replaceCount) as IBlockExpression;
ace.Initializer = ReplaceSubexpressions(iace.Initializer, replace) as IBlockExpression;
foreach (IExpression dim in iace.Dimensions)
{
ace.Dimensions.Add(ReplaceExpression(dim, exprFind, exprReplace, ref replaceCount));
ace.Dimensions.Add(ReplaceSubexpressions(dim, replace));
}
return ace;
}
@ -1658,7 +1697,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
IBlockExpression be = BlockExpr();
foreach (IExpression e in ible.Expressions)
{
be.Expressions.Add(ReplaceExpression(e, exprFind, exprReplace, ref replaceCount));
be.Expressions.Add(ReplaceSubexpressions(e, replace));
}
return be;
}
@ -1668,7 +1707,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
mie.Method = imie.Method;
foreach (IExpression arg in imie.Arguments)
{
mie.Arguments.Add(ReplaceExpression(arg, exprFind, exprReplace, ref replaceCount));
mie.Arguments.Add(ReplaceSubexpressions(arg, replace));
}
return mie;
}
@ -1679,9 +1718,9 @@ namespace Microsoft.ML.Probabilistic.Compiler
oce.Type = ioce.Type;
foreach (IExpression arg in ioce.Arguments)
{
oce.Arguments.Add(ReplaceExpression(arg, exprFind, exprReplace, ref replaceCount));
oce.Arguments.Add(ReplaceSubexpressions(arg, replace));
}
oce.Initializer = (IBlockExpression)ReplaceExpression(ioce.Initializer, exprFind, exprReplace, ref replaceCount);
oce.Initializer = (IBlockExpression)ReplaceSubexpressions(ioce.Initializer, replace);
return oce;
}
else if (expr is IAnonymousMethodExpression iame)
@ -1695,11 +1734,11 @@ namespace Microsoft.ML.Probabilistic.Compiler
IStatement st = ist;
if (ist is IExpressionStatement ies)
{
st = ExprStatement(ReplaceExpression(ies.Expression, exprFind, exprReplace, ref replaceCount));
st = ExprStatement(ReplaceSubexpressions(ies.Expression, replace));
}
else if (ist is IMethodReturnStatement imrs)
{
st = Return(ReplaceExpression(imrs.Expression, exprFind, exprReplace, ref replaceCount));
st = Return(ReplaceSubexpressions(imrs.Expression, replace));
}
ame.Body.Statements.Add(st);
}
@ -1709,33 +1748,29 @@ namespace Microsoft.ML.Probabilistic.Compiler
{
IUnaryExpression ue = UnaryExpr();
ue.Operator = iue.Operator;
ue.Expression = ReplaceExpression(iue.Expression, exprFind, exprReplace, ref replaceCount);
ue.Expression = ReplaceSubexpressions(iue.Expression, replace);
return ue;
}
else if (expr is IBinaryExpression ibe)
{
IBinaryExpression be = BinaryExpr();
be.Operator = ibe.Operator;
be.Left = ReplaceExpression(ibe.Left, exprFind, exprReplace, ref replaceCount);
be.Right = ReplaceExpression(ibe.Right, exprFind, exprReplace, ref replaceCount);
be.Left = ReplaceSubexpressions(ibe.Left, replace);
be.Right = ReplaceSubexpressions(ibe.Right, replace);
return be;
}
else if (expr is IMethodReferenceExpression imre)
{
var target = ReplaceExpression(imre.Target, exprFind, exprReplace, ref replaceCount);
var target = ReplaceSubexpressions(imre.Target, replace);
return MethodRefExpr(imre.Method, target);
}
else if (expr is IThisReferenceExpression)
{
return expr;
}
else if (expr is IAddressOutExpression iaoe)
{
IAddressOutExpression aoe = AddrOutExpr();
aoe.Expression = ReplaceExpression(iaoe.Expression, exprFind, exprReplace, ref replaceCount);
aoe.Expression = ReplaceSubexpressions(iaoe.Expression, replace);
return aoe;
}
else throw new NotImplementedException("Unhandled expression type in ReplaceExpression(): " + expr.GetType());
else throw new NotImplementedException("Unhandled expression type in ReplaceSubexpressions(): " + expr.GetType());
}
/// <summary>