add expr simplify and canonical
This commit is contained in:
Родитель
77345051bd
Коммит
fc4ba79621
|
@ -1,9 +1,13 @@
|
|||
"""Base class of symbolic expression"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from numbers import Number as _Number
|
||||
from . import op as _op
|
||||
from . import var_name as _name
|
||||
|
||||
__addop__ = None
|
||||
__subop__ = None
|
||||
__mulop__ = None
|
||||
__divop__ = None
|
||||
|
||||
class Expr(object):
|
||||
"""Base class of expression.
|
||||
|
||||
|
@ -20,28 +24,28 @@ class Expr(object):
|
|||
return ()
|
||||
|
||||
def __add__(self, other):
|
||||
return BinaryOpExpr(_op.add, self, other)
|
||||
return BinaryOpExpr(__addop__, self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
def __sub__(self, other):
|
||||
return BinaryOpExpr(_op.sub, self, other)
|
||||
return BinaryOpExpr(__subop__, self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return BinaryOpExpr(_op.sub, other, self)
|
||||
return BinaryOpExpr(__subop__, other, self)
|
||||
|
||||
def __mul__(self, other):
|
||||
return BinaryOpExpr(_op.mul, self, other)
|
||||
return BinaryOpExpr(__mulop__, self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return BinaryOpExpr(_op.mul, other, self)
|
||||
return BinaryOpExpr(__mulop__, other, self)
|
||||
|
||||
def __div__(self, other):
|
||||
return BinaryOpExpr(_op.div, self, other)
|
||||
return BinaryOpExpr(__divop__, self, other)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
return BinaryOpExpr(_op.div, other, self)
|
||||
return BinaryOpExpr(__divop__, other, self)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.__div__(other)
|
||||
|
@ -75,7 +79,8 @@ class Var(Expr):
|
|||
optional name to the var.
|
||||
"""
|
||||
def __init__(self, name=None):
|
||||
self.name = name if name else _name.NameManager.current.get(name)
|
||||
if name is None: name = 'i'
|
||||
self.name = _name.NameManager.current.get(name)
|
||||
|
||||
|
||||
class ConstExpr(Expr):
|
||||
|
@ -95,8 +100,7 @@ class BinaryOpExpr(Expr):
|
|||
def children(self):
|
||||
return (self.lhs, self.rhs)
|
||||
|
||||
_op.binary_op_cls = BinaryOpExpr
|
||||
|
||||
|
||||
class UnaryOpExpr(Expr):
|
||||
"""Unary operator expression."""
|
||||
def __init__(self, op, src):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Utilities to manipulate expression"""
|
||||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
from . import op as _op
|
||||
|
||||
def expr_with_new_children(e, children):
|
||||
"""Returns same expr as e but with new children
|
||||
|
@ -48,6 +49,7 @@ def transform(e, f):
|
|||
result : return value of f
|
||||
The final result of transformation.
|
||||
"""
|
||||
assert isinstance(e, _expr.Expr)
|
||||
return f(e , [transform(c, f) for c in e.children()])
|
||||
|
||||
|
||||
|
@ -77,6 +79,32 @@ def format_str(expr):
|
|||
raise TypeError("Do not know how to handle type " + str(type(e)))
|
||||
return transform(expr, make_str)
|
||||
|
||||
def simplify(expr):
|
||||
"""simplify expression
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
Input expression
|
||||
|
||||
Returns
|
||||
-------
|
||||
e : Expr
|
||||
Simplified expression
|
||||
"""
|
||||
def canonical(e, result_children):
|
||||
if isinstance(e, _expr.BinaryOpExpr):
|
||||
return e.op.canonical(result_children[0], result_children[1])
|
||||
elif isinstance(e, _expr.UnaryOpExpr):
|
||||
return e.op.canonical(result_children[0])
|
||||
elif isinstance(e, _expr.ConstExpr):
|
||||
return {_op.constant_canonical_key: e.value}
|
||||
elif isinstance(e, _expr.Var):
|
||||
return {e: 1}
|
||||
else:
|
||||
raise TypeError("Do not know how to handle type " + str(type(e)))
|
||||
return _op.canonical_to_expr(transform(expr, canonical))
|
||||
|
||||
|
||||
def bind(expr, update_dict):
|
||||
"""Replace the variable in e by specification from kwarg
|
||||
|
|
|
@ -1,30 +1,122 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from . import expr as _expr
|
||||
|
||||
_binary_op_cls = None
|
||||
constant_canonical_key = '__constant__'
|
||||
|
||||
def canonical_to_expr(c):
|
||||
elements = []
|
||||
for k, v in sorted(c.items()):
|
||||
if k == constant_canonical_key:
|
||||
elements.append(_expr.const(v))
|
||||
elif v == 0:
|
||||
continue
|
||||
elif v == 1:
|
||||
elements.append(k)
|
||||
else:
|
||||
elements.append(k * v)
|
||||
if elements:
|
||||
expr = elements[0]
|
||||
for i in range(1, len(elements)):
|
||||
expr = expr + elements[i]
|
||||
return expr
|
||||
else:
|
||||
return _expr.const(0)
|
||||
|
||||
class BinaryOp(object):
|
||||
"""Base class of binary operator"""
|
||||
def __call__(self, lhs, rhs):
|
||||
return _binary_op_cls(self, lhs, rhs)
|
||||
return _expr.BinaryOpExpr(self, lhs, rhs)
|
||||
|
||||
class AddOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s + %s)' % (lhs, rhs)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
lhs = lhs.copy()
|
||||
for k, v in rhs.items():
|
||||
if k in lhs:
|
||||
lhs[k] += v
|
||||
else:
|
||||
lhs[k] = v
|
||||
return lhs
|
||||
|
||||
class SubOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s - %s)' % (lhs, rhs)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
lhs = lhs.copy()
|
||||
for k, v in rhs.items():
|
||||
if k in lhs:
|
||||
lhs[k] -= v
|
||||
else:
|
||||
lhs[k] = -v
|
||||
return lhs
|
||||
|
||||
class MulOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s * %s)' % (lhs, rhs)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
elhs = canonical_to_expr(lhs)
|
||||
erhs = canonical_to_expr(rhs)
|
||||
if isinstance(erhs, _expr.ConstExpr):
|
||||
lhs = lhs.copy()
|
||||
for k, v in lhs.items():
|
||||
lhs[k] *= erhs.value
|
||||
return lhs
|
||||
if isinstance(elhs, _expr.ConstExpr):
|
||||
rhs = rhs.copy()
|
||||
for k, v in rhs.items():
|
||||
rhs[k] *= elhs.value
|
||||
return rhs
|
||||
return {elhs * erhs: 1}
|
||||
|
||||
class DivOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return '(%s / %s)' % (lhs, rhs)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
erhs = canonical_to_expr(rhs)
|
||||
if isinstance(erhs, _expr.ConstExpr):
|
||||
lhs = lhs.copy()
|
||||
for k, v in lhs.items():
|
||||
lhs[k] /= erhs.value
|
||||
return lhs
|
||||
elhs = canonical_to_expr(lhs)
|
||||
return {elhs / erhs: 1}
|
||||
|
||||
class MaxOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return 'max(%s, %s)' % (lhs, rhs)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
diff = SubOp().canonical(lhs, rhs)
|
||||
ediff = canonical_to_expr(diff)
|
||||
if isinstance(ediff, _expr.ConstExpr):
|
||||
return lhs if ediff.value >= 0 else rhs
|
||||
return {MaxOp()(lhs, rhs): 1}
|
||||
|
||||
class MinOp(BinaryOp):
|
||||
def format_str(self, lhs, rhs):
|
||||
return 'min(%s, %s)' % (lhs, rhs)
|
||||
|
||||
def canonical(self, lhs, rhs):
|
||||
diff = SubOp().canonical(lhs, rhs)
|
||||
ediff = canonical_to_expr(diff)
|
||||
if isinstance(ediff, _expr.ConstExpr):
|
||||
return rhs if ediff.value >= 0 else lhs
|
||||
return {MinOp()(lhs, rhs): 1}
|
||||
|
||||
|
||||
add = AddOp()
|
||||
sub = SubOp()
|
||||
mul = MulOp()
|
||||
div = DivOp()
|
||||
max = MaxOp()
|
||||
min = MinOp()
|
||||
|
||||
_expr.__addop__ = add
|
||||
_expr.__subop__ = sub
|
||||
_expr.__mulop__ = mul
|
||||
_expr.__divop__ = div
|
||||
|
|
|
@ -9,12 +9,24 @@ def test_bind():
|
|||
|
||||
|
||||
def test_basic():
|
||||
a= tvm.Var('a')
|
||||
a = tvm.Var('a')
|
||||
b = tvm.Var('b')
|
||||
c = a + b
|
||||
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
|
||||
|
||||
def test_simplify():
|
||||
a = tvm.Var('a')
|
||||
b = tvm.Var('b')
|
||||
e1 = a * (2 + 1) + b * 1
|
||||
e2 = a * (2 + 1) - b * 1
|
||||
e3 = tvm.max(a * 3.3 + 5, 3 + 3.3 * a)
|
||||
e4 = a - a
|
||||
assert tvm.format_str(tvm.simplify(e1)) == '((%s * 3) + %s)' % (a.name, b.name)
|
||||
assert tvm.format_str(tvm.simplify(e2)) == '((%s * 3) + (%s * -1))' % (a.name, b.name)
|
||||
assert tvm.format_str(tvm.simplify(e3)) == '((%s * 3.3) + 5)' % (a.name)
|
||||
assert tvm.format_str(tvm.simplify(e4)) == '0'
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simplify()
|
||||
test_basic()
|
||||
test_bind()
|
||||
|
|
Загрузка…
Ссылка в новой задаче