[OP] enable binary op
This commit is contained in:
Родитель
1a7fb9f969
Коммит
8de0a08330
|
@ -6,6 +6,7 @@
|
|||
#ifndef TVM_OP_H_
|
||||
#define TVM_OP_H_
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <string>
|
||||
#include "./expr.h"
|
||||
|
||||
|
@ -14,6 +15,8 @@ namespace tvm {
|
|||
/*! \brief binary operator */
|
||||
class BinaryOp {
|
||||
public:
|
||||
// virtual destructor
|
||||
virtual ~BinaryOp() {}
|
||||
/*! \return the function name to be called in binary op */
|
||||
virtual const char* FunctionName() const = 0;
|
||||
/*!
|
||||
|
@ -23,6 +26,11 @@ class BinaryOp {
|
|||
* \return the result expr
|
||||
*/
|
||||
Expr operator()(Expr lhs, Expr rhs) const;
|
||||
/*!
|
||||
* \brief get binary op by name
|
||||
* \param name name of operator
|
||||
*/
|
||||
static const BinaryOp* Get(const char* name);
|
||||
};
|
||||
|
||||
|
||||
|
@ -37,6 +45,11 @@ class UnaryOp {
|
|||
* \return the result expr
|
||||
*/
|
||||
Expr operator()(Expr src) const;
|
||||
/*!
|
||||
* \brief get unary op by name
|
||||
* \param name name of operator
|
||||
*/
|
||||
static const UnaryOp* Get(const char* name);
|
||||
};
|
||||
|
||||
|
||||
|
@ -45,7 +58,6 @@ class AddOp : public BinaryOp {
|
|||
const char* FunctionName() const override {
|
||||
return "+";
|
||||
}
|
||||
static AddOp* Get();
|
||||
};
|
||||
|
||||
|
||||
|
@ -54,7 +66,6 @@ class SubOp : public BinaryOp {
|
|||
const char* FunctionName() const override {
|
||||
return "-";
|
||||
}
|
||||
static SubOp* Get();
|
||||
};
|
||||
|
||||
|
||||
|
@ -63,7 +74,6 @@ class MulOp : public BinaryOp {
|
|||
const char* FunctionName() const override {
|
||||
return "*";
|
||||
}
|
||||
static MulOp* Get();
|
||||
};
|
||||
|
||||
|
||||
|
@ -72,7 +82,6 @@ class DivOp : public BinaryOp {
|
|||
const char* FunctionName() const override {
|
||||
return "/";
|
||||
}
|
||||
static DivOp* Get();
|
||||
};
|
||||
|
||||
|
||||
|
@ -81,7 +90,6 @@ class MaxOp : public BinaryOp {
|
|||
const char* FunctionName() const override {
|
||||
return "max";
|
||||
}
|
||||
static MaxOp* Get();
|
||||
};
|
||||
|
||||
|
||||
|
@ -90,32 +98,57 @@ class MinOp : public BinaryOp {
|
|||
const char* FunctionName() const override {
|
||||
return "min";
|
||||
}
|
||||
static MinOp* Get();
|
||||
};
|
||||
|
||||
#define DEFINE_OP_OVERLOAD(OpChar, OpName) \
|
||||
#define DEFINE_BINARY_OP_OVERLOAD(OpChar) \
|
||||
inline Expr operator OpChar (Expr lhs, Expr rhs) { \
|
||||
return (*OpName::Get())(lhs, rhs); \
|
||||
static const BinaryOp* op = BinaryOp::Get(#OpChar); \
|
||||
return (*op)(lhs, rhs); \
|
||||
}
|
||||
|
||||
#define DEFINE_BINARY_OP_FUNCTION(FuncName, OpName) \
|
||||
inline Expr FuncName(Expr lhs, Expr rhs) { \
|
||||
return (*OpName::Get())(lhs, rhs); \
|
||||
#define DEFINE_BINARY_OP_FUNCTION(FuncName) \
|
||||
inline Expr FuncName(Expr lhs, Expr rhs) { \
|
||||
static const BinaryOp* op = BinaryOp::Get(#FuncName); \
|
||||
return (*op)(lhs, rhs); \
|
||||
}
|
||||
|
||||
DEFINE_OP_OVERLOAD(+, AddOp);
|
||||
DEFINE_OP_OVERLOAD(-, SubOp);
|
||||
DEFINE_OP_OVERLOAD(*, MulOp);
|
||||
DEFINE_OP_OVERLOAD(/, DivOp);
|
||||
DEFINE_BINARY_OP_OVERLOAD(+);
|
||||
DEFINE_BINARY_OP_OVERLOAD(-);
|
||||
DEFINE_BINARY_OP_OVERLOAD(*);
|
||||
DEFINE_BINARY_OP_OVERLOAD(/);
|
||||
|
||||
DEFINE_BINARY_OP_FUNCTION(max, MaxOp);
|
||||
DEFINE_BINARY_OP_FUNCTION(min, MinOp);
|
||||
DEFINE_BINARY_OP_FUNCTION(max);
|
||||
DEFINE_BINARY_OP_FUNCTION(min);
|
||||
|
||||
// overload negation
|
||||
inline Expr operator-(Expr src) {
|
||||
return src * (-1);
|
||||
}
|
||||
|
||||
// template of op registry
|
||||
template<typename Op>
|
||||
struct OpReg {
|
||||
std::string name;
|
||||
std::unique_ptr<Op> op;
|
||||
inline OpReg& set(Op* op) {
|
||||
this->op.reset(op);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
using UnaryOpReg = OpReg<UnaryOp>;
|
||||
using BinaryOpReg = OpReg<BinaryOp>;
|
||||
|
||||
#define TVM_REGISTER_BINARY_OP(FunctionName, TypeName) \
|
||||
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
|
||||
::dmlc::Registry<::tvm::BinaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
|
||||
.set(new TypeName())
|
||||
|
||||
#define TVM_REGISTER_UNARY_OP(FunctionName, TypeName) \
|
||||
static DMLC_ATTRIBUTE_UNUSED ::tvm::BinaryOpReg & __make_ ## _BinOp_ ## TypeName = \
|
||||
::dmlc::Registry<::tvm::UnaryOpReg>::Get()->__REGISTER_OR_GET__(#FunctionName) \
|
||||
.set(new TypeName())
|
||||
|
||||
} // namespace tvm
|
||||
|
||||
#endif // TVM_OP_H_
|
||||
|
|
|
@ -1,7 +1,41 @@
|
|||
from ._ctypes._api import NodeBase, register_node
|
||||
from .function import binary_op
|
||||
from ._function_internal import _binary_op
|
||||
|
||||
class Expr(NodeBase):
|
||||
pass
|
||||
def __add__(self, other):
|
||||
return binary_op('+', self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
def __sub__(self, other):
|
||||
return binary_op('-', self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return binary_op('-', other, self)
|
||||
|
||||
def __mul__(self, other):
|
||||
return binary_op('*', self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return binary_op('*', other, self)
|
||||
|
||||
def __div__(self, other):
|
||||
return binary_op('/', self, other)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
return binary_op('/', other, self)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self.__div__(other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return self.__rdiv__(other)
|
||||
|
||||
def __neg__(self):
|
||||
return self.__mul__(-1)
|
||||
|
||||
|
||||
@register_node("VarNode")
|
||||
class Var(Expr):
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import absolute_import as _abs
|
||||
from numbers import Number as _Number
|
||||
from ._ctypes._api import _init_function_module
|
||||
import _function_internal
|
||||
from .import _function_internal
|
||||
|
||||
int32 = 1
|
||||
float32 = 2
|
||||
|
@ -18,4 +20,57 @@ def Var(name="tindex", dtype=int32):
|
|||
return _function_internal._Var(name, dtype)
|
||||
|
||||
|
||||
def _symbol(value):
|
||||
"""Convert a value to expression."""
|
||||
if isinstance(value, _Number):
|
||||
return constant(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def binary_op(op, lhs, rhs):
|
||||
"""Binary operator given op lhs and rhs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
op : str
|
||||
The operator string
|
||||
|
||||
lhs : Expr/number
|
||||
The left operand
|
||||
|
||||
rhs : Expr/number
|
||||
The right operand
|
||||
"""
|
||||
return _function_internal._binary_op(op, _symbol(lhs), _symbol(rhs))
|
||||
|
||||
|
||||
def max(lhs, rhs):
|
||||
"""Max of two expressions
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lhs : Expr/number
|
||||
The left operand
|
||||
|
||||
rhs : Expr/number
|
||||
The right operand
|
||||
"""
|
||||
return binary_op("max", lhs, rhs)
|
||||
|
||||
|
||||
def min(lhs, rhs):
|
||||
"""Min of two expressions
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lhs : Expr/number
|
||||
The left operand
|
||||
|
||||
rhs : Expr/number
|
||||
The right operand
|
||||
"""
|
||||
return binary_op("max", lhs, rhs)
|
||||
|
||||
|
||||
_init_function_module("tvm.cpp")
|
||||
|
|
|
@ -16,6 +16,7 @@ namespace tvm {
|
|||
using ArgStack = const std::vector<APIVariantValue>;
|
||||
using RetValue = APIVariantValue;
|
||||
|
||||
// expression logic x
|
||||
TVM_REGISTER_API(_Var)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = Var(args.at(0),
|
||||
|
@ -24,21 +25,28 @@ TVM_REGISTER_API(_Var)
|
|||
.add_argument("name", "str", "name of the var")
|
||||
.add_argument("dtype", "int", "data type of var");
|
||||
|
||||
|
||||
TVM_REGISTER_API(max)
|
||||
TVM_REGISTER_API(constant)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = max(args.at(0), args.at(1));
|
||||
})
|
||||
.add_argument("lhs", "Expr", "left operand")
|
||||
.add_argument("rhs", "Expr", "right operand");
|
||||
|
||||
TVM_REGISTER_API(min)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
*ret = min(args.at(0), args.at(1));
|
||||
if (args.at(0).type_id == kLong) {
|
||||
*ret = IntConstant(args.at(0));
|
||||
} else if (args.at(0).type_id == kDouble) {
|
||||
*ret = FloatConstant(args.at(0));
|
||||
} else {
|
||||
LOG(FATAL) << "only accept int or float";
|
||||
}
|
||||
})
|
||||
.add_argument("src", "Number", "source number");
|
||||
|
||||
TVM_REGISTER_API(_binary_op)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
CHECK(args.at(0).type_id == kStr);
|
||||
*ret = (*BinaryOp::Get(args.at(0).str.c_str()))(args.at(1), args.at(2));
|
||||
})
|
||||
.add_argument("op", "str", "operator")
|
||||
.add_argument("lhs", "Expr", "left operand")
|
||||
.add_argument("rhs", "Expr", "right operand");
|
||||
|
||||
// transformations
|
||||
TVM_REGISTER_API(format_str)
|
||||
.set_body([](const ArgStack& args, RetValue *ret) {
|
||||
std::ostringstream os;
|
||||
|
|
|
@ -5,6 +5,12 @@
|
|||
#include <tvm/op.h>
|
||||
#include <tvm/expr_node.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::tvm::BinaryOpReg);
|
||||
DMLC_REGISTRY_ENABLE(::tvm::UnaryOpReg);
|
||||
} // namespace dmlc
|
||||
|
||||
|
||||
namespace tvm {
|
||||
|
||||
Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
|
||||
|
@ -14,17 +20,18 @@ Expr BinaryOp::operator()(Expr lhs, Expr rhs) const {
|
|||
return Expr(std::move(nptr));
|
||||
}
|
||||
|
||||
#define DEFINE_SINGLETON_GET(TypeName) \
|
||||
TypeName* TypeName::Get() { \
|
||||
static TypeName inst; \
|
||||
return &inst; \
|
||||
}
|
||||
const BinaryOp* BinaryOp::Get(const char* name) {
|
||||
const auto* op = dmlc::Registry<BinaryOpReg>::Find(name);
|
||||
CHECK(op != nullptr) << "cannot find " << name;
|
||||
return op->op.get();
|
||||
}
|
||||
|
||||
TVM_REGISTER_BINARY_OP(+, AddOp);
|
||||
TVM_REGISTER_BINARY_OP(-, SubOp);
|
||||
TVM_REGISTER_BINARY_OP(*, MulOp);
|
||||
TVM_REGISTER_BINARY_OP(/, DivOp);
|
||||
TVM_REGISTER_BINARY_OP(max, MaxOp);
|
||||
TVM_REGISTER_BINARY_OP(min, MinOp);
|
||||
|
||||
DEFINE_SINGLETON_GET(AddOp);
|
||||
DEFINE_SINGLETON_GET(SubOp);
|
||||
DEFINE_SINGLETON_GET(MulOp);
|
||||
DEFINE_SINGLETON_GET(DivOp);
|
||||
DEFINE_SINGLETON_GET(MaxOp);
|
||||
DEFINE_SINGLETON_GET(MinOp);
|
||||
|
||||
} // namespace tvm
|
||||
|
|
|
@ -3,8 +3,8 @@ from tvm import cpp as tvm
|
|||
def test_basic():
|
||||
a = tvm.Var('a')
|
||||
b = tvm.Var('b')
|
||||
z = tvm.max(a, b)
|
||||
assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name)
|
||||
c = a + b
|
||||
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_basic()
|
||||
|
|
Загрузка…
Ссылка в новой задаче