Upsampling op support (#298)
* add nnvm upsampling symbol * add upsampling mxnet frontend * add doc for upsampling op * cleanup upsampling test * minor fix * use schedule_injective for upsampling * upgrade tvm
This commit is contained in:
Родитель
d059dbabf0
Коммит
6a0fb6efbf
|
@ -257,6 +257,15 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
|
|||
}
|
||||
};
|
||||
|
||||
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
|
||||
int scale;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
|
||||
DMLC_DECLARE_FIELD(scale)
|
||||
.describe("upsampling scaling factor");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
||||
|
||||
|
|
|
@ -190,6 +190,12 @@ def _softmax_output(inputs, attrs):
|
|||
new_attrs['axis'] = 1
|
||||
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)
|
||||
|
||||
def _upsampling(inputs, attrs):
|
||||
scale = attrs.get('scale')
|
||||
new_attrs = {'scale':int(scale)}
|
||||
return _get_nnvm_op('upsampling')(inputs[0], **new_attrs)
|
||||
|
||||
|
||||
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
|
||||
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
|
||||
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
|
||||
|
@ -231,6 +237,7 @@ _convert_map = {
|
|||
'min_axis' : _rename('min'),
|
||||
'reshape' : _reshape,
|
||||
'sum_axis' : _rename('sum'),
|
||||
'UpSampling' : _upsampling
|
||||
}
|
||||
|
||||
def _convert_symbol(op_name, inputs, attrs,
|
||||
|
|
|
@ -250,3 +250,18 @@ def schedule_global_avg_pool2d(_, outs, target):
|
|||
return topi.generic.schedule_global_pool(outs)
|
||||
|
||||
reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
|
||||
@reg.register_compute("upsampling")
|
||||
def compute_upsampling(attrs, inputs, _):
|
||||
"""Compute definition of upsampling"""
|
||||
scale = attrs.get_int("scale")
|
||||
return topi.nn.upsampling(inputs[0], scale)
|
||||
|
||||
@reg.register_schedule("upsampling")
|
||||
def schedule_upsampling(_, outs, target):
|
||||
"""Compute definition of upsampling"""
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_injective(outs)
|
||||
|
||||
reg.register_pattern("upsampling", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file pooling.cc
|
||||
* \brief Property def of pooling operators.
|
||||
*/
|
||||
#include <nnvm/op.h>
|
||||
#include <nnvm/node.h>
|
||||
#include <nnvm/op_attr_types.h>
|
||||
#include <nnvm/top/nn.h>
|
||||
#include "./nn_common.h"
|
||||
#include "../op_common.h"
|
||||
#include "../elemwise_op_common.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace top {
|
||||
|
||||
DMLC_REGISTER_PARAMETER(UpSamplingParam);
|
||||
|
||||
inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_shape,
|
||||
std::vector<TShape>* out_shape) {
|
||||
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
|
||||
CHECK_EQ(in_shape->size(), 1U);
|
||||
CHECK_EQ(out_shape->size(), 1U);
|
||||
TShape dshape = (*in_shape)[0];
|
||||
if (dshape.ndim() == 0) return false;
|
||||
TShape oshape = dshape;
|
||||
oshape[2] = oshape[2] * param.scale;
|
||||
oshape[3] = oshape[3] * param.scale;
|
||||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
|
||||
return true;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(upsampling)
|
||||
.describe(R"(Perform nearest neighbor upsampling to input array.
|
||||
|
||||
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
|
||||
- **out**: Output is 4D array of shape (batch_size, channels, in_height*scale, in_width*scale).
|
||||
|
||||
)" NNVM_ADD_FILELINE)
|
||||
.add_argument("data", "4D Tensor", "Input data.")
|
||||
.add_arguments(UpSamplingParam::__FIELDS__())
|
||||
.set_attr_parser(ParamParser<UpSamplingParam>)
|
||||
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<UpSamplingParam>)
|
||||
.set_attr<FInferShape>("FInferShape", UpSamplingInferShape)
|
||||
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
|
||||
.set_num_outputs(1)
|
||||
.set_num_inputs(1)
|
||||
.set_support_level(2);
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
|
@ -148,6 +148,25 @@ def test_global_avg_pool2d():
|
|||
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
|
||||
def test_upsampling():
|
||||
x = sym.Variable("x")
|
||||
scale = 2
|
||||
y = sym.upsampling(x, scale=scale, name="y")
|
||||
dtype = "float32"
|
||||
dshape = (1, 16, 32, 32)
|
||||
oshape = (1, 16, 32*scale, 32*scale)
|
||||
shape_dict = {"x": dshape}
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
a_np = np.random.uniform(size=dshape).astype(dtype)
|
||||
data = tvm.nd.array(a_np)
|
||||
m.run(x=data)
|
||||
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
|
||||
b_np = topi.testing.upsampling_python(a_np, scale)
|
||||
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d()
|
||||
test_grouped_conv2d()
|
||||
|
@ -156,3 +175,4 @@ if __name__ == "__main__":
|
|||
test_avg_pool2d()
|
||||
test_global_max_pool2d()
|
||||
test_global_avg_pool2d()
|
||||
test_upsampling()
|
||||
|
|
Загрузка…
Ссылка в новой задаче