This commit is contained in:
Yao Wang 2018-03-29 09:27:37 -07:00 коммит произвёл Tianqi Chen
Родитель bd40bcd1a2
Коммит f4789db696
18 изменённых файлов: 1042 добавлений и 134 удалений

Просмотреть файл

@ -28,7 +28,6 @@ This level enables fully connected multi-layer perceptron.
:nosignatures:
nnvm.symbol.dense
nnvm.symbol.matmul
nnvm.symbol.relu
nnvm.symbol.tanh
nnvm.symbol.sigmoid
@ -40,12 +39,6 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.elemwise_sum
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
@ -57,7 +50,6 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.log_softmax
nnvm.symbol.pad
nnvm.symbol.block_grad
nnvm.symbol.indicator
**Level 2: Convolutions**
@ -81,8 +73,6 @@ This level enables typical convnet models.
:nosignatures:
nnvm.symbol.reshape
nnvm.symbol.reshape_like
nnvm.symbol.expand_like
nnvm.symbol.copy
nnvm.symbol.negative
nnvm.symbol.leaky_relu
@ -109,11 +99,21 @@ This level enables typical convnet models.
nnvm.symbol.broadcast_sub
nnvm.symbol.broadcast_mul
nnvm.symbol.broadcast_div
nnvm.symbol.clip
nnvm.symbol.greater
nnvm.symbol.less
nnvm.symbol.expand_like
nnvm.symbol.reshape_like
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
nnvm.symbol.ones_like
nnvm.symbol.zeros
nnvm.symbol.zeros_like
Detailed Definitions
--------------------
.. autofunction:: nnvm.symbol.dense
.. autofunction:: nnvm.symbol.matmul
.. autofunction:: nnvm.symbol.relu
.. autofunction:: nnvm.symbol.tanh
.. autofunction:: nnvm.symbol.sigmoid
@ -125,12 +125,6 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.elemwise_sum
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
@ -142,7 +136,6 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.log_softmax
.. autofunction:: nnvm.symbol.pad
.. autofunction:: nnvm.symbol.block_grad
.. autofunction:: nnvm.symbol.indicator
.. autofunction:: nnvm.symbol.conv2d
.. autofunction:: nnvm.symbol.conv2d_transpose
@ -152,8 +145,6 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.global_avg_pool2d
.. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.reshape_like
.. autofunction:: nnvm.symbol.expand_like
.. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.leaky_relu
@ -175,3 +166,14 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.broadcast_sub
.. autofunction:: nnvm.symbol.broadcast_mul
.. autofunction:: nnvm.symbol.broadcast_div
.. autofunction:: nnvm.symbol.clip
.. autofunction:: nnvm.symbol.greater
.. autofunction:: nnvm.symbol.less
.. autofunction:: nnvm.symbol.expand_like
.. autofunction:: nnvm.symbol.reshape_like
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
.. autofunction:: nnvm.symbol.ones_like
.. autofunction:: nnvm.symbol.zeros
.. autofunction:: nnvm.symbol.zeros_like

Просмотреть файл

@ -241,6 +241,16 @@ struct MatMulParam : public dmlc::Parameter<MatMulParam> {
}
};
struct ClipParam : public dmlc::Parameter<ClipParam> {
double a_min, a_max;
DMLC_DECLARE_PARAMETER(ClipParam) {
DMLC_DECLARE_FIELD(a_min)
.describe("Minimum value such that value smaller then this will be clipped.");
DMLC_DECLARE_FIELD(a_max)
.describe("Maximum value such that value larger then this will be clipped.");
}
};
} // namespace top
} // namespace nnvm

Просмотреть файл

@ -54,6 +54,9 @@ OpHandle = ctypes.c_void_p
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p
# Global dict of str to symbol to initialize variables
_all_var_init = {}
#----------------------------
# helper function definition
#----------------------------

Просмотреть файл

@ -4,9 +4,12 @@ from __future__ import absolute_import as _abs
import logging
import tvm
from tvm.contrib import graph_runtime
from . import graph_attr, graph_util
from .. import graph as _graph
from .. import symbol as sym
from .._base import _all_var_init
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
@ -201,6 +204,9 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
initialize : bool, optional
Whether to initialize variables in global dict _all_var_init.
Returns
-------
graph : Graph
@ -230,6 +236,10 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype))
# Initialize all variables specified in _all_var_init
init_var = {}
if _all_var_init:
init_var = initialize_variables(shape, dtype)
# Apply optimization
graph = optimize(graph, shape, dtype)
# Precompute prune
@ -250,6 +260,11 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
with target:
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
# Write variable initial values into params
if init_var:
if params is None:
params = {}
params.update(init_var)
return graph, libmod, params
@ -329,3 +344,45 @@ def precompute_prune(graph, params):
with tvm.build_config(auto_unroll_max_step=0):
out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs))
def initialize_variables(ishape, idtype):
""" Initialize variables stored in _all_var_init dictionary.
Parameters
----------
ishape : dict of str to tuple of int
The input shape to the graph
idtype : str or dict of str to str
The input types to the graph
Returns
-------
init_var : dict of str to tvm.ndarray
"""
symbol_init_dict = {}
const_init_dict = {}
init_var = {}
for key, value in _all_var_init.items():
if isinstance(value, sym.Symbol):
symbol_init_dict[key] = value
else:
const_init_dict[key] = tvm.nd.array(value)
# Make sure variables are initialized only once.
_all_var_init.clear()
if symbol_init_dict:
# Create dummy params to run initialization graph
params = {}
for name, shape in ishape.items():
dtype = idtype if isinstance(idtype, str) else idtype[name]
params[name] = tvm.nd.empty(shape, dtype, ctx=tvm.cpu())
init_group_sym = sym.Group(symbol_init_dict.values())
graph = _graph.create(init_group_sym)
with tvm.build_config(auto_unroll_max_step=0):
init_values = _run_graph(graph, params)
init_var.update(dict(zip(symbol_init_dict.keys(), init_values)))
init_var.update(const_init_dict)
for name, data in init_var.items():
ishape[name] = data.shape
return init_var

Просмотреть файл

@ -0,0 +1,58 @@
# pylint: disable=too-few-public-methods, no-member
"""API for scheduling learning rate."""
from .. import symbol as sym
class LRScheduler(object):
"""Base class of a learning rate scheduler.
A scheduler returns a new learning rate based on the number of updates that have
been performed.
Parameters
----------
base_lr : float, optional
The initial learning rate.
"""
def __init__(self, base_lr=0.01, name='LRScheduler'):
self.name = name
self.base_lr = base_lr
def __call__(self, num_update):
"""Return a new learning rate based on number of updates.
Parameters
----------
num_update: nnvm Symbol
the number of updates applied to weight.
"""
raise NotImplementedError("__call__ method must be overridden.")
class FactorScheduler(LRScheduler):
"""Reduce the learning rate by a factor for every *n* steps.
It returns a new learning rate by::
base_lr * pow(factor, num_update/step)
Parameters
----------
step : int
Changes the learning rate for every n updates.
factor : float, optional
The factor to change the learning rate.
stop_factor_lr : float, optional
Stop updating the learning rate if it is less than this value.
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8, name='FactorScheduler', **kwargs):
super(FactorScheduler, self).__init__(name=name, **kwargs)
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
raise ValueError("Factor must be no more than 1 to make lr reduce")
self.step = step
self.factor = factor
self.stop_factor_lr = stop_factor_lr
def __call__(self, num_update):
updated_lr = self.base_lr * self.factor ** (num_update / self.step)
return sym.clip(updated_lr, a_min=self.stop_factor_lr, a_max=self.base_lr)

Просмотреть файл

@ -0,0 +1,131 @@
# pylint: disable=invalid-name, no-member, too-few-public-methods, too-many-arguments, too-many-locals, protected-access
"""Optimizer API"""
from . import graph_util
from .. import symbol as sym
class Optimizer(object):
"""Base class inherited by all optimizers.
Parameters
----------
learning_rate : float, optional
The initial learning rate.
lr_scheduler : LRScheduler, optional
The learning rate scheduler.
rescale_grad : float, optional
Multiply the gradient with `rescale_grad` before updating. Often
choose to be ``1.0/batch_size``.
clip_gradient : float, optional
Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
wd : float, optional
The weight decay (or L2 regularization) coefficient. Modifies objective
by adding a penalty for having large weights.
name : string, optional
The name of optimizer.
"""
def __init__(self, learning_rate=0.01, lr_scheduler=None,
rescale_grad=1, clip_gradient=None, wd=0, name="Optimizer"):
self.name = name
self.lr = learning_rate
self.lr_scheduler = lr_scheduler
self.rescale_grad = rescale_grad
self.clip_gradient = clip_gradient
self.wd = wd
init_update_t = sym.Variable(name+'_t', init=sym.zeros(shape=(1,), dtype="int32"))
self.update_t = sym._assign(init_update_t, init_update_t + 1)
def minimize(self, obj, var=None):
"""Minimize given obj symbol respect to var. If var is not set, all input
variables of obj will be used.
Parameters
----------
obj : nnvm Symbol or list of nnvm Symbols
Symbols to be minimized.
var : nnvm Symbol or list of nnvm Symbols, optional
Symbols the gradient respect to.
Returns
-------
group_sym : nnvm Symbol
Group symbol represents update symbols.
"""
raise NotImplementedError()
def _get_lr(self):
"""Gets the learning rate with learning rate scheduler.
Returns
-------
lr : float
Learning rate.
"""
if self.lr_scheduler is not None:
lr = self.lr_scheduler(self.update_t)
else:
lr = self.lr
return lr
class SGD(Optimizer):
"""The SGD optimizer
"""
def __init__(self, name='SGD', **kwargs):
super(SGD, self).__init__(name=name, **kwargs)
def minimize(self, obj, var=None):
variables = var or obj.list_input_variables()
if not isinstance(variables, list):
variables = [variables]
grads = graph_util.gradients(obj, variables)
updates = []
lr_t = self._get_lr()
for v, g in zip(variables, grads):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
updates.append(sym._assign(v, v - lr_t * (g + self.wd * v)))
return sym.Group(updates)
class Adam(Optimizer):
"""The Adam optimizer.
This class implements the optimizer described in *Adam: A Method for
Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999,
epsilon=1e-8, name='Adam', **kwargs):
super(Adam, self).__init__(learning_rate=learning_rate, name=name, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = []
self.v = []
def minimize(self, obj, var=None):
variables = var or obj.list_input_variables()
if not isinstance(variables, list):
variables = [variables]
grads = graph_util.gradients(obj, variables)
updates = []
for i, v in enumerate(variables):
self.m.append(sym.Variable(self.name + '_m' + str(i), init=sym.zeros_like(v)))
self.v.append(sym.Variable(self.name + '_v' + str(i), init=sym.zeros_like(v)))
rate = sym.sqrt(1 - self.beta2 ** self.update_t) / (1 - self.beta1 ** self.update_t)
lr_t = self._get_lr() * rate
for variable, g, m, v in zip(variables, grads, self.m, self.v):
g = self.rescale_grad * g
if self.clip_gradient is not None:
g = sym.clip(g, a_min=-1 * self.clip_gradient, a_max=self.clip_gradient)
update_m = sym._assign(m, self.beta1 * m + (1 - self.beta1) * g)
update_v = sym._assign(v, self.beta2 * v + (1 - self.beta2) * g * g)
update_var = sym._assign(variable, variable - lr_t * (update_m / (sym.sqrt(update_v) \
+ self.epsilon) + self.wd * variable))
updates.append(update_var)
return sym.Group(updates)

Просмотреть файл

@ -1,4 +1,4 @@
# pylint: disable=invalid-name, unused-import
# pylint: disable=invalid-name, unused-import, protected-access
"""Symbolic graph construction API.
This namespace contains most of the registered operators.
@ -8,10 +8,12 @@ from __future__ import absolute_import as _abs
import sys as _sys
import os as _os
import ctypes as _ctypes
from numbers import Number as _Number
import numpy as np
from . import _base
from ._base import _LIB, check_call as _check_call, _FFI_MODE
from ._base import _LIB, check_call as _check_call, _FFI_MODE, _all_var_init
from .attribute import AttrScope
from . import _symbol_internal as _internal
@ -309,13 +311,19 @@ class Symbol(SymbolBase):
self.handle, deps.handle))
def Variable(name, **kwargs):
def Variable(name, init=None, **kwargs):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
init : Symbol or numpy.ndarray
Symbol or numpy ndarray of initial value for the variable.
Note that for symbolic initialization value, it must be able
to be defined through InferShape, such as sym.zeros_like(v),
in which v is an input or parameter. Otherwise, pass a numpy
ndarray instead.
kwargs : dict of string -> string
Additional attributes to set on the variable.
@ -333,6 +341,11 @@ def Variable(name, **kwargs):
attr = AttrScope.current.get(kwargs)
if attr:
ret._set_attr(**attr)
if init is not None:
if not isinstance(init, (Symbol, np.ndarray)):
raise TypeError('Expect a Symbol or numpy ndarray'
'for variable `init`')
_all_var_init[name] = init
return ret

Просмотреть файл

@ -123,6 +123,21 @@ class AttrDict(object):
else:
raise ValueError("Wrong bool format for key %s" % key)
def get_string(self, key):
"""Get string from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : str
The result value
"""
return self[key]
def __repr__(self):
return str({k : self[k] for k in self.keys()})

Просмотреть файл

@ -143,3 +143,95 @@ reg.register_schedule("broadcast_div", _fschedule_broadcast)
# broadcast_to
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
reg.register_schedule("broadcast_to", _fschedule_broadcast)
# clip
reg.register_pattern("clip", OpPattern.ELEMWISE)
reg.register_schedule("clip", _fschedule_elemwise)
# elemwise sum
@reg.register_compute("elemwise_sum")
def compute_elemwise_sum(attrs, inputs, _):
"""Compute definition of elemwise sum"""
num_args = attrs.get_int("num_args")
assert num_args == len(inputs), "Number of tensors does not match num_args."
return topi.tensor.elemwise_sum(inputs, num_args)
reg.register_pattern("elemwise_sum", OpPattern.ELEMWISE)
reg.register_schedule("elemwise_sum", _fschedule_elemwise)
# full
@reg.register_compute("full")
def compute_full(attrs, inputs, _):
"""Compute definition of full"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
fill_value = attrs.get_float("fill_value")
return topi.tensor.full(shape, dtype, fill_value)
reg.register_pattern("full", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("full", _fschedule_elemwise)
# full_like
@reg.register_compute("full_like")
def compute_full_like(attrs, inputs, _):
"""Compute definition of full_like"""
fill_value = attrs.get_float("fill_value")
return topi.tensor.full_like(inputs[0], fill_value)
reg.register_pattern("full_like", OpPattern.ELEMWISE)
reg.register_schedule("full_like", _fschedule_elemwise)
# zeros
@reg.register_compute("zeros")
def compute_zeros(attrs, inputs, _):
"""Compute definition of zeros"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
return topi.tensor.full(shape, dtype, 0)
reg.register_pattern("zeros", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("zeros", _fschedule_elemwise)
# zeros_like
@reg.register_compute("zeros_like")
def compute_zeros_like(_, inputs, out_info):
"""Compute definition of zeros_like"""
return topi.tensor.full_like(inputs[0], 0)
reg.register_pattern("zeros_like", OpPattern.ELEMWISE)
reg.register_schedule("zeros_like", _fschedule_elemwise)
# ones
@reg.register_compute("ones")
def compute_ones(attrs, inputs, _):
"""Compute definition of ones"""
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_string("dtype")
#tvm.tensor.Tensor()
return topi.tensor.full(shape, dtype, 1)
reg.register_pattern("ones", OpPattern.OUT_ELEMWISE_FUSABLE)
reg.register_schedule("ones", _fschedule_elemwise)
# ones_like
@reg.register_compute("ones_like")
def compute_ones_like(_, inputs, out_info):
"""Compute definition of ones_like"""
return topi.tensor.full_like(inputs[0], 1)
reg.register_pattern("ones_like", OpPattern.ELEMWISE)
reg.register_schedule("ones_like", _fschedule_elemwise)
# greater
@reg.register_compute("greater")
def compute_greater(_, inputs, out_info):
"""Compute definition of greater"""
return topi.tensor.greater(inputs[0], inputs[1], 'float32')
reg.register_pattern("greater", OpPattern.ELEMWISE)
reg.register_schedule("greater", _fschedule_elemwise)
# less
@reg.register_compute("less")
def compute_less(_, inputs, out_info):
"""Compute definition of less"""
return topi.tensor.less(inputs[0], inputs[1], 'float32')
reg.register_pattern("less", OpPattern.ELEMWISE)
reg.register_schedule("less", _fschedule_elemwise)
# block_grad
reg.register_compute("block_grad", _compute_unary(topi.identity))
reg.register_pattern("block_grad", OpPattern.ELEMWISE)
reg.register_schedule("block_grad", _fschedule_elemwise)

Просмотреть файл

@ -2,6 +2,7 @@
"""Tensor transformation ops"""
from __future__ import absolute_import
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
from .registry import OpPattern
@ -10,6 +11,32 @@ from .registry import OpPattern
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast)
# expand_like
@reg.register_compute("expand_like")
def compute_expand_like(attrs, inputs, _):
"""Compute definition of expand_like"""
exclude = attrs.get_bool("exclude")
axis = attrs.get_int_tuple("axis")
if exclude:
exclude_axis = (axis,) if isinstance(axis, int) else axis
axis = []
for item in range(len(inputs[1].shape)):
if item not in exclude_axis:
axis.append(item)
axis = tuple(axis)
return topi.transform.expand_like(inputs[0], inputs[1], axis)
reg.register_pattern("expand_like", OpPattern.BROADCAST)
reg.register_schedule("expand_like", _fschedule_broadcast)
# reshape_like
@reg.register_compute("reshape_like")
def compute_reshape_like(attrs, inputs, out_info):
"""Compute definition of reshape_like"""
return topi.reshape(inputs[0], inputs[1].shape)
reg.register_pattern("reshape_like", OpPattern.INJECTIVE)
reg.register_schedule("reshape_like", _fschedule_injective)
# transpose
reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective)

Просмотреть файл

@ -130,15 +130,14 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(relu)
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
// y = relu(x)
// grad = indicator(x > 0)
NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
// grad = indicator(x > 0) * ograd
NodeEntry sub0 = MakeNode("zeros_like", n->attrs.name + "_sub0",
{n->inputs[0]});
NodeEntry sub1 = MakeNode("greater", n->attrs.name + "_sub1",
{n->inputs[0], sub0}, {{"exclude", "true"}});
return std::vector<NodeEntry>{
MakeNode("elemwise_mul", n->attrs.name + "_grad", {
ograds[0],
MakeNode("greater", n->attrs.name + "_grad_mask",
{n->inputs[0], zero}, {{"exclude", "true"}})
})
MakeNode("elemwise_mul", n->attrs.name + "_grad",
{ograds[0], sub1})
};
})
.set_support_level(1);
@ -358,23 +357,21 @@ NNVM_REGISTER_OP(log_softmax)
// grad_x = sum(grad_x, keepdim, axis)
// grad_x = neg grad_x
// grad_x = grad_x + ones_like(grad_x)
// grad_x = expand_dims(grad_x, axis)
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(n->attrs.parsed);
NodeEntry output = NodeEntry{n, 0, 0};
NodeEntry sub0 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub0", {ograds[0], output});
NodeEntry sub1 = MakeNode("sum", n->attrs.name + "_grad_sub1", {sub0},
{{"axis", std::to_string(param.axis)}, {"keepdims", "true"}});
NodeEntry sub2 = MakeNode("negative", n->attrs.name + "_grad_sub2", {sub1});
NodeEntry sub3 = MakeNode("ones_like", n->attrs.name + "_grad_sub3", {sub2});
NodeEntry sub4 = MakeNode("elemwise_add", n->attrs.name + "_grad_sub4", {sub2, sub3});
NodeEntry sub2 = MakeNode("full_like", n->attrs.name + "_grad_sub2", {n->inputs[0]},
{{"fill_value", "-1"}});
NodeEntry sub3 = MakeNode("broadcast_mul", n->attrs.name + "_grad_sub3", {sub1, sub2});
return std::vector<NodeEntry> {
MakeNode("expand_like", n->attrs.name + "_grad", {sub4, output},
{{"axis", std::to_string(param.axis)}})
MakeNode("elemwise_add", n->attrs.name + "_grad", {sub3, ograds[0]})
};
})
.set_support_level(1);
// leaky_rlu
// leaky_relu
DMLC_REGISTER_PARAMETER(LeakyReLUParam);
NNVM_REGISTER_OP(leaky_relu)
@ -407,14 +404,15 @@ NNVM_REGISTER_OP(leaky_relu)
NodeEntry zero = MakeNode("zeros_like", n->attrs.name + "_grad_zero",
{n->inputs[0]});
NodeEntry sub0 = MakeNode("greater", n->attrs.name + "_pos_grad",
{n->inputs[0], zero}, {{"exclude", "true"}});
{n->inputs[0], zero});
NodeEntry sub1 = MakeNode("less", n->attrs.name + "_neg_grad",
{n->inputs[0], zero}, {{"exclude", "true"}});
{n->inputs[0], zero});
NodeEntry sub2 = MakeNode("__mul_scalar__", n->attrs.name + "_neg_mul_2",
{sub1},
{{"scalar", std::to_string(param.alpha)}});
NodeEntry sub3 = MakeNode("elemwise_add", n->attrs.name + "_sub3", {sub0, sub2});
return std::vector<NodeEntry>{
MakeNode("elemwise_add", n->attrs.name + "_add_grad", {sub0, sub2})
MakeNode("elemwise_mul", n->attrs.name + "_grad", {ograds[0], sub3})
};
})
.set_support_level(1);

Просмотреть файл

@ -190,7 +190,10 @@ NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_add)
// y = n0 + n1
// grad_0 = grad_y
// grad_1 = grad_y
return std::vector<NodeEntry>{ograds[0], ograds[0]};
return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}),
MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}) };
});
NNVM_REGISTER_ELEMWISE_BINARY_OP(elemwise_sub)
@ -311,7 +314,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(copy)
const std::vector<NodeEntry>& ograds){
// y = copy(n0)
// grad_0 = grad_y
return std::vector<NodeEntry>{ograds[0]};
return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}) };
});
DMLC_REGISTER_PARAMETER(InitOpParam);
@ -329,7 +333,7 @@ NNVM_REGISTER_INIT_OP(full)
.add_arguments(InitOpWithScalarParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpWithScalarParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpWithScalarParam>)
.set_support_level(1);
.set_support_level(4);
NNVM_REGISTER_INIT_OP(zeros)
.describe(R"code(Fill target with zeros
@ -341,7 +345,7 @@ NNVM_REGISTER_INIT_OP(zeros)
.add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_support_level(1);
.set_support_level(4);
NNVM_REGISTER_INIT_OP(ones)
.describe(R"code(Fill target with ones
@ -353,7 +357,7 @@ NNVM_REGISTER_INIT_OP(ones)
.add_arguments(InitOpParam::__FIELDS__())
.set_attr<FInferShape>("FInferShape", ZeroShape<InitOpParam>)
.set_attr<FInferType>("FInferType", ZeroType<InitOpParam>)
.set_support_level(1);
.set_support_level(4);
// full_like
NNVM_REGISTER_INIT_LIKE_OP(full_like)
@ -364,21 +368,21 @@ as the input array
.add_arguments(FillValueParam::__FIELDS__())
.set_attr_parser(ParamParser<FillValueParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FillValueParam>)
.set_support_level(1);
.set_support_level(4);
NNVM_REGISTER_INIT_LIKE_OP(zeros_like)
.describe(R"code(Return an array of zeros with the same shape and type
as the input array.
)code")
.set_support_level(1);
.set_support_level(4);
NNVM_REGISTER_INIT_LIKE_OP(ones_like)
.describe(R"code(Return an array of ones with the same shape and type
as the input array.
)code")
.set_support_level(1);
.set_support_level(4);
// unary scalar op
DMLC_REGISTER_PARAMETER(ScalarParam);
@ -415,7 +419,8 @@ NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__add_scalar__)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0]};
return std::vector<NodeEntry>{ MakeNode("copy", n->attrs.name + "_grad_0",
{ograds[0]}) };
});
NNVM_REGISTER_ELEMWISE_BINARY_SCALAR(__sub_scalar__)
@ -601,10 +606,11 @@ NNVM_REGISTER_ELEMWISE_REDUCE_OP(elemwise_sum)
CHECK_EQ(ograds.size(), 1);
std::vector<NodeEntry> ret;
for (size_t i = 0; i < n->inputs.size(); i++) {
ret.push_back(ograds[0]);
ret.push_back(MakeNode("copy", n->attrs.name + "_grad_0", {ograds[0]}));
}
return ret;
});
})
.set_support_level(4);
NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad)
.describe(R"code(Blocks gradient computation for input.
@ -614,7 +620,8 @@ NNVM_REGISTER_ELEMWISE_UNARY_OP(block_grad)
"FInplaceIdentity", [](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_support_level(4);
DMLC_REGISTER_PARAMETER(IndicatorParam);
@ -628,7 +635,7 @@ with 1.0 if (left > right), otherwise 0.0 element-wise.
.add_argument("rhs", "Tensor", "Second input")
.set_num_inputs(2)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_support_level(1);
.set_support_level(4);
NNVM_REGISTER_INDICATOR_OP(less)
@ -640,7 +647,7 @@ with 1.0 if (left < right), otherwise 0.0 element-wise.
.add_argument("rhs", "Tensor", "Second input")
.set_num_inputs(2)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_support_level(1);
.set_support_level(4);
NNVM_REGISTER_INDICATOR_OP(_max_mask)
.describe(R"code(Function that returns a mask tensor
@ -668,5 +675,73 @@ with 1.0 if the value is minimum over given axes, otherwise 0.0 element-wise.
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_support_level(1);
DMLC_REGISTER_PARAMETER(ClipParam);
NNVM_REGISTER_OP(clip)
.describe(R"doc(Clips (limits) the values in an array.
Given an interval, values outside the interval are clipped to the interval edges.
Clipping ``x`` between `a_min` and `a_x` would be::
clip(x, a_min, a_max) = max(min(x, a_max), a_min))
Example::
x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
clip(x,1,8) = [ 1., 1., 2., 3., 4., 5., 6., 7., 8., 8.]
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<ClipParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ClipParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ClipParam params = get<ClipParam>(attrs.parsed);
return Array<Tensor>{
topi::clip(inputs[0], tvm::make_const(tvm::Float(32), params.a_min),
tvm::make_const(tvm::Float(32), params.a_max)) };
})
.add_argument("data", "NDArray-or-Symbol", "Input array.")
.add_arguments(ClipParam::__FIELDS__())
.set_attr<nnvm::FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
// y = clip(x, a_min, a_max)
// min_mask = greater_equal(x, a_min*ones_like(x))
// => ones_like(x) - less(x, a_min)
// max_mask = less_equal(x, a_max*ones_like(x))
// => ones_like(x) - greater(x, a_max)
// grad_x = min_mask * max_mask * grad_y
CHECK_EQ(ograds.size(), 1);
NodeEntry sub0 = MakeNode("ones_like", n->attrs.name + "_grad_sub_0",
{n->inputs[0]});
// min_mask
NodeEntry sub1 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_1",
{sub0}, {{"scalar", n->attrs.dict["a_min"]}});
NodeEntry sub2 = MakeNode("less", n->attrs.name + "_grad_sub_2",
{n->inputs[0], sub1});
NodeEntry sub3 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_3",
{sub0, sub2});
// max_mask
NodeEntry sub4 = MakeNode("__mul_scalar__", n->attrs.name + "_grad_sub_4",
{sub0}, {{"scalar", n->attrs.dict["a_max"]}});
NodeEntry sub5 = MakeNode("greater", n->attrs.name + "_grad_sub_5",
{n->inputs[0], sub4});
NodeEntry sub6 = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub_6",
{sub0, sub5});
// min_mask * max_mask
NodeEntry sub7 = MakeNode("elemwise_mul", n->attrs.name + "_grad_sub_7",
{sub3, sub6});
return std::vector<NodeEntry>{
MakeNode("elemwise_mul", n->attrs.name + "_grad",
{sub7, ograds[0]})
};
})
.set_support_level(4);
} // namespace top
} // namespace nnvm

Просмотреть файл

@ -137,7 +137,20 @@ Example::
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
Array<Expr> axis;
if (param.exclude) {
std::set<dim_t> exclude_axis;
for (dim_t i = 0; i < param.axis.ndim(); ++i) {
exclude_axis.insert(param.axis[i]);
}
for (dim_t i = 0; i < inputs[0].ndim(); ++i) {
if (exclude_axis.count(i) == 0) {
axis.push_back(make_const(Int(32), i));
}
}
} else {
axis = ShapeToArray(param.axis);
}
return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) };
})
@ -150,7 +163,6 @@ Example::
MakeNode("expand_like", n->attrs.name + "_grad",
{ograds[0], n->inputs[0]},
{{"axis", axis.str()},
{"keepdims", std::to_string(param.keepdims)},
{"exclude", std::to_string(param.exclude)}})
};
});

Просмотреть файл

@ -48,6 +48,15 @@ This is an experimental operator.
.set_attr<FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{1, 0}};
})
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("zeros_like", n->attrs.name + "_zero_grad",
{n->inputs[0]}),
ograds[0]
};
});
} // namespace top

Просмотреть файл

@ -229,29 +229,24 @@ will return a new array with shape ``(2,5,3,4)``.
NNVM_REGISTER_OP(expand_like)
.describe(R"code(Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and expanding dims.
Examples::
input = [ 12. 19. 27.]
input.shape = (3,)
new_shape_array = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
new_shape_array.shape = (3, 3, 2)
expand_like(input, [1,2], new_shape_array) =
[[[12,12],[12,12],[12,12]],
[[19,19],[19,19],[19,19]],
[[27,27],[27,27],[27,27]]]
)code" NNVM_ADD_FILELINE)
.add_argument("input", "Tensor", "Source input")
.add_argument("shape_like", "Tensor", "Input with new shape")
.add_arguments(ReduceParam::__FIELDS__())
.set_attr_parser(ParamParser<ReduceParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>)
.add_arguments(IndicatorParam::__FIELDS__())
.set_attr_parser(ParamParser<IndicatorParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<IndicatorParam>)
.set_attr<nnvm::FInferShape>("FInferShape", AssignOutputAttr<TShape, 1, 0>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_num_inputs(2)
@ -259,7 +254,7 @@ Examples::
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
const ReduceParam& param = nnvm::get<ReduceParam>(n->attrs.parsed);
const IndicatorParam& param = nnvm::get<IndicatorParam>(n->attrs.parsed);
std::ostringstream axis;
axis << param.axis;
@ -267,11 +262,11 @@ Examples::
MakeNode("sum", n->attrs.name + "_grad",
{ograds[0]},
{{"axis", axis.str()},
{"keepdims", std::to_string(param.keepdims)},
{"exclude", std::to_string(param.exclude)}})
{"exclude", std::to_string(param.exclude)}}),
MakeNode("zeros_like", n->attrs.name + "_zero_grad", {n->inputs[1]})
};
})
.set_support_level(1);
})
.set_support_level(4);
// split
DMLC_REGISTER_PARAMETER(SplitParam);
@ -564,13 +559,10 @@ The significance of each is explained below:
NNVM_REGISTER_OP(reshape_like)
.describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_argument("shape_like", "Tensor", "Input data.")
@ -589,10 +581,12 @@ the input array into an output array with the same shape as the second input arr
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return MakeGradNode("reshape_like", n,
{ograds[0], n->inputs[0]});
return std::vector<NodeEntry>{
MakeNode("reshape_like", n->attrs.name + "_grad", {ograds[0], n->inputs[0]}),
MakeNode("zeros_like", n->attrs.name + "_zero_grad", { n->inputs[1]})
};
})
.set_support_level(3);
.set_support_level(4);
// squeeze
DMLC_REGISTER_PARAMETER(SqueezeParam);
@ -680,7 +674,8 @@ Examples::
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("reshape_like", n->attrs.name + "_grad", {n->inputs[0]})
MakeNode("reshape_like", n->attrs.name + "_grad",
{ograds[0], n->inputs[0]})
};
})
.set_support_level(1);

Просмотреть файл

@ -0,0 +1,118 @@
import numpy as np
import tvm
import nnvm
import nnvm.compiler.optimizer as optimizer
import nnvm.compiler.lr_scheduler as lr_scheduler
from nnvm.testing.config import ctx_list
from tvm.contrib import graph_runtime
def helper(symbol, inputs, params, update_func, run_times, target, ctx, dtype="float32"):
ishapes = {}
np_inputs = {}
params_dict = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
for (name, shape, s) in params:
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
params_dict.update({name: np_inputs[name]})
graph, lib, rt_params = nnvm.compiler.build(symbol, target, shape=ishapes)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**np_inputs)
m.set_input(**rt_params)
for _ in range(run_times):
m.run()
y_np = update_func(**np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
def test_sgd():
for target, ctx in ctx_list():
data = nnvm.sym.Variable("data")
weight = nnvm.sym.Variable("weight")
out = nnvm.sym.elemwise_mul(data, weight ** 2)
dshape = (1, 2, 3)
wshape = dshape
base_lr = 0.1
lr_factor = 0.5
rescale_grad = 0.2
wd = 0.1
clip_gradient = 0.25
scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor)
opt = optimizer.SGD(learning_rate=base_lr, lr_scheduler=scheduler,
rescale_grad=rescale_grad, clip_gradient=clip_gradient,
wd=wd)
opt_sym = opt.minimize(out, var=weight)
inputs = [("data", dshape, data)]
params = [("weight", wshape, weight)]
def update_func(data, weight):
gradient_0 = data * 2 * weight * rescale_grad
gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient)
weight_0 = weight - base_lr * lr_factor * (gradient_0 + wd * weight)
gradient_1 = data * 2 * weight_0 * rescale_grad
gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient)
weight_1 = weight_0 - base_lr * (lr_factor ** 2) * (gradient_1 + wd * weight_0)
return weight_1
helper(opt_sym, inputs, params, update_func, 2, target, ctx)
def test_adam():
for target, ctx in ctx_list():
data = nnvm.sym.Variable("data")
weight = nnvm.sym.Variable("weight")
out = nnvm.sym.elemwise_mul(data, weight ** 2)
dshape = (1, 2, 3)
wshape = dshape
base_lr = 0.1
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
lr_factor = 0.5
rescale_grad = 0.2
wd = 0.1
clip_gradient = 0.25
scheduler = lr_scheduler.FactorScheduler(base_lr=base_lr, step=1, factor=lr_factor)
opt = optimizer.Adam(learning_rate=base_lr, beta1=beta1, beta2=beta2, epsilon=epsilon,
lr_scheduler=scheduler, rescale_grad=rescale_grad,
clip_gradient=clip_gradient, wd=wd)
opt_sym = opt.minimize(out, var=weight)
inputs = [("data", dshape, data)]
params = [("weight", wshape, weight)]
def update_func(data, weight):
rate_0 = np.sqrt(1 - beta2) / (1 - beta1)
lr_0 = base_lr * lr_factor * rate_0
gradient_0 = data * 2 * weight * rescale_grad
gradient_0 = np.clip(gradient_0, -clip_gradient, clip_gradient)
m_0 = (1 - beta1) * gradient_0
v_0 = (1 - beta2) * (gradient_0 ** 2)
weight_0 = weight - lr_0 * (m_0 / (np.sqrt(v_0) + epsilon) + wd * weight)
rate_1 = np.sqrt(1 - beta2 ** 2) / (1 - beta1 ** 2)
lr_1 = base_lr * (lr_factor ** 2) * rate_1
gradient_1 = data * 2 * weight_0 * rescale_grad
gradient_1 = np.clip(gradient_1, -clip_gradient, clip_gradient)
m_1 = beta1 * m_0 + (1 - beta1) * gradient_1
v_1 = beta2 * v_0 + (1 - beta2) * (gradient_1 ** 2)
weight_1 = weight_0 - lr_1 * (m_1 / (np.sqrt(v_1) + epsilon) + wd * weight_0)
return weight_1
helper(opt_sym, inputs, params, update_func, 2, target, ctx)
if __name__ == "__main__":
test_sgd()
test_adam()

Просмотреть файл

@ -8,15 +8,14 @@ from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype,
np_forward, np_backward=None):
np_forward, np_backward=None, need_input=True, need_head_grads=True):
ishapes = {}
input_syms = []
np_inputs = {}
for (k, v) in inputs.items():
ishapes.update({k: v[0]})
np_inputs.update({k: np.random.uniform(size=v[0]).astype(dtype)})
if len(v) > 1:
input_syms.append(v[1])
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
input_syms.append(s)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
@ -25,23 +24,26 @@ def helper(symbol, inputs, dtype,
y_np = np_forward(**np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
# backward
if np_backward:
graph._set_symbol_list_attr("grad_ys", symbol)
for x in input_syms:
graph._set_symbol_list_attr("grad_xs", x)
graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads"))
graph._set_symbol_list_attr("grad_xs", input_syms)
graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape))
graph = graph.apply("Gradient")
ishapes.update({"head_grads": y_np.shape})
graph, lib, _ = nnvm.compiler.build(graph, target, ishapes)
m = graph_runtime.create(graph, lib, ctx)
head_grads = np.random.uniform(size=y_np.shape).astype(dtype)
y_np = head_grads * np_backward(**np_inputs)
m.run(head_grads=head_grads, **np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
y_np = np_backward(head_grads=head_grads, **np_inputs)
b_inputs = {}
if need_input:
b_inputs.update(np_inputs)
if need_head_grads:
b_inputs.update({"head_grads":head_grads})
m.run(**b_inputs)
for i in range(len(y_np)):
out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5)
def test_relu():
@ -52,10 +54,15 @@ def test_relu():
x = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
return (x > 0) * x
def backward(head_grads, x):
sub = (x < 0) * x * 0.3 + (x > 0) * x - 0.2
return [(sub > 0).astype("float") * \
((x > 0).astype("float") + 0.3 * (x < 0).astype("float")) * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
helper(y, inputs, dtype, forward)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
def test_sym_scalar_pow():
@ -66,12 +73,12 @@ def test_sym_scalar_pow():
def forward(x):
return x**scalar
def backward(x):
return scalar * x**(scalar - 1)
def backward(head_grads, x):
return [scalar * x**(scalar - 1) * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
@ -83,12 +90,12 @@ def test_scalar_sym_pow():
def forward(x):
return scalar**x
def backward(x):
return np.log(scalar) * scalar**x
def backward(head_grads, x):
return [np.log(scalar) * scalar**x * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
@ -99,12 +106,12 @@ def test_exp():
def forward(x):
return np.exp(x)
def backward(x):
return np.exp(x)
def backward(head_grads, x):
return [np.exp(x) * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
@ -115,12 +122,12 @@ def test_log():
def forward(x):
return np.log(x)
def backward(x):
return 1. / x
def backward(head_grads, x):
return [1. / x * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
@ -131,13 +138,13 @@ def test_tanh():
def forward(x):
return np.sinh(x) / np.cosh(x)
def backward(x):
def backward(head_grads, x):
y_np = forward(x)
return (1 - y_np**2)
return [(1 - y_np**2) * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
@ -148,13 +155,13 @@ def test_sigmoid():
def forward(x):
return 1.0 / (1.0 + np.exp(-x))
def backward(x):
def backward(head_grads, x):
y_np = forward(x)
return y_np *(1 - y_np)
return [y_np *(1 - y_np) * head_grads]
dtype = "float32"
dshape = (1, 3, 32, 32)
inputs = {'x': (dshape, x)}
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
@ -165,10 +172,15 @@ def test_softmax():
def forward(x):
return topi.testing.softmax_python(x)
def backward(head_grads, x):
y = topi.testing.softmax_python(x)
grad = y * (head_grads - np.sum(y * head_grads, axis=1, keepdims=True))
return [grad]
dtype = "float32"
dshape = (10, 1000)
inputs = {'x': (dshape, x)}
helper(y, inputs, dtype, forward)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward), backward
def test_log_softmax():
@ -178,26 +190,32 @@ def test_log_softmax():
def forward(x):
return topi.testing.log_softmax_python(x)
def backward(head_grads, x):
y = topi.testing.log_softmax_python(x)
grad = head_grads - np.sum(y * head_grads, axis=1, keepdims=True)
return [grad]
dtype = "float32"
dshape = (10, 1000)
inputs = {'x': (dshape, x)}
helper(y, inputs, dtype, forward)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
def test_dense():
x = sym.Variable("x")
y = sym.dense(x, units=3, name="dense")
x = sym.Variable("x", shape=(10, 100))
w = sym.Variable("dense_weight", shape=(3, 100))
b = sym.Variable("dense_bias", shape=(3,))
y = sym.dense(x, w, b, use_bias=True, units=3, name="dense")
y = sym.flatten(y)
def forward(x, dense_weight, dense_bias):
return np.dot(x, dense_weight.T) + dense_bias
dtype = "float32"
inputs = {
'x': ((10, 100), x),
'dense_weight': ((3, 100),),
'dense_bias': ((3,),)
}
inputs = [
('x', (10, 100), x),
('dense_weight', (3, 100), w),
('dense_bias', (3,), b)
]
helper(y, inputs, dtype, forward)
@ -215,13 +233,13 @@ def test_batchnorm():
return (x - moving_mean) / np.sqrt(moving_var + eps) * gamma + beta
dtype = "float32"
inputs = {
'x': ((10, 20), x),
'gamma': ((20,),),
'beta': ((20,),),
'moving_mean': ((20,),),
'moving_var': ((20,),)
}
inputs = [
('x', (10, 20), x),
('gamma', (20,), gamma),
('beta', (20,), beta),
('moving_mean', (20,), moving_var),
('moving_var', (20,), moving_mean)
]
helper(y, inputs, dtype, forward)
@ -283,9 +301,12 @@ def verify_squeeze(dshape, axis):
def forward(x):
return np.squeeze(x, axis=axis) + 1
def backward(head_grads, x):
return [np.reshape(head_grads, x.shape)]
dtype = "float32"
inputs = {'x': (dshape, x)}
helper(y, inputs, dtype, forward)
inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward)
def test_squeeze():
@ -304,7 +325,7 @@ def test_pad():
mode='constant', constant_values=1.)
dtype = "float32"
inputs = {'x': ((1, 3, 28, 28), x)}
inputs = [('x', (1, 3, 28, 28), x)]
helper(y, inputs, dtype, forward)

Просмотреть файл

@ -6,6 +6,46 @@ import nnvm.symbol as sym
import nnvm.compiler
from nnvm.testing.config import ctx_list
def helper(symbol, inputs, dtype,
np_forward, np_backward=None, need_input=True, need_head_grads=True):
ishapes = {}
input_syms = []
np_inputs = {}
for (name, shape, s) in inputs:
ishapes.update({name: shape})
np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
input_syms.append(s)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
m = graph_runtime.create(graph, lib, ctx)
m.run(**np_inputs)
y_np = np_forward(**np_inputs)
out = m.get_output(0, tvm.nd.empty(y_np.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
# backward
if np_backward:
graph._set_symbol_list_attr("grad_ys", symbol)
graph._set_symbol_list_attr("grad_xs", input_syms)
graph._set_symbol_list_attr("grad_ys_out_grad", sym.Variable("head_grads", shape=y_np.shape))
graph = graph.apply("Gradient")
ishapes.update({"head_grads": y_np.shape})
graph, lib, _ = nnvm.compiler.build(graph, target, ishapes)
m = graph_runtime.create(graph, lib, ctx)
head_grads = np.random.uniform(size=y_np.shape).astype(dtype)
y_np = np_backward(head_grads=head_grads, **np_inputs)
b_inputs = {}
if need_input:
b_inputs.update(np_inputs)
if need_head_grads:
b_inputs.update({"head_grads":head_grads})
m.run(**b_inputs)
for i in range(len(y_np)):
out = m.get_output(i, tvm.nd.empty(y_np[i].shape, dtype))
np.testing.assert_allclose(out.asnumpy(), y_np[i], atol=1e-5, rtol=1e-5)
def verify_transpose(dshape, axes):
x = sym.Variable("x")
if axes:
@ -66,13 +106,245 @@ def verify_reshape(dshape, oshape):
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
def test_reshape():
verify_reshape((2, 3, 4), (-1, 2, 1))
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))
def test_clip():
x = sym.Variable("x")
a_min=0.2
a_max=0.75
y = sym.clip(x, a_min=a_min, a_max=a_max)
def forward(x):
return np.clip(x, a_min=a_min, a_max=a_max)
def backward(head_grads, x):
mask1 = np.greater_equal(x, a_min).astype("float")
mask2 = np.less_equal(x, a_max).astype("float")
return [head_grads * mask1 * mask2]
dtype = "float32"
inputs = [('x', (3, 4, 5), x)]
helper(y, inputs, dtype, forward, backward)
def test_greater():
l = sym.Variable("l")
r = sym.Variable("r")
y = sym.greater(l, r)
def forward(l, r):
return np.greater(l, r).astype("float32")
def backward(head_grads, l, r):
return [np.zeros_like(l)]
dtype = "float32"
inputs = [('l', (3, 4, 5), l),
('r', (3, 4, 5), r)]
helper(y, inputs, dtype, forward, backward, need_head_grads=False)
def test_less():
l = sym.Variable("l")
r = sym.Variable("r")
y = sym.less(l, r)
def forward(l, r):
return np.less(l, r).astype("float32")
def backward(head_grads, l, r):
return [np.zeros_like(l)]
dtype = "float32"
inputs = [('l', (3, 4, 5), l),
('r', (3, 4, 5), r)]
helper(y, inputs, dtype, forward, backward, need_head_grads=False)
def test_reshape_like():
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.reshape_like(x, y)
def forward(x, y):
return np.reshape(x, y.shape)
def backward(head_grads, x, y):
return [np.reshape(head_grads, x.shape),
np.zeros_like(y)]
dtype = "float32"
inputs = [('x', (3, 4, 5), x),
('y', (5, 4, 3), y)]
helper(z, inputs, dtype, forward, backward)
def verify_expand_like(in_shape, out_shape, axis, exclude):
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.expand_like(x, y, axis=axis, exclude=exclude)
def forward(x, y):
odim = len(out_shape)
real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis)
if exclude:
real_axis = list(set(range(odim)) - set(real_axis))
for i in real_axis:
x = np.expand_dims(x, i).astype(x.dtype)
for i in real_axis:
x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype)
return x
def backward(head_grads, x, y):
odim = len(out_shape)
real_axis = [i if i >= 0 else i + odim for i in axis]
real_axis = sorted(real_axis)
if exclude:
real_axis = list(set(range(odim)) - set(real_axis))
return [np.sum(head_grads, axis=tuple(real_axis)),
np.zeros_like(y)]
dtype = "float32"
inputs = [('x', in_shape, x),
('y', out_shape, y)]
helper(z, inputs, dtype, forward, backward, need_input=False)
def test_expand_like():
verify_expand_like((3,), (3, 2), [1], False)
verify_expand_like((2,), (2, 3), [1], False)
verify_expand_like((3, 4), (3, 5, 4), [1], False)
verify_expand_like((5, 7), (5, 6, 7, 8), [0, 2], True)
def verify_elemwise_sum(num_args):
s = [sym.Variable("input" + str(i)) for i in range(num_args)]
y = sym.elemwise_sum(*s, num_args=num_args)
def forward(**inputs):
return np.sum(np.array(list(inputs.values())), axis=0)
def backward(head_grads, **inputs):
return [head_grads] * num_args
dtype = "float32"
inputs = [("input" + str(i), (3, 4, 5), s[i])
for i in range(num_args)]
helper(y, inputs, dtype, forward, backward, need_input=False)
def test_elemwise_sum():
verify_elemwise_sum(1)
verify_elemwise_sum(5)
verify_elemwise_sum(7)
def test_block_grad():
x = sym.Variable("x")
y = sym.block_grad(x)
def forward(x):
return x
def backward(head_grads, x):
return [np.zeros_like(head_grads)]
dtype = "float32"
inputs = [('x', (3, 4, 5), x)]
helper(y, inputs, dtype, forward, backward, need_head_grads=False)
def test_full():
shape = (3, 4, 5)
value = 7
dtype = "float32"
for target, ctx in ctx_list():
data = sym.Variable("data", dtype=dtype)
# full_like
s = sym.full_like(data=data, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
m = graph_runtime.create(graph, lib, ctx)
m.run(data=np.random.uniform(size=shape).astype(dtype))
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=value, dtype=dtype),
atol=1e-5, rtol=1e-5)
# ones_like
s = sym.ones_like(data=data, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
m = graph_runtime.create(graph, lib, ctx)
m.run(data=np.random.uniform(size=shape).astype(dtype))
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=1, dtype=dtype),
atol=1e-5, rtol=1e-5)
# zeros_like
s = sym.zeros_like(data=data, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target, {"data": shape})
m = graph_runtime.create(graph, lib, ctx)
m.run(data=np.random.uniform(size=shape).astype(dtype))
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=0, dtype=dtype),
atol=1e-5, rtol=1e-5)
# full
s = sym.full(shape=shape, dtype=dtype, fill_value=value, name="s")
graph, lib, _ = nnvm.compiler.build(s, target)
m = graph_runtime.create(graph, lib, ctx)
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=value, dtype=dtype),
atol=1e-5, rtol=1e-5)
# ones
s = sym.ones(shape=shape, dtype=dtype, name="s")
graph, lib, _ = nnvm.compiler.build(s, target)
m = graph_runtime.create(graph, lib, ctx)
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=1, dtype=dtype),
atol=1e-5, rtol=1e-5)
# zeros
s = sym.zeros(shape=shape, dtype=dtype, name="s")
graph, lib, _ = nnvm.compiler.build(s, target)
m = graph_runtime.create(graph, lib, ctx)
m.run()
out = m.get_output(0, tvm.nd.empty(shape, dtype=dtype))
np.testing.assert_allclose(
out.asnumpy(),
np.full(shape, fill_value=0, dtype=dtype),
atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
test_reshape()
test_reduce()
test_tranpose()
test_clip()
test_greater()
test_less()
test_reshape_like()
test_expand_like()
test_elemwise_sum()
test_block_grad()
test_full()
print(nnvm.compiler.engine.dump())