This commit is contained in:
Steven S. Lyubomirsky 2018-10-11 21:57:51 -07:00 коммит произвёл Tianqi Chen
Родитель 493fc04028
Коммит 7e8a876737
5 изменённых файлов: 151 добавлений и 3 удалений

Просмотреть файл

@ -55,6 +55,7 @@ This level enables typical convnet models.
tvm.relay.nn.global_avg_pool2d
tvm.relay.nn.upsampling
tvm.relay.nn.batch_flatten
tvm.relay.nn.pad
tvm.relay.nn.lrn
tvm.relay.nn.l2_normalize

Просмотреть файл

@ -223,8 +223,19 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
}
};
/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
Array<Array<IndexExpr> > pad_width;
TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_value).set_default(0.0)
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(pad_width)
.describe("Number of values padded to the edges of each axis, "
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
}
};
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {

Просмотреть файл

@ -429,7 +429,6 @@ def batch_flatten(data):
"""
return _make.batch_flatten(data)
def relu(data):
"""Rectified linear unit.
@ -449,6 +448,32 @@ def relu(data):
return _make.relu(data)
def pad(data,
pad_width,
pad_value=0.0):
r"""Padding
This operator takes in a tensor and pads each axis by the specified
widths using the specified value.
Parameters
----------
data: relay.Expr
The input data to the operator
pad_width: tuple of <tuple of <int>>, required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
pad_value: float, optional, default=0.0
The value used for padding
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.pad(data, pad_width, pad_value)
def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
"""This operator takes data as input and does local response normalization.
@ -484,9 +509,9 @@ def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
result : relay.Expr
The computed result.
"""
return _make.lrn(data, size, axis, alpha, beta, bias)
def l2_normalize(data, eps, axis=None):
"""Perform L2 normalization on the input data

86
src/relay/op/nn/pad.cc Normal file
Просмотреть файл

@ -0,0 +1,86 @@
/*!
* Copyright (c) 2018 by Contributors
* \file pad.cc
* \brief Implementation of operator pad
*/
#include <tvm/ir_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(PadAttrs);
bool PadRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const PadAttrs* param = attrs.as<PadAttrs>();
CHECK(param != nullptr);
// check that pad widths match lengths
CHECK(data->shape.size() == param->pad_width.size())
<< "There should be as many pad width pairs as shape dimensions "
<< "but the shape has " << data->shape.size() << " dimensions "
<< "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->pad_width.size(); i++) {
CHECK(param->pad_width[i].size() == 2)
<< "Each pad width element should be a pair but at index " << i
<< " there are " << param->pad_width[i].size() << " elements.";
auto width1 = as_const_int(param->pad_width[i][0]);
auto width2 = as_const_int(param->pad_width[i][1]);
CHECK(width1 != nullptr);
CHECK(width2 != nullptr);
CHECK(*width1 >= 0)
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width1 << ".";
CHECK(*width2 >= 0)
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
static const Op& op = Op::Get("nn.pad");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.pad")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
});
RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Pad", PadRel);
} // namespace relay
} // namespace tvm

Просмотреть файл

@ -196,10 +196,35 @@ def test_flatten_infer_type():
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType((d1, ((2*d3)*3)), "float32")
def test_pad_infer_type():
# entirely concrete case
ib = relay.ir_builder.IRBuilder()
n, c, h, w = 1, 2, 3, 4
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func:
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((3, 6, 9, 12), "float32")
# some symbolic values
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
t = ib.param("t", relay.TensorType((n, c, h, w), "float32"))
with ib.function(t) as func:
ib.ret(relay.nn.pad(t.var, ((1, 1), (2, 2), (3, 3), (4, 4))))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32")
if __name__ == "__main__":
test_conv2d_infer_type()
test_pool2d_infer_type()
test_upsampling_infer_type()
test_flatten_infer_type()
test_pad_infer_type()
test_conv2d_transpose_infer_type()