[TOPI][IMAGE][RESIZE] Bilinear interpolation for resize and upsampling. (#1181)

This commit is contained in:
Siva 2018-06-14 21:23:49 +05:30 коммит произвёл Tianqi Chen
Родитель 758cb7519c
Коммит 76fa3ca4f6
22 изменённых файлов: 890 добавлений и 94 удалений

Просмотреть файл

@ -51,6 +51,7 @@ List of operators
topi.broadcast_div
topi.broadcast_maximum
topi.broadcast_minimum
topi.image.resize
List of schedules
@ -114,6 +115,10 @@ topi.nn
.. autofunction:: topi.nn.depthwise_conv2d_nchw
.. autofunction:: topi.nn.depthwise_conv2d_nhwc
topi.image
~~~~~~~~~~
.. autofunction:: topi.image.resize
topi.generic
~~~~~~~~~~~~

Просмотреть файл

@ -288,16 +288,22 @@ struct GlobalPool2DParam : public dmlc::Parameter<GlobalPool2DParam> {
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
int scale;
std::string layout;
std::string method;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.describe("upsampling scaling factor");
DMLC_DECLARE_FIELD(layout)
.set_default("NCHW")
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
.describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"dimensions respectively. Upsampling is applied on the 'H' and"
"'W' dimensions.");
DMLC_DECLARE_FIELD(method)
.set_default("NEAREST_NEIGHBOR")
.describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation");
}
};

Просмотреть файл

@ -8,6 +8,7 @@ from . import nn
from . import transform
from . import reduction
from . import vision
from . import image
from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern

Просмотреть файл

@ -0,0 +1,17 @@
# pylint: disable=invalid-name, unused-argument
"""Definition of image ops"""
from __future__ import absolute_import
import topi
import tvm
from . import registry as reg
from .registry import OpPattern
# resize
@reg.register_schedule("resize")
def schedule_resize(_, outs, target):
"""Schedule definition of resize"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)
reg.register_pattern("resize", OpPattern.INJECTIVE)

Просмотреть файл

@ -235,20 +235,10 @@ def schedule_global_avg_pool2d(_, outs, target):
reg.register_pattern("global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("upsampling")
def compute_upsampling(attrs, inputs, _):
"""Compute definition of upsampling"""
scale = attrs.get_int("scale")
layout = attrs["layout"]
if layout:
assert layout == "NCHW" or layout == "NHWC"
return topi.nn.upsampling(inputs[0], scale, layout)
return topi.nn.upsampling(inputs[0], scale)
# upsampling
@reg.register_schedule("upsampling")
def schedule_upsampling(_, outs, target):
"""Compute definition of upsampling"""
"""Schedule definition of upsampling"""
with tvm.target.create(target):
return topi.generic.schedule_injective(outs)

Просмотреть файл

@ -0,0 +1,113 @@
/*!
* Copyright (c) 2017 by Contributors
* \file resize.cc
* \brief Property def of resize operators.
*/
#include <tvm/tvm.h>
#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include "../nn/nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/elemwise.h"
#include "topi/transform.h"
#include "topi/image/resize.h"
#include "resize.h"
namespace nnvm {
namespace top {
using tvm::Expr;
using tvm::Array;
using tvm::Tensor;
using nnvm::compiler::FTVMCompute;
DMLC_REGISTER_PARAMETER(ResizeParam);
inline bool ResizeInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, param.layout, kNCHW);
TShape oshape = dshape;
if (param.layout == "NCHW") {
oshape[2] = param.size[0];
oshape[3] = param.size[1];
} else {
oshape[1] = param.size[0];
oshape[2] = param.size[1];
}
oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
inline bool ResizeLayout(const NodeAttrs& attrs,
std::vector<Layout> *in_layouts,
const std::vector<Layout> *last_in_layouts,
std::vector<Layout> *out_layouts) {
const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
CHECK_EQ(in_layouts->size(), 1U);
CHECK_EQ(out_layouts->size(), 1U);
const Layout layout(param.layout);
NNVM_ASSIGN_LAYOUT(*in_layouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*out_layouts, 0, layout);
return true;
}
NNVM_REGISTER_OP(resize)
.describe(R"(Perform resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, size[0], size[1])
for layout NHWC
(batch_size, size[0], size[1], channels)
)" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_arguments(ResizeParam::__FIELDS__())
.set_attr_parser(ParamParser<ResizeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ResizeParam>)
.set_attr<FInferShape>("FInferShape", ResizeInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ResizeLayout)
.set_num_outputs(1)
.set_num_inputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ResizeParam& param = nnvm::get<ResizeParam>(attrs.parsed);
Array<Expr> oshape;
if (param.layout == "NCHW") {
oshape.push_back(out_info[0]->shape[2]);
oshape.push_back(out_info[0]->shape[3]);
} else {
oshape.push_back(out_info[0]->shape[1]);
oshape.push_back(out_info[0]->shape[2]);
}
return Array<Tensor>{ topi::image::resize(inputs[0], oshape, param.layout,
param.align_corners, param.method)};
})
.set_support_level(2);
} // namespace top
} // namespace nnvm

Просмотреть файл

@ -0,0 +1,45 @@
/*!
* Copyright (c) 2018 by Contributors
* \file resize.h
*/
#ifndef NNVM_TOP_IMAGE_RESIZE_H_
#define NNVM_TOP_IMAGE_RESIZE_H_
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include <sstream>
namespace nnvm {
namespace top {
struct ResizeParam : public dmlc::Parameter<ResizeParam> {
TShape size;
std::string layout;
std::string method;
bool align_corners;
DMLC_DECLARE_PARAMETER(ResizeParam) {
DMLC_DECLARE_FIELD(size)
.describe("Output size");
DMLC_DECLARE_FIELD(layout)
.set_default("NCHW")
.describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
DMLC_DECLARE_FIELD(method)
.set_default("BILINEAR")
.describe("Specify the mode to use for scaling."
"NEAREST_NEIGHBOR - Nearest Neighbor"
"BILINEAR - Bilinear Interpolation");
DMLC_DECLARE_FIELD(align_corners)
.set_default(false)
.describe("Should be true to preserve the values at the corner pixels");
}
};
} // namespace top
} // namespace nnvm
#endif // NNVM_TOP_IMAGE_RESIZE_H_

Просмотреть файл

@ -1,8 +1,12 @@
/*!
* Copyright (c) 2017 by Contributors
* \file pooling.cc
* \brief Property def of pooling operators.
* \file upsampling.cc
* \brief Property def of upsampling operators.
*/
#include <tvm/tvm.h>
#include <tvm/expr.h>
#include <nnvm/layout.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
@ -10,27 +14,36 @@
#include "./nn_common.h"
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/elemwise.h"
#include "topi/transform.h"
#include "topi/nn/upsampling.h"
namespace nnvm {
namespace top {
using tvm::Expr;
using tvm::Array;
using tvm::Tensor;
using nnvm::compiler::FTVMCompute;
DMLC_REGISTER_PARAMETER(UpSamplingParam);
inline bool UpSamplingInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[0];
if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, param.layout, kNCHW);
TShape oshape = dshape;
oshape[2] = oshape[2] * param.scale;
oshape[3] = oshape[3] * param.scale;
oshape = ConvertLayout(oshape, kNCHW, param.layout);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
@ -48,10 +61,18 @@ inline bool UpsamplingLayout(const NodeAttrs& attrs,
}
NNVM_REGISTER_OP(upsampling)
.describe(R"(Perform nearest neighbor upsampling to input array.
.describe(R"(Perform upsampling to input array with nearest neighbour or bilinear interpolation.
- **data**: Input is 4D array of shape (batch_size, channels, in_height, in_width).
- **out**: Output is 4D array of shape (batch_size, channels, in_height*scale, in_width*scale).
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, in_height*scale, in_width*scale)
for layout NHWC
(batch_size, in_height*scale, in_width*scale, channels)
)" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
@ -63,6 +84,22 @@ NNVM_REGISTER_OP(upsampling)
.set_attr<FCorrectLayout>("FCorrectLayout", UpsamplingLayout)
.set_num_outputs(1)
.set_num_inputs(1)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const UpSamplingParam& param = nnvm::get<UpSamplingParam>(attrs.parsed);
Array<Expr> oshape;
if (param.layout == "NCHW") {
oshape.push_back(out_info[0]->shape[2]);
oshape.push_back(out_info[0]->shape[3]);
} else {
oshape.push_back(out_info[0]->shape[1]);
oshape.push_back(out_info[0]->shape[2]);
}
return Array<Tensor>{ topi::nn::upsampling(inputs[0], oshape, param.layout, param.method)};
})
.set_support_level(2);
} // namespace top

Просмотреть файл

@ -210,7 +210,7 @@ def test_global_avg_pool2d():
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_upsampling():
def test_upsampling_nearest_neighbor():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, name="y")
@ -225,9 +225,46 @@ def test_upsampling():
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.upsampling_python(a_np, scale)
b_np = topi.testing.upsampling_python(a_np, scale, "NCHW")
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_upsampling_bilinear():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NCHW")
dtype = "float32"
dshape = (1, 4, 32, 32)
oshape = (1, 4, 32*scale, 32*scale)
shape_dict = {"x": dshape}
dtype_dict = {"x": dtype}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict, dtype_dict)
m = graph_runtime.create(graph, lib, ctx)
a_np = np.random.uniform(size=dshape).astype(dtype)
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NCHW")
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
def test_resize_bilinear():
x = sym.Variable("x")
scale = 2
y = sym.upsampling(x, scale=scale, method="BILINEAR", name="y", layout="NHWC")
dtype = "float32"
dshape = (1, 32, 32, 4)
oshape = (1, 32*scale, 32*scale, 4)
shape_dict = {"x": dshape}
dtype_dict = {"x": dtype}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict, dtype_dict)
m = graph_runtime.create(graph, lib, ctx)
a_np = np.random.uniform(size=dshape).astype(dtype)
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.bilinear_resize_python(a_np, (32*scale, 32*scale), "NHWC")
np.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
test_conv2d()
@ -239,4 +276,6 @@ if __name__ == "__main__":
test_avg_pool2d_no_count_pad()
test_global_max_pool2d()
test_global_avg_pool2d()
test_upsampling()
test_upsampling_nearest_neighbor()
test_upsampling_bilinear()
test_resize_bilinear()

Просмотреть файл

@ -0,0 +1,316 @@
/*!
* Copyright (c) 2017 by Contributors
* \file topi/image/resize.h
* \brief image resize constructors
*/
#ifndef TOPI_IMAGE_RESIZE_H_
#define TOPI_IMAGE_RESIZE_H_
#include <string>
#include <vector>
#include <iterator>
#include <algorithm>
#include "topi/tags.h"
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "tvm/tvm.h"
namespace topi {
namespace image {
using namespace tvm;
/*!
* \brief Resize given tensor to given shape using nearest neighbour for NHWC
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[3]);
Expr h_ratio = shape[0] / input->shape[1];
Expr w_ratio = shape[1] / input->shape[2];
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1] / h_ratio);
idx.push_back(indices[2] / w_ratio);
idx.push_back(indices[3]);
return input(idx);
}, name, tag);
}
/*!
* \brief Resize given tensor to given shape using nearest neighbour for NCHW
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
return input(idx);
}, name, tag);
}
/*!
* \brief Resize given tensor to given shape using nearest neighbour
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param layout input layout
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor(const Tensor& input,
const Array<Expr>& shape,
std::string layout = "NCHW",
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";
if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
} else {
return resize_nearest_neighbor_nchw(input, shape, align_corners);
}
}
/*!
* \brief Resize given tensor to given shape using bilinear interpolation for NHWC
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_bilinear_nhwc(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[3]);
Expr cone = make_const(Int(32), 1);
auto in_height = as_const_int(input->shape[1]);
auto in_width = as_const_int(input->shape[2]);
auto out_height = as_const_int(shape[0]);
auto out_width = as_const_int(shape[1]);
Expr y_ratio;
Expr x_ratio;
if (align_corners) {
y_ratio = make_const(Float(32), (static_cast<float>(*in_height) /
static_cast<float>(*out_height)));
x_ratio = make_const(Float(32), (static_cast<float>(*in_width) /
static_cast<float>(*out_width)));
} else {
y_ratio = make_const(Float(32), (static_cast<float>(*in_height - 1) /
static_cast<float>(*out_height - 1)));
x_ratio = make_const(Float(32), (static_cast<float>(*in_width - 1) /
static_cast<float>(*out_width - 1)));
}
Expr other_y = tvm::ir::Simplify(input->shape[1] - cone);
Expr other_x = tvm::ir::Simplify(input->shape[2] - cone);
return compute(
out_shape, [&](const Array<Var>& indices) {
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(y_ratio * indices[1]));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(x_ratio * indices[2]));
auto y1 = tvm::select(((y0 + cone) > other_y), other_y, (y0 + cone));
auto x1 = tvm::select(((x0 + cone) > other_x), other_x, (x0 + cone));
auto h = (y_ratio * indices[1]) - y0;
auto w = (x_ratio * indices[2]) - x0;;
auto A = input(indices[0], y0, x0, indices[3]);
auto B = input(indices[0], y0, x1, indices[3]);
auto C = input(indices[0], y1, x0, indices[3]);
auto D = input(indices[0], y1, x1, indices[3]);
return (A*(cone-w)*(cone-h) + B*(w)*(cone-h) + C*(h)*(cone-w) + D*w*h);
}, name, tag);
}
/*!
* \brief Resize given tensor to given shape using bilinear interpolation for NCHW
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_bilinear_nchw(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
Expr cone = make_const(Int(32), 1);
auto in_height = as_const_int(input->shape[2]);
auto in_width = as_const_int(input->shape[3]);
auto out_height = as_const_int(shape[0]);
auto out_width = as_const_int(shape[1]);
Expr y_ratio;
Expr x_ratio;
if (align_corners) {
y_ratio = make_const(Float(32), (static_cast<float>(*in_height) /
static_cast<float>(*out_height)));
x_ratio = make_const(Float(32), (static_cast<float>(*in_width) /
static_cast<float>(*out_width)));
} else {
y_ratio = make_const(Float(32), (static_cast<float>(*in_height - 1) /
static_cast<float>(*out_height - 1)));
x_ratio = make_const(Float(32), (static_cast<float>(*in_width - 1) /
static_cast<float>(*out_width - 1)));
}
Expr other_y = tvm::ir::Simplify(input->shape[2] - cone);
Expr other_x = tvm::ir::Simplify(input->shape[3] - cone);
return compute(
out_shape, [&](const Array<Var>& indices) {
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(y_ratio * indices[2]));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(x_ratio * indices[3]));
auto y1 = tvm::select(((y0 + cone) > other_y), other_y, (y0 + cone));
auto x1 = tvm::select(((x0 + cone) > other_x), other_x, (x0 + cone));
auto h = (y_ratio * indices[2]) - y0;
auto w = (x_ratio * indices[3]) - x0;;
auto A = input(indices[0], indices[1], y0, x0);
auto B = input(indices[0], indices[1], y0, x1);
auto C = input(indices[0], indices[1], y1, x0);
auto D = input(indices[0], indices[1], y1, x1);
return ((A*(cone-w)*(cone-h)) + (B*(w)*(cone-h)) + (C*(h)*(cone-w)) + (D*w*h));
}, name, tag);
}
/*!
* \brief Resize given tensor to given shape using bilinear interpolation
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param layout input layout
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_bilinear(const Tensor& input,
const Array<Expr>& shape,
std::string layout = "NCHW",
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Tensor ret;
if (layout == "NHWC") {
ret = resize_bilinear_nhwc(input, shape, align_corners);
} else {
ret = resize_bilinear_nchw(input, shape, align_corners);
}
return cast(ret, input->dtype);
}
/*!
* \brief Resize given tensor to given shape
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param layout input layout
* \param align_corners To preserve centers of 4 corner pixels
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize(const Tensor& input,
const Array<Expr>& shape,
std::string layout = "NCHW",
bool align_corners = false,
std::string mode = "BILINEAR",
std::string name = "tensor",
std::string tag = kInjective) {
if (mode == "NEAREST_NEIGHBOR") {
return resize_nearest_neighbor(input, shape, layout, align_corners);
} else {
return resize_bilinear(input, shape, layout, align_corners);
}
}
} // namespace image
} // namespace topi
#endif // TOPI_IMAGE_RESIZE_H_

Просмотреть файл

@ -0,0 +1,44 @@
/*!
* Copyright (c) 2017 by Contributors
* \file topi/nn/upsampling.h
* \brief upsampling op constructors
*/
#ifndef TOPI_NN_UPSAMPLING_H_
#define TOPI_NN_UPSAMPLING_H_
#include <string>
#include <vector>
#include <iterator>
#include <algorithm>
#include "topi/image/resize.h"
namespace topi {
namespace nn {
using namespace tvm;
using namespace topi::image;
/*!
* \brief Upsample given tensor to given shape
*
* \param input The input tensor.
* \param shape Output shape to upsample.
* \param layout input layout
* \param mode Angorithm to use (NEAREST_NEIGHBOR / BILINEAR)
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor upsampled to given shape
*/
inline Tensor upsampling(const Tensor& input,
const Array<Expr> shape,
std::string layout = "NCHW",
std::string mode = "NEAREST_NEIGHBOR",
std::string name = "tensor",
std::string tag = kInjective) {
return resize(input, shape, layout, false, mode);
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_UPSAMPLING_H_

Просмотреть файл

@ -30,6 +30,7 @@ from . import opengl
from . import util
from . import rocm
from . import vision
from . import image
# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly

Просмотреть файл

@ -50,6 +50,8 @@ vision = _create_module("vision")
_init_api_prefix("topi.cpp.vision", "topi.vision")
yolo2 = _create_module("vision.yolo2")
_init_api_prefix("topi.cpp.vision.yolo2", "topi.vision.yolo2")
image = _create_module("image")
_init_api_prefix("topi.cpp.image", "topi.image")
class IntVector(object):
"""Handle to std::vector<int> instance """

Просмотреть файл

@ -0,0 +1,5 @@
# pylint: disable=wildcard-import
"""IMAGE network operators"""
from __future__ import absolute_import as _abs
from .resize import *

Просмотреть файл

@ -0,0 +1,33 @@
"""TVM operator input resize compute."""
from __future__ import absolute_import
import topi
def resize(data, size, layout="NCHW", align_corners=False, method="BILINEAR"):
"""Perform resize operation on the data.
Parameters
----------
inputs : tvm.Tensor
inputs is a 4-D tensor with shape
[batch, channel, in_height, in_width]
or [batch, in_height, in_width, channel]
size: Tuple
Output resolution scale to
layout: string, optional
either "NCHW" or "NHWC"
align_corners: Boolean, optional
To preserve the values at the corner pixels
method: {"BILINEAR", "NEAREST_NEIGHBOR"}
Method to be used for resizing.
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel]
"""
return topi.cpp.image.resize(data, size, layout, align_corners, method)

Просмотреть файл

@ -1,25 +1,28 @@
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import tvm
from .. import util
import topi
def upsampling(data, scale, layout="NCHW"):
"""Perform nearest neighbor upsampling on the data.
Bilinear upsampling is not supported.
def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
"""Perform upsampling on the data.
Nearest neighbor and bilinear upsampling are supported.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
inputs : tvm.Tensor
inputs is a 4-D tensor with shape
[batch, channel, in_height, in_width]
or [batch, in_height, in_width, channel]
scale: int
upsampling scaling factor
scale : int
Scaling factor
layout: string
layout : string, optional
either "NCHW" or "NHWC"
method : {"BILINEAR", "NEAREST_NEIGHBOR"}
Method to be used for upsampling.
Returns
-------
output : tvm.Tensor
@ -28,53 +31,10 @@ def upsampling(data, scale, layout="NCHW"):
"""
if layout == "NCHW":
return upsampling_nchw(data, scale)
out_shape = (data.shape[2] * scale, data.shape[3] * scale)
elif layout == "NHWC":
return upsampling_nhwc(data, scale)
out_shape = (data.shape[1] * scale, data.shape[2] * scale)
else:
raise ValueError("not support this layout {} yet".format(layout))
def upsampling_nchw(data, scale):
"""Perform nearest neighor upsampling on NCHW layout input.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
scale: int
upsampling scaling factor
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, in_height*scale, in_width*scale]
"""
batch, channel, height, width = data.shape
out_height = util.simplify(height * scale)
out_width = util.simplify(width * scale)
return tvm.compute((batch, channel, out_height, out_width), \
lambda n, c, h, w: data[n, c, h/scale, w/scale])
def upsampling_nhwc(data, scale):
"""Perform nearest neighor upsampling on NHWC layout input.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, in_height, in_width, channel]
scale: int
upsampling scaling factor
"""
batch, height, width, channel = data.shape
out_height = util.simplify(height * scale)
out_width = util.simplify(width * scale)
return tvm.compute((batch, out_height, out_width, channel), \
lambda n, h, w, c: data[n, h/scale, w/scale, c])
return topi.cpp.nn.upsampling(data, out_shape, layout, method)

Просмотреть файл

@ -12,6 +12,7 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_con
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
from .upsampling_python import upsampling_python
from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python
from .region_python import region_python
from .shortcut_python import shortcut_python

Просмотреть файл

@ -0,0 +1,79 @@
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Bilinear Scale in python"""
import math
import numpy as np
def bilinear_weights(height, width, new_h, new_w, align_corners=False):
""" Helper function to generate weights for bilinear scaling """
if align_corners:
x_ratio = np.float32(np.float32(width)/np.float32(new_w))
y_ratio = np.float32(np.float32(height)/np.float32(new_h))
else:
x_ratio = np.float32(np.float32(width-1)/np.float32(new_w-1))
y_ratio = np.float32(np.float32(height-1)/np.float32(new_h-1))
def _bilinear_interpolation(y, x):
x_coord = math.floor(x_ratio * x)
y_coord = math.floor(y_ratio * y)
x_diff = np.float32((x_ratio * x) - x_coord)
y_diff = np.float32((y_ratio * y) - y_coord)
return [y_coord, x_coord, y_diff, x_diff]
# weights to hold (srcx, srcy, x_diff, y_diff) for each out value.
weights = np.empty([new_h, new_w, 4], dtype='float32')
for i in range(new_h):
for j in range(new_w):
weights[i][j] = _bilinear_interpolation(i, j)
return weights
def bilinear_resize_python(image, out_size, layout, align_corners=False):
""" Bilinear scaling using python"""
(new_h, new_w) = out_size
if layout == 'NHWC':
(batch, h, w, channel) = image.shape
scaled_image = np.ones((batch, new_h, new_w, channel))
else:
(batch, channel, h, w) = image.shape
scaled_image = np.ones((batch, channel, new_h, new_w))
weights = bilinear_weights(h, w, new_h, new_w, align_corners)
for b in range(batch):
for i in range(channel):
for j in range(new_h):
for k in range(new_w):
y0 = int(weights[j][k][0])
x0 = int(weights[j][k][1])
x1 = min((x0+1), (w-1))
y1 = min((y0+1), (h-1))
y_diff = weights[j][k][2]
x_diff = weights[j][k][3]
if layout == 'NHWC':
A = image[b][y0][x0][i]
B = image[b][y0][x1][i]
C = image[b][y1][x0][i]
D = image[b][y1][x1][i]
else:
A = image[b][i][y0][x0]
B = image[b][i][y0][x1]
C = image[b][i][y1][x0]
D = image[b][i][y1][x1]
pixel = np.float32((A*(1-x_diff)*(1-y_diff) +
B*(x_diff)*(1-y_diff) +
C*(y_diff)*(1-x_diff) +
D*(x_diff*y_diff)))
if layout == 'NHWC':
scaled_image[b][j][k][i] = pixel
else:
scaled_image[b][i][j][k] = pixel
return scaled_image

Просмотреть файл

@ -3,13 +3,26 @@
import numpy as np
def upsample_nearest(arr, scale):
""" Populate the array by scale factor"""
return arr.repeat(scale, axis=0).repeat(scale, axis=1)
def upsampling_python(data, scale):
def upsampling_python(data, scale, layout='NCHW'):
""" Python version of scaling using nearest neighbour """
ishape = data.shape
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np
if layout == 'NCHW':
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np
elif layout == 'NHWC':
oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3])
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[3]):
output_np[b, :, :, c] = upsample_nearest(data[b, :, :, c], scale)
return output_np
else:
raise ValueError("not support this layout {} yet".format(layout))

Просмотреть файл

@ -23,8 +23,10 @@
#include <topi/nn/mapping.h>
#include <topi/nn/pooling.h>
#include <topi/nn/softmax.h>
#include <topi/nn/upsampling.h>
#include <topi/vision/reorg.h>
#include <topi/image/resize.h>
#include <topi/vision/yolo2/region.h>
#include <topi/generic/default.h>
#include <topi/generic/extern.h>
@ -285,6 +287,12 @@ TVM_REGISTER_GLOBAL("topi.strided_slice")
*rv = strided_slice(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/upsampling.h */
TVM_REGISTER_GLOBAL("topi.nn.upsampling")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::upsampling(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/batch_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_norm_inference")
.set_body([](TVMArgs args, TVMRetValue *rv) {
@ -366,10 +374,18 @@ TVM_REGISTER_GLOBAL("topi.vision.reorg")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::reorg(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.vision.yolo2.region")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::yolo2::region(args[0], args[1], args[2], args[3], args[4], args[5]);
});
/* Ops from image/resize.h */
TVM_REGISTER_GLOBAL("topi.image.resize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = image::resize(args[0], args[1], args[2], args[3], args[4]);
});
/* Generic schedules */
TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {

Просмотреть файл

@ -0,0 +1,57 @@
"""Test code for bilinear scale """
import numpy as np
import tvm
import topi
import topi.testing
import math
def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False):
if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype
out_shape = (batch, in_channel, out_height, out_width)
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
elif layout == 'NHWC':
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='float32')
dtype = A.dtype
out_shape = (batch, out_height, out_width, in_channel)
a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
else:
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners)
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3)
for device in ['llvm', 'cuda', 'vulkan']:
check_device(device)
def test_resize():
# Scale NCHW
verify_bilinear_scale(4, 16, 32, 32, 50, 50, 'NCHW')
# Scale NCHW + Align Corners
verify_bilinear_scale(6, 32, 64, 64, 20, 20, 'NCHW', True)
# Scale NHWC
verify_bilinear_scale(4, 16, 32, 32, 50, 50, "NHWC")
# Scale NHWC + Align Corners
verify_bilinear_scale(6, 32, 64, 64, 20, 20, "NHWC", True)
if __name__ == "__main__":
test_resize()

Просмотреть файл

@ -5,14 +5,26 @@ import topi
import topi.testing
import math
def verify_upsampling(batch, in_channel, in_height, in_width, scale):
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.nn.upsampling(A, scale)
out_shape = (batch, in_channel, in_height*scale, in_width*scale)
dtype = A.dtype
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW'):
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
b_np = topi.testing.upsampling_python(a_np, scale)
if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
dtype = A.dtype
out_shape = (batch, in_channel, in_height*scale, in_width*scale)
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
elif layout == 'NHWC':
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
dtype = A.dtype
out_shape = (batch, in_height*scale, in_width*scale, in_channel)
a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype)
else:
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.nn.upsampling(A, scale, layout=layout)
b_np = topi.testing.upsampling_python(a_np, scale, layout)
def check_device(device):
ctx = tvm.context(device, 0)
@ -33,8 +45,12 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale):
check_device(device)
def test_upsampling():
# NCHW
verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(12, 32, 64, 64, 3)
# NHWC
verify_upsampling(8, 16, 32, 32, 2, "NHWC")
verify_upsampling(12, 32, 64, 64, 3, "NHWC")
if __name__ == "__main__":
test_upsampling()