diff --git a/src/Compiler/Infer/CompilerAttributes/VariableInformation.cs b/src/Compiler/Infer/CompilerAttributes/VariableInformation.cs index cf7521aa..daea770e 100644 --- a/src/Compiler/Infer/CompilerAttributes/VariableInformation.cs +++ b/src/Compiler/Infer/CompilerAttributes/VariableInformation.cs @@ -411,8 +411,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes IList> indices, IList> wildcardVars = null) { IExpression original = prototypeExpression; - int replaceCount = 0; - prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, ref replaceCount); + prototypeExpression = ReplaceIndexVars(context, prototypeExpression, indices, wildcardVars, out int replaceCount); int mpDepth = Util.GetArrayDepth(varType, Distribution.GetDomainType(prototypeExpression.GetExpressionType())); int indexingDepth = indices.Count; int wildcardBracket = 0; @@ -504,8 +503,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes IExpression index = indices[i][j]; if (Recognizer.IsStaticMethod(index, new Func(GateAnalysisTransform.AnyIndex))) { - int replaceCount = 0; - sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, ref replaceCount)); + sizeBracket.Add(ReplaceIndexVars(context, sizes[i][j], indices, wildcardVars, out int replaceCount)); IVariableDeclaration v = indexVars[i][j]; if (wildcardVars != null) v = Recognizer.GetVariableDeclaration(wildcardVars[newIndexVars.Count][indexVarsBracket.Count]); else if (Recognizer.GetLoopForVariable(context, v) != null) @@ -556,8 +554,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes IList replacementBracket = Builder.ExprCollection(); for (int j = 0; j < sizeBracket.Length; j++) { - int replaceCount = 0; - sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, ref replaceCount); + sizeBracket[j] = ReplaceIndexVars(context, sizes[i][j], replacements, wildcardVars, out int replaceCount); if (replaceCount > 0) indexVarBracket[j] = GenerateLoopVar(context, "_a"); else if (indexVars.Count > i) indexVarBracket[j] = indexVars[i][j]; if (indexVarBracket[j] != null) replacementBracket.Add(Builder.VarRefExpr(indexVarBracket[j])); @@ -598,9 +595,17 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes // substitute indices in the marginal prototype expression vi.marginalPrototypeExpression = GetMarginalPrototypeExpression(context, marginalPrototypeExpression, replacements, wildcardVars); } - InitialiseTo it = context.InputAttributes.Get(declaration); - if (it != null && copyInitializer) + if (copyInitializer) CopyInitialiser(); + ChannelTransform.setAllGroupRoots(context, arrayvd, false); + return arrayvd; + + void CopyInitialiser() { + InitialiseTo it = context.InputAttributes.Get(declaration); + if (it == null) + { + return; + } // if original array is indexed [i,j][k,l][m,n] and indices = [*,*][3,*] then // initExpr2 = new PlaceHolder[wildcard0,wildcard1] { new PlaceHolder[wildcard2] { new PlaceHolder[newIndexVar] { initExpr[wildcard0,wildcard1][3,wildcard2] } } } IExpression initExpr = it.initialMessagesExpression; @@ -643,8 +648,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes } context.OutputAttributes.Set(arrayvd, new InitialiseTo(initExpr)); } - ChannelTransform.setAllGroupRoots(context, arrayvd, false); - return arrayvd; } internal static IExpression MakePlaceHolderArrayCreate(IExpression expr, IList indexVars) @@ -693,11 +696,12 @@ namespace Microsoft.ML.Probabilistic.Compiler.Attributes /// Any expression /// A list of lists of index expressions (one list for each indexing bracket). /// Expressions used to replace wildcards. May be null if there are no wildcards. - /// Incremented for each replacement. + /// The number of replacements. /// A new expression. internal IExpression ReplaceIndexVars(BasicTransformContext context, IExpression expr, IList> indices, - IList> wildcardIndices, ref int replaceCount) + IList> wildcardIndices, out int replaceCount) { + replaceCount = 0; Dictionary replacedIndexVars = new Dictionary(); int wildcardBracket = 0; for (int depth = 0; depth < indices.Count; depth++) diff --git a/src/Compiler/Infer/ModelCompiler.cs b/src/Compiler/Infer/ModelCompiler.cs index c1a9a5db..5fe3e1f3 100644 --- a/src/Compiler/Infer/ModelCompiler.cs +++ b/src/Compiler/Infer/ModelCompiler.cs @@ -946,7 +946,7 @@ namespace Microsoft.ML.Probabilistic.Compiler var lct2 = new LoopCuttingTransform(true); tc.AddTransform(lct2); tc.AddTransform(lct2); // run again to catch uses before declaration - if(OptimiseInferenceCode) + if (OptimiseInferenceCode) { // must run after HoistingTransform tc.AddTransform(new LoopRemovalTransform()); diff --git a/src/Compiler/Infer/Models/ModelBuilder.cs b/src/Compiler/Infer/Models/ModelBuilder.cs index d0cec737..75f0bc58 100644 --- a/src/Compiler/Infer/Models/ModelBuilder.cs +++ b/src/Compiler/Infer/Models/ModelBuilder.cs @@ -836,9 +836,10 @@ namespace Microsoft.ML.Probabilistic.Models for (int i = 0; i < indexVars.Length; i++) { IModelExpression expr = parent.indices[i]; - if (!(expr is Range)) + if (expr is Range range) + indexVars[i] = range.GetIndexDeclaration(); + else throw new Exception(parent + ".InitializeTo is not allowed since the indices are not ranges"); - indexVars[i] = ((Range)expr).GetIndexDeclaration(); } initExpr = VariableInformation.MakePlaceHolderArrayCreate(initExpr, indexVars); parent = (Variable)parent.ArrayVariable; diff --git a/src/Compiler/Infer/Transforms/ConstantFoldingTransform.cs b/src/Compiler/Infer/Transforms/ConstantFoldingTransform.cs index 3bf0b780..a8eff3ad 100644 --- a/src/Compiler/Infer/Transforms/ConstantFoldingTransform.cs +++ b/src/Compiler/Infer/Transforms/ConstantFoldingTransform.cs @@ -321,7 +321,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms { foreach (ConditionBinding ci in conditionContext) { - // each lhs has already been replaced, so we only need to compare for equality + // each subexpression has already been replaced, so we only need to compare for equality here if (expr.Equals(ci.lhs)) return ci.rhs; } } diff --git a/src/Compiler/Infer/Transforms/DependencyAnalysisTransform.cs b/src/Compiler/Infer/Transforms/DependencyAnalysisTransform.cs index 693bdd52..108419ff 100644 --- a/src/Compiler/Infer/Transforms/DependencyAnalysisTransform.cs +++ b/src/Compiler/Infer/Transforms/DependencyAnalysisTransform.cs @@ -810,10 +810,10 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms if (parameter.Name == "length") { IExpression arg = ioce.Arguments[argIndex]; - if (arg is ILiteralExpression) + if (arg is ILiteralExpression ile) { - object argValue = ((ILiteralExpression)arg).Value; - if (argValue is int && (int)argValue == 0) + object argValue = ile.Value; + if (argValue is int i && i == 0) dependencyInformation.IsUniform = true; } break; @@ -1063,18 +1063,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms // arguments in all dependencies with the correct model expressions. IExpression MapDependency(IStatement ist) { - return ReplaceArgs(((IExpressionStatement)ist).Expression); - - // Replace the parameter expressions in an expression with the corresponding model expressions - IExpression ReplaceArgs(IExpression iExpression) - { - foreach (KeyValuePair kvp in parameterToExpressionMap) - { - int repCount = 0; - iExpression = Builder.ReplaceExpression(iExpression, kvp.Key, kvp.Value, ref repCount); - } - return iExpression; - } + return Builder.ReplaceSubexpressions(((IExpressionStatement)ist).Expression, parameterToExpressionMap); } } diff --git a/src/Compiler/Infer/Transforms/DepthCloningTransform.cs b/src/Compiler/Infer/Transforms/DepthCloningTransform.cs index 0b4f98f1..badc91fa 100644 --- a/src/Compiler/Infer/Transforms/DepthCloningTransform.cs +++ b/src/Compiler/Infer/Transforms/DepthCloningTransform.cs @@ -148,7 +148,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms } private static void AddCopyStatements(ICollection stmts, VariableInformation varInfo, int indexingDepth, IExpression lhs, IExpression rhs, - int bracket = 0, Dictionary replacements = null) + int bracket = 0, Dictionary replacements = null) { if (indexingDepth == bracket) { @@ -164,7 +164,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms int[] sizes = Util.ArrayInit(varInfo.sizes[bracket].Length, i => (int)((ILiteralExpression)varInfo.sizes[bracket][i]).Value); ForEachLiteralIndex(sizes, index => { - IExpression[] bracketIndices = Util.ArrayInit(index.Length, i => Builder.LiteralExpr(index[i])); + ILiteralExpression[] bracketIndices = Util.ArrayInit(index.Length, i => Builder.LiteralExpr(index[i])); var newLhs = Builder.ArrayIndex(lhs, bracketIndices); var newRhs = Builder.ArrayIndex(rhs, bracketIndices); for (int dim = 0; dim < index.Length; dim++) @@ -180,7 +180,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms IReadOnlyList replacedSizes = varInfo.sizes[bracket]; if(replacements != null) { - replacedSizes = Util.ArrayInit(replacedSizes.Count, i => Replace(replacedSizes[i], replacements)); + replacedSizes = Util.ArrayInit(replacedSizes.Count, i => Builder.ReplaceSubexpressions(replacedSizes[i], replacements)); } IForStatement innerForStatement; var fs = Builder.NestedForStmt(varInfo.indexVars[bracket], replacedSizes, out innerForStatement); @@ -194,15 +194,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms } } - private static IExpression Replace(IExpression expr, IReadOnlyDictionary replacements) - { - foreach(var entry in replacements) - { - expr = Builder.ReplaceExpression(expr, entry.Key, entry.Value); - } - return expr; - } - private static void ForEachLiteralIndex(int[] sizes, Action action) { int[] strides = StringUtil.ArrayStrides(sizes); diff --git a/src/Compiler/Infer/Transforms/ForwardBackwardTransform.cs b/src/Compiler/Infer/Transforms/ForwardBackwardTransform.cs index c00db364..f04ef2cd 100644 --- a/src/Compiler/Infer/Transforms/ForwardBackwardTransform.cs +++ b/src/Compiler/Infer/Transforms/ForwardBackwardTransform.cs @@ -166,10 +166,9 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms } foreach (var stmt in outputBlock) { - if (stmt is IWhileStatement) + // recursively collect the child block + if (stmt is IWhileStatement iws) { - // recursively collect the child block - IWhileStatement iws = (IWhileStatement)stmt; CollectTransformedStmts(iws.Body.Statements, replacementsInContext); // merge the child replacements into this block's replacements var childReplacements = replacementsInContext[iws.Body.Statements]; @@ -208,9 +207,8 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms var replacements = replacementsInContext[outputBlock]; foreach (var stmt in outputBlock) { - if (stmt is IWhileStatement) + if (stmt is IWhileStatement iws) { - IWhileStatement iws = (IWhileStatement)stmt; // merge this block's replacements into the child's replacements var childReplacements = replacementsInContext[iws.Body.Statements]; foreach (var entry in replacements) @@ -314,7 +312,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms OffsetInfo newOffsetInfo = new OffsetInfo(); foreach (var offset in offsetInfo) { - if (CanKeepOffsetDependency(ssinfo, reversedLoopVars, reversedLoopVarsOther, offset)) + if (CanKeepOffsetDependency(reversedLoopVarsOther, offset)) newOffsetInfo.Add(offset); else changed = true; @@ -334,19 +332,19 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms { di.offsetIndexOf.Add(pair.Key, pair.Value); } - } - private bool CanKeepOffsetDependency(SerialSchedulingInfo ssinfo, Set reversedLoopVars, Set reversedLoopVarsOther, Offset offset) - { - foreach (var loopVar in ssinfo.loopInfos.Select(info => info.loopVar)) + bool CanKeepOffsetDependency(Set reversedLoopVarsOther, Offset offset) { - bool compatible = (reversedLoopVars.Contains(loopVar) == reversedLoopVarsOther.Contains(loopVar)); - if (!compatible) - return false; - if (offset.loopVar == loopVar) - break; + foreach (var loopVar in ssinfo.loopInfos.Select(info => info.loopVar)) + { + bool compatible = (reversedLoopVars.Contains(loopVar) == reversedLoopVarsOther.Contains(loopVar)); + if (!compatible) + return false; + if (offset.loopVar == loopVar) + break; + } + return true; } - return true; } /// diff --git a/src/Compiler/Infer/Transforms/GateAnalysisTransform.cs b/src/Compiler/Infer/Transforms/GateAnalysisTransform.cs index 50c426ba..07c866d0 100644 --- a/src/Compiler/Infer/Transforms/GateAnalysisTransform.cs +++ b/src/Compiler/Infer/Transforms/GateAnalysisTransform.cs @@ -574,17 +574,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms } /// - /// Apply all bindings to expr, in order. + /// Apply all bindings to expr. /// - /// + /// Must not chain, i.e. no lhs appears in any rhs. /// /// internal static IExpression ReplaceExpression(IEnumerable bindings, IExpression expr) { foreach (ConditionBinding binding in bindings) { - int replaceCount = 0; - expr = Builder.ReplaceExpression(expr, binding.lhs, binding.rhs, ref replaceCount); + expr = Builder.ReplaceExpression(expr, binding.lhs, binding.rhs); } return expr; } diff --git a/src/Compiler/Infer/Transforms/GateTransform.cs b/src/Compiler/Infer/Transforms/GateTransform.cs index c9211b02..3dc0db3f 100644 --- a/src/Compiler/Infer/Transforms/GateTransform.cs +++ b/src/Compiler/Infer/Transforms/GateTransform.cs @@ -516,6 +516,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms out replaced); if (isDef && !replaced) { + // This call is valid because bindings is formed from condition expressions and these have been fully replaced by ConstantFoldingTransform. IExpression boundDef = GateAnalysisTransform.ReplaceExpression(bindings, definedExpression.Expression); Error($"{expr} doesn\'t match bound GateBlock def: {boundDef}"); @@ -586,8 +587,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms })); RecordReplacement(expr, toClone, !isSubset); - int replaceCount = 0; - expr = Builder.ReplaceExpression(expr, toReplace, clone, ref replaceCount); + expr = Builder.ReplaceExpression(expr, toReplace, clone); replaced = true; } diff --git a/src/Compiler/Infer/Transforms/IndexingTransform.cs b/src/Compiler/Infer/Transforms/IndexingTransform.cs index ba6e4c63..0e714c69 100644 --- a/src/Compiler/Infer/Transforms/IndexingTransform.cs +++ b/src/Compiler/Infer/Transforms/IndexingTransform.cs @@ -565,8 +565,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms int depth = Recognizer.GetIndexingDepth(index); IExpression resultSize = indexInfo.sizes[depth][0]; var indices = Recognizer.GetIndices(index); - int replaceCount = 0; - resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, ref replaceCount); + resultSize = indexInfo.ReplaceIndexVars(context, resultSize, indices, null, out int replaceCount); indexInfo.DefineIndexVarsUpToDepth(context, depth + 1); IVariableDeclaration resultIndex = indexInfo.indexVars[depth][0]; Type arrayType = arrayExpr.GetExpressionType(); diff --git a/src/Compiler/Infer/Transforms/MessageTransform.cs b/src/Compiler/Infer/Transforms/MessageTransform.cs index ce05150f..bf15145d 100644 --- a/src/Compiler/Infer/Transforms/MessageTransform.cs +++ b/src/Compiler/Infer/Transforms/MessageTransform.cs @@ -1864,8 +1864,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms List> indices = Recognizer.GetIndices(outputLhs); if (indices.Count > 0) { - int replaceCount = 0; - mpe = channelVarInfo.ReplaceIndexVars(context, mpe, indices, null, ref replaceCount); + mpe = channelVarInfo.ReplaceIndexVars(context, mpe, indices, null, out int replaceCount); } if (mai.isDistribution) { @@ -2122,13 +2121,16 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms { if (iace.Type.DotNetType.Equals(typeof(PlaceHolder)) && iace.Initializer != null && iace.Initializer.Expressions.Count == 1) { + // VariableInformation.DeriveArrayVariable has created an initializer expression of the form: + // new PlaceHolder[wildcard0,wildcard1] { initExpr[wildcard0,wildcard1] } IExpression initExpr = iace.Initializer.Expressions[0]; // replace index variables with the given indices + Dictionary replacements = new Dictionary(); for (int dim = 0; dim < iace.Dimensions.Count; dim++) { - initExpr = Builder.ReplaceExpression(initExpr, iace.Dimensions[dim], iaie.Indices[dim]); + replacements.Add(iace.Dimensions[dim], iaie.Indices[dim]); } - return initExpr; + return Builder.ReplaceSubexpressions(initExpr, replacements); } } return Builder.ArrayIndex(target, iaie.Indices); @@ -2275,8 +2277,7 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms IVariableDeclaration indexVar = indexVars[i]; IParameterDeclaration param = Builder.Param(indexVar.Name, typeof(int)); iame.Parameters.Add(param); - int replaceCount = 0; - elementInit = Builder.ReplaceExpression(elementInit, Builder.VarRefExpr(indexVar), Builder.ParamRef(param), ref replaceCount); + elementInit = Builder.ReplaceExpression(elementInit, Builder.VarRefExpr(indexVar), Builder.ParamRef(param)); } iame.Body.Statements.Add(Builder.Return(elementInit)); return iame; diff --git a/src/Compiler/Infer/Transforms/Scheduler.cs b/src/Compiler/Infer/Transforms/Scheduler.cs index b339e3d0..d0383e9e 100644 --- a/src/Compiler/Infer/Transforms/Scheduler.cs +++ b/src/Compiler/Infer/Transforms/Scheduler.cs @@ -3848,8 +3848,6 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms private void LabelEdgesWithOffsets(NodeIndex node) { - var obc = descendantOffset[node]; - var keys = this.loopVarsOfNode[node]; foreach (EdgeIndex edge in g.EdgesInto(node)) { if (direction[edge] == Direction.Unknown && !deletedEdges.Contains(edge)) diff --git a/src/Compiler/Infer/Transforms/SchedulingTransform.cs b/src/Compiler/Infer/Transforms/SchedulingTransform.cs index 85007c7e..4faa23bc 100644 --- a/src/Compiler/Infer/Transforms/SchedulingTransform.cs +++ b/src/Compiler/Infer/Transforms/SchedulingTransform.cs @@ -238,10 +238,14 @@ namespace Microsoft.ML.Probabilistic.Compiler.Transforms // this will fill in groupOf and loopInfoOfGroup BuildGroups(inputStmts, -1); g = new DependencyGraph(context, flatStmts, ignoreMissingNodes: true, ignoreRequirements: false, deleteCancels: true); - g.getTargetIndex = delegate (NodeIndex node) + bool replaceTargetIndex = false; + if (replaceTargetIndex) { - return new DependencyGraph.TargetIndex(loopMergingInfo.GetIndexOf(flatStmts[node])); - }; + g.getTargetIndex = delegate (NodeIndex node) + { + return new DependencyGraph.TargetIndex(loopMergingInfo.GetIndexOf(flatStmts[node])); + }; + } if (compiler.UseSerialSchedules && !compiler.UseExperimentalSerialSchedules) { bool anyDeleted; diff --git a/src/Compiler/TransformFramework/CodeBuilder.cs b/src/Compiler/TransformFramework/CodeBuilder.cs index 18f67396..fb15df15 100644 --- a/src/Compiler/TransformFramework/CodeBuilder.cs +++ b/src/Compiler/TransformFramework/CodeBuilder.cs @@ -1589,7 +1589,7 @@ namespace Microsoft.ML.Probabilistic.Compiler } /// - /// Finds and replaces one expression with another expression in a given expression + /// Finds and replaces one expression with another expression, everywhere it occurs /// /// The expression /// The expression to be found @@ -1598,58 +1598,97 @@ namespace Microsoft.ML.Probabilistic.Compiler /// The resulting expression public IExpression ReplaceExpression(IExpression expr, IExpression exprFind, IExpression exprReplace, ref int replaceCount) { - if (expr == null) return expr; - else if (expr.Equals(exprFind)) + int localReplaceCount = 0; + IExpression result = ReplaceSubexpressions(expr, e => { - replaceCount++; - return exprReplace; + if (e.Equals(exprFind)) + { + localReplaceCount++; + return exprReplace; + } + else + { + return null; + } + }); + replaceCount = localReplaceCount; + return result; + } + + /// + /// Replaces all subexpressions of an expression + /// + /// The expression + /// Subexpressions to find and replace + /// The replaced expression + public IExpression ReplaceSubexpressions(IExpression expr, IReadOnlyDictionary replacements) + { + return ReplaceSubexpressions(expr, e => + { + replacements.TryGetValue(e, out IExpression replacement); + return replacement; + }); + } + + /// + /// Replaces all subexpressions of an expression + /// + /// The expression + /// Returns a new expression or null for no replacement + /// The replaced expression + public IExpression ReplaceSubexpressions(IExpression expr, Func replace) + { + if (expr == null) return expr; + var replaced = replace(expr); + if (replaced != null) + { + return replaced; } + else if ((expr is IVariableDeclarationExpression) || + (expr is IVariableReferenceExpression) || + (expr is ILiteralExpression) || + (expr is IDefaultExpression) || + (expr is IArgumentReferenceExpression) || + (expr is IThisReferenceExpression)) return expr; else if (expr is IArrayIndexerExpression iaie) { IArrayIndexerExpression aie = ArrayIndxrExpr(); - foreach (IExpression ind in iaie.Indices) aie.Indices.Add(ReplaceExpression(ind, exprFind, exprReplace, ref replaceCount)); - aie.Target = ReplaceExpression(iaie.Target, exprFind, exprReplace, ref replaceCount); + foreach (IExpression ind in iaie.Indices) aie.Indices.Add(ReplaceSubexpressions(ind, replace)); + aie.Target = ReplaceSubexpressions(iaie.Target, replace); return aie; } else if (expr is IPropertyIndexerExpression ipie) { IPropertyIndexerExpression pie = PropIndxrExpr(); - foreach (IExpression ind in ipie.Indices) pie.Indices.Add(ReplaceExpression(ind, exprFind, exprReplace, ref replaceCount)); - pie.Target = (IPropertyReferenceExpression)ReplaceExpression(ipie.Target, exprFind, exprReplace, ref replaceCount); + foreach (IExpression ind in ipie.Indices) pie.Indices.Add(ReplaceSubexpressions(ind, replace)); + pie.Target = (IPropertyReferenceExpression)ReplaceSubexpressions(ipie.Target, replace); return pie; } else if (expr is ICastExpression ice) { - return CastExpr(ReplaceExpression(ice.Expression, exprFind, exprReplace, ref replaceCount), ice.TargetType); + return CastExpr(ReplaceSubexpressions(ice.Expression, replace), ice.TargetType); } else if (expr is ICheckedExpression iche) { - return CheckedExpr(ReplaceExpression(iche.Expression, exprFind, exprReplace, ref replaceCount)); + return CheckedExpr(ReplaceSubexpressions(iche.Expression, replace)); } - else if ( - (expr is IVariableDeclarationExpression) || - (expr is IVariableReferenceExpression) || - (expr is ILiteralExpression) || - (expr is IDefaultExpression) || - (expr is IArgumentReferenceExpression)) return expr; else if (expr is IPropertyReferenceExpression ipre) { - IExpression target = ReplaceExpression(ipre.Target, exprFind, exprReplace, ref replaceCount); + IExpression target = ReplaceSubexpressions(ipre.Target, replace); if (target == ipre.Target) return ipre; IPropertyReferenceExpression pre = PropRefExpr(); pre.Property = ipre.Property; pre.Target = target; return pre; } - else if (expr is IArrayCreateExpression) + else if (expr is IArrayCreateExpression iace) { - IArrayCreateExpression iace = expr as IArrayCreateExpression; var ace = ArrayCreateExpr(); ace.Type = iace.Type; - ace.Initializer = ReplaceExpression(iace.Initializer, exprFind, exprReplace, ref replaceCount) as IBlockExpression; + ace.Initializer = ReplaceSubexpressions(iace.Initializer, replace) as IBlockExpression; foreach (IExpression dim in iace.Dimensions) { - ace.Dimensions.Add(ReplaceExpression(dim, exprFind, exprReplace, ref replaceCount)); + ace.Dimensions.Add(ReplaceSubexpressions(dim, replace)); } return ace; } @@ -1658,7 +1697,7 @@ namespace Microsoft.ML.Probabilistic.Compiler IBlockExpression be = BlockExpr(); foreach (IExpression e in ible.Expressions) { - be.Expressions.Add(ReplaceExpression(e, exprFind, exprReplace, ref replaceCount)); + be.Expressions.Add(ReplaceSubexpressions(e, replace)); } return be; } @@ -1668,7 +1707,7 @@ namespace Microsoft.ML.Probabilistic.Compiler mie.Method = imie.Method; foreach (IExpression arg in imie.Arguments) { - mie.Arguments.Add(ReplaceExpression(arg, exprFind, exprReplace, ref replaceCount)); + mie.Arguments.Add(ReplaceSubexpressions(arg, replace)); } return mie; } @@ -1679,9 +1718,9 @@ namespace Microsoft.ML.Probabilistic.Compiler oce.Type = ioce.Type; foreach (IExpression arg in ioce.Arguments) { - oce.Arguments.Add(ReplaceExpression(arg, exprFind, exprReplace, ref replaceCount)); + oce.Arguments.Add(ReplaceSubexpressions(arg, replace)); } - oce.Initializer = (IBlockExpression)ReplaceExpression(ioce.Initializer, exprFind, exprReplace, ref replaceCount); + oce.Initializer = (IBlockExpression)ReplaceSubexpressions(ioce.Initializer, replace); return oce; } else if (expr is IAnonymousMethodExpression iame) @@ -1695,11 +1734,11 @@ namespace Microsoft.ML.Probabilistic.Compiler IStatement st = ist; if (ist is IExpressionStatement ies) { - st = ExprStatement(ReplaceExpression(ies.Expression, exprFind, exprReplace, ref replaceCount)); + st = ExprStatement(ReplaceSubexpressions(ies.Expression, replace)); } else if (ist is IMethodReturnStatement imrs) { - st = Return(ReplaceExpression(imrs.Expression, exprFind, exprReplace, ref replaceCount)); + st = Return(ReplaceSubexpressions(imrs.Expression, replace)); } ame.Body.Statements.Add(st); } @@ -1709,33 +1748,29 @@ namespace Microsoft.ML.Probabilistic.Compiler { IUnaryExpression ue = UnaryExpr(); ue.Operator = iue.Operator; - ue.Expression = ReplaceExpression(iue.Expression, exprFind, exprReplace, ref replaceCount); + ue.Expression = ReplaceSubexpressions(iue.Expression, replace); return ue; } else if (expr is IBinaryExpression ibe) { IBinaryExpression be = BinaryExpr(); be.Operator = ibe.Operator; - be.Left = ReplaceExpression(ibe.Left, exprFind, exprReplace, ref replaceCount); - be.Right = ReplaceExpression(ibe.Right, exprFind, exprReplace, ref replaceCount); + be.Left = ReplaceSubexpressions(ibe.Left, replace); + be.Right = ReplaceSubexpressions(ibe.Right, replace); return be; } else if (expr is IMethodReferenceExpression imre) { - var target = ReplaceExpression(imre.Target, exprFind, exprReplace, ref replaceCount); + var target = ReplaceSubexpressions(imre.Target, replace); return MethodRefExpr(imre.Method, target); } - else if (expr is IThisReferenceExpression) - { - return expr; - } else if (expr is IAddressOutExpression iaoe) { IAddressOutExpression aoe = AddrOutExpr(); - aoe.Expression = ReplaceExpression(iaoe.Expression, exprFind, exprReplace, ref replaceCount); + aoe.Expression = ReplaceSubexpressions(iaoe.Expression, replace); return aoe; } - else throw new NotImplementedException("Unhandled expression type in ReplaceExpression(): " + expr.GetType()); + else throw new NotImplementedException("Unhandled expression type in ReplaceSubexpressions(): " + expr.GetType()); } ///