CPP implementation of L2Norm and LRN ops (#1157)
This commit is contained in:
Родитель
464c8c2667
Коммит
fb88b74e4a
|
@ -368,6 +368,41 @@ struct NMSParam : public dmlc::Parameter<NMSParam> {
|
|||
}
|
||||
};
|
||||
|
||||
struct LRNParam : public dmlc::Parameter<LRNParam> {
|
||||
int size;
|
||||
int axis;
|
||||
float alpha;
|
||||
float beta;
|
||||
float bias;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(LRNParam) {
|
||||
DMLC_DECLARE_FIELD(size)
|
||||
.describe("The size of the local region to be considered for normalization.");
|
||||
DMLC_DECLARE_FIELD(axis)
|
||||
.describe("input data layout channel axis");
|
||||
DMLC_DECLARE_FIELD(alpha)
|
||||
.describe("The scaling parameter.");
|
||||
DMLC_DECLARE_FIELD(beta)
|
||||
.describe("The exponent parameter.");
|
||||
DMLC_DECLARE_FIELD(bias)
|
||||
.describe("The offset parameter.");
|
||||
}
|
||||
// constants
|
||||
static const constexpr int kData = 0;
|
||||
};
|
||||
|
||||
struct L2NormalizeParam : public dmlc::Parameter<L2NormalizeParam> {
|
||||
float eps;
|
||||
Tuple<int> axis;
|
||||
|
||||
DMLC_DECLARE_PARAMETER(L2NormalizeParam) {
|
||||
DMLC_DECLARE_FIELD(eps)
|
||||
.describe("float type epsilon value.");
|
||||
DMLC_DECLARE_FIELD(axis)
|
||||
.describe("axis over the normalization applied");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
||||
|
||||
|
|
|
@ -243,3 +243,36 @@ def schedule_upsampling(_, outs, target):
|
|||
return topi.generic.schedule_injective(outs)
|
||||
|
||||
reg.register_pattern("upsampling", OpPattern.INJECTIVE)
|
||||
|
||||
@reg.register_compute("lrn")
|
||||
def compute_lrn(attrs, inputs, _):
|
||||
"""Compute definition of lrn"""
|
||||
size = attrs.get_int("size")
|
||||
axis = attrs.get_int("axis")
|
||||
alpha = attrs.get_float("alpha")
|
||||
beta = attrs.get_float("beta")
|
||||
bias = attrs.get_float("bias")
|
||||
return topi.nn.lrn(inputs[0], size, axis, alpha, beta, bias)
|
||||
|
||||
@reg.register_schedule("lrn")
|
||||
def schedule_lrn(attrs, outs, target):
|
||||
"""Schedule definition of lrn"""
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_lrn(outs)
|
||||
|
||||
reg.register_pattern("lrn", OpPattern.OPAQUE)
|
||||
|
||||
@reg.register_compute("l2_normalize")
|
||||
def compute_l2_normalize(attrs, inputs, _):
|
||||
"""Compute definition of l2 normalize"""
|
||||
eps = attrs.get_float("eps")
|
||||
axis = attrs.get_int_tuple("axis")
|
||||
return topi.nn.l2_normalize(inputs[0], eps, axis)
|
||||
|
||||
@reg.register_schedule("l2_normalize")
|
||||
def schedule_l2_normalize(attrs, outs, target):
|
||||
"""Schedule definition of l2 normalize"""
|
||||
with tvm.target.create(target):
|
||||
return topi.generic.schedule_l2_normalize(outs)
|
||||
|
||||
reg.register_pattern("l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
|
||||
|
|
|
@ -712,5 +712,52 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
|
|||
})
|
||||
.set_support_level(1);
|
||||
|
||||
DMLC_REGISTER_PARAMETER(LRNParam);
|
||||
|
||||
inline bool LRNInferShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_shape,
|
||||
std::vector<TShape>* out_shape) {
|
||||
TShape dshape = (*in_shape)[0];
|
||||
TShape oshape = dshape;
|
||||
|
||||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
|
||||
return true;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(lrn)
|
||||
.describe(R"code(LRN layer)code" NNVM_ADD_FILELINE)
|
||||
.add_argument("data", "4D Tensor", "Input data.")
|
||||
.set_attr_parser(ParamParser<LRNParam>)
|
||||
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<LRNParam>)
|
||||
.set_num_inputs(1)
|
||||
.set_num_outputs(1)
|
||||
.set_attr<FInferShape>("FInferShape", LRNInferShape)
|
||||
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
|
||||
.set_support_level(1);
|
||||
|
||||
DMLC_REGISTER_PARAMETER(L2NormalizeParam);
|
||||
|
||||
inline bool L2NormalizeInferShape(const nnvm::NodeAttrs& attrs,
|
||||
std::vector<TShape>* in_shape,
|
||||
std::vector<TShape>* out_shape) {
|
||||
TShape dshape = (*in_shape)[0];
|
||||
TShape oshape = dshape;
|
||||
|
||||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
|
||||
return true;
|
||||
}
|
||||
|
||||
NNVM_REGISTER_OP(l2_normalize)
|
||||
.describe(R"code(L2NORMALIZE layer)code" NNVM_ADD_FILELINE)
|
||||
.add_argument("data", "4D Tensor", "Input data.")
|
||||
.set_attr_parser(ParamParser<L2NormalizeParam>)
|
||||
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<L2NormalizeParam>)
|
||||
.set_num_inputs(1)
|
||||
.set_num_outputs(1)
|
||||
.set_attr<FInferShape>("FInferShape", L2NormalizeInferShape)
|
||||
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
|
||||
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
|
||||
.set_support_level(1);
|
||||
|
||||
} // namespace top
|
||||
} // namespace nnvm
|
||||
|
|
|
@ -6,7 +6,6 @@ import nnvm.symbol as sym
|
|||
import nnvm.compiler
|
||||
from nnvm.testing.config import ctx_list
|
||||
|
||||
|
||||
def helper(symbol, inputs, dtype,
|
||||
np_forward, np_backward=None, need_input=True, need_head_grads=True):
|
||||
ishapes = {}
|
||||
|
@ -365,6 +364,65 @@ def test_pad():
|
|||
inputs = [('x', (1, 3, 28, 28), x)]
|
||||
helper(y, inputs, dtype, forward)
|
||||
|
||||
def verify_lrn(ishape, size, axis, bias, alpha, beta):
|
||||
x = sym.Variable("x")
|
||||
y = sym.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
|
||||
dtype = "float32"
|
||||
x_np = np.random.uniform(size=ishape).astype(dtype)
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
m.run(x=x_np)
|
||||
out = m.get_output(0, tvm.nd.empty(ishape))
|
||||
out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta)
|
||||
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
#Checking LRN op followed by elementwise op relu
|
||||
z = sym.relu(y)
|
||||
x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape})
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
m.run(x=x_np)
|
||||
out = m.get_output(0, tvm.nd.empty(ishape))
|
||||
out_np = topi.testing.lrn_python(x_np, size, axis, bias, alpha, beta)
|
||||
out_np = (out_np > 0) * out_np
|
||||
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def verify_l2_normalize(ishape, eps, axis):
|
||||
x = sym.Variable("x")
|
||||
y = sym.l2_normalize(x, eps=eps, axis=axis)
|
||||
dtype = "float32"
|
||||
x_np = np.random.uniform(size=ishape).astype(dtype)
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
m.run(x=x_np)
|
||||
out = m.get_output(0, tvm.nd.empty(ishape))
|
||||
out_np = topi.testing.l2_normalize_python(x_np, eps, axis)
|
||||
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
#Checking L2 normalization op followed by elementwise op relu
|
||||
z = sym.relu(y)
|
||||
x_np = np.random.uniform(low=-10.0, high=10.0, size=ishape).astype(dtype)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(z, target, {"x": ishape})
|
||||
m = graph_runtime.create(graph, lib, ctx)
|
||||
m.run(x=x_np)
|
||||
out = m.get_output(0, tvm.nd.empty(ishape))
|
||||
out_np = topi.testing.l2_normalize_python(x_np, eps, axis)
|
||||
out_np = (out_np > 0) * out_np
|
||||
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_lrn():
|
||||
verify_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5)
|
||||
verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
|
||||
|
||||
def test_l2_normalize():
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_split()
|
||||
|
@ -384,3 +442,5 @@ if __name__ == "__main__":
|
|||
test_softmax()
|
||||
test_squeeze()
|
||||
test_pad()
|
||||
test_lrn()
|
||||
test_l2_normalize()
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file cuda/normalization.h
|
||||
* \brief CUDA schedule for LRN and l2 normalization operations
|
||||
*/
|
||||
#ifndef TOPI_CUDA_NORMALIZATION_H_
|
||||
#define TOPI_CUDA_NORMALIZATION_H_
|
||||
|
||||
#include "tvm/tvm.h"
|
||||
#include "tvm/build_module.h"
|
||||
#include "topi/tags.h"
|
||||
|
||||
namespace topi {
|
||||
using namespace tvm;
|
||||
namespace cuda {
|
||||
/*!
|
||||
* \brief Create a CUDA schedule for LRN
|
||||
*
|
||||
* \param target The target to generate a schedule for.
|
||||
* \param outs The output tensors.
|
||||
*
|
||||
* \return A schedule for the given ops.
|
||||
*/
|
||||
inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
|
||||
Array<Operation> out_ops;
|
||||
for (auto t : outs) {
|
||||
out_ops.push_back(t->op);
|
||||
}
|
||||
Schedule s = create_schedule(out_ops);
|
||||
int num_thread = 64;
|
||||
IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x");
|
||||
IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
|
||||
Tensor lrn = outs[0];
|
||||
Tensor sqr_sum_up = lrn->op->InputTensors()[1];
|
||||
Tensor sqr_sum = sqr_sum_up->op->InputTensors()[0];
|
||||
Tensor set_pad = sqr_sum->op->InputTensors()[0];
|
||||
s[set_pad].bind(set_pad->op.as<ComputeOpNode>()->axis[0], block_x);
|
||||
IterVar rxk = sqr_sum->op.as<ComputeOpNode>()->reduce_axis[0];
|
||||
IterVar xko, xki;
|
||||
s[sqr_sum].split(rxk, num_thread, &xko, &xki);
|
||||
Tensor srf = s.rfactor(sqr_sum, xki)[0];
|
||||
s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->axis[0], block_x);
|
||||
s[sqr_sum].bind(s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
|
||||
s[srf].compute_at(s[sqr_sum], s[sqr_sum]->op.as<ComputeOpNode>()->reduce_axis[0]);
|
||||
s[sqr_sum_up].bind(sqr_sum_up->op.as<ComputeOpNode>()->axis[0], block_x);
|
||||
IterVar xto, xti;
|
||||
s[lrn].split_by_nparts(lrn->op.as<ComputeOpNode>()->axis[1], num_thread, &xto, &xti);
|
||||
s[lrn].bind(lrn->op.as<ComputeOpNode>()->axis[0], block_x);
|
||||
s[lrn].bind(xto, thread_x);
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Create a CUDA schedule for L2 normalization
|
||||
*
|
||||
* \param target The target to generate a schedule for.
|
||||
* \param outs The output tensors.
|
||||
*
|
||||
* \return A schedule for the given ops.
|
||||
*/
|
||||
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
|
||||
Array<Operation> out_ops;
|
||||
for (auto t : outs) {
|
||||
out_ops.push_back(t->op);
|
||||
}
|
||||
Schedule s = create_schedule(out_ops);
|
||||
|
||||
std::function<void(Operation)> traverse;
|
||||
traverse = [&](const Operation& op) {
|
||||
// Inline all one-to-one-mapping operators except the last stage (output)
|
||||
if (is_injective(op->tag) || op->tag == "l2_normalize") {
|
||||
if (!detail::contains(s->outputs, op)) {
|
||||
s[op].compute_inline();
|
||||
}
|
||||
for (auto tensor : op->InputTensors()) {
|
||||
if (tensor->op->InputTensors().size() > 0) {
|
||||
traverse(tensor->op);
|
||||
}
|
||||
}
|
||||
} else if (op->tag == "comm_reduce") {
|
||||
ScheduleReduce(target, op, s, false);
|
||||
for (auto tensor : op->InputTensors()) {
|
||||
traverse(tensor->op);
|
||||
}
|
||||
} else {
|
||||
LOG(ERROR) << "Unsupported operator " << op->tag;
|
||||
}
|
||||
};
|
||||
|
||||
traverse(outs[0]->op);
|
||||
int num_thread = 64;
|
||||
Tensor l2_normalize = outs[0];
|
||||
IterVar block_x = tvm::thread_axis(Range(), "blockIdx.x");
|
||||
IterVar thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
|
||||
IterVar xto, xti;
|
||||
s[l2_normalize].split_by_nparts(l2_normalize->op.as<ComputeOpNode>()->axis[1],
|
||||
num_thread, &xto, &xti);
|
||||
s[l2_normalize].bind(l2_normalize->op.as<ComputeOpNode>()->axis[0], block_x);
|
||||
s[l2_normalize].bind(xto, thread_x);
|
||||
return s;
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace topi
|
||||
#endif // TOPI_CUDA_NORMALIZATION_H_
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \brief l2 normalization op constructions
|
||||
* \file nn/l2_normalize.h
|
||||
*/
|
||||
#ifndef TOPI_NN_L2_NORMALIZE_H_
|
||||
#define TOPI_NN_L2_NORMALIZE_H_
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "topi/tags.h"
|
||||
#include "tvm/tvm.h"
|
||||
namespace topi {
|
||||
namespace nn {
|
||||
using namespace tvm;
|
||||
|
||||
/*!
|
||||
* \brief L2 normalization inference operator
|
||||
*
|
||||
* \param data The input tensor. 4-D with shape [batch, channel, height, width]
|
||||
* \param eps Epsilon to prevent div by 0
|
||||
* \param axis Axes over the normalization applied
|
||||
* \param name The name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor whose op member is the l2 normalization operation
|
||||
*/
|
||||
inline Tensor l2_normalize(const Tensor& data,
|
||||
float eps,
|
||||
const Array<Expr>& axis,
|
||||
std::string name = "tensor",
|
||||
std::string tag = "l2_normalize") {
|
||||
CHECK_EQ(data->shape.size(), 4) << "L2 normalization requires 4-D input";
|
||||
auto input_shape = data->shape;
|
||||
Tensor dot_value = pow(data, static_cast<float>(2.0));
|
||||
Tensor sum_value = topi::sum(dot_value, axis, true);
|
||||
Tensor expand_sum = topi::broadcast_to(sum_value, input_shape);
|
||||
return topi::broadcast_div(data,
|
||||
topi::sqrt(tvm::compute(expand_sum->shape,
|
||||
[&](const Array<Var>& i){
|
||||
return (max(expand_sum(i), eps));
|
||||
}, name = name, tag = tag)));
|
||||
}
|
||||
} // namespace nn
|
||||
} // namespace topi
|
||||
#endif // TOPI_NN_L2_NORMALIZE_H_
|
|
@ -0,0 +1,76 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \brief local response normalization op constructions
|
||||
* \file nn/local_response_norm.h
|
||||
*/
|
||||
#ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_
|
||||
#define TOPI_NN_LOCAL_RESPONSE_NORM_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "topi/tags.h"
|
||||
#include "tvm/tvm.h"
|
||||
|
||||
namespace topi {
|
||||
namespace nn {
|
||||
using namespace tvm;
|
||||
|
||||
/*!
|
||||
* \brief Local response normalization inference operator
|
||||
*
|
||||
* \param data The input tensor. 4-D shape NCHW or NHWC
|
||||
* \param size Integer to define normalisation window size
|
||||
* \param axis Input data layout channel axis
|
||||
* \param alpha Float scaling factor
|
||||
* \param beta Exponent value
|
||||
* \param bias Offset to avoid dividing by zero
|
||||
* \param name The name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor whose op member is the Local response normalization operation
|
||||
*/
|
||||
inline Tensor lrn(const Tensor& data,
|
||||
int size,
|
||||
int axis = 1,
|
||||
float alpha = 0.0001,
|
||||
float beta = 0.75,
|
||||
float bias = 2,
|
||||
std::string name = "tensor",
|
||||
std::string tag = kBroadcast) {
|
||||
CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input";
|
||||
CHECK_EQ(size % 2, 1) << "size should be odd number";
|
||||
CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC";
|
||||
auto input_shape = data->shape;
|
||||
Array<Expr> pad_before{ 0, 0, 0, 0};
|
||||
Array<Expr> pad_after{ 0, 0, 0, 0};
|
||||
pad_before.Set(axis, static_cast<Expr>(size/2));
|
||||
pad_after.Set(axis, static_cast<Expr>(size/2));
|
||||
auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data");
|
||||
auto rxs = tvm::reduce_axis(Range(0, size), "rxs");
|
||||
Tensor sqr_sum;
|
||||
if (axis == 1) {
|
||||
sqr_sum = tvm::compute(input_shape,
|
||||
[&](Var i, Var l, Var j, Var k) {
|
||||
return tvm::sum(pad_data(i, l + rxs, j, k) *
|
||||
pad_data(i, l + rxs, j, k),
|
||||
{rxs});
|
||||
});
|
||||
} else if (axis == 3) {
|
||||
sqr_sum = tvm::compute(input_shape,
|
||||
[&](Var i, Var l, Var j, Var k) {
|
||||
return tvm::sum(pad_data(i, l, j, k + rxs) *
|
||||
pad_data(i, l, j, k + rxs),
|
||||
{rxs});
|
||||
});
|
||||
}
|
||||
auto sqrt_sum_up = tvm::compute(input_shape,
|
||||
[&](Var i, Var j, Var k, Var l) {
|
||||
return tvm::pow(bias +
|
||||
(alpha * sqr_sum(i, j, k, l) / size),
|
||||
beta);
|
||||
});
|
||||
return topi::broadcast_div(data, sqrt_sum_up);
|
||||
}
|
||||
} // namespace nn
|
||||
} // namespace topi
|
||||
#endif // TOPI_NN_LOCAL_RESPONSE_NORM_H_
|
|
@ -0,0 +1,41 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file rocm/normalization.h
|
||||
* \brief rocm schedule for LRN and l2 normalization operations
|
||||
*/
|
||||
#ifndef TOPI_ROCM_NORMALIZATION_H_
|
||||
#define TOPI_ROCM_NORMALIZATION_H_
|
||||
|
||||
#include "tvm/tvm.h"
|
||||
#include "tvm/build_module.h"
|
||||
#include "topi/tags.h"
|
||||
|
||||
namespace topi {
|
||||
using namespace tvm;
|
||||
namespace rocm {
|
||||
/*!
|
||||
* \brief Create a rocm schedule for LRN
|
||||
*
|
||||
* \param target The target to generate a schedule for.
|
||||
* \param outs The output tensors.
|
||||
*
|
||||
* \return A schedule for the given ops.
|
||||
*/
|
||||
inline Schedule schedule_lrn(const Target &target, const Array<Tensor>& outs) {
|
||||
return topi::cuda::schedule_lrn(target, outs);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Create a rocm schedule for L2 Normalization
|
||||
*
|
||||
* \param target The target to generate a schedule for.
|
||||
* \param outs The output tensors.
|
||||
*
|
||||
* \return A schedule for the given ops.
|
||||
*/
|
||||
inline Schedule schedule_l2_normalize(const Target &target, const Array<Tensor>& outs) {
|
||||
return topi::cuda::schedule_l2_normalize(target, outs);
|
||||
}
|
||||
} // namespace rocm
|
||||
} // namespace topi
|
||||
#endif // TOPI_ROCM_NORMALIZATION_H_
|
|
@ -17,4 +17,4 @@ from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
|
|||
from .extern import schedule_extern
|
||||
from .vision import schedule_region
|
||||
from .vision import schedule_reorg
|
||||
from .nn import schedule_lrn, schedule_l2norm
|
||||
from .nn import schedule_lrn, schedule_l2_normalize
|
||||
|
|
|
@ -4,8 +4,7 @@ from __future__ import absolute_import as _abs
|
|||
|
||||
import tvm
|
||||
from .. import generic
|
||||
from .. import tag
|
||||
from .reduction import _schedule_reduce
|
||||
from .. import cpp
|
||||
|
||||
@generic.schedule_lrn.register(["cuda"])
|
||||
def schedule_lrn(outs):
|
||||
|
@ -22,37 +21,18 @@ def schedule_lrn(outs):
|
|||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
num_thread = 64
|
||||
block_x = tvm.thread_axis("blockIdx.x")
|
||||
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
cpp_target = cpp.TEST_create_target(target.target_name)
|
||||
return cpp.cuda.schedule_lrn(cpp_target, outs)
|
||||
|
||||
lrn = outs[0]
|
||||
sqr_sum_up = lrn.op.input_tensors[1]
|
||||
sqr_sum = sqr_sum_up.op.input_tensors[0]
|
||||
set_pad = sqr_sum.op.input_tensors[0]
|
||||
s[set_pad].bind(set_pad.op.axis[0], block_x)
|
||||
rxk = sqr_sum.op.reduce_axis[0]
|
||||
_, xki = s[sqr_sum].split(rxk, factor=num_thread)
|
||||
srf = s.rfactor(sqr_sum, xki)
|
||||
s[sqr_sum].bind(s[sqr_sum].op.axis[0], block_x)
|
||||
s[sqr_sum].bind(s[sqr_sum].op.reduce_axis[0], thread_x)
|
||||
s[srf].compute_at(s[sqr_sum], s[sqr_sum].op.reduce_axis[0])
|
||||
s[sqr_sum_up].bind(sqr_sum_up.op.axis[0], block_x)
|
||||
xto, _ = s[lrn].split(lrn.op.axis[1], nparts=num_thread)
|
||||
s[lrn].bind(lrn.op.axis[0], block_x)
|
||||
s[lrn].bind(xto, thread_x)
|
||||
return s
|
||||
|
||||
@generic.schedule_l2norm.register(["cuda"])
|
||||
def schedule_l2norm(outs):
|
||||
"""Schedule for L2norm
|
||||
@generic.schedule_l2_normalize.register(["cuda"])
|
||||
def schedule_l2_normalize(outs):
|
||||
"""Schedule for L2 normalize
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of L2norm
|
||||
The computation graph description of L2 normalize
|
||||
in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
|
@ -60,32 +40,6 @@ def schedule_l2norm(outs):
|
|||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
|
||||
def traverse(OP):
|
||||
'''inline all one-to-one-mapping operators
|
||||
except the last stage (output)'''
|
||||
if tag.is_injective(OP.tag) or OP.tag == 'l2norm':
|
||||
if OP not in s.outputs:
|
||||
s[OP].compute_inline()
|
||||
for tensor in OP.input_tensors:
|
||||
if tensor.op.input_tensors:
|
||||
traverse(tensor.op)
|
||||
elif OP.tag == 'comm_reduce':
|
||||
_schedule_reduce(OP, s, is_idx_reduce=False)
|
||||
for tensor in OP.input_tensors:
|
||||
traverse(tensor.op)
|
||||
else:
|
||||
raise RuntimeError("Unsupported operator tag: %s" % OP.tag)
|
||||
traverse(outs[0].op)
|
||||
|
||||
num_thread = 64
|
||||
l2norm = outs[0]
|
||||
block_x = tvm.thread_axis("blockIdx.x")
|
||||
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
|
||||
xto, _ = s[l2norm].split(l2norm.op.axis[1], nparts=num_thread)
|
||||
s[l2norm].bind(l2norm.op.axis[0], block_x)
|
||||
s[l2norm].bind(xto, thread_x)
|
||||
|
||||
return s
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
cpp_target = cpp.TEST_create_target(target.target_name)
|
||||
return cpp.cuda.schedule_l2_normalize(cpp_target, outs)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
"""Generic nn operators"""
|
||||
from __future__ import absolute_import as _abs
|
||||
import tvm
|
||||
|
||||
from .. import cpp
|
||||
|
||||
def _default_schedule(outs, auto_inline):
|
||||
"""Default schedule for llvm."""
|
||||
|
@ -273,17 +273,18 @@ def schedule_lrn(outs):
|
|||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
cpp_target = cpp.TEST_create_target(target.target_name)
|
||||
return cpp.generic.default_schedule(cpp_target, outs, False)
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_l2norm(outs):
|
||||
"""Schedule for l2norm
|
||||
def schedule_l2_normalize(outs):
|
||||
"""Schedule for l2 normalize
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of l2norm
|
||||
The computation graph description of l2 normalize
|
||||
in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
|
@ -291,4 +292,6 @@ def schedule_l2norm(outs):
|
|||
sch: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
cpp_target = cpp.TEST_create_target(target.target_name)
|
||||
return cpp.generic.default_schedule(cpp_target, outs, False)
|
||||
|
|
|
@ -16,4 +16,4 @@ from .conv2d_transpose import *
|
|||
from .bnn import *
|
||||
from .upsampling import *
|
||||
from .local_response_norm import *
|
||||
from .l2_norm import *
|
||||
from .l2_normalize import *
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""TVM operator for l2norm"""
|
||||
from __future__ import absolute_import
|
||||
import tvm
|
||||
import topi
|
||||
|
||||
@tvm.target.generic_func
|
||||
def l2norm_instance(data, eps, axis=None):
|
||||
"""Perform L2norm on the input data
|
||||
|
||||
For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.Tensor
|
||||
4-D with NCHW or NHWC layout
|
||||
|
||||
eps : float
|
||||
epsilon value
|
||||
|
||||
axis : list of int
|
||||
axis over the normalization applied
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : tvm.Tensor
|
||||
4-D output with same shape
|
||||
"""
|
||||
assert len(data.shape) == 4, "only support 4-dim lrn"
|
||||
dot_value = topi.cpp.pow(data, 2.0)
|
||||
sum_value = topi.sum(dot_value, axis=axis, keepdims=True)
|
||||
expand_sum = topi.broadcast_to(sum_value, data.shape)
|
||||
return topi.broadcast_div(data, topi.sqrt(\
|
||||
tvm.compute(expand_sum.shape, lambda i, j, k, l:\
|
||||
tvm.max(expand_sum[i, j, k, l], eps), tag='l2norm')))
|
|
@ -0,0 +1,29 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""TVM operator for l2 normalize"""
|
||||
from __future__ import absolute_import
|
||||
import tvm
|
||||
from .. import cpp
|
||||
|
||||
@tvm.target.generic_func
|
||||
def l2_normalize(data, eps, axis=None):
|
||||
"""Perform L2 normalization on the input data
|
||||
|
||||
For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : tvm.Tensor
|
||||
4-D with NCHW or NHWC layout
|
||||
|
||||
eps : float
|
||||
epsilon value
|
||||
|
||||
axis : list of int
|
||||
axis over the normalization applied
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : tvm.Tensor
|
||||
4-D output with same shape
|
||||
"""
|
||||
return cpp.nn.l2_normalize(data, eps, axis)
|
|
@ -2,8 +2,7 @@
|
|||
"""TVM operator for local response norm compute."""
|
||||
from __future__ import absolute_import
|
||||
import tvm
|
||||
import topi
|
||||
from .pad import pad
|
||||
from .. import cpp
|
||||
|
||||
@tvm.target.generic_func
|
||||
def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
|
||||
|
@ -42,27 +41,4 @@ def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
|
|||
output : tvm.Tensor
|
||||
4-D output with same shape
|
||||
"""
|
||||
assert len(data.shape) == 4, "only support 4-dim lrn"
|
||||
assert (size % 2) == 1, "size should be odd number"
|
||||
assert (axis == 1) or (axis == 3), "axis should 1 or 3 for NCHW and NHWC"
|
||||
##Add padding on left & right of size radius first
|
||||
pad_after = pad_before = [0, 0, 0, 0]
|
||||
pad_after[axis] = pad_before[axis] = (size//2)
|
||||
pad_data = pad(data, pad_before, pad_after, name="pad_data")
|
||||
|
||||
rxs = tvm.reduce_axis((0, size), name='rxs')
|
||||
if axis == 1:
|
||||
#NCHW layout
|
||||
sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
|
||||
pad_data[i, j + rxs, k, l] * pad_data[i, j + rxs, k, l],
|
||||
axis=rxs))
|
||||
elif axis == 3:
|
||||
#NHWC layout
|
||||
sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
|
||||
pad_data[i, j, k, l + rxs] * pad_data[i, j, k, l + rxs],
|
||||
axis=rxs))
|
||||
|
||||
sqr_sum_up = tvm.compute(data.shape, lambda i, j, k, l: tvm.power(
|
||||
(bias + (alpha * sqr_sum[i, j, k, l] / size)), beta))
|
||||
|
||||
return topi.broadcast_div(data, sqr_sum_up)
|
||||
return cpp.nn.lrn(data, size, axis, alpha, beta, bias)
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
"""scheduler for normalization functions on rocm backend"""
|
||||
from __future__ import absolute_import as _abs
|
||||
|
||||
import topi
|
||||
import tvm
|
||||
from .. import generic
|
||||
from .. import cpp
|
||||
|
||||
@generic.schedule_lrn.register(["rocm", "gpu"])
|
||||
def schedule_lrn(outs):
|
||||
return topi.cuda.schedule_lrn(outs)
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
cpp_target = cpp.TEST_create_target(target.target_name)
|
||||
return cpp.rocm.schedule_lrn(cpp_target, outs)
|
||||
|
||||
@generic.schedule_l2norm.register(["rocm", "gpu"])
|
||||
def schedule_l2norm(outs):
|
||||
return topi.cuda.schedule_l2norm(outs)
|
||||
@generic.schedule_l2_normalize.register(["rocm", "gpu"])
|
||||
def schedule_l2_normalize(outs):
|
||||
target = tvm.target.current_target(allow_none=False)
|
||||
cpp_target = cpp.TEST_create_target(target.target_name)
|
||||
return cpp.rocm.schedule_l2_normalize(cpp_target, outs)
|
||||
|
|
|
@ -16,3 +16,5 @@ from .bilinear_resize_python import bilinear_resize_python
|
|||
from .reorg_python import reorg_python
|
||||
from .region_python import region_python
|
||||
from .shortcut_python import shortcut_python
|
||||
from .lrn_python import lrn_python
|
||||
from .l2_normalize_python import l2_normalize_python
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
|
||||
"""L2 normalize in python"""
|
||||
import numpy as np
|
||||
|
||||
def l2_normalize_python(a_np, eps, axis=None):
|
||||
"""L2 normalize operator in NCHW layout.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a_np : numpy.ndarray
|
||||
4-D with shape [batch, in_channel, in_height, in_width]
|
||||
|
||||
eps : float
|
||||
epsilon constant value
|
||||
axis : list of int
|
||||
axis over the normalization applied
|
||||
|
||||
Returns
|
||||
-------
|
||||
l2_normalize_out : np.ndarray
|
||||
4-D with shape [batch, out_channel, out_height, out_width]
|
||||
"""
|
||||
dot_value = np.power(a_np, 2.0)
|
||||
sqr_sum = np.sum(dot_value, axis, keepdims=True)
|
||||
sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
|
||||
l2_normalize_out = np.divide(a_np, sqrt_sum)
|
||||
return l2_normalize_out
|
|
@ -0,0 +1,53 @@
|
|||
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
|
||||
"""LRN in python"""
|
||||
from itertools import product
|
||||
import numpy as np
|
||||
|
||||
def lrn_python(a_np, size, axis, bias, alpha, beta):
|
||||
"""Local response normalization operator in NCHW layout.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a_np : numpy.ndarray
|
||||
4-D with shape [batch, in_channel, in_height, in_width]
|
||||
|
||||
size : int
|
||||
normalization window size
|
||||
|
||||
axis : int
|
||||
input data layout channel axis
|
||||
|
||||
bias : float
|
||||
offset to avoid dividing by 0. constant value
|
||||
|
||||
alpha : float
|
||||
constant value
|
||||
|
||||
beta : float
|
||||
exponent constant value
|
||||
|
||||
Returns
|
||||
-------
|
||||
lrn_out : np.ndarray
|
||||
4-D with shape [batch, out_channel, out_height, out_width]
|
||||
"""
|
||||
radius = size // 2
|
||||
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
|
||||
for i, j, k, l in product(*[range(_axis) for _axis in a_np.shape]):
|
||||
axis_size = a_np.shape[axis]
|
||||
if axis == 1:
|
||||
#NCHW layout
|
||||
sum_start = j-radius if j-radius >= 0 else 0
|
||||
sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
|
||||
sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
|
||||
a_np[i, sum_start:sum_end, k, l])
|
||||
elif axis == 3:
|
||||
#NHWC layout
|
||||
sum_start = l-radius if l-radius >= 0 else 0
|
||||
sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
|
||||
sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
|
||||
a_np[i, j, k, sum_start:sum_end])
|
||||
|
||||
sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
|
||||
lrn_out = np.divide(a_np, sqr_sum_up)
|
||||
return lrn_out
|
|
@ -24,6 +24,8 @@
|
|||
#include <topi/nn/pooling.h>
|
||||
#include <topi/nn/softmax.h>
|
||||
#include <topi/nn/upsampling.h>
|
||||
#include <topi/nn/l2_normalize.h>
|
||||
#include <topi/nn/local_response_norm.h>
|
||||
|
||||
#include <topi/vision/reorg.h>
|
||||
#include <topi/image/resize.h>
|
||||
|
@ -39,6 +41,7 @@
|
|||
#include <topi/cuda/reduction.h>
|
||||
#include <topi/cuda/softmax.h>
|
||||
#include <topi/cuda/vision.h>
|
||||
#include <topi/cuda/normalization.h>
|
||||
|
||||
#include <topi/x86/bnn.h>
|
||||
#include <topi/x86/default.h>
|
||||
|
@ -46,6 +49,7 @@
|
|||
|
||||
#include <topi/rocm/dense.h>
|
||||
#include <topi/rocm/vision.h>
|
||||
#include <topi/rocm/normalization.h>
|
||||
|
||||
namespace topi {
|
||||
|
||||
|
@ -359,6 +363,20 @@ TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
|
|||
*rv = nn::log_softmax(args[0]);
|
||||
});
|
||||
|
||||
/* Ops from nn/l2_normalize.h */
|
||||
TVM_REGISTER_GLOBAL("topi.nn.l2_normalize")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = nn::l2_normalize(args[0], static_cast<double>(args[1]), args[2]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.nn.lrn")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = nn::lrn(args[0], args[1], args[2],
|
||||
static_cast<double>(args[3]),
|
||||
static_cast<double>(args[4]),
|
||||
static_cast<double>(args[5]));
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.vision.reorg")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = vision::reorg(args[0], args[1]);
|
||||
|
@ -435,6 +453,17 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_region")
|
|||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = topi::rocm::schedule_region(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = topi::rocm::schedule_lrn(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.rocm.schedule_l2_normalize")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = topi::rocm::schedule_l2_normalize(args[0], args[1]);
|
||||
});
|
||||
|
||||
/* CUDA schedules */
|
||||
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
|
@ -481,6 +510,16 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_region")
|
|||
*rv = topi::cuda::schedule_region(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = topi::cuda::schedule_lrn(args[0], args[1]);
|
||||
});
|
||||
|
||||
TVM_REGISTER_GLOBAL("topi.cuda.schedule_l2_normalize")
|
||||
.set_body([](TVMArgs args, TVMRetValue *rv) {
|
||||
*rv = topi::cuda::schedule_l2_normalize(args[0], args[1]);
|
||||
});
|
||||
|
||||
/*! \brief Builder function for instantiating schedules. */
|
||||
using FTVMScheduleBuilder = std::function<
|
||||
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
|
||||
|
|
|
@ -1,44 +1,18 @@
|
|||
"""Test code for L2 norm"""
|
||||
"""Test code for L2 normalization"""
|
||||
import numpy as np
|
||||
import tvm
|
||||
import topi
|
||||
from topi.util import get_const_tuple
|
||||
import topi.testing
|
||||
|
||||
def l2norm_instance_python(a_np, eps, axis=None):
|
||||
"""L2 norm operator in NCHW layout.
|
||||
def verify_l2_normalize(ishape, eps, axis=None):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a_np : numpy.ndarray
|
||||
4-D with shape [batch, in_channel, in_height, in_width]
|
||||
|
||||
eps : float
|
||||
epsilon constant value
|
||||
axis : list of int
|
||||
axis over the normalization applied
|
||||
|
||||
Returns
|
||||
-------
|
||||
l2norm_out : np.ndarray
|
||||
4-D with shape [batch, out_channel, out_height, out_width]
|
||||
"""
|
||||
batch, axis1, axis2, axis3 = a_np.shape
|
||||
sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
|
||||
sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
|
||||
l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
|
||||
dot_value = np.power(a_np, 2.0)
|
||||
sqr_sum = np.sum(dot_value, axis, keepdims=True)
|
||||
sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
|
||||
return np.divide(a_np, sqrt_sum)
|
||||
|
||||
def verify_l2norm(n, c, h, w, eps, axis=None):
|
||||
|
||||
A = tvm.placeholder((n, c, h, w), name='A')
|
||||
B = topi.nn.l2norm_instance(A, eps, axis)
|
||||
A = tvm.placeholder(ishape, name='A')
|
||||
B = topi.nn.l2_normalize(A, eps, axis)
|
||||
dtype = A.dtype
|
||||
|
||||
a_np = np.random.uniform(size=(n, c, h, w)).astype(dtype)
|
||||
b_np = l2norm_instance_python(a_np, eps, axis)
|
||||
a_np = np.random.uniform(size=ishape).astype(dtype)
|
||||
b_np = topi.testing.l2_normalize_python(a_np, eps, axis)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
|
@ -47,7 +21,10 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
|
|||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_l2norm(B)
|
||||
if device == 'llvm':
|
||||
s = topi.generic.schedule_l2_normalize([B])
|
||||
else:
|
||||
s = topi.cuda.schedule_l2_normalize([B])
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
|
@ -57,14 +34,14 @@ def verify_l2norm(n, c, h, w, eps, axis=None):
|
|||
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
|
||||
check_device(device)
|
||||
|
||||
def test_l2norm():
|
||||
verify_l2norm(1, 3, 20, 20, 0.001)
|
||||
verify_l2norm(1, 3, 20, 20, 0.001, 1)
|
||||
verify_l2norm(1, 3, 20, 20, 0.001, (1, 2))
|
||||
verify_l2norm(1, 3, 20, 20, 0.001, (2, 3))
|
||||
verify_l2norm(1, 3, 20, 20, 0.001, (0, 3))
|
||||
verify_l2norm(1, 3, 20, 20, 0.001, (0, 2, 3))
|
||||
def test_l2_normalize():
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001)
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_l2norm()
|
||||
test_l2_normalize()
|
||||
|
|
|
@ -3,63 +3,7 @@ import numpy as np
|
|||
import tvm
|
||||
import topi
|
||||
from topi.util import get_const_tuple
|
||||
|
||||
def lrn_python(a_np, size, axis, bias, alpha, beta):
|
||||
"""Local response norm operator in NCHW layout.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a_np : numpy.ndarray
|
||||
4-D with shape [batch, in_channel, in_height, in_width]
|
||||
|
||||
size : int
|
||||
normalisation window size
|
||||
|
||||
axis : int
|
||||
input data layout channel axis
|
||||
|
||||
bias : float
|
||||
offset to avoid dividing by 0. constant value
|
||||
|
||||
alpha : float
|
||||
contant valie
|
||||
|
||||
beta : float
|
||||
exponent constant value
|
||||
|
||||
Returns
|
||||
-------
|
||||
b_np : np.ndarray
|
||||
4-D with shape [batch, out_channel, out_height, out_width]
|
||||
"""
|
||||
axis0, axis1, axis2, axis3 = a_np.shape
|
||||
radius = size // 2
|
||||
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
|
||||
sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype)
|
||||
lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
|
||||
def sum_dot_values(i, j, k, l):
|
||||
axis_size = a_np.shape[axis]
|
||||
if (axis == 1):
|
||||
#NCHW layout
|
||||
sum_start = j-radius if j-radius >= 0 else 0
|
||||
sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
|
||||
sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
|
||||
a_np[i, sum_start:sum_end, k, l])
|
||||
elif (axis == 3):
|
||||
#NHWC layout
|
||||
sum_start = l-radius if l-radius >= 0 else 0
|
||||
sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
|
||||
sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
|
||||
a_np[i, j, k, sum_start:sum_end])
|
||||
|
||||
for i in range(axis0):
|
||||
for j in range(axis1):
|
||||
for k in range(axis2):
|
||||
for l in range(axis3):
|
||||
sum_dot_values(i, j, k, l)
|
||||
|
||||
sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
|
||||
return np.divide(a_np, sqr_sum_up)
|
||||
import topi.testing
|
||||
|
||||
def verify_lrn(shape, size, axis, bias, alpha, beta):
|
||||
A = tvm.placeholder(shape, name='A')
|
||||
|
@ -67,16 +11,19 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
|
|||
dtype = A.dtype
|
||||
|
||||
a_np = np.random.uniform(size=shape).astype(dtype)
|
||||
b_np = lrn_python(a_np, size, axis, bias, alpha, beta)
|
||||
b_np = topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
s = topi.generic.schedule_lrn(B)
|
||||
if device == 'llvm':
|
||||
s = topi.generic.schedule_lrn([B])
|
||||
else:
|
||||
s = topi.cuda.schedule_lrn([B])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
|
@ -87,9 +34,9 @@ def verify_lrn(shape, size, axis, bias, alpha, beta):
|
|||
check_device(device)
|
||||
|
||||
def test_lrn():
|
||||
verify_lrn((1, 3, 5, 5), 3, 1, 1, 1, 0.5)
|
||||
verify_lrn((1, 3, 5, 5), 3, 3, 1, 1, 0.5)
|
||||
verify_lrn((1, 3, 20, 20), 3, 1, 2, 1, 0.75)
|
||||
verify_lrn((1, 3, 5, 5), 3, 1, 1.0, 1.0, 0.5)
|
||||
verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
|
||||
verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lrn()
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
"""Test code for l2 normalization"""
|
||||
import numpy as np
|
||||
import tvm
|
||||
import topi
|
||||
import logging
|
||||
from topi.util import get_const_tuple
|
||||
import topi.testing
|
||||
|
||||
def verify_l2_normalize(shape, eps, axis=None):
|
||||
'''Verify l2 normalization operator by comparing outputs from tvm and numpy implementation'''
|
||||
A = tvm.placeholder(shape, name='A')
|
||||
B = topi.cpp.nn.l2_normalize(A, eps, axis)
|
||||
dtype = A.dtype
|
||||
|
||||
a_np = np.random.uniform(size=shape).astype(dtype)
|
||||
b_np = topi.testing.l2_normalize_python(a_np, eps, axis)
|
||||
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
target = topi.cpp.TEST_create_target(device)
|
||||
if device == "llvm":
|
||||
s = topi.cpp.generic.default_schedule(target, [B], False)
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_l2_normalize(target, [B])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
|
||||
func = tvm.build(s, [A, B], device, name="l2_normalize")
|
||||
func(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm']:
|
||||
check_device(device)
|
||||
|
||||
def test_l2_normalize():
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001)
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (1,))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (1, 2))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (2, 3))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 3))
|
||||
verify_l2_normalize((1, 3, 20, 20), 0.001, (0, 2, 3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
test_l2_normalize()
|
|
@ -0,0 +1,44 @@
|
|||
"""Test code for LRN"""
|
||||
import numpy as np
|
||||
import tvm
|
||||
import topi
|
||||
import logging
|
||||
from topi.util import get_const_tuple
|
||||
import topi.testing
|
||||
|
||||
def verify_lrn(shape, size, axis, bias, alpha, beta):
|
||||
'''Verify Local response normalization operator by comparing outputs from tvm and numpy implementation'''
|
||||
A = tvm.placeholder(shape, name='A')
|
||||
B = topi.cpp.nn.lrn(A, size, axis, alpha, beta, bias)
|
||||
dtype = A.dtype
|
||||
|
||||
a_np = np.random.uniform(size=shape).astype(dtype)
|
||||
b_np = topi.testing.lrn_python(a_np, size, axis, bias, alpha, beta)
|
||||
def check_device(device):
|
||||
if not tvm.module.enabled(device):
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
target = topi.cpp.TEST_create_target(device)
|
||||
if device == "llvm":
|
||||
s = topi.cpp.generic.default_schedule(target, [B], False)
|
||||
else:
|
||||
s = topi.cpp.cuda.schedule_lrn(target, [B])
|
||||
ctx = tvm.context(device, 0)
|
||||
a = tvm.nd.array(a_np, ctx)
|
||||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
|
||||
f = tvm.build(s, [A, B], device)
|
||||
f(a, b)
|
||||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-1)
|
||||
|
||||
for device in ['cuda', 'opencl', 'metal', 'rocm', 'llvm']:
|
||||
check_device(device)
|
||||
|
||||
def test_lrn():
|
||||
verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
|
||||
verify_lrn((1, 3, 5, 5), 3, 3, 1.0, 1.0, 0.5)
|
||||
verify_lrn((1, 3, 20, 20), 3, 1, 2.0, 1.0, 0.75)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
test_lrn()
|
Загрузка…
Ссылка в новой задаче