[TOPI][IMAGE][RESIZE] Bilinear interpolation for resize and upsampling. (#1181)
This commit is contained in:
Родитель
758cb7519c
Коммит
76fa3ca4f6
|
@ -51,6 +51,7 @@ List of operators
|
|||
topi.broadcast_div
|
||||
topi.broadcast_maximum
|
||||
topi.broadcast_minimum
|
||||
topi.image.resize
|
||||
|
||||
|
||||
List of schedules
|
||||
|
@ -114,6 +115,10 @@ topi.nn
|
|||
.. autofunction:: topi.nn.depthwise_conv2d_nchw
|
||||
.. autofunction:: topi.nn.depthwise_conv2d_nhwc
|
||||
|
||||
topi.image
|
||||
~~~~~~~~~~
|
||||
.. autofunction:: topi.image.resize
|
||||
|
||||
|
||||
topi.generic
|
||||
~~~~~~~~~~~~
|
||||
|
|
|
@ -288,16 +288,22 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
|
|||
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
|
||||
int scale;
|
||||
std::string layout;
|
||||
std::string method;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
|
||||
DMLC_DECLARE_FIELD(scale)
|
||||
.describe("upsampling scaling factor");
|
||||
DMLC_DECLARE_FIELD(layout)
|
||||
.set_default("NCHW")
|
||||
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
|
||||
.describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
|
||||
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
|
||||
"dimensions respectively. Convolution is applied on the 'H' and"
|
||||
"dimensions respectively. Upsampling is applied on the 'H' and"
|
||||
"'W' dimensions.");
|
||||
DMLC_DECLARE_FIELD(method)
|
||||
.set_default("NEAREST_NEIGHBOR")
|
||||
.describe("Specify the mode to use for scaling."
|
||||
"NEAREST_NEIGHBOR - Nearest Neighbor"
|
||||
"BILINEAR - Bilinear Interpolation");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from . import nn
|
|||
from . import transform
|
||||
from . import reduction
|
||||
from . import vision
|
||||
from . import image
|
||||
|
||||
from .registry import OpPattern
|
||||
from .registry import register_compute, register_schedule, register_pattern
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# pylint: disable=invalid-name, unused-argument
|
||||
"""Definition of image ops"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import topi
|
||||
import tvm
|
||||
from . import registry as reg
|
||||
from .registry import OpPattern
|
||||
|
||||
# resize
|
||||
@reg.register_schedule("resize")
|
||||
def schedule_resize(_, outs, target):
|
||||
"""Schedule definition of resize"""
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_injective(outs)
|
||||
|
||||
reg.register_pattern("resize", OpPattern.INJECTIVE)
|
|
@ -235,20 +235,10 @@ def schedule_global_avg_pool2d(_, outs, target):
|
|||
|
||||
reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
||||
|
||||
@reg.register_compute("upsampling")
|
||||
def compute_upsampling(attrs, inputs, _):
|
||||
"""Compute definition of upsampling"""
|
||||
scale = attrs.get_int("scale")
|
||||
layout = attrs["layout"]
|
||||
if layout:
|
||||
assert layout == "NCHW" or layout == "NHWC"
|
||||
return topi.nn.upsampling(inputs[0], scale, layout)
|
||||
return topi.nn.upsampling(inputs[0], scale)
|
||||
|
||||
# upsampling
|
||||
@reg.register_schedule("upsampling")
|
||||
def schedule_upsampling(_, outs, target):
|
||||
"""Compute definition of upsampling"""
|
||||
"""Schedule definition of upsampling"""
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_injective(outs)
|
||||
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file resize.cc
|
||||
* \brief Property def of resize operators.
|
||||
*/
|
||||
#include <tvm/tvm.h>
|
||||
#include <tvm/expr.h>
|
||||
#include <tvm/packed_func_ext.h>
|
||||
#include <nnvm/layout.h>
|
||||
#include <nnvm/compiler/op_attr_types.h>
|
||||
#include <nnvm/op.h>
|
||||
#include <nnvm/node.h>
|
||||
#include <nnvm/op_attr_types.h>
|
||||
#include "../nn/nn_common.h"
|
||||
#include "../op_common.h"
|
||||
#include "../elemwise_op_common.h"
|
||||
#include "topi/elemwise.h"
|
||||
#include "topi/transform.h"
|
||||
#include "topi/image/resize.h"
|
||||
#include "resize.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace top {
|
||||
using tvm::Expr;
|
||||
using tvm::Array;
|
||||
using tvm::Tensor;
|
||||
using nnvm::compiler::FTVMCompute;
|
||||
|
||||
DMLC_REGISTER_PARAMETER(ResizeParam);
|
||||
|
||||
inline bool ResizeInferShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_shape,
|
||||
std::vector<TShape>* out_shape) {
|
||||
static const Layout kNCHW("NCHW");
|
||||
const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
|
||||
CHECK_EQ(in_shape->size(), 1U);
|
||||
CHECK_EQ(out_shape->size(), 1U);
|
||||
TShape dshape = (*in_shape)[0];
|
||||
if (dshape.ndim() == 0) return false;
|
||||
dshape = ConvertLayout(dshape, param.layout, kNCHW);
|
||||
|
||||
TShape oshape = dshape;
|
||||
if (param.layout == "NCHW") {
|
||||
oshape[2] = param.size[0];
|
||||
oshape[3] = param.size[1];
|
||||
} else {
|
||||
oshape[1] = param.size[0];
|
||||
oshape[2] = param.size[1];
|
||||
}
|
||||
oshape = ConvertLayout(oshape, kNCHW, param.layout);
|
||||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool ResizeLayout(const NodeAttrs& attrs,
|
||||
std::vector<Layout> *in_layouts,
|
||||
const std::vector<Layout> *last_in_layouts,
|
||||
std::vector<Layout> *out_layouts) {
|
||||
const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
|
||||
CHECK_EQ(in_layouts->size(), 1U);
|
||||
CHECK_EQ(out_layouts->size(), 1U);
|
||||
const Layout layout(param.layout);
|
||||
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, layout);
|
||||
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, layout);
|
||||
return true;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(resize)
|
||||
.describe(R"(Perform resize to input array with nearest neighbour or bilinear interpolation.
|
||||
|
||||
- **data**: data is 4D array of shape
|
||||
(batch_size, channels, in_height, in_width) for NCHW
|
||||
(batch_size, in_height, in_width, channels) for NHWC
|
||||
|
||||
- **out**: Output is 4D array of shape
|
||||
for layout NCHW
|
||||
(batch_size, channels, size[0], size[1])
|
||||
|
||||
for layout NHWC
|
||||
(batch_size, size[0], size[1], channels)
|
||||
|
||||
)" NNVM_ADD_FILELINE)
|
||||
.add_argument("data", "4D Tensor", "Input data.")
|
||||
.add_arguments(ResizeParam::__FIELDS__())
|
||||
.set_attr_parser(ParamParser<ResizeParam>)
|
||||
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ResizeParam>)
|
||||
.set_attr<FInferShape>("FInferShape", ResizeInferShape)
|
||||
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
|
||||
.set_attr<FCorrectLayout>("FCorrectLayout", ResizeLayout)
|
||||
.set_num_outputs(1)
|
||||
.set_num_inputs(1)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
|
||||
Array<Expr> oshape;
|
||||
if (param.layout == "NCHW") {
|
||||
oshape.push_back(out_info[0]->shape[2]);
|
||||
oshape.push_back(out_info[0]->shape[3]);
|
||||
} else {
|
||||
oshape.push_back(out_info[0]->shape[1]);
|
||||
oshape.push_back(out_info[0]->shape[2]);
|
||||
}
|
||||
|
||||
return Array<Tensor>{ topi::image::resize(inputs[0], oshape, param.layout,
|
||||
param.align_corners, param.method)};
|
||||
})
|
||||
.set_support_level(2);
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
|
@ -0,0 +1,45 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file resize.h
|
||||
*/
|
||||
#ifndef NNVM_TOP_IMAGE_RESIZE_H_
|
||||
#define NNVM_TOP_IMAGE_RESIZE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace nnvm {
|
||||
namespace top {
|
||||
|
||||
struct ResizeParam : public dmlc::Parameter<ResizeParam> {
|
||||
TShape size;
|
||||
std::string layout;
|
||||
std::string method;
|
||||
bool align_corners;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(ResizeParam) {
|
||||
DMLC_DECLARE_FIELD(size)
|
||||
.describe("Output size");
|
||||
DMLC_DECLARE_FIELD(layout)
|
||||
.set_default("NCHW")
|
||||
.describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
|
||||
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
|
||||
"dimensions respectively. Resize is applied on the 'H' and"
|
||||
"'W' dimensions.");
|
||||
DMLC_DECLARE_FIELD(method)
|
||||
.set_default("BILINEAR")
|
||||
.describe("Specify the mode to use for scaling."
|
||||
"NEAREST_NEIGHBOR - Nearest Neighbor"
|
||||
"BILINEAR - Bilinear Interpolation");
|
||||
DMLC_DECLARE_FIELD(align_corners)
|
||||
.set_default(false)
|
||||
.describe("Should be true to preserve the values at the corner pixels");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
||||
#endif // NNVM_TOP_IMAGE_RESIZE_H_
|
|
@ -1,8 +1,12 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file pooling.cc
|
||||
* \brief Property def of pooling operators.
|
||||
* \file upsampling.cc
|
||||
* \brief Property def of upsampling operators.
|
||||
*/
|
||||
#include <tvm/tvm.h>
|
||||
#include <tvm/expr.h>
|
||||
#include <nnvm/layout.h>
|
||||
#include <nnvm/compiler/op_attr_types.h>
|
||||
#include <nnvm/op.h>
|
||||
#include <nnvm/node.h>
|
||||
#include <nnvm/op_attr_types.h>
|
||||
|
@ -10,27 +14,36 @@
|
|||
#include "./nn_common.h"
|
||||
#include "../op_common.h"
|
||||
#include "../elemwise_op_common.h"
|
||||
#include "topi/elemwise.h"
|
||||
#include "topi/transform.h"
|
||||
#include "topi/nn/upsampling.h"
|
||||
|
||||
namespace nnvm {
|
||||
namespace top {
|
||||
using tvm::Expr;
|
||||
using tvm::Array;
|
||||
using tvm::Tensor;
|
||||
using nnvm::compiler::FTVMCompute;
|
||||
|
||||
DMLC_REGISTER_PARAMETER(UpSamplingParam);
|
||||
|
||||
inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_shape,
|
||||
std::vector<TShape>* out_shape) {
|
||||
std::vector<TShape>* in_shape,
|
||||
std::vector<TShape>* out_shape) {
|
||||
static const Layout kNCHW("NCHW");
|
||||
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
|
||||
CHECK_EQ(in_shape->size(), 1U);
|
||||
CHECK_EQ(out_shape->size(), 1U);
|
||||
TShape dshape = (*in_shape)[0];
|
||||
if (dshape.ndim() == 0) return false;
|
||||
|
||||
dshape = ConvertLayout(dshape, param.layout, kNCHW);
|
||||
TShape oshape = dshape;
|
||||
oshape[2] = oshape[2] * param.scale;
|
||||
oshape[3] = oshape[3] * param.scale;
|
||||
oshape = ConvertLayout(oshape, kNCHW, param.layout);
|
||||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -48,10 +61,18 @@ inline bool UpsamplingLayout(const NodeAttrs& attrs,
|
|||
}
|
||||
|
||||
NNVM_REGISTER_OP(upsampling)
|
||||
.describe(R"(Perform nearest neighbor upsampling to input array.
|
||||
.describe(R"(Perform upsampling to input array with nearest neighbour or bilinear interpolation.
|
||||
|
||||
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
|
||||
- **out**: Output is 4D array of shape (batch_size, channels, in_height*scale, in_width*scale).
|
||||
- **data**: data is 4D array of shape
|
||||
(batch_size, channels, in_height, in_width) for NCHW
|
||||
(batch_size, in_height, in_width, channels) for NHWC
|
||||
|
||||
- **out**: Output is 4D array of shape
|
||||
for layout NCHW
|
||||
(batch_size, channels, in_height*scale, in_width*scale)
|
||||
|
||||
for layout NHWC
|
||||
(batch_size, in_height*scale, in_width*scale, channels)
|
||||
|
||||
)" NNVM_ADD_FILELINE)
|
||||
.add_argument("data", "4D Tensor", "Input data.")
|
||||
|
@ -63,6 +84,22 @@ NNVM_REGISTER_OP(upsampling)
|
|||
.set_attr<FCorrectLayout>("FCorrectLayout", UpsamplingLayout)
|
||||
.set_num_outputs(1)
|
||||
.set_num_inputs(1)
|
||||
.set_attr<FTVMCompute>(
|
||||
"FTVMCompute", [](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info) {
|
||||
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
|
||||
Array<Expr> oshape;
|
||||
if (param.layout == "NCHW") {
|
||||
oshape.push_back(out_info[0]->shape[2]);
|
||||
oshape.push_back(out_info[0]->shape[3]);
|
||||
} else {
|
||||
oshape.push_back(out_info[0]->shape[1]);
|
||||
oshape.push_back(out_info[0]->shape[2]);
|
||||
}
|
||||
|
||||
return Array<Tensor>{ topi::nn::upsampling(inputs[0], oshape, param.layout, param.method)};
|
||||
})
|
||||
.set_support_level(2);
|
||||
|
||||
} // namespace top
|
||||
|
|
|
@ -210,7 +210,7 @@ def test_global_avg_pool2d():
|
|||
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
|
||||
def test_upsampling():
|
||||
def test_upsampling_nearest_neighbor():
|
||||
x = sym.Variable("x")
|
||||
scale = 2
|
||||
y = sym.upsampling(x, scale=scale, name="y")
|
||||
|
@ -225,9 +225,46 @@ def test_upsampling():
|
|||
data = tvm.nd.array(a_np)
|
||||
m.run(x=data)
|
||||
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
|
||||
b_np = topi.testing.upsampling_python(a_np, scale)
|
||||
b_np = topi.testing.upsampling_python(a_np, scale, "NCHW")
|
||||
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
def test_upsampling_bilinear():
|
||||
x = sym.Variable("x")
|
||||
scale = 2
|
||||
y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NCHW")
|
||||
dtype = "float32"
|
||||
dshape = (1, 4, 32, 32)
|
||||
oshape = (1, 4, 32*scale, 32*scale)
|
||||
shape_dict = {"x": dshape}
|
||||
dtype_dict = {"x": dtype}
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict, dtype_dict)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
a_np = np.random.uniform(size=dshape).astype(dtype)
|
||||
data = tvm.nd.array(a_np)
|
||||
m.run(x=data)
|
||||
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
|
||||
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW")
|
||||
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_resize_bilinear():
|
||||
x = sym.Variable("x")
|
||||
scale = 2
|
||||
y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NHWC")
|
||||
dtype = "float32"
|
||||
dshape = (1, 32, 32, 4)
|
||||
oshape = (1, 32*scale, 32*scale, 4)
|
||||
shape_dict = {"x": dshape}
|
||||
dtype_dict = {"x": dtype}
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict, dtype_dict)
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
a_np = np.random.uniform(size=dshape).astype(dtype)
|
||||
data = tvm.nd.array(a_np)
|
||||
m.run(x=data)
|
||||
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
|
||||
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NHWC")
|
||||
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d()
|
||||
|
@ -239,4 +276,6 @@ if __name__ == "__main__":
|
|||
test_avg_pool2d_no_count_pad()
|
||||
test_global_max_pool2d()
|
||||
test_global_avg_pool2d()
|
||||
test_upsampling()
|
||||
test_upsampling_nearest_neighbor()
|
||||
test_upsampling_bilinear()
|
||||
test_resize_bilinear()
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file topi/image/resize.h
|
||||
* \brief image resize constructors
|
||||
*/
|
||||
#ifndef TOPI_IMAGE_RESIZE_H_
|
||||
#define TOPI_IMAGE_RESIZE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
|
||||
#include "topi/tags.h"
|
||||
#include "topi/detail/ravel_unravel.h"
|
||||
#include "topi/detail/constant_utils.h"
|
||||
#include "tvm/tvm.h"
|
||||
|
||||
namespace topi {
|
||||
namespace image {
|
||||
using namespace tvm;
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape using nearest neighbour for NHWC
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
bool align_corners = false,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
Array<Expr> out_shape;
|
||||
out_shape.push_back(input->shape[0]);
|
||||
out_shape.push_back(shape[0]);
|
||||
out_shape.push_back(shape[1]);
|
||||
out_shape.push_back(input->shape[3]);
|
||||
|
||||
Expr h_ratio = shape[0] / input->shape[1];
|
||||
Expr w_ratio = shape[1] / input->shape[2];
|
||||
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& indices) {
|
||||
Array<Expr> idx;
|
||||
idx.push_back(indices[0]);
|
||||
idx.push_back(indices[1] / h_ratio);
|
||||
idx.push_back(indices[2] / w_ratio);
|
||||
idx.push_back(indices[3]);
|
||||
|
||||
return input(idx);
|
||||
}, name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape using nearest neighbour for NCHW
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
bool align_corners = false,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
Array<Expr> out_shape;
|
||||
out_shape.push_back(input->shape[0]);
|
||||
out_shape.push_back(input->shape[1]);
|
||||
out_shape.push_back(shape[0]);
|
||||
out_shape.push_back(shape[1]);
|
||||
|
||||
Expr h_ratio = shape[0] / input->shape[2];
|
||||
Expr w_ratio = shape[1] / input->shape[3];
|
||||
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& indices) {
|
||||
Array<Expr> idx;
|
||||
idx.push_back(indices[0]);
|
||||
idx.push_back(indices[1]);
|
||||
idx.push_back(indices[2] / h_ratio);
|
||||
idx.push_back(indices[3] / w_ratio);
|
||||
|
||||
return input(idx);
|
||||
}, name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape using nearest neighbour
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param layout input layout
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize_nearest_neighbor(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
std::string layout = "NCHW",
|
||||
bool align_corners = false,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";
|
||||
|
||||
if (layout == "NHWC") {
|
||||
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
|
||||
} else {
|
||||
return resize_nearest_neighbor_nchw(input, shape, align_corners);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape using bilinear interpolation for NHWC
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize_bilinear_nhwc(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
bool align_corners = false,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
Array<Expr> out_shape;
|
||||
out_shape.push_back(input->shape[0]);
|
||||
out_shape.push_back(shape[0]);
|
||||
out_shape.push_back(shape[1]);
|
||||
out_shape.push_back(input->shape[3]);
|
||||
|
||||
Expr cone = make_const(Int(32), 1);
|
||||
|
||||
auto in_height = as_const_int(input->shape[1]);
|
||||
auto in_width = as_const_int(input->shape[2]);
|
||||
auto out_height = as_const_int(shape[0]);
|
||||
auto out_width = as_const_int(shape[1]);
|
||||
|
||||
Expr y_ratio;
|
||||
Expr x_ratio;
|
||||
|
||||
if (align_corners) {
|
||||
y_ratio = make_const(Float(32), (static_cast<float>(*in_height) /
|
||||
static_cast<float>(*out_height)));
|
||||
x_ratio = make_const(Float(32), (static_cast<float>(*in_width) /
|
||||
static_cast<float>(*out_width)));
|
||||
} else {
|
||||
y_ratio = make_const(Float(32), (static_cast<float>(*in_height - 1) /
|
||||
static_cast<float>(*out_height - 1)));
|
||||
x_ratio = make_const(Float(32), (static_cast<float>(*in_width - 1) /
|
||||
static_cast<float>(*out_width - 1)));
|
||||
}
|
||||
|
||||
Expr other_y = tvm::ir::Simplify(input->shape[1] - cone);
|
||||
Expr other_x = tvm::ir::Simplify(input->shape[2] - cone);
|
||||
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& indices) {
|
||||
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(y_ratio * indices[1]));
|
||||
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(x_ratio * indices[2]));
|
||||
|
||||
auto y1 = tvm::select(((y0 + cone) > other_y), other_y, (y0 + cone));
|
||||
auto x1 = tvm::select(((x0 + cone) > other_x), other_x, (x0 + cone));
|
||||
|
||||
auto h = (y_ratio * indices[1]) - y0;
|
||||
auto w = (x_ratio * indices[2]) - x0;;
|
||||
|
||||
auto A = input(indices[0], y0, x0, indices[3]);
|
||||
auto B = input(indices[0], y0, x1, indices[3]);
|
||||
auto C = input(indices[0], y1, x0, indices[3]);
|
||||
auto D = input(indices[0], y1, x1, indices[3]);
|
||||
|
||||
return (A*(cone-w)*(cone-h) + B*(w)*(cone-h) + C*(h)*(cone-w) + D*w*h);
|
||||
}, name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape using bilinear interpolation for NCHW
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize_bilinear_nchw(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
bool align_corners = false,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
Array<Expr> out_shape;
|
||||
out_shape.push_back(input->shape[0]);
|
||||
out_shape.push_back(input->shape[1]);
|
||||
out_shape.push_back(shape[0]);
|
||||
out_shape.push_back(shape[1]);
|
||||
|
||||
Expr cone = make_const(Int(32), 1);
|
||||
|
||||
auto in_height = as_const_int(input->shape[2]);
|
||||
auto in_width = as_const_int(input->shape[3]);
|
||||
auto out_height = as_const_int(shape[0]);
|
||||
auto out_width = as_const_int(shape[1]);
|
||||
|
||||
Expr y_ratio;
|
||||
Expr x_ratio;
|
||||
|
||||
if (align_corners) {
|
||||
y_ratio = make_const(Float(32), (static_cast<float>(*in_height) /
|
||||
static_cast<float>(*out_height)));
|
||||
x_ratio = make_const(Float(32), (static_cast<float>(*in_width) /
|
||||
static_cast<float>(*out_width)));
|
||||
} else {
|
||||
y_ratio = make_const(Float(32), (static_cast<float>(*in_height - 1) /
|
||||
static_cast<float>(*out_height - 1)));
|
||||
x_ratio = make_const(Float(32), (static_cast<float>(*in_width - 1) /
|
||||
static_cast<float>(*out_width - 1)));
|
||||
}
|
||||
|
||||
Expr other_y = tvm::ir::Simplify(input->shape[2] - cone);
|
||||
Expr other_x = tvm::ir::Simplify(input->shape[3] - cone);
|
||||
|
||||
return compute(
|
||||
out_shape, [&](const Array<Var>& indices) {
|
||||
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(y_ratio * indices[2]));
|
||||
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(x_ratio * indices[3]));
|
||||
|
||||
auto y1 = tvm::select(((y0 + cone) > other_y), other_y, (y0 + cone));
|
||||
auto x1 = tvm::select(((x0 + cone) > other_x), other_x, (x0 + cone));
|
||||
|
||||
auto h = (y_ratio * indices[2]) - y0;
|
||||
auto w = (x_ratio * indices[3]) - x0;;
|
||||
|
||||
auto A = input(indices[0], indices[1], y0, x0);
|
||||
auto B = input(indices[0], indices[1], y0, x1);
|
||||
auto C = input(indices[0], indices[1], y1, x0);
|
||||
auto D = input(indices[0], indices[1], y1, x1);
|
||||
|
||||
return ((A*(cone-w)*(cone-h)) + (B*(w)*(cone-h)) + (C*(h)*(cone-w)) + (D*w*h));
|
||||
}, name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape using bilinear interpolation
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param layout input layout
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize_bilinear(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
std::string layout = "NCHW",
|
||||
bool align_corners = false,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
Tensor ret;
|
||||
|
||||
if (layout == "NHWC") {
|
||||
ret = resize_bilinear_nhwc(input, shape, align_corners);
|
||||
} else {
|
||||
ret = resize_bilinear_nchw(input, shape, align_corners);
|
||||
}
|
||||
|
||||
return cast(ret, input->dtype);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Resize given tensor to given shape
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to resize to.
|
||||
* \param layout input layout
|
||||
* \param align_corners To preserve centers of 4 corner pixels
|
||||
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR)
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor resized to given shape
|
||||
*/
|
||||
inline Tensor resize(const Tensor& input,
|
||||
const Array<Expr>& shape,
|
||||
std::string layout = "NCHW",
|
||||
bool align_corners = false,
|
||||
std::string mode = "BILINEAR",
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
if (mode == "NEAREST_NEIGHBOR") {
|
||||
return resize_nearest_neighbor(input, shape, layout, align_corners);
|
||||
} else {
|
||||
return resize_bilinear(input, shape, layout, align_corners);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace image
|
||||
} // namespace topi
|
||||
#endif // TOPI_IMAGE_RESIZE_H_
|
|
@ -0,0 +1,44 @@
|
|||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file topi/nn/upsampling.h
|
||||
* \brief upsampling op constructors
|
||||
*/
|
||||
#ifndef TOPI_NN_UPSAMPLING_H_
|
||||
#define TOPI_NN_UPSAMPLING_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
|
||||
#include "topi/image/resize.h"
|
||||
|
||||
namespace topi {
|
||||
namespace nn {
|
||||
using namespace tvm;
|
||||
using namespace topi::image;
|
||||
|
||||
/*!
|
||||
* \brief Upsample given tensor to given shape
|
||||
*
|
||||
* \param input The input tensor.
|
||||
* \param shape Output shape to upsample.
|
||||
* \param layout input layout
|
||||
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR)
|
||||
* \param name Name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor upsampled to given shape
|
||||
*/
|
||||
inline Tensor upsampling(const Tensor& input,
|
||||
const Array<Expr> shape,
|
||||
std::string layout = "NCHW",
|
||||
std::string mode = "NEAREST_NEIGHBOR",
|
||||
std::string name = "tensor",
|
||||
std::string tag = kInjective) {
|
||||
return resize(input, shape, layout, false, mode);
|
||||
}
|
||||
|
||||
} // namespace nn
|
||||
} // namespace topi
|
||||
#endif // TOPI_NN_UPSAMPLING_H_
|
|
@ -30,6 +30,7 @@ from . import opengl
|
|||
from . import util
|
||||
from . import rocm
|
||||
from . import vision
|
||||
from . import image
|
||||
# not import testing by default
|
||||
# because testing can have extra deps that are not necessary
|
||||
# we can import them from test cases explicitly
|
||||
|
|
|
@ -50,6 +50,8 @@ vision = _create_module("vision")
|
|||
_init_api_prefix("topi.cpp.vision", "topi.vision")
|
||||
yolo2 = _create_module("vision.yolo2")
|
||||
_init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2")
|
||||
image = _create_module("image")
|
||||
_init_api_prefix("topi.cpp.image", "topi.image")
|
||||
|
||||
class IntVector(object):
|
||||
"""Handle to std::vector<int> instance """
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# pylint: disable=wildcard-import
|
||||
"""IMAGE network operators"""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
from .resize import *
|
|
@ -0,0 +1,33 @@
|
|||
"""TVM operator input resize compute."""
|
||||
from __future__ import absolute_import
|
||||
import topi
|
||||
|
||||
def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"):
|
||||
"""Perform resize operation on the data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : tvm.Tensor
|
||||
inputs is a 4-D tensor with shape
|
||||
[batch, channel, in_height, in_width]
|
||||
or [batch, in_height, in_width, channel]
|
||||
|
||||
size: Tuple
|
||||
Output resolution scale to
|
||||
|
||||
layout: string, optional
|
||||
either "NCHW" or "NHWC"
|
||||
|
||||
align_corners: Boolean, optional
|
||||
To preserve the values at the corner pixels
|
||||
|
||||
method: {"BILINEAR", "NEAREST_NEIGHBOR"}
|
||||
Method to be used for resizing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : tvm.Tensor
|
||||
4-D with shape [batch, channel, in_height*scale, in_width*scale]
|
||||
or [batch, in_height*scale, in_width*scale, channel]
|
||||
"""
|
||||
return topi.cpp.image.resize(data, size, layout, align_corners, method)
|
|
@ -1,25 +1,28 @@
|
|||
"""TVM operator upsampling compute."""
|
||||
from __future__ import absolute_import
|
||||
import tvm
|
||||
from .. import util
|
||||
import topi
|
||||
|
||||
|
||||
def upsampling(data, scale, layout="NCHW"):
|
||||
"""Perform nearest neighbor upsampling on the data.
|
||||
Bilinear upsampling is not supported.
|
||||
def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
|
||||
"""Perform upsampling on the data.
|
||||
Nearest neighbor and bilinear upsampling are supported.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.Tensor
|
||||
4-D with shape [batch, channel, in_height, in_width]
|
||||
inputs : tvm.Tensor
|
||||
inputs is a 4-D tensor with shape
|
||||
[batch, channel, in_height, in_width]
|
||||
or [batch, in_height, in_width, channel]
|
||||
|
||||
scale: int
|
||||
upsampling scaling factor
|
||||
scale : int
|
||||
Scaling factor
|
||||
|
||||
layout: string
|
||||
layout : string, optional
|
||||
either "NCHW" or "NHWC"
|
||||
|
||||
method : {"BILINEAR", "NEAREST_NEIGHBOR"}
|
||||
Method to be used for upsampling.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : tvm.Tensor
|
||||
|
@ -28,53 +31,10 @@ def upsampling(data, scale, layout="NCHW"):
|
|||
"""
|
||||
|
||||
if layout == "NCHW":
|
||||
return upsampling_nchw(data, scale)
|
||||
out_shape = (data.shape[2] * scale, data.shape[3] * scale)
|
||||
elif layout == "NHWC":
|
||||
return upsampling_nhwc(data, scale)
|
||||
out_shape = (data.shape[1] * scale, data.shape[2] * scale)
|
||||
else:
|
||||
raise ValueError("not support this layout {} yet".format(layout))
|
||||
|
||||
|
||||
def upsampling_nchw(data, scale):
|
||||
"""Perform nearest neighor upsampling on NCHW layout input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.Tensor
|
||||
4-D with shape [batch, channel, in_height, in_width]
|
||||
|
||||
scale: int
|
||||
upsampling scaling factor
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : tvm.Tensor
|
||||
4-D with shape [batch, channel, in_height*scale, in_width*scale]
|
||||
"""
|
||||
batch, channel, height, width = data.shape
|
||||
out_height = util.simplify(height * scale)
|
||||
out_width = util.simplify(width * scale)
|
||||
|
||||
return tvm.compute((batch, channel, out_height, out_width), \
|
||||
lambda n, c, h, w: data[n, c, h/scale, w/scale])
|
||||
|
||||
|
||||
def upsampling_nhwc(data, scale):
|
||||
"""Perform nearest neighor upsampling on NHWC layout input.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.Tensor
|
||||
4-D with shape [batch, in_height, in_width, channel]
|
||||
|
||||
scale: int
|
||||
upsampling scaling factor
|
||||
|
||||
"""
|
||||
|
||||
batch, height, width, channel = data.shape
|
||||
out_height = util.simplify(height * scale)
|
||||
out_width = util.simplify(width * scale)
|
||||
|
||||
return tvm.compute((batch, out_height, out_width, channel), \
|
||||
lambda n, h, w, c: data[n, h/scale, w/scale, c])
|
||||
return topi.cpp.nn.upsampling(data, out_shape, layout, method)
|
||||
|
|
|
@ -12,6 +12,7 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_con
|
|||
from .dilate_python import dilate_python
|
||||
from .softmax_python import softmax_python, log_softmax_python
|
||||
from .upsampling_python import upsampling_python
|
||||
from .bilinear_resize_python import bilinear_resize_python
|
||||
from .reorg_python import reorg_python
|
||||
from .region_python import region_python
|
||||
from .shortcut_python import shortcut_python
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
|
||||
"""Bilinear Scale in python"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def bilinear_weights(height, width, new_h, new_w, align_corners=False):
|
||||
""" Helper function to generate weights for bilinear scaling """
|
||||
|
||||
if align_corners:
|
||||
x_ratio = np.float32(np.float32(width)/np.float32(new_w))
|
||||
y_ratio = np.float32(np.float32(height)/np.float32(new_h))
|
||||
else:
|
||||
x_ratio = np.float32(np.float32(width-1)/np.float32(new_w-1))
|
||||
y_ratio = np.float32(np.float32(height-1)/np.float32(new_h-1))
|
||||
|
||||
def _bilinear_interpolation(y, x):
|
||||
x_coord = math.floor(x_ratio * x)
|
||||
y_coord = math.floor(y_ratio * y)
|
||||
x_diff = np.float32((x_ratio * x) - x_coord)
|
||||
y_diff = np.float32((y_ratio * y) - y_coord)
|
||||
|
||||
return [y_coord, x_coord, y_diff, x_diff]
|
||||
|
||||
# weights to hold (srcx, srcy, x_diff, y_diff) for each out value.
|
||||
weights = np.empty([new_h, new_w, 4], dtype='float32')
|
||||
|
||||
for i in range(new_h):
|
||||
for j in range(new_w):
|
||||
weights[i][j] = _bilinear_interpolation(i, j)
|
||||
return weights
|
||||
|
||||
def bilinear_resize_python(image, out_size, layout, align_corners=False):
|
||||
""" Bilinear scaling using python"""
|
||||
(new_h, new_w) = out_size
|
||||
|
||||
if layout == 'NHWC':
|
||||
(batch, h, w, channel) = image.shape
|
||||
scaled_image = np.ones((batch, new_h, new_w, channel))
|
||||
else:
|
||||
(batch, channel, h, w) = image.shape
|
||||
scaled_image = np.ones((batch, channel, new_h, new_w))
|
||||
|
||||
weights = bilinear_weights(h, w, new_h, new_w, align_corners)
|
||||
|
||||
for b in range(batch):
|
||||
for i in range(channel):
|
||||
for j in range(new_h):
|
||||
for k in range(new_w):
|
||||
y0 = int(weights[j][k][0])
|
||||
x0 = int(weights[j][k][1])
|
||||
|
||||
x1 = min((x0+1), (w-1))
|
||||
y1 = min((y0+1), (h-1))
|
||||
|
||||
y_diff = weights[j][k][2]
|
||||
x_diff = weights[j][k][3]
|
||||
|
||||
if layout == 'NHWC':
|
||||
A = image[b][y0][x0][i]
|
||||
B = image[b][y0][x1][i]
|
||||
C = image[b][y1][x0][i]
|
||||
D = image[b][y1][x1][i]
|
||||
else:
|
||||
A = image[b][i][y0][x0]
|
||||
B = image[b][i][y0][x1]
|
||||
C = image[b][i][y1][x0]
|
||||
D = image[b][i][y1][x1]
|
||||
|
||||
pixel = np.float32((A*(1-x_diff)*(1-y_diff) +
|
||||
B*(x_diff)*(1-y_diff) +
|
||||
C*(y_diff)*(1-x_diff) +
|
||||
D*(x_diff*y_diff)))
|
||||
|
||||
if layout == 'NHWC':
|
||||
scaled_image[b][j][k][i] = pixel
|
||||
else:
|
||||
scaled_image[b][i][j][k] = pixel
|
||||
|
||||
return scaled_image
|
|
@ -3,13 +3,26 @@
|
|||
import numpy as np
|
||||
|
||||
def upsample_nearest(arr, scale):
|
||||
""" Populate the array by scale factor"""
|
||||
return arr.repeat(scale, axis=0).repeat(scale, axis=1)
|
||||
|
||||
def upsampling_python(data, scale):
|
||||
def upsampling_python(data, scale, layout='NCHW'):
|
||||
""" Python version of scaling using nearest neighbour """
|
||||
|
||||
ishape = data.shape
|
||||
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
|
||||
output_np = np.zeros(oshape, dtype=data.dtype)
|
||||
for b in range(oshape[0]):
|
||||
for c in range(oshape[1]):
|
||||
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
|
||||
return output_np
|
||||
if layout == 'NCHW':
|
||||
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
|
||||
output_np = np.zeros(oshape, dtype=data.dtype)
|
||||
for b in range(oshape[0]):
|
||||
for c in range(oshape[1]):
|
||||
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
|
||||
return output_np
|
||||
elif layout == 'NHWC':
|
||||
oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3])
|
||||
output_np = np.zeros(oshape, dtype=data.dtype)
|
||||
for b in range(oshape[0]):
|
||||
for c in range(oshape[3]):
|
||||
output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale)
|
||||
return output_np
|
||||
else:
|
||||
raise ValueError("not support this layout {} yet".format(layout))
|
||||
|
|
|
@ -23,8 +23,10 @@
|
|||
#include <topi/nn/mapping.h>
|
||||
#include <topi/nn/pooling.h>
|
||||
#include <topi/nn/softmax.h>
|
||||
#include <topi/nn/upsampling.h>
|
||||
|
||||
#include <topi/vision/reorg.h>
|
||||
#include <topi/image/resize.h>
|
||||
#include <topi/vision/yolo2/region.h>
|
||||
#include <topi/generic/default.h>
|
||||
#include <topi/generic/extern.h>
|
||||
|
@ -285,6 +287,12 @@ TVM_REGISTER_GLOBAL("topi.strided_slice")
|
|||
*rv = strided_slice(args[0], args[1], args[2], args[3]);
|
||||
});
|
||||
|
||||
/* Ops from nn/upsampling.h */
|
||||
TVM_REGISTER_GLOBAL("topi.nn.upsampling")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = nn::upsampling(args[0], args[1], args[2], args[3]);
|
||||
});
|
||||
|
||||
/* Ops from nn/batch_norm.h */
|
||||
TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
|
@ -366,10 +374,18 @@ TVM_REGISTER_GLOBAL("topi.vision.reorg")
|
|||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = vision::reorg(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.vision.yolo2.region")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = vision::yolo2::region(args[0], args[1], args[2], args[3], args[4], args[5]);
|
||||
});
|
||||
|
||||
/* Ops from image/resize.h */
|
||||
TVM_REGISTER_GLOBAL("topi.image.resize")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = image::resize(args[0], args[1], args[2], args[3], args[4]);
|
||||
});
|
||||
|
||||
/* Generic schedules */
|
||||
TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
"""Test code for bilinear scale """
|
||||
import numpy as np
|
||||
import tvm
|
||||
import topi
|
||||
import topi.testing
|
||||
import math
|
||||
|
||||
def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False):
|
||||
|
||||
if layout == 'NCHW':
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
|
||||
dtype = A.dtype
|
||||
out_shape = (batch, in_channel, out_height, out_width)
|
||||
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
|
||||
elif layout == 'NHWC':
|
||||
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='float32')
|
||||
dtype = A.dtype
|
||||
out_shape = (batch, out_height, out_width, in_channel)
|
||||
a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Layout not supported {} '.format(layout))
|
||||
|
||||
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners)
|
||||
|
||||
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_injective(B)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
f(a, b)
|
||||
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
|
||||
|
||||
for device in ['llvm', 'cuda', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_resize():
|
||||
# Scale NCHW
|
||||
verify_bilinear_scale(4, 16, 32, 32, 50, 50, 'NCHW')
|
||||
# Scale NCHW + Align Corners
|
||||
verify_bilinear_scale(6, 32, 64, 64, 20, 20, 'NCHW', True)
|
||||
# Scale NHWC
|
||||
verify_bilinear_scale(4, 16, 32, 32, 50, 50, "NHWC")
|
||||
# Scale NHWC + Align Corners
|
||||
verify_bilinear_scale(6, 32, 64, 64, 20, 20, "NHWC", True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_resize()
|
|
@ -5,14 +5,26 @@ import topi
|
|||
import topi.testing
|
||||
import math
|
||||
|
||||
def verify_upsampling(batch, in_channel, in_height, in_width, scale):
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
|
||||
B = topi.nn.upsampling(A, scale)
|
||||
out_shape = (batch, in_channel, in_height*scale, in_width*scale)
|
||||
dtype = A.dtype
|
||||
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW'):
|
||||
|
||||
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
|
||||
b_np = topi.testing.upsampling_python(a_np, scale)
|
||||
|
||||
if layout == 'NCHW':
|
||||
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
|
||||
dtype = A.dtype
|
||||
out_shape = (batch, in_channel, in_height*scale, in_width*scale)
|
||||
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
|
||||
elif layout == 'NHWC':
|
||||
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
|
||||
dtype = A.dtype
|
||||
out_shape = (batch, in_height*scale, in_width*scale, in_channel)
|
||||
a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Layout not supported {} '.format(layout))
|
||||
|
||||
B = topi.nn.upsampling(A, scale, layout=layout)
|
||||
|
||||
b_np = topi.testing.upsampling_python(a_np, scale, layout)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
|
@ -33,8 +45,12 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
|
|||
check_device(device)
|
||||
|
||||
def test_upsampling():
|
||||
# NCHW
|
||||
verify_upsampling(8, 16, 32, 32, 2)
|
||||
verify_upsampling(12, 32, 64, 64, 3)
|
||||
# NHWC
|
||||
verify_upsampling(8, 16, 32, 32, 2, "NHWC")
|
||||
verify_upsampling(12, 32, 64, 64, 3, "NHWC")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_upsampling()
|
||||
|
|
Загрузка…
Ссылка в новой задаче