[TOP] Add dense, batchnorm (#22)
* [TOP] Add dense, batchnorm * update tvm
This commit is contained in:
Родитель
a2ab3d83a7
Коммит
215693df04
|
@ -44,11 +44,14 @@ using TOpPattern = int;
|
||||||
* \brief Computation description interface
|
* \brief Computation description interface
|
||||||
* \param attrs The attribute of the node.
|
* \param attrs The attribute of the node.
|
||||||
* \param inputs The input tensors(placeholders)
|
* \param inputs The input tensors(placeholders)
|
||||||
|
* \param out_info Tensors holding shape/type information about output,
|
||||||
|
& these are always placeholders.
|
||||||
* \return The output description of the tensor.
|
* \return The output description of the tensor.
|
||||||
*/
|
*/
|
||||||
using FTVMCompute = std::function<
|
using FTVMCompute = std::function<
|
||||||
Array<Tensor>
|
Array<Tensor>(const NodeAttrs& attrs,
|
||||||
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
|
const Array<Tensor>& inputs,
|
||||||
|
const Array<Tensor>& out_info)>;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Build the computation schedule for
|
* \brief Build the computation schedule for
|
||||||
|
|
|
@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"):
|
||||||
"""
|
"""
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
cfg = BuildConfig.current
|
cfg = BuildConfig.current
|
||||||
|
graph = graph_attr.set_shape_inputs(graph, shape)
|
||||||
|
graph = graph.apply("InferShape")
|
||||||
|
if graph.json_attr("shape_num_unknown_nodes"):
|
||||||
|
raise ValueError("InferShape fails..")
|
||||||
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
|
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
|
||||||
graph = graph_attr.set_shape_inputs(graph, shape)
|
graph = graph.apply("SimplifyBatchNormInference")
|
||||||
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None):
|
||||||
cfg = BuildConfig.current
|
cfg = BuildConfig.current
|
||||||
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
|
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
|
||||||
shape, dtype = _update_shape_dtype(shape, dtype, params)
|
shape, dtype = _update_shape_dtype(shape, dtype, params)
|
||||||
|
# Initial pass do shape type inference
|
||||||
|
ishape, _ = graph_util.infer_shape(graph, **shape)
|
||||||
|
shape.update(zip(graph.index.input_names, ishape))
|
||||||
|
if not isinstance(dtype, str):
|
||||||
|
idtype, _ = graph_util.infer_dtype(graph, **dtype)
|
||||||
|
dtype.update(zip(graph.index.input_names, idtype))
|
||||||
# Apply optimization
|
# Apply optimization
|
||||||
graph = optimize(graph, shape, dtype)
|
graph = optimize(graph, shape, dtype)
|
||||||
# Precompute prune
|
# Precompute prune
|
||||||
|
|
|
@ -5,8 +5,10 @@ import tvm
|
||||||
class OpPattern(object):
|
class OpPattern(object):
|
||||||
ELEM_WISE = 0
|
ELEM_WISE = 0
|
||||||
BROADCAST = 1
|
BROADCAST = 1
|
||||||
|
# Complex means we can fuse elemwise to it
|
||||||
COMPLEX = 2
|
COMPLEX = 2
|
||||||
EXTERN = 2
|
# Extern means the op is not fusable
|
||||||
|
EXTERN = 3
|
||||||
|
|
||||||
_register_compute = tvm.get_global_func("nnvm._register_compute")
|
_register_compute = tvm.get_global_func("nnvm._register_compute")
|
||||||
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
|
_register_schedule = tvm.get_global_func("nnvm._register_schedule")
|
||||||
|
|
|
@ -2,3 +2,4 @@
|
||||||
from .attr_dict import AttrDict
|
from .attr_dict import AttrDict
|
||||||
from . import tensor
|
from . import tensor
|
||||||
from . import nn
|
from . import nn
|
||||||
|
from . import transform
|
||||||
|
|
|
@ -1,30 +1,37 @@
|
||||||
|
# pylint: disable=invalid-name, unused-argument
|
||||||
"""Definition of nn ops"""
|
"""Definition of nn ops"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import tvm
|
import tvm
|
||||||
import topi
|
import topi
|
||||||
from topi.util import get_const_int
|
from topi.util import get_const_int
|
||||||
from .tensor import schedule_elemwise
|
from .tensor import _fschedule_broadcast
|
||||||
from ..compiler import registry as reg
|
from ..compiler import registry as reg
|
||||||
from ..compiler import OpPattern
|
from ..compiler import OpPattern
|
||||||
|
|
||||||
# relu
|
# relu
|
||||||
@reg.register_compute("relu")
|
@reg.register_compute("relu")
|
||||||
def compute_relu(_, inputs):
|
def compute_relu(attrs, inputs, _):
|
||||||
"""Compute definition of relu"""
|
"""Compute definition of relu"""
|
||||||
return topi.nn.relu(inputs[0])
|
return topi.nn.relu(inputs[0])
|
||||||
|
|
||||||
@reg.register_schedule("relu")
|
reg.register_schedule("relu", _fschedule_broadcast)
|
||||||
def schedule_relu(_, outs, target):
|
|
||||||
"""Schedule definition of relu"""
|
|
||||||
return schedule_elemwise(_, outs, target)
|
|
||||||
|
|
||||||
reg.register_pattern("relu", OpPattern.ELEM_WISE)
|
reg.register_pattern("relu", OpPattern.ELEM_WISE)
|
||||||
|
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
@reg.register_compute("flatten")
|
||||||
|
def compute_flatten(attrs, inputs, _):
|
||||||
|
"""Compute definition of flatten"""
|
||||||
|
return topi.nn.flatten(inputs[0])
|
||||||
|
|
||||||
|
reg.register_schedule("flatten", _fschedule_broadcast)
|
||||||
|
reg.register_pattern("flatten", OpPattern.COMPLEX)
|
||||||
|
|
||||||
|
|
||||||
# softmax
|
# softmax
|
||||||
@reg.register_compute("softmax")
|
@reg.register_compute("softmax")
|
||||||
def compute_softmax(attrs, inputs):
|
def compute_softmax(attrs, inputs, _):
|
||||||
"""Compute definition of softmax"""
|
"""Compute definition of softmax"""
|
||||||
axis = attrs.get_int("axis")
|
axis = attrs.get_int("axis")
|
||||||
assert axis == -1, "only support axis == -1 for now"
|
assert axis == -1, "only support axis == -1 for now"
|
||||||
|
@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target):
|
||||||
# naive schedule
|
# naive schedule
|
||||||
return tvm.create_schedule([x.op for x in outs])
|
return tvm.create_schedule([x.op for x in outs])
|
||||||
|
|
||||||
reg.register_pattern("softmax", OpPattern.COMPLEX)
|
# Mark softmax as extern as we do not fuse it in call cases
|
||||||
|
reg.register_pattern("softmax", OpPattern.EXTERN)
|
||||||
|
|
||||||
|
|
||||||
|
# dense
|
||||||
|
@reg.register_compute("dense")
|
||||||
|
def compute_dense(attrs, inputs, _):
|
||||||
|
"""Compute definition of dense"""
|
||||||
|
if attrs.get_bool("use_bias"):
|
||||||
|
return topi.nn.fully_connected_with_bias(
|
||||||
|
inputs[0], inputs[1], inputs[2])
|
||||||
|
return topi.nn.fully_connected(inputs[0], inputs[1])
|
||||||
|
|
||||||
|
@reg.register_schedule("dense")
|
||||||
|
def schedule_dense(_, outs, target):
|
||||||
|
"""Schedule definition of dense"""
|
||||||
|
if target == "cuda":
|
||||||
|
raise ValueError("fully_connected not yet implemented")
|
||||||
|
# naive schedule
|
||||||
|
return tvm.create_schedule([x.op for x in outs])
|
||||||
|
|
||||||
|
# register extern for now, change me when fusion is enabled.
|
||||||
|
reg.register_pattern("dense", OpPattern.EXTERN)
|
||||||
|
|
||||||
|
|
||||||
# conv
|
# conv
|
||||||
@reg.register_compute("conv2d")
|
@reg.register_compute("conv2d")
|
||||||
def compute_conv2d(attrs, inputs):
|
def compute_conv2d(attrs, inputs, _):
|
||||||
"""Compute definition of conv2d"""
|
"""Compute definition of conv2d"""
|
||||||
padding = attrs.get_int_tuple("padding")
|
padding = attrs.get_int_tuple("padding")
|
||||||
strides = attrs.get_int_tuple("strides")
|
strides = attrs.get_int_tuple("strides")
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name, unused-argument
|
||||||
"""Tensor ops"""
|
"""Tensor ops"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
@ -8,15 +8,6 @@ import topi.cuda
|
||||||
from ..compiler import registry as reg
|
from ..compiler import registry as reg
|
||||||
from ..compiler import OpPattern
|
from ..compiler import OpPattern
|
||||||
|
|
||||||
def schedule_elemwise(_, outs, target):
|
|
||||||
"""Generic schedule for elemwise operation"""
|
|
||||||
if target == "cuda":
|
|
||||||
return topi.cuda.schedule_elemwise(outs)
|
|
||||||
assert target.startswith("llvm")
|
|
||||||
s = tvm.create_schedule([x.op for x in outs])
|
|
||||||
tvm.schedule.AutoInlineInjective(s)
|
|
||||||
return s
|
|
||||||
|
|
||||||
def _schedule_broadcast(_, outs, target):
|
def _schedule_broadcast(_, outs, target):
|
||||||
"""Generic schedule for binary bcast"""
|
"""Generic schedule for binary bcast"""
|
||||||
if target == "cuda":
|
if target == "cuda":
|
||||||
|
@ -29,7 +20,7 @@ def _schedule_broadcast(_, outs, target):
|
||||||
def _compute_binary_scalar(f):
|
def _compute_binary_scalar(f):
|
||||||
"""auxiliary function"""
|
"""auxiliary function"""
|
||||||
@tvm.tag_scope("ewise")
|
@tvm.tag_scope("ewise")
|
||||||
def _compute(attrs, x):
|
def _compute(attrs, x, _):
|
||||||
x = x[0]
|
x = x[0]
|
||||||
scalar = attrs.get_float("scalar")
|
scalar = attrs.get_float("scalar")
|
||||||
scalar = tvm.const(scalar, x.dtype)
|
scalar = tvm.const(scalar, x.dtype)
|
||||||
|
@ -37,58 +28,132 @@ def _compute_binary_scalar(f):
|
||||||
return _compute
|
return _compute
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_unary(f):
|
||||||
|
"""auxiliary function"""
|
||||||
|
def _compute(attrs, x, _):
|
||||||
|
return f(x[0])
|
||||||
|
return _compute
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_binary(f):
|
||||||
|
"""auxiliary function"""
|
||||||
|
def _compute(attrs, x, _):
|
||||||
|
return f(x[0], x[1])
|
||||||
|
return _compute
|
||||||
|
|
||||||
|
|
||||||
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
|
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
|
||||||
|
|
||||||
# exp
|
# exp
|
||||||
reg.register_compute("exp",
|
reg.register_compute("exp", _compute_unary(topi.exp))
|
||||||
lambda _, x: topi.exp(x[0]))
|
|
||||||
reg.register_pattern("exp", OpPattern.ELEM_WISE)
|
reg.register_pattern("exp", OpPattern.ELEM_WISE)
|
||||||
reg.register_schedule("exp", _fschedule_broadcast)
|
reg.register_schedule("exp", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# sqrt
|
||||||
|
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
|
||||||
|
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("sqrt", _fschedule_broadcast)
|
||||||
|
|
||||||
# log
|
# log
|
||||||
reg.register_compute("log",
|
reg.register_compute("log", _compute_unary(topi.log))
|
||||||
lambda _, x: topi.log(x[0]))
|
|
||||||
reg.register_pattern("log", OpPattern.ELEM_WISE)
|
reg.register_pattern("log", OpPattern.ELEM_WISE)
|
||||||
reg.register_schedule("log", _fschedule_broadcast)
|
reg.register_schedule("log", _fschedule_broadcast)
|
||||||
|
|
||||||
# tanh
|
# tanh
|
||||||
reg.register_compute("tanh",
|
reg.register_compute("tanh", _compute_unary(topi.tanh))
|
||||||
lambda _, x: topi.tanh(x[0]))
|
|
||||||
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
|
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
|
||||||
reg.register_schedule("tanh", _fschedule_broadcast)
|
reg.register_schedule("tanh", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# negative
|
||||||
|
reg.register_compute("negative", _compute_unary(topi.negative))
|
||||||
|
reg.register_pattern("negative", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("negative", _fschedule_broadcast)
|
||||||
|
|
||||||
# sigmoid
|
# sigmoid
|
||||||
reg.register_compute("sigmoid",
|
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
|
||||||
lambda _, x: topi.sigmoid(x[0]))
|
|
||||||
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
|
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
|
||||||
reg.register_schedule("sigmoid", _fschedule_broadcast)
|
reg.register_schedule("sigmoid", _fschedule_broadcast)
|
||||||
|
|
||||||
# add scalar
|
# add_scalar
|
||||||
reg.register_compute("__add_scalar__",
|
reg.register_compute("__add_scalar__",
|
||||||
_compute_binary_scalar(lambda x, y: x + y))
|
_compute_binary_scalar(lambda x, y: x + y))
|
||||||
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
|
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
|
||||||
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
|
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# sub_calar
|
||||||
|
reg.register_compute("__sub_scalar__",
|
||||||
|
_compute_binary_scalar(lambda x, y: x - y))
|
||||||
|
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# rsub_scalar
|
||||||
|
reg.register_compute("__rsub_scalar__",
|
||||||
|
_compute_binary_scalar(lambda x, y: y - x))
|
||||||
|
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# mul_scalar
|
||||||
|
reg.register_compute("__mul_scalar__",
|
||||||
|
_compute_binary_scalar(lambda x, y: x * y))
|
||||||
|
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# div_scalar
|
||||||
|
reg.register_compute("__div_scalar__",
|
||||||
|
_compute_binary_scalar(lambda x, y: x / y))
|
||||||
|
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# rdiv_scalar
|
||||||
|
reg.register_compute("__rdiv_scalar__",
|
||||||
|
_compute_binary_scalar(lambda x, y: y / x))
|
||||||
|
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
|
||||||
|
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# elemwise_add
|
||||||
|
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
|
||||||
|
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
|
||||||
|
reg.register_schedule("elemwise_add", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# elemwise_sub
|
||||||
|
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
|
||||||
|
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
|
||||||
|
reg.register_schedule("elemwise_sub", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# elemwise_mul
|
||||||
|
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
|
||||||
|
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
|
||||||
|
reg.register_schedule("elemwise_mul", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# elemwise_div
|
||||||
|
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
|
||||||
|
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
|
||||||
|
reg.register_schedule("elemwise_div", _fschedule_broadcast)
|
||||||
|
|
||||||
# broadcast_add
|
# broadcast_add
|
||||||
reg.register_compute("broadcast_add",
|
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
|
||||||
lambda _, x: topi.broadcast_add(x[0], x[1]))
|
|
||||||
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
|
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
|
||||||
reg.register_schedule("broadcast_add", _fschedule_broadcast)
|
reg.register_schedule("broadcast_add", _fschedule_broadcast)
|
||||||
|
|
||||||
# broadcast_sub
|
# broadcast_sub
|
||||||
reg.register_compute("broadcast_sub",
|
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
|
||||||
lambda _, x: topi.broadcast_sub(x[0], x[1]))
|
|
||||||
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
|
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
|
||||||
reg.register_schedule("broadcast_sub", _fschedule_broadcast)
|
reg.register_schedule("broadcast_sub", _fschedule_broadcast)
|
||||||
|
|
||||||
# broadcast_mul
|
# broadcast_mul
|
||||||
reg.register_compute("broadcast_mul",
|
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
|
||||||
lambda _, x: topi.broadcast_mul(x[0], x[1]))
|
|
||||||
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
|
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
|
||||||
reg.register_schedule("broadcast_mul", _fschedule_broadcast)
|
reg.register_schedule("broadcast_mul", _fschedule_broadcast)
|
||||||
|
|
||||||
# broadcast_div
|
# broadcast_div
|
||||||
reg.register_compute("broadcast_div",
|
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
|
||||||
lambda _, x: topi.broadcast_div(x[0], x[1]))
|
|
||||||
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
|
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
|
||||||
reg.register_schedule("broadcast_div", _fschedule_broadcast)
|
reg.register_schedule("broadcast_div", _fschedule_broadcast)
|
||||||
|
|
||||||
|
# broadcast_to
|
||||||
|
@reg.register_compute("broadcast_to")
|
||||||
|
def compute_softmax(attrs, inputs, out_info):
|
||||||
|
"""Compute definition of softmax"""
|
||||||
|
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
|
||||||
|
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
|
||||||
|
reg.register_schedule("broadcast_to", _fschedule_broadcast)
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# pylint: disable=invalid-name, unused-argument
|
||||||
|
"""Tensor transformation ops"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tvm
|
||||||
|
from .tensor import _fschedule_broadcast
|
||||||
|
from ..compiler import registry as reg
|
||||||
|
from ..compiler import OpPattern
|
||||||
|
|
||||||
|
# Need add reshape, transpose
|
||||||
|
|
||||||
|
def _flatten_index(indices, shape):
|
||||||
|
"""flatten the index to 1D"""
|
||||||
|
idx = 0
|
||||||
|
for i, value in enumerate(shape):
|
||||||
|
if i != 0:
|
||||||
|
idx *= value
|
||||||
|
idx = idx + indices[i]
|
||||||
|
return idx
|
||||||
|
|
||||||
|
# reshape
|
||||||
|
@reg.register_compute("reshape")
|
||||||
|
def compute_reshape(attrs, inputs, out_info):
|
||||||
|
"""Compute definition of softmax"""
|
||||||
|
# TODO(sxj) add support for general reshape
|
||||||
|
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
|
||||||
|
oshape = out_info[0].shape
|
||||||
|
x = inputs[0]
|
||||||
|
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
|
||||||
|
reg.register_pattern("reshape", OpPattern.COMPLEX)
|
||||||
|
reg.register_schedule("reshape", _fschedule_broadcast)
|
|
@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
|
||||||
if (inode.source->is_variable()) continue;
|
if (inode.source->is_variable()) continue;
|
||||||
int root_id = group_vec[nid];
|
int root_id = group_vec[nid];
|
||||||
FuseEntry& fe = fuse_vec[root_id];
|
FuseEntry& fe = fuse_vec[root_id];
|
||||||
Array<Tensor> inputs;
|
Array<Tensor> inputs, out_info;
|
||||||
// input loading
|
// input loading
|
||||||
for (const auto& e : inode.inputs) {
|
for (const auto& e : inode.inputs) {
|
||||||
if (group_vec[e.node_id] != root_id) {
|
if (group_vec[e.node_id] != root_id) {
|
||||||
|
@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
|
||||||
inputs.push_back(t);
|
inputs.push_back(t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// output hint
|
||||||
|
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
|
||||||
|
Array<Expr> shape;
|
||||||
|
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
|
||||||
|
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
|
||||||
|
shape.push_back(make_const(Int(32), x));
|
||||||
|
}
|
||||||
|
out_info.push_back(
|
||||||
|
placeholder(shape,
|
||||||
|
TVMType2Type(dltype_vec[idx.entry_id(nid, i)])));
|
||||||
|
}
|
||||||
// get default
|
// get default
|
||||||
Array<Tensor> out = fcompute[inode.source->op()](
|
Array<Tensor> out = fcompute[inode.source->op()](
|
||||||
inode.source->attrs, inputs);
|
inode.source->attrs, inputs, out_info);
|
||||||
CHECK_EQ(out.size(), inode.source->num_outputs());
|
CHECK_EQ(out.size(), inode.source->num_outputs());
|
||||||
|
|
||||||
// schedule on root node, and use master's schedule
|
// schedule on root node, and use master's schedule
|
||||||
if (nid != root_id) {
|
if (nid != root_id) {
|
||||||
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
|
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
|
||||||
|
@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tvm::runtime::Module module = fbuild(funcs, target);
|
tvm::runtime::Module module = fbuild(funcs, target);
|
||||||
// Final step: Remap the node, with given attribute
|
// Final step: Remap the node, with given attribute
|
||||||
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
|
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
|
||||||
|
|
|
@ -67,9 +67,11 @@ TVM_REGISTER_GLOBAL("nnvm._register_compute")
|
||||||
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
|
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
|
||||||
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
|
PackedFunc* f = new PackedFunc(args[1].operator PackedFunc());
|
||||||
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
|
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
|
||||||
auto fcompute = [f](const NodeAttrs& attrs, const Array<Tensor>& inputs)
|
auto fcompute = [f](const NodeAttrs& attrs,
|
||||||
|
const Array<Tensor>& inputs,
|
||||||
|
const Array<Tensor>& out_info)
|
||||||
-> Array<Tensor> {
|
-> Array<Tensor> {
|
||||||
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs);
|
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, out_info);
|
||||||
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
|
if ((*ret.ptr<std::shared_ptr<tvm::Node> >())->derived_from<tvm::TensorNode>()) {
|
||||||
return {ret.operator Tensor()};
|
return {ret.operator Tensor()};
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -21,7 +21,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
|
||||||
nnvm::NodeEntry beta,
|
nnvm::NodeEntry beta,
|
||||||
nnvm::NodeEntry moving_mean,
|
nnvm::NodeEntry moving_mean,
|
||||||
nnvm::NodeEntry moving_var,
|
nnvm::NodeEntry moving_var,
|
||||||
int data_dim) {
|
TShape dshape) {
|
||||||
CHECK(attrs.op);
|
CHECK(attrs.op);
|
||||||
static const Op* bn_op = Op::Get("batch_norm");
|
static const Op* bn_op = Op::Get("batch_norm");
|
||||||
CHECK(attrs.op == bn_op);
|
CHECK(attrs.op == bn_op);
|
||||||
|
@ -57,19 +57,12 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
|
||||||
shift = MakeNode(
|
shift = MakeNode(
|
||||||
"elemwise_add", bn_name + "_add_beta", {shift, beta});
|
"elemwise_add", bn_name + "_add_beta", {shift, beta});
|
||||||
}
|
}
|
||||||
// reshape to nhwc
|
// use broaodcast to reshape
|
||||||
std::ostringstream oshape;
|
std::ostringstream oshape;
|
||||||
oshape << "(";
|
for (dim_t i = 0; i < dshape.ndim(); ++i) {
|
||||||
for (int i = 0; i < data_dim; ++i) {
|
dshape[i] = (i != param.axis) ? 1 : -1;
|
||||||
if (i != 0) oshape << ", ";
|
|
||||||
if (i == param.axis) {
|
|
||||||
oshape << "-1";
|
|
||||||
} else {
|
|
||||||
oshape << "1";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
oshape << ")";
|
oshape << dshape;
|
||||||
|
|
||||||
scale = MakeNode("reshape", bn_name + "_sc_reshape",
|
scale = MakeNode("reshape", bn_name + "_sc_reshape",
|
||||||
{scale}, {{"shape", oshape.str()}});
|
{scale}, {{"shape", oshape.str()}});
|
||||||
shift = MakeNode("reshape", bn_name + "_sh_reshape",
|
shift = MakeNode("reshape", bn_name + "_sh_reshape",
|
||||||
|
@ -98,7 +91,7 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
|
||||||
n->inputs[2],
|
n->inputs[2],
|
||||||
n->inputs[3],
|
n->inputs[3],
|
||||||
n->inputs[4],
|
n->inputs[4],
|
||||||
shape_vec[idx.entry_id(nid, 0)].ndim());
|
shape_vec[idx.entry_id(nid, 0)]);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -73,7 +73,7 @@ void PrintGraphIR_(Graph src,
|
||||||
AttrPrinter fp = GetVectorPrinter(src, key);
|
AttrPrinter fp = GetVectorPrinter(src, key);
|
||||||
auto fprint = [&idx, key, fp](
|
auto fprint = [&idx, key, fp](
|
||||||
uint32_t nid, std::ostream& os) { // NOLINT(*)
|
uint32_t nid, std::ostream& os) { // NOLINT(*)
|
||||||
os << key << "=";
|
os << ", " << key << "=";
|
||||||
fp(idx.entry_id(nid, 0), os);
|
fp(idx.entry_id(nid, 0), os);
|
||||||
};
|
};
|
||||||
trigger.push_back(fprint);
|
trigger.push_back(fprint);
|
||||||
|
|
|
@ -5,13 +5,13 @@ from nnvm.compiler import graph_util, graph_attr
|
||||||
|
|
||||||
def test_simplify_batchnorm():
|
def test_simplify_batchnorm():
|
||||||
def simple_bn(x, gamma, beta, moving_mean, moving_var,
|
def simple_bn(x, gamma, beta, moving_mean, moving_var,
|
||||||
axis=1, epsilon=1e-5, dim=2):
|
axis=1, epsilon=1e-5, shape=None):
|
||||||
# expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
|
# expect = (x - moving_mean) / sym.sqrt(moving_var + eps) * gamma + beta
|
||||||
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
|
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
|
||||||
shift = sym.elemwise_add(
|
shift = sym.elemwise_add(
|
||||||
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
|
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
|
||||||
|
shape = [-1 if i == axis else 1 for i in range(len(shape))]
|
||||||
# for 2D
|
# for 2D
|
||||||
shape = tuple(1 if i != axis else -1 for i in range(dim))
|
|
||||||
scale = sym.reshape(scale, shape=shape)
|
scale = sym.reshape(scale, shape=shape)
|
||||||
shift = sym.reshape(shift, shape=shape)
|
shift = sym.reshape(shift, shape=shape)
|
||||||
return x * scale + shift
|
return x * scale + shift
|
||||||
|
@ -26,15 +26,14 @@ def test_simplify_batchnorm():
|
||||||
moving_var = sym.Variable("moving_var")
|
moving_var = sym.Variable("moving_var")
|
||||||
moving_mean = sym.Variable("moving_mean")
|
moving_mean = sym.Variable("moving_mean")
|
||||||
y1, y2 = x, x
|
y1, y2 = x, x
|
||||||
|
ishape = {"x": tuple(10 for i in range(dim))}
|
||||||
for i in range(nstep):
|
for i in range(nstep):
|
||||||
y1 = sym.batch_norm(
|
y1 = sym.batch_norm(
|
||||||
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
|
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
|
||||||
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
|
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
|
||||||
epsilon=eps, axis=axis, dim=dim)
|
epsilon=eps, axis=axis, shape=ishape["x"])
|
||||||
g = nnvm.graph.create(y1)
|
g = nnvm.graph.create(y1)
|
||||||
g2 = nnvm.graph.create(y2)
|
g2 = nnvm.graph.create(y2)
|
||||||
ishape = {"x": tuple(10 for i in range(dim))}
|
|
||||||
graph_attr.set_shape_inputs(g, ishape)
|
graph_attr.set_shape_inputs(g, ishape)
|
||||||
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
|
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
|
||||||
# Some prints for debug
|
# Some prints for debug
|
||||||
|
|
|
@ -6,19 +6,10 @@ import nnvm.symbol as sym
|
||||||
import nnvm.compiler
|
import nnvm.compiler
|
||||||
import nnvm.runtime
|
import nnvm.runtime
|
||||||
|
|
||||||
USE_GPU=True
|
def ctx_list():
|
||||||
|
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
|
||||||
|
return [x for x in res if x[1].exist]
|
||||||
|
|
||||||
def default_target():
|
|
||||||
if USE_GPU:
|
|
||||||
return 'cuda'
|
|
||||||
else:
|
|
||||||
return 'llvm'
|
|
||||||
|
|
||||||
def default_ctx():
|
|
||||||
if USE_GPU:
|
|
||||||
return tvm.gpu(0)
|
|
||||||
else:
|
|
||||||
return tvm.cpu(0)
|
|
||||||
|
|
||||||
def test_relu():
|
def test_relu():
|
||||||
x = sym.Variable("x")
|
x = sym.Variable("x")
|
||||||
|
@ -26,20 +17,21 @@ def test_relu():
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
dshape = (1, 3, 32, 32)
|
dshape = (1, 3, 32, 32)
|
||||||
oshape = dshape
|
oshape = dshape
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
for target, ctx in ctx_list():
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
set_input("x", data)
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
# execute
|
set_input("x", data)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
y_np = np.maximum(data.asnumpy(), 0.0)
|
get_output(0, out)
|
||||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
y_np = np.maximum(data.asnumpy(), 0.0)
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_exp():
|
def test_exp():
|
||||||
|
@ -48,20 +40,21 @@ def test_exp():
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
dshape = (1, 3, 32, 32)
|
dshape = (1, 3, 32, 32)
|
||||||
oshape = dshape
|
oshape = dshape
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
for target, ctx in ctx_list():
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
set_input("x", data)
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
# execute
|
set_input("x", data)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
y_np = np.exp(data.asnumpy())
|
get_output(0, out)
|
||||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
y_np = np.exp(data.asnumpy())
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_log():
|
def test_log():
|
||||||
|
@ -70,21 +63,22 @@ def test_log():
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
dshape = (1, 3, 32, 32)
|
dshape = (1, 3, 32, 32)
|
||||||
oshape = dshape
|
oshape = dshape
|
||||||
with nnvm.compiler.build_config(opt_level=1):
|
for target, ctx in ctx_list():
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
with nnvm.compiler.build_config(opt_level=1):
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
set_input("x", data)
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
# execute
|
set_input("x", data)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
y_np = np.log(data.asnumpy())
|
get_output(0, out)
|
||||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
y_np = np.log(data.asnumpy())
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_tanh():
|
def test_tanh():
|
||||||
|
@ -93,21 +87,22 @@ def test_tanh():
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
dshape = (1, 3, 32, 32)
|
dshape = (1, 3, 32, 32)
|
||||||
oshape = dshape
|
oshape = dshape
|
||||||
with nnvm.compiler.build_config(opt_level=1):
|
for target, ctx in ctx_list():
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
with nnvm.compiler.build_config(opt_level=1):
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
set_input("x", data)
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
# execute
|
set_input("x", data)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
|
get_output(0, out)
|
||||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
y_np = np.sinh(data.asnumpy()) / np.cosh(data.asnumpy())
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_sigmoid():
|
def test_sigmoid():
|
||||||
|
@ -116,20 +111,21 @@ def test_sigmoid():
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
dshape = (1, 3, 32, 32)
|
dshape = (1, 3, 32, 32)
|
||||||
oshape = dshape
|
oshape = dshape
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
for target, ctx in ctx_list():
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
set_input("x", data)
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
# execute
|
set_input("x", data)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
|
get_output(0, out)
|
||||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
y_np = 1.0 / (1.0 + np.exp(-data.asnumpy()))
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_softmax():
|
def test_softmax():
|
||||||
|
@ -138,24 +134,79 @@ def test_softmax():
|
||||||
dtype = "float32"
|
dtype = "float32"
|
||||||
dshape = (10, 1000)
|
dshape = (10, 1000)
|
||||||
oshape = dshape
|
oshape = dshape
|
||||||
with nnvm.compiler.build_config(opt_level=1):
|
for target, ctx in ctx_list():
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
|
with nnvm.compiler.build_config(opt_level=1):
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
set_input("x", data)
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
# execute
|
set_input("x", data)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
y_np = topi.testing.softmax_python(data.asnumpy())
|
get_output(0, out)
|
||||||
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
y_np = topi.testing.softmax_python(data.asnumpy())
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dense():
|
||||||
|
x = sym.Variable("x")
|
||||||
|
y = sym.dense(x, units=3, name="dense")
|
||||||
|
y = sym.flatten(y)
|
||||||
|
dtype = "float32"
|
||||||
|
shape = {
|
||||||
|
"x" : (10, 100),
|
||||||
|
"dense_weight" : (3, 100),
|
||||||
|
"dense_bias" : (3,),
|
||||||
|
}
|
||||||
|
graph, lib, _ = nnvm.compiler.build(y, "llvm", shape)
|
||||||
|
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
|
||||||
|
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
|
||||||
|
w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype)
|
||||||
|
b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype)
|
||||||
|
res = tvm.nd.empty((10, 3))
|
||||||
|
m.run(x=x_np, dense_weight=w_np, dense_bias=b_np)
|
||||||
|
m.get_output(0, res)
|
||||||
|
res_np = np.dot(x_np, w_np.T) + b_np
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batchnorm():
|
||||||
|
x = sym.Variable("x")
|
||||||
|
beta = sym.Variable("beta")
|
||||||
|
gamma = sym.Variable("gamma")
|
||||||
|
moving_var = sym.Variable("moving_var")
|
||||||
|
moving_mean = sym.Variable("moving_mean")
|
||||||
|
shape = (10, 20)
|
||||||
|
eps = 1e-5
|
||||||
|
dtype = "float32"
|
||||||
|
y = sym.batch_norm(
|
||||||
|
x, gamma, beta, moving_mean, moving_var, epsilon=eps)
|
||||||
|
|
||||||
|
for target, ctx in ctx_list():
|
||||||
|
graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape})
|
||||||
|
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
|
||||||
|
x_np = np.random.uniform(size=shape).astype(dtype)
|
||||||
|
mean_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||||
|
var_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||||
|
gamma_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||||
|
beta_np = np.random.uniform(size=shape[1]).astype(dtype)
|
||||||
|
res = tvm.nd.empty(shape)
|
||||||
|
m.run(x=x_np, moving_mean=mean_np, moving_var=var_np,
|
||||||
|
gamma=gamma_np, beta=beta_np)
|
||||||
|
m.get_output(0, res)
|
||||||
|
res_np = (x_np - mean_np) / np.sqrt(var_np + eps) * gamma_np + beta_np
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_batchnorm()
|
||||||
|
test_dense()
|
||||||
test_relu()
|
test_relu()
|
||||||
test_exp()
|
test_exp()
|
||||||
test_log()
|
test_log()
|
||||||
|
|
|
@ -6,19 +6,9 @@ import nnvm.symbol as sym
|
||||||
import nnvm.compiler
|
import nnvm.compiler
|
||||||
import nnvm.runtime
|
import nnvm.runtime
|
||||||
|
|
||||||
USE_GPU=True
|
def ctx_list():
|
||||||
|
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
|
||||||
def default_target():
|
return [x for x in res if x[1].exist]
|
||||||
if USE_GPU:
|
|
||||||
return 'cuda'
|
|
||||||
else:
|
|
||||||
return 'llvm'
|
|
||||||
|
|
||||||
def default_ctx():
|
|
||||||
if USE_GPU:
|
|
||||||
return tvm.gpu(0)
|
|
||||||
else:
|
|
||||||
return tvm.cpu(0)
|
|
||||||
|
|
||||||
def test_conv2d():
|
def test_conv2d():
|
||||||
x = sym.Variable("x")
|
x = sym.Variable("x")
|
||||||
|
@ -29,23 +19,24 @@ def test_conv2d():
|
||||||
kshape = (10, 3, 3, 3)
|
kshape = (10, 3, 3, 3)
|
||||||
oshape = (1, 10, 18, 18)
|
oshape = (1, 10, 18, 18)
|
||||||
shape_dict = {"x": dshape}
|
shape_dict = {"x": dshape}
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict)
|
for target, ctx in ctx_list():
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
set_input("x", data)
|
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||||
set_input("y_weight", kernel)
|
set_input("x", data)
|
||||||
# execute
|
set_input("y_weight", kernel)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
c_np = topi.testing.conv2d_nchw_python(
|
get_output(0, out)
|
||||||
data.asnumpy(), kernel.asnumpy(), 1, 1)
|
c_np = topi.testing.conv2d_nchw_python(
|
||||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
data.asnumpy(), kernel.asnumpy(), 1, 1)
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_grouped_conv2d():
|
def test_grouped_conv2d():
|
||||||
|
@ -57,23 +48,24 @@ def test_grouped_conv2d():
|
||||||
kshape = (32, 1, 3, 3)
|
kshape = (32, 1, 3, 3)
|
||||||
oshape = (1, 32, 18, 18)
|
oshape = (1, 32, 18, 18)
|
||||||
shape_dict = {"x": dshape}
|
shape_dict = {"x": dshape}
|
||||||
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict)
|
for target, ctx in ctx_list():
|
||||||
m = nnvm.runtime.create(graph, lib, default_ctx())
|
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
|
||||||
# get member functions
|
m = nnvm.runtime.create(graph, lib, ctx)
|
||||||
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
# get member functions
|
||||||
# set input
|
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
|
||||||
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
# set input
|
||||||
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
|
||||||
set_input("x", data)
|
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
|
||||||
set_input("y_weight", kernel)
|
set_input("x", data)
|
||||||
# execute
|
set_input("y_weight", kernel)
|
||||||
run()
|
# execute
|
||||||
# get output
|
run()
|
||||||
out = tvm.nd.empty(oshape, dtype)
|
# get output
|
||||||
get_output(0, out)
|
out = tvm.nd.empty(oshape, dtype)
|
||||||
c_np = topi.testing.depthwise_conv2d_python_nchw(
|
get_output(0, out)
|
||||||
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
|
c_np = topi.testing.depthwise_conv2d_python_nchw(
|
||||||
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME')
|
||||||
|
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Загрузка…
Ссылка в новой задаче