[TOPI] Fix declaration for different dtypes (#546)
This commit is contained in:
Родитель
b384cd4a8d
Коммит
b20678b03d
|
@ -18,6 +18,7 @@ For example, you can use addexp.a to get the left operand of an Add node.
|
|||
from __future__ import absolute_import as _abs
|
||||
from ._ffi.node import NodeBase, register_node
|
||||
from . import make as _make
|
||||
from . import _api_internal
|
||||
|
||||
class ExprOp(object):
|
||||
def __add__(self, other):
|
||||
|
@ -60,7 +61,8 @@ class ExprOp(object):
|
|||
return _make.Mod(self, other)
|
||||
|
||||
def __neg__(self):
|
||||
return self.__mul__(-1)
|
||||
neg_one = _api_internal._const(-1, self.dtype)
|
||||
return self.__mul__(neg_one)
|
||||
|
||||
def __lshift__(self, other):
|
||||
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)
|
||||
|
|
|
@ -17,7 +17,7 @@ def relu(x):
|
|||
y : tvm.Tensor
|
||||
The result.
|
||||
"""
|
||||
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), 0))
|
||||
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), tvm.const(0, x.dtype)))
|
||||
|
||||
|
||||
@tvm.tag_scope(tag=tag.ELEMWISE)
|
||||
|
|
|
@ -38,7 +38,7 @@ def global_pool(data, pool_type):
|
|||
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
|
||||
tag="global_pool_sum")
|
||||
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
|
||||
tsum[n, c, h, w] / (height*width), \
|
||||
tsum[n, c, h, w] / (height*width).astype(tsum.dtype), \
|
||||
tag=tag.ELEMWISE)
|
||||
else:
|
||||
raise ValueError("Pool type should be 'avg' or 'max'.")
|
||||
|
|
Загрузка…
Ссылка в новой задаче