зеркало из https://github.com/dotnet/infer.git
Fixed DependencyAnalysisTransform and SchedulingTransform (#405)
This commit is contained in:
Родитель
55d26a7138
Коммит
5156776f4c
|
@ -411,8 +411,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
IList<IList<IExpression>> indices, IList<IList<IExpression>> wildcardVars = null)
|
||||
{
|
||||
IExpression original = prototypeExpression;
|
||||
int replaceCount = 0;
|
||||
prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, ref replaceCount);
|
||||
prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, out int replaceCount);
|
||||
int mpDepth = Util.GetArrayDepth(varType, Distribution.GetDomainType(prototypeExpression.GetExpressionType()));
|
||||
int indexingDepth = indices.Count;
|
||||
int wildcardBracket = 0;
|
||||
|
@ -504,8 +503,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
IExpression index = indices[i][j];
|
||||
if (Recognizer.IsStaticMethod(index, new Func<int>(GateAnalysisTransform.AnyIndex)))
|
||||
{
|
||||
int replaceCount = 0;
|
||||
sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, ref replaceCount));
|
||||
sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, out int replaceCount));
|
||||
IVariableDeclaration v = indexVars[i][j];
|
||||
if (wildcardVars != null) v = Recognizer.GetVariableDeclaration(wildcardVars[newIndexVars.Count][indexVarsBracket.Count]);
|
||||
else if (Recognizer.GetLoopForVariable(context, v) != null)
|
||||
|
@ -556,8 +554,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
IList<IExpression> replacementBracket = Builder.ExprCollection();
|
||||
for (int j = 0; j < sizeBracket.Length; j++)
|
||||
{
|
||||
int replaceCount = 0;
|
||||
sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, ref replaceCount);
|
||||
sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, out int replaceCount);
|
||||
if (replaceCount > 0) indexVarBracket[j] = GenerateLoopVar(context, "_a");
|
||||
else if (indexVars.Count > i) indexVarBracket[j] = indexVars[i][j];
|
||||
if (indexVarBracket[j] != null) replacementBracket.Add(Builder.VarRefExpr(indexVarBracket[j]));
|
||||
|
@ -598,9 +595,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
// substitute indices in the marginal prototype expression
|
||||
vi.marginalPrototypeExpression = GetMarginalPrototypeExpression(context, marginalPrototypeExpression, replacements, wildcardVars);
|
||||
}
|
||||
InitialiseTo it = context.InputAttributes.Get<InitialiseTo>(declaration);
|
||||
if (it != null && copyInitializer)
|
||||
if (copyInitializer) CopyInitialiser();
|
||||
ChannelTransform.setAllGroupRoots(context, arrayvd, false);
|
||||
return arrayvd;
|
||||
|
||||
void CopyInitialiser()
|
||||
{
|
||||
InitialiseTo it = context.InputAttributes.Get<InitialiseTo>(declaration);
|
||||
if (it == null)
|
||||
{
|
||||
return;
|
||||
}
|
||||
// if original array is indexed [i,j][k,l][m,n] and indices = [*,*][3,*] then
|
||||
// initExpr2 = new PlaceHolder[wildcard0,wildcard1] { new PlaceHolder[wildcard2] { new PlaceHolder[newIndexVar] { initExpr[wildcard0,wildcard1][3,wildcard2] } } }
|
||||
IExpression initExpr = it.initialMessagesExpression;
|
||||
|
@ -643,8 +648,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
}
|
||||
context.OutputAttributes.Set(arrayvd, new InitialiseTo(initExpr));
|
||||
}
|
||||
ChannelTransform.setAllGroupRoots(context, arrayvd, false);
|
||||
return arrayvd;
|
||||
}
|
||||
|
||||
internal static IExpression MakePlaceHolderArrayCreate(IExpression expr, IList<IVariableDeclaration[]> indexVars)
|
||||
|
@ -693,11 +696,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes
|
|||
/// <param name="expr">Any expression</param>
|
||||
/// <param name="indices">A list of lists of index expressions (one list for each indexing bracket).</param>
|
||||
/// <param name="wildcardIndices">Expressions used to replace wildcards. May be null if there are no wildcards.</param>
|
||||
/// <param name="replaceCount">Incremented for each replacement.</param>
|
||||
/// <param name="replaceCount">The number of replacements.</param>
|
||||
/// <returns>A new expression.</returns>
|
||||
internal IExpression ReplaceIndexVars(BasicTransformContext context, IExpression expr, IList<IList<IExpression>> indices,
|
||||
IList<IList<IExpression>> wildcardIndices, ref int replaceCount)
|
||||
IList<IList<IExpression>> wildcardIndices, out int replaceCount)
|
||||
{
|
||||
replaceCount = 0;
|
||||
Dictionary<IVariableDeclaration, IExpression> replacedIndexVars = new Dictionary<IVariableDeclaration, IExpression>();
|
||||
int wildcardBracket = 0;
|
||||
for (int depth = 0; depth < indices.Count; depth++)
|
||||
|
|
|
@ -946,7 +946,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
var lct2 = new LoopCuttingTransform(true);
|
||||
tc.AddTransform(lct2);
|
||||
tc.AddTransform(lct2); // run again to catch uses before declaration
|
||||
if(OptimiseInferenceCode)
|
||||
if (OptimiseInferenceCode)
|
||||
{
|
||||
// must run after HoistingTransform
|
||||
tc.AddTransform(new LoopRemovalTransform());
|
||||
|
|
|
@ -836,9 +836,10 @@ namespace Microsoft.ML.Probabilistic.Models
|
|||
for (int i = 0; i < indexVars.Length; i++)
|
||||
{
|
||||
IModelExpression expr = parent.indices[i];
|
||||
if (!(expr is Range))
|
||||
if (expr is Range range)
|
||||
indexVars[i] = range.GetIndexDeclaration();
|
||||
else
|
||||
throw new Exception(parent + ".InitializeTo is not allowed since the indices are not ranges");
|
||||
indexVars[i] = ((Range)expr).GetIndexDeclaration();
|
||||
}
|
||||
initExpr = VariableInformation.MakePlaceHolderArrayCreate(initExpr, indexVars);
|
||||
parent = (Variable)parent.ArrayVariable;
|
||||
|
|
|
@ -321,7 +321,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
foreach (ConditionBinding ci in conditionContext)
|
||||
{
|
||||
// each lhs has already been replaced, so we only need to compare for equality
|
||||
// each subexpression has already been replaced, so we only need to compare for equality here
|
||||
if (expr.Equals(ci.lhs)) return ci.rhs;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -810,10 +810,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
if (parameter.Name == "length")
|
||||
{
|
||||
IExpression arg = ioce.Arguments[argIndex];
|
||||
if (arg is ILiteralExpression)
|
||||
if (arg is ILiteralExpression ile)
|
||||
{
|
||||
object argValue = ((ILiteralExpression)arg).Value;
|
||||
if (argValue is int && (int)argValue == 0)
|
||||
object argValue = ile.Value;
|
||||
if (argValue is int i && i == 0)
|
||||
dependencyInformation.IsUniform = true;
|
||||
}
|
||||
break;
|
||||
|
@ -1063,18 +1063,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
// arguments in all dependencies with the correct model expressions.
|
||||
IExpression MapDependency(IStatement ist)
|
||||
{
|
||||
return ReplaceArgs(((IExpressionStatement)ist).Expression);
|
||||
|
||||
// Replace the parameter expressions in an expression with the corresponding model expressions
|
||||
IExpression ReplaceArgs(IExpression iExpression)
|
||||
{
|
||||
foreach (KeyValuePair<IExpression, IExpression> kvp in parameterToExpressionMap)
|
||||
{
|
||||
int repCount = 0;
|
||||
iExpression = Builder.ReplaceExpression(iExpression, kvp.Key, kvp.Value, ref repCount);
|
||||
}
|
||||
return iExpression;
|
||||
}
|
||||
return Builder.ReplaceSubexpressions(((IExpressionStatement)ist).Expression, parameterToExpressionMap);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -148,7 +148,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
|
||||
private static void AddCopyStatements(ICollection<IStatement> stmts, VariableInformation varInfo, int indexingDepth, IExpression lhs, IExpression rhs,
|
||||
int bracket = 0, Dictionary<IExpression,IExpression> replacements = null)
|
||||
int bracket = 0, Dictionary<IExpression, IExpression> replacements = null)
|
||||
{
|
||||
if (indexingDepth == bracket)
|
||||
{
|
||||
|
@ -164,7 +164,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
int[] sizes = Util.ArrayInit(varInfo.sizes[bracket].Length, i => (int)((ILiteralExpression)varInfo.sizes[bracket][i]).Value);
|
||||
ForEachLiteralIndex(sizes, index =>
|
||||
{
|
||||
IExpression[] bracketIndices = Util.ArrayInit(index.Length, i => Builder.LiteralExpr(index[i]));
|
||||
ILiteralExpression[] bracketIndices = Util.ArrayInit(index.Length, i => Builder.LiteralExpr(index[i]));
|
||||
var newLhs = Builder.ArrayIndex(lhs, bracketIndices);
|
||||
var newRhs = Builder.ArrayIndex(rhs, bracketIndices);
|
||||
for (int dim = 0; dim < index.Length; dim++)
|
||||
|
@ -180,7 +180,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
IReadOnlyList<IExpression> replacedSizes = varInfo.sizes[bracket];
|
||||
if(replacements != null)
|
||||
{
|
||||
replacedSizes = Util.ArrayInit(replacedSizes.Count, i => Replace(replacedSizes[i], replacements));
|
||||
replacedSizes = Util.ArrayInit(replacedSizes.Count, i => Builder.ReplaceSubexpressions(replacedSizes[i], replacements));
|
||||
}
|
||||
IForStatement innerForStatement;
|
||||
var fs = Builder.NestedForStmt(varInfo.indexVars[bracket], replacedSizes, out innerForStatement);
|
||||
|
@ -194,15 +194,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
}
|
||||
|
||||
private static IExpression Replace(IExpression expr, IReadOnlyDictionary<IExpression,IExpression> replacements)
|
||||
{
|
||||
foreach(var entry in replacements)
|
||||
{
|
||||
expr = Builder.ReplaceExpression(expr, entry.Key, entry.Value);
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
private static void ForEachLiteralIndex(int[] sizes, Action<int[]> action)
|
||||
{
|
||||
int[] strides = StringUtil.ArrayStrides(sizes);
|
||||
|
|
|
@ -166,10 +166,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
foreach (var stmt in outputBlock)
|
||||
{
|
||||
if (stmt is IWhileStatement)
|
||||
// recursively collect the child block
|
||||
if (stmt is IWhileStatement iws)
|
||||
{
|
||||
// recursively collect the child block
|
||||
IWhileStatement iws = (IWhileStatement)stmt;
|
||||
CollectTransformedStmts(iws.Body.Statements, replacementsInContext);
|
||||
// merge the child replacements into this block's replacements
|
||||
var childReplacements = replacementsInContext[iws.Body.Statements];
|
||||
|
@ -208,9 +207,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
var replacements = replacementsInContext[outputBlock];
|
||||
foreach (var stmt in outputBlock)
|
||||
{
|
||||
if (stmt is IWhileStatement)
|
||||
if (stmt is IWhileStatement iws)
|
||||
{
|
||||
IWhileStatement iws = (IWhileStatement)stmt;
|
||||
// merge this block's replacements into the child's replacements
|
||||
var childReplacements = replacementsInContext[iws.Body.Statements];
|
||||
foreach (var entry in replacements)
|
||||
|
@ -314,7 +312,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
OffsetInfo newOffsetInfo = new OffsetInfo();
|
||||
foreach (var offset in offsetInfo)
|
||||
{
|
||||
if (CanKeepOffsetDependency(ssinfo, reversedLoopVars, reversedLoopVarsOther, offset))
|
||||
if (CanKeepOffsetDependency(reversedLoopVarsOther, offset))
|
||||
newOffsetInfo.Add(offset);
|
||||
else
|
||||
changed = true;
|
||||
|
@ -334,19 +332,19 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
di.offsetIndexOf.Add(pair.Key, pair.Value);
|
||||
}
|
||||
}
|
||||
|
||||
private bool CanKeepOffsetDependency(SerialSchedulingInfo ssinfo, Set<IVariableDeclaration> reversedLoopVars, Set<IVariableDeclaration> reversedLoopVarsOther, Offset offset)
|
||||
{
|
||||
foreach (var loopVar in ssinfo.loopInfos.Select(info => info.loopVar))
|
||||
bool CanKeepOffsetDependency(Set<IVariableDeclaration> reversedLoopVarsOther, Offset offset)
|
||||
{
|
||||
bool compatible = (reversedLoopVars.Contains(loopVar) == reversedLoopVarsOther.Contains(loopVar));
|
||||
if (!compatible)
|
||||
return false;
|
||||
if (offset.loopVar == loopVar)
|
||||
break;
|
||||
foreach (var loopVar in ssinfo.loopInfos.Select(info => info.loopVar))
|
||||
{
|
||||
bool compatible = (reversedLoopVars.Contains(loopVar) == reversedLoopVarsOther.Contains(loopVar));
|
||||
if (!compatible)
|
||||
return false;
|
||||
if (offset.loopVar == loopVar)
|
||||
break;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
@ -574,17 +574,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Apply all bindings to expr, in order.
|
||||
/// Apply all bindings to expr.
|
||||
/// </summary>
|
||||
/// <param name="bindings"></param>
|
||||
/// <param name="bindings">Must not chain, i.e. no lhs appears in any rhs.</param>
|
||||
/// <param name="expr"></param>
|
||||
/// <returns></returns>
|
||||
internal static IExpression ReplaceExpression(IEnumerable<ConditionBinding> bindings, IExpression expr)
|
||||
{
|
||||
foreach (ConditionBinding binding in bindings)
|
||||
{
|
||||
int replaceCount = 0;
|
||||
expr = Builder.ReplaceExpression(expr, binding.lhs, binding.rhs, ref replaceCount);
|
||||
expr = Builder.ReplaceExpression(expr, binding.lhs, binding.rhs);
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
|
|
@ -516,6 +516,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
out replaced);
|
||||
if (isDef && !replaced)
|
||||
{
|
||||
// This call is valid because bindings is formed from condition expressions and these have been fully replaced by ConstantFoldingTransform.
|
||||
IExpression boundDef =
|
||||
GateAnalysisTransform.ReplaceExpression(bindings, definedExpression.Expression);
|
||||
Error($"{expr} doesn\'t match bound GateBlock def: {boundDef}");
|
||||
|
@ -586,8 +587,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
}));
|
||||
RecordReplacement(expr, toClone, !isSubset);
|
||||
|
||||
int replaceCount = 0;
|
||||
expr = Builder.ReplaceExpression(expr, toReplace, clone, ref replaceCount);
|
||||
expr = Builder.ReplaceExpression(expr, toReplace, clone);
|
||||
replaced = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -565,8 +565,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
int depth = Recognizer.GetIndexingDepth(index);
|
||||
IExpression resultSize = indexInfo.sizes[depth][0];
|
||||
var indices = Recognizer.GetIndices(index);
|
||||
int replaceCount = 0;
|
||||
resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, ref replaceCount);
|
||||
resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, out int replaceCount);
|
||||
indexInfo.DefineIndexVarsUpToDepth(context, depth + 1);
|
||||
IVariableDeclaration resultIndex = indexInfo.indexVars[depth][0];
|
||||
Type arrayType = arrayExpr.GetExpressionType();
|
||||
|
|
|
@ -1864,8 +1864,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
List<IList<IExpression>> indices = Recognizer.GetIndices(outputLhs);
|
||||
if (indices.Count > 0)
|
||||
{
|
||||
int replaceCount = 0;
|
||||
mpe = channelVarInfo.ReplaceIndexVars(context, mpe, indices, null, ref replaceCount);
|
||||
mpe = channelVarInfo.ReplaceIndexVars(context, mpe, indices, null, out int replaceCount);
|
||||
}
|
||||
if (mai.isDistribution)
|
||||
{
|
||||
|
@ -2122,13 +2121,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
{
|
||||
if (iace.Type.DotNetType.Equals(typeof(PlaceHolder)) && iace.Initializer != null && iace.Initializer.Expressions.Count == 1)
|
||||
{
|
||||
// VariableInformation.DeriveArrayVariable has created an initializer expression of the form:
|
||||
// new PlaceHolder[wildcard0,wildcard1] { initExpr[wildcard0,wildcard1] }
|
||||
IExpression initExpr = iace.Initializer.Expressions[0];
|
||||
// replace index variables with the given indices
|
||||
Dictionary<IExpression, IExpression> replacements = new Dictionary<IExpression, IExpression>();
|
||||
for (int dim = 0; dim < iace.Dimensions.Count; dim++)
|
||||
{
|
||||
initExpr = Builder.ReplaceExpression(initExpr, iace.Dimensions[dim], iaie.Indices[dim]);
|
||||
replacements.Add(iace.Dimensions[dim], iaie.Indices[dim]);
|
||||
}
|
||||
return initExpr;
|
||||
return Builder.ReplaceSubexpressions(initExpr, replacements);
|
||||
}
|
||||
}
|
||||
return Builder.ArrayIndex(target, iaie.Indices);
|
||||
|
@ -2275,8 +2277,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
IVariableDeclaration indexVar = indexVars[i];
|
||||
IParameterDeclaration param = Builder.Param(indexVar.Name, typeof(int));
|
||||
iame.Parameters.Add(param);
|
||||
int replaceCount = 0;
|
||||
elementInit = Builder.ReplaceExpression(elementInit, Builder.VarRefExpr(indexVar), Builder.ParamRef(param), ref replaceCount);
|
||||
elementInit = Builder.ReplaceExpression(elementInit, Builder.VarRefExpr(indexVar), Builder.ParamRef(param));
|
||||
}
|
||||
iame.Body.Statements.Add(Builder.Return(elementInit));
|
||||
return iame;
|
||||
|
|
|
@ -3848,8 +3848,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
|
||||
private void LabelEdgesWithOffsets(NodeIndex node)
|
||||
{
|
||||
var obc = descendantOffset[node];
|
||||
var keys = this.loopVarsOfNode[node];
|
||||
foreach (EdgeIndex edge in g.EdgesInto(node))
|
||||
{
|
||||
if (direction[edge] == Direction.Unknown && !deletedEdges.Contains(edge))
|
||||
|
|
|
@ -238,10 +238,14 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms
|
|||
// this will fill in groupOf and loopInfoOfGroup
|
||||
BuildGroups(inputStmts, -1);
|
||||
g = new DependencyGraph(context, flatStmts, ignoreMissingNodes: true, ignoreRequirements: false, deleteCancels: true);
|
||||
g.getTargetIndex = delegate (NodeIndex node)
|
||||
bool replaceTargetIndex = false;
|
||||
if (replaceTargetIndex)
|
||||
{
|
||||
return new DependencyGraph.TargetIndex(loopMergingInfo.GetIndexOf(flatStmts[node]));
|
||||
};
|
||||
g.getTargetIndex = delegate (NodeIndex node)
|
||||
{
|
||||
return new DependencyGraph.TargetIndex(loopMergingInfo.GetIndexOf(flatStmts[node]));
|
||||
};
|
||||
}
|
||||
if (compiler.UseSerialSchedules && !compiler.UseExperimentalSerialSchedules)
|
||||
{
|
||||
bool anyDeleted;
|
||||
|
|
|
@ -1589,7 +1589,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
}
|
||||
|
||||
/// <summary>
|
||||
/// Finds and replaces one expression with another expression in a given expression
|
||||
/// Finds and replaces one expression with another expression, everywhere it occurs
|
||||
/// </summary>
|
||||
/// <param name="expr">The expression</param>
|
||||
/// <param name="exprFind">The expression to be found</param>
|
||||
|
@ -1598,58 +1598,97 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
/// <returns>The resulting expression</returns>
|
||||
public IExpression ReplaceExpression(IExpression expr, IExpression exprFind, IExpression exprReplace, ref int replaceCount)
|
||||
{
|
||||
if (expr == null) return expr;
|
||||
else if (expr.Equals(exprFind))
|
||||
int localReplaceCount = 0;
|
||||
IExpression result = ReplaceSubexpressions(expr, e =>
|
||||
{
|
||||
replaceCount++;
|
||||
return exprReplace;
|
||||
if (e.Equals(exprFind))
|
||||
{
|
||||
localReplaceCount++;
|
||||
return exprReplace;
|
||||
}
|
||||
else
|
||||
{
|
||||
return null;
|
||||
}
|
||||
});
|
||||
replaceCount = localReplaceCount;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Replaces all subexpressions of an expression
|
||||
/// </summary>
|
||||
/// <param name="expr">The expression</param>
|
||||
/// <param name="replacements">Subexpressions to find and replace</param>
|
||||
/// <returns>The replaced expression</returns>
|
||||
public IExpression ReplaceSubexpressions(IExpression expr, IReadOnlyDictionary<IExpression, IExpression> replacements)
|
||||
{
|
||||
return ReplaceSubexpressions(expr, e =>
|
||||
{
|
||||
replacements.TryGetValue(e, out IExpression replacement);
|
||||
return replacement;
|
||||
});
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Replaces all subexpressions of an expression
|
||||
/// </summary>
|
||||
/// <param name="expr">The expression</param>
|
||||
/// <param name="replace">Returns a new expression or null for no replacement</param>
|
||||
/// <returns>The replaced expression</returns>
|
||||
public IExpression ReplaceSubexpressions(IExpression expr, Func<IExpression, IExpression> replace)
|
||||
{
|
||||
if (expr == null) return expr;
|
||||
var replaced = replace(expr);
|
||||
if (replaced != null)
|
||||
{
|
||||
return replaced;
|
||||
}
|
||||
else if ((expr is IVariableDeclarationExpression) ||
|
||||
(expr is IVariableReferenceExpression) ||
|
||||
(expr is ILiteralExpression) ||
|
||||
(expr is IDefaultExpression) ||
|
||||
(expr is IArgumentReferenceExpression) ||
|
||||
(expr is IThisReferenceExpression)) return expr;
|
||||
else if (expr is IArrayIndexerExpression iaie)
|
||||
{
|
||||
IArrayIndexerExpression aie = ArrayIndxrExpr();
|
||||
foreach (IExpression ind in iaie.Indices) aie.Indices.Add(ReplaceExpression(ind, exprFind, exprReplace, ref replaceCount));
|
||||
aie.Target = ReplaceExpression(iaie.Target, exprFind, exprReplace, ref replaceCount);
|
||||
foreach (IExpression ind in iaie.Indices) aie.Indices.Add(ReplaceSubexpressions(ind, replace));
|
||||
aie.Target = ReplaceSubexpressions(iaie.Target, replace);
|
||||
return aie;
|
||||
}
|
||||
else if (expr is IPropertyIndexerExpression ipie)
|
||||
{
|
||||
IPropertyIndexerExpression pie = PropIndxrExpr();
|
||||
foreach (IExpression ind in ipie.Indices) pie.Indices.Add(ReplaceExpression(ind, exprFind, exprReplace, ref replaceCount));
|
||||
pie.Target = (IPropertyReferenceExpression)ReplaceExpression(ipie.Target, exprFind, exprReplace, ref replaceCount);
|
||||
foreach (IExpression ind in ipie.Indices) pie.Indices.Add(ReplaceSubexpressions(ind, replace));
|
||||
pie.Target = (IPropertyReferenceExpression)ReplaceSubexpressions(ipie.Target, replace);
|
||||
return pie;
|
||||
}
|
||||
else if (expr is ICastExpression ice)
|
||||
{
|
||||
return CastExpr(ReplaceExpression(ice.Expression, exprFind, exprReplace, ref replaceCount), ice.TargetType);
|
||||
return CastExpr(ReplaceSubexpressions(ice.Expression, replace), ice.TargetType);
|
||||
}
|
||||
else if (expr is ICheckedExpression iche)
|
||||
{
|
||||
return CheckedExpr(ReplaceExpression(iche.Expression, exprFind, exprReplace, ref replaceCount));
|
||||
return CheckedExpr(ReplaceSubexpressions(iche.Expression, replace));
|
||||
}
|
||||
else if (
|
||||
(expr is IVariableDeclarationExpression) ||
|
||||
(expr is IVariableReferenceExpression) ||
|
||||
(expr is ILiteralExpression) ||
|
||||
(expr is IDefaultExpression) ||
|
||||
(expr is IArgumentReferenceExpression)) return expr;
|
||||
else if (expr is IPropertyReferenceExpression ipre)
|
||||
{
|
||||
IExpression target = ReplaceExpression(ipre.Target, exprFind, exprReplace, ref replaceCount);
|
||||
IExpression target = ReplaceSubexpressions(ipre.Target, replace);
|
||||
if (target == ipre.Target) return ipre;
|
||||
IPropertyReferenceExpression pre = PropRefExpr();
|
||||
pre.Property = ipre.Property;
|
||||
pre.Target = target;
|
||||
return pre;
|
||||
}
|
||||
else if (expr is IArrayCreateExpression)
|
||||
else if (expr is IArrayCreateExpression iace)
|
||||
{
|
||||
IArrayCreateExpression iace = expr as IArrayCreateExpression;
|
||||
var ace = ArrayCreateExpr();
|
||||
ace.Type = iace.Type;
|
||||
ace.Initializer = ReplaceExpression(iace.Initializer, exprFind, exprReplace, ref replaceCount) as IBlockExpression;
|
||||
ace.Initializer = ReplaceSubexpressions(iace.Initializer, replace) as IBlockExpression;
|
||||
foreach (IExpression dim in iace.Dimensions)
|
||||
{
|
||||
ace.Dimensions.Add(ReplaceExpression(dim, exprFind, exprReplace, ref replaceCount));
|
||||
ace.Dimensions.Add(ReplaceSubexpressions(dim, replace));
|
||||
}
|
||||
return ace;
|
||||
}
|
||||
|
@ -1658,7 +1697,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
IBlockExpression be = BlockExpr();
|
||||
foreach (IExpression e in ible.Expressions)
|
||||
{
|
||||
be.Expressions.Add(ReplaceExpression(e, exprFind, exprReplace, ref replaceCount));
|
||||
be.Expressions.Add(ReplaceSubexpressions(e, replace));
|
||||
}
|
||||
return be;
|
||||
}
|
||||
|
@ -1668,7 +1707,7 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
mie.Method = imie.Method;
|
||||
foreach (IExpression arg in imie.Arguments)
|
||||
{
|
||||
mie.Arguments.Add(ReplaceExpression(arg, exprFind, exprReplace, ref replaceCount));
|
||||
mie.Arguments.Add(ReplaceSubexpressions(arg, replace));
|
||||
}
|
||||
return mie;
|
||||
}
|
||||
|
@ -1679,9 +1718,9 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
oce.Type = ioce.Type;
|
||||
foreach (IExpression arg in ioce.Arguments)
|
||||
{
|
||||
oce.Arguments.Add(ReplaceExpression(arg, exprFind, exprReplace, ref replaceCount));
|
||||
oce.Arguments.Add(ReplaceSubexpressions(arg, replace));
|
||||
}
|
||||
oce.Initializer = (IBlockExpression)ReplaceExpression(ioce.Initializer, exprFind, exprReplace, ref replaceCount);
|
||||
oce.Initializer = (IBlockExpression)ReplaceSubexpressions(ioce.Initializer, replace);
|
||||
return oce;
|
||||
}
|
||||
else if (expr is IAnonymousMethodExpression iame)
|
||||
|
@ -1695,11 +1734,11 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
IStatement st = ist;
|
||||
if (ist is IExpressionStatement ies)
|
||||
{
|
||||
st = ExprStatement(ReplaceExpression(ies.Expression, exprFind, exprReplace, ref replaceCount));
|
||||
st = ExprStatement(ReplaceSubexpressions(ies.Expression, replace));
|
||||
}
|
||||
else if (ist is IMethodReturnStatement imrs)
|
||||
{
|
||||
st = Return(ReplaceExpression(imrs.Expression, exprFind, exprReplace, ref replaceCount));
|
||||
st = Return(ReplaceSubexpressions(imrs.Expression, replace));
|
||||
}
|
||||
ame.Body.Statements.Add(st);
|
||||
}
|
||||
|
@ -1709,33 +1748,29 @@ namespace Microsoft.ML.Probabilistic.Compiler
|
|||
{
|
||||
IUnaryExpression ue = UnaryExpr();
|
||||
ue.Operator = iue.Operator;
|
||||
ue.Expression = ReplaceExpression(iue.Expression, exprFind, exprReplace, ref replaceCount);
|
||||
ue.Expression = ReplaceSubexpressions(iue.Expression, replace);
|
||||
return ue;
|
||||
}
|
||||
else if (expr is IBinaryExpression ibe)
|
||||
{
|
||||
IBinaryExpression be = BinaryExpr();
|
||||
be.Operator = ibe.Operator;
|
||||
be.Left = ReplaceExpression(ibe.Left, exprFind, exprReplace, ref replaceCount);
|
||||
be.Right = ReplaceExpression(ibe.Right, exprFind, exprReplace, ref replaceCount);
|
||||
be.Left = ReplaceSubexpressions(ibe.Left, replace);
|
||||
be.Right = ReplaceSubexpressions(ibe.Right, replace);
|
||||
return be;
|
||||
}
|
||||
else if (expr is IMethodReferenceExpression imre)
|
||||
{
|
||||
var target = ReplaceExpression(imre.Target, exprFind, exprReplace, ref replaceCount);
|
||||
var target = ReplaceSubexpressions(imre.Target, replace);
|
||||
return MethodRefExpr(imre.Method, target);
|
||||
}
|
||||
else if (expr is IThisReferenceExpression)
|
||||
{
|
||||
return expr;
|
||||
}
|
||||
else if (expr is IAddressOutExpression iaoe)
|
||||
{
|
||||
IAddressOutExpression aoe = AddrOutExpr();
|
||||
aoe.Expression = ReplaceExpression(iaoe.Expression, exprFind, exprReplace, ref replaceCount);
|
||||
aoe.Expression = ReplaceSubexpressions(iaoe.Expression, replace);
|
||||
return aoe;
|
||||
}
|
||||
else throw new NotImplementedException("Unhandled expression type in ReplaceExpression(): " + expr.GetType());
|
||||
else throw new NotImplementedException("Unhandled expression type in ReplaceSubexpressions(): " + expr.GetType());
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
Загрузка…
Ссылка в новой задаче