[RELAY][OP] conv2d, ShapeExpr->IndexExpr (#1798)

This commit is contained in:
Tianqi Chen 2018-10-03 21:58:25 -07:00 коммит произвёл GitHub
Родитель 147f3ad526
Коммит 9afde69b84
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
28 изменённых файлов: 1039 добавлений и 37 удалений

Просмотреть файл

@ -56,6 +56,22 @@ namespace tvm {
__fvisit__(#FieldName, &FieldName)
/*!
* \brief Create a NodeRef type that represents null.
* \tparam TNodeRef the type to be created.
* \return A instance that will represent None.
*/
template<typename TNodeRef>
inline TNodeRef NullValue() {
return TNodeRef(NodePtr<Node>(nullptr));
}
template<>
inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0);
}
/*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error {
/*!

Просмотреть файл

@ -114,7 +114,7 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
if (type_code_ == kNull) return TNodeRef(NodePtr<Node>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))

Просмотреть файл

@ -0,0 +1,72 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/nn.h
* \brief Auxiliary attributes for nn operators.
*/
#ifndef TVM_RELAY_ATTRS_NN_H_
#define TVM_RELAY_ATTRS_NN_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Attributes used in convolution operators */
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string weight_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("__undef__")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_

Просмотреть файл

@ -37,7 +37,7 @@ using DataType = ::tvm::Type;
/*!
* \brief Symbolic expression for tensor shape.
*/
using ShapeExpr = ::tvm::Expr;
using IndexExpr = ::tvm::Expr;
/*!
* \brief Hash function for nodes.

Просмотреть файл

@ -286,7 +286,9 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Call make(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
TVM_DLL static Call make(Expr op,
Array<Expr> args,
Attrs attrs = Attrs(),
Array<Type> ty_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call";

Просмотреть файл

@ -70,9 +70,9 @@ class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by ShapeExpr(tvm::Expr).
* represented by IndexExpr(tvm::Expr).
*/
Array<ShapeExpr> shape;
Array<IndexExpr> shape;
/*! \brief The content data type */
DataType dtype;
@ -82,7 +82,7 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}
TVM_DLL static TensorType make(Array<ShapeExpr> shape, DataType dtype);
TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */
TVM_DLL static TensorType Scalar(DataType dtype);
@ -273,8 +273,10 @@ class TypeReporterNode : public Node {
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) = 0;
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}

Просмотреть файл

@ -521,6 +521,12 @@ class TVMArgValue : public TVMPODValue_ {
if (type_code_ == kStr) {
return String2TVMType(operator std::string());
}
// None type
if (type_code_ == kNull) {
TVMType t;
t.code = kHandle; t.bits = 0; t.lanes = 0;
return t;
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
@ -878,6 +884,7 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
#endif
inline std::string TVMType2String(TVMType t) {
if (t.bits == 0) return "";
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
@ -896,6 +903,11 @@ inline std::string TVMType2String(TVMType t) {
inline TVMType String2TVMType(std::string s) {
TVMType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {

Просмотреть файл

@ -9,6 +9,7 @@ from . import ir_builder
# Operators
from .op import Op
from .op.tensor import *
from .op import nn
# Span
Span = base.Span

Просмотреть файл

@ -11,17 +11,19 @@ class Environment(NodeBase):
options and more.
"""
def __init__(self, funcs):
def __init__(self, funcs=None):
"""Construct an environment.
Parameters
------
funcs: list of relay.Function
funcs : optional, dict
Map of global var to Function
Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
"""
funcs = funcs if funcs else {}
self.__init_handle_by_constructor__(_make.Environment, funcs)
def add(self, var, func):

Просмотреть файл

@ -6,10 +6,26 @@ Exposes an interface for configuring the passes and scripting
them in Python.
"""
from . import _ir_pass
# Expose checking expression, should rename to infer_type.
# pylint: disable=invalid-name
check_expr = _ir_pass.check_expr
def infer_type(env, expr):
"""Infer the type of expr under the context of env
Parameters
----------
env : relay.Environment
The global environmemt.
expr : relay.Expr
The input expression.
Returns
-------
checked_expr : relay.Expr
The checked expression.
"""
return _ir_pass.infer_type(env, expr)
well_formed = _ir_pass.well_formed

Просмотреть файл

@ -5,6 +5,8 @@ from .op import get, register, Op
# Operators
from .tensor import *
from . import nn
# operator registry
from . import _tensor

54
python/tvm/relay/op/nn.py Normal file
Просмотреть файл

@ -0,0 +1,54 @@
"""Neural network operations."""
from __future__ import absolute_import as _abs
from . import _make
def conv2d(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
weight_layout="OIHW",
out_layout="",
out_dtype=""):
"""Two dimensional convolution operator.
Parameters
----------
data : relay.Expr
The input data to the operator.
weight : relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
data_layout : str, optional
Layout of the input.
weight_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
"""
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
weight_layout, out_layout, out_dtype)

Просмотреть файл

@ -117,6 +117,9 @@ Operation ComputeOpNode::make(std::string name,
Map<std::string, NodeRef> attrs,
Array<IterVar> axis,
Array<Expr> body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);

Просмотреть файл

@ -43,6 +43,9 @@ Operation ExternOpNode::make(std::string name,
Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders,
Stmt body) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ExternOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);

Просмотреть файл

@ -51,6 +51,9 @@ Operation ScanOpNode::make(std::string name,
Array<Tensor> update,
Array<Tensor> state_placeholder,
Array<Tensor> inputs) {
if (!attrs.defined()) {
attrs = Map<std::string, NodeRef>();
}
auto n = make_node<ScanOpNode>();
CHECK_EQ(init.size(), update.size());
CHECK_EQ(init.size(), state_placeholder.size());

Просмотреть файл

@ -418,6 +418,19 @@ bool Equal(const Stmt& lhs, const Stmt& rhs) {
}
bool Equal(const Expr& lhs, const Expr& rhs) {
// quick pass for constant expressions.
if (const int64_t *a = as_const_int(lhs)) {
if (const int64_t *b = as_const_int(rhs)) {
return a[0] == b[0];
}
}
if (!lhs.defined()) {
if (rhs.defined()) return false;
if (!rhs.defined()) return true;
} else {
if (!rhs.defined()) return false;
}
// deep comparison.
return IRDeepCompare().Equal(lhs, rhs);
}

Просмотреть файл

@ -11,7 +11,7 @@ namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) {
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
n->shape = std::move(shape);
n->dtype = std::move(dtype);
@ -24,7 +24,7 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Array<ShapeExpr> shape = args[0];
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});

Просмотреть файл

@ -0,0 +1,158 @@
/*!
* Copyright (c) 2018 by Contributors
* \file convolution.cc
* \brief Convolution operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "layout.h"
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ConvAttrs);
bool Conv2DRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const ConvAttrs* param = attrs.as<ConvAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout);
if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
{param->channels / param->groups,
data->shape[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[kernel_layout.indexof('O')] *= param->groups;
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(data->shape[1] / param->groups, wshape[1]));
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape[2] = (data->shape[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1;
oshape[3] = (data->shape[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1;
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = ConvertLayout(oshape, kNCHW, out_layout);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
// Positional relay function to create conv2d operator
// used by frontend FFI.
Expr MakeConv2D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string weight_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<ConvAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = kernel_size;
attrs->data_layout = std::move(data_layout);
attrs->weight_layout = std::move(weight_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("conv2d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
});
RELAY_REGISTER_OP("conv2d")
.describe(R"code(2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel);
} // namespace relay
} // namespace tvm

538
src/relay/op/nn/layout.h Normal file
Просмотреть файл

@ -0,0 +1,538 @@
/*!
* Copyright (c) 2018 by Contributors
* \file relay/op/nn/layout.h
* \brief Layout expression.
*
* This file is adapted from its nnvm counterpart and will keep involving
* to the new layout system
*
* The layout is composed of upper cases, lower cases and numbers,
* where upper case indicates a (super-)dimension and
* the corresponding lower case with factor size indicates the split (sub-)dimension.
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* Here sub-dimension channel_block=16 is the split of super-dimension C (channel).
*/
#ifndef TVM_RELAY_OP_NN_LAYOUT_H_
#define TVM_RELAY_OP_NN_LAYOUT_H_
#include <string>
#include <sstream>
#include <vector>
#include <utility>
#include <algorithm>
namespace tvm {
namespace relay {
/*! \brief layout auxiliary structure */
class Layout {
public:
using LayoutDim = char;
/*! \brief default constructor */
Layout() : name_("__undef__") {} // NOLINT(*)
/*!
* \brief construct from a string.
* \param layout input in layout convention:
* upper case indicates a dimension and
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
*/
Layout(const std::string& layout) { // NOLINT(*)
if (layout.length() != 0) {
parse(layout);
} else {
parse("__undef__");
}
}
/*!
* \brief copy constructor from another layout
* \param s the source layout
*/
Layout(const Layout& s) { // NOLINT(*)
this->parse(s.name_);
}
/*!
* \brief move constructor from Layout
* \param src the source layout
*/
Layout(Layout&& src) { // NOLINT(*)
this->swap(src);
}
/*!
* \brief assignment from another layout.
* \param src source layout
* \return reference of self
*/
Layout& operator=(const Layout& src) {
this->parse(src.name_);
return *this;
}
/*!
* \brief assignment from rvalue of another layout.
* \param src source layout
* \return reference of self
*/
Layout& operator=(Layout&& src) {
Layout(std::move(src)).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief assignment from string.
* \param src source layout
* \return reference of self
*/
Layout& operator=(const std::string& src) {
this->parse(src);
return *this;
}
/*!
* \return whether two layout equals
* \param s the layout to compare against
*/
bool operator==(const Layout& s) const {
return name_ == s.name_;
}
/*!
* \return whether two layout not equal
* \param s the layout to compare against
*/
bool operator!=(const Layout& s) const {
return !(*this == s);
}
/*!
* \brief Append the current layout by another.
* @param other the layout to be appended
* @return a new layout
*/
Layout operator+(const Layout& other) const {
if (!this->defined() && !other.defined()) {
return Layout::Undef();
} else if (!this->defined()) {
return other;
} else if (!other.defined()) {
return *this;
}
return Layout(this->name_ + other.name_);
}
/*!
* \brief Check whether a given dimension is a super-dimension.
* \param dim input dimension
* \return Whether a given dimension is a super-dimension.
*/
static bool is_superdim(LayoutDim dim) {
return dim >= 'A' && dim <= 'Z';
}
/*!
* \brief Check whether a given dimension is a sub-dimension.
* \param dim input dimension
* \return Whether a given dimension is a sub-dimension.
*/
static bool is_subdim(LayoutDim dim) {
return dim >= 'a' && dim <= 'z';
}
/*!
* \brief Convert a given dimension to super-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim to_superdim(LayoutDim dim) {
if (is_subdim(dim)) {
return dim - 'a' + 'A';
}
return dim;
}
/*!
* \brief Convert a given dimension to sub-dimension.
* \param dim input dimension
* \return The converted description.
*/
static LayoutDim to_subdim(LayoutDim dim) {
if (is_superdim(dim)) {
return dim - 'A' + 'a';
}
return dim;
}
/*!
* \brief Return an undefined layout.
* \return a (global) undefined layout.
*/
static const Layout& Undef() {
static Layout undef;
return undef;
}
/*!
* \brief Swap current object with other
* \param other another object to be swapped.
*/
void swap(Layout& other) { // NOLINT(*)
std::swap(name_, other.name_);
std::swap(superdim_pos_, other.superdim_pos_);
std::swap(subdim_pos_, other.subdim_pos_);
std::swap(subdim_size_, other.subdim_size_);
std::swap(layout_simplified_, other.layout_simplified_);
}
/*!
* \brief Two layouts are convertible only if
* they have same set of super-dimensions.
* e.g., NCHW, NCHW16c, NHWC are convertible between each other,
* but NCHW, CHW, OIHW are not.
* \param dst the target layout
* \return Whether can be converted to dst layout.
*/
bool convertible(const Layout &dst) const {
if (!this->defined() || !dst.defined()) return false;
for (size_t i = 0; i < kUniqueDim; ++i) {
if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) ||
(superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) {
return false;
}
}
return true;
}
/*!
* \brief Returns a sublayout which is the portion of the object
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param len The length of the sub-layout.
* \return A newly constructed Layout object.
*/
Layout sublayout(size_t pos, size_t len) const {
if (pos > ndim()) return Layout::Undef();
if (pos + len > ndim()) len = ndim() - pos;
if (len == 0) return Layout::Undef();
std::ostringstream new_layout;
for (size_t i = pos; i < pos + len; ++i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*! \return A newly constructed reversed Layout object. */
Layout reverse() const {
if (!this->defined()) return Layout::Undef();
std::ostringstream new_layout;
for (int64_t i = this->ndim() - 1; i >= 0; --i) {
if (is_subdim(layout_simplified_[i])) {
auto block_size = this->subsizeof(layout_simplified_[i]);
CHECK_GT(block_size, 0);
new_layout << block_size;
}
new_layout << layout_simplified_[i];
}
return Layout(new_layout.str());
}
/*!
* \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos.
* \param dim The source dimension to be split. It must be a super-dimension.
* \param target_pos The target position of the newly split sub-dimension.
* \param size size of the sub-dimension.
* \return A newly constructed Layout object.
*/
Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const {
CHECK(target_pos <= this->ndim()) << "Invalid split position "
<< target_pos << " for layout " << name_;
CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim;
CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_;
CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim
<< " has already been split in "
<< name_;
CHECK(size > 0) << "Invalid split size " << size;
std::ostringstream new_layout;
for (size_t i = 0; i <= this->ndim(); ++i) {
if (i == target_pos) {
new_layout << size << Layout::to_subdim(dim);
}
if (i == this->ndim()) break;
new_layout << this->at(i);
}
Layout x(new_layout.str());
return x;
}
using iterator = std::vector<LayoutDim>::const_iterator;
using reverse_iterator = std::vector<LayoutDim>::const_reverse_iterator;
/*! \return begin iterator */
iterator begin() const {
return layout_simplified_.begin();
}
/*! \return end iterator */
iterator end() const {
return layout_simplified_.end();
}
/*! \return rbegin iterator */
reverse_iterator rbegin() const {
return layout_simplified_.rbegin();
}
/*! \return rend iterator */
reverse_iterator rend() const {
return layout_simplified_.rend();
}
/*! \return number of dimensions */
size_t ndim() const {
return layout_simplified_.size();
}
/*!
* \brief The description of the \p i-th dimension.
* If it is a sub-dimension, the size will be returned as well,
* e.g., 16c. Otherwise a single character is returned, e.g., C.
* \param i The position
* \return the description of the dimension.
*/
std::string at(size_t i) const {
CHECK_LT(i, this->ndim()) << "position " << i
<< " exceeds ndim=" << this->ndim();
std::ostringstream repr;
if (is_subdim(layout_simplified_[i])) {
auto factor = subsizeof(layout_simplified_[i]);
CHECK_GT(factor, 0);
repr << factor;
}
repr << layout_simplified_[i];
return repr.str();
}
/*!
* \brief return the index of the input dimension.
* If it is not found in the layout or the layout is undefined,
* return -1.
* \param dim the input dimension.
* \return the index or -1 if not found.
*/
int32_t indexof(LayoutDim dim) const {
if (!this->defined()) return -1;
else if (is_superdim(dim)) return superdim_pos_[dim - 'A'];
else if (is_subdim(dim)) return subdim_pos_[dim - 'a'];
return -1;
}
/*!
* \param dim the input super-dimension or sub-dimension.
* \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension),
* or the size of \p dim itself (if \p dim is a sub-dimension).
* Return -1 if \p dim is not in the layout or the layout is undefined.
*/
int64_t subsizeof(LayoutDim dim) const {
CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim;
if (!this->defined() || !this->contains(to_subdim(dim))) {
return -1;
}
int idx = to_subdim(dim) - 'a';
return subdim_size_[idx];
}
/*!
* \brief Whether the layout contains a dimension.
* \param dim dimension to be checked.
* \return Whether the layout contains the dimension.
*/
bool contains(LayoutDim dim) const {
if (is_superdim(dim)) {
return superdim_pos_[dim-'A'] >= 0;
} else if (is_subdim(dim)) {
return subdim_pos_[dim-'a'] >= 0;
}
return false;
}
LayoutDim operator[](size_t i) const {
return layout_simplified_[i];
}
/*! \return whether the layout is defined */
bool defined() const {
return name_ != "__undef__";
}
/*! \return the string description of the layout */
const std::string& name() const {
return name_;
}
/*!
* \brief Write layout in JSON format.
* \param writer JSONWriter
*/
void Save(dmlc::JSONWriter* writer) const {
writer->Write(name_);
}
/*!
* \brief Load layout from JSON.
* \param reader JSONReader
*/
void Load(dmlc::JSONReader* reader) {
std::string tmp;
reader->Read(&tmp);
this->parse(tmp);
}
/*!
* \brief allow output string of layout to ostream
* \param os the output stream
* \param l the layout
* \return the ostream
*/
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
os << l.name_;
return os;
}
private:
static const uint32_t kUniqueDim = 26;
std::string name_;
int32_t superdim_pos_[kUniqueDim];
int32_t subdim_pos_[kUniqueDim];
int64_t subdim_size_[kUniqueDim];
std::vector<LayoutDim> layout_simplified_;
void parse(const std::string& layout) {
name_ = layout;
std::fill_n(superdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_pos_, kUniqueDim, -1);
std::fill_n(subdim_size_, kUniqueDim, -1);
layout_simplified_.clear();
if (layout == "__undef__") return;
int32_t factor = 0;
uint32_t curr = 0;
for (size_t i = 0; i < layout.size(); ++i) {
const LayoutDim c = layout.at(i);
if (is_superdim(c)) {
int pos = c - 'A';
CHECK_EQ(factor, 0) << "Invalid layout " << layout
<< ": invalid factor size " << factor
<< " before dimension " << c;
CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
superdim_pos_[pos] = curr++;
layout_simplified_.push_back(c);
} else if (is_subdim(c)) {
int pos = c - 'a';
CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size "
<< factor << " for dimension " << c;
CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout
<< ": duplicate dimension " << c;
subdim_pos_[pos] = curr++;
subdim_size_[pos] = factor;
layout_simplified_.push_back(c);
factor = 0;
} else if (c >= '0' && c <= '9') {
CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number.";
factor = factor * 10 + c - '0';
} else {
LOG(FATAL) << "Invalid layout " << layout;
}
}
CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout;
for (LayoutDim dim : layout_simplified_) {
CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0)
<< "Invalid layout " << layout << ": missing axis "
<< static_cast<char>(dim - 'a' + 'A');
}
}
};
/*!
* \brief Convert shape in src_layout to shape in dst_layout
* \param src original shape
* \param src_layout layout of original shape
* \param dst_layout target layout
* \return shape in target layout
*/
inline std::vector<IndexExpr> ConvertLayout(
std::vector<IndexExpr> src,
const Layout& src_layout,
const Layout& dst_layout) {
CHECK_EQ(src_layout.ndim(), src.size());
if (src_layout == dst_layout) {
return src;
} else if (!src_layout.defined()) {
LOG(FATAL) << "cannot convert undefined layout to " << dst_layout;
} else if (!dst_layout.defined()) {
LOG(FATAL) << "cannot convert " << src_layout << " to undefined layout";
}
CHECK(src_layout.convertible(dst_layout))
<< "cannot convert from "
<< src_layout << " to " << dst_layout;
std::vector<IndexExpr> dst(dst_layout.ndim());
for (size_t i = 0; i < src_layout.ndim(); ++i) {
Layout::LayoutDim src_dim = src_layout[i];
if (Layout::is_superdim(src_dim)) {
int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_dim));
int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_dim));
int src_minor_pos = src_layout.indexof(Layout::to_subdim(src_dim));
int src_factor = src_layout.subsizeof(src_dim);
int dst_factor = dst_layout.subsizeof(src_dim);
IndexExpr src_dim_size = src[i];
if (src_minor_pos >= 0) {
const int64_t* minor_size = as_const_int(src[src_minor_pos]);
CHECK(minor_size == nullptr &&
src_factor == minor_size[0])
<< "src shape " << Array<IndexExpr>(src)
<< " does not agree with layout "
<< src_layout;
src_dim_size *= src_factor;
}
dst[dst_major_pos] = src_dim_size;
if (dst_minor_pos >= 0) {
CHECK_GT(dst_factor, 0);
if (const int64_t* const_src_dim_size = as_const_int(src_dim_size)) {
CHECK_LE(dst_factor, const_src_dim_size[0])
<< "Converting " << Array<IndexExpr>(src)
<< " from " << src_layout
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
}
dst[dst_major_pos] /= dst_factor;
dst[dst_minor_pos] = dst_factor;
}
}
}
return dst;
}
inline std::vector<IndexExpr> ConvertLayout(
const Array<IndexExpr>& src,
const Layout& src_layout,
const Layout& dst_layout) {
std::vector<IndexExpr> ret(src.size());
for (size_t i = 0; i < src.size(); ++i) {
ret[i] = src[i];
}
return ConvertLayout(ret, src_layout, dst_layout);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_LAYOUT_H_

Просмотреть файл

@ -69,8 +69,8 @@ Type ConcreteBroadcast(const TensorType& t1,
rev_sh2++;
}
Array<ShapeExpr> larger;
Array<ShapeExpr> smaller;
Array<IndexExpr> larger;
Array<IndexExpr> smaller;
for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(make_const(tvm::Int(64), 1));
@ -93,7 +93,7 @@ Type ConcreteBroadcast(const TensorType& t1,
CHECK_EQ(larger.size(), smaller.size());
Array<ShapeExpr> out_shape;
Array<IndexExpr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>();

Просмотреть файл

@ -1,8 +1,9 @@
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc
* \brief Compute the set of variables not bound in the expression.
* \brief The structral equivalence comparison.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include "./type_visitor.h"
#include "tvm/relay/pass.h"
@ -19,9 +20,23 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
TypeAlphaEq() : eq_map(), equal(true) {}
void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
equal = equal && dt1 == dt2;
if (dt1 != dt2) {
equal = false;
}
}
void ShapeEqual(const Array<IndexExpr>& s1, const Array<IndexExpr>& s2) {
if (s1.size() != s2.size()) {
equal = false;
return;
}
for (size_t i = 0; i < s1.size(); ++i) {
if (!tvm::ir::Equal(s1[i], s2[i])) {
equal = false;
return;
}
}
}
void ShapeEqual(Array<ShapeExpr> s1, Array<ShapeExpr> s2) {}
void VisitType_(const TensorTypeNode *tt1, const Type& t2) final {
if (const TensorTypeNode *tt2 = t2.as<TensorTypeNode>()) {

Просмотреть файл

@ -354,8 +354,8 @@ Expr TypeInferencer::Infer(Expr expr) {
return Resolver(type_map_, &solver_).VisitExpr(expr);
}
Expr InferType(const Environment& env, const Expr& e) {
return TypeInferencer(env).Infer(e);
Expr InferType(const Environment& env, const Expr& expr) {
return TypeInferencer(env).Infer(expr);
}
Expr InferType(const Environment& env,
@ -370,11 +370,9 @@ Expr InferType(const Environment& env,
return func_ret;
}
TVM_REGISTER_API("relay._ir_pass.check_expr")
TVM_REGISTER_API("relay._ir_pass.infer_type")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Environment env = args[0];
Expr e = args[1];
*ret = InferType(env, e);
*ret = InferType(args[0], args[1]);
});
} // namespace relay

Просмотреть файл

@ -18,8 +18,13 @@ class TypeSolver::Reporter : public TypeReporterNode {
solver_->Unify(dst, src);
}
void AssertEQ(const ShapeExpr& lhs, const ShapeExpr& rhs) final {
// TODO(tqchen)
bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final {
// early warning constant case.
IndexExpr diff = lhs - rhs;
if (const int64_t* pdiff = as_const_int(diff)) {
return pdiff[0] == 0;
}
return true;
}
private:

Просмотреть файл

@ -8,7 +8,6 @@ ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
# print(r) # uncomment this line to debug
def test_constant():

Просмотреть файл

@ -0,0 +1,62 @@
import tvm
from tvm import relay
def test_conv2d_infer_type():
# symbolic in batch dimension
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var,
kernel_size=(3, 3),
padding=(1, 1),
channels=2))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 224, 224), "float32")
assert ftype.arg_types[1] == relay.ty.TensorType(
(2, 10, 3, 3), "float32")
# infer by shape of w, mixed precision
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 10, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.TensorType((2, 10, 3, 3), "int8"))
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(n, 2, 222, 222), "int32")
# Infer with a different layout
ib = relay.ir_builder.IRBuilder()
n, c, h, w = 4, 32, 224, 224
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d(x.var, w.var,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NCHW4n4c",
weight_layout="OIHW4o4i",
out_dtype="int32"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType(
(1, 4, 224, 224, 4, 4), "int32")
assert ftype.arg_types[1] == relay.ty.TensorType(
(4, 8, 3, 3, 4, 4), "int8")
if __name__ == "__main__":
test_conv2d_infer_type()

Просмотреть файл

@ -0,0 +1,17 @@
import tvm
from tvm import relay
def test_type_alpha_eq():
t1 = relay.ty.TensorType((3, 4), "float32")
t2 = relay.ty.TensorType((3, 4), "float32")
t3 = relay.ty.TensorType((3, 4, 5), "float32")
assert t1 == t2
assert t1 != t3
t1 = relay.ty.TensorType((), "float32")
t2 = relay.ty.TensorType((), "float32")
assert t1 == t2
if __name__ == "__main__":
test_type_alpha_eq()

Просмотреть файл

@ -3,7 +3,7 @@
"""
import tvm
import numpy as np
from tvm.relay.ir_pass import check_expr
from tvm.relay.ir_pass import infer_type
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
@ -11,8 +11,11 @@ from tvm.relay.op import log, add, equal, subtract, concat
from tvm.relay.expr import Function
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = check_expr(env, expr)
assert checked_expr.checked_type() == typ
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type()
if checked_type != typ:
raise RuntimeError("Type mismatch %s vs %s" % (
checked_type, typ))
def assert_decl_has_type(env, name, typ):
@ -47,6 +50,7 @@ def test_add_op():
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
@ -71,8 +75,9 @@ def test_add_broadcast_op():
b.ret(add(x.var, y.var))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
expected_ty = func_type([ttype, ttype], ttype)
expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)],
tensor_type(5, 10, 4))
assert_has_type(func.to_func(), expected_ty)
def test_dual_op():
@ -89,7 +94,9 @@ def test_dual_op():
t1 = b.let('t1', log(x))
t2 = b.let('t2', add(t1, x))
b.ret(t2)
assert_has_type(func.to_func(), func_type(['float32'], 'float32'))
assert_has_type(func.to_func(),
func_type([tensor_type(10, 10)], tensor_type(10, 10)))
def test_decl():
@ -152,12 +159,12 @@ def test_concat():
assert_decl_has_type(ib.env, try_concat2, fn_ty)
if __name__ == "__main__":
test_recursion()
test_dual_op()
test_recursion()
test_monomorphic_let()
test_single_op()
test_add_op()
test_add_broadcast_op()
test_dual_op()
test_decl()
test_concat()

Просмотреть файл

@ -59,6 +59,7 @@ inline int64_t GetConstInt(Expr expr) {
*/
inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& var_name) {
std::vector<int> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
@ -77,6 +78,7 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
*/
inline std::vector<int64_t> GetConstInt64Values(Array<Expr> exprs, const std::string& var_name) {
std::vector<int64_t> result;
if (!exprs.defined()) return result;
for (auto expr : exprs) {
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));