[Relay] Alter Op Layout (#2150)
* [RELAY] Finish alter op pass * [RELAY] AlterOpLayout Pass * fix broadcast operators * fix broadcast operators * fix broadcast operators * Support concatenate * address comments * address comments * add comments * rebase
This commit is contained in:
Родитель
4bf1fd8c44
Коммит
2a5656bf80
|
@ -1 +1 @@
|
||||||
Subproject commit e4a4c02764d37c9c3db0d64c4996651a3ef9513c
|
Subproject commit a08e26e5a97f4ef4d566a42f6c78704b3f9c7b8a
|
|
@ -105,6 +105,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
|
||||||
int groups;
|
int groups;
|
||||||
std::string data_layout;
|
std::string data_layout;
|
||||||
std::string weight_layout;
|
std::string weight_layout;
|
||||||
|
std::string out_layout;
|
||||||
DataType out_dtype;
|
DataType out_dtype;
|
||||||
|
|
||||||
TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
|
TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
|
||||||
|
@ -139,6 +140,10 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
|
||||||
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
|
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
|
||||||
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
|
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
|
||||||
"dimensions respectively.");
|
"dimensions respectively.");
|
||||||
|
TVM_ATTR_FIELD(out_layout).set_default("")
|
||||||
|
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
|
||||||
|
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
|
||||||
|
"dimensions respectively. Default to be same as input layout.");
|
||||||
TVM_ATTR_FIELD(out_dtype)
|
TVM_ATTR_FIELD(out_dtype)
|
||||||
.set_default(NullValue<DataType>())
|
.set_default(NullValue<DataType>())
|
||||||
.describe("Output data type, set to explicit type under mixed precision setting");
|
.describe("Output data type, set to explicit type under mixed precision setting");
|
||||||
|
|
|
@ -164,6 +164,19 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
|
||||||
|
std::string src_layout;
|
||||||
|
std::string dst_layout;
|
||||||
|
|
||||||
|
TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") {
|
||||||
|
TVM_ATTR_FIELD(src_layout)
|
||||||
|
.describe("The source layout of the tensor. (e.g. NCHW)");
|
||||||
|
TVM_ATTR_FIELD(dst_layout)
|
||||||
|
.describe("The destination layout of the tensor. (e.g. NCHW16c)");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
|
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
|
||||||
|
|
|
@ -459,7 +459,7 @@ inline const TTypeNode* ExprNode::type_as() const {
|
||||||
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
|
static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
|
||||||
"TType must be a special case of type");
|
"TType must be a special case of type");
|
||||||
CHECK(checked_type_.defined())
|
CHECK(checked_type_.defined())
|
||||||
<< "Type inference for this Expr has not completed";
|
<< "Type inference for this Expr has not completed. Try to call infer_type pass.";
|
||||||
const TTypeNode* node = checked_type_.as<TTypeNode>();
|
const TTypeNode* node = checked_type_.as<TTypeNode>();
|
||||||
CHECK(node != nullptr)
|
CHECK(node != nullptr)
|
||||||
<< "Expected type to be " << TTypeNode::_type_key
|
<< "Expected type to be " << TTypeNode::_type_key
|
||||||
|
|
|
@ -86,6 +86,21 @@ using FTVMSchedule = runtime::TypedPackedFunc<
|
||||||
const Array<Tensor>& outs,
|
const Array<Tensor>& outs,
|
||||||
const Target& target)>;
|
const Target& target)>;
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Alternate the layout of operators or replace the
|
||||||
|
* operator with other expressions. This function will be invoked
|
||||||
|
* in AlterOpLayout pass.
|
||||||
|
* \param attrs The attribute of the original node.
|
||||||
|
* \param inputs The input symbols of the original node.
|
||||||
|
* \param tinfos An array of placeholders, use for getting the inferred shape
|
||||||
|
* and dtype of the inputs.
|
||||||
|
* \return new_expr The modified expression.
|
||||||
|
*/
|
||||||
|
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
|
||||||
|
Expr(const Attrs& attrs,
|
||||||
|
const Array<Expr>& args,
|
||||||
|
const Array<Tensor>& tinfos)>;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Forward rewriting rule for a specific op.
|
* \brief Forward rewriting rule for a specific op.
|
||||||
*
|
*
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
#include <tvm/relay/module.h>
|
#include <tvm/relay/module.h>
|
||||||
#include <tvm/relay/expr.h>
|
#include <tvm/relay/expr.h>
|
||||||
|
#include <tvm/relay/op_attr_types.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
@ -173,6 +174,21 @@ Expr ForwardRewrite(const Expr& expr,
|
||||||
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||||
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
|
||||||
|
* \param expr The expression.
|
||||||
|
* \param rewrite_func The rewrite func that will apply to all operators.
|
||||||
|
* \param fcontext Additional callback to provide context argument for each call node.
|
||||||
|
* \param fmulti_ref_trigger Transformation function to be called when
|
||||||
|
* an Expr consumed by multiple callers.
|
||||||
|
* \return The rewritten expression.
|
||||||
|
*/
|
||||||
|
Expr ForwardRewrite(const Expr& expr,
|
||||||
|
const FForwardRewrite& rewrite_func,
|
||||||
|
std::function<NodeRef(const Call&)> fcontext = nullptr,
|
||||||
|
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);
|
||||||
|
|
||||||
|
|
||||||
/*! \brief A hashing structure in the style of std::hash. */
|
/*! \brief A hashing structure in the style of std::hash. */
|
||||||
struct StructuralHash {
|
struct StructuralHash {
|
||||||
/*! \brief Hash a Relay type.
|
/*! \brief Hash a Relay type.
|
||||||
|
|
|
@ -13,6 +13,7 @@ from . import container
|
||||||
from . import schedule
|
from . import schedule
|
||||||
from . import module
|
from . import module
|
||||||
from . import node
|
from . import node
|
||||||
|
from . import attrs
|
||||||
from . import ir_builder
|
from . import ir_builder
|
||||||
from . import target
|
from . import target
|
||||||
from . import generic
|
from . import generic
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
""" TVM Attribute module, which is mainly used for defining attributes of operators"""
|
||||||
|
from ._ffi.node import NodeBase, register_node as _register_tvm_node
|
||||||
|
from ._ffi.function import _init_api
|
||||||
|
from . import _api_internal
|
||||||
|
|
||||||
|
|
||||||
|
@_register_tvm_node
|
||||||
|
class Attrs(NodeBase):
|
||||||
|
"""Attribute node, which is mainly use for defining attributes of relay operators.
|
||||||
|
|
||||||
|
Used by function registered in python side, such as compute, schedule and alter_layout.
|
||||||
|
Attrs is passed as the first argument to these functions.
|
||||||
|
"""
|
||||||
|
def list_field_info(self):
|
||||||
|
""" Get fields information
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
infos: list of AttrFieldInfo
|
||||||
|
List of field information
|
||||||
|
"""
|
||||||
|
return _api_internal._AttrsListFieldInfo(self)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
"""Get list of names in the attribute.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
keys : list of str
|
||||||
|
List of keys
|
||||||
|
"""
|
||||||
|
fields = self.list_field_info()
|
||||||
|
for field in fields:
|
||||||
|
yield field.name
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__getattr__(item)
|
||||||
|
|
||||||
|
|
||||||
|
_init_api("tvm.attrs")
|
|
@ -21,6 +21,20 @@ def register_relay_node(type_key=None):
|
||||||
return _register_tvm_node(type_key)
|
return _register_tvm_node(type_key)
|
||||||
|
|
||||||
|
|
||||||
|
def register_relay_attr_node(type_key=None):
|
||||||
|
"""register relay attribute node
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
type_key : str or cls
|
||||||
|
The type key of the node
|
||||||
|
"""
|
||||||
|
if not isinstance(type_key, str):
|
||||||
|
return _register_tvm_node(
|
||||||
|
"relay.attrs." + type_key.__name__)(type_key)
|
||||||
|
return _register_tvm_node(type_key)
|
||||||
|
|
||||||
|
|
||||||
class RelayNode(NodeBase):
|
class RelayNode(NodeBase):
|
||||||
"""Base class of all relay node."""
|
"""Base class of all relay node."""
|
||||||
def astext(self, show_meta_data=True, annotate=None):
|
def astext(self, show_meta_data=True, annotate=None):
|
||||||
|
|
|
@ -17,6 +17,7 @@ OPT_PASS_LEVEL = {
|
||||||
"FoldConstant": 2,
|
"FoldConstant": 2,
|
||||||
"CombineParallelConv2D": 3,
|
"CombineParallelConv2D": 3,
|
||||||
"FoldScaleAxis": 3,
|
"FoldScaleAxis": 3,
|
||||||
|
"AlterOpLayout": 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
class BuildConfig(object):
|
class BuildConfig(object):
|
||||||
|
@ -157,6 +158,13 @@ def optimize(func, params=None):
|
||||||
|
|
||||||
if cfg.pass_enabled("FoldConstant"):
|
if cfg.pass_enabled("FoldConstant"):
|
||||||
func = ir_pass.fold_constant(func)
|
func = ir_pass.fold_constant(func)
|
||||||
|
|
||||||
|
if cfg.pass_enabled("AlterOpLayout"):
|
||||||
|
func = ir_pass.infer_type(func)
|
||||||
|
func = ir_pass.canonicalize_ops(func)
|
||||||
|
func = ir_pass.infer_type(func)
|
||||||
|
func = ir_pass.alter_op_layout(func)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -191,6 +191,23 @@ def simplify_inference(expr):
|
||||||
return _ir_pass.simplify_inference(expr)
|
return _ir_pass.simplify_inference(expr)
|
||||||
|
|
||||||
|
|
||||||
|
def canonicalize_ops(expr):
|
||||||
|
""" Canonicalize special operators to basic operators.
|
||||||
|
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
e: tvm.relay.Expr
|
||||||
|
The input Expression
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result: tvm.relay.Expr
|
||||||
|
An expression without bias_add
|
||||||
|
"""
|
||||||
|
return _ir_pass.canonicalize_ops(expr)
|
||||||
|
|
||||||
|
|
||||||
def dead_code_elimination(expr):
|
def dead_code_elimination(expr):
|
||||||
""" Remove expressions which does not effect the program result (dead code).
|
""" Remove expressions which does not effect the program result (dead code).
|
||||||
|
|
||||||
|
@ -321,3 +338,22 @@ def combine_parallel_conv2d(expr):
|
||||||
Transformed expression
|
Transformed expression
|
||||||
"""
|
"""
|
||||||
return _ir_pass.CombineParallelConv2D(expr)
|
return _ir_pass.CombineParallelConv2D(expr)
|
||||||
|
|
||||||
|
|
||||||
|
def alter_op_layout(expr):
|
||||||
|
"""Alternate the layouts of operators or replace primitive operators with
|
||||||
|
other expressions.
|
||||||
|
This pass can be used for computing convolution in custom layouts or
|
||||||
|
other general weight pre-transformation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
expr : tvm.relay.Expr
|
||||||
|
The input expression.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
transformed_expr : tvm.relay.Expr
|
||||||
|
Transformed expression with alternated layout.
|
||||||
|
"""
|
||||||
|
return _ir_pass.AlterOpLayout(expr)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
#pylint: disable=wildcard-import, redefined-builtin
|
#pylint: disable=wildcard-import, redefined-builtin
|
||||||
"""Relay core operators."""
|
"""Relay core operators."""
|
||||||
# operator defs
|
# operator defs
|
||||||
from .op import get, register, register_schedule, register_compute, Op
|
from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
|
||||||
|
Op
|
||||||
|
|
||||||
# Operators
|
# Operators
|
||||||
from .reduce import *
|
from .reduce import *
|
||||||
|
@ -10,6 +11,7 @@ from .transform import *
|
||||||
from . import nn
|
from . import nn
|
||||||
from . import image
|
from . import image
|
||||||
from . import vision
|
from . import vision
|
||||||
|
from . import op_attrs
|
||||||
|
|
||||||
# operator registry
|
# operator registry
|
||||||
from . import _tensor
|
from . import _tensor
|
||||||
|
|
|
@ -80,12 +80,3 @@ def clip_compute(attrs, inputs, output_type, target):
|
||||||
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
|
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
|
||||||
|
|
||||||
register_schedule("clip", schedule_elemwise)
|
register_schedule("clip", schedule_elemwise)
|
||||||
register_pattern("clip", OpPattern.ELEMWISE)
|
|
||||||
|
|
||||||
# concatenate
|
|
||||||
@register_compute("concatenate")
|
|
||||||
def concatenate_compute(attrs, inputs, output_type, target):
|
|
||||||
return [topi.concatenate(inputs, axis=attrs.axis)]
|
|
||||||
|
|
||||||
register_schedule("concatenate", schedule_injective)
|
|
||||||
register_pattern("concatenate", OpPattern.INJECTIVE)
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
"""Backend compiler related feature registration"""
|
"""Backend compiler related feature registration"""
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name,unused-argument
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
import topi
|
||||||
from . import op as _reg
|
from . import op as _reg
|
||||||
from ._reduce import _schedule_reduce
|
from ._reduce import _schedule_reduce
|
||||||
|
from .op import schedule_injective, OpPattern
|
||||||
|
|
||||||
schedule_injective = _reg.schedule_injective
|
schedule_injective = _reg.schedule_injective
|
||||||
schedule_broadcast = _reg.schedule_injective
|
schedule_broadcast = _reg.schedule_injective
|
||||||
|
@ -15,10 +17,22 @@ _reg.register_schedule("reshape", schedule_injective)
|
||||||
_reg.register_schedule("reshape_like", schedule_injective)
|
_reg.register_schedule("reshape_like", schedule_injective)
|
||||||
_reg.register_schedule("full", schedule_injective)
|
_reg.register_schedule("full", schedule_injective)
|
||||||
_reg.register_schedule("full_like", schedule_injective)
|
_reg.register_schedule("full_like", schedule_injective)
|
||||||
_reg.register_schedule("cast", schedule_broadcast)
|
_reg.register_schedule("cast", schedule_injective)
|
||||||
_reg.register_schedule("strided_slice", schedule_injective)
|
_reg.register_schedule("strided_slice", schedule_injective)
|
||||||
_reg.register_schedule("slice_like", schedule_injective)
|
_reg.register_schedule("slice_like", schedule_injective)
|
||||||
_reg.register_schedule("split", schedule_injective)
|
_reg.register_schedule("split", schedule_injective)
|
||||||
_reg.register_schedule("take", schedule_injective)
|
_reg.register_schedule("take", schedule_injective)
|
||||||
_reg.register_schedule("transpose", schedule_injective)
|
_reg.register_schedule("transpose", schedule_injective)
|
||||||
_reg.register_schedule("where", schedule_broadcast)
|
_reg.register_schedule("where", schedule_broadcast)
|
||||||
|
|
||||||
|
# layout_transform
|
||||||
|
_reg.register_schedule("layout_transform", schedule_injective)
|
||||||
|
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
|
||||||
|
|
||||||
|
# concatenate
|
||||||
|
@_reg.register_compute("concatenate")
|
||||||
|
def concatenate_compute(attrs, inputs, output_type, target):
|
||||||
|
return [topi.concatenate(inputs, axis=attrs.axis)]
|
||||||
|
|
||||||
|
_reg.register_schedule("concatenate", schedule_injective)
|
||||||
|
_reg.register_pattern("concatenate", OpPattern.INJECTIVE)
|
||||||
|
|
|
@ -107,7 +107,7 @@ def register_schedule(op_name, schedule=None, level=10):
|
||||||
op_name : str
|
op_name : str
|
||||||
The name of the op.
|
The name of the op.
|
||||||
|
|
||||||
schedule : function
|
schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
|
||||||
The schedule function.
|
The schedule function.
|
||||||
|
|
||||||
level : int
|
level : int
|
||||||
|
@ -124,7 +124,8 @@ def register_compute(op_name, compute=None, level=10):
|
||||||
op_name : str
|
op_name : str
|
||||||
The name of the op.
|
The name of the op.
|
||||||
|
|
||||||
compute : function
|
compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
|
||||||
|
-> List[Tensor]
|
||||||
The compute function.
|
The compute function.
|
||||||
|
|
||||||
level : int
|
level : int
|
||||||
|
@ -133,6 +134,23 @@ def register_compute(op_name, compute=None, level=10):
|
||||||
return register(op_name, "FTVMCompute", compute, level)
|
return register(op_name, "FTVMCompute", compute, level)
|
||||||
|
|
||||||
|
|
||||||
|
def register_alter_op_layout(op_name, alter_layout=None, level=10):
|
||||||
|
"""Register alter op layout function for an op
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
op_name : str
|
||||||
|
The name of the operator
|
||||||
|
|
||||||
|
alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
|
||||||
|
The function for changing the layout or replacing the operator
|
||||||
|
|
||||||
|
level : int
|
||||||
|
The priority level
|
||||||
|
"""
|
||||||
|
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)
|
||||||
|
|
||||||
|
|
||||||
def register_pattern(op_name, pattern, level=10):
|
def register_pattern(op_name, pattern, level=10):
|
||||||
"""Register operator pattern for an op.
|
"""Register operator pattern for an op.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
"""The attributes node used for Relay operators"""
|
||||||
|
|
||||||
|
from ...attrs import Attrs
|
||||||
|
from ..base import register_relay_attr_node
|
||||||
|
|
||||||
|
@register_relay_attr_node
|
||||||
|
class Conv2DAttrs(Attrs):
|
||||||
|
"""Attribute of a Convolution Operator"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register_relay_attr_node
|
||||||
|
class GlobalPool2DAttrs(Attrs):
|
||||||
|
"""Attribute of a Global 2D Pooling Operator"""
|
||||||
|
pass
|
|
@ -387,3 +387,25 @@ def slice_like(data, shape_like, axes=None):
|
||||||
The computed result.
|
The computed result.
|
||||||
"""
|
"""
|
||||||
return _make.slice_like(data, shape_like, axes)
|
return _make.slice_like(data, shape_like, axes)
|
||||||
|
|
||||||
|
|
||||||
|
def layout_transform(data, src_layout, dst_layout):
|
||||||
|
"""Transform the layout of a tensor
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : relay.Expr
|
||||||
|
The source tensor to be transformed
|
||||||
|
|
||||||
|
src_layout: str
|
||||||
|
The source layout. (e.g NCHW)
|
||||||
|
|
||||||
|
dst_layout: str
|
||||||
|
The destination layout. (e.g. NCHW16c)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ret : relay.Expr
|
||||||
|
The transformed tensor.
|
||||||
|
"""
|
||||||
|
return _make.layout_transform(data, src_layout, dst_layout)
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
* \file attrs.cc
|
* \file attrs.cc
|
||||||
*/
|
*/
|
||||||
#include <tvm/attrs.h>
|
#include <tvm/attrs.h>
|
||||||
|
#include <tvm/api_registry.h>
|
||||||
#include "attr_functor.h"
|
#include "attr_functor.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
|
@ -321,4 +322,9 @@ bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
|
||||||
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
|
return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_API("_AttrsListFieldInfo")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = args[0].operator Attrs()->ListFieldInfo();
|
||||||
|
});
|
||||||
|
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -185,7 +185,7 @@ class Layout : public NodeRef {
|
||||||
CHECK_GT(block_size, 0);
|
CHECK_GT(block_size, 0);
|
||||||
new_layout << block_size;
|
new_layout << block_size;
|
||||||
}
|
}
|
||||||
new_layout << layout_simplified[i]->value;
|
new_layout << static_cast<char>(layout_simplified[i]->value);
|
||||||
}
|
}
|
||||||
return Layout(new_layout.str());
|
return Layout(new_layout.str());
|
||||||
}
|
}
|
||||||
|
@ -241,6 +241,16 @@ class Layout : public NodeRef {
|
||||||
return operator->()->layout_simplified.size();
|
return operator->()->layout_simplified.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*! \return number of super dimensions */
|
||||||
|
size_t ndim_super() const {
|
||||||
|
size_t ct = 0;
|
||||||
|
for (auto x : operator->()->layout_simplified) {
|
||||||
|
if (IsSuperdim(x))
|
||||||
|
ct++;
|
||||||
|
}
|
||||||
|
return ct;
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief The description of the \p i-th dimension.
|
* \brief The description of the \p i-th dimension.
|
||||||
* If it is a sub-dimension, the size will be returned as well,
|
* If it is a sub-dimension, the size will be returned as well,
|
||||||
|
@ -327,6 +337,17 @@ class Layout : public NodeRef {
|
||||||
return operator->()->name == rhs->name;
|
return operator->()->name == rhs->name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief allow output string of layout to ostream
|
||||||
|
* \param os the output stream
|
||||||
|
* \param l the layout
|
||||||
|
* \return the ostream
|
||||||
|
*/
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
|
||||||
|
os << l.name();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
using ContainerType = LayoutNode;
|
using ContainerType = LayoutNode;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -7,11 +7,13 @@
|
||||||
#include <tvm/relay/attrs/nn.h>
|
#include <tvm/relay/attrs/nn.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../pass/alter_op_layout.h"
|
||||||
#include "../layout.h"
|
#include "../layout.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
|
||||||
|
// relay.nn.conv2d
|
||||||
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
|
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
|
||||||
|
|
||||||
bool Conv2DRel(const Array<Type>& types,
|
bool Conv2DRel(const Array<Type>& types,
|
||||||
|
@ -101,6 +103,20 @@ bool Conv2DRel(const Array<Type>& types,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
Array<Array<Layout> > Conv2DInferCorrectLayout(
|
||||||
|
const Attrs& attrs,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr>> &old_in_shapes) {
|
||||||
|
const T* params = attrs.as<T>();
|
||||||
|
Layout out_layout(params->out_layout);
|
||||||
|
|
||||||
|
// We always make other operators to fit the layouts of convolution layers
|
||||||
|
// So this inference ignores all inputs
|
||||||
|
return Array<Array<Layout> >{{params->data_layout, params->weight_layout},
|
||||||
|
{out_layout.defined() ? out_layout : params->data_layout}};
|
||||||
|
}
|
||||||
|
|
||||||
// Positional relay function to create conv2d operator
|
// Positional relay function to create conv2d operator
|
||||||
// used by frontend FFI.
|
// used by frontend FFI.
|
||||||
|
@ -156,10 +172,11 @@ with the layer input to produce a tensor of outputs.
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.add_argument("weight", "Tensor", "The weight tensor.")
|
.add_argument("weight", "Tensor", "The weight tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
.add_type_rel("Conv2D", Conv2DRel);
|
.add_type_rel("Conv2D", Conv2DRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
|
||||||
|
|
||||||
|
|
||||||
// Conv2DTranspose
|
// relay.nn.conv2d_transpose
|
||||||
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
|
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
|
||||||
|
|
||||||
bool Conv2DTransposeRel(const Array<Type>& types,
|
bool Conv2DTransposeRel(const Array<Type>& types,
|
||||||
|
@ -185,6 +202,12 @@ bool Conv2DTransposeRel(const Array<Type>& types,
|
||||||
<< "Conv only support kernel layouts that are convertible from OIHW."
|
<< "Conv only support kernel layouts that are convertible from OIHW."
|
||||||
<< " But got "<< kernel_layout;
|
<< " But got "<< kernel_layout;
|
||||||
|
|
||||||
|
Layout out_layout(param->out_layout);
|
||||||
|
if (!out_layout.defined()) out_layout = in_layout;
|
||||||
|
CHECK(out_layout.Convertible(kNCHW))
|
||||||
|
<< "Conv only support output layouts that are convertible from NCHW."
|
||||||
|
<< " But got " << out_layout;
|
||||||
|
|
||||||
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
|
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
|
||||||
|
|
||||||
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
|
auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
|
||||||
|
@ -241,7 +264,7 @@ bool Conv2DTransposeRel(const Array<Type>& types,
|
||||||
if (out_dtype.bits() == 0) {
|
if (out_dtype.bits() == 0) {
|
||||||
out_dtype = data->dtype;
|
out_dtype = data->dtype;
|
||||||
}
|
}
|
||||||
oshape = ConvertLayout(oshape, kNCHW, in_layout);
|
oshape = ConvertLayout(oshape, kNCHW, out_layout);
|
||||||
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
|
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -307,6 +330,8 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.add_argument("weight", "Tensor", "The weight tensor.")
|
.add_argument("weight", "Tensor", "The weight tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
|
||||||
|
Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
|
||||||
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
|
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
|
|
|
@ -12,12 +12,14 @@
|
||||||
#include <topi/nn/flatten.h>
|
#include <topi/nn/flatten.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../type_relations.h"
|
#include "../type_relations.h"
|
||||||
|
#include "../../pass/alter_op_layout.h"
|
||||||
#include "../op_common.h"
|
#include "../op_common.h"
|
||||||
#include "../layout.h"
|
#include "../layout.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
|
||||||
|
// relay.nn.bias_add
|
||||||
TVM_REGISTER_NODE_TYPE(BiasAddAttrs);
|
TVM_REGISTER_NODE_TYPE(BiasAddAttrs);
|
||||||
|
|
||||||
bool BiasAddRel(const Array<Type>& types,
|
bool BiasAddRel(const Array<Type>& types,
|
||||||
|
@ -74,6 +76,7 @@ RELAY_REGISTER_OP("nn.bias_add")
|
||||||
.add_type_rel("BiasAdd", BiasAddRel);
|
.add_type_rel("BiasAdd", BiasAddRel);
|
||||||
|
|
||||||
|
|
||||||
|
// relay.nn.dense
|
||||||
TVM_REGISTER_NODE_TYPE(DenseAttrs);
|
TVM_REGISTER_NODE_TYPE(DenseAttrs);
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,6 +146,8 @@ RELAY_REGISTER_OP("nn.dense")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Dense", DenseRel);
|
.add_type_rel("Dense", DenseRel);
|
||||||
|
|
||||||
|
// relay.leaky_relu
|
||||||
|
TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
|
||||||
|
|
||||||
// Positional relay function to create leaky relu operator used by frontend FFI.
|
// Positional relay function to create leaky relu operator used by frontend FFI.
|
||||||
Expr MakeLeakyRelu(Expr data,
|
Expr MakeLeakyRelu(Expr data,
|
||||||
|
@ -171,6 +176,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
|
||||||
.add_argument("data", "Tensor", "Input data.")
|
.add_argument("data", "Tensor", "Input data.")
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
.add_type_rel("Identity", IdentityRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
.set_attr<FTVMCompute>(
|
.set_attr<FTVMCompute>(
|
||||||
"FTVMCompute", [](const Attrs& attrs,
|
"FTVMCompute", [](const Attrs& attrs,
|
||||||
const Array<Tensor>& inputs,
|
const Array<Tensor>& inputs,
|
||||||
|
@ -181,6 +187,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// relay.prelu
|
||||||
TVM_REGISTER_NODE_TYPE(PReluAttrs);
|
TVM_REGISTER_NODE_TYPE(PReluAttrs);
|
||||||
|
|
||||||
bool PReluRel(const Array<Type>& types,
|
bool PReluRel(const Array<Type>& types,
|
||||||
|
@ -235,6 +242,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
|
||||||
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
|
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("PRelu", PReluRel)
|
.add_type_rel("PRelu", PReluRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
.set_attr<FTVMCompute>(
|
.set_attr<FTVMCompute>(
|
||||||
"FTVMCompute", [](const Attrs& attrs,
|
"FTVMCompute", [](const Attrs& attrs,
|
||||||
const Array<Tensor>& inputs,
|
const Array<Tensor>& inputs,
|
||||||
|
@ -245,6 +253,9 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// relay.softmax
|
||||||
|
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
|
||||||
|
|
||||||
TVM_REGISTER_API("relay.op.nn._make.softmax")
|
TVM_REGISTER_API("relay.op.nn._make.softmax")
|
||||||
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||||
auto make_func = [](Expr data, int axis) {
|
auto make_func = [](Expr data, int axis) {
|
||||||
|
@ -282,6 +293,7 @@ RELAY_REGISTER_OP("nn.softmax")
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// relay.nn.log_softmax
|
||||||
TVM_REGISTER_API("relay.op.nn._make.log_softmax")
|
TVM_REGISTER_API("relay.op.nn._make.log_softmax")
|
||||||
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||||
auto make_func = [](Expr data, int axis) {
|
auto make_func = [](Expr data, int axis) {
|
||||||
|
@ -321,8 +333,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// relay.nn.batch_flatten
|
||||||
// BatchFlatten
|
|
||||||
bool BatchFlattenRel(const Array<Type>& types,
|
bool BatchFlattenRel(const Array<Type>& types,
|
||||||
int num_inputs,
|
int num_inputs,
|
||||||
const Attrs& attrs,
|
const Attrs& attrs,
|
||||||
|
@ -410,6 +421,7 @@ RELAY_REGISTER_OP("nn.relu")
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
.add_type_rel("Identity", IdentityRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
|
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
|
||||||
const Array<Tensor>& inputs,
|
const Array<Tensor>& inputs,
|
||||||
const Type& out_type,
|
const Type& out_type,
|
||||||
|
@ -460,6 +472,7 @@ centered at that value (zero padding is added where necessary).
|
||||||
.set_num_inputs(1)
|
.set_num_inputs(1)
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
.add_type_rel("Identity", IdentityRel);
|
.add_type_rel("Identity", IdentityRel);
|
||||||
|
|
||||||
|
|
||||||
|
@ -495,6 +508,7 @@ Normalizes along dimension axis using an L2 norm
|
||||||
.set_num_inputs(1)
|
.set_num_inputs(1)
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
.add_type_rel("Identity", IdentityRel);
|
.add_type_rel("Identity", IdentityRel);
|
||||||
|
|
||||||
// Dropout
|
// Dropout
|
||||||
|
@ -538,6 +552,7 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
|
||||||
.set_num_inputs(1)
|
.set_num_inputs(1)
|
||||||
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
|
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
.add_type_rel("Dropout", DropoutRel);
|
.add_type_rel("Dropout", DropoutRel);
|
||||||
|
|
||||||
// batch_norm
|
// batch_norm
|
||||||
|
|
|
@ -1,87 +1,88 @@
|
||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2018 by Contributors
|
* Copyright (c) 2018 by Contributors
|
||||||
* \file pad.cc
|
* \file pad.cc
|
||||||
* \brief Implementation of operator pad
|
* \brief Implementation of operator pad
|
||||||
*/
|
*/
|
||||||
#include <tvm/ir_operator.h>
|
#include <tvm/ir_operator.h>
|
||||||
#include <tvm/relay/op.h>
|
#include <tvm/relay/op.h>
|
||||||
#include <tvm/relay/attrs/nn.h>
|
#include <tvm/relay/attrs/nn.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../layout.h"
|
#include "../layout.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(PadAttrs);
|
// relay.nn.pad
|
||||||
|
TVM_REGISTER_NODE_TYPE(PadAttrs);
|
||||||
bool PadRel(const Array<Type>& types,
|
|
||||||
int num_inputs,
|
bool PadRel(const Array<Type>& types,
|
||||||
const Attrs& attrs,
|
int num_inputs,
|
||||||
const TypeReporter& reporter) {
|
const Attrs& attrs,
|
||||||
CHECK_EQ(types.size(), 2);
|
const TypeReporter& reporter) {
|
||||||
const auto* data = types[0].as<TensorTypeNode>();
|
CHECK_EQ(types.size(), 2);
|
||||||
if (data == nullptr) return false;
|
const auto* data = types[0].as<TensorTypeNode>();
|
||||||
|
if (data == nullptr) return false;
|
||||||
const PadAttrs* param = attrs.as<PadAttrs>();
|
|
||||||
CHECK(param != nullptr);
|
const PadAttrs* param = attrs.as<PadAttrs>();
|
||||||
|
CHECK(param != nullptr);
|
||||||
// check that pad widths match lengths
|
|
||||||
CHECK(data->shape.size() == param->pad_width.size())
|
// check that pad widths match lengths
|
||||||
<< "There should be as many pad width pairs as shape dimensions "
|
CHECK(data->shape.size() == param->pad_width.size())
|
||||||
<< "but the shape has " << data->shape.size() << " dimensions "
|
<< "There should be as many pad width pairs as shape dimensions "
|
||||||
<< "and there are " << param->pad_width.size() << " pad width pairs.";
|
<< "but the shape has " << data->shape.size() << " dimensions "
|
||||||
|
<< "and there are " << param->pad_width.size() << " pad width pairs.";
|
||||||
// each pad width element should be a pair of positive integers
|
|
||||||
std::vector<IndexExpr> oshape;
|
// each pad width element should be a pair of positive integers
|
||||||
for (size_t i = 0; i < param->pad_width.size(); i++) {
|
std::vector<IndexExpr> oshape;
|
||||||
CHECK(param->pad_width[i].size() == 2)
|
for (size_t i = 0; i < param->pad_width.size(); i++) {
|
||||||
<< "Each pad width element should be a pair but at index " << i
|
CHECK(param->pad_width[i].size() == 2)
|
||||||
<< " there are " << param->pad_width[i].size() << " elements.";
|
<< "Each pad width element should be a pair but at index " << i
|
||||||
|
<< " there are " << param->pad_width[i].size() << " elements.";
|
||||||
auto width1 = as_const_int(param->pad_width[i][0]);
|
|
||||||
auto width2 = as_const_int(param->pad_width[i][1]);
|
auto width1 = as_const_int(param->pad_width[i][0]);
|
||||||
CHECK(width1 != nullptr);
|
auto width2 = as_const_int(param->pad_width[i][1]);
|
||||||
CHECK(width2 != nullptr);
|
CHECK(width1 != nullptr);
|
||||||
|
CHECK(width2 != nullptr);
|
||||||
CHECK(*width1 >= 0)
|
|
||||||
<< "Param width elements should be positive but first pad width at "
|
CHECK(*width1 >= 0)
|
||||||
<< "index " << i << " is " << *width1 << ".";
|
<< "Param width elements should be positive but first pad width at "
|
||||||
CHECK(*width2 >= 0)
|
<< "index " << i << " is " << *width1 << ".";
|
||||||
<< "Param width elements should be positive but first pad width at "
|
CHECK(*width2 >= 0)
|
||||||
<< "index " << i << " is " << *width2 << ".";
|
<< "Param width elements should be positive but first pad width at "
|
||||||
|
<< "index " << i << " is " << *width2 << ".";
|
||||||
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
|
|
||||||
oshape.push_back(data->shape[i] + padding);
|
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
|
||||||
}
|
oshape.push_back(data->shape[i] + padding);
|
||||||
|
}
|
||||||
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
|
|
||||||
data->dtype));
|
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
|
||||||
return true;
|
data->dtype));
|
||||||
}
|
return true;
|
||||||
|
}
|
||||||
// Handler to create a call to the padding op used by front-end FFI
|
|
||||||
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
|
// Handler to create a call to the padding op used by front-end FFI
|
||||||
auto attrs = make_node<PadAttrs>();
|
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
|
||||||
attrs->pad_value = pad_value;
|
auto attrs = make_node<PadAttrs>();
|
||||||
attrs->pad_width = std::move(pad_width);
|
attrs->pad_value = pad_value;
|
||||||
static const Op& op = Op::Get("nn.pad");
|
attrs->pad_width = std::move(pad_width);
|
||||||
return CallNode::make(op, {data}, Attrs(attrs), {});
|
static const Op& op = Op::Get("nn.pad");
|
||||||
}
|
return CallNode::make(op, {data}, Attrs(attrs), {});
|
||||||
|
}
|
||||||
TVM_REGISTER_API("relay.op.nn._make.pad")
|
|
||||||
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
TVM_REGISTER_API("relay.op.nn._make.pad")
|
||||||
runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
|
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||||
});
|
runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
|
||||||
|
});
|
||||||
RELAY_REGISTER_OP("nn.pad")
|
|
||||||
.describe(R"code(Pad for n-D tensor.
|
RELAY_REGISTER_OP("nn.pad")
|
||||||
|
.describe(R"code(Pad for n-D tensor.
|
||||||
)code" TVM_ADD_FILELINE)
|
|
||||||
.set_attrs_type_key("relay.attrs.PadAttrs")
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_num_inputs(1)
|
.set_attrs_type_key("relay.attrs.PadAttrs")
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.set_num_inputs(1)
|
||||||
.set_support_level(2)
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.add_type_rel("Pad", PadRel);
|
.set_support_level(2)
|
||||||
|
.add_type_rel("Pad", PadRel);
|
||||||
} // namespace relay
|
|
||||||
} // namespace tvm
|
} // namespace relay
|
||||||
|
} // namespace tvm
|
||||||
|
|
|
@ -9,13 +9,39 @@
|
||||||
#include <topi/nn/pooling.h>
|
#include <topi/nn/pooling.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../layout.h"
|
#include "../layout.h"
|
||||||
|
#include "../../pass/alter_op_layout.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
|
||||||
|
// relay.nn.max_pool2d & relay.nn.avg_pool2d
|
||||||
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
|
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
|
||||||
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
|
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Array<Array<Layout> > Pool2DInferCorrectLayout(
|
||||||
|
const Attrs& attrs,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr>> &old_in_shapes) {
|
||||||
|
// NOTE: Discard "const" qualifier here.
|
||||||
|
T *params = const_cast<T*>(attrs.as<T>());
|
||||||
|
|
||||||
|
if (new_in_layouts.defined()) {
|
||||||
|
CHECK_EQ(new_in_layouts.size(), 1);
|
||||||
|
|
||||||
|
Layout raw_layout(params->layout);
|
||||||
|
Layout input = new_in_layouts[0];
|
||||||
|
if (input.Indexof('W') == raw_layout.Indexof('W') &&
|
||||||
|
input.Indexof('H') == raw_layout.Indexof('H') &&
|
||||||
|
!input.Contains('w') && !input.Contains('h')) {
|
||||||
|
params->layout = input.name(); // modify self to follow the input layout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array<Array<Layout> >{{params->layout}, {params->layout}};
|
||||||
|
}
|
||||||
|
|
||||||
template <typename AttrType>
|
template <typename AttrType>
|
||||||
bool Pool2DRel(const Array<Type>& types,
|
bool Pool2DRel(const Array<Type>& types,
|
||||||
int num_inputs,
|
int num_inputs,
|
||||||
|
@ -163,6 +189,7 @@ RELAY_REGISTER_OP("nn.max_pool2d")
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>)
|
.add_type_rel("MaxPool2D", Pool2DRel<MaxPool2DAttrs>)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Pool2DInferCorrectLayout<MaxPool2DAttrs>)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
|
.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<MaxPool2DAttrs, topi::nn::kMaxPool>);
|
||||||
|
|
||||||
|
|
||||||
|
@ -219,9 +246,10 @@ Average pooling operation for one dimensional data.
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>)
|
.add_type_rel("AvgPool2D", Pool2DRel<AvgPool2DAttrs>)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Pool2DInferCorrectLayout<AvgPool2DAttrs>)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
|
.set_attr<FTVMCompute>("FTVMCompute", Pool2DCompute<AvgPool2DAttrs, topi::nn::kAvgPool>);
|
||||||
|
|
||||||
// Global Pool
|
// relay.nn.global_pool_2d & relay.nn.max_pool_2d
|
||||||
TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs);
|
TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs);
|
||||||
|
|
||||||
bool GlobalPool2DRel(const Array<Type>& types,
|
bool GlobalPool2DRel(const Array<Type>& types,
|
||||||
|
@ -247,8 +275,9 @@ bool GlobalPool2DRel(const Array<Type>& types,
|
||||||
|
|
||||||
const auto hidx = layout.Indexof('H');
|
const auto hidx = layout.Indexof('H');
|
||||||
const auto widx = layout.Indexof('W');
|
const auto widx = layout.Indexof('W');
|
||||||
std::vector<IndexExpr> oshape({dshape[0], dshape[1], dshape[2], dshape[3]});
|
Array<IndexExpr> oshape(dshape);
|
||||||
oshape[hidx] = oshape[widx] = 1;
|
oshape.Set(hidx, 1);
|
||||||
|
oshape.Set(widx, 1);
|
||||||
|
|
||||||
// assign output type
|
// assign output type
|
||||||
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
|
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
|
||||||
|
@ -307,6 +336,8 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel)
|
.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
|
||||||
|
Pool2DInferCorrectLayout<GlobalPool2DAttrs>)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>);
|
.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kAvgPool>);
|
||||||
|
|
||||||
// GlobalMaxPool
|
// GlobalMaxPool
|
||||||
|
@ -338,6 +369,8 @@ RELAY_REGISTER_OP("nn.global_max_pool2d")
|
||||||
.add_argument("data", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(2)
|
.set_support_level(2)
|
||||||
.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel)
|
.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
|
||||||
|
Pool2DInferCorrectLayout<GlobalPool2DAttrs>)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>);
|
.set_attr<FTVMCompute>("FTVMCompute", GlobalPool2DCompute<topi::nn::kMaxPool>);
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
#include <tvm/relay/op.h>
|
#include <tvm/relay/op.h>
|
||||||
#include <tvm/relay/op_attr_types.h>
|
#include <tvm/relay/op_attr_types.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "../pass/alter_op_layout.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
@ -32,21 +33,24 @@ inline std::vector<T> AsVector(const Array<T> &array) {
|
||||||
* We make the decision to always only expose positional argument.
|
* We make the decision to always only expose positional argument.
|
||||||
* We will do rewrapping in the frontend to support language
|
* We will do rewrapping in the frontend to support language
|
||||||
* sugars such as keyword arguments and default value.
|
* sugars such as keyword arguments and default value.
|
||||||
*
|
|
||||||
* \param Prefix the prefix of the registry, for example, "relay.op._make.".
|
|
||||||
*
|
|
||||||
* \param OpName the name of registry.
|
* \param OpName the name of registry.
|
||||||
*/
|
*/
|
||||||
#define RELAY_REGISTER_UNARY_OP(Prefix, OpName) \
|
#define RELAY_REGISTER_UNARY_OP(OpName) \
|
||||||
TVM_REGISTER_API(Prefix OpName) \
|
TVM_REGISTER_API("relay.op._make." OpName) \
|
||||||
.set_body_typed<Expr(Expr)>([](Expr data) { \
|
.set_body_typed<Expr(Expr)>([](Expr data) { \
|
||||||
static const Op& op = Op::Get(OpName); \
|
static const Op& op = Op::Get(OpName); \
|
||||||
return CallNode::make(op, {data}, Attrs(), {}); \
|
return CallNode::make(op, {data}, Attrs(), {}); \
|
||||||
}); \
|
}); \
|
||||||
RELAY_REGISTER_OP(OpName) \
|
RELAY_REGISTER_OP(OpName) \
|
||||||
.set_num_inputs(1) \
|
.set_num_inputs(1) \
|
||||||
.add_argument("data", "Tensor", "The input tensor.") \
|
.add_argument("data", "Tensor", "The input tensor.") \
|
||||||
.set_attr<TOpPattern>("TOpPattern", kElemWise)
|
.add_type_rel("Identity", IdentityRel) \
|
||||||
|
.set_attr<TOpPattern>("TOpPattern", kElemWise) \
|
||||||
|
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
|
||||||
|
ElemwiseArbitraryLayout) \
|
||||||
|
|
||||||
|
|
||||||
/*! Quick helper macro
|
/*! Quick helper macro
|
||||||
* - Expose a positional make function to construct the node.
|
* - Expose a positional make function to construct the node.
|
||||||
|
@ -56,12 +60,10 @@ inline std::vector<T> AsVector(const Array<T> &array) {
|
||||||
* We will do rewrapping in the frontend to support language
|
* We will do rewrapping in the frontend to support language
|
||||||
* sugars such as keyword arguments and default value.
|
* sugars such as keyword arguments and default value.
|
||||||
*
|
*
|
||||||
* \param Prefix the prefix of the registry, for example, "relay.op._make.".
|
|
||||||
*
|
|
||||||
* \param OpName the name of registry.
|
* \param OpName the name of registry.
|
||||||
*/
|
*/
|
||||||
#define RELAY_REGISTER_BINARY_OP(Prefix, OpName) \
|
#define RELAY_REGISTER_BINARY_OP(OpName) \
|
||||||
TVM_REGISTER_API(Prefix OpName) \
|
TVM_REGISTER_API("relay.op._make." OpName) \
|
||||||
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
|
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
|
||||||
static const Op& op = Op::Get(OpName); \
|
static const Op& op = Op::Get(OpName); \
|
||||||
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
|
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
|
||||||
|
@ -72,7 +74,26 @@ inline std::vector<T> AsVector(const Array<T> &array) {
|
||||||
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
|
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
|
||||||
.add_type_rel("Broadcast", BroadcastRel) \
|
.add_type_rel("Broadcast", BroadcastRel) \
|
||||||
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
|
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
|
||||||
.set_attr<TOpIsStateful>("TOpIsStateful", false)
|
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
|
||||||
|
BinaryBroadcastLayout)
|
||||||
|
|
||||||
|
// Comparisons
|
||||||
|
#define RELAY_REGISTER_CMP_OP(OpName) \
|
||||||
|
TVM_REGISTER_API("relay.op._make." OpName) \
|
||||||
|
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
|
||||||
|
static const Op& op = Op::Get(OpName); \
|
||||||
|
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
|
||||||
|
}); \
|
||||||
|
RELAY_REGISTER_OP(OpName) \
|
||||||
|
.set_num_inputs(2) \
|
||||||
|
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
|
||||||
|
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
|
||||||
|
.add_type_rel("BroadcastComp", BroadcastCompRel) \
|
||||||
|
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
|
||||||
|
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
|
||||||
|
BinaryBroadcastLayout)
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -23,71 +23,65 @@ namespace relay {
|
||||||
|
|
||||||
|
|
||||||
// Addition
|
// Addition
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "add")
|
RELAY_REGISTER_BINARY_OP("add")
|
||||||
.describe("Elementwise add with with broadcasting")
|
.describe("Elementwise add with with broadcasting")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add));
|
||||||
|
|
||||||
// Subtraction
|
// Subtraction
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract")
|
RELAY_REGISTER_BINARY_OP("subtract")
|
||||||
.describe("Elementwise substract with broadcasting")
|
.describe("Elementwise substract with broadcasting")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract));
|
||||||
|
|
||||||
// Right shift
|
// Right shift
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift")
|
RELAY_REGISTER_BINARY_OP("right_shift")
|
||||||
.describe("Elementwise right shift with broadcasting")
|
.describe("Elementwise right shift with broadcasting")
|
||||||
.set_support_level(4)
|
.set_support_level(4)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("left_shift")
|
||||||
.describe("Elementwise left shift with broadcasting")
|
.describe("Elementwise left shift with broadcasting")
|
||||||
.set_support_level(4)
|
.set_support_level(4)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("maximum")
|
||||||
.describe("Elementwise maximum of two tensors with broadcasting")
|
.describe("Elementwise maximum of two tensors with broadcasting")
|
||||||
.set_support_level(4)
|
.set_support_level(4)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("minimum")
|
||||||
.describe("Elementwise minimum of two tensors with broadcasting")
|
.describe("Elementwise minimum of two tensors with broadcasting")
|
||||||
.set_support_level(4)
|
.set_support_level(4)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("divide")
|
||||||
.describe("Elementwise divide with broadcasting")
|
.describe("Elementwise divide with broadcasting")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("multiply")
|
||||||
.describe("Elementwise multiply with broadcasting")
|
.describe("Elementwise multiply with broadcasting")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "power")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("power")
|
||||||
.describe("Elementwise power with broadcasting")
|
.describe("Elementwise power with broadcasting")
|
||||||
.set_support_level(4)
|
.set_support_level(4)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power));
|
||||||
|
|
||||||
RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod")
|
|
||||||
|
RELAY_REGISTER_BINARY_OP("mod")
|
||||||
.describe("Elementwise mod with broadcasting")
|
.describe("Elementwise mod with broadcasting")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod));
|
||||||
|
|
||||||
// Comparisons
|
|
||||||
#define RELAY_REGISTER_CMP_OP(OpName) \
|
|
||||||
TVM_REGISTER_API("relay.op._make." OpName) \
|
|
||||||
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
|
|
||||||
static const Op& op = Op::Get(OpName); \
|
|
||||||
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
|
|
||||||
}); \
|
|
||||||
RELAY_REGISTER_OP(OpName) \
|
|
||||||
.set_num_inputs(2) \
|
|
||||||
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
|
|
||||||
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
|
|
||||||
.add_type_rel("BroadcastComp", BroadcastCompRel) \
|
|
||||||
.set_attr<TOpPattern>("TOpPattern", kBroadcast)
|
|
||||||
|
|
||||||
RELAY_REGISTER_CMP_OP("equal")
|
RELAY_REGISTER_CMP_OP("equal")
|
||||||
.describe("Elementwise equal compare with broadcasting")
|
.describe("Elementwise equal compare with broadcasting")
|
||||||
|
|
|
@ -11,9 +11,12 @@
|
||||||
#include <topi/elemwise.h>
|
#include <topi/elemwise.h>
|
||||||
#include <topi/broadcast.h>
|
#include <topi/broadcast.h>
|
||||||
#include <topi/reduction.h>
|
#include <topi/reduction.h>
|
||||||
|
#include <topi/nn.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../op_common.h"
|
#include "../op_common.h"
|
||||||
#include "../../../arithmetic/compute_expr.h"
|
#include "../../../arithmetic/compute_expr.h"
|
||||||
|
#include "../../pass/alter_op_layout.h"
|
||||||
|
#include "../layout.h"
|
||||||
|
|
||||||
namespace tvm {
|
namespace tvm {
|
||||||
namespace relay {
|
namespace relay {
|
||||||
|
@ -156,6 +159,7 @@ RELAY_REGISTER_OP("expand_dims")
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
|
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
|
||||||
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
|
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
|
||||||
|
|
||||||
|
// relay.concatenate
|
||||||
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
|
TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
|
||||||
|
|
||||||
bool ConcatenateRel(const Array<Type>& types,
|
bool ConcatenateRel(const Array<Type>& types,
|
||||||
|
@ -201,6 +205,42 @@ bool ConcatenateRel(const Array<Type>& types,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Array<Array<Layout>> ConcatenateLayout(
|
||||||
|
const Attrs& attrs,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr>> &old_in_shapes) {
|
||||||
|
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
|
||||||
|
|
||||||
|
size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() :
|
||||||
|
static_cast<size_t>(param->axis);
|
||||||
|
|
||||||
|
Layout ret;
|
||||||
|
if (new_in_layouts.defined()) { // this function is called after some operators are alternated.
|
||||||
|
Layout::LayoutDim concate_dim = old_in_layouts[0][axis];
|
||||||
|
for (size_t i = 0; i < new_in_layouts.size(); ++i) {
|
||||||
|
if (new_in_layouts[i].ndim() > axis &&
|
||||||
|
new_in_layouts[i][axis] == concate_dim) {
|
||||||
|
ret = new_in_layouts[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { // this function is called on the original correct relay ir
|
||||||
|
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
|
||||||
|
if (old_in_layouts[i].defined()) {
|
||||||
|
ret = old_in_layouts[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret.ndim() <= axis || Layout::IsSubdim(ret[axis])) {
|
||||||
|
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array<Array<Layout> > {Array<Layout>(old_in_layouts.size(), ret), {ret}};
|
||||||
|
}
|
||||||
|
|
||||||
Expr MakeConcatenate(Expr data,
|
Expr MakeConcatenate(Expr data,
|
||||||
int axis) {
|
int axis) {
|
||||||
auto attrs = make_node<ConcatenateAttrs>();
|
auto attrs = make_node<ConcatenateAttrs>();
|
||||||
|
@ -226,7 +266,8 @@ RELAY_REGISTER_OP("concatenate")
|
||||||
.set_num_inputs(1)
|
.set_num_inputs(1)
|
||||||
.add_argument("data", "Tensor", "The input list of tensors.")
|
.add_argument("data", "Tensor", "The input list of tensors.")
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Concatenate", ConcatenateRel);
|
.add_type_rel("Concatenate", ConcatenateRel)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout);
|
||||||
|
|
||||||
/* relay.transpose */
|
/* relay.transpose */
|
||||||
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
|
TVM_REGISTER_NODE_TYPE(TransposeAttrs);
|
||||||
|
@ -323,7 +364,6 @@ RELAY_REGISTER_OP("transpose")
|
||||||
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
||||||
|
|
||||||
/* relay.reshape */
|
/* relay.reshape */
|
||||||
|
|
||||||
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
|
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
|
||||||
|
|
||||||
bool ReshapeRel(const Array<Type>& types,
|
bool ReshapeRel(const Array<Type>& types,
|
||||||
|
@ -1252,7 +1292,7 @@ Examples::
|
||||||
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
||||||
|
|
||||||
|
|
||||||
// Split
|
// relay.split
|
||||||
TVM_REGISTER_NODE_TYPE(SplitAttrs);
|
TVM_REGISTER_NODE_TYPE(SplitAttrs);
|
||||||
|
|
||||||
bool SplitRel(const Array<Type>& types,
|
bool SplitRel(const Array<Type>& types,
|
||||||
|
@ -1367,6 +1407,7 @@ the entries indicate where along axis the array is split.
|
||||||
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
||||||
|
|
||||||
|
|
||||||
|
// relay.slice_like
|
||||||
TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
|
TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
@ -1513,5 +1554,104 @@ RELAY_REGISTER_OP("slice_like")
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
|
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
|
||||||
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
.set_attr<TOpPattern>("TOpPattern", kInjective);
|
||||||
|
|
||||||
|
|
||||||
|
// relay.layout_transform
|
||||||
|
Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
|
||||||
|
const Array<Tensor>& inputs,
|
||||||
|
const Type& out_type,
|
||||||
|
const Target& target) {
|
||||||
|
const LayoutTransformAttrs *param = attrs.as<LayoutTransformAttrs>();
|
||||||
|
CHECK(param != nullptr);
|
||||||
|
|
||||||
|
Layout src_layout(param->src_layout);
|
||||||
|
Layout dst_layout(param->dst_layout);
|
||||||
|
|
||||||
|
if (src_layout.Equals(dst_layout)) {
|
||||||
|
return Array<Tensor>{ inputs[0] };
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK(src_layout.defined() && dst_layout.defined())
|
||||||
|
<< "cannot convert from/to undefined layout";
|
||||||
|
CHECK(src_layout.Convertible(dst_layout))
|
||||||
|
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout;
|
||||||
|
|
||||||
|
const auto& out_shape = ConvertLayout(inputs[0]->shape, src_layout, dst_layout);
|
||||||
|
return Array<Tensor> {
|
||||||
|
topi::layout_transform(inputs[0], out_shape, [&](const Array<tvm::Var>& dst_indices) {
|
||||||
|
std::vector<tvm::Expr> dst_to_src_indices;
|
||||||
|
for (size_t i = 0; i < src_layout.ndim(); ++i) {
|
||||||
|
Layout::LayoutDim src_axis = src_layout[i];
|
||||||
|
int dst_major_pos = dst_layout.Indexof(Layout::ToSuperdim(src_axis));
|
||||||
|
int dst_minor_pos = dst_layout.Indexof(Layout::ToSubdim(src_axis));
|
||||||
|
int32_t src_factor = static_cast<int32_t>(src_layout.Subsizeof(src_axis));
|
||||||
|
int32_t dst_factor = static_cast<int32_t>(dst_layout.Subsizeof(src_axis));
|
||||||
|
|
||||||
|
tvm::Expr src_index(dst_indices[dst_major_pos]);
|
||||||
|
if (dst_minor_pos >= 0) {
|
||||||
|
CHECK_GT(dst_factor, 0);
|
||||||
|
src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
|
||||||
|
}
|
||||||
|
if (Layout::IsSuperdim(src_axis) && src_factor > 0) {
|
||||||
|
src_index = src_index / src_factor;
|
||||||
|
} else if (Layout::IsSubdim(src_axis) && src_factor > 0) {
|
||||||
|
src_index = src_index % src_factor;
|
||||||
|
}
|
||||||
|
dst_to_src_indices.push_back(src_index);
|
||||||
|
}
|
||||||
|
return Array<tvm::Expr>(dst_to_src_indices);
|
||||||
|
})
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LayoutTransformRel(const Array<Type>& types,
|
||||||
|
int num_inputs,
|
||||||
|
const Attrs& attrs,
|
||||||
|
const TypeReporter& reporter) {
|
||||||
|
const auto* data = types[0].as<TensorTypeNode>();
|
||||||
|
CHECK(data != nullptr);
|
||||||
|
const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>();
|
||||||
|
|
||||||
|
Layout src_layout(params->src_layout);
|
||||||
|
Layout dst_layout(params->dst_layout);
|
||||||
|
|
||||||
|
CHECK(src_layout.defined() && dst_layout.defined())
|
||||||
|
<< "cannot convert from/to undefined layout";
|
||||||
|
CHECK(src_layout.Convertible(dst_layout))
|
||||||
|
<< "cannot convert from " << params->src_layout << " to " << params->dst_layout;
|
||||||
|
|
||||||
|
const auto& out_shape = ConvertLayout(data->shape, src_layout, dst_layout);
|
||||||
|
reporter->Assign(types[1], TensorTypeNode::make(out_shape, data->dtype));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr MakeLayoutTransform(Expr data,
|
||||||
|
std::string src_layout,
|
||||||
|
std::string dst_layout) {
|
||||||
|
auto attrs = make_node<LayoutTransformAttrs>();
|
||||||
|
attrs->src_layout = std::move(src_layout);
|
||||||
|
attrs->dst_layout = std::move(dst_layout);
|
||||||
|
static const Op& op = Op::Get("layout_transform");
|
||||||
|
return CallNode::make(op, {data}, Attrs(attrs), {});
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_API("relay.op._make.layout_transform")
|
||||||
|
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
|
||||||
|
runtime::detail::unpack_call<Expr, 3>(MakeLayoutTransform, args, rv);
|
||||||
|
});
|
||||||
|
|
||||||
|
RELAY_REGISTER_OP("layout_transform")
|
||||||
|
.describe(R"code(Transform the input data layout.
|
||||||
|
|
||||||
|
For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes
|
||||||
|
the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
|
||||||
|
|
||||||
|
)code" TVM_ADD_FILELINE)
|
||||||
|
.set_attrs_type_key("relay.attrs.LayoutTransformAttrs")
|
||||||
|
.set_num_inputs(1)
|
||||||
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
|
.add_type_rel("layout_transform", LayoutTransformRel)
|
||||||
|
.set_support_level(5)
|
||||||
|
.set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -22,7 +22,7 @@ namespace relay {
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
|
RELAY_REGISTER_UNARY_OP("log")
|
||||||
.describe(R"code(Returns the log input array, computed element-wise.
|
.describe(R"code(Returns the log input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -30,11 +30,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "log")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
|
RELAY_REGISTER_UNARY_OP("exp")
|
||||||
.describe(R"code(Returns the exp input array, computed element-wise.
|
.describe(R"code(Returns the exp input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -42,36 +41,30 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
|
||||||
|
|
||||||
|
RELAY_REGISTER_UNARY_OP("sqrt")
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt")
|
.describe(R"code(Returns the rsqrt input array, computed element-wise.
|
||||||
.describe(R"code(Returns the sqrt input array, computed element-wise.
|
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
sqrt(x)
|
sqrt(x)
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like")
|
RELAY_REGISTER_UNARY_OP("zeros_like")
|
||||||
.describe(R"code(Returns an array of zeros, with same type and shape as the input.
|
.describe(R"code(Returns an array of zeros, with same type and shape as the input.
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(4);
|
||||||
.add_type_rel("Identity", IdentityRel);
|
|
||||||
|
|
||||||
|
RELAY_REGISTER_UNARY_OP("ones_like")
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like")
|
|
||||||
.describe(R"code(Returns an array of ones, with same type and shape as the input.
|
.describe(R"code(Returns an array of ones, with same type and shape as the input.
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(4);
|
||||||
.add_type_rel("Identity", IdentityRel);
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
|
RELAY_REGISTER_UNARY_OP("sigmoid")
|
||||||
.describe(R"code(Returns the sigmoid input array, computed element-wise.
|
.describe(R"code(Returns the sigmoid input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -79,48 +72,47 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy")
|
RELAY_REGISTER_UNARY_OP("copy")
|
||||||
.describe(R"code(Copy a tensor.
|
.describe(R"code(Copy a tensor.
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity));
|
||||||
|
|
||||||
// relay.clip
|
// relay.clip
|
||||||
TVM_REGISTER_NODE_TYPE(ClipAttrs);
|
TVM_REGISTER_NODE_TYPE(ClipAttrs);
|
||||||
|
|
||||||
TVM_REGISTER_API("relay.op._make.clip")
|
TVM_REGISTER_API("relay.op._make.clip")
|
||||||
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
|
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
|
||||||
auto attrs = make_node<ClipAttrs>();
|
auto attrs = make_node<ClipAttrs>();
|
||||||
attrs->a_min = a_min;
|
attrs->a_min = a_min;
|
||||||
attrs->a_max = a_max;
|
attrs->a_max = a_max;
|
||||||
static const Op& op = Op::Get("clip");
|
static const Op& op = Op::Get("clip");
|
||||||
return CallNode::make(op, {a}, Attrs(attrs), {});
|
return CallNode::make(op, {a}, Attrs(attrs), {});
|
||||||
});
|
});
|
||||||
|
|
||||||
RELAY_REGISTER_OP("clip")
|
RELAY_REGISTER_OP("clip")
|
||||||
.describe(R"code(Clip tensor values.
|
.describe(R"code(Clip tensor values.
|
||||||
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
|
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_num_inputs(1)
|
.set_num_inputs(1)
|
||||||
.add_argument("tensor", "Tensor", "The input tensor.")
|
.add_argument("data", "Tensor", "The input tensor.")
|
||||||
.set_support_level(3)
|
.add_type_rel("Identity", IdentityRel)
|
||||||
.add_type_rel("Clip", IdentityRel);
|
.set_attr<TOpPattern>("TOpPattern", kElemWise)
|
||||||
|
.set_attr<TOpIsStateful>("TOpIsStateful", false)
|
||||||
|
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
|
||||||
|
.set_support_level(3);
|
||||||
|
|
||||||
|
RELAY_REGISTER_UNARY_OP("floor")
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor")
|
|
||||||
.describe(R"code(Returns the floor of input array, computed element-wise.
|
.describe(R"code(Returns the floor of input array, computed element-wise.
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
|
RELAY_REGISTER_UNARY_OP("ceil")
|
||||||
.describe(R"code(Returns the ceil of input array, computed element-wise.
|
.describe(R"code(Returns the ceil of input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -128,11 +120,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
|
RELAY_REGISTER_UNARY_OP("trunc")
|
||||||
.describe(R"code(Returns the trunc of input array, computed element-wise.
|
.describe(R"code(Returns the trunc of input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -140,11 +131,9 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc));
|
||||||
|
|
||||||
|
RELAY_REGISTER_UNARY_OP("round")
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
|
|
||||||
.describe(R"code(Returns the round of input array, computed element-wise.
|
.describe(R"code(Returns the round of input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -152,11 +141,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "round")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
|
RELAY_REGISTER_UNARY_OP("abs")
|
||||||
.describe(R"code(Returns the abs of input array, computed element-wise.
|
.describe(R"code(Returns the abs of input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -164,11 +152,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
|
RELAY_REGISTER_UNARY_OP("tanh")
|
||||||
.describe(R"code(Returns the tanh of input array, computed element-wise.
|
.describe(R"code(Returns the tanh of input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -176,11 +163,10 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(1)
|
.set_support_level(1)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
|
||||||
|
|
||||||
|
|
||||||
RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
|
RELAY_REGISTER_UNARY_OP("negative")
|
||||||
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
|
.describe(R"code(Returns the numeric negative of input array, computed element-wise.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
@ -188,7 +174,6 @@ RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative")
|
||||||
|
|
||||||
)code" TVM_ADD_FILELINE)
|
)code" TVM_ADD_FILELINE)
|
||||||
.set_support_level(3)
|
.set_support_level(3)
|
||||||
.add_type_rel("Identity", IdentityRel)
|
|
||||||
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
|
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative));
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
|
|
|
@ -0,0 +1,312 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2018 by Contributors
|
||||||
|
* \file alter_op_layout.cc
|
||||||
|
* \brief Alternate the layouts of operators or replace primitive operators with
|
||||||
|
other expressions. This pass can be used for computing convolution in
|
||||||
|
custom layouts or other general weight pre-transformation.
|
||||||
|
*/
|
||||||
|
#include <tvm/relay/pass.h>
|
||||||
|
#include <tvm/relay/op_attr_types.h>
|
||||||
|
#include <tvm/relay/attrs/transform.h>
|
||||||
|
#include <tvm/tvm.h>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "alter_op_layout.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace relay {
|
||||||
|
|
||||||
|
namespace alter_op_layout {
|
||||||
|
|
||||||
|
// Make a transform CallNode
|
||||||
|
Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
|
||||||
|
if (src_layout.Equals(dst_layout)) { return raw; }
|
||||||
|
CHECK(src_layout.defined() && dst_layout.defined())
|
||||||
|
<< "Cannot insert layout transform because there are undefined layouts";
|
||||||
|
CHECK(src_layout.Convertible(dst_layout))
|
||||||
|
<< "Cannot insert layout transform because there are inconvertible layouts: "
|
||||||
|
<< src_layout << " v.s. " << dst_layout;
|
||||||
|
static auto &transform_op = Op::Get("layout_transform");
|
||||||
|
NodePtr<LayoutTransformAttrs> attrs = make_node<LayoutTransformAttrs>();
|
||||||
|
attrs->src_layout = src_layout.name();
|
||||||
|
attrs->dst_layout = dst_layout.name();
|
||||||
|
Call transform = CallNode::make(transform_op, {raw}, Attrs{attrs});
|
||||||
|
return transform;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Memorize layout transform so we can reuse internal transformed nodes
|
||||||
|
class TransformMemorizerNode : public Node {
|
||||||
|
public:
|
||||||
|
// map from (Expr, src_layout, dst_layout) to transformed Expr
|
||||||
|
using TransformKey = std::tuple<const Node*, std::string, std::string>;
|
||||||
|
struct key_hash : public std::unary_function<TransformKey , std::size_t> {
|
||||||
|
std::size_t operator()(const TransformKey& k) const {
|
||||||
|
return dmlc::HashCombine<std::string>(dmlc::HashCombine<std::string>(
|
||||||
|
std::hash<const Node*>()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unordered_map<TransformKey, Expr, key_hash> memo;
|
||||||
|
static constexpr const char *_type_key = "relay.alter_op_layout.TransformMemorizerNode";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node);
|
||||||
|
};
|
||||||
|
|
||||||
|
class TransformMemorizer : public NodeRef {
|
||||||
|
public:
|
||||||
|
TransformMemorizer() {}
|
||||||
|
explicit TransformMemorizer(NodePtr<Node> n) : NodeRef(n) {}
|
||||||
|
|
||||||
|
TransformMemorizerNode* operator->() {
|
||||||
|
return static_cast<TransformMemorizerNode*>(node_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform layout with memorizer
|
||||||
|
Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) {
|
||||||
|
if (src_layout.Equals(dst_layout)) { return raw; }
|
||||||
|
|
||||||
|
std::tuple<const Node*, std::string, std::string> key =
|
||||||
|
std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name());
|
||||||
|
auto& memo = operator->()->memo;
|
||||||
|
|
||||||
|
auto iter = memo.find(key);
|
||||||
|
if (iter != memo.end()) {
|
||||||
|
return iter->second;
|
||||||
|
} else {
|
||||||
|
Expr transform = TransformLayout(raw, src_layout, dst_layout);
|
||||||
|
memo[key] = transform;
|
||||||
|
return transform;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using ContainerType = TransformMemorizerNode;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// TempExprNode during layout transform
|
||||||
|
// Instance of this expr will be Realized to normal expr ultimately
|
||||||
|
class LayoutAlternatedExprNode : public TempExprNode {
|
||||||
|
public:
|
||||||
|
Expr value;
|
||||||
|
Layout old_layout;
|
||||||
|
Layout new_layout;
|
||||||
|
TransformMemorizer memorizer;
|
||||||
|
|
||||||
|
Expr Realize() const final {
|
||||||
|
// NOTE: use a copy to discard the "const" qualifier
|
||||||
|
TransformMemorizer tmp_memorizer = memorizer;
|
||||||
|
// fallback to old layout
|
||||||
|
return tmp_memorizer.Transform(value, new_layout, old_layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VisitAttrs(AttrVisitor *v) final {
|
||||||
|
v->Visit("value", &value);
|
||||||
|
v->Visit("old_layout", &old_layout);
|
||||||
|
v->Visit("new_layout", &new_layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr const char *_type_key = "relay.alter_op_layout.LayoutAlternatedExprNode";
|
||||||
|
TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode);
|
||||||
|
};
|
||||||
|
|
||||||
|
RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr);
|
||||||
|
|
||||||
|
// Call registered FInferCorrectLayout of an op.
|
||||||
|
// Parameters are the same as the parameters for FInferCorrectLayout
|
||||||
|
// Returns inferred_input_layout, inferred_output_layout, success
|
||||||
|
std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
|
||||||
|
const Call& call,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr> > &old_in_shapes) {
|
||||||
|
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
|
||||||
|
|
||||||
|
Op op = Downcast<Op>(call->op);
|
||||||
|
if (finfer_layout.count(op)) {
|
||||||
|
Array<Array<Layout> > inferred_layouts;
|
||||||
|
inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts,
|
||||||
|
old_in_layouts, old_in_shapes);
|
||||||
|
CHECK_EQ(inferred_layouts.size(), 2)
|
||||||
|
<< "FInferCorrectLayout should return an array with size of 2";
|
||||||
|
for (auto x : inferred_layouts) {
|
||||||
|
for (auto y : x) {
|
||||||
|
if (!y.defined()) { // inference fails
|
||||||
|
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true);
|
||||||
|
} else {
|
||||||
|
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call registered FTVMAlterOpLayout of an op
|
||||||
|
// Returns the altered expression
|
||||||
|
Call CallAlter(const Call& ref_call,
|
||||||
|
const std::vector<Expr>& new_args) {
|
||||||
|
static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
|
||||||
|
Op op = Downcast<Op>(ref_call->op);
|
||||||
|
|
||||||
|
Expr new_e;
|
||||||
|
bool modified = false;
|
||||||
|
if (falter_layout.count(op)) {
|
||||||
|
tvm::Array<tvm::Tensor> tinfos;
|
||||||
|
for (auto expr : ref_call->args) {
|
||||||
|
auto ttype = expr->type_as<TensorTypeNode>();
|
||||||
|
tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype));
|
||||||
|
}
|
||||||
|
Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos);
|
||||||
|
if (altered_value.defined()) {
|
||||||
|
new_e = altered_value;
|
||||||
|
modified = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!modified) {
|
||||||
|
new_e = CallNode::make(ref_call->op, new_args,
|
||||||
|
ref_call->attrs, ref_call->type_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
const CallNode *new_call = new_e.as<CallNode>();
|
||||||
|
CHECK(new_call) << "Can only replace the original operator with another call node";
|
||||||
|
return GetRef<Call>(new_call);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr AlterOpLayoutRewrite(const Call &ref_call,
|
||||||
|
const Array<Expr> &new_args,
|
||||||
|
const NodeRef& ctx) {
|
||||||
|
std::vector<LayoutAlternatedExpr> inputs;
|
||||||
|
std::vector<Expr> normal_new_args;
|
||||||
|
Array<Array<IndexExpr> > input_shapes;
|
||||||
|
|
||||||
|
// NOTE: discard the "const" qualifier
|
||||||
|
TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);
|
||||||
|
|
||||||
|
// fill incomplete state and expand tuple
|
||||||
|
for (auto new_arg : new_args) {
|
||||||
|
auto push_back_one_arg = [&](Expr arg) {
|
||||||
|
// We always expect LayoutAlternatedExpr.
|
||||||
|
// This is used to convert the normal Expr to LayoutAlternatedExpr.
|
||||||
|
if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) {
|
||||||
|
inputs.push_back(GetRef<LayoutAlternatedExpr>(inp));
|
||||||
|
normal_new_args.push_back(inp->value);
|
||||||
|
} else {
|
||||||
|
auto inode = make_node<LayoutAlternatedExprNode>();
|
||||||
|
inode->value = arg;
|
||||||
|
inode->memorizer = memorizer;
|
||||||
|
inputs.push_back(LayoutAlternatedExpr(inode));
|
||||||
|
normal_new_args.push_back(arg);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (new_arg->is_type<TupleNode>()) {
|
||||||
|
Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
|
||||||
|
for (auto x : tuple_new_arg->fields) {
|
||||||
|
push_back_one_arg(x);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
push_back_one_arg(new_arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// old_in, new_in = state[inputs]
|
||||||
|
Array<Layout> old_in, old_out, new_in, new_out, new_in2;
|
||||||
|
for (auto inp : inputs) {
|
||||||
|
old_in.push_back(inp->old_layout);
|
||||||
|
new_in.push_back(inp->new_layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto arg : ref_call->args) {
|
||||||
|
if (arg->is_type<TupleNode>()) { // expand tuple
|
||||||
|
Tuple tuple_arg = Downcast<Tuple>(arg);
|
||||||
|
for (auto x : tuple_arg->fields) {
|
||||||
|
input_shapes.push_back(x->type_as<TensorTypeNode>()->shape);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// old_in, old_out = op.infer(old_in)
|
||||||
|
bool success = false;
|
||||||
|
std::tie(old_in, old_out, success) = CallInfer(ref_call,
|
||||||
|
Array<Layout>(nullptr),
|
||||||
|
old_in, input_shapes);
|
||||||
|
if (!success) { return Expr(nullptr); }
|
||||||
|
CHECK_EQ(old_in.size(), new_in.size());
|
||||||
|
|
||||||
|
// if new_in == 'undef': new_in = old_in
|
||||||
|
for (size_t i = 0; i < new_in.size(); ++i) {
|
||||||
|
if (!new_in[i].defined()) {
|
||||||
|
new_in.Set(i, old_in[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// new_op = alter(op)
|
||||||
|
Call new_call = CallAlter(ref_call, normal_new_args);
|
||||||
|
|
||||||
|
// new_in2, new_out = op.infer(new_in)
|
||||||
|
if (new_call->op->is_type<OpNode>()) {
|
||||||
|
success = false;
|
||||||
|
std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes);
|
||||||
|
if (!success) { return Expr(nullptr); }
|
||||||
|
} else {
|
||||||
|
return Expr(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(new_out.size(), old_out.size())
|
||||||
|
<< "The number of output nodes should keep the same during alter_op_layout";
|
||||||
|
CHECK_EQ(new_in.size(), new_in2.size())
|
||||||
|
<< "The number of input nodes should keep the same during alter_op_layout";
|
||||||
|
|
||||||
|
// if (new_in != new_in2): insert transform (new_in -> new_in2)
|
||||||
|
Array<Expr> transformed_args;
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
transformed_args.push_back(memorizer.Transform(new_call->args[i], new_in[i], new_in2[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// state[node] = (old_out, new_out)
|
||||||
|
CHECK(ref_call->checked_type_.defined())
|
||||||
|
<< "Call infer_type pass before alter_op_layout pass";
|
||||||
|
|
||||||
|
if (ref_call->checked_type()->is_type<TupleTypeNode>()) {
|
||||||
|
Expr tuple_output = CallNode::make(new_call->op, transformed_args,
|
||||||
|
new_call->attrs, new_call->type_args);
|
||||||
|
Array<Expr> fields;
|
||||||
|
for (size_t i = 0; i < new_out.size(); ++i) {
|
||||||
|
auto rnode = make_node<LayoutAlternatedExprNode>();
|
||||||
|
rnode->value = TupleGetItemNode::make(tuple_output, i);
|
||||||
|
rnode->old_layout = old_out[i];
|
||||||
|
rnode->new_layout = new_out[i];
|
||||||
|
rnode->memorizer = memorizer;
|
||||||
|
fields.push_back(Expr(rnode));
|
||||||
|
}
|
||||||
|
return TupleNode::make(fields);
|
||||||
|
} else {
|
||||||
|
auto rnode = make_node<LayoutAlternatedExprNode>();
|
||||||
|
CHECK_EQ(new_out.size(), 1);
|
||||||
|
rnode->value = CallNode::make(new_call->op, transformed_args,
|
||||||
|
new_call->attrs, new_call->type_args);
|
||||||
|
rnode->old_layout = old_out[0];
|
||||||
|
rnode->new_layout = new_out[0];
|
||||||
|
rnode->memorizer = memorizer;
|
||||||
|
return Expr(rnode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_API("relay._ir_pass.AlterOpLayout")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue *ret) {
|
||||||
|
TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
|
||||||
|
auto fcontext = [&](const Call& call) -> NodeRef{
|
||||||
|
return transformMemorizer;
|
||||||
|
};
|
||||||
|
|
||||||
|
*ret = ForwardRewrite(args[0], AlterOpLayoutRewrite, fcontext);
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace alter_op_layout
|
||||||
|
|
||||||
|
} // namespace relay
|
||||||
|
} // namespace tvm
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2018 by Contributors
|
||||||
|
* \file alter_op_layout.h
|
||||||
|
* \brief Alternate the layouts of operators or replace primitive operators with
|
||||||
|
other expressions. This pass can be used for computing convolution in
|
||||||
|
custom layouts or other general weight pre-transformation.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
|
||||||
|
#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
|
||||||
|
|
||||||
|
#include <tvm/relay/expr.h>
|
||||||
|
|
||||||
|
#include "../op/layout.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace relay {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Infer & correct function of node layout. See \p Layout for layout convention
|
||||||
|
* \param attrs The attribute of the node.
|
||||||
|
* \param new_in_layouts The layouts of input arguments after alter_op_layout.
|
||||||
|
* This can be undefined, which means we call this function before alternating
|
||||||
|
* any operators.
|
||||||
|
* \param old_in_layouts The layouts of input arguments before alter_op_layout.
|
||||||
|
* \param old_in_shapes The shapes of old input arguments.
|
||||||
|
* \return infered_layout An array of two elements that are inferred input layouts and
|
||||||
|
* inferred output layouts.
|
||||||
|
*/
|
||||||
|
using FInferCorrectLayout = runtime::TypedPackedFunc<
|
||||||
|
Array<Array<Layout>>(const Attrs& attrs,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr>> &old_in_shapes)>;
|
||||||
|
|
||||||
|
/*! \brief take arbitrary input layout and copy to output */
|
||||||
|
inline Array<Array<Layout> > ElemwiseArbitraryLayout(const Attrs& attrs,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr>> &old_in_shapes) {
|
||||||
|
Layout ret;
|
||||||
|
|
||||||
|
if (new_in_layouts.defined()) {
|
||||||
|
CHECK_GE(new_in_layouts.size(), 1);
|
||||||
|
ret = new_in_layouts[0];
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < old_in_layouts.size(); ++i) {
|
||||||
|
if (old_in_layouts[i].defined()) {
|
||||||
|
ret = old_in_layouts[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array<Array<Layout> >{Array<Layout>(old_in_layouts.size(), ret), {ret}};
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! \brief Infer layout for binary broadcast operators */
|
||||||
|
inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
|
||||||
|
const Array<Layout>& new_in_layouts,
|
||||||
|
const Array<Layout>& old_in_layouts,
|
||||||
|
const Array<Array<IndexExpr>> &old_in_shapes) {
|
||||||
|
Array<Layout> layouts;
|
||||||
|
|
||||||
|
if (new_in_layouts.defined()) {
|
||||||
|
layouts.assign(new_in_layouts.begin(), new_in_layouts.end());
|
||||||
|
} else {
|
||||||
|
layouts.assign(old_in_layouts.begin(), old_in_layouts.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!layouts[0].defined() && !layouts[1].defined()) {
|
||||||
|
// both undefined, infer fails
|
||||||
|
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
|
||||||
|
} else if (!layouts[0].defined() || !layouts[1].defined()) {
|
||||||
|
// only one is defined, use shape information to help infer
|
||||||
|
int defined_idx = layouts[0].defined() ? 0 : 1;
|
||||||
|
int undef_idx = 1 - defined_idx;
|
||||||
|
|
||||||
|
if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) {
|
||||||
|
layouts.Set(undef_idx,
|
||||||
|
layouts[defined_idx].Sublayout(
|
||||||
|
old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(),
|
||||||
|
old_in_shapes[undef_idx].size()));
|
||||||
|
return Array<Array<Layout> > {layouts, {layouts[defined_idx]}};
|
||||||
|
} else {
|
||||||
|
// only know the tensor with smaller dimensions,
|
||||||
|
// so we cannot infer the final broadcasted output.
|
||||||
|
// fails in this case.
|
||||||
|
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// try to broadcast the tensors to the larger dimension
|
||||||
|
int large_idx = layouts[0].ndim_super() >= layouts[1].ndim_super() ? 0 : 1;
|
||||||
|
int small_idx = 1 - large_idx;
|
||||||
|
Layout ret = layouts[large_idx];
|
||||||
|
|
||||||
|
// extract common part
|
||||||
|
size_t i = layouts[large_idx].ndim();
|
||||||
|
for (; i != 0; --i) {
|
||||||
|
auto dim = layouts[large_idx][i-1];
|
||||||
|
if (!layouts[small_idx].Contains(Layout::ToSuperdim(dim))) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Layout common_part = layouts[large_idx].Sublayout(i, layouts[large_idx].ndim() - i);
|
||||||
|
if (!layouts[small_idx].Convertible(common_part)) { // fail
|
||||||
|
return Array<Array<Layout> > {{Layout::Undef()}, {Layout::Undef()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
layouts.Set(small_idx, common_part);
|
||||||
|
return Array<Array<Layout> > {layouts, {ret}};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace relay
|
||||||
|
} // namespace tvm
|
||||||
|
|
||||||
|
#endif // TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*!
|
||||||
|
* Copyright (c) 2018 by Contributors
|
||||||
|
* \file canonicalize_ops.cc
|
||||||
|
* \brief Canonicalize special operators to basic operators.
|
||||||
|
This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
|
||||||
|
*/
|
||||||
|
#include <tvm/relay/pass.h>
|
||||||
|
#include <tvm/relay/expr_functor.h>
|
||||||
|
#include <tvm/relay/attrs/nn.h>
|
||||||
|
#include "pattern_util.h"
|
||||||
|
|
||||||
|
namespace tvm {
|
||||||
|
namespace relay {
|
||||||
|
|
||||||
|
class BiasAddSimplifier : public ExprMutator {
|
||||||
|
public:
|
||||||
|
Expr VisitExpr_(const CallNode* n) {
|
||||||
|
static const Op& bias_add = Op::Get("nn.bias_add");
|
||||||
|
auto new_n = ExprMutator::VisitExpr_(n);
|
||||||
|
if (n->op.same_as(bias_add)) {
|
||||||
|
Call call = Downcast<Call>(new_n);
|
||||||
|
CHECK_EQ(call->args.size(), 2);
|
||||||
|
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
|
||||||
|
|
||||||
|
auto ttype = call->args[0]->type_as<TensorTypeNode>();
|
||||||
|
size_t n_dim = ttype->shape.size();
|
||||||
|
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
|
||||||
|
Expr ret = Add(call->args[0], expanded_bias);
|
||||||
|
ret->checked_type_ = n->checked_type_;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return new_n;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Expr CanonicalizeOps(const Expr& e) {
|
||||||
|
return BiasAddSimplifier().Mutate(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
|
||||||
|
.set_body([](TVMArgs args, TVMRetValue* ret) {
|
||||||
|
*ret = CanonicalizeOps(args[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace relay
|
||||||
|
} // namespace tvm
|
|
@ -29,11 +29,11 @@ using runtime::TypedPackedFunc;
|
||||||
// FoldScaleAxis algorithm:
|
// FoldScaleAxis algorithm:
|
||||||
//
|
//
|
||||||
// The general idea is to transform Expr to tuple of
|
// The general idea is to transform Expr to tuple of
|
||||||
// (value, axes, scale), where the final result satiesfies:
|
// (value, axes, scale), where the final result satisfies:
|
||||||
//
|
//
|
||||||
// result = value
|
// result = value
|
||||||
// for i, k in enumerate(axes):
|
// for i, k in enumerate(axes):
|
||||||
// k-ith dimension of result *= i-th dimension of scale
|
// k-th dimension of result *= i-th dimension of scale
|
||||||
//
|
//
|
||||||
// Then we can propagate this signal along and fold the scale if necessary.
|
// Then we can propagate this signal along and fold the scale if necessary.
|
||||||
// However, it is possible that certain scale may never be consumed
|
// However, it is possible that certain scale may never be consumed
|
||||||
|
|
|
@ -42,13 +42,20 @@ class TempRealizer : private ExprMutator {
|
||||||
|
|
||||||
class ForwardRewriter : private ExprMutator {
|
class ForwardRewriter : private ExprMutator {
|
||||||
public:
|
public:
|
||||||
ForwardRewriter(const OpMap<FForwardRewrite>& rewrite_map,
|
ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map,
|
||||||
std::function<NodeRef(const Call&)> fcontext,
|
std::function<NodeRef(const Call&)> fcontext,
|
||||||
std::function<Expr(const Expr&)> fmulti_ref_trigger)
|
std::function<Expr(const Expr&)> fmulti_ref_trigger)
|
||||||
: rewrite_map_(rewrite_map),
|
: rewrite_map_(rewrite_map),
|
||||||
fcontext_(fcontext),
|
fcontext_(fcontext),
|
||||||
fmulti_ref_trigger_(fmulti_ref_trigger) {
|
fmulti_ref_trigger_(fmulti_ref_trigger) {}
|
||||||
}
|
|
||||||
|
ForwardRewriter(const FForwardRewrite* rewrite_func,
|
||||||
|
std::function<NodeRef(const Call&)> fcontext,
|
||||||
|
std::function<Expr(const Expr&)> fmulti_ref_trigger)
|
||||||
|
: rewrite_func_(rewrite_func),
|
||||||
|
fcontext_(fcontext),
|
||||||
|
fmulti_ref_trigger_(fmulti_ref_trigger) {}
|
||||||
|
|
||||||
|
|
||||||
// Transform expression.
|
// Transform expression.
|
||||||
Expr Rewrite(Expr expr) {
|
Expr Rewrite(Expr expr) {
|
||||||
|
@ -60,8 +67,9 @@ class ForwardRewriter : private ExprMutator {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The rewrite rule.
|
// The rewrite rule.
|
||||||
const OpMap<FForwardRewrite>& rewrite_map_;
|
const OpMap<FForwardRewrite>* rewrite_map_{nullptr};
|
||||||
// The context.
|
const FForwardRewrite* rewrite_func_{nullptr};
|
||||||
|
// The context.const
|
||||||
std::function<NodeRef(const Call&)> fcontext_{nullptr};
|
std::function<NodeRef(const Call&)> fcontext_{nullptr};
|
||||||
// The multiple reference trigger
|
// The multiple reference trigger
|
||||||
std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
|
std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
|
||||||
|
@ -104,9 +112,31 @@ class ForwardRewriter : private ExprMutator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Expr VisitExpr_(const TupleNode* op) final {
|
||||||
|
tvm::Array<Expr> fields;
|
||||||
|
bool all_fields_unchanged = true;
|
||||||
|
for (auto field : op->fields) {
|
||||||
|
auto new_field = this->GetTempExpr(field);
|
||||||
|
fields.push_back(new_field);
|
||||||
|
all_fields_unchanged &= new_field.same_as(field);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (all_fields_unchanged) {
|
||||||
|
return GetRef<Expr>(op);
|
||||||
|
} else {
|
||||||
|
return TupleNode::make(fields);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Expr VisitExpr_(const CallNode* call_node) final {
|
Expr VisitExpr_(const CallNode* call_node) final {
|
||||||
const Call& ref_call = GetRef<Call>(call_node);
|
const Call& ref_call = GetRef<Call>(call_node);
|
||||||
PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr);
|
PackedFunc frewrite;
|
||||||
|
if (rewrite_func_) {
|
||||||
|
frewrite = *rewrite_func_;
|
||||||
|
} else {
|
||||||
|
CHECK(rewrite_map_);
|
||||||
|
frewrite = rewrite_map_->get(call_node->op, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
auto new_op = this->Mutate(call_node->op);
|
auto new_op = this->Mutate(call_node->op);
|
||||||
bool unchanged = call_node->op.same_as(new_op);
|
bool unchanged = call_node->op.same_as(new_op);
|
||||||
|
@ -147,9 +177,16 @@ Expr ForwardRewrite(const Expr& expr,
|
||||||
std::function<NodeRef(const Call&)> fcontext,
|
std::function<NodeRef(const Call&)> fcontext,
|
||||||
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
|
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
|
||||||
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
|
auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name);
|
||||||
return ForwardRewriter(rewrite_map,
|
return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
|
||||||
fcontext,
|
|
||||||
fmulti_ref_trigger).Rewrite(expr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Expr ForwardRewrite(const Expr& expr,
|
||||||
|
const FForwardRewrite& rewrite_func,
|
||||||
|
std::function<NodeRef(const Call&)> fcontext,
|
||||||
|
std::function<Expr(const Expr&)> fmulti_ref_trigger) {
|
||||||
|
return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace relay
|
} // namespace relay
|
||||||
} // namespace tvm
|
} // namespace tvm
|
||||||
|
|
|
@ -73,7 +73,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
|
||||||
* the target Tensor on the specified axis via broadcasting rule.
|
* the target Tensor on the specified axis via broadcasting rule.
|
||||||
*
|
*
|
||||||
* \param bias The bias.
|
* \param bias The bias.
|
||||||
* \param target_ndim target dimension.
|
* \param target_ndim Target dimension.
|
||||||
* \param axes The axis on the output we want to match on.
|
* \param axes The axis on the output we want to match on.
|
||||||
*/
|
*/
|
||||||
inline Expr ExpandBiasToMatchAxis(Expr bias,
|
inline Expr ExpandBiasToMatchAxis(Expr bias,
|
||||||
|
|
|
@ -0,0 +1,316 @@
|
||||||
|
"""Test alter op layout pass"""
|
||||||
|
|
||||||
|
from tvm import relay
|
||||||
|
from tvm.relay.op import register_alter_op_layout
|
||||||
|
from tvm.relay.ir_pass import *
|
||||||
|
|
||||||
|
def test_alter_op():
|
||||||
|
"""Test directly replacing an operator with a new one"""
|
||||||
|
def before():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
weight = relay.var('weight', shape=(64, 64, 3, 3))
|
||||||
|
y = relay.nn.conv2d(x, weight,
|
||||||
|
channels=64,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1))
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y = relay.Function([x, weight], y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@register_alter_op_layout("nn.conv2d", level=100)
|
||||||
|
def alter_conv2d(attrs, inputs, tinfos):
|
||||||
|
data, weight = inputs
|
||||||
|
weight = relay.multiply(weight, relay.const(2.0))
|
||||||
|
return relay.nn.conv2d(data, weight, **attrs)
|
||||||
|
|
||||||
|
def expected():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
weight = relay.var('weight', shape=(64, 64, 3, 3))
|
||||||
|
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0)),
|
||||||
|
channels=64,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1))
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y = relay.Function([x, weight], y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
a = before()
|
||||||
|
a = infer_type(a)
|
||||||
|
a = alter_op_layout(a)
|
||||||
|
|
||||||
|
b = expected()
|
||||||
|
b = infer_type(b)
|
||||||
|
|
||||||
|
assert(alpha_equal(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
def test_alter_return_none():
|
||||||
|
"""Test doing nothing by returning 'None' """
|
||||||
|
def before():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
y = relay.nn.global_max_pool2d(x)
|
||||||
|
y = relay.Function([x], y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
called = [False]
|
||||||
|
|
||||||
|
@register_alter_op_layout("nn.global_max_pool2d", level=101)
|
||||||
|
def alter_conv2d(attrs, inputs, tinfos):
|
||||||
|
called[0] = True
|
||||||
|
return None
|
||||||
|
|
||||||
|
a = before()
|
||||||
|
a = infer_type(a)
|
||||||
|
a = alter_op_layout(a)
|
||||||
|
|
||||||
|
b = before()
|
||||||
|
b = infer_type(b)
|
||||||
|
assert(alpha_equal(a, b))
|
||||||
|
assert(called[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_alter_layout():
|
||||||
|
"""Test alternating the layout of a conv2d.
|
||||||
|
The layout of broadcast operators and the weight should be changed accordingly.
|
||||||
|
"""
|
||||||
|
def before():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
bias = relay.var("bias")
|
||||||
|
weight = relay.var("weight")
|
||||||
|
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
|
||||||
|
y = relay.nn.bias_add(y, bias)
|
||||||
|
# a useless tuple, which will be eliminated
|
||||||
|
y = relay.Tuple([y])[0]
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y = relay.nn.batch_flatten(y)
|
||||||
|
y = relay.Function(free_vars(y), y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@register_alter_op_layout("nn.conv2d", level=102)
|
||||||
|
def alter_conv2d(attrs, inputs, tinfos):
|
||||||
|
data, weight = inputs
|
||||||
|
new_attrs = dict(attrs)
|
||||||
|
new_attrs['data_layout'] = 'NCHW16c'
|
||||||
|
new_attrs['weight_layout'] = 'OIHW16i'
|
||||||
|
return relay.nn.conv2d(data, weight, **new_attrs)
|
||||||
|
|
||||||
|
def expected():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
bias = relay.var("bias", shape=(64,))
|
||||||
|
weight = relay.var("weight", shape=(64, 64, 3, 3))
|
||||||
|
|
||||||
|
y = relay.layout_transform(x, "NCHW", "NCHW16c")
|
||||||
|
w = relay.layout_transform(weight, "OIHW", "OIHW16i")
|
||||||
|
y = relay.nn.conv2d(y, w,
|
||||||
|
channels=64,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1),
|
||||||
|
weight_layout="OIHW16i",
|
||||||
|
data_layout="NCHW16c")
|
||||||
|
b = relay.expand_dims(bias, axis=1, num_newaxis=2)
|
||||||
|
b = relay.layout_transform(b, "CHW", "CHW16c")
|
||||||
|
y = relay.add(y, b)
|
||||||
|
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||||
|
y = relay.nn.batch_flatten(y)
|
||||||
|
y = relay.Function(free_vars(y), y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
a = before()
|
||||||
|
a = infer_type(a)
|
||||||
|
a = canonicalize_ops(a)
|
||||||
|
a = infer_type(a)
|
||||||
|
a = alter_op_layout(a)
|
||||||
|
a = infer_type(a)
|
||||||
|
|
||||||
|
b = expected()
|
||||||
|
b = infer_type(b)
|
||||||
|
|
||||||
|
assert(alpha_equal(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
def test_alter_layout_dual_path():
|
||||||
|
"""
|
||||||
|
Test alternating the layout with two outputs.
|
||||||
|
One path continues to use the new layout while one path fall backs to old layout.
|
||||||
|
"""
|
||||||
|
def before():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
weight1 = relay.var('weight1')
|
||||||
|
weight2 = relay.var('weight2')
|
||||||
|
y = relay.nn.conv2d(x, weight1,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1))
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y1 = relay.nn.conv2d(y, weight2,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1))
|
||||||
|
y1 = relay.nn.relu(y1)
|
||||||
|
y2 = relay.nn.batch_flatten(y)
|
||||||
|
ret = relay.Tuple([y1, y2])
|
||||||
|
y = relay.Function(free_vars(ret), ret)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@register_alter_op_layout("nn.conv2d", level=103)
|
||||||
|
def alter_conv2d(attrs, inputs, tinfos):
|
||||||
|
data, weight = inputs
|
||||||
|
new_attrs = dict(attrs)
|
||||||
|
new_attrs['data_layout'] = 'NCHW16c'
|
||||||
|
return relay.nn.conv2d(data, weight, **new_attrs)
|
||||||
|
|
||||||
|
def expected():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
weight1 = relay.var('weight1')
|
||||||
|
weight2 = relay.var('weight2')
|
||||||
|
y = relay.layout_transform(x, "NCHW", "NCHW16c")
|
||||||
|
y = relay.nn.conv2d(y, weight1,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1),
|
||||||
|
data_layout="NCHW16c")
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y1 = relay.nn.conv2d(y, weight2,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1),
|
||||||
|
data_layout='NCHW16c')
|
||||||
|
y1 = relay.nn.relu(y1)
|
||||||
|
y1 = relay.layout_transform(y1, "NCHW16c", "NCHW")
|
||||||
|
y2 = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||||
|
y2 = relay.nn.batch_flatten(y2)
|
||||||
|
ret = relay.Tuple([y1, y2])
|
||||||
|
y = relay.Function(free_vars(ret), ret)
|
||||||
|
return y
|
||||||
|
|
||||||
|
a = before()
|
||||||
|
a = infer_type(a)
|
||||||
|
a = alter_op_layout(a)
|
||||||
|
a = infer_type(a)
|
||||||
|
|
||||||
|
b = expected()
|
||||||
|
b = infer_type(b)
|
||||||
|
|
||||||
|
assert(alpha_equal(a, b))
|
||||||
|
|
||||||
|
def test_alter_layout_resnet():
|
||||||
|
"""Test alternating the layout of a residual block
|
||||||
|
This also tests the elimination of duplicated transformation.
|
||||||
|
If a same transformation applies to a same node twice, only one transformation will be created.
|
||||||
|
"""
|
||||||
|
def before():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
weight1 = relay.var('weight1')
|
||||||
|
weight2 = relay.var('weight2')
|
||||||
|
y = relay.nn.conv2d(x, weight1,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1))
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y2 = relay.nn.conv2d(x, weight2,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
y2 = relay.nn.relu(y2)
|
||||||
|
y = y + y2
|
||||||
|
y = relay.nn.global_max_pool2d(y)
|
||||||
|
return relay.Function(free_vars(y), y)
|
||||||
|
|
||||||
|
@register_alter_op_layout("nn.conv2d", level=104)
|
||||||
|
def alter_conv2d(attrs, inputs, tinfos):
|
||||||
|
data, weight = inputs
|
||||||
|
new_attrs = dict(attrs)
|
||||||
|
new_attrs['data_layout'] = 'NCHW16c'
|
||||||
|
return relay.nn.conv2d(data, weight, **new_attrs)
|
||||||
|
|
||||||
|
def expected():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
weight1 = relay.var('weight1')
|
||||||
|
weight2 = relay.var('weight2')
|
||||||
|
x = relay.layout_transform(x, "NCHW", "NCHW16c")
|
||||||
|
y = relay.nn.conv2d(x, weight1,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
padding=(1, 1),
|
||||||
|
data_layout="NCHW16c")
|
||||||
|
y = relay.nn.relu(y)
|
||||||
|
y2 = relay.nn.conv2d(x, weight2,
|
||||||
|
channels=32,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
data_layout='NCHW16c')
|
||||||
|
y2 = relay.nn.relu(y2)
|
||||||
|
y = y + y2
|
||||||
|
y = relay.nn.global_max_pool2d(y, layout="NCHW16c")
|
||||||
|
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||||
|
return relay.Function(free_vars(y), y)
|
||||||
|
|
||||||
|
a = before()
|
||||||
|
a = infer_type(a)
|
||||||
|
a = alter_op_layout(a)
|
||||||
|
a = infer_type(a)
|
||||||
|
|
||||||
|
b = expected()
|
||||||
|
b = infer_type(b)
|
||||||
|
|
||||||
|
assert(alpha_equal(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
def test_alter_layout_broadcast_op():
|
||||||
|
"""Test boradcast operators """
|
||||||
|
def before():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
bias = relay.var("bias", shape=(64,))
|
||||||
|
scale = relay.var("scale", shape=(64, 1, 1))
|
||||||
|
weight = relay.var("weight")
|
||||||
|
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
|
||||||
|
y = relay.nn.bias_add(y, bias) # test broadcasting to lhs
|
||||||
|
y = relay.multiply(scale, y) # test broadcasting to rhs
|
||||||
|
y = relay.Function(free_vars(y), y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@register_alter_op_layout("nn.conv2d", level=102)
|
||||||
|
def alter_conv2d(attrs, inputs, tinfos):
|
||||||
|
data, weight = inputs
|
||||||
|
new_attrs = dict(attrs)
|
||||||
|
new_attrs['data_layout'] = 'NCHW16c'
|
||||||
|
return relay.nn.conv2d(data, weight, **new_attrs)
|
||||||
|
|
||||||
|
def expected():
|
||||||
|
x = relay.var("x", shape=(1, 64, 56, 56))
|
||||||
|
bias = relay.var("bias", shape=(64,))
|
||||||
|
scale = relay.var("scale", shape=(64, 1, 1))
|
||||||
|
weight = relay.var("weight")
|
||||||
|
x = relay.layout_transform(x, "NCHW", "NCHW16c")
|
||||||
|
bias = relay.expand_dims(bias, 1, 2)
|
||||||
|
bias = relay.layout_transform(bias, "CHW", "CHW16c")
|
||||||
|
scale = relay.layout_transform(scale, "CHW", "CHW16c")
|
||||||
|
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1),
|
||||||
|
data_layout="NCHW16c")
|
||||||
|
y = relay.add(y, bias) # test broadcasting to lhs
|
||||||
|
y = relay.multiply(scale, y) # test broadcasting to rhs
|
||||||
|
y = relay.layout_transform(y, "NCHW16c", "NCHW")
|
||||||
|
y = relay.Function(free_vars(y), y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
a = before()
|
||||||
|
a = infer_type(a)
|
||||||
|
a = canonicalize_ops(a)
|
||||||
|
a = infer_type(a)
|
||||||
|
a = alter_op_layout(a)
|
||||||
|
a = infer_type(a)
|
||||||
|
|
||||||
|
b = expected()
|
||||||
|
b = infer_type(b)
|
||||||
|
|
||||||
|
assert(alpha_equal(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_alter_op()
|
||||||
|
test_alter_return_none()
|
||||||
|
test_alter_layout()
|
||||||
|
test_alter_layout_dual_path()
|
||||||
|
test_alter_layout_resnet()
|
||||||
|
test_alter_layout_broadcast_op()
|
||||||
|
|
|
@ -448,6 +448,7 @@ inline tvm::Tensor group_conv2d_ngchw(const tvm::Tensor& I,
|
||||||
}
|
}
|
||||||
|
|
||||||
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
|
using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indices)>;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Transform the layout according to the mapping function \p to_src_indices.
|
* \brief Transform the layout according to the mapping function \p to_src_indices.
|
||||||
* \param src the source input.
|
* \param src the source input.
|
||||||
|
|
Загрузка…
Ссылка в новой задаче