[TOP] Add dense, batchnorm (#22)
* [TOP] Add dense, batchnorm * update tvm
This commit is contained in:
Родитель
a2ab3d83a7
Коммит
215693df04
|
@ -44,11 +44,14 @@ using TOpPattern = int;
|
|||
* \brief Computation description interface
|
||||
* \param attrs The attribute of the node.
|
||||
* \param inputs The input tensors(placeholders)
|
||||
* \param out_info Tensors holding shape/type information about output,
|
||||
& these are always placeholders.
|
||||
* \return The output description of the tensor.
|
||||
*/
|
||||
using FTVMCompute = std::function<
|
||||
Array<Tensor>
|
||||
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
|
||||
Array<Tensor>(const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info)>;
|
||||
|
||||
/*!
|
||||
* \brief Build the computation schedule for
|
||||
|
|
|
@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"):
|
|||
"""
|
||||
# pylint: disable=unused-argument
|
||||
cfg = BuildConfig.current
|
||||
graph = graph_attr.set_shape_inputs(graph, shape)
|
||||
graph = graph.apply("InferShape")
|
||||
if graph.json_attr("shape_num_unknown_nodes"):
|
||||
raise ValueError("InferShape fails..")
|
||||
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
|
||||
graph = graph_attr.set_shape_inputs(graph, shape)
|
||||
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
|
||||
graph = graph.apply("SimplifyBatchNormInference")
|
||||
return graph
|
||||
|
||||
|
||||
|
@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None):
|
|||
cfg = BuildConfig.current
|
||||
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
|
||||
shape, dtype = _update_shape_dtype(shape, dtype, params)
|
||||
# Initial pass do shape type inference
|
||||
ishape, _ = graph_util.infer_shape(graph, **shape)
|
||||
shape.update(zip(graph.index.input_names, ishape))
|
||||
if not isinstance(dtype, str):
|
||||
idtype, _ = graph_util.infer_dtype(graph, **dtype)
|
||||
dtype.update(zip(graph.index.input_names, idtype))
|
||||
# Apply optimization
|
||||
graph = optimize(graph, shape, dtype)
|
||||
# Precompute prune
|
||||
|
|
|
@ -5,8 +5,10 @@ import tvm
|
|||
class OpPattern(object):
|
||||
ELEM_WISE = 0
|
||||
BROADCAST = 1
|
||||
# Complex means we can fuse elemwise to it
|
||||
COMPLEX = 2
|
||||
EXTERN = 2
|
||||
# Extern means the op is not fusable
|
||||
EXTERN = 3
|
||||
|
||||
_register_compute = tvm.get_global_func("nnvm._register_compute")
|
||||
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
from .attr_dict import AttrDict
|
||||
from . import tensor
|
||||
from . import nn
|
||||
from . import transform
|
||||
|
|
|
@ -1,30 +1,37 @@
|
|||
# pylint: disable=invalid-name, unused-argument
|
||||
"""Definition of nn ops"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tvm
|
||||
import topi
|
||||
from topi.util import get_const_int
|
||||
from .tensor import schedule_elemwise
|
||||
from .tensor import _fschedule_broadcast
|
||||
from ..compiler import registry as reg
|
||||
from ..compiler import OpPattern
|
||||
|
||||
# relu
|
||||
@reg.register_compute("relu")
|
||||
def compute_relu(_, inputs):
|
||||
def compute_relu(attrs, inputs, _):
|
||||
"""Compute definition of relu"""
|
||||
return topi.nn.relu(inputs[0])
|
||||
|
||||
@reg.register_schedule("relu")
|
||||
def schedule_relu(_, outs, target):
|
||||
"""Schedule definition of relu"""
|
||||
return schedule_elemwise(_, outs, target)
|
||||
|
||||
reg.register_schedule("relu", _fschedule_broadcast)
|
||||
reg.register_pattern("relu", OpPattern.ELEM_WISE)
|
||||
|
||||
|
||||
# flatten
|
||||
@reg.register_compute("flatten")
|
||||
def compute_flatten(attrs, inputs, _):
|
||||
"""Compute definition of flatten"""
|
||||
return topi.nn.flatten(inputs[0])
|
||||
|
||||
reg.register_schedule("flatten", _fschedule_broadcast)
|
||||
reg.register_pattern("flatten", OpPattern.COMPLEX)
|
||||
|
||||
|
||||
# softmax
|
||||
@reg.register_compute("softmax")
|
||||
def compute_softmax(attrs, inputs):
|
||||
def compute_softmax(attrs, inputs, _):
|
||||
"""Compute definition of softmax"""
|
||||
axis = attrs.get_int("axis")
|
||||
assert axis == -1, "only support axis == -1 for now"
|
||||
|
@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target):
|
|||
# naive schedule
|
||||
return tvm.create_schedule([x.op for x in outs])
|
||||
|
||||
reg.register_pattern("softmax", OpPattern.COMPLEX)
|
||||
# Mark softmax as extern as we do not fuse it in call cases
|
||||
reg.register_pattern("softmax", OpPattern.EXTERN)
|
||||
|
||||
|
||||
# dense
|
||||
@reg.register_compute("dense")
|
||||
def compute_dense(attrs, inputs, _):
|
||||
"""Compute definition of dense"""
|
||||
if attrs.get_bool("use_bias"):
|
||||
return topi.nn.fully_connected_with_bias(
|
||||
inputs[0], inputs[1], inputs[2])
|
||||
return topi.nn.fully_connected(inputs[0], inputs[1])
|
||||
|
||||
@reg.register_schedule("dense")
|
||||
def schedule_dense(_, outs, target):
|
||||
"""Schedule definition of dense"""
|
||||
if target == "cuda":
|
||||
raise ValueError("fully_connected not yet implemented")
|
||||
# naive schedule
|
||||
return tvm.create_schedule([x.op for x in outs])
|
||||
|
||||
# register extern for now, change me when fusion is enabled.
|
||||
reg.register_pattern("dense", OpPattern.EXTERN)
|
||||
|
||||
|
||||
# conv
|
||||
@reg.register_compute("conv2d")
|
||||
def compute_conv2d(attrs, inputs):
|
||||
def compute_conv2d(attrs, inputs, _):
|
||||
"""Compute definition of conv2d"""
|
||||
padding = attrs.get_int_tuple("padding")
|
||||
strides = attrs.get_int_tuple("strides")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name, unused-argument
|
||||
"""Tensor ops"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
|
@ -8,15 +8,6 @@ import topi.cuda
|
|||
from ..compiler import registry as reg
|
||||
from ..compiler import OpPattern
|
||||
|
||||
def schedule_elemwise(_, outs, target):
|
||||
"""Generic schedule for elemwise operation"""
|
||||
if target == "cuda":
|
||||
return topi.cuda.schedule_elemwise(outs)
|
||||
assert target.startswith("llvm")
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
tvm.schedule.AutoInlineInjective(s)
|
||||
return s
|
||||
|
||||
def _schedule_broadcast(_, outs, target):
|
||||
"""Generic schedule for binary bcast"""
|
||||
if target == "cuda":
|
||||
|
@ -29,7 +20,7 @@ def _schedule_broadcast(_, outs, target):
|
|||
def _compute_binary_scalar(f):
|
||||
"""auxiliary function"""
|
||||
@tvm.tag_scope("ewise")
|
||||
def _compute(attrs, x):
|
||||
def _compute(attrs, x, _):
|
||||
x = x[0]
|
||||
scalar = attrs.get_float("scalar")
|
||||
scalar = tvm.const(scalar, x.dtype)
|
||||
|
@ -37,58 +28,132 @@ def _compute_binary_scalar(f):
|
|||
return _compute
|
||||
|
||||
|
||||
def _compute_unary(f):
|
||||
"""auxiliary function"""
|
||||
def _compute(attrs, x, _):
|
||||
return f(x[0])
|
||||
return _compute
|
||||
|
||||
|
||||
def _compute_binary(f):
|
||||
"""auxiliary function"""
|
||||
def _compute(attrs, x, _):
|
||||
return f(x[0], x[1])
|
||||
return _compute
|
||||
|
||||
|
||||
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
|
||||
|
||||
# exp
|
||||
reg.register_compute("exp",
|
||||
lambda _, x: topi.exp(x[0]))
|
||||
reg.register_compute("exp", _compute_unary(topi.exp))
|
||||
reg.register_pattern("exp", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("exp", _fschedule_broadcast)
|
||||
|
||||
# sqrt
|
||||
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
|
||||
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("sqrt", _fschedule_broadcast)
|
||||
|
||||
# log
|
||||
reg.register_compute("log",
|
||||
lambda _, x: topi.log(x[0]))
|
||||
reg.register_compute("log", _compute_unary(topi.log))
|
||||
reg.register_pattern("log", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("log", _fschedule_broadcast)
|
||||
|
||||
# tanh
|
||||
reg.register_compute("tanh",
|
||||
lambda _, x: topi.tanh(x[0]))
|
||||
reg.register_compute("tanh", _compute_unary(topi.tanh))
|
||||
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("tanh", _fschedule_broadcast)
|
||||
|
||||
# negative
|
||||
reg.register_compute("negative", _compute_unary(topi.negative))
|
||||
reg.register_pattern("negative", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("negative", _fschedule_broadcast)
|
||||
|
||||
# sigmoid
|
||||
reg.register_compute("sigmoid",
|
||||
lambda _, x: topi.sigmoid(x[0]))
|
||||
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
|
||||
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("sigmoid", _fschedule_broadcast)
|
||||
|
||||
# add scalar
|
||||
# add_scalar
|
||||
reg.register_compute("__add_scalar__",
|
||||
_compute_binary_scalar(lambda x, y: x + y))
|
||||
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
|
||||
|
||||
# sub_calar
|
||||
reg.register_compute("__sub_scalar__",
|
||||
_compute_binary_scalar(lambda x, y: x - y))
|
||||
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
|
||||
|
||||
# rsub_scalar
|
||||
reg.register_compute("__rsub_scalar__",
|
||||
_compute_binary_scalar(lambda x, y: y - x))
|
||||
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
|
||||
|
||||
# mul_scalar
|
||||
reg.register_compute("__mul_scalar__",
|
||||
_compute_binary_scalar(lambda x, y: x * y))
|
||||
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
|
||||
|
||||
# div_scalar
|
||||
reg.register_compute("__div_scalar__",
|
||||
_compute_binary_scalar(lambda x, y: x / y))
|
||||
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
|
||||
|
||||
# rdiv_scalar
|
||||
reg.register_compute("__rdiv_scalar__",
|
||||
_compute_binary_scalar(lambda x, y: y / x))
|
||||
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
|
||||
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
|
||||
|
||||
# elemwise_add
|
||||
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
|
||||
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
|
||||
reg.register_schedule("elemwise_add", _fschedule_broadcast)
|
||||
|
||||
# elemwise_sub
|
||||
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
|
||||
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
|
||||
reg.register_schedule("elemwise_sub", _fschedule_broadcast)
|
||||
|
||||
# elemwise_mul
|
||||
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
|
||||
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
|
||||
reg.register_schedule("elemwise_mul", _fschedule_broadcast)
|
||||
|
||||
# elemwise_div
|
||||
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
|
||||
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
|
||||
reg.register_schedule("elemwise_div", _fschedule_broadcast)
|
||||
|
||||
# broadcast_add
|
||||
reg.register_compute("broadcast_add",
|
||||
lambda _, x: topi.broadcast_add(x[0], x[1]))
|
||||
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
|
||||
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_add", _fschedule_broadcast)
|
||||
|
||||
# broadcast_sub
|
||||
reg.register_compute("broadcast_sub",
|
||||
lambda _, x: topi.broadcast_sub(x[0], x[1]))
|
||||
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
|
||||
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_sub", _fschedule_broadcast)
|
||||
|
||||
# broadcast_mul
|
||||
reg.register_compute("broadcast_mul",
|
||||
lambda _, x: topi.broadcast_mul(x[0], x[1]))
|
||||
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
|
||||
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_mul", _fschedule_broadcast)
|
||||
|
||||
# broadcast_div
|
||||
reg.register_compute("broadcast_div",
|
||||
lambda _, x: topi.broadcast_div(x[0], x[1]))
|
||||
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
|
||||
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_div", _fschedule_broadcast)
|
||||
|
||||
# broadcast_to
|
||||
@reg.register_compute("broadcast_to")
|
||||
def compute_softmax(attrs, inputs, out_info):
|
||||
"""Compute definition of softmax"""
|
||||
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
|
||||
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
|
||||
reg.register_schedule("broadcast_to", _fschedule_broadcast)
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# pylint: disable=invalid-name, unused-argument
|
||||
"""Tensor transformation ops"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tvm
|
||||
from .tensor import _fschedule_broadcast
|
||||
from ..compiler import registry as reg
|
||||
from ..compiler import OpPattern
|
||||
|
||||
# Need add reshape, transpose
|
||||
|
||||
def _flatten_index(indices, shape):
|
||||
"""flatten the index to 1D"""
|
||||
idx = 0
|
||||
for i, value in enumerate(shape):
|
||||
if i != 0:
|
||||
idx *= value
|
||||
idx = idx + indices[i]
|
||||
return idx
|
||||
|
||||
# reshape
|
||||
@reg.register_compute("reshape")
|
||||
def compute_reshape(attrs, inputs, out_info):
|
||||
"""Compute definition of softmax"""
|
||||
# TODO(sxj) add support for general reshape
|
||||
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
|
||||
oshape = out_info[0].shape
|
||||
x = inputs[0]
|
||||
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
|
||||
reg.register_pattern("reshape", OpPattern.COMPLEX)
|
||||
reg.register_schedule("reshape", _fschedule_broadcast)
|
|
@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
|
|||
if (inode.source->is_variable()) continue;
|
||||
int root_id = group_vec[nid];
|
||||
FuseEntry& fe = fuse_vec[root_id];
|
||||
Array<Tensor> inputs;
|
||||
Array<Tensor> inputs, out_info;
|
||||
// input loading
|
||||
for (const auto& e : inode.inputs) {
|
||||
if (group_vec[e.node_id] != root_id) {
|
||||
|
@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
|
|||
inputs.push_back(t);
|
||||
}
|
||||
}
|
||||
// output hint
|
||||
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
|
||||
Array<Expr> shape;
|
||||
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
|
||||
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
|
||||
shape.push_back(make_const(Int(32), x));
|
||||
}
|
||||
out_info.push_back(
|
||||
placeholder(shape,
|
||||
TVMType2Type(dltype_vec[idx.entry_id(nid, i)])));
|
||||
}
|
||||
// get default
|
||||
Array<Tensor> out = fcompute[inode.source->op()](
|
||||
inode.source->attrs, inputs);
|
||||
inode.source->attrs, inputs, out_info);
|
||||
CHECK_EQ(out.size(), inode.source->num_outputs());
|
||||
|
||||
// schedule on root node, and use master's schedule
|
||||
if (nid != root_id) {
|
||||
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
|
||||
|
@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
tvm::runtime::Module module = fbuild(funcs, target);
|
||||
// Final step: Remap the node, with given attribute
|
||||
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
|
||||
|
|
|
@ -67,9 +67,11 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
|
|||
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
|
||||
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
|
||||
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
|
||||
auto fcompute = [f](const NodeAttrs& attrs, const Array<Tensor>& inputs)
|
||||
auto fcompute = [f](const NodeAttrs& attrs,
|
||||
const Array<Tensor>& inputs,
|
||||
const Array<Tensor>& out_info)
|
||||
-> Array<Tensor> {
|
||||
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs);
|
||||
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info);
|
||||
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
|
||||
return {ret.operator Tensor()};
|
||||
} else {
|
||||
|
|
|
@ -21,7 +21,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
|
|||
nnvm::NodeEntry beta,
|
||||
nnvm::NodeEntry moving_mean,
|
||||
nnvm::NodeEntry moving_var,
|
||||
int data_dim) {
|
||||
TShape dshape) {
|
||||
CHECK(attrs.op);
|
||||
static const Op* bn_op = Op::Get("batch_norm");
|
||||
CHECK(attrs.op == bn_op);
|
||||
|
@ -57,19 +57,12 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
|
|||
shift = MakeNode(
|
||||
"elemwise_add", bn_name + "_add_beta", {shift, beta});
|
||||
}
|
||||
// reshape to nhwc
|
||||
// use broaodcast to reshape
|
||||
std::ostringstream oshape;
|
||||
oshape << "(";
|
||||
for (int i = 0; i < data_dim; ++i) {
|
||||
if (i != 0) oshape << ", ";
|
||||
if (i == param.axis) {
|
||||
oshape << "-1";
|
||||
} else {
|
||||
oshape << "1";
|
||||
}
|
||||
for (dim_t i = 0; i < dshape.ndim(); ++i) {
|
||||
dshape[i] = (i != param.axis) ? 1 : -1;
|
||||
}
|
||||
oshape << ")";
|
||||
|
||||
oshape << dshape;
|
||||
scale = MakeNode("reshape", bn_name + "_sc_reshape",
|
||||
{scale}, {{"shape", oshape.str()}});
|
||||
shift = MakeNode("reshape", bn_name + "_sh_reshape",
|
||||
|
@ -98,7 +91,7 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
|
|||
n->inputs[2],
|
||||
n->inputs[3],
|
||||
n->inputs[4],
|
||||
shape_vec[idx.entry_id(nid, 0)].ndim());
|
||||
shape_vec[idx.entry_id(nid, 0)]);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
|
|
|
@ -73,7 +73,7 @@ void PrintGraphIR_(Graph src,
|
|||
AttrPrinter fp = GetVectorPrinter(src, key);
|
||||
auto fprint = [&idx, key, fp](
|
||||
uint32_t nid, std::ostream& os) { // NOLINT(*)
|
||||
os << key << "=";
|
||||
os << ", " << key << "=";
|
||||
fp(idx.entry_id(nid, 0), os);
|
||||
};
|
||||
trigger.push_back(fprint);
|
||||
|
|
|
@ -5,13 +5,13 @@ from nnvm.compiler import graph_util, graph_attr
|
|||
|
||||
def test_simplify_batchnorm():
|
||||
def simple_bn(x, gamma, beta, moving_mean, moving_var,
|
||||
axis=1, epsilon=1e-5, dim=2):
|
||||
axis=1, epsilon=1e-5, shape=None):
|
||||
# expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
|
||||
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
|
||||
shift = sym.elemwise_add(
|
||||
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
|
||||
shape = [-1 if i == axis else 1 for i in range(len(shape))]
|
||||
# for 2D
|
||||
shape = tuple(1 if i != axis else -1 for i in range(dim))
|
||||
scale = sym.reshape(scale, shape=shape)
|
||||
shift = sym.reshape(shift, shape=shape)
|
||||
return x * scale + shift
|
||||
|
@ -26,15 +26,14 @@ def test_simplify_batchnorm():
|
|||
moving_var = sym.Variable("moving_var")
|
||||
moving_mean = sym.Variable("moving_mean")
|
||||
y1, y2 = x, x
|
||||
|
||||
ishape = {"x": tuple(10 for i in range(dim))}
|
||||
for i in range(nstep):
|
||||
y1 = sym.batch_norm(
|
||||
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
|
||||
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
|
||||
epsilon=eps, axis=axis, dim=dim)
|
||||
epsilon=eps, axis=axis, shape=ishape["x"])
|
||||
g = nnvm.graph.create(y1)
|
||||
g2 = nnvm.graph.create(y2)
|
||||
ishape = {"x": tuple(10 for i in range(dim))}
|
||||
graph_attr.set_shape_inputs(g, ishape)
|
||||
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
|
||||
# Some prints for debug
|
||||
|
|
|
@ -6,19 +6,10 @@ import nnvm.symbol as sym
|
|||
import nnvm.compiler
|
||||
import nnvm.runtime
|
||||
|
||||
USE_GPU=True
|
||||
def ctx_list():
|
||||
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
|
||||
return [x for x in res if x[1].exist]
|
||||
|
||||
def default_target():
|
||||
if USE_GPU:
|
||||
return 'cuda'
|
||||
else:
|
||||
return 'llvm'
|
||||
|
||||
def default_ctx():
|
||||
if USE_GPU:
|
||||
return tvm.gpu(0)
|
||||
else:
|
||||
return tvm.cpu(0)
|
||||
|
||||
def test_relu():
|
||||
x = sym.Variable("x")
|
||||
|
@ -26,20 +17,21 @@ def test_relu():
|
|||
dtype = "float32"
|
||||
dshape = (1, 3, 32, 32)
|
||||
oshape = dshape
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.maximum(data.asnumpy(), 0.0)
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.maximum(data.asnumpy(), 0.0)
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_exp():
|
||||
|
@ -48,20 +40,21 @@ def test_exp():
|
|||
dtype = "float32"
|
||||
dshape = (1, 3, 32, 32)
|
||||
oshape = dshape
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.exp(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.exp(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_log():
|
||||
|
@ -70,21 +63,22 @@ def test_log():
|
|||
dtype = "float32"
|
||||
dshape = (1, 3, 32, 32)
|
||||
oshape = dshape
|
||||
with nnvm.compiler.build_config(opt_level=1):
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.log(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
with nnvm.compiler.build_config(opt_level=1):
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.log(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_tanh():
|
||||
|
@ -93,21 +87,22 @@ def test_tanh():
|
|||
dtype = "float32"
|
||||
dshape = (1, 3, 32, 32)
|
||||
oshape = dshape
|
||||
with nnvm.compiler.build_config(opt_level=1):
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
with nnvm.compiler.build_config(opt_level=1):
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_sigmoid():
|
||||
|
@ -116,20 +111,21 @@ def test_sigmoid():
|
|||
dtype = "float32"
|
||||
dshape = (1, 3, 32, 32)
|
||||
oshape = dshape
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_softmax():
|
||||
|
@ -138,24 +134,79 @@ def test_softmax():
|
|||
dtype = "float32"
|
||||
dshape = (10, 1000)
|
||||
oshape = dshape
|
||||
with nnvm.compiler.build_config(opt_level=1):
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = topi.testing.softmax_python(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
with nnvm.compiler.build_config(opt_level=1):
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
y_np = topi.testing.softmax_python(data.asnumpy())
|
||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_dense():
|
||||
x = sym.Variable("x")
|
||||
y = sym.dense(x, units=3, name="dense")
|
||||
y = sym.flatten(y)
|
||||
dtype = "float32"
|
||||
shape = {
|
||||
"x" : (10, 100),
|
||||
"dense_weight" : (3, 100),
|
||||
"dense_bias" : (3,),
|
||||
}
|
||||
graph, lib, _ = nnvm.compiler.build(y, "llvm", shape)
|
||||
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
|
||||
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
|
||||
w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype)
|
||||
b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype)
|
||||
res = tvm.nd.empty((10, 3))
|
||||
m.run(x=x_np, dense_weight=w_np, dense_bias=b_np)
|
||||
m.get_output(0, res)
|
||||
res_np = np.dot(x_np, w_np.T) + b_np
|
||||
np.testing.assert_allclose(
|
||||
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_batchnorm():
|
||||
x = sym.Variable("x")
|
||||
beta = sym.Variable("beta")
|
||||
gamma = sym.Variable("gamma")
|
||||
moving_var = sym.Variable("moving_var")
|
||||
moving_mean = sym.Variable("moving_mean")
|
||||
shape = (10, 20)
|
||||
eps = 1e-5
|
||||
dtype = "float32"
|
||||
y = sym.batch_norm(
|
||||
x, gamma, beta, moving_mean, moving_var, epsilon=eps)
|
||||
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape})
|
||||
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
|
||||
x_np = np.random.uniform(size=shape).astype(dtype)
|
||||
mean_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||
var_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||
gamma_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||
beta_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||
res = tvm.nd.empty(shape)
|
||||
m.run(x=x_np, moving_mean=mean_np, moving_var=var_np,
|
||||
gamma=gamma_np, beta=beta_np)
|
||||
m.get_output(0, res)
|
||||
res_np = (x_np - mean_np) / np.sqrt(var_np + eps) * gamma_np + beta_np
|
||||
np.testing.assert_allclose(
|
||||
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batchnorm()
|
||||
test_dense()
|
||||
test_relu()
|
||||
test_exp()
|
||||
test_log()
|
||||
|
|
|
@ -6,19 +6,9 @@ import nnvm.symbol as sym
|
|||
import nnvm.compiler
|
||||
import nnvm.runtime
|
||||
|
||||
USE_GPU=True
|
||||
|
||||
def default_target():
|
||||
if USE_GPU:
|
||||
return 'cuda'
|
||||
else:
|
||||
return 'llvm'
|
||||
|
||||
def default_ctx():
|
||||
if USE_GPU:
|
||||
return tvm.gpu(0)
|
||||
else:
|
||||
return tvm.cpu(0)
|
||||
def ctx_list():
|
||||
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
|
||||
return [x for x in res if x[1].exist]
|
||||
|
||||
def test_conv2d():
|
||||
x = sym.Variable("x")
|
||||
|
@ -29,23 +19,24 @@ def test_conv2d():
|
|||
kshape = (10, 3, 3, 3)
|
||||
oshape = (1, 10, 18, 18)
|
||||
shape_dict = {"x": dshape}
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict)
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
set_input("y_weight", kernel)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
c_np = topi.testing.conv2d_nchw_python(
|
||||
data.asnumpy(), kernel.asnumpy(), 1, 1)
|
||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
set_input("y_weight", kernel)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
c_np = topi.testing.conv2d_nchw_python(
|
||||
data.asnumpy(), kernel.asnumpy(), 1, 1)
|
||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
|
||||
def test_grouped_conv2d():
|
||||
|
@ -57,23 +48,24 @@ def test_grouped_conv2d():
|
|||
kshape = (32, 1, 3, 3)
|
||||
oshape = (1, 32, 18, 18)
|
||||
shape_dict = {"x": dshape}
|
||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict)
|
||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
set_input("y_weight", kernel)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
c_np = topi.testing.depthwise_conv2d_python_nchw(
|
||||
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
|
||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||
for target, ctx in ctx_list():
|
||||
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
|
||||
m = nnvm.runtime.create(graph, lib, ctx)
|
||||
# get member functions
|
||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||
# set input
|
||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||
set_input("x", data)
|
||||
set_input("y_weight", kernel)
|
||||
# execute
|
||||
run()
|
||||
# get output
|
||||
out = tvm.nd.empty(oshape, dtype)
|
||||
get_output(0, out)
|
||||
c_np = topi.testing.depthwise_conv2d_python_nchw(
|
||||
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
|
||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Загрузка…
Ссылка в новой задаче