Refactor how rewrite.py passes delayed match application (for typing & clarity) (#921)
* Missing import * Using a dict literal so Dict type * The Wrong Type with @awf * Clarify typing on matcher * Extra type annotation * filter_term.py type checks now * Be more explicit for mypy * Drop the set change, didn't really help anyway * SKETCH, doesn't run. Can we solve this with a closure instead of kwargs approach * Put visitors back. We might want this too, but keep separate * Clarify operation * Carried through change to match untested not tidied up * Keep rule for reference https://github.com/microsoft/knossos-ksc/pull/921#pullrequestreview-701034632 * Rename back to Match now it's clear * Add link to mypy lack of support * Refactor for clarity Revisit if profiler gets hot here with @awf * Clear comment * Comma fixes for Alan * Don't need the encapsulation now
This commit is contained in:
Родитель
432e72f6e1
Коммит
8790b5ae72
|
@ -12,6 +12,7 @@ from typing import (
|
|||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Callable,
|
||||
)
|
||||
|
||||
from pyrsistent import pmap
|
||||
|
@ -62,14 +63,9 @@ from ksc.visitors import ExprTransformer
|
|||
@dataclass(frozen=True)
|
||||
class Match:
|
||||
rule: "RuleMatcher"
|
||||
apply_rewrite: Callable[[], Expr]
|
||||
ewp: ExprWithPath
|
||||
|
||||
# Anything the RuleMatcher needs to pass from matching to rewriting.
|
||||
rule_specific_data: Mapping[str, Any] = pmap()
|
||||
|
||||
def apply_rewrite(self):
|
||||
return self.rule.apply_at(self.ewp, **self.rule_specific_data)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self.ewp.path
|
||||
|
@ -112,7 +108,7 @@ class AbstractMatcher(ABC):
|
|||
yield from self._matches_with_env(ch, env)
|
||||
|
||||
@abstractmethod
|
||||
def matches_here(self, ewp: ExprWithPath, env: Environment,) -> Iterator[Match]:
|
||||
def matches_here(self, ewp: ExprWithPath, env: Environment) -> Iterator[Match]:
|
||||
""" Return any matches which rewrite the topmost node of the specified subtree """
|
||||
|
||||
|
||||
|
@ -143,10 +139,6 @@ class RuleMatcher(AbstractMatcher):
|
|||
""" If a RuleMatcher (instance or subclass) returns true, indicates that it might match
|
||||
any Expr which is a Call, regardless of the value of get_filter_term() on that Expr. """
|
||||
|
||||
@abstractmethod
|
||||
def apply_at(self, ewp: ExprWithPath, **kwargs) -> Expr:
|
||||
""" Applies this rule at the specified <path> within <expr>. kwargs are any stored in the Match's rule_specific_data field. """
|
||||
|
||||
@abstractmethod
|
||||
def matches_for_possible_expr(
|
||||
self, ewp: ExprWithPath, env: Environment,
|
||||
|
@ -154,7 +146,7 @@ class RuleMatcher(AbstractMatcher):
|
|||
""" Returns any 'Match's acting on the topmost node of the specified Expr, given that <get_filter_term(expr)>
|
||||
is of one of <self.possible_filter_terms>. """
|
||||
|
||||
def matches_here(self, ewp: ExprWithPath, env: Environment,) -> Iterator[Match]:
|
||||
def matches_here(self, ewp: ExprWithPath, env: Environment) -> Iterator[Match]:
|
||||
if get_filter_term(ewp.expr) in self.possible_filter_terms or (
|
||||
isinstance(ewp.expr, Call) and self.may_match_any_call
|
||||
):
|
||||
|
@ -181,7 +173,7 @@ class RuleSet(AbstractMatcher):
|
|||
for term in rule.possible_filter_terms:
|
||||
self._rules_by_filter_term.setdefault(term, []).append(rule)
|
||||
|
||||
def matches_here(self, ewp: ExprWithPath, env: Environment,) -> Iterator[Match]:
|
||||
def matches_here(self, ewp: ExprWithPath, env: Environment) -> Iterator[Match]:
|
||||
possible_rules: Iterable[RuleMatcher] = self._rules_by_filter_term.get(
|
||||
get_filter_term(ewp.expr), []
|
||||
)
|
||||
|
@ -195,28 +187,28 @@ class RuleSet(AbstractMatcher):
|
|||
class inline_var(RuleMatcher):
|
||||
possible_filter_terms = frozenset([Var])
|
||||
|
||||
def apply_at(self, ewp: ExprWithPath, binding_location: Path) -> Expr:
|
||||
# binding_location comes from the Match.
|
||||
# Note there is an alternative design, where we don't store any "rule_specific_data" in the Match.
|
||||
# Thus, at application time (here), we would have to first do an extra traversal all the way down path_to_var, to identify which variable to inline (and its binding location).
|
||||
# (Followed by the same traversal as here, that does renaming-to-avoid-capture from the binding location to the variable usage.)
|
||||
assert ewp.path[: len(binding_location)] == binding_location
|
||||
return replace_subtree(
|
||||
ewp.root,
|
||||
binding_location,
|
||||
Const(0.0), # Nothing to avoid capturing in outer call
|
||||
lambda _zero, let: replace_subtree(
|
||||
let, ewp.path[len(binding_location) :], let.rhs
|
||||
), # No applicator; renaming will prevent capturing let.rhs, so just insert that
|
||||
)
|
||||
|
||||
def matches_for_possible_expr(
|
||||
self, ewp: ExprWithPath, env: Environment,
|
||||
) -> Iterator[Match]:
|
||||
assert isinstance(ewp.expr, Var)
|
||||
binding_location = env.let_vars.get(ewp.expr.name)
|
||||
if binding_location is not None:
|
||||
yield Match(self, ewp, {"binding_location": binding_location})
|
||||
|
||||
if ewp.expr.name not in env.let_vars:
|
||||
return
|
||||
|
||||
binding_location: Path = env.let_vars[ewp.expr.name]
|
||||
|
||||
def apply() -> Expr:
|
||||
assert ewp.path[: len(binding_location)] == binding_location
|
||||
return replace_subtree(
|
||||
ewp.root,
|
||||
binding_location,
|
||||
Const(0.0), # Nothing to avoid capturing in outer call
|
||||
lambda _zero, let: replace_subtree(
|
||||
let, ewp.path[len(binding_location) :], let.rhs
|
||||
), # No applicator; renaming will prevent capturing let.rhs, so just insert that
|
||||
)
|
||||
|
||||
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
|
||||
|
||||
|
||||
@singleton
|
||||
|
@ -227,56 +219,61 @@ class inline_call(RuleMatcher):
|
|||
def matches_for_possible_expr(
|
||||
self, ewp: ExprWithPath, env: Environment
|
||||
) -> Iterator[Match]:
|
||||
func_def: Optional[Def] = env.defs.get(ewp.expr.name)
|
||||
if func_def is not None:
|
||||
yield Match(self, ewp, {"func_def": func_def})
|
||||
if ewp.expr.name not in env.defs:
|
||||
return
|
||||
|
||||
def apply_at(self, ewp: ExprWithPath, func_def: Def) -> Expr:
|
||||
# func_def comes from the Match.
|
||||
def apply_here(const_zero, call_node):
|
||||
call_arg = (
|
||||
call_node.args[0]
|
||||
if len(call_node.args) == 1
|
||||
else make_prim_call(StructuredName.from_str("tuple"), call_node.args)
|
||||
)
|
||||
return (
|
||||
Let(func_def.args[0], call_arg, func_def.body)
|
||||
if len(func_def.args) == 1
|
||||
else untuple_one_let(Let(func_def.args, call_arg, func_def.body))
|
||||
)
|
||||
func_def: Def = env.defs[ewp.expr.name]
|
||||
|
||||
arg_names = frozenset([arg.name for arg in func_def.args])
|
||||
assert func_def.body.free_vars_.issubset(
|
||||
arg_names
|
||||
) # Sets may not be equal, if some args unused.
|
||||
# There is thus nothing in the function body that could be captured by 'let's around the callsite.
|
||||
# Thus, TODO: simplify interface to replace_subtree: only inline_let requires capture-avoidance, and
|
||||
# there is no need to support both capture-avoidance + applicator in the same call to replace_subtree.
|
||||
# In the meantime, the 0.0 here (as elsewhere) indicates there are no variables to avoid capturing.
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
def apply() -> Expr:
|
||||
def apply_here(const_zero, call_node):
|
||||
call_arg = (
|
||||
call_node.args[0]
|
||||
if len(call_node.args) == 1
|
||||
else make_prim_call(
|
||||
StructuredName.from_str("tuple"), call_node.args
|
||||
)
|
||||
)
|
||||
return (
|
||||
Let(func_def.args[0], call_arg, func_def.body)
|
||||
if len(func_def.args) == 1
|
||||
else untuple_one_let(Let(func_def.args, call_arg, func_def.body))
|
||||
)
|
||||
|
||||
arg_names = frozenset([arg.name for arg in func_def.args])
|
||||
assert func_def.body.free_vars_.issubset(
|
||||
arg_names
|
||||
) # Sets may not be equal, if some args unused.
|
||||
# There is thus nothing in the function body that could be captured by 'let's around the callsite.
|
||||
# Thus, TODO: simplify interface to replace_subtree: only inline_let requires capture-avoidance, and
|
||||
# there is no need to support both capture-avoidance + applicator in the same call to replace_subtree.
|
||||
# In the meantime, the 0.0 here (as elsewhere) indicates there are no variables to avoid capturing.
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
|
||||
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
|
||||
|
||||
|
||||
@singleton
|
||||
class delete_let(RuleMatcher):
|
||||
possible_filter_terms = frozenset([Let])
|
||||
|
||||
def apply_at(self, ewp: ExprWithPath) -> Expr:
|
||||
def apply_here(const_zero: Expr, let_node: Expr) -> Expr:
|
||||
assert const_zero == Const(0.0) # Passed to replace_subtree below
|
||||
assert let_node is ewp.expr
|
||||
assert isinstance(let_node, Let)
|
||||
assert let_node.vars.name not in let_node.body.free_vars_
|
||||
return let_node.body
|
||||
|
||||
# The constant just has no free variables that we want to avoid being captured
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
|
||||
def matches_for_possible_expr(
|
||||
self, ewp: ExprWithPath, env: Environment
|
||||
) -> Iterator[Match]:
|
||||
assert isinstance(ewp.expr, Let)
|
||||
if ewp.vars.name not in ewp.body.free_vars_:
|
||||
yield Match(self, ewp)
|
||||
|
||||
def apply() -> Expr:
|
||||
def apply_here(const_zero: Expr, let_node: Expr) -> Expr:
|
||||
assert const_zero == Const(0.0) # Passed to replace_subtree below
|
||||
assert let_node is ewp.expr
|
||||
assert isinstance(let_node, Let)
|
||||
assert let_node.vars.name not in let_node.body.free_vars_
|
||||
return let_node.body
|
||||
|
||||
# The constant just has no free variables that we want to avoid being captured
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
|
||||
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
@ -330,22 +327,23 @@ class ParsedRuleMatcher(RuleMatcher):
|
|||
# the result will then be replacement[subst].
|
||||
substs = find_template_subst(self._rule.template, ewp.expr, self._arg_types)
|
||||
if substs is not None and self._side_conditions(**substs):
|
||||
yield Match(self, ewp, substs)
|
||||
|
||||
def apply_at(self, ewp: ExprWithPath, **substs: VariableSubstitution) -> Expr:
|
||||
def apply_here(const_zero: Expr, target: Expr) -> Expr:
|
||||
assert const_zero == Const(0.0) # Passed to replace_subtree below
|
||||
assert are_alpha_equivalent(
|
||||
SubstPattern.visit(self._rule.template, substs), target
|
||||
) # Note this traverses, so expensive.
|
||||
result = SubstPattern.visit(self._rule.replacement, substs)
|
||||
# Types copied from the template (down to the variables, and the subject-expr's types from there).
|
||||
# So there should be no need for any further type-propagation.
|
||||
assert result.type_ == target.type_
|
||||
return result
|
||||
def apply() -> Expr:
|
||||
def apply_here(const_zero: Expr, target: Expr) -> Expr:
|
||||
assert const_zero == Const(0.0) # Passed to replace_subtree below
|
||||
assert are_alpha_equivalent(
|
||||
SubstPattern.visit(self._rule.template, substs), target
|
||||
) # Note this traverses, so expensive.
|
||||
result = SubstPattern.visit(self._rule.replacement, substs)
|
||||
# Types copied from the template (down to the variables, and the subject-expr's types from there).
|
||||
# So there should be no need for any further type-propagation.
|
||||
assert result.type_ == target.type_
|
||||
return result
|
||||
|
||||
# The constant just has no free variables that we want to avoid being captured
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
# The constant just has no free variables that we want to avoid being captured
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
|
||||
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
|
||||
|
||||
|
||||
def _combine_substs(
|
||||
|
@ -509,7 +507,7 @@ class SubstPattern(ExprTransformer):
|
|||
target_var, var_names_to_exprs = _maybe_add_binder_to_subst(
|
||||
l.arg, var_names_to_exprs, [l.body]
|
||||
)
|
||||
return Lam(target_var, self.visit(l.body, var_names_to_exprs), type=l.type_,)
|
||||
return Lam(target_var, self.visit(l.body, var_names_to_exprs), type=l.type_)
|
||||
|
||||
|
||||
def parse_rule_str(ks_str, symtab, **kwargs):
|
||||
|
|
|
@ -22,20 +22,25 @@ class ConstantFolder(RuleMatcher):
|
|||
def possible_filter_terms(self) -> FrozenSet[FilterTerm]:
|
||||
return frozenset([self._name])
|
||||
|
||||
def apply_at(self, ewp: ExprWithPath, **kwargs) -> Expr:
|
||||
def apply_here(const_zero: Expr, subtree: Expr):
|
||||
assert const_zero == Const(0.0) # Payload passed to replace_subtree below
|
||||
assert isinstance(subtree, Call) and subtree.name == self._name
|
||||
return Const(self._native_impl(*[arg.value for arg in subtree.args]))
|
||||
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
|
||||
def matches_for_possible_expr(
|
||||
self, ewp: ExprWithPath, env: Environment,
|
||||
) -> Iterator[Match]:
|
||||
assert isinstance(ewp.expr, Call) and ewp.name == self._name
|
||||
if all(isinstance(arg, Const) for arg in ewp.expr.args):
|
||||
yield Match(self, ewp)
|
||||
|
||||
def apply() -> Expr:
|
||||
def apply_here(const_zero: Expr, subtree: Expr):
|
||||
assert const_zero == Const(
|
||||
0.0
|
||||
) # Payload passed to replace_subtree below
|
||||
assert isinstance(subtree, Call) and subtree.name == self._name
|
||||
return Const(
|
||||
self._native_impl(*[arg.value for arg in subtree.args])
|
||||
)
|
||||
|
||||
return replace_subtree(ewp.root, ewp.path, Const(0.0), apply_here)
|
||||
|
||||
yield Match(ewp=ewp, rule=self, apply_rewrite=apply)
|
||||
|
||||
|
||||
constant_folding_rules = [ConstantFolder(sn, func) for sn, func in native_impls.items()]
|
||||
|
|
Загрузка…
Ссылка в новой задаче