[RELAY][OP] conv2d, ShapeExpr->IndexExpr (#1798)
This commit is contained in:
Родитель
147f3ad526
Коммит
9afde69b84
|
@ -56,6 +56,22 @@ namespace tvm {
|
|||
__fvisit__(#FieldName, &FieldName)
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Create a NodeRef type that represents null.
|
||||
* \tparam TNodeRef the type to be created.
|
||||
* \return A instance that will represent None.
|
||||
*/
|
||||
template<typename TNodeRef>
|
||||
inline TNodeRef NullValue() {
|
||||
return TNodeRef(NodePtr<Node>(nullptr));
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Type NullValue<Type>() {
|
||||
return Type(Type::Handle, 0, 0);
|
||||
}
|
||||
|
||||
|
||||
/*! \brief Error thrown during attribute checking. */
|
||||
struct AttrError : public dmlc::Error {
|
||||
/*!
|
||||
|
|
|
@ -114,7 +114,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
|
|||
static_assert(
|
||||
std::is_base_of<NodeRef, TNodeRef>::value,
|
||||
"Conversion only works for NodeRef");
|
||||
if (type_code_ == kNull) return TNodeRef();
|
||||
if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
|
||||
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
|
||||
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file tvm/relay/attrs/nn.h
|
||||
* \brief Auxiliary attributes for nn operators.
|
||||
*/
|
||||
#ifndef TVM_RELAY_ATTRS_NN_H_
|
||||
#define TVM_RELAY_ATTRS_NN_H_
|
||||
|
||||
#include <tvm/attrs.h>
|
||||
#include <string>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
/*! \brief Attributes used in convolution operators */
|
||||
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
|
||||
Array<IndexExpr> strides;
|
||||
Array<IndexExpr> padding;
|
||||
Array<IndexExpr> dilation;
|
||||
int groups;
|
||||
IndexExpr channels;
|
||||
Array<IndexExpr> kernel_size;
|
||||
std::string data_layout;
|
||||
std::string weight_layout;
|
||||
std::string out_layout;
|
||||
DataType out_dtype;
|
||||
|
||||
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") {
|
||||
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
|
||||
.describe("Specifies the strides of the convolution.");
|
||||
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
|
||||
.describe("If padding is non-zero, then the input is implicitly zero-padded"
|
||||
"on both sides for padding number of points");
|
||||
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
|
||||
.describe("Specifies the dilation rate to use for dilated convolution.");
|
||||
TVM_ATTR_FIELD(groups).set_default(1)
|
||||
.describe("Controls the connections between inputs and outputs."
|
||||
"At groups=1, all inputs are convolved to all outputs."
|
||||
"At groups=2, the operation becomes equivalent to having two convolution"
|
||||
"layers side by side, each seeing half the input channels, and producing"
|
||||
"half the output channels, and both subsequently concatenated.");
|
||||
TVM_ATTR_FIELD(channels)
|
||||
.describe("The number of output channels in the convolution."
|
||||
" If it is not set, inferred by shape of the weight.")
|
||||
.set_default(NullValue<IndexExpr>());
|
||||
TVM_ATTR_FIELD(kernel_size)
|
||||
.describe("Specifies the dimensions of the convolution window.")
|
||||
.set_default(NullValue<Array<IndexExpr> >());
|
||||
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
|
||||
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
|
||||
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
|
||||
"dimensions respectively. Convolution is applied on the 'H' and"
|
||||
"'W' dimensions.");
|
||||
TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
|
||||
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
|
||||
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
|
||||
"dimensions respectively.");
|
||||
TVM_ATTR_FIELD(out_layout).set_default("__undef__")
|
||||
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
|
||||
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
|
||||
"dimensions respectively. Default to be same as input layout.");
|
||||
|
||||
// use 0 bits to indicate none.
|
||||
TVM_ATTR_FIELD(out_dtype)
|
||||
.set_default(Int(0))
|
||||
.describe("Output data type, set to explicit type under mixed precision setting");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_ATTRS_NN_H_
|
|
@ -37,7 +37,7 @@ using DataType = ::tvm::Type;
|
|||
/*!
|
||||
* \brief Symbolic expression for tensor shape.
|
||||
*/
|
||||
using ShapeExpr = ::tvm::Expr;
|
||||
using IndexExpr = ::tvm::Expr;
|
||||
|
||||
/*!
|
||||
* \brief Hash function for nodes.
|
||||
|
|
|
@ -286,7 +286,9 @@ class CallNode : public ExprNode {
|
|||
v->Visit("_checked_type_", &checked_type_);
|
||||
}
|
||||
|
||||
TVM_DLL static Call make(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
|
||||
TVM_DLL static Call make(Expr op,
|
||||
Array<Expr> args,
|
||||
Attrs attrs = Attrs(),
|
||||
Array<Type> ty_args = Array<Type>());
|
||||
|
||||
static constexpr const char* _type_key = "relay.Call";
|
||||
|
|
|
@ -70,9 +70,9 @@ class TensorTypeNode : public BaseTensorTypeNode {
|
|||
public:
|
||||
/*!
|
||||
* \brief The shape of the tensor,
|
||||
* represented by ShapeExpr(tvm::Expr).
|
||||
* represented by IndexExpr(tvm::Expr).
|
||||
*/
|
||||
Array<ShapeExpr> shape;
|
||||
Array<IndexExpr> shape;
|
||||
/*! \brief The content data type */
|
||||
DataType dtype;
|
||||
|
||||
|
@ -82,7 +82,7 @@ class TensorTypeNode : public BaseTensorTypeNode {
|
|||
v->Visit("span", &span);
|
||||
}
|
||||
|
||||
TVM_DLL static TensorType make(Array<ShapeExpr> shape, DataType dtype);
|
||||
TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
|
||||
|
||||
/*! \brief Construct an scalar containing elements of dtype. */
|
||||
TVM_DLL static TensorType Scalar(DataType dtype);
|
||||
|
@ -273,8 +273,10 @@ class TypeReporterNode : public Node {
|
|||
* \brief assert shape expression equals each other.
|
||||
* \param lhs The left operand.
|
||||
* \param rhs The right operand.
|
||||
* \return false if assertation can be proven to have failed
|
||||
* true if solver can still proceed.
|
||||
*/
|
||||
TVM_DLL virtual void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) = 0;
|
||||
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
|
||||
|
||||
// solver is not serializable.
|
||||
void VisitAttrs(tvm::AttrVisitor* v) final {}
|
||||
|
|
|
@ -521,6 +521,12 @@ class TVMArgValue : public TVMPODValue_ {
|
|||
if (type_code_ == kStr) {
|
||||
return String2TVMType(operator std::string());
|
||||
}
|
||||
// None type
|
||||
if (type_code_ == kNull) {
|
||||
TVMType t;
|
||||
t.code = kHandle; t.bits = 0; t.lanes = 0;
|
||||
return t;
|
||||
}
|
||||
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
|
||||
return value_.v_type;
|
||||
}
|
||||
|
@ -878,6 +884,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
|
|||
#endif
|
||||
|
||||
inline std::string TVMType2String(TVMType t) {
|
||||
if (t.bits == 0) return "";
|
||||
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
|
||||
std::ostringstream os;
|
||||
os << t;
|
||||
|
@ -896,6 +903,11 @@ inline std::string TVMType2String(TVMType t) {
|
|||
|
||||
inline TVMType String2TVMType(std::string s) {
|
||||
TVMType t;
|
||||
// handle None type
|
||||
if (s.length() == 0) {
|
||||
t.bits = 0; t.lanes = 0; t.code = kHandle;
|
||||
return t;
|
||||
}
|
||||
t.bits = 32; t.lanes = 1;
|
||||
const char* scan;
|
||||
if (s.substr(0, 3) == "int") {
|
||||
|
|
|
@ -9,6 +9,7 @@ from . import ir_builder
|
|||
# Operators
|
||||
from .op import Op
|
||||
from .op.tensor import *
|
||||
from .op import nn
|
||||
|
||||
# Span
|
||||
Span = base.Span
|
||||
|
|
|
@ -11,17 +11,19 @@ class Environment(NodeBase):
|
|||
options and more.
|
||||
"""
|
||||
|
||||
def __init__(self, funcs):
|
||||
def __init__(self, funcs=None):
|
||||
"""Construct an environment.
|
||||
|
||||
Parameters
|
||||
------
|
||||
funcs: list of relay.Function
|
||||
funcs : optional, dict
|
||||
Map of global var to Function
|
||||
|
||||
Returns
|
||||
------
|
||||
env: A new environment containing :py:class:`~relay.env.Environment`.
|
||||
"""
|
||||
funcs = funcs if funcs else {}
|
||||
self.__init_handle_by_constructor__(_make.Environment, funcs)
|
||||
|
||||
def add(self, var, func):
|
||||
|
|
|
@ -6,10 +6,26 @@ Exposes an interface for configuring the passes and scripting
|
|||
them in Python.
|
||||
"""
|
||||
from . import _ir_pass
|
||||
|
||||
# Expose checking expression, should rename to infer_type.
|
||||
# pylint: disable=invalid-name
|
||||
check_expr = _ir_pass.check_expr
|
||||
|
||||
def infer_type(env, expr):
|
||||
"""Infer the type of expr under the context of env
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env : relay.Environment
|
||||
The global environmemt.
|
||||
|
||||
expr : relay.Expr
|
||||
The input expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
checked_expr : relay.Expr
|
||||
The checked expression.
|
||||
"""
|
||||
return _ir_pass.infer_type(env, expr)
|
||||
|
||||
|
||||
well_formed = _ir_pass.well_formed
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ from .op import get, register, Op
|
|||
|
||||
# Operators
|
||||
from .tensor import *
|
||||
from . import nn
|
||||
|
||||
|
||||
# operator registry
|
||||
from . import _tensor
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
"""Neural network operations."""
|
||||
from __future__ import absolute_import as _abs
|
||||
from . import _make
|
||||
|
||||
|
||||
def conv2d(data,
|
||||
weight,
|
||||
strides=(1, 1),
|
||||
padding=(0, 0),
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
channels=None,
|
||||
kernel_size=None,
|
||||
data_layout="NCHW",
|
||||
weight_layout="OIHW",
|
||||
out_layout="",
|
||||
out_dtype=""):
|
||||
"""Two dimensional convolution operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : relay.Expr
|
||||
The input data to the operator.
|
||||
|
||||
weight : relay.Expr
|
||||
The weight expressions.
|
||||
|
||||
strides : tuple of int, optional
|
||||
The strides of convoltution.
|
||||
|
||||
padding : tuple of int, optional
|
||||
The padding of convolution on both sides of inputs.
|
||||
|
||||
dilation : tuple of int, optional
|
||||
Specifies the dilation rate to be used for dilated convolution.
|
||||
|
||||
groups : int, optional
|
||||
Number of groups for grouped convolution.
|
||||
|
||||
data_layout : str, optional
|
||||
Layout of the input.
|
||||
|
||||
weight_layout : str, optional
|
||||
Layout of the weight.
|
||||
|
||||
out_layout : str, optional
|
||||
Layout of the output.
|
||||
|
||||
out_dtype : str, optional
|
||||
Specifies the output data type for mixed precision conv2d.
|
||||
"""
|
||||
return _make.conv2d(data, weight, strides, padding, dilation,
|
||||
groups, channels, kernel_size, data_layout,
|
||||
weight_layout, out_layout, out_dtype)
|
|
@ -117,6 +117,9 @@ Operation ComputeOpNode::make(std::string name,
|
|||
Map<std::string, NodeRef> attrs,
|
||||
Array<IterVar> axis,
|
||||
Array<Expr> body) {
|
||||
if (!attrs.defined()) {
|
||||
attrs = Map<std::string, NodeRef>();
|
||||
}
|
||||
auto n = make_node<ComputeOpNode>();
|
||||
n->name = std::move(name);
|
||||
n->tag = std::move(tag);
|
||||
|
|
|
@ -43,6 +43,9 @@ Operation ExternOpNode::make(std::string name,
|
|||
Array<Buffer> input_placeholders,
|
||||
Array<Buffer> output_placeholders,
|
||||
Stmt body) {
|
||||
if (!attrs.defined()) {
|
||||
attrs = Map<std::string, NodeRef>();
|
||||
}
|
||||
auto n = make_node<ExternOpNode>();
|
||||
n->name = std::move(name);
|
||||
n->tag = std::move(tag);
|
||||
|
|
|
@ -51,6 +51,9 @@ Operation ScanOpNode::make(std::string name,
|
|||
Array<Tensor> update,
|
||||
Array<Tensor> state_placeholder,
|
||||
Array<Tensor> inputs) {
|
||||
if (!attrs.defined()) {
|
||||
attrs = Map<std::string, NodeRef>();
|
||||
}
|
||||
auto n = make_node<ScanOpNode>();
|
||||
CHECK_EQ(init.size(), update.size());
|
||||
CHECK_EQ(init.size(), state_placeholder.size());
|
||||
|
|
|
@ -418,6 +418,19 @@ bool Equal(const Stmt& lhs, const Stmt& rhs) {
|
|||
}
|
||||
|
||||
bool Equal(const Expr& lhs, const Expr& rhs) {
|
||||
// quick pass for constant expressions.
|
||||
if (const int64_t *a = as_const_int(lhs)) {
|
||||
if (const int64_t *b = as_const_int(rhs)) {
|
||||
return a[0] == b[0];
|
||||
}
|
||||
}
|
||||
if (!lhs.defined()) {
|
||||
if (rhs.defined()) return false;
|
||||
if (!rhs.defined()) return true;
|
||||
} else {
|
||||
if (!rhs.defined()) return false;
|
||||
}
|
||||
// deep comparison.
|
||||
return IRDeepCompare().Equal(lhs, rhs);
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ namespace relay {
|
|||
using tvm::IRPrinter;
|
||||
using namespace tvm::runtime;
|
||||
|
||||
TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) {
|
||||
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
|
||||
NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
|
||||
n->shape = std::move(shape);
|
||||
n->dtype = std::move(dtype);
|
||||
|
@ -24,7 +24,7 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
|
|||
|
||||
TVM_REGISTER_API("relay._make.TensorType")
|
||||
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||
Array<ShapeExpr> shape = args[0];
|
||||
Array<IndexExpr> shape = args[0];
|
||||
*ret = TensorTypeNode::make(shape, args[1]);
|
||||
});
|
||||
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file convolution.cc
|
||||
* \brief Convolution operators
|
||||
*/
|
||||
#include <tvm/relay/op.h>
|
||||
#include <tvm/relay/attrs/nn.h>
|
||||
#include <vector>
|
||||
#include "layout.h"
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
TVM_REGISTER_NODE_TYPE(ConvAttrs);
|
||||
|
||||
bool Conv2DRel(const Array<Type>& types,
|
||||
int num_inputs,
|
||||
const Attrs& attrs,
|
||||
const TypeReporter& reporter) {
|
||||
CHECK_EQ(types.size(), 3);
|
||||
const auto* data = types[0].as<TensorTypeNode>();
|
||||
const auto* weight = types[1].as<TensorTypeNode>();
|
||||
if (data == nullptr) return false;
|
||||
|
||||
static const Layout kNCHW("NCHW");
|
||||
static const Layout kOIHW("OIHW");
|
||||
|
||||
const ConvAttrs* param = attrs.as<ConvAttrs>();
|
||||
CHECK(param != nullptr);
|
||||
const Layout in_layout(param->data_layout);
|
||||
const Layout kernel_layout(param->weight_layout);
|
||||
CHECK(in_layout.convertible(kNCHW))
|
||||
<< "Conv only support input layouts that are convertible from NCHW."
|
||||
<< " But got " << in_layout;
|
||||
CHECK(kernel_layout.convertible(kOIHW))
|
||||
<< "Conv only support kernel layouts that are convertible from OIHW."
|
||||
<< " But got "<< kernel_layout;
|
||||
|
||||
Layout out_layout(param->out_layout);
|
||||
if (!out_layout.defined()) out_layout = in_layout;
|
||||
CHECK(out_layout.convertible(kNCHW))
|
||||
<< "Conv only support output layouts that are convertible from NCHW."
|
||||
<< " But got " << out_layout;
|
||||
|
||||
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
|
||||
// infer weight if the kernel_size and channels are defined
|
||||
if (param->kernel_size.defined() && param->channels.defined()) {
|
||||
CHECK_EQ(param->kernel_size.size(), 2);
|
||||
CHECK_EQ(param->dilation.size(), 2);
|
||||
std::vector<IndexExpr> wshape(
|
||||
{param->channels / param->groups,
|
||||
data->shape[1] / param->groups,
|
||||
param->kernel_size[0],
|
||||
param->kernel_size[1]});
|
||||
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
|
||||
wshape[kernel_layout.indexof('O')] *= param->groups;
|
||||
channels = param->channels;
|
||||
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
|
||||
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
|
||||
// assign result to reporter
|
||||
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
|
||||
} else {
|
||||
// use weight to infer the conv shape.
|
||||
if (weight == nullptr) return false;
|
||||
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
|
||||
if (param->kernel_size.defined()) {
|
||||
CHECK_EQ(param->kernel_size.size(), 2);
|
||||
// check the size
|
||||
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
|
||||
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
|
||||
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
|
||||
<< " kernel_size=" << param->kernel_size
|
||||
<< " wshape=" << Array<IndexExpr>(wshape);
|
||||
}
|
||||
if (param->channels.defined()) {
|
||||
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
|
||||
<< "Conv2D: shape of weight is inconsistent with channels, "
|
||||
<< " channels=" << param->channels
|
||||
<< " wshape=" << Array<IndexExpr>(wshape);
|
||||
}
|
||||
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
|
||||
channels = wshape[0];
|
||||
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
|
||||
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
|
||||
}
|
||||
// dilation
|
||||
std::vector<IndexExpr> oshape({data->shape[0], channels, 0, 0});
|
||||
|
||||
oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
|
||||
oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
|
||||
DataType out_dtype = param->out_dtype;
|
||||
if (out_dtype.bits() == 0) {
|
||||
out_dtype = data->dtype;
|
||||
}
|
||||
oshape = ConvertLayout(oshape, kNCHW, out_layout);
|
||||
// assign output type
|
||||
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// Positional relay function to create conv2d operator
|
||||
// used by frontend FFI.
|
||||
Expr MakeConv2D(Expr data,
|
||||
Expr weight,
|
||||
Array<IndexExpr> strides,
|
||||
Array<IndexExpr> padding,
|
||||
Array<IndexExpr> dilation,
|
||||
int groups,
|
||||
IndexExpr channels,
|
||||
Array<IndexExpr> kernel_size,
|
||||
std::string data_layout,
|
||||
std::string weight_layout,
|
||||
std::string out_layout,
|
||||
DataType out_dtype) {
|
||||
auto attrs = make_node<ConvAttrs>();
|
||||
attrs->strides = std::move(strides);
|
||||
attrs->padding = std::move(padding);
|
||||
attrs->dilation = std::move(dilation);
|
||||
attrs->groups = groups;
|
||||
attrs->channels = channels;
|
||||
attrs->kernel_size = kernel_size;
|
||||
attrs->data_layout = std::move(data_layout);
|
||||
attrs->weight_layout = std::move(weight_layout);
|
||||
attrs->out_layout = std::move(out_layout);
|
||||
attrs->out_dtype = std::move(out_dtype);
|
||||
static const Op& op = Op::Get("conv2d");
|
||||
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
|
||||
}
|
||||
|
||||
|
||||
TVM_REGISTER_API("relay.op._make.conv2d")
|
||||
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||
runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
|
||||
});
|
||||
|
||||
|
||||
RELAY_REGISTER_OP("conv2d")
|
||||
.describe(R"code(2D convolution layer (e.g. spatial convolution over images).
|
||||
|
||||
This layer creates a convolution kernel that is convolved
|
||||
with the layer input to produce a tensor of outputs.
|
||||
|
||||
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
|
||||
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
|
||||
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
|
||||
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
|
||||
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
|
||||
|
||||
)code" TVM_ADD_FILELINE)
|
||||
.set_num_inputs(2)
|
||||
.add_argument("data", "Tensor", "The input tensor.")
|
||||
.add_argument("weight", "Tensor", "The weight tensor.")
|
||||
.set_support_level(2)
|
||||
.add_type_rel("Conv2D", Conv2DRel);
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
|
@ -0,0 +1,538 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file relay/op/nn/layout.h
|
||||
* \brief Layout expression.
|
||||
*
|
||||
* This file is adapted from its nnvm counterpart and will keep involving
|
||||
* to the new layout system
|
||||
*
|
||||
* The layout is composed of upper cases, lower cases and numbers,
|
||||
* where upper case indicates a (super-)dimension and
|
||||
* the corresponding lower case with factor size indicates the split (sub-)dimension.
|
||||
* For example, NCHW16c can describe a 5-D tensor of
|
||||
* [batch_size, channel, height, width, channel_block].
|
||||
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
|
||||
*/
|
||||
#ifndef TVM_RELAY_OP_NN_LAYOUT_H_
|
||||
#define TVM_RELAY_OP_NN_LAYOUT_H_
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
namespace tvm {
|
||||
namespace relay {
|
||||
|
||||
/*! \brief layout auxiliary structure */
|
||||
class Layout {
|
||||
public:
|
||||
using LayoutDim = char;
|
||||
|
||||
/*! \brief default constructor */
|
||||
Layout() : name_("__undef__") {} // NOLINT(*)
|
||||
|
||||
/*!
|
||||
* \brief construct from a string.
|
||||
* \param layout input in layout convention:
|
||||
* upper case indicates a dimension and
|
||||
* the corresponding lower case with factor size
|
||||
* indicates the split dimension.
|
||||
* return undefined layout if "__undef__" is passed.
|
||||
*/
|
||||
Layout(const std::string& layout) { // NOLINT(*)
|
||||
if (layout.length() != 0) {
|
||||
parse(layout);
|
||||
} else {
|
||||
parse("__undef__");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief copy constructor from another layout
|
||||
* \param s the source layout
|
||||
*/
|
||||
Layout(const Layout& s) { // NOLINT(*)
|
||||
this->parse(s.name_);
|
||||
}
|
||||
/*!
|
||||
* \brief move constructor from Layout
|
||||
* \param src the source layout
|
||||
*/
|
||||
Layout(Layout&& src) { // NOLINT(*)
|
||||
this->swap(src);
|
||||
}
|
||||
/*!
|
||||
* \brief assignment from another layout.
|
||||
* \param src source layout
|
||||
* \return reference of self
|
||||
*/
|
||||
Layout& operator=(const Layout& src) {
|
||||
this->parse(src.name_);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief assignment from rvalue of another layout.
|
||||
* \param src source layout
|
||||
* \return reference of self
|
||||
*/
|
||||
Layout& operator=(Layout&& src) {
|
||||
Layout(std::move(src)).swap(*this); // NOLINT(*)
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \brief assignment from string.
|
||||
* \param src source layout
|
||||
* \return reference of self
|
||||
*/
|
||||
Layout& operator=(const std::string& src) {
|
||||
this->parse(src);
|
||||
return *this;
|
||||
}
|
||||
/*!
|
||||
* \return whether two layout equals
|
||||
* \param s the layout to compare against
|
||||
*/
|
||||
bool operator==(const Layout& s) const {
|
||||
return name_ == s.name_;
|
||||
}
|
||||
/*!
|
||||
* \return whether two layout not equal
|
||||
* \param s the layout to compare against
|
||||
*/
|
||||
bool operator!=(const Layout& s) const {
|
||||
return !(*this == s);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Append the current layout by another.
|
||||
* @param other the layout to be appended
|
||||
* @return a new layout
|
||||
*/
|
||||
Layout operator+(const Layout& other) const {
|
||||
if (!this->defined() && !other.defined()) {
|
||||
return Layout::Undef();
|
||||
} else if (!this->defined()) {
|
||||
return other;
|
||||
} else if (!other.defined()) {
|
||||
return *this;
|
||||
}
|
||||
return Layout(this->name_ + other.name_);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Check whether a given dimension is a super-dimension.
|
||||
* \param dim input dimension
|
||||
* \return Whether a given dimension is a super-dimension.
|
||||
*/
|
||||
static bool is_superdim(LayoutDim dim) {
|
||||
return dim >= 'A' && dim <= 'Z';
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Check whether a given dimension is a sub-dimension.
|
||||
* \param dim input dimension
|
||||
* \return Whether a given dimension is a sub-dimension.
|
||||
*/
|
||||
static bool is_subdim(LayoutDim dim) {
|
||||
return dim >= 'a' && dim <= 'z';
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Convert a given dimension to super-dimension.
|
||||
* \param dim input dimension
|
||||
* \return The converted description.
|
||||
*/
|
||||
static LayoutDim to_superdim(LayoutDim dim) {
|
||||
if (is_subdim(dim)) {
|
||||
return dim - 'a' + 'A';
|
||||
}
|
||||
return dim;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Convert a given dimension to sub-dimension.
|
||||
* \param dim input dimension
|
||||
* \return The converted description.
|
||||
*/
|
||||
static LayoutDim to_subdim(LayoutDim dim) {
|
||||
if (is_superdim(dim)) {
|
||||
return dim - 'A' + 'a';
|
||||
}
|
||||
return dim;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Return an undefined layout.
|
||||
* \return a (global) undefined layout.
|
||||
*/
|
||||
static const Layout& Undef() {
|
||||
static Layout undef;
|
||||
return undef;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Swap current object with other
|
||||
* \param other another object to be swapped.
|
||||
*/
|
||||
void swap(Layout& other) { // NOLINT(*)
|
||||
std::swap(name_, other.name_);
|
||||
std::swap(superdim_pos_, other.superdim_pos_);
|
||||
std::swap(subdim_pos_, other.subdim_pos_);
|
||||
std::swap(subdim_size_, other.subdim_size_);
|
||||
std::swap(layout_simplified_, other.layout_simplified_);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Two layouts are convertible only if
|
||||
* they have same set of super-dimensions.
|
||||
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
|
||||
* but NCHW, CHW, OIHW are not.
|
||||
* \param dst the target layout
|
||||
* \return Whether can be converted to dst layout.
|
||||
*/
|
||||
bool convertible(const Layout &dst) const {
|
||||
if (!this->defined() || !dst.defined()) return false;
|
||||
for (size_t i = 0; i < kUniqueDim; ++i) {
|
||||
if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
|
||||
(superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Returns a sublayout which is the portion of the object
|
||||
* that starts at dimension \p pos and spans \p len dimensions
|
||||
* (or until the end of the layout, whichever comes first).
|
||||
* \param pos The start position.
|
||||
* \param len The length of the sub-layout.
|
||||
* \return A newly constructed Layout object.
|
||||
*/
|
||||
Layout sublayout(size_t pos, size_t len) const {
|
||||
if (pos > ndim()) return Layout::Undef();
|
||||
if (pos + len > ndim()) len = ndim() - pos;
|
||||
if (len == 0) return Layout::Undef();
|
||||
std::ostringstream new_layout;
|
||||
for (size_t i = pos; i < pos + len; ++i) {
|
||||
if (is_subdim(layout_simplified_[i])) {
|
||||
auto block_size = this->subsizeof(layout_simplified_[i]);
|
||||
CHECK_GT(block_size, 0);
|
||||
new_layout << block_size;
|
||||
}
|
||||
new_layout << layout_simplified_[i];
|
||||
}
|
||||
return Layout(new_layout.str());
|
||||
}
|
||||
|
||||
/*! \return A newly constructed reversed Layout object. */
|
||||
Layout reverse() const {
|
||||
if (!this->defined()) return Layout::Undef();
|
||||
std::ostringstream new_layout;
|
||||
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
|
||||
if (is_subdim(layout_simplified_[i])) {
|
||||
auto block_size = this->subsizeof(layout_simplified_[i]);
|
||||
CHECK_GT(block_size, 0);
|
||||
new_layout << block_size;
|
||||
}
|
||||
new_layout << layout_simplified_[i];
|
||||
}
|
||||
return Layout(new_layout.str());
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
|
||||
* \param dim The source dimension to be split. It must be a super-dimension.
|
||||
* \param target_pos The target position of the newly split sub-dimension.
|
||||
* \param size size of the sub-dimension.
|
||||
* \return A newly constructed Layout object.
|
||||
*/
|
||||
Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
|
||||
CHECK(target_pos <= this->ndim()) << "Invalid split position "
|
||||
<< target_pos << " for layout " << name_;
|
||||
CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
|
||||
CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
|
||||
CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
|
||||
<< " has already been split in "
|
||||
<< name_;
|
||||
CHECK(size > 0) << "Invalid split size " << size;
|
||||
std::ostringstream new_layout;
|
||||
for (size_t i = 0; i <= this->ndim(); ++i) {
|
||||
if (i == target_pos) {
|
||||
new_layout << size << Layout::to_subdim(dim);
|
||||
}
|
||||
if (i == this->ndim()) break;
|
||||
new_layout << this->at(i);
|
||||
}
|
||||
Layout x(new_layout.str());
|
||||
return x;
|
||||
}
|
||||
|
||||
using iterator = std::vector<LayoutDim>::const_iterator;
|
||||
using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
|
||||
|
||||
/*! \return begin iterator */
|
||||
iterator begin() const {
|
||||
return layout_simplified_.begin();
|
||||
}
|
||||
/*! \return end iterator */
|
||||
iterator end() const {
|
||||
return layout_simplified_.end();
|
||||
}
|
||||
/*! \return rbegin iterator */
|
||||
reverse_iterator rbegin() const {
|
||||
return layout_simplified_.rbegin();
|
||||
}
|
||||
/*! \return rend iterator */
|
||||
reverse_iterator rend() const {
|
||||
return layout_simplified_.rend();
|
||||
}
|
||||
|
||||
/*! \return number of dimensions */
|
||||
size_t ndim() const {
|
||||
return layout_simplified_.size();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief The description of the \p i-th dimension.
|
||||
* If it is a sub-dimension, the size will be returned as well,
|
||||
* e.g., 16c. Otherwise a single character is returned, e.g., C.
|
||||
* \param i The position
|
||||
* \return the description of the dimension.
|
||||
*/
|
||||
std::string at(size_t i) const {
|
||||
CHECK_LT(i, this->ndim()) << "position " << i
|
||||
<< " exceeds ndim=" << this->ndim();
|
||||
std::ostringstream repr;
|
||||
if (is_subdim(layout_simplified_[i])) {
|
||||
auto factor = subsizeof(layout_simplified_[i]);
|
||||
CHECK_GT(factor, 0);
|
||||
repr << factor;
|
||||
}
|
||||
repr << layout_simplified_[i];
|
||||
return repr.str();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief return the index of the input dimension.
|
||||
* If it is not found in the layout or the layout is undefined,
|
||||
* return -1.
|
||||
* \param dim the input dimension.
|
||||
* \return the index or -1 if not found.
|
||||
*/
|
||||
int32_t indexof(LayoutDim dim) const {
|
||||
if (!this->defined()) return -1;
|
||||
else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
|
||||
else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \param dim the input super-dimension or sub-dimension.
|
||||
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
|
||||
* or the size of \p dim itself (if \p dim is a sub-dimension).
|
||||
* Return -1 if \p dim is not in the layout or the layout is undefined.
|
||||
*/
|
||||
int64_t subsizeof(LayoutDim dim) const {
|
||||
CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
|
||||
if (!this->defined() || !this->contains(to_subdim(dim))) {
|
||||
return -1;
|
||||
}
|
||||
int idx = to_subdim(dim) - 'a';
|
||||
return subdim_size_[idx];
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Whether the layout contains a dimension.
|
||||
* \param dim dimension to be checked.
|
||||
* \return Whether the layout contains the dimension.
|
||||
*/
|
||||
bool contains(LayoutDim dim) const {
|
||||
if (is_superdim(dim)) {
|
||||
return superdim_pos_[dim-'A'] >= 0;
|
||||
} else if (is_subdim(dim)) {
|
||||
return subdim_pos_[dim-'a'] >= 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LayoutDim operator[](size_t i) const {
|
||||
return layout_simplified_[i];
|
||||
}
|
||||
|
||||
/*! \return whether the layout is defined */
|
||||
bool defined() const {
|
||||
return name_ != "__undef__";
|
||||
}
|
||||
|
||||
/*! \return the string description of the layout */
|
||||
const std::string& name() const {
|
||||
return name_;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Write layout in JSON format.
|
||||
* \param writer JSONWriter
|
||||
*/
|
||||
void Save(dmlc::JSONWriter* writer) const {
|
||||
writer->Write(name_);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Load layout from JSON.
|
||||
* \param reader JSONReader
|
||||
*/
|
||||
void Load(dmlc::JSONReader* reader) {
|
||||
std::string tmp;
|
||||
reader->Read(&tmp);
|
||||
this->parse(tmp);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief allow output string of layout to ostream
|
||||
* \param os the output stream
|
||||
* \param l the layout
|
||||
* \return the ostream
|
||||
*/
|
||||
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
|
||||
os << l.name_;
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
static const uint32_t kUniqueDim = 26;
|
||||
|
||||
std::string name_;
|
||||
int32_t superdim_pos_[kUniqueDim];
|
||||
int32_t subdim_pos_[kUniqueDim];
|
||||
int64_t subdim_size_[kUniqueDim];
|
||||
std::vector<LayoutDim> layout_simplified_;
|
||||
|
||||
void parse(const std::string& layout) {
|
||||
name_ = layout;
|
||||
std::fill_n(superdim_pos_, kUniqueDim, -1);
|
||||
std::fill_n(subdim_pos_, kUniqueDim, -1);
|
||||
std::fill_n(subdim_size_, kUniqueDim, -1);
|
||||
layout_simplified_.clear();
|
||||
|
||||
if (layout == "__undef__") return;
|
||||
|
||||
int32_t factor = 0;
|
||||
uint32_t curr = 0;
|
||||
for (size_t i = 0; i < layout.size(); ++i) {
|
||||
const LayoutDim c = layout.at(i);
|
||||
if (is_superdim(c)) {
|
||||
int pos = c - 'A';
|
||||
CHECK_EQ(factor, 0) << "Invalid layout " << layout
|
||||
<< ": invalid factor size " << factor
|
||||
<< " before dimension " << c;
|
||||
CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
|
||||
<< ": duplicate dimension " << c;
|
||||
superdim_pos_[pos] = curr++;
|
||||
layout_simplified_.push_back(c);
|
||||
} else if (is_subdim(c)) {
|
||||
int pos = c - 'a';
|
||||
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
|
||||
<< factor << " for dimension " << c;
|
||||
CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
|
||||
<< ": duplicate dimension " << c;
|
||||
CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
|
||||
<< ": duplicate dimension " << c;
|
||||
subdim_pos_[pos] = curr++;
|
||||
subdim_size_[pos] = factor;
|
||||
layout_simplified_.push_back(c);
|
||||
factor = 0;
|
||||
} else if (c >= '0' && c <= '9') {
|
||||
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
|
||||
factor = factor * 10 + c - '0';
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid layout " << layout;
|
||||
}
|
||||
}
|
||||
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
|
||||
for (LayoutDim dim : layout_simplified_) {
|
||||
CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
|
||||
<< "Invalid layout " << layout << ": missing axis "
|
||||
<< static_cast<char>(dim - 'a' + 'A');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Convert shape in src_layout to shape in dst_layout
|
||||
* \param src original shape
|
||||
* \param src_layout layout of original shape
|
||||
* \param dst_layout target layout
|
||||
* \return shape in target layout
|
||||
*/
|
||||
inline std::vector<IndexExpr> ConvertLayout(
|
||||
std::vector<IndexExpr> src,
|
||||
const Layout& src_layout,
|
||||
const Layout& dst_layout) {
|
||||
CHECK_EQ(src_layout.ndim(), src.size());
|
||||
if (src_layout == dst_layout) {
|
||||
return src;
|
||||
} else if (!src_layout.defined()) {
|
||||
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
|
||||
} else if (!dst_layout.defined()) {
|
||||
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
|
||||
}
|
||||
|
||||
CHECK(src_layout.convertible(dst_layout))
|
||||
<< "cannot convert from "
|
||||
<< src_layout << " to " << dst_layout;
|
||||
|
||||
std::vector<IndexExpr> dst(dst_layout.ndim());
|
||||
for (size_t i = 0; i < src_layout.ndim(); ++i) {
|
||||
Layout::LayoutDim src_dim = src_layout[i];
|
||||
if (Layout::is_superdim(src_dim)) {
|
||||
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim));
|
||||
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim));
|
||||
int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim));
|
||||
int src_factor = src_layout.subsizeof(src_dim);
|
||||
int dst_factor = dst_layout.subsizeof(src_dim);
|
||||
IndexExpr src_dim_size = src[i];
|
||||
|
||||
if (src_minor_pos >= 0) {
|
||||
const int64_t* minor_size = as_const_int(src[src_minor_pos]);
|
||||
CHECK(minor_size == nullptr &&
|
||||
src_factor == minor_size[0])
|
||||
<< "src shape " << Array<IndexExpr>(src)
|
||||
<< " does not agree with layout "
|
||||
<< src_layout;
|
||||
src_dim_size *= src_factor;
|
||||
}
|
||||
dst[dst_major_pos] = src_dim_size;
|
||||
if (dst_minor_pos >= 0) {
|
||||
CHECK_GT(dst_factor, 0);
|
||||
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
|
||||
CHECK_LE(dst_factor, const_src_dim_size[0])
|
||||
<< "Converting " << Array<IndexExpr>(src)
|
||||
<< " from " << src_layout
|
||||
<< " to " << dst_layout
|
||||
<< ": cannot split dimension size of "
|
||||
<< src_dim_size << " by " << dst_factor;
|
||||
}
|
||||
dst[dst_major_pos] /= dst_factor;
|
||||
dst[dst_minor_pos] = dst_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
|
||||
inline std::vector<IndexExpr> ConvertLayout(
|
||||
const Array<IndexExpr>& src,
|
||||
const Layout& src_layout,
|
||||
const Layout& dst_layout) {
|
||||
std::vector<IndexExpr> ret(src.size());
|
||||
for (size_t i = 0; i < src.size(); ++i) {
|
||||
ret[i] = src[i];
|
||||
}
|
||||
return ConvertLayout(ret, src_layout, dst_layout);
|
||||
}
|
||||
|
||||
} // namespace relay
|
||||
} // namespace tvm
|
||||
#endif // TVM_RELAY_OP_NN_LAYOUT_H_
|
|
@ -69,8 +69,8 @@ Type ConcreteBroadcast(const TensorType& t1,
|
|||
rev_sh2++;
|
||||
}
|
||||
|
||||
Array<ShapeExpr> larger;
|
||||
Array<ShapeExpr> smaller;
|
||||
Array<IndexExpr> larger;
|
||||
Array<IndexExpr> smaller;
|
||||
|
||||
for (int i = 0; i < (full_len - suffix_len); i++) {
|
||||
smaller.push_back(make_const(tvm::Int(64), 1));
|
||||
|
@ -93,7 +93,7 @@ Type ConcreteBroadcast(const TensorType& t1,
|
|||
|
||||
CHECK_EQ(larger.size(), smaller.size());
|
||||
|
||||
Array<ShapeExpr> out_shape;
|
||||
Array<IndexExpr> out_shape;
|
||||
for (size_t i = 0; i < smaller.size(); i++) {
|
||||
auto left = smaller[i].as<tvm::ir::IntImm>();
|
||||
auto right = larger[i].as<tvm::ir::IntImm>();
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* \file src/tvm/relay/pass/alpha_eq.cc
|
||||
* \brief Compute the set of variables not bound in the expression.
|
||||
* \brief The structral equivalence comparison.
|
||||
*/
|
||||
#include <tvm/ir_pass.h>
|
||||
#include <tvm/relay/expr_functor.h>
|
||||
#include "./type_visitor.h"
|
||||
#include "tvm/relay/pass.h"
|
||||
|
@ -19,9 +20,23 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
|
|||
TypeAlphaEq() : eq_map(), equal(true) {}
|
||||
|
||||
void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
|
||||
equal = equal && dt1 == dt2;
|
||||
if (dt1 != dt2) {
|
||||
equal = false;
|
||||
}
|
||||
}
|
||||
|
||||
void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) {
|
||||
if (s1.size() != s2.size()) {
|
||||
equal = false;
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < s1.size(); ++i) {
|
||||
if (!tvm::ir::Equal(s1[i], s2[i])) {
|
||||
equal = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
void ShapeEqual(Array<ShapeExpr> s1, Array<ShapeExpr> s2) {}
|
||||
|
||||
void VisitType_(const TensorTypeNode *tt1, const Type& t2) final {
|
||||
if (const TensorTypeNode *tt2 = t2.as<TensorTypeNode>()) {
|
||||
|
|
|
@ -354,8 +354,8 @@ Expr TypeInferencer::Infer(Expr expr) {
|
|||
return Resolver(type_map_, &solver_).VisitExpr(expr);
|
||||
}
|
||||
|
||||
Expr InferType(const Environment& env, const Expr& e) {
|
||||
return TypeInferencer(env).Infer(e);
|
||||
Expr InferType(const Environment& env, const Expr& expr) {
|
||||
return TypeInferencer(env).Infer(expr);
|
||||
}
|
||||
|
||||
Expr InferType(const Environment& env,
|
||||
|
@ -370,11 +370,9 @@ Expr InferType(const Environment& env,
|
|||
return func_ret;
|
||||
}
|
||||
|
||||
TVM_REGISTER_API("relay._ir_pass.check_expr")
|
||||
TVM_REGISTER_API("relay._ir_pass.infer_type")
|
||||
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||
Environment env = args[0];
|
||||
Expr e = args[1];
|
||||
*ret = InferType(env, e);
|
||||
*ret = InferType(args[0], args[1]);
|
||||
});
|
||||
|
||||
} // namespace relay
|
||||
|
|
|
@ -18,8 +18,13 @@ class TypeSolver::Reporter : public TypeReporterNode {
|
|||
solver_->Unify(dst, src);
|
||||
}
|
||||
|
||||
void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) final {
|
||||
// TODO(tqchen)
|
||||
bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final {
|
||||
// early warning constant case.
|
||||
IndexExpr diff = lhs - rhs;
|
||||
if (const int64_t* pdiff = as_const_int(diff)) {
|
||||
return pdiff[0] == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -8,7 +8,6 @@ ib = IRBuilder()
|
|||
def show(e):
|
||||
r = debug_print(ib.env, e)
|
||||
assert r is not None
|
||||
# print(r) # uncomment this line to debug
|
||||
|
||||
|
||||
def test_constant():
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
import tvm
|
||||
from tvm import relay
|
||||
|
||||
|
||||
def test_conv2d_infer_type():
|
||||
# symbolic in batch dimension
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, c, h, w = tvm.var("n"), 10, 224, 224
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
|
||||
w = ib.param("w", relay.ty.IncompleteType())
|
||||
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d(x.var, w.var,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
channels=2))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type()
|
||||
assert ftype.ret_type == relay.ty.TensorType(
|
||||
(n, 2, 224, 224), "float32")
|
||||
assert ftype.arg_types[1] == relay.ty.TensorType(
|
||||
(2, 10, 3, 3), "float32")
|
||||
|
||||
# infer by shape of w, mixed precision
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, c, h, w = tvm.var("n"), 10, 224, 224
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
|
||||
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type()
|
||||
assert ftype.ret_type == relay.ty.TensorType(
|
||||
(n, 2, 222, 222), "int32")
|
||||
|
||||
# Infer with a different layout
|
||||
ib = relay.ir_builder.IRBuilder()
|
||||
n, c, h, w = 4, 32, 224, 224
|
||||
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
|
||||
w = ib.param("w", relay.ty.IncompleteType())
|
||||
with ib.function(x, w) as func:
|
||||
ib.ret(relay.nn.conv2d(x.var, w.var,
|
||||
kernel_size=(3, 3),
|
||||
padding=(1, 1),
|
||||
channels=16,
|
||||
data_layout="NCHW4n4c",
|
||||
weight_layout="OIHW4o4i",
|
||||
out_dtype="int32"))
|
||||
ib.ret(func)
|
||||
func = relay.ir_pass.infer_type(ib.env, func.to_func())
|
||||
ftype = func.checked_type()
|
||||
assert ftype.ret_type == relay.ty.TensorType(
|
||||
(1, 4, 224, 224, 4, 4), "int32")
|
||||
assert ftype.arg_types[1] == relay.ty.TensorType(
|
||||
(4, 8, 3, 3, 4, 4), "int8")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_conv2d_infer_type()
|
|
@ -0,0 +1,17 @@
|
|||
import tvm
|
||||
from tvm import relay
|
||||
|
||||
def test_type_alpha_eq():
|
||||
t1 = relay.ty.TensorType((3, 4), "float32")
|
||||
t2 = relay.ty.TensorType((3, 4), "float32")
|
||||
t3 = relay.ty.TensorType((3, 4, 5), "float32")
|
||||
assert t1 == t2
|
||||
assert t1 != t3
|
||||
|
||||
t1 = relay.ty.TensorType((), "float32")
|
||||
t2 = relay.ty.TensorType((), "float32")
|
||||
assert t1 == t2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_type_alpha_eq()
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
import tvm
|
||||
import numpy as np
|
||||
from tvm.relay.ir_pass import check_expr
|
||||
from tvm.relay.ir_pass import infer_type
|
||||
from tvm.relay.ir_builder import IRBuilder, func_type
|
||||
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
|
||||
from tvm.relay.env import Environment
|
||||
|
@ -11,8 +11,11 @@ from tvm.relay.op import log, add, equal, subtract, concat
|
|||
from tvm.relay.expr import Function
|
||||
|
||||
def assert_has_type(expr, typ, env=Environment({})):
|
||||
checked_expr = check_expr(env, expr)
|
||||
assert checked_expr.checked_type() == typ
|
||||
checked_expr = infer_type(env, expr)
|
||||
checked_type = checked_expr.checked_type()
|
||||
if checked_type != typ:
|
||||
raise RuntimeError("Type mismatch %s vs %s" % (
|
||||
checked_type, typ))
|
||||
|
||||
|
||||
def assert_decl_has_type(env, name, typ):
|
||||
|
@ -47,6 +50,7 @@ def test_add_op():
|
|||
}
|
||||
"""
|
||||
b = IRBuilder()
|
||||
|
||||
x = b.param('x', tensor_type(5, 5, 5))
|
||||
y = b.param('y', tensor_type(5, 5, 5))
|
||||
with b.function(x, y) as func:
|
||||
|
@ -71,8 +75,9 @@ def test_add_broadcast_op():
|
|||
b.ret(add(x.var, y.var))
|
||||
b.ret(func)
|
||||
prog, env = b.get()
|
||||
ttype = tensor_type(5, 5, 5)
|
||||
expected_ty = func_type([ttype, ttype], ttype)
|
||||
|
||||
expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)],
|
||||
tensor_type(5, 10, 4))
|
||||
assert_has_type(func.to_func(), expected_ty)
|
||||
|
||||
def test_dual_op():
|
||||
|
@ -89,7 +94,9 @@ def test_dual_op():
|
|||
t1 = b.let('t1', log(x))
|
||||
t2 = b.let('t2', add(t1, x))
|
||||
b.ret(t2)
|
||||
assert_has_type(func.to_func(), func_type(['float32'], 'float32'))
|
||||
|
||||
assert_has_type(func.to_func(),
|
||||
func_type([tensor_type(10, 10)], tensor_type(10, 10)))
|
||||
|
||||
|
||||
def test_decl():
|
||||
|
@ -152,12 +159,12 @@ def test_concat():
|
|||
assert_decl_has_type(ib.env, try_concat2, fn_ty)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_recursion()
|
||||
test_dual_op()
|
||||
|
||||
test_recursion()
|
||||
test_monomorphic_let()
|
||||
test_single_op()
|
||||
test_add_op()
|
||||
test_add_broadcast_op()
|
||||
test_dual_op()
|
||||
test_decl()
|
||||
test_concat()
|
||||
|
|
|
@ -59,6 +59,7 @@ inline int64_t GetConstInt(Expr expr) {
|
|||
*/
|
||||
inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& var_name) {
|
||||
std::vector<int> result;
|
||||
if (!exprs.defined()) return result;
|
||||
for (auto expr : exprs) {
|
||||
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
|
||||
result.push_back(GetConstInt(expr));
|
||||
|
@ -77,6 +78,7 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
|
|||
*/
|
||||
inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::string& var_name) {
|
||||
std::vector<int64_t> result;
|
||||
if (!exprs.defined()) return result;
|
||||
for (auto expr : exprs) {
|
||||
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
|
||||
result.push_back(GetConstInt(expr));
|
||||
|
|
Загрузка…
Ссылка в новой задаче