[Relay/TOPI][Op] Add erf intrinsic and op (#3702)
* add more ops * stop vectorization for erf * x * cleanup * fix * add whitelist for vectorizable intrin * add tf converter * fix dense * fix * add missing intrin * fix mxnet frontend * fix nvptx
This commit is contained in:
Родитель
6a377f77e8
Коммит
2f5b155ab5
|
@ -512,6 +512,7 @@ TVM_DLL Expr trunc(Expr x);
|
|||
} \
|
||||
|
||||
TVM_DECLARE_INTRIN_UNARY(exp);
|
||||
TVM_DECLARE_INTRIN_UNARY(erf);
|
||||
TVM_DECLARE_INTRIN_UNARY(tanh);
|
||||
TVM_DECLARE_INTRIN_UNARY(sigmoid);
|
||||
TVM_DECLARE_INTRIN_UNARY(sqrt);
|
||||
|
|
|
@ -556,6 +556,9 @@ class Call : public ExprNode {
|
|||
name == intrin_name);
|
||||
}
|
||||
|
||||
/*! \return Whether call node can be vectorized. */
|
||||
bool is_vectorizable() const;
|
||||
|
||||
static constexpr const char* _type_key = "Call";
|
||||
TVM_DECLARE_NODE_TYPE_INFO(Call, ExprNode);
|
||||
|
||||
|
@ -571,6 +574,9 @@ class Call : public ExprNode {
|
|||
static constexpr const char* likely = "likely";
|
||||
static constexpr const char* glsl_texture_store = "glsl_texture_store";
|
||||
static constexpr const char* prefetch = "prefetch";
|
||||
|
||||
/*! \brief Vectorizable intrinsic list. */
|
||||
static const char* vectorizable_intrinsics[];
|
||||
};
|
||||
|
||||
/*!
|
||||
|
|
|
@ -211,6 +211,22 @@ def exp(x):
|
|||
return call_pure_intrin(x.dtype, "exp", x)
|
||||
|
||||
|
||||
def erf(x):
|
||||
"""Take gauss error function of the input x.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Expr
|
||||
Input argument.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : Expr
|
||||
The result.
|
||||
"""
|
||||
return call_pure_intrin(x.dtype, "erf", x)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""Take hyperbolic tanh of input x.
|
||||
|
||||
|
|
|
@ -170,8 +170,8 @@ class Executor(object):
|
|||
return args
|
||||
|
||||
if kwargs and not isinstance(expr, Function):
|
||||
raise Exception("can only supply keyword parameters for a \
|
||||
relay.Function, found {0}".format(expr))
|
||||
raise Exception("can only supply keyword parameters for a "
|
||||
"relay.Function, found {0}".format(expr))
|
||||
|
||||
params = expr.params
|
||||
param_names = [p.name_hint for p in params]
|
||||
|
@ -182,16 +182,16 @@ class Executor(object):
|
|||
if i < num_of_args:
|
||||
if kwargs.get(name):
|
||||
raise Exception(
|
||||
"duplicate argument supplied in \
|
||||
both positional args (at position: {0}), \
|
||||
and keyword argument (with name: {1})".format(i, name))
|
||||
"duplicate argument supplied in "
|
||||
"both positional args (at position: {0}), "
|
||||
"and keyword argument (with name: {1})".format(i, name))
|
||||
else:
|
||||
cargs.append(kwargs[name])
|
||||
|
||||
if len(cargs) != len(params):
|
||||
raise Exception(
|
||||
"insufficient arguments, expected" \
|
||||
" {0}, provided {1}".format(len(cargs), len(params)))
|
||||
"insufficient arguments, expected "
|
||||
"{0}, provided {1}".format(len(cargs), len(params)))
|
||||
|
||||
return tuple(cargs)
|
||||
|
||||
|
|
|
@ -124,7 +124,16 @@ class StrAttrsDict(object):
|
|||
"""
|
||||
if key in self.attrs:
|
||||
tshape = self.attrs[key]
|
||||
return tuple(int(x.strip()) for x in tshape.strip('()[]').split(',') if x)
|
||||
ret = []
|
||||
for x in tshape.strip('()[]').split(','):
|
||||
x = x.strip()
|
||||
if not x:
|
||||
continue
|
||||
if x == "None":
|
||||
ret.append(None)
|
||||
else:
|
||||
ret.append(int(x))
|
||||
return tuple(ret)
|
||||
if isinstance(default, RequiredAttr):
|
||||
raise AttributeError("Required attribute {} not found.".format(key))
|
||||
return default
|
||||
|
|
|
@ -55,10 +55,17 @@ def _mx_fully_connected(inputs, attrs):
|
|||
use_flatten = attrs.get_bool("flatten", True)
|
||||
if has_flatten and use_flatten:
|
||||
inputs[0] = _op.nn.batch_flatten(inputs[0])
|
||||
data_shape = _infer_type(inputs[0]).checked_type.shape
|
||||
if len(data_shape) > 2:
|
||||
inputs[0] = _op.reverse_reshape(inputs[0], [-1, 0])
|
||||
res = _op.nn.dense(inputs[0], inputs[1], units=units)
|
||||
if use_bias:
|
||||
assert len(inputs) == 3
|
||||
res = _op.nn.bias_add(res, inputs[2], axis=-1)
|
||||
if len(data_shape) > 2:
|
||||
new_shape = data_shape[:-1]
|
||||
new_shape.append(units)
|
||||
res = _op.reshape(res, new_shape)
|
||||
return res
|
||||
|
||||
|
||||
|
@ -241,8 +248,8 @@ def _mx_layer_norm(inputs, attrs):
|
|||
|
||||
def _mx_slice(inputs, attrs):
|
||||
new_attrs = {}
|
||||
begin = attrs.get_int_tuple('begin', None)
|
||||
end = attrs.get_int_tuple('end', None)
|
||||
begin = list(attrs.get_int_tuple('begin', None))
|
||||
end = list(attrs.get_int_tuple('end', None))
|
||||
stride = attrs.get_int_tuple('step', None)
|
||||
if begin is None:
|
||||
raise tvm.error.OpAttributeRequired(
|
||||
|
@ -251,11 +258,12 @@ def _mx_slice(inputs, attrs):
|
|||
raise tvm.error.OpAttributeRequired(
|
||||
'Attribute "end" not found in operator Slice.')
|
||||
if None in begin:
|
||||
raise tvm.error.OpAttributeInvalid(
|
||||
'Value None in attribute "begin" of operator Slice is not valid.')
|
||||
if None in end:
|
||||
raise tvm.error.OpAttributeInvalid(
|
||||
'Value None in attribute "end" of operator Slice is not valid.')
|
||||
data_shape = _infer_type(inputs[0]).checked_type.shape
|
||||
for i, beg in enumerate(begin):
|
||||
if beg is None:
|
||||
assert end[i] is None
|
||||
begin[i] = 0
|
||||
end[i] = data_shape[i]
|
||||
new_attrs = {'begin': begin, 'end': end}
|
||||
if stride is not None:
|
||||
new_attrs['strides'] = stride
|
||||
|
@ -497,7 +505,8 @@ def _mx_arange(inputs, attrs):
|
|||
'Attribute "repeat" is not supported in operator arange.')
|
||||
new_attrs = {}
|
||||
new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
|
||||
new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
|
||||
stop = attrs.get_str("stop", "None")
|
||||
new_attrs["stop"] = None if stop == "None" else _expr.const(float(stop))
|
||||
new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
|
||||
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
|
||||
return _op.arange(**new_attrs)
|
||||
|
@ -910,6 +919,7 @@ def _mx_one_hot(inputs, attrs):
|
|||
_identity_list = [
|
||||
"log",
|
||||
"exp",
|
||||
"erf",
|
||||
"sqrt",
|
||||
"floor",
|
||||
"ceil",
|
||||
|
|
|
@ -1261,6 +1261,7 @@ _convert_map = {
|
|||
'DepthToSpace' : _depth_to_space(),
|
||||
'Equal' : _broadcast('equal'),
|
||||
'Elu' : _elu(),
|
||||
'Erf' : AttrCvt('erf'),
|
||||
'Exp' : AttrCvt('exp'),
|
||||
'ExpandDims' : _expand_dims(),
|
||||
'Fill' : _fill(),
|
||||
|
|
|
@ -30,6 +30,7 @@ register_schedule("log1p", schedule_broadcast)
|
|||
register_schedule("cos", schedule_broadcast)
|
||||
register_schedule("sin", schedule_broadcast)
|
||||
register_schedule("exp", schedule_broadcast)
|
||||
register_schedule("erf", schedule_broadcast)
|
||||
register_schedule("sqrt", schedule_broadcast)
|
||||
register_schedule("rsqrt", schedule_broadcast)
|
||||
register_schedule("sigmoid", schedule_broadcast)
|
||||
|
|
|
@ -92,6 +92,22 @@ def exp(data):
|
|||
return _make.exp(data)
|
||||
|
||||
|
||||
def erf(data):
|
||||
"""Compute elementwise error function of data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : relay.Expr
|
||||
The input data
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : relay.Expr
|
||||
The computed result.
|
||||
"""
|
||||
return _make.erf(data)
|
||||
|
||||
|
||||
def sqrt(data):
|
||||
"""Compute elementwise sqrt of data.
|
||||
|
||||
|
|
|
@ -31,6 +31,9 @@ namespace intrin {
|
|||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
|
||||
.set_body(DispatchExtern<FloatSuffix>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
|
||||
.set_body(DispatchExtern<FloatSuffix>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
|
||||
.set_body(DispatchExtern<FloatSuffix>);
|
||||
|
||||
|
|
|
@ -92,6 +92,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
|
|||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
|
||||
.set_body(DispatchExtern<CUDAFastMath>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
|
||||
.set_body(DispatchExtern<CUDAMath>);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
|
||||
.set_body(DispatchExtern<CUDAFastMath>);
|
||||
|
||||
|
|
|
@ -64,6 +64,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
|
|||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
|
||||
.set_body(DispatchExternLibDevice);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
|
||||
.set_body(DispatchExternLibDevice);
|
||||
|
||||
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
|
||||
.set_body(DispatchExternLibDevice);
|
||||
|
||||
|
|
|
@ -176,6 +176,22 @@ Expr Let::make(Var var, Expr value, Expr body) {
|
|||
return Expr(node);
|
||||
}
|
||||
|
||||
const char* Call::vectorizable_intrinsics[] = {
|
||||
"floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
|
||||
"log", "sin", "cos", "pow", ir::Call::shift_left, ir::Call::shift_right,
|
||||
ir::Call::likely, ir::Call::popcount
|
||||
};
|
||||
|
||||
bool Call::is_vectorizable() const {
|
||||
size_t cnt = sizeof(Call::vectorizable_intrinsics) / sizeof(char*);
|
||||
for (size_t i = 0; i < cnt; ++i) {
|
||||
if (name == Call::vectorizable_intrinsics[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Expr Call::make(DataType type,
|
||||
std::string name,
|
||||
Array<Expr> args,
|
||||
|
|
|
@ -268,16 +268,34 @@ class Vectorizer : public IRMutator {
|
|||
if (op->name == intrinsic::tvm_if_then_else) {
|
||||
return MutateIfThenElseExpr_(op, e);
|
||||
}
|
||||
int lane = 0;
|
||||
Array<Expr> new_args = MutateArray(op->args, &lane);
|
||||
|
||||
// normal code path.
|
||||
if (op->args.same_as(new_args)) {
|
||||
return e;
|
||||
if (!op->is_vectorizable()) {
|
||||
// Cannot vectorize this op
|
||||
Array<Expr> new_args;
|
||||
for (auto arg : op->args) {
|
||||
auto new_arg = this->Mutate(arg);
|
||||
if (new_arg.type().is_vector()) {
|
||||
need_scalarize_ = true;
|
||||
return e;
|
||||
}
|
||||
new_args.push_back(new_arg);
|
||||
}
|
||||
if (op->args.same_as(new_args)) {
|
||||
return e;
|
||||
} else {
|
||||
return Call::make(
|
||||
op->type, op->name, new_args, op->call_type, op->func, op->value_index);
|
||||
}
|
||||
} else {
|
||||
return Call::make(
|
||||
op->type.with_lanes(lane), op->name, new_args,
|
||||
op->call_type, op->func, op->value_index);
|
||||
int lane = 0;
|
||||
Array<Expr> new_args = MutateArray(op->args, &lane);
|
||||
// normal code path.
|
||||
if (op->args.same_as(new_args)) {
|
||||
return e;
|
||||
} else {
|
||||
return Call::make(
|
||||
op->type.with_lanes(lane), op->name, new_args,
|
||||
op->call_type, op->func, op->value_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Load
|
||||
|
|
|
@ -85,6 +85,18 @@ RELAY_REGISTER_UNARY_OP("exp")
|
|||
.set_support_level(1)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
|
||||
|
||||
|
||||
RELAY_REGISTER_UNARY_OP("erf")
|
||||
.describe(R"code(Returns the error function value for input array, computed element-wise.
|
||||
|
||||
.. math::
|
||||
\erf(x)
|
||||
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_support_level(1)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));
|
||||
|
||||
|
||||
RELAY_REGISTER_UNARY_OP("sqrt")
|
||||
.describe(R"code(Returns the sqrt input array, computed element-wise.
|
||||
|
||||
|
|
|
@ -1844,6 +1844,14 @@ def test_forward_zeros_like():
|
|||
_test_forward_zeros_like((2, 3, 11), "float32")
|
||||
_test_forward_zeros_like((2, 3, 11), "float64")
|
||||
|
||||
def test_forward_erf():
|
||||
ishape = (1, 3, 10, 10)
|
||||
inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
|
||||
with tf.Graph().as_default():
|
||||
in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype)
|
||||
tf.math.erf(in1)
|
||||
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')
|
||||
|
||||
def _test_forward_reverse_v2(in_shape, axis, dtype):
|
||||
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
|
||||
tf.reset_default_graph()
|
||||
|
@ -2244,6 +2252,7 @@ if __name__ == '__main__':
|
|||
test_forward_log_softmax()
|
||||
test_forward_bias_add()
|
||||
test_forward_zeros_like()
|
||||
test_forward_erf()
|
||||
|
||||
# Reductions
|
||||
test_forward_argminmax()
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
import numpy as np
|
||||
import tvm
|
||||
import scipy
|
||||
from tvm import relay
|
||||
from tvm.relay import transform
|
||||
from tvm.relay.testing import ctx_list
|
||||
|
@ -67,6 +68,7 @@ def test_unary_op():
|
|||
|
||||
for opfunc, ref in [(tvm.relay.log, np.log),
|
||||
(tvm.relay.exp, np.exp),
|
||||
(tvm.relay.erf, scipy.special.erf),
|
||||
(tvm.relay.sqrt, np.sqrt),
|
||||
(tvm.relay.rsqrt, rsqrt),
|
||||
(tvm.relay.sigmoid, sigmoid),
|
||||
|
|
|
@ -46,6 +46,7 @@ using namespace tvm;
|
|||
}
|
||||
|
||||
TOPI_DECLARE_UNARY_OP(exp);
|
||||
TOPI_DECLARE_UNARY_OP(erf);
|
||||
TOPI_DECLARE_UNARY_OP(sigmoid);
|
||||
TOPI_DECLARE_UNARY_OP(sqrt);
|
||||
TOPI_DECLARE_UNARY_OP(log);
|
||||
|
|
|
@ -74,6 +74,23 @@ def exp(x):
|
|||
return tvm.compute(x.shape, lambda *i: tvm.exp(x(*i)))
|
||||
|
||||
|
||||
@tvm.tag_scope(tag=tag.ELEMWISE)
|
||||
def erf(x):
|
||||
"""Take gauss error function of input x.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : tvm.Tensor
|
||||
Input argument.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : tvm.Tensor
|
||||
The result.
|
||||
"""
|
||||
return tvm.compute(x.shape, lambda *i: tvm.erf(x(*i)))
|
||||
|
||||
|
||||
@tvm.tag_scope(tag=tag.ELEMWISE)
|
||||
def tanh(x):
|
||||
"""Take hyperbolic tanh of input x.
|
||||
|
|
|
@ -28,12 +28,19 @@ from ..util import traverse_inline, get_const_tuple
|
|||
|
||||
@autotvm.register_topi_compute(nn.dense, "cpu", "direct")
|
||||
def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
|
||||
batch, _ = get_const_tuple(data.shape)
|
||||
target = tvm.target.current_target()
|
||||
if "cblas" in target.libs:
|
||||
C = cblas.matmul(data, weight, False, True)
|
||||
if bias is not None:
|
||||
C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
|
||||
tag=tag.BROADCAST)
|
||||
return C
|
||||
|
||||
M, _ = get_const_tuple(data.shape)
|
||||
# For small batch sizes, don't pack weight into cache-friendly layout
|
||||
# because of overhead in packing and limited reuse from batch dimension
|
||||
# TODO(icemelon9): use a more systematic way to determine which schedule to use
|
||||
if batch <= 16:
|
||||
if M <= 16:
|
||||
return _declaration_dense_nopack(cfg, data, weight, bias, out_dtype)
|
||||
return _declaration_dense_pack(cfg, data, weight, bias, out_dtype)
|
||||
|
||||
|
@ -41,35 +48,31 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
|
|||
# Declare dense compute with packing weight into cache-friendly layout
|
||||
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
|
||||
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
|
||||
target = tvm.target.current_target()
|
||||
if "cblas" in target.libs:
|
||||
C = cblas.matmul(data, weight, False, True)
|
||||
else:
|
||||
if out_dtype is None:
|
||||
out_dtype = data.dtype
|
||||
batch, in_dim = get_const_tuple(data.shape)
|
||||
out_dim, _ = get_const_tuple(weight.shape)
|
||||
# create tuning space
|
||||
cfg.define_split("tile_y", batch, num_outputs=3)
|
||||
cfg.define_split("tile_x", out_dim, num_outputs=3)
|
||||
cfg.define_split("tile_k", in_dim, num_outputs=2)
|
||||
if cfg.is_fallback:
|
||||
_default_dense_pack_config(cfg, batch, out_dim, in_dim)
|
||||
if out_dtype is None:
|
||||
out_dtype = data.dtype
|
||||
M, K = get_const_tuple(data.shape) # batch, in_dim
|
||||
N, _ = get_const_tuple(weight.shape) # out_dim
|
||||
# create tuning space
|
||||
cfg.define_split("tile_y", M, num_outputs=3)
|
||||
cfg.define_split("tile_x", N, num_outputs=3)
|
||||
cfg.define_split("tile_k", K, num_outputs=2)
|
||||
if cfg.is_fallback:
|
||||
_default_dense_pack_config(cfg, M, N, K)
|
||||
|
||||
packw_bn = cfg["tile_x"].size[-1]
|
||||
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
|
||||
packw = tvm.compute(packw_shape,
|
||||
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
|
||||
packw_bn = cfg["tile_x"].size[-1]
|
||||
packw_shape = (N // packw_bn, K, packw_bn)
|
||||
packw = tvm.compute(packw_shape,
|
||||
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
|
||||
|
||||
k = tvm.reduce_axis((0, in_dim), name="k")
|
||||
C = tvm.compute((batch, out_dim),
|
||||
lambda y, x: tvm.sum(
|
||||
data[y, k].astype(out_dtype) *
|
||||
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
|
||||
axis=k),
|
||||
tag="dense_pack")
|
||||
k = tvm.reduce_axis((0, K), name="k")
|
||||
C = tvm.compute((M, N),
|
||||
lambda y, x: tvm.sum(
|
||||
data[y, k].astype(out_dtype) *
|
||||
packw[x // packw_bn, k, x % packw_bn].astype(out_dtype),
|
||||
axis=k),
|
||||
tag="dense_pack")
|
||||
if bias is not None:
|
||||
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
|
||||
C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
|
||||
tag=tag.BROADCAST)
|
||||
return C
|
||||
|
||||
|
@ -77,34 +80,30 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
|
|||
# Declare dense compute without packing weight
|
||||
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
|
||||
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
|
||||
target = tvm.target.current_target()
|
||||
if "cblas" in target.libs:
|
||||
C = cblas.matmul(data, weight, False, True)
|
||||
else:
|
||||
if out_dtype is None:
|
||||
out_dtype = data.dtype
|
||||
batch, in_dim = get_const_tuple(data.shape)
|
||||
out_dim, _ = get_const_tuple(weight.shape)
|
||||
# create tuning space
|
||||
cfg.define_split("tile_x", out_dim, num_outputs=2)
|
||||
cfg.define_split("tile_y", batch, num_outputs=2)
|
||||
cfg.define_split("tile_k", in_dim, num_outputs=2)
|
||||
if cfg.is_fallback:
|
||||
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)
|
||||
if out_dtype is None:
|
||||
out_dtype = data.dtype
|
||||
M, K = get_const_tuple(data.shape)
|
||||
N, _ = get_const_tuple(weight.shape)
|
||||
# create tuning space
|
||||
cfg.define_split("tile_y", M, num_outputs=2)
|
||||
cfg.define_split("tile_x", N, num_outputs=2)
|
||||
cfg.define_split("tile_k", K, num_outputs=2)
|
||||
if cfg.is_fallback:
|
||||
_default_dense_nopack_config(cfg, M, N, K)
|
||||
|
||||
vec = cfg["tile_k"].size[-1]
|
||||
k = tvm.reduce_axis((0, in_dim // vec), "k")
|
||||
CC = tvm.compute((batch, out_dim, vec),
|
||||
lambda z, y, x: tvm.sum(
|
||||
data[z, k * vec + x].astype(out_dtype) *
|
||||
weight[y, k * vec + x].astype(out_dtype), axis=k))
|
||||
vec = cfg["tile_k"].size[-1]
|
||||
k = tvm.reduce_axis((0, K // vec), "k")
|
||||
CC = tvm.compute((M, N, vec),
|
||||
lambda z, y, x: tvm.sum(
|
||||
data[z, k * vec + x].astype(out_dtype) *
|
||||
weight[y, k * vec + x].astype(out_dtype), axis=k))
|
||||
|
||||
kk = tvm.reduce_axis((0, vec), "kk")
|
||||
C = tvm.compute((batch, out_dim),
|
||||
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
|
||||
tag="dense_nopack")
|
||||
kk = tvm.reduce_axis((0, vec), "kk")
|
||||
C = tvm.compute((M, N),
|
||||
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
|
||||
tag="dense_nopack")
|
||||
if bias is not None:
|
||||
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
|
||||
C = tvm.compute((M, N), lambda i, j: C[i, j] + bias[j].astype(out_dtype),
|
||||
tag=tag.BROADCAST)
|
||||
|
||||
return C
|
||||
|
|
|
@ -148,6 +148,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
|
|||
*rv = exp(args[0]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.erf")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = erf(args[0]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.cos")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = cos(args[0]);
|
||||
|
@ -157,7 +162,6 @@ TVM_REGISTER_GLOBAL("topi.sin")
|
|||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = sin(args[0]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.tanh")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = tanh(args[0]);
|
||||
|
|
|
@ -36,6 +36,7 @@ def test_ewise():
|
|||
assert B.op.body[0].name == name
|
||||
|
||||
test_apply(topi.exp, "exp")
|
||||
test_apply(topi.erf, "erf")
|
||||
test_apply(topi.tanh, "tanh")
|
||||
test_apply(topi.sigmoid, "sigmoid")
|
||||
test_apply(topi.log, "log")
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import numpy as np
|
||||
import scipy
|
||||
import tvm
|
||||
import topi
|
||||
import topi.testing
|
||||
|
@ -86,6 +87,7 @@ def test_ewise():
|
|||
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
|
||||
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
|
||||
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
|
||||
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
|
||||
|
||||
|
||||
def test_cast():
|
||||
|
|
Загрузка…
Ссылка в новой задаче