Refactor how rewrite.py passes delayed match application (for typing & clarity) (#921)

* Missing import

* Using a dict literal so Dict type

* The Wrong Type
with @awf

* Clarify typing on matcher

* Extra type annotation

* filter_term.py type checks now

* Be more explicit for mypy

* Drop the set change,
didn't really help anyway

* SKETCH, doesn't run.
Can we solve this with a closure instead of kwargs approach

* Put visitors back.
We might want this too, but keep separate

* Clarify operation

* Carried through change to match
untested
not tidied up

* Keep rule for reference
https://github.com/microsoft/knossos-ksc/pull/921#pullrequestreview-701034632

* Rename back to Match now it's clear

* Add link to mypy lack of support

* Refactor for clarity
Revisit if profiler gets hot here
with @awf

* Clear comment

* Comma fixes for Alan

* Don't need the encapsulation now
This commit is contained in:
Colin Gravill 2021-07-14 17:44:59 +01:00 коммит произвёл GitHub
Родитель 432e72f6e1
Коммит 8790b5ae72
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 95 добавлений и 92 удалений

Просмотреть файл

@ -12,6 +12,7 @@ from typing import (
Mapping,
Optional,
Tuple,
Callable,
)
from pyrsistent import pmap
@ -62,14 +63,9 @@ from ksc.visitors import ExprTransformer
@dataclass(frozen=True)
class Match:
rule: "RuleMatcher"
apply_rewrite: Callable[[], Expr]
ewp: ExprWithPath
# Anything the RuleMatcher needs to pass from matching to rewriting.
rule_specific_data: Mapping[str, Any] = pmap()
def apply_rewrite(self):
return self.rule.apply_at(self.ewp, **self.rule_specific_data)
@property
def path(self):
return self.ewp.path
@ -112,7 +108,7 @@ class AbstractMatcher(ABC):
yield from self._matches_with_env(ch, env)
@abstractmethod
def matches_here(self, ewp: ExprWithPath, env: Environment,) -> Iterator[Match]:
def matches_here(self, ewp: ExprWithPath, env: Environment) -> Iterator[Match]:
""" Return any matches which rewrite the topmost node of the specified subtree """
@ -143,10 +139,6 @@ class RuleMatcher(AbstractMatcher):
""" If a RuleMatcher (instance or subclass) returns true, indicates that it might match
any Expr which is a Call, regardless of the value of get_filter_term() on that Expr. """
@abstractmethod
def apply_at(self, ewp: ExprWithPath, **kwargs) -> Expr:
""" Applies this rule at the specified <path> within <expr>. kwargs are any stored in the Match's rule_specific_data field. """
@abstractmethod
def matches_for_possible_expr(
self, ewp: ExprWithPath, env: Environment,
@ -154,7 +146,7 @@ class RuleMatcher(AbstractMatcher):
""" Returns any 'Match's acting on the topmost node of the specified Expr, given that <get_filter_term(expr)>
is of one of <self.possible_filter_terms>. """
def matches_here(self, ewp: ExprWithPath, env: Environment,) -> Iterator[Match]:
def matches_here(self, ewp: ExprWithPath, env: Environment) -> Iterator[Match]:
if get_filter_term(ewp.expr) in self.possible_filter_terms or (
isinstance(ewp.expr, Call) and self.may_match_any_call
):
@ -181,7 +173,7 @@ class RuleSet(AbstractMatcher):
for term in rule.possible_filter_terms:
self._rules_by_filter_term.setdefault(term, []).append(rule)
def matches_here(self, ewp: ExprWithPath, env: Environment,) -> Iterator[Match]:
def matches_here(self, ewp: ExprWithPath, env: Environment) -> Iterator[Match]:
possible_rules: Iterable[RuleMatcher] = self._rules_by_filter_term.get(
get_filter_term(ewp.expr), []
)
@ -195,28 +187,28 @@ class RuleSet(AbstractMatcher):
class inline_var(RuleMatcher):
possible_filter_terms = frozenset([Var])
def apply_at(self, ewp: ExprWithPath, binding_location: Path) -> Expr:
# binding_location comes from the Match.
# Note there is an alternative design, where we don't store any "rule_specific_data" in the Match.
# Thus, at application time (here), we would have to first do an extra traversal all the way down path_to_var, to identify which variable to inline (and its binding location).
# (Followed by the same traversal as here, that does renaming-to-avoid-capture from the binding location to the variable usage.)
assert ewp.path[: len(binding_location)] == binding_location
return replace_subtree(
ewp.root,
binding_location,
Const(0.0), # Nothing to avoid capturing in outer call
lambda _zero, let: replace_subtree(
let, ewp.path[len(binding_location) :], let.rhs
), # No applicator; renaming will prevent capturing let.rhs, so just insert that
)
def matches_for_possible_expr(
self, ewp: ExprWithPath, env: Environment,
) -> Iterator[Match]:
assert isinstance(ewp.expr, Var)
binding_location = env.let_vars.get(ewp.expr.name)
if binding_location is not None:
yield Match(self, ewp, {"binding_location": binding_location})
if ewp.expr.name not in env.let_vars:
return
binding_location: Path = env.let_vars[ewp.expr.name]
def apply() -> Expr:
assert ewp.path[: len(binding_location)] == binding_location
return replace_subtree(
ewp.root,
binding_location,
Const(0.0), # Nothing to avoid capturing in outer call
lambda _zero, let: replace_subtree(
let, ewp.path[len(binding_location) :], let.rhs
), # No applicator; renaming will prevent capturing let.rhs, so just insert that
)
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
@singleton
@ -227,56 +219,61 @@ class inline_call(RuleMatcher):
def matches_for_possible_expr(
self, ewp: ExprWithPath, env: Environment
) -> Iterator[Match]:
func_def: Optional[Def] = env.defs.get(ewp.expr.name)
if func_def is not None:
yield Match(self, ewp, {"func_def": func_def})
if ewp.expr.name not in env.defs:
return
def apply_at(self, ewp: ExprWithPath, func_def: Def) -> Expr:
# func_def comes from the Match.
def apply_here(const_zero, call_node):
call_arg = (
call_node.args[0]
if len(call_node.args) == 1
else make_prim_call(StructuredName.from_str("tuple"), call_node.args)
)
return (
Let(func_def.args[0], call_arg, func_def.body)
if len(func_def.args) == 1
else untuple_one_let(Let(func_def.args, call_arg, func_def.body))
)
func_def: Def = env.defs[ewp.expr.name]
arg_names = frozenset([arg.name for arg in func_def.args])
assert func_def.body.free_vars_.issubset(
arg_names
) # Sets may not be equal, if some args unused.
# There is thus nothing in the function body that could be captured by 'let's around the callsite.
# Thus, TODO: simplify interface to replace_subtree: only inline_let requires capture-avoidance, and
# there is no need to support both capture-avoidance + applicator in the same call to replace_subtree.
# In the meantime, the 0.0 here (as elsewhere) indicates there are no variables to avoid capturing.
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
def apply() -> Expr:
def apply_here(const_zero, call_node):
call_arg = (
call_node.args[0]
if len(call_node.args) == 1
else make_prim_call(
StructuredName.from_str("tuple"), call_node.args
)
)
return (
Let(func_def.args[0], call_arg, func_def.body)
if len(func_def.args) == 1
else untuple_one_let(Let(func_def.args, call_arg, func_def.body))
)
arg_names = frozenset([arg.name for arg in func_def.args])
assert func_def.body.free_vars_.issubset(
arg_names
) # Sets may not be equal, if some args unused.
# There is thus nothing in the function body that could be captured by 'let's around the callsite.
# Thus, TODO: simplify interface to replace_subtree: only inline_let requires capture-avoidance, and
# there is no need to support both capture-avoidance + applicator in the same call to replace_subtree.
# In the meantime, the 0.0 here (as elsewhere) indicates there are no variables to avoid capturing.
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
@singleton
class delete_let(RuleMatcher):
possible_filter_terms = frozenset([Let])
def apply_at(self, ewp: ExprWithPath) -> Expr:
def apply_here(const_zero: Expr, let_node: Expr) -> Expr:
assert const_zero == Const(0.0) # Passed to replace_subtree below
assert let_node is ewp.expr
assert isinstance(let_node, Let)
assert let_node.vars.name not in let_node.body.free_vars_
return let_node.body
# The constant just has no free variables that we want to avoid being captured
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
def matches_for_possible_expr(
self, ewp: ExprWithPath, env: Environment
) -> Iterator[Match]:
assert isinstance(ewp.expr, Let)
if ewp.vars.name not in ewp.body.free_vars_:
yield Match(self, ewp)
def apply() -> Expr:
def apply_here(const_zero: Expr, let_node: Expr) -> Expr:
assert const_zero == Const(0.0) # Passed to replace_subtree below
assert let_node is ewp.expr
assert isinstance(let_node, Let)
assert let_node.vars.name not in let_node.body.free_vars_
return let_node.body
# The constant just has no free variables that we want to avoid being captured
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
###############################################################################
@ -330,22 +327,23 @@ class ParsedRuleMatcher(RuleMatcher):
# the result will then be replacement[subst].
substs = find_template_subst(self._rule.template, ewp.expr, self._arg_types)
if substs is not None and self._side_conditions(**substs):
yield Match(self, ewp, substs)
def apply_at(self, ewp: ExprWithPath, **substs: VariableSubstitution) -> Expr:
def apply_here(const_zero: Expr, target: Expr) -> Expr:
assert const_zero == Const(0.0) # Passed to replace_subtree below
assert are_alpha_equivalent(
SubstPattern.visit(self._rule.template, substs), target
) # Note this traverses, so expensive.
result = SubstPattern.visit(self._rule.replacement, substs)
# Types copied from the template (down to the variables, and the subject-expr's types from there).
# So there should be no need for any further type-propagation.
assert result.type_ == target.type_
return result
def apply() -> Expr:
def apply_here(const_zero: Expr, target: Expr) -> Expr:
assert const_zero == Const(0.0) # Passed to replace_subtree below
assert are_alpha_equivalent(
SubstPattern.visit(self._rule.template, substs), target
) # Note this traverses, so expensive.
result = SubstPattern.visit(self._rule.replacement, substs)
# Types copied from the template (down to the variables, and the subject-expr's types from there).
# So there should be no need for any further type-propagation.
assert result.type_ == target.type_
return result
# The constant just has no free variables that we want to avoid being captured
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
# The constant just has no free variables that we want to avoid being captured
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
def _combine_substs(
@ -509,7 +507,7 @@ class SubstPattern(ExprTransformer):
target_var, var_names_to_exprs = _maybe_add_binder_to_subst(
l.arg, var_names_to_exprs, [l.body]
)
return Lam(target_var, self.visit(l.body, var_names_to_exprs), type=l.type_,)
return Lam(target_var, self.visit(l.body, var_names_to_exprs), type=l.type_)
def parse_rule_str(ks_str, symtab, **kwargs):

Просмотреть файл

@ -22,20 +22,25 @@ class ConstantFolder(RuleMatcher):
def possible_filter_terms(self) -> FrozenSet[FilterTerm]:
return frozenset([self._name])
def apply_at(self, ewp: ExprWithPath, **kwargs) -> Expr:
def apply_here(const_zero: Expr, subtree: Expr):
assert const_zero == Const(0.0) # Payload passed to replace_subtree below
assert isinstance(subtree, Call) and subtree.name == self._name
return Const(self._native_impl(*[arg.value for arg in subtree.args]))
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
def matches_for_possible_expr(
self, ewp: ExprWithPath, env: Environment,
) -> Iterator[Match]:
assert isinstance(ewp.expr, Call) and ewp.name == self._name
if all(isinstance(arg, Const) for arg in ewp.expr.args):
yield Match(self, ewp)
def apply() -> Expr:
def apply_here(const_zero: Expr, subtree: Expr):
assert const_zero == Const(
0.0
) # Payload passed to replace_subtree below
assert isinstance(subtree, Call) and subtree.name == self._name
return Const(
self._native_impl(*[arg.value for arg in subtree.args])
)
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
constant_folding_rules = [ConstantFolder(sn, func) for sn, func in native_impls.items()]