[TOPI][RELAY] Add op Size (#3094)
This commit is contained in:
Родитель
be26083620
Коммит
313bc9de4d
|
@ -97,6 +97,7 @@ List of operators
|
|||
topi.repeat
|
||||
topi.tile
|
||||
topi.shape
|
||||
topi.ndarray_size
|
||||
topi.layout_transform
|
||||
topi.image.resize
|
||||
topi.argsort
|
||||
|
@ -165,6 +166,7 @@ topi
|
|||
.. autofunction:: topi.repeat
|
||||
.. autofunction:: topi.tile
|
||||
.. autofunction:: topi.shape
|
||||
.. autofunction:: topi.ndarray_size
|
||||
.. autofunction:: topi.layout_transform
|
||||
.. autofunction:: topi.argsort
|
||||
.. autofunction:: topi.topk
|
||||
|
|
|
@ -186,6 +186,7 @@ This level support backpropagation of broadcast operators. It is temporary.
|
|||
tvm.relay.collapse_sum_like
|
||||
tvm.relay.slice_like
|
||||
tvm.relay.shape_of
|
||||
tvm.relay.contrib.ndarray_size
|
||||
tvm.relay.layout_transform
|
||||
tvm.relay.device_copy
|
||||
tvm.relay.annotation.on_device
|
||||
|
@ -320,6 +321,7 @@ Level 10 Definitions
|
|||
.. autofunction:: tvm.relay.collapse_sum_like
|
||||
.. autofunction:: tvm.relay.slice_like
|
||||
.. autofunction:: tvm.relay.shape_of
|
||||
.. autofunction:: tvm.relay.contrib.ndarray_size
|
||||
.. autofunction:: tvm.relay.layout_transform
|
||||
.. autofunction:: tvm.relay.device_copy
|
||||
.. autofunction:: tvm.relay.annotation.on_device
|
||||
|
|
|
@ -59,7 +59,7 @@ class OperationNode : public ir::FunctionBaseNode {
|
|||
std::string name;
|
||||
/*! \brief optional tag of the operation */
|
||||
std::string tag;
|
||||
/*! \brief addtitional attributes of the operation*/
|
||||
/*! \brief additional attributes of the operation*/
|
||||
Map<std::string, NodeRef> attrs;
|
||||
/*! \return name of the operation */
|
||||
const std::string& func_name() const final {
|
||||
|
|
|
@ -287,6 +287,17 @@ struct SequenceMaskAttrs : public tvm::AttrsNode<SequenceMaskAttrs> {
|
|||
}
|
||||
}; // struct SequenceMaskAttrs.
|
||||
|
||||
/*! \brief Attributes for ndarray_size operator */
|
||||
struct NdarraySizeAttrs : public tvm::AttrsNode<NdarraySizeAttrs> {
|
||||
DataType dtype;
|
||||
|
||||
TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") {
|
||||
TVM_ATTR_FIELD(dtype)
|
||||
.describe("Target data type")
|
||||
.set_default(NullValue<DataType>());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
|
||||
|
|
|
@ -275,7 +275,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
|
|||
The name hint of the tensor
|
||||
|
||||
tag: str, optional
|
||||
Additonal tag information about the compute.
|
||||
Additional tag information about the compute.
|
||||
|
||||
attrs: dict, optional
|
||||
The additional auxiliary attributes about the compute.
|
||||
|
|
|
@ -1383,6 +1383,7 @@ _convert_map = {
|
|||
'Shape' : _shape(),
|
||||
'Sigmoid' : AttrCvt('sigmoid'),
|
||||
'Sign' : AttrCvt('sign'),
|
||||
'Size' : AttrCvt('ndarray_size'),
|
||||
'Slice' : _slice(),
|
||||
'Softmax' : _softmax(),
|
||||
'Softplus' : _softplus(),
|
||||
|
|
|
@ -20,7 +20,7 @@ from __future__ import absolute_import
|
|||
|
||||
import topi
|
||||
from .. import op as reg
|
||||
from ..op import OpPattern
|
||||
from ..op import schedule_injective, OpPattern
|
||||
|
||||
|
||||
# adaptive_max_pool2d
|
||||
|
@ -41,3 +41,6 @@ def schedule_adaptive_avg_pool2d(_, outs, target):
|
|||
return topi.generic.schedule_adaptive_pool(outs)
|
||||
|
||||
reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
# relay.contrib.ndarray_size
|
||||
reg.register_schedule("contrib.ndarray_size", schedule_injective)
|
||||
|
|
|
@ -111,3 +111,21 @@ def adaptive_avg_pool2d(data,
|
|||
"""
|
||||
output_size = [] or output_size
|
||||
return _make.adaptive_avg_pool2d(data, output_size, layout)
|
||||
|
||||
def ndarray_size(data, dtype="int32"):
|
||||
"""Get number of elements of input tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.relay.Expr
|
||||
The input tensor.
|
||||
|
||||
dtype : str, optional
|
||||
The target data type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.relay.Expr
|
||||
The number of elements of input tensor.
|
||||
"""
|
||||
return _make.ndarray_size(data, dtype)
|
||||
|
|
|
@ -279,5 +279,54 @@ RELAY_REGISTER_OP("shape_of")
|
|||
.set_support_level(10)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", ShapeOfCompute);
|
||||
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs);
|
||||
|
||||
bool NdarraySizeRel(const Array<Type>& types,
|
||||
int num_inputs,
|
||||
const Attrs& attrs,
|
||||
const TypeReporter& reporter) {
|
||||
CHECK_EQ(num_inputs, 1);
|
||||
auto tt = types[0].as<TensorTypeNode>();
|
||||
CHECK(tt != nullptr);
|
||||
const auto* param = attrs.as<NdarraySizeAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype));
|
||||
return true;
|
||||
}
|
||||
|
||||
Array<Tensor> NdarraySizeCompute(const Attrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Type& out_type,
|
||||
const Target& target) {
|
||||
CHECK_EQ(inputs.size(), 1);
|
||||
const auto* param = attrs.as<NdarraySizeAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
return Array<Tensor>{topi::ndarray_size(inputs[0], param->dtype)};
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay.op.contrib._make.ndarray_size")
|
||||
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
|
||||
auto attrs = make_node<NdarraySizeAttrs>();
|
||||
attrs->dtype = dtype;
|
||||
static const Op& op = Op::Get("contrib.ndarray_size");
|
||||
return CallNode::make(op, {data}, Attrs(attrs), {});
|
||||
});
|
||||
|
||||
RELAY_REGISTER_OP("contrib.ndarray_size")
|
||||
.describe(R"code(Returns a tensor representing the number of elements of input tensor.
|
||||
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_num_inputs(1)
|
||||
.set_attrs_type_key("relay.attrs.NdarraySizeAttrs")
|
||||
.add_argument("data", "Tensor", "The input tensor.")
|
||||
.add_type_rel("NdarraySize", NdarraySizeRel)
|
||||
.set_attr<TOpIsStateful>("TOpIsStateful", false)
|
||||
.set_attr<TOpPattern>("TOpPattern", kInjective)
|
||||
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
|
||||
ElemwiseArbitraryLayout)
|
||||
.set_support_level(10)
|
||||
.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
|
|
|
@ -1933,6 +1933,22 @@ def test_forward_mean():
|
|||
check_mean((10, 8, 16, 32), axis=(2, 3))
|
||||
check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True)
|
||||
|
||||
#######################################################################
|
||||
# Size
|
||||
# ----
|
||||
def test_forward_size():
|
||||
def check_size(ishape):
|
||||
np_input = np.random.uniform(size=ishape).astype(np.float32)
|
||||
with tf.Graph().as_default():
|
||||
input = tf.placeholder(shape=np_input.shape, dtype=np_input.dtype, name='input')
|
||||
tf.size(input, name='size')
|
||||
compare_tf_with_tvm([np_input], ['input:0'], 'size:0')
|
||||
|
||||
if tf.__version__ < LooseVersion('1.1'):
|
||||
check_size((10, 8, 16, 32))
|
||||
check_size((10,))
|
||||
check_size(())
|
||||
|
||||
#######################################################################
|
||||
# All, Max, Min
|
||||
# -------------
|
||||
|
@ -2087,6 +2103,7 @@ if __name__ == '__main__':
|
|||
test_forward_depthtospace()
|
||||
test_forward_squeeze()
|
||||
test_forward_pack()
|
||||
test_forward_size()
|
||||
test_forward_broadcast_to()
|
||||
test_forward_fill()
|
||||
test_forward_crop()
|
||||
|
|
|
@ -215,6 +215,23 @@ def test_shape_of():
|
|||
tvm.testing.assert_allclose(op_res.asnumpy(),
|
||||
np.array(shape).astype('int32'))
|
||||
|
||||
def test_ndarray_size():
|
||||
def verify_ndarray_size(shape):
|
||||
x = relay.var("x", shape=shape)
|
||||
func = relay.Function([x], relay.op.contrib.ndarray_size(x))
|
||||
func = run_infer_type(func)
|
||||
|
||||
x_data = np.random.uniform(size=shape).astype("float32")
|
||||
ref_res = np.size(x_data)
|
||||
for target, ctx in ctx_list():
|
||||
for kind in ["graph", "debug"]:
|
||||
intrp = relay.create_executor(kind, ctx=ctx, target=target)
|
||||
op_res = intrp.evaluate(func)(x_data)
|
||||
tvm.testing.assert_allclose(op_res.asnumpy(),
|
||||
ref_res)
|
||||
verify_ndarray_size((2, 3, 5))
|
||||
verify_ndarray_size((2, 3, 5, 7))
|
||||
|
||||
def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
|
||||
def start_index(index, odim, idim):
|
||||
return int(np.floor(index * idim / odim))
|
||||
|
@ -288,3 +305,5 @@ if __name__ == "__main__":
|
|||
test_batch_matmul()
|
||||
test_shape_of()
|
||||
test_sequence_mask()
|
||||
test_ndarray_size()
|
||||
|
||||
|
|
|
@ -1223,5 +1223,28 @@ inline Tensor shape(const Tensor& src,
|
|||
}, name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Get the size of input tensor.
|
||||
* \param src the input tensor.
|
||||
* \param dtype the type of the elements in the tensor.
|
||||
* \param name output tensor name.
|
||||
* \param tag output tensor tag.
|
||||
* \return Tensor of input shape.
|
||||
*/
|
||||
inline Tensor ndarray_size(const Tensor& src,
|
||||
const Type& dtype,
|
||||
const std::string& name = "ndarray_size",
|
||||
const std::string& tag = kInjective) {
|
||||
int ndim = static_cast<int>(src->shape.size());
|
||||
Array<Expr> out_ndarray_size = {1};
|
||||
return compute(out_ndarray_size, [&](const Array<Var>& indices) {
|
||||
Expr ret = 1;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
ret *= src->shape[i];
|
||||
}
|
||||
return tvm::cast(dtype, ret);
|
||||
}, name, tag);
|
||||
}
|
||||
|
||||
} // namespace topi
|
||||
#endif // TOPI_TRANSFORM_H_
|
||||
|
|
|
@ -425,7 +425,7 @@ def shape(array, dtype="int32"):
|
|||
Parameters
|
||||
----------
|
||||
array : tvm.Tensor
|
||||
The source tenosr.
|
||||
The source tensor.
|
||||
|
||||
dtype : str, optional
|
||||
The target data type.
|
||||
|
@ -477,3 +477,22 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):
|
|||
"only support data.ndim >= 2, received data.shape = {}".format(data.shape)
|
||||
assert axis == 0 or axis == 1, "only support axis = 0, 1, received axis = {}".format(axis)
|
||||
return cpp.sequence_mask(data, valid_length, mask_value, axis)
|
||||
|
||||
|
||||
def ndarray_size(array, dtype="int32"):
|
||||
"""Get the number of elements of input array
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : tvm.Tensor
|
||||
The source tensor.
|
||||
|
||||
dtype : str, optional
|
||||
The target data type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : tvm.Tensor
|
||||
The resulting tensor.
|
||||
"""
|
||||
return cpp.ndarray_size(array, dtype)
|
||||
|
|
|
@ -311,6 +311,11 @@ TVM_REGISTER_GLOBAL("topi.shape")
|
|||
*rv = shape(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.ndarray_size")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = ndarray_size(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.split")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
|
||||
|
|
|
@ -649,6 +649,33 @@ def test_sequence_mask():
|
|||
for backend in get_all_backend():
|
||||
check_device(backend)
|
||||
|
||||
def test_ndarray_size():
|
||||
in_shape = (5, 11, 7)
|
||||
dtype = "int32"
|
||||
A = tvm.placeholder(shape=in_shape, dtype="float32", name="A")
|
||||
B = topi.ndarray_size(A, dtype)
|
||||
|
||||
input = np.random.uniform(size=in_shape).astype(A.dtype)
|
||||
output = np.asarray(np.size(input)).astype(dtype)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
tvm_input = tvm.nd.array(input, ctx=ctx)
|
||||
tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype)
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
f = tvm.build(s, [A, B], device, name="ndarray_size")
|
||||
f(tvm_input, tvm_output)
|
||||
tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
|
||||
|
||||
for backend in get_all_backend():
|
||||
check_device(backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_strided_slice()
|
||||
test_concatenate()
|
||||
|
@ -668,3 +695,4 @@ if __name__ == "__main__":
|
|||
test_tile()
|
||||
test_shape()
|
||||
test_sequence_mask()
|
||||
test_ndarray_size()
|
||||
|
|
Загрузка…
Ссылка в новой задаче