This commit is contained in:
Ivan Rodriguez 2016-09-29 17:03:57 +02:00
Родитель 391432ca77
Коммит a5edcb3a41
31 изменённых файлов: 1176 добавлений и 776 удалений

Просмотреть файл

@ -23,6 +23,7 @@ from .cntk_py import DeviceDescriptor, momentums_per_sample
DATATYPE = np.float32
class Trainer(cntk_py.Trainer):
'''
Trainer to train the specified `model` with the specified `training_loss`
@ -37,6 +38,7 @@ class Trainer(cntk_py.Trainer):
eval_function (`:class:cntk.ops.Function`): evaluation function
parameter_learners (`list`): list of learners from `:cntk:cntk.learners`
'''
def __init__(self, model, loss_function, eval_function, parameter_learners):
if isinstance(model, cntk_py.Variable):
model = model.owner
@ -45,7 +47,7 @@ class Trainer(cntk_py.Trainer):
if isinstance(eval_function, cntk_py.Variable):
eval_function = eval_function.owner
super(Trainer, self).__init__(model, loss_function, eval_function,
parameter_learners)
parameter_learners)
def train_minibatch(self, arguments, device=None):
'''
@ -63,7 +65,7 @@ class Trainer(cntk_py.Trainer):
`bool`: `True` if updates have been performed
'''
if not device:
device=DeviceDescriptor.use_default_device()
device = DeviceDescriptor.use_default_device()
arguments = sanitize_var_map(arguments, add_batch_axis=True)
return super(Trainer, self).train_minibatch(arguments, device)
@ -85,8 +87,7 @@ class Trainer(cntk_py.Trainer):
tested minibatch.
'''
if not device:
device=DeviceDescriptor.use_default_device()
device = DeviceDescriptor.use_default_device()
arguments = sanitize_var_map(arguments, add_batch_axis=True)
return super(Trainer, self).test_minibatch(arguments, device)

Просмотреть файл

@ -32,8 +32,8 @@ class MinibatchSource(cntk_py.MinibatchSource):
same name.
'''
return super(MinibatchSource, self).stream_info(name)
def get_next_minibatch(self, minibatch_size_in_samples, device = None):
def get_next_minibatch(self, minibatch_size_in_samples, device=None):
'''
Reads a minibatch that contains data for all input streams.
The minibatch size is specified terms of #samples and/or #sequences for the primary input stream; value of 0 for #samples/#sequences means unspecified.
@ -47,9 +47,9 @@ class MinibatchSource(cntk_py.MinibatchSource):
if device is None:
device = cntk_py.DeviceDescriptor.use_default_device()
return super(MinibatchSource, self).get_next_minibatch(\
minibatch_size_in_samples,
minibatch_size_in_sequences, device)
return super(MinibatchSource, self).get_next_minibatch(
minibatch_size_in_samples,
minibatch_size_in_sequences, device)
def _py_dict_to_cntk_dict(py_dict):
@ -60,23 +60,25 @@ def _py_dict_to_cntk_dict(py_dict):
Returns:
:class:`cntk_py.Dictionary`
'''
res = cntk_py.Dictionary();
for k,v in py_dict.items():
if isinstance(v,dict):
res = cntk_py.Dictionary()
for k, v in py_dict.items():
if isinstance(v, dict):
res[k] = cntk_py.DictionaryValueFromDict(_py_dict_to_cntk_dict(v))
#TODO: add support to list of lists ?
elif isinstance(v,list):
# TODO: add support to list of lists ?
elif isinstance(v, list):
l = list()
for e in v:
if isinstance(e,dict):
l.append(cntk_py.DictionaryValueFromDict(_py_dict_to_cntk_dict(e)))
if isinstance(e, dict):
l.append(cntk_py.DictionaryValueFromDict(
_py_dict_to_cntk_dict(e)))
else:
l.append(cntk_py.DictionaryValue(v))
res[k] = cntk_py.DictionaryValue(l)
else:
res[k] = cntk_py.DictionaryValue(v)
return res
def minibatch_source(config):
'''
Instantiate the CNTK built-in composite minibatch source which is used to stream data into the network.
@ -88,6 +90,7 @@ def minibatch_source(config):
cntk_dict = _py_dict_to_cntk_dict(config)
return cntk_py.create_composite_minibatch_source(cntk_dict)
class ReaderConfig(dict):
'''
Reader configuration.
@ -98,13 +101,14 @@ class ReaderConfig(dict):
randomize (`bool`, default True): randomize images before every epoch
epoch_size (`int`): epoch size
'''
def __init__(self, deserializers=None, randomize=True, epoch_size=MAX_UI64):
self['epochSize'] = epoch_size
if not isinstance(deserializers, (list, tuple)):
deserializers = [deserializers]
self['deserializers'] = self.deserializers = deserializers or []
self['randomize'] = randomize;
self['randomize'] = randomize
def minibatch_source(self):
'''
@ -117,6 +121,7 @@ class ReaderConfig(dict):
'''
return minibatch_source(self)
class Deserializer(dict):
'''
Base deserializer class that can be used in the `:class:ReaderConfig`.
@ -124,9 +129,11 @@ class Deserializer(dict):
Args:
type (`str`): type of the deserializer
'''
def __init__(self, type):
self['type'] = type
class ImageDeserializer(Deserializer):
'''
This class configures the image reader that reads images and corresponding
@ -143,6 +150,7 @@ class ImageDeserializer(Deserializer):
See also:
https://github.com/microsoft/cntk/wiki/Understanding-and-Extending-Readers
'''
def __init__(self, filename):
super(ImageDeserializer, self).__init__('ImageDeserializer')
self['file'] = filename
@ -217,7 +225,7 @@ class ImageDeserializer(Deserializer):
'''
trans = {}
trans['type'] = 'Scale'
trans['width'] = width
trans['width'] = width
trans['height'] = height
trans['channels'] = channels
trans['interpolations'] = interpolations
@ -248,6 +256,7 @@ class ImageDeserializer(Deserializer):
# similarly to ImageDeserializer
#
def text_format_minibatch_source(path, stream_configs, epoch_size=MAX_UI64):
'''
Creates a minibatch source from a CNTKTextFormatReader file.
@ -264,7 +273,8 @@ def text_format_minibatch_source(path, stream_configs, epoch_size=MAX_UI64):
`:class:cntk.io.MinibatchSource'
'''
return cntk_py.text_format_minibatch_source(path, stream_configs,
epoch_size)
epoch_size)
class StreamConfiguration(cntk_py.StreamConfiguration):
'''
@ -281,6 +291,6 @@ class StreamConfiguration(cntk_py.StreamConfiguration):
stream_alias (`str`, default ''): name of the stream in the file that is fed to the
`:func:cntk.io.text_format_minibatch_source`
'''
def __init__(self, name, dim, is_sparse=False, stream_alias=''):
return super(StreamConfiguration, self).__init__(name, dim, is_sparse, stream_alias)

Просмотреть файл

@ -8,6 +8,7 @@ from . import sequence
from .functions import Function
from ..utils import sanitize_input, sanitize_shape, get_data_type, sanitize_axis, sanitize_dynamic_axes
def combine(operands, name=''):
'''
Create a new Function instance which just combines the outputs of the specified list of
@ -35,9 +36,10 @@ def combine(operands, name=''):
return combine(converted_operands, name)
################################################################################
##########################################################################
# evaluation ops
################################################################################
##########################################################################
def cross_entropy_with_softmax(output_vector, target_vector, name=''):
'''
@ -69,6 +71,7 @@ def cross_entropy_with_softmax(output_vector, target_vector, name=''):
target_vector = sanitize_input(target_vector, dtype)
return cross_entropy_with_softmax(output_vector, target_vector, name)
def squared_error(output, target, name=''):
'''
This operation computes the sum of the squared difference between elements
@ -98,6 +101,7 @@ def squared_error(output, target, name=''):
target = sanitize_input(target, dtype)
return squared_error(output, target, name)
def classification_error(output_vector, target_vector, name=''):
'''
This operation computes the classification_error error. It finds the index of the highest
@ -127,9 +131,10 @@ def classification_error(output_vector, target_vector, name=''):
target_vector = sanitize_input(target_vector, dtype)
return classification_error(output_vector, target_vector, name)
################################################################################
##########################################################################
# convolution ops
################################################################################
##########################################################################
def convolution(convolution_map, operand, strides=(1,), sharing=[True],
auto_padding=[True], lower_pad=(0,), upper_pad=(0,), transpose=False,
@ -176,12 +181,14 @@ def convolution(convolution_map, operand, strides=(1,), sharing=[True],
from cntk.cntk_py import convolution
operand = sanitize_input(operand)
return convolution(convolution_map, operand, tuple(reversed(strides)), sharing, auto_padding,
tuple(reversed(lower_pad)), tuple(reversed(upper_pad)), transpose,
max_temp_mem_size_in_samples, name)
tuple(reversed(lower_pad)), tuple(
reversed(upper_pad)), transpose,
max_temp_mem_size_in_samples, name)
from cntk.cntk_py import PoolingType_Max, PoolingType_Average
MAX_POOLING = PoolingType_Max
AVG_POOLING = PoolingType_Average
from cntk.cntk_py import PoolingType_Max,PoolingType_Average
MAX_POOLING=PoolingType_Max
AVG_POOLING=PoolingType_Average
def pooling(operand, pooling_type, pooling_window_shape, strides=(1,), auto_padding=[False],
lower_pad=(0,), upper_pad=(0,), name=''):
@ -213,6 +220,7 @@ def pooling(operand, pooling_type, pooling_window_shape, strides=(1,), auto_padd
return pooling(operand, pooling_type, pooling_window_shape, strides, auto_padding,
lower_pad, upper_pad, name)
def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spatial,
normalization_time_constant=0, blend_time_constant=0,
epsilon=0.00001, use_cudnn_engine=False, name=''):
@ -248,12 +256,13 @@ def batch_normalization(operand, scale, bias, running_mean, running_inv_std, spa
from cntk.cntk_py import batch_normalization
operand = sanitize_input(operand)
return batch_normalization(operand, scale, bias, running_mean, running_inv_std, spatial,
normalization_time_constant, blend_time_constant,
epsilon, use_cudnn_engine, name)
normalization_time_constant, blend_time_constant,
epsilon, use_cudnn_engine, name)
################################################################################
##########################################################################
# comparison ops
################################################################################
##########################################################################
def less(left, right, name=''):
'''
@ -279,6 +288,7 @@ def less(left, right, name=''):
right = sanitize_input(right, dtype)
return less(left, right, name)
def equal(left, right, name=''):
'''
Elementwise 'equal' comparison of two tensors. Result is 1 if values are equal 0 otherwise.
@ -303,6 +313,7 @@ def equal(left, right, name=''):
right = sanitize_input(right, dtype)
return equal(left, right, name)
def greater(left, right, name=''):
'''
Elementwise 'greater' comparison of two tensors. Result is 1 if left > right else 0.
@ -327,6 +338,7 @@ def greater(left, right, name=''):
right = sanitize_input(right, dtype)
return greater(left, right, name)
def greater_equal(left, right, name=''):
'''
Elementwise 'greater equal' comparison of two tensors. Result is 1 if left >= right else 0.
@ -351,6 +363,7 @@ def greater_equal(left, right, name=''):
right = sanitize_input(right, dtype)
return greater_equal(left, right, name)
def not_equal(left, right, name=''):
'''
Elementwise 'not equal' comparison of two tensors. Result is 1 if left != right else 0.
@ -375,6 +388,7 @@ def not_equal(left, right, name=''):
right = sanitize_input(right, dtype)
return not_equal(left, right, name)
def less_equal(left, right, name=''):
'''
Elementwise 'less equal' comparison of two tensors. Result is 1 if left <= right else 0.
@ -399,9 +413,10 @@ def less_equal(left, right, name=''):
right = sanitize_input(right, dtype)
return less_equal(left, right, name)
################################################################################
##########################################################################
# linear ops
################################################################################
##########################################################################
def plus(left, right, name=''):
'''
@ -429,6 +444,7 @@ def plus(left, right, name=''):
right = sanitize_input(right, dtype)
return plus(left, right, name)
def minus(left, right, name=''):
'''
The output of this operation is left minus right tensor. It supports broadcasting.
@ -457,6 +473,7 @@ def minus(left, right, name=''):
right = sanitize_input(right, dtype)
return minus(left, right, name)
def element_times(left, right, name=''):
'''
The output of this operation is the element-wise product of the two input
@ -484,6 +501,7 @@ def element_times(left, right, name=''):
right = sanitize_input(right, dtype)
return element_times(left, right, name)
def element_divide(left, right, name=''):
'''
The output of this operation is the element-wise division of the two input
@ -514,6 +532,7 @@ def element_divide(left, right, name=''):
right = sanitize_input(right, dtype)
return element_divide(left, right, name)
def times(left, right, output_rank=1, name=''):
'''
The output of this operation is the matrix product of the two input matrices.
@ -560,9 +579,9 @@ def times(left, right, output_rank=1, name=''):
right = sanitize_input(right, dtype)
return times(right, left, output_rank, name)
################################################################################
##########################################################################
# non_diff ops
################################################################################
##########################################################################
def floor(arg, name=''):
@ -595,6 +614,7 @@ def floor(arg, name=''):
arg = sanitize_input(arg, get_data_type(arg))
return floor(arg, name)
def ceil(arg, name=''):
'''
The output of this operation is the element wise value rounded to the smallest
@ -618,6 +638,7 @@ def ceil(arg, name=''):
arg = sanitize_input(arg, get_data_type(arg))
return ceil(arg, name)
def round(arg, name=''):
'''
The output of this operation is the element wise value rounded to the nearest integer.
@ -651,11 +672,13 @@ def round(arg, name=''):
arg = sanitize_input(arg, get_data_type(arg))
return round(arg, name)
################################################################################
##########################################################################
# non_linear and nn ops
################################################################################
##########################################################################
# TODO: enable when it is exposed in c++
#TODO: enable when it is exposed in c++
def clip(x, min_value, max_value, name=''):
'''
Computes a tensor with all of its values clipped to fall
@ -690,6 +713,7 @@ def clip(x, min_value, max_value, name=''):
max_value = sanitize_input(max_value, get_data_type(max_value))
return clip(x, min_value, max_value, name)
def relu(x, name=''):
'''
Rectified linear operation. Computes the element-wise rectified linear
@ -711,6 +735,7 @@ def relu(x, name=''):
x = sanitize_input(x)
return re_lu(x, name)
def sigmoid(x, name=''):
'''
Computes the element-wise sigmoid of `x`:
@ -733,6 +758,7 @@ def sigmoid(x, name=''):
x = sanitize_input(x)
return sigmoid(x, name)
def tanh(x, name=''):
'''
Computes the element-wise tanh of `x`:
@ -754,6 +780,7 @@ def tanh(x, name=''):
x = sanitize_input(x)
return tanh(x, name)
def softmax(x, name=''):
'''
Squashes the input values `x` such that they add up to 1:
@ -780,6 +807,7 @@ def softmax(x, name=''):
x = sanitize_input(x)
return softmax(x)
def hardmax(x, name=''):
'''
TBA
@ -796,6 +824,7 @@ def hardmax(x, name=''):
x = sanitize_input(x)
return hardmax(x)
def hardmax(x, name=''):
'''
TBA
@ -812,6 +841,7 @@ def hardmax(x, name=''):
x = sanitize_input(x)
return hardmax(x)
def exp(x, name=''):
'''
Computes the element-wise exponential of `x`:
@ -832,6 +862,7 @@ def exp(x, name=''):
x = sanitize_input(x)
return exp(x, name)
def log(x, name=''):
'''
Computes the element-wise the natural logarithm of `x`:
@ -856,6 +887,7 @@ def log(x, name=''):
x = sanitize_input(x)
return log(x, name)
def sqrt(x, name=''):
'''
Computes the element-wise square-root of `x`:
@ -880,6 +912,7 @@ def sqrt(x, name=''):
x = sanitize_input(x)
return sqrt(x, name)
def square(x, name=''):
'''
Computes the element-wise square of `x`:
@ -898,6 +931,7 @@ def square(x, name=''):
x = sanitize_input(x)
return square(x, name)
def abs(x, name=''):
'''
Computes the element-wise absolute of `x`:
@ -918,6 +952,7 @@ def abs(x, name=''):
x = sanitize_input(x)
return abs(x, name)
def negate(x, name=''):
'''
Computes the element-wise negation of `x`:
@ -938,6 +973,7 @@ def negate(x, name=''):
x = sanitize_input(x)
return negate(x, name)
def reciprocal(x, name=''):
'''
Computes the element-wise reciprocal of `x`:
@ -956,6 +992,7 @@ def reciprocal(x, name=''):
x = sanitize_input(x)
return reciprocal(x, name)
def element_select(flag, value_if_true, value_if_false, name=''):
'''
return either value_if_true or value_if_false based on the value of flag.
@ -981,13 +1018,14 @@ def element_select(flag, value_if_true, value_if_false, name=''):
value_if_false = sanitize_input(value_if_false)
return element_select(flag, value_if_true, value_if_false, name)
################################################################################
##########################################################################
# recurrent ops
################################################################################
##########################################################################
# TODO: add default value for initial_state. It should be a constant scalar
# (0.0), using the default device
def future_value(x, initial_state=None, time_step=1, name=''):
'''
This function returns the future value w.r.t. `x`. It is most often used when
@ -1019,6 +1057,7 @@ def future_value(x, initial_state=None, time_step=1, name=''):
x = sanitize_input(x)
return future_value(x, initial_state, time_step, name)
def past_value(x, initial_state=None, time_step=1, name=''):
'''
This function returns the past value w.r.t. `x`. It is most often used when
@ -1050,11 +1089,13 @@ def past_value(x, initial_state=None, time_step=1, name=''):
x = sanitize_input(x)
return past_value(x, initial_state, time_step, name)
################################################################################
##########################################################################
# reshaping ops
################################################################################
##########################################################################
# TODO: enable when it is exposed in c++
#TODO: enable when it is exposed in c++
def reshape(x, shape, name=''):
'''
Reinterpret input samples as having different tensor dimensions
@ -1077,8 +1118,9 @@ def reshape(x, shape, name=''):
Returns:
:class:`cntk.Function`
'''
if np.any(np.asarray(shape)<0):
# TODO decide on whether -1 instead of 0 should be used to infer the dimension
if np.any(np.asarray(shape) < 0):
# TODO decide on whether -1 instead of 0 should be used to infer the
# dimension
raise ValueError('shape dimensions cannot be negative')
from cntk.cntk_py import reshape
@ -1087,6 +1129,7 @@ def reshape(x, shape, name=''):
return reshape(x, shape, name)
def transpose(x, axis1=0, axis2=1, name=''):
'''
Reverses two axes of the tensor. The output tensor has the same data but with
@ -1112,6 +1155,7 @@ def transpose(x, axis1=0, axis2=1, name=''):
axis2 = sanitize_axis(rank, axis2)
return transpose_axes(x, axis1, axis2, name)
def slice(x, axis, begin_index, end_index, name=''):
'''
Slice the input along an axis.
@ -1169,7 +1213,9 @@ def slice(x, axis, begin_index, end_index, name=''):
axis = sanitize_axis(x.shape().rank(), axis)
return slice(x, axis, begin_index, end_index, name)
#TODO: enable when it is exposed in c++
# TODO: enable when it is exposed in c++
def splice(inputs, axis=0, name=''):
'''
Concatenate the input tensors along an axis.
@ -1212,9 +1258,10 @@ def splice(inputs, axis=0, name=''):
return splice(inputs, axis, name)
################################################################################
##########################################################################
# reduction ops
################################################################################
##########################################################################
def reduce_sum(x, axis=None, name=''):
'''
@ -1254,6 +1301,7 @@ def reduce_sum(x, axis=None, name=''):
axis = sanitize_axis(x.shape().rank(), axis)
return reduce_sum(x, axis, name)
def reduce_log_sum(x, axis, name=''):
'''
Computes the log sum of the input tensor's elements across the specified axis.
@ -1274,6 +1322,7 @@ def reduce_log_sum(x, axis, name=''):
axis = sanitize_axis(x.shape().rank(), axis)
return reduce_log_sum(x, axis, name)
def reduce_mean(x, axis, name=''):
'''
Computes the mean of the input tensor's elements across the specified axis.
@ -1294,6 +1343,7 @@ def reduce_mean(x, axis, name=''):
axis = sanitize_axis(x.shape().rank(), axis)
return reduce_mean(x, axis, name)
def reduce_max(x, axis, name=''):
'''
Computes the max of the input tensor's elements across the specified axis.
@ -1314,6 +1364,7 @@ def reduce_max(x, axis, name=''):
axis = sanitize_axis(x.shape().rank(), axis)
return reduce_max(x, axis, name)
def reduce_min(x, axis, name=''):
'''
Computes the min of the input tensor's elements across the specified axis.
@ -1334,9 +1385,10 @@ def reduce_min(x, axis, name=''):
axis = sanitize_axis(x.shape().rank(), axis)
return reduce_min(x, axis, name)
################################################################################
##########################################################################
# training ops
################################################################################
##########################################################################
def dropout(x, dropout_rate=0.0, name=''):
'''
@ -1346,7 +1398,7 @@ def dropout(x, dropout_rate=0.0, name=''):
The output tensor has the same shape as `x`, but with `dropout_rate` of the
elements set to zero (dropped out).
Args:
x: input tensor
@ -1356,7 +1408,7 @@ def dropout(x, dropout_rate=0.0, name=''):
Returns:
FIXME also in all of the other cases :class:`cntk.Function`
'''
if dropout_rate<0.0 or dropout_rate>=1.0:
if dropout_rate < 0.0 or dropout_rate >= 1.0:
raise ValueError('dropout_rate must be in the interval [0,1)')
from cntk.cntk_py import dropout
@ -1364,18 +1416,20 @@ def dropout(x, dropout_rate=0.0, name=''):
return dropout(x, dropout_rate, name)
################################################################################
##########################################################################
# variables_and_parameters ops
################################################################################
##########################################################################
from cntk.cntk_py import Axis, DeviceDescriptor
#TODO: expose output_variable as well ?
# TODO: expose output_variable as well ?
# TODO: if we end up using only factory methods, we should get rid of the
# class Variable in variables.py
#TODO: if we end up using only factory methods, we should get rid of the class Variable in variables.py
def input_variable(shape, data_type=np.float32, needs_gradient=True, is_sparse=False,
dynamic_axes = Axis.default_input_variable_dynamic_axes, name=''):
dynamic_axes=Axis.default_input_variable_dynamic_axes, name=''):
'''
It creates an input node.
@ -1406,7 +1460,7 @@ def input_variable(shape, data_type=np.float32, needs_gradient=True, is_sparse=F
return input_variable(shape, is_sparse, dtype, needs_gradient, name, dynamic_axes)
def placeholder_variable(shape, dynamic_axes = [Axis.default_dynamic_axis(), Axis.default_batch_axis()]):
def placeholder_variable(shape, dynamic_axes=[Axis.default_dynamic_axis(), Axis.default_batch_axis()]):
'''
It creates a variable place holder for recurrence networks, when the network's dynamic axes
are unfolded, the place holder will get assigned a variable along the correspondent dynamic axis.
@ -1423,6 +1477,7 @@ def placeholder_variable(shape, dynamic_axes = [Axis.default_dynamic_axis(), Axi
dynamic_axes = sanitize_dynamic_axes(dynamic_axes)
return placeholder_variable(shape, dynamic_axes)
def parameter(shape=None, value=None, initializer=None, device=None, name=''):
'''
It creates a parameter tensor.
@ -1444,7 +1499,7 @@ def parameter(shape=None, value=None, initializer=None, device=None, name=''):
from .variables import Parameter
if not device:
device=DeviceDescriptor.use_default_device()
device = DeviceDescriptor.use_default_device()
if np.isscalar(value) and not shape:
shape = ()
@ -1457,6 +1512,7 @@ def parameter(shape=None, value=None, initializer=None, device=None, name=''):
return Parameter(shape, value, data_type, initializer, device, name)
def constant(shape=None, value=None, device=None, name=''):
'''
It creates a constant tensor initialized from a numpy array
@ -1474,7 +1530,7 @@ def constant(shape=None, value=None, device=None, name=''):
'''
from .variables import Constant
if not device:
device=DeviceDescriptor.use_default_device()
device = DeviceDescriptor.use_default_device()
if np.isscalar(value) and not shape:
shape = ()
if isinstance(value, np.ndarray):
@ -1486,11 +1542,12 @@ def constant(shape=None, value=None, device=None, name=''):
return Constant(shape, value, data_type, device, name)
################################################################################
##########################################################################
# normalization ops
################################################################################
##########################################################################
# TODO: ComputeInputPerDimMeansAndInvStdDevs
#TODO: ComputeInputPerDimMeansAndInvStdDevs
def per_dim_mean_variance_normalize(operand, mean, inv_stddev, name=''):
'''

Просмотреть файл

@ -6,13 +6,15 @@
import numpy as np
from ...utils import sanitize_input, sanitize_shape, get_data_type
################################################################################
##########################################################################
# sequence ops
################################################################################
def is_first(operand, name = ''):
##########################################################################
def is_first(operand, name=''):
'''
TBA
Example:
TBA
Args:
@ -20,15 +22,16 @@ def is_first(operand, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import is_first
operand = sanitize_input(operand, get_data_type(operand))
return is_first(operand, name)
def is_last(operand, name = ''):
def is_last(operand, name=''):
'''
TBA
Example:
TBA
Args:
@ -36,15 +39,16 @@ def is_last(operand, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import is_last
operand = sanitize_input(operand, get_data_type(operand))
return is_last(operand, name)
def first(operand, name = ''):
def first(operand, name=''):
'''
TBA
Example:
TBA
Args:
@ -52,15 +56,16 @@ def first(operand, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import first
operand = sanitize_input(operand, get_data_type(operand))
return first(operand, name)
def last(operand, name = ''):
def last(operand, name=''):
'''
TBA
Example:
TBA
Args:
@ -68,15 +73,16 @@ def last(operand, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import last
operand = sanitize_input(operand, get_data_type(operand))
return last(operand, name)
def where(condition, name = ''):
def where(condition, name=''):
'''
TBA
Example:
TBA
Args:
@ -84,15 +90,16 @@ def where(condition, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import where
condition = sanitize_input(condition, get_data_type(condition))
return where(condition, name)
def gather(operand, condition, name = ''):
def gather(operand, condition, name=''):
'''
TBA
Example:
TBA
Args:
@ -101,16 +108,17 @@ def gather(operand, condition, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import gather
operand = sanitize_input(operand, get_data_type(operand))
condition = sanitize_input(condition, get_data_type(condition))
return gather(operand, condition, name)
def scatter(operand, condition, name = ''):
def scatter(operand, condition, name=''):
'''
TBA
Example:
TBA
Args:
@ -119,16 +127,17 @@ def scatter(operand, condition, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import scatter
operand = sanitize_input(operand, get_data_type(operand))
condition = sanitize_input(condition, get_data_type(condition))
return scatter(operand, condition, name)
def broadcast_as(operand, broadcast_as_operand, name = ''):
def broadcast_as(operand, broadcast_as_operand, name=''):
'''
TBA
Example:
TBA
Args:
@ -137,8 +146,9 @@ def broadcast_as(operand, broadcast_as_operand, name = ''):
name (str): the name of the node in the network
Returns:
:class:`cntk.Function`
'''
'''
from cntk.cntk_py import broadcast_as
operand = sanitize_input(operand, get_data_type(operand))
broadcast_as_operand = sanitize_input(broadcast_as_operand, get_data_type(broadcast_as_operand))
broadcast_as_operand = sanitize_input(
broadcast_as_operand, get_data_type(broadcast_as_operand))
return broadcast_as(operand, broadcast_as_operand, name)

Просмотреть файл

@ -16,7 +16,7 @@ from .ops_test_utils import _test_binary_op, AA, precision, PRECISION_TO_TYPE
TENSOR_PAIRS = [
([41., 42., 43., 42., 42., 42.], [42., 42., 42., 41., 42., 43.]),
]
]
from cntk import equal, less, less_equal, greater, greater_equal, not_equal
@ -27,7 +27,7 @@ FUNCTIONS_TO_TEST = [
(greater, np.greater),
(greater_equal, np.greater_equal),
(not_equal, np.not_equal),
]
]
test_parameters = []
import itertools as itt
@ -36,16 +36,18 @@ for functions_to_test, tensor_pairs in itt.product(FUNCTIONS_TO_TEST, TENSOR_PAI
left_op, right_op = tensor_pairs
test_parameters.append((cntk_func, numpy_func, left_op, right_op))
@pytest.mark.parametrize("cntk_function, numpy_function, left_operand, right_operand", test_parameters)
def test_op_comparison(left_operand, right_operand, cntk_function, numpy_function, device_id, precision):
dt = PRECISION_TO_TYPE[precision]
expected_forward = [numpy_function(AA([left_operand], dtype=dt ),AA([right_operand], dtype=dt))]
expected_forward = [numpy_function(
AA([left_operand], dtype=dt), AA([right_operand], dtype=dt))]
expected_backward = {
'left_arg': [[np.zeros_like(left_operand, dtype=dt)]],
'right_arg': [[np.zeros_like(left_operand, dtype=dt)]]
}
'left_arg': [[np.zeros_like(left_operand, dtype=dt)]],
'right_arg': [[np.zeros_like(left_operand, dtype=dt)]]
}
_test_binary_op(precision, device_id, cntk_function, left_operand, right_operand,
expected_forward, expected_backward)
expected_forward, expected_backward)

Просмотреть файл

@ -18,9 +18,12 @@ TARGET_OUT_PAIRS = [
([[0., 0., 0., 1]], [[1., 2., 3., 4.]]),
([[0., 0., 0.5, 0.5]], [[1., 2., 3., 4.]]),
([[0., 0.4, 0.3, 0.3]], [[2., 1., 1., 4.]])
]
]
# TODO: Enable tests when 0d arrays are correctly handled for backward
# propagation e. g. array(29.0)
# TODO: Enable tests when 0d arrays are correctly handled for backward propagation e. g. array(29.0)
@pytest.mark.parametrize("target_vector, output_vector", TARGET_OUT_PAIRS)
def _test_op_cross_entropy_with_soft_max(output_vector, target_vector, device_id, precision):
dt = PRECISION_TO_TYPE[precision]
@ -35,17 +38,18 @@ def _test_op_cross_entropy_with_soft_max(output_vector, target_vector, device_id
expected_forward = [-np.sum(t * np.log(s_max, dtype=dt), dtype=dt)]
s = np.sum(t, dtype=dt)
backward = np.subtract(s_max * s , t)
backward = np.subtract(s_max * s, t)
expected_backward = {
'left_arg': backward,
'right_arg': backward
}
'left_arg': backward,
'right_arg': backward
}
from .. import cross_entropy_with_softmax
_test_binary_op(precision, device_id, cross_entropy_with_softmax,
output_vector, target_vector,
expected_forward, expected_backward)
output_vector, target_vector,
expected_forward, expected_backward)
@pytest.mark.parametrize("target_vector, output_vector", TARGET_OUT_PAIRS)
def _test_op_squared_error(output_vector, target_vector, device_id, precision):
@ -54,24 +58,26 @@ def _test_op_squared_error(output_vector, target_vector, device_id, precision):
o = AA(output_vector, dtype=dt)
t = AA(target_vector, dtype=dt)
expected_forward = np.sum((t-o)**2)
expected_forward = np.sum((t - o)**2)
expected_backward = {
'left_arg': 2*np.subtract(t, o),
'right_arg': 2*np.subtract(o, t)
}
'left_arg': 2 * np.subtract(t, o),
'right_arg': 2 * np.subtract(o, t)
}
from .. import squared_error
_test_binary_op(precision, device_id, squared_error,
output_vector, target_vector,
expected_forward, expected_backward, True)
output_vector, target_vector,
expected_forward, expected_backward, True)
TARGET_OUT_PAIRS_EP = [
([[1., 0., 0., 0]], [[1., 2., 3., 4.]]),
([[0., 0., 0., 1]], [[1., 2., 3., 4.]]),
]
]
# -- ErrorPrediction with softmax operation tests --
@pytest.mark.parametrize("target_vector, output_vector", TARGET_OUT_PAIRS_EP)
def _test_op_classification_error(output_vector, target_vector, device_id, precision):
dt = PRECISION_TO_TYPE[precision]
@ -82,11 +88,11 @@ def _test_op_classification_error(output_vector, target_vector, device_id, preci
expected_forward = [np.argmax(t) != np.argmax(o)]
expected_backward = {
'left_arg': np.zeros_like(t),
'right_arg': np.zeros_like(t)
}
'left_arg': np.zeros_like(t),
'right_arg': np.zeros_like(t)
}
from .. import classification_error
_test_binary_op(precision, device_id, classification_error,
output_vector, target_vector,
expected_forward, expected_backward)
output_vector, target_vector,
expected_forward, expected_backward)

Просмотреть файл

@ -23,12 +23,13 @@ TENSOR_PAIRS = [
# [[10., 20.], [30., 40.], [1., 2.]]),
# Adding two 3x2 inputs of sequence length 1
([[30.,40.], [1.,2.], [0.1, 0.2]], [[10,20], [3,4], [-0.5, -0.4]]),
([[30., 40.], [1., 2.], [0.1, 0.2]], [[10, 20], [3, 4], [-0.5, -0.4]]),
]
# -- plus operation tests --
TENSOR_PAIRS_SCALAR = TENSOR_PAIRS + [(left, np.random.rand()) for left,right
in TENSOR_PAIRS]
TENSOR_PAIRS_SCALAR = TENSOR_PAIRS + [(left, np.random.rand()) for left, right
in TENSOR_PAIRS]
@pytest.mark.parametrize("left_operand, right_operand", TENSOR_PAIRS_SCALAR)
def test_op_plus(left_operand, right_operand, device_id, precision):
@ -36,128 +37,137 @@ def test_op_plus(left_operand, right_operand, device_id, precision):
if np.isscalar(right_operand):
expected_backward = {
'left_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in left_operand]]],
# gradients are accumulated
'right_arg': [[AA([left_operand]).size]]
}
'left_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in left_operand]]],
# gradients are accumulated
'right_arg': [[AA([left_operand]).size]]
}
else:
expected_backward = {
'left_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in left_operand]]],
'right_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in right_operand]]]
}
'left_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in left_operand]]],
'right_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in right_operand]]]
}
from .. import plus
_test_binary_op(precision, device_id, plus,
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
_test_binary_op(precision, device_id, '+',
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
SEQ_TENSOR_PAIRS = [
# two inputs each having sequences of length 1 and 2
([[[30.]], [[40], [50]]], # first batch with two sequences
[[[ 3.]], [[ 4], [ 5]]]), # second batch with two sequences
[[[3.]], [[4], [5]]]), # second batch with two sequences
([[[30., 0]], [[40, 1], [50, 2]]], # first batch with two sequences
[[[ 3., -10]], [[ 4, -20], [ 5, -30]]]), # second batch with two sequences
[[[3., -10]], [[4, -20], [5, -30]]]), # second batch with two sequences
]
@pytest.mark.parametrize("left_batch, right_batch", SEQ_TENSOR_PAIRS)
def test_op_plus_var_sequences_input_input(left_batch, right_batch, device_id, precision):
from .. import plus
assert len(left_batch) == len(right_batch)
expected_forward = [AA(left_batch[i]) + AA(right_batch[i]) \
for i in range(len(left_batch))]
expected_forward = [AA(left_batch[i]) + AA(right_batch[i])
for i in range(len(left_batch))]
expected_backward = {
'left': ones_like(left_batch, PRECISION_TO_TYPE[precision]),
'right': ones_like(right_batch, PRECISION_TO_TYPE[precision])
}
'left': ones_like(left_batch, PRECISION_TO_TYPE[precision]),
'right': ones_like(right_batch, PRECISION_TO_TYPE[precision])
}
left_value = [AA(sample, dtype=PRECISION_TO_TYPE[precision]) for sample in left_batch]
left_value = [AA(sample, dtype=PRECISION_TO_TYPE[precision])
for sample in left_batch]
left_shape = left_value[0][0].shape
right_value = [AA(sample, dtype=PRECISION_TO_TYPE[precision]) for sample in right_batch]
right_value = [AA(sample, dtype=PRECISION_TO_TYPE[precision])
for sample in right_batch]
right_shape = right_value[0][0].shape
a = I(shape=left_shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
b = I(shape=right_shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='b')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='b')
input_op_input = plus(a, b)
forward_input = {a:left_value, b:right_value}
backward_input = { a: None, b: None }
expected_backward = { a: expected_backward['left'], b: expected_backward['right'], }
forward_input = {a: left_value, b: right_value}
backward_input = {a: None, b: None}
expected_backward = {a: expected_backward[
'left'], b: expected_backward['right'], }
unittest_helper(input_op_input,
forward_input, expected_forward,
expected_backward,
device_id, precision)
forward_input, expected_forward,
expected_backward,
device_id, precision)
# -- minus operation tests --
#TODO: enable once the function is exposed
# TODO: enable once the function is exposed
@pytest.mark.parametrize("left_operand, right_operand", TENSOR_PAIRS)
def test_op_minus(left_operand, right_operand, device_id, precision):
expected_forward = [AA([left_operand], dtype=PRECISION_TO_TYPE[precision]) - AA([right_operand], dtype=PRECISION_TO_TYPE[precision])]
expected_forward = [AA([left_operand], dtype=PRECISION_TO_TYPE[
precision]) - AA([right_operand], dtype=PRECISION_TO_TYPE[precision])]
expected_backward = {
'left_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in left_operand]]],
'right_arg': [[[-1*np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in right_operand]]]
}
'left_arg': [[[np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in left_operand]]],
'right_arg': [[[-1 * np.ones_like(x, dtype=PRECISION_TO_TYPE[precision]) for x in right_operand]]]
}
from .. import minus
_test_binary_op(precision, device_id, minus,
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
_test_binary_op(precision, device_id, '-',
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
# -- element times tests --
@pytest.mark.parametrize("left_operand, right_operand", TENSOR_PAIRS)
def test_op_element_times(left_operand, right_operand, device_id, precision):
expected_forward = [AA([left_operand]) * AA([right_operand])]
expected_backward = {
'left_arg': [[right_operand]],
'right_arg': [[left_operand]]
}
'left_arg': [[right_operand]],
'right_arg': [[left_operand]]
}
from .. import element_times
_test_binary_op(precision, device_id, element_times,
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
_test_binary_op(precision, device_id, '*',
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
# -- element divide tests --
#TODO: enable once the function is exposed
# TODO: enable once the function is exposed
@pytest.mark.parametrize("left_operand, right_operand", TENSOR_PAIRS)
def test_op_element_divide(left_operand, right_operand, device_id, precision):
expected_forward = [AA([left_operand]) / AA([right_operand])]
expected_backward = {
'left_arg': [[[np.ones_like(x) / x for x in right_operand]]],
'right_arg': [[-AA(left_operand, dtype=PRECISION_TO_TYPE[precision]) / AA(right_operand, dtype=PRECISION_TO_TYPE[precision])**2]]
}
'left_arg': [[[np.ones_like(x) / x for x in right_operand]]],
'right_arg': [[-AA(left_operand, dtype=PRECISION_TO_TYPE[precision]) / AA(right_operand, dtype=PRECISION_TO_TYPE[precision])**2]]
}
from .. import element_divide
_test_binary_op(precision, device_id, element_divide,
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
_test_binary_op(precision, device_id, '/',
left_operand, right_operand,
expected_forward, expected_backward)
left_operand, right_operand,
expected_forward, expected_backward)
# -- identity function tests --
@ -167,9 +177,10 @@ IDENTITY_TENSORS = [
([[30.]]),
([[1.5, 2.1]]),
([[100., 200.], [300., 400.], [10., 20.]]),
([[30,40], [1,2], [0.1, 0.2]])
([[30, 40], [1, 2], [0.1, 0.2]])
]
@pytest.mark.parametrize("operand", IDENTITY_TENSORS)
def test_op_negate(operand, device_id, precision):
t = -1 * AA(operand, dtype=PRECISION_TO_TYPE[precision])
@ -177,16 +188,16 @@ def test_op_negate(operand, device_id, precision):
expected_forward = [AA([t])]
expected_backward = {
'arg': [[-1*np.ones_like(operand, PRECISION_TO_TYPE[precision])]]
}
'arg': [[-1 * np.ones_like(operand, PRECISION_TO_TYPE[precision])]]
}
from cntk import negate
_test_unary_op(precision, device_id, negate, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
_test_unary_op(precision, device_id, '-', operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
TIMES_PAIRS = [
([[30.]], [[10.]]),
@ -196,10 +207,12 @@ TIMES_PAIRS = [
([[100., 200.], [300., 400.]], [[10., 20.], [20., 30.]])
]
#TODO: Handle sparse matrices
# TODO: Handle sparse matrices
@pytest.mark.parametrize("left_operand, right_operand", TIMES_PAIRS)
def test_op_times(left_operand, right_operand, device_id, precision,
left_matrix_type, right_matrix_type):
left_matrix_type, right_matrix_type):
dt_precision = PRECISION_TO_TYPE[precision]
a = AA(left_operand, dtype=dt_precision)
@ -210,17 +223,17 @@ def test_op_times(left_operand, right_operand, device_id, precision,
assert len(a.shape) == len(b.shape) == 2
left_backward = np.zeros_like(a)
left_backward[:,:] = b.sum(axis = 1)
left_backward[:, :] = b.sum(axis=1)
right_backward = np.zeros_like(b)
right_backward[:,:] = np.transpose([a.sum(axis = 0)])
right_backward[:, :] = np.transpose([a.sum(axis=0)])
expected_backward = {
'left_arg': [[left_backward]],
'right_arg': [[right_backward]]
}
'left_arg': [[left_backward]],
'right_arg': [[right_backward]]
}
from cntk import times
_test_binary_op(precision, device_id, times,
left_operand, right_operand, expected_forward, expected_backward)
left_operand, right_operand, expected_forward, expected_backward)

Просмотреть файл

@ -11,18 +11,20 @@ Unit tests for operations that are not differentiable.
from __future__ import division
import numpy as np
import pytest
from .ops_test_utils import unittest_helper, _test_unary_op, _test_binary_op, AA, I, precision, PRECISION_TO_TYPE
from .ops_test_utils import unittest_helper, _test_unary_op, _test_binary_op, AA, I, precision, PRECISION_TO_TYPE
TENSORS = [
([12.3, -12.3]),
([10.2, -10.2]),
([0.5, -0.5]),
([0.01, -0.01]),
([0.499, -0.499]),
([5.0, -5.0]),
([0.0]),
([[2.1, 9.9], [4.7, 5.3]])
([12.3, -12.3]),
([10.2, -10.2]),
([0.5, -0.5]),
([0.01, -0.01]),
([0.499, -0.499]),
([5.0, -5.0]),
([0.0]),
([[2.1, 9.9], [4.7, 5.3]])
]
@pytest.mark.parametrize("operand", TENSORS)
def test_op_floor(operand, device_id, precision):
operand = AA(operand)
@ -30,12 +32,13 @@ def test_op_floor(operand, device_id, precision):
expected_forward = [[expected]]
expected_backward = {
'arg': [[np.zeros_like(expected)]],
}
'arg': [[np.zeros_like(expected)]],
}
from .. import floor
_test_unary_op(precision, device_id, floor, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_ceil(operand, device_id, precision):
@ -44,48 +47,49 @@ def test_op_ceil(operand, device_id, precision):
expected_forward = [[expected]]
expected_backward = {
'arg': [[np.zeros_like(expected)]],
}
'arg': [[np.zeros_like(expected)]],
}
from .. import ceil
_test_unary_op(precision, device_id, ceil, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
# Manually setting the expectation since CNTK's round behaves differently than
# NumPy's round (see operator's docstring).
ROUND_TENSORS = [
([0.2, 1.3, 4.0, 5.5, 0.0],
[0.0, 1.0, 4.0, 6.0, 0.0]),
([0.2, 1.3, 4.0, 5.5, 0.0],
[0.0, 1.0, 4.0, 6.0, 0.0]),
([[0.6, 3.3], [1.9, 5.6]],
[[1.0, 3.0], [2.0, 6.0]]),
([[0.6, 3.3], [1.9, 5.6]],
[[1.0, 3.0], [2.0, 6.0]]),
([-5.5, -4.2, -3., -0.7, 0],
[-5.0, -4.0, -3., -1.0, 0]),
([-5.5, -4.2, -3., -0.7, 0],
[-5.0, -4.0, -3., -1.0, 0]),
([[-0.6, -4.3], [1.9, -3.2]],
[[-1.0, -4.0], [2.0, -3.0]]),
([[-0.6, -4.3], [1.9, -3.2]],
[[-1.0, -4.0], [2.0, -3.0]]),
# CNTK is always rounding up values starting at x.5, while numpy rounds
# to the nearest even value for half-integers
# Refer here: https://en.wikipedia.org/wiki/Rounding#Tie-breaking
# This test shows such values are not equal comparing numpy and CNTK
([0.5, 1.5, 2.5, 3.5],
# NumPy would round to
# [0.0, 2.0, 2.0, 4.0]))
# while CNTK rounds to
[1.0, 2.0, 3.0, 4.0])
# CNTK is always rounding up values starting at x.5, while numpy rounds
# to the nearest even value for half-integers
# Refer here: https://en.wikipedia.org/wiki/Rounding#Tie-breaking
# This test shows such values are not equal comparing numpy and CNTK
([0.5, 1.5, 2.5, 3.5],
# NumPy would round to
# [0.0, 2.0, 2.0, 4.0]))
# while CNTK rounds to
[1.0, 2.0, 3.0, 4.0])
]
@pytest.mark.parametrize("operand,expected", ROUND_TENSORS)
def test_op_round(operand, expected, device_id, precision):
operand, expected = AA(operand), AA(expected)
expected_forward = [[expected]]
expected_backward = {
'arg': [[np.zeros_like(expected)]],
}
'arg': [[np.zeros_like(expected)]],
}
from .. import round
_test_unary_op(precision, device_id, round, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)

Просмотреть файл

@ -15,60 +15,66 @@ import pytest
from .ops_test_utils import unittest_helper, _test_unary_op, _test_binary_op, AA, I, precision, PRECISION_TO_TYPE
EPS_IN_LOG = 1e-37 # 1e-37 is the highest guaranteed precision
BACKWARD_RESULST_FOR_LOG_EPS = 9.08782e+36 # the backward result returned by CNTK log() for epsilon
LOG_OF_EPS_IN_LOG = -85.1 # log(EPS_IN_LOG)
# the backward result returned by CNTK log() for epsilon
BACKWARD_RESULST_FOR_LOG_EPS = 9.08782e+36
LOG_OF_EPS_IN_LOG = -85.1 # log(EPS_IN_LOG)
CLIP_TUPLES = [
([1.0], [2.0], [1.5]), # value shouldn't be clipped; gradient is [1.0]
([1.0], [2.0], [0.5]), # value should be clipped to 1.0; gradient is [0.0]
([1.0], [2.0], [2.5]), # value should be clipped to 2.0; gradient is [0.0]
([1.0], [2.0], [1.5]), # value shouldn't be clipped; gradient is [1.0]
([1.0], [2.0], [0.5]), # value should be clipped to 1.0; gradient is [0.0]
([1.0], [2.0], [2.5]), # value should be clipped to 2.0; gradient is [0.0]
# should clip to [1.5, 2.0, 1.0]; gradient is [[1.0, 0.0, 0.0]]
([1.0], [2.0], [[1.5, 2.1, 0.9]]),
# should clip to [[1.0, 2.0], [1.0, 2.0], [1.5, 2.0]];
# gradient is [[0.0, 0.0], [1.0, 1.0], [1.0, 0.0]]
([1.0], [2.0], [[0.0, 3.0], [1.0, 2.0], [1.5, 2.5]]),
# test what happens if a user puts a higher "min" value than their "max" value
# should clip to [[5.0, 5.0, 5.0, 5.0, 5.0]] because min is evaluated first
# gradient should be all zeros: [[0.0, 0.0, 0.0, 0.0, 0.0]]
([5.0], [0.5], [[1.5, 2.1, 0.9, -1.0, -2.0]]),
# test a more complicated broadcasting scenario
([[1.5, 2.0], [2.5, 3.0]], [[-2.0, 2.5], [2.5, 3.5]], [[-1.0, 2.0], [3.0, 4.0]]),
]
]
@pytest.mark.parametrize("min_value, max_value, x", CLIP_TUPLES)
def test_op_clip(min_value, max_value, x, device_id, precision):
def test_op_clip(min_value, max_value, x, device_id, precision):
from .. import clip
expected_forward = [np.clip(AA([x], dtype=PRECISION_TO_TYPE[precision]), AA(min_value, dtype=PRECISION_TO_TYPE[precision]), AA(max_value, dtype=PRECISION_TO_TYPE[precision]))]
expected_forward = [np.clip(AA([x], dtype=PRECISION_TO_TYPE[precision]), AA(
min_value, dtype=PRECISION_TO_TYPE[precision]), AA(max_value, dtype=PRECISION_TO_TYPE[precision]))]
expected_backward = {
'arg': [[np.array(np.logical_not(np.logical_or(np.greater(x, max_value), np.less(x, min_value))), dtype=PRECISION_TO_TYPE[precision])]]
}
'arg': [[np.array(np.logical_not(np.logical_or(np.greater(x, max_value), np.less(x, min_value))), dtype=PRECISION_TO_TYPE[precision])]]
}
_test_unary_op(precision, device_id, clip, x,
expected_forward, expected_backward,
{'min_value': min_value, 'max_value': max_value})
expected_forward, expected_backward,
{'min_value': min_value, 'max_value': max_value})
TENSORS = [
([[0, -0.1]]),
([[-100, -10], [-1, -0.1], [-0.01, -0.001],
[0.001, 0.01], [0.1, 1], [10, 100]]),
]
@pytest.mark.parametrize("operand", TENSORS)
def test_op_sigmoid(operand, device_id, precision):
s = 1.0 / (1.0 + np.exp(-AA(operand, dtype=PRECISION_TO_TYPE[precision])))
expected_forward = [AA([s])]
expected_backward = {
'arg': [[s * (1 - s)]],
}
'arg': [[s * (1 - s)]],
}
from .. import sigmoid
_test_unary_op(precision, device_id, sigmoid, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_exp(operand, device_id, precision):
@ -76,12 +82,13 @@ def test_op_exp(operand, device_id, precision):
expected_forward = [AA([e])]
expected_backward = {
'arg': expected_forward,
}
'arg': expected_forward,
}
from .. import exp
_test_unary_op(precision, device_id, exp, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_abs(operand, device_id, precision):
@ -93,12 +100,12 @@ def test_op_abs(operand, device_id, precision):
backward = operand / np.abs(operand)
backward[np.isnan(backward)] = 0
expected_backward = {
'arg': [[backward]]
}
'arg': [[backward]]
}
from .. import abs
_test_unary_op(precision, device_id, abs, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
@ -107,14 +114,15 @@ def test_op_tanh(operand, device_id, precision):
expected_forward = [AA([t])]
expected_backward = {
'arg': [[1 - t**2]],
}
'arg': [[1 - t**2]],
}
from .. import tanh
_test_unary_op(precision, device_id, tanh, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("shape", [(3,9), (10,20,30)])
@pytest.mark.parametrize("shape", [(3, 9), (10, 20, 30)])
@pytest.mark.parametrize("dropout_rate", [0.0, 0.2, 0.5, 0.8])
def test_op_dropout(shape, dropout_rate, device_id, precision):
from cntk import dropout
@ -129,41 +137,43 @@ def test_op_dropout(shape, dropout_rate, device_id, precision):
value = np.ones(shape=shape, dtype=PRECISION_TO_TYPE[precision])
a = I(shape=value.shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
dropout_node = dropout(a, dropout_rate=dropout_rate)
value.shape = (1,1) + value.shape
forward_input = {a:value}
value.shape = (1, 1) + value.shape
forward_input = {a: value}
forward, backward = eval(dropout_node,
precision,
cntk_device(device_id),
forward_input,
backward_pass=True)
forward, backward = eval(dropout_node,
precision,
cntk_device(device_id),
forward_input,
backward_pass=True)
resulted_non_zeros += np.count_nonzero(forward[dropout_node.output()])
resulted_non_zeros /= count
num_elements = np.multiply.reduce(shape)
expected_non_zeros = num_elements * (1-dropout_rate)
max_off = 0.2*num_elements
expected_non_zeros = num_elements * (1 - dropout_rate)
max_off = 0.2 * num_elements
assert(abs(resulted_non_zeros - expected_non_zeros) <
max_off)
assert(abs(resulted_non_zeros-expected_non_zeros) <
max_off)
@pytest.mark.parametrize("dropout_rate", [-0.1, 1.0, 100])
def test_op_dropout_bad_input(dropout_rate):
from cntk import dropout
from cntk.utils import eval, sanitize_dtype_cntk, cntk_device
a = I(shape=(1,2), data_type='float', needs_gradient=True, name='a')
a = I(shape=(1, 2), data_type='float', needs_gradient=True, name='a')
with pytest.raises(ValueError):
dropout_node = dropout(a, dropout_rate=dropout_rate)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_sqrt(operand, device_id, precision):
t = np.sqrt(AA(operand, dtype=PRECISION_TO_TYPE[precision]))
@ -173,27 +183,30 @@ def test_op_sqrt(operand, device_id, precision):
backward = 1 / (2 * t)
expected_backward = {
'arg': [[backward]]
}
'arg': [[backward]]
}
from cntk import sqrt
_test_unary_op(precision, device_id, sqrt, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_square(operand, device_id, precision):
s = AA(operand, dtype=PRECISION_TO_TYPE[precision]) * AA(operand, dtype=PRECISION_TO_TYPE[precision])
s = AA(operand, dtype=PRECISION_TO_TYPE[
precision]) * AA(operand, dtype=PRECISION_TO_TYPE[precision])
expected_forward = [AA([s])]
backward = 2 * AA(operand, dtype=PRECISION_TO_TYPE[precision])
expected_backward = {
'arg': [[backward]]
}
'arg': [[backward]]
}
from cntk import square
_test_unary_op(precision, device_id, square, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_log(operand, device_id, precision):
@ -207,16 +220,17 @@ def test_op_log(operand, device_id, precision):
backward[np.isnan(backward)] = "inf"
backward[np.isinf(backward)] = BACKWARD_RESULST_FOR_LOG_EPS
backward[backward<=0] = BACKWARD_RESULST_FOR_LOG_EPS
backward[backward <= 0] = BACKWARD_RESULST_FOR_LOG_EPS
expected_backward = {
'arg': [[backward]]
}
'arg': [[backward]]
}
from cntk import log
_test_unary_op(precision, device_id, log, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_reciprocal(operand, device_id, precision):
@ -227,13 +241,14 @@ def test_op_reciprocal(operand, device_id, precision):
backward = -1 * t * t
expected_backward = {
'arg': [[backward]]
}
'arg': [[backward]]
}
from cntk import reciprocal
_test_unary_op(precision, device_id, reciprocal, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("operand", TENSORS)
def test_op_relu(operand, device_id, precision):
@ -241,19 +256,20 @@ def test_op_relu(operand, device_id, precision):
expected_forward = [[np.maximum(np.zeros_like(t), t)]]
expected_backward = {
'arg' : [[AA(t > np.zeros_like(t), dtype = int)]]
}
'arg': [[AA(t > np.zeros_like(t), dtype=int)]]
}
from cntk import relu
_test_unary_op(precision, device_id, relu, operand,
expected_forward, expected_backward)
expected_forward, expected_backward)
SAMPLES = [ # 2 samples having 4 classes
[1, 1, 2, 3],
[0, 0, 0, 0],
[3, 3, 4, 4]
]
[1, 1, 2, 3],
[0, 0, 0, 0],
[3, 3, 4, 4]
]
@pytest.mark.parametrize("sample", SAMPLES)
def test_op_softmax(sample, device_id, precision):
@ -267,7 +283,8 @@ def test_op_softmax(sample, device_id, precision):
expected_forward = [AA([forward])]
sample_length = len(forward)
grad = np.zeros((sample_length, sample_length), dtype=PRECISION_TO_TYPE[precision])
grad = np.zeros((sample_length, sample_length),
dtype=PRECISION_TO_TYPE[precision])
for i in range(sample_length):
for j in range(sample_length):
@ -279,13 +296,14 @@ def test_op_softmax(sample, device_id, precision):
backward = grad.sum(axis=0)
expected_backward = {
'arg' : [[backward]]
}
'arg': [[backward]]
}
from cntk import softmax
_test_unary_op(precision, device_id, softmax, sample,
expected_forward, expected_backward)
expected_forward, expected_backward)
@pytest.mark.parametrize("sample", SAMPLES)
def test_op_hardmax(sample, device_id, precision):
@ -294,7 +312,7 @@ def test_op_hardmax(sample, device_id, precision):
forward = np.zeros_like(t, dtype=PRECISION_TO_TYPE[precision])
for i,x in enumerate(t):
for i, x in enumerate(t):
if x == t_max:
forward[i] = 1
break
@ -302,10 +320,10 @@ def test_op_hardmax(sample, device_id, precision):
expected_forward = [AA([forward])]
expected_backward = {
'arg' : [[np.zeros_like(forward)]]
}
'arg': [[np.zeros_like(forward)]]
}
from cntk import hardmax
_test_unary_op(precision, device_id, hardmax, sample,
expected_forward, expected_backward)
expected_forward, expected_backward)

Просмотреть файл

@ -19,60 +19,64 @@ from .. import constant, input_variable
I = input_variable
@pytest.fixture(params=["dense", "sparse"])
def left_matrix_type(request):
return request.param
@pytest.fixture(params=["dense", "sparse"])
def right_matrix_type(request):
return request.param
def _test_unary_op(precision, device_id, op_func,
value, expected_forward, expected_backward_all, op_param_dict=None):
value, expected_forward, expected_backward_all, op_param_dict=None):
value = AA(value, dtype=PRECISION_TO_TYPE[precision])
a = I(shape=value.shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
# create batch
value.shape = (1,1) + value.shape
value.shape = (1, 1) + value.shape
if (type(op_func) == str):
input_op = eval('%s a'%op_func)
input_op = eval('%s a' % op_func)
elif op_param_dict:
input_op = op_func(a, **op_param_dict)
else:
input_op = op_func(a)
forward_input = {a:value}
expected_backward = { a: expected_backward_all['arg'], }
unittest_helper(input_op,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
forward_input = {a: value}
expected_backward = {a: expected_backward_all['arg'], }
unittest_helper(input_op,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
def _test_binary_op(precision, device_id, op_func, left_operand, right_operand,
expected_forward, expected_backward_all, only_input_variables=False):
left_value = AA(left_operand, dtype=PRECISION_TO_TYPE[precision])
expected_forward, expected_backward_all, only_input_variables=False):
left_value = AA(left_operand, dtype=PRECISION_TO_TYPE[precision])
right_value = AA(right_operand, dtype=PRECISION_TO_TYPE[precision])
a = I(shape=left_value.shape,
data_type=sanitize_dtype_cntk(precision),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(precision),
needs_gradient=True,
name='a')
b = I(shape=right_value.shape,
data_type=sanitize_dtype_cntk(precision),
needs_gradient=True,
name='b')
data_type=sanitize_dtype_cntk(precision),
needs_gradient=True,
name='b')
if (type(op_func) == str):
input_op_constant = eval('a %s right_operand'%op_func)
constant_op_input = eval('left_operand %s b'%op_func)
input_op_input = eval('a %s b'%op_func)
input_op_constant = eval('a %s right_operand' % op_func)
constant_op_input = eval('left_operand %s b' % op_func)
input_op_input = eval('a %s b' % op_func)
else:
input_op_constant = op_func(a, right_value)
constant_op_input = op_func(left_value, b)
@ -80,38 +84,40 @@ def _test_binary_op(precision, device_id, op_func, left_operand, right_operand,
# create batch by wrapping the data point into a sequence of length one and
# putting it into a batch of one sample
left_value.shape = (1,1) + left_value.shape
right_value.shape = (1,1) + right_value.shape
left_value.shape = (1, 1) + left_value.shape
right_value.shape = (1, 1) + right_value.shape
forward_input = {a:left_value, b:right_value}
expected_backward = { a: expected_backward_all['left_arg'], b: expected_backward_all['right_arg'], }
forward_input = {a: left_value, b: right_value}
expected_backward = {a: expected_backward_all[
'left_arg'], b: expected_backward_all['right_arg'], }
unittest_helper(input_op_input,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
if not only_input_variables:
forward_input = {a:left_value}
expected_backward = { a: expected_backward_all['left_arg'], }
forward_input = {a: left_value}
expected_backward = {a: expected_backward_all['left_arg'], }
unittest_helper(input_op_constant,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
forward_input = {b:right_value}
expected_backward = { b: expected_backward_all['right_arg'], }
forward_input = {b: right_value}
expected_backward = {b: expected_backward_all['right_arg'], }
unittest_helper(constant_op_input,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
def unittest_helper(root_node,
forward_input, expected_forward, expected_backward,
device_id=-1, precision="float"):
def unittest_helper(root_node,
forward_input, expected_forward, expected_backward,
device_id=-1, precision="float"):
backward_pass = expected_backward is not None
forward, backward = cntk_eval(root_node, precision, cntk_device(device_id),
forward_input, backward_pass)
forward_input, backward_pass)
# for forward we always expect only one result
assert len(forward)==1
assert len(forward) == 1
forward = list(forward.values())[0]
forward = np.atleast_1d(forward)
@ -120,19 +126,21 @@ def unittest_helper(root_node,
assert res.shape == AA(exp).shape
assert np.allclose(res, exp, atol=TOLERANCE_ABSOLUTE)
if expected_backward:
if expected_backward:
for key in expected_backward:
res, exp = backward[key], expected_backward[key]
if isinstance(res, list):
assert len(res) == len(exp)
for res_seq, exp_seq in zip (res, exp):
for res_seq, exp_seq in zip(res, exp):
assert res_seq.shape == AA(exp_seq).shape
assert np.allclose(res_seq, exp_seq, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(
res_seq, exp_seq, atol=TOLERANCE_ABSOLUTE)
elif isinstance(res, np.ndarray):
assert res.shape == AA(exp).shape
assert np.allclose(res, exp, atol=TOLERANCE_ABSOLUTE)
def batch_dense_to_sparse(batch, dynamic_axis=''):
'''
Helper test function that converts a batch of dense tensors into sparse
@ -170,46 +178,48 @@ def batch_dense_to_sparse(batch, dynamic_axis=''):
t_indices = range(tensor.size)
t_values = tensor.ravel(order='F')
mask = t_values!=0
mask = t_values != 0
batch_indices.append(list(np.asarray(t_indices)[mask]))
batch_values.append(list(np.asarray(t_values)[mask]))
return batch_indices, batch_values, shapes_in_tensor.pop()
def test_batch_dense_to_sparse_full():
i, v, s = batch_dense_to_sparse(
[
[[1,2,3], [4,5,6]],
[[10,20,30], [40,50,60]],
])
[
[[1, 2, 3], [4, 5, 6]],
[[10, 20, 30], [40, 50, 60]],
])
assert i == [
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
]
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
]
assert v == [
[1,4,2,5,3,6],
[10,40,20,50,30,60]
]
assert s == (2,3)
[1, 4, 2, 5, 3, 6],
[10, 40, 20, 50, 30, 60]
]
assert s == (2, 3)
i, v, s = batch_dense_to_sparse([[1]])
assert i == [[0]]
assert v == [[1]]
assert s == (1,)
def test_batch_dense_to_sparse_zeros():
i, v, s = batch_dense_to_sparse(
[
[[1,2,3], [4,0,6]],
[[0,0,0], [40,50,60]],
])
[
[[1, 2, 3], [4, 0, 6]],
[[0, 0, 0], [40, 50, 60]],
])
assert i == [
[0, 1, 2, 4, 5],
[1, 3, 5],
]
[0, 1, 2, 4, 5],
[1, 3, 5],
]
assert v == [
[1,4,2,3,6],
[40,50,60]
]
assert s == (2,3)
[1, 4, 2, 3, 6],
[40, 50, 60]
]
assert s == (2, 3)

Просмотреть файл

@ -16,8 +16,9 @@ import cntk as C
from ...utils import sanitize_dtype_cntk, precision_numpy
EPS_IN_LOG = 1e-37 # 1e-37 is the highest guaranteed precision
BACKWARD_RESULST_FOR_LOG_EPS = 9.08782e+36 # the backward result returned by CNTK log() for epsilon
LOG_OF_EPS_IN_LOG = -85.1 # log(EPS_IN_LOG)
# the backward result returned by CNTK log() for epsilon
BACKWARD_RESULST_FOR_LOG_EPS = 9.08782e+36
LOG_OF_EPS_IN_LOG = -85.1 # log(EPS_IN_LOG)
RESHAPE_TEST_CASES = [
#(input_shape, output_shape, expected_output_shape)
@ -27,69 +28,75 @@ RESHAPE_TEST_CASES = [
((6, 1), (2, 3), (2, 3)),
((2, 3, 5), (5, 6), (5, 6)),
# now we test the feature that we can set one dimension of the output_shape to 0 meaning that it's value is inferred
#FIXME 0 is for some reason not supported yet
#((2, 3, 5), (0, 6), (5, 6)),
# FIXME 0 is for some reason not supported yet
#((2, 3, 5), (0, 6), (5, 6)),
#((2, 3, 5), (5, 0), (5, 6)),
]
@pytest.mark.parametrize("input_shape, output_shape, expected_output_shape", RESHAPE_TEST_CASES)
def test_op_reshape(input_shape, output_shape, expected_output_shape, device_id, precision):
# Reshaping is just moving the input values to different indexes of the result tensor.
# If we compute the gradients on the unmodified tensor, reshape would get 1 for all inputs
# For testing the gradients we want to have different gradients for each input index otherwise we can't
# test if they get wrongly permuted during test. To this end we multiply the reshaping result with itself.
# test if they get wrongly permuted during test. To this end we multiply
# the reshaping result with itself.
from ...utils import sanitize_dtype_cntk
from .. import reshape, element_times
num_tensor_elements = np.multiply.reduce(input_shape)
input_tensor = np.arange(num_tensor_elements, dtype=PRECISION_TO_TYPE[precision])
input_tensor = np.arange(
num_tensor_elements, dtype=PRECISION_TO_TYPE[precision])
input_reshaped = input_tensor.reshape(expected_output_shape)
a = I(shape=input_tensor.shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
a_reshaped = reshape(a, output_shape)
input_op = element_times(a_reshaped, input_reshaped)
expected_forward = [[input_reshaped**2]]
expected_backward = { a: input_tensor }
expected_backward = {a: input_tensor}
# create batch
input_tensor.shape = (1,1) + input_tensor.shape
input_tensor.shape = (1, 1) + input_tensor.shape
forward_input = {a:input_tensor}
forward_input = {a: input_tensor}
unittest_helper(input_op,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
unittest_helper(input_op,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
def test_op_reshape_bad_input():
from .. import reshape
a = I(shape=(4,5))
a = I(shape=(4, 5))
with pytest.raises(ValueError):
reshape(a, (-1,2,3))
reshape(a, (-1, 2, 3))
SLICE_TEST_CASES_STATIC = [
#(input_data, slice_params(beg_index, end_index,axis), expected_result)
([[1,2],[-3,4]], (1,2,0), [[-3,4]]),
([[1, 2], [-3, 4]], (1, 2, 0), [[-3, 4]]),
# FIXME slicing on axes >0 is not supported yet
# ([[1,2],[-3,4]], (1,2,1), [[2],[4]]),
]
@pytest.mark.parametrize("input_data, slice_params, expected_result",
SLICE_TEST_CASES_STATIC)
SLICE_TEST_CASES_STATIC)
def test_op_slice(input_data, slice_params, expected_result, device_id, precision):
input_data = AA(input_data, dtype=PRECISION_TO_TYPE[precision])
a = I(shape=input_data.shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
def _ax_slices(x, beg_index, end_index, axis):
'''
@ -97,24 +104,24 @@ def test_op_slice(input_data, slice_params, expected_result, device_id, precisio
'''
ax_slices = []
for i in range(0, len(x.shape)):
if i==axis:
if i == axis:
if end_index >= x.shape[i]:
ax_slices.append([beg_index,])
ax_slices.append([beg_index, ])
else:
ax_slices.append([beg_index,end_index])
ax_slices.append([beg_index, end_index])
else:
ax_slices.append(slice(None)) # corresponds to ':'
ax_slices.append(slice(None)) # corresponds to ':'
return ax_slices
# slice using the overload
if False: # FIXME remove ones the overloads are in place
if False: # FIXME remove ones the overloads are in place
# slice using the operator
result = C.slice(a, *slice_params)
ax_slices = _ax_slices(a, *slice_params)
result = a[ax_slices]
unittest_helper(result, None, [[expected_result]], device_id=device_id,
precision=precision, clean_up=True, backward_pass=False)
unittest_helper(result, None, [[expected_result]], device_id=device_id,
precision=precision, clean_up=True, backward_pass=False)
# Backward pass test
# ==================
@ -126,48 +133,52 @@ def test_op_slice(input_data, slice_params, expected_result, device_id, precisio
res = np.zeros_like(x)
ax_slices = _ax_slices(x, beg_index, end_index, axis)
res[ax_slices] = x[ax_slices]
res[res!=0] = 1
res[res != 0] = 1
return res
expected_forward = [AA([expected_result], dtype=PRECISION_TO_TYPE[precision])]
expected_forward = [
AA([expected_result], dtype=PRECISION_TO_TYPE[precision])]
expected_backward = {
'arg': [[grad_slice(np.asarray(input_data), *slice_params)]]
}
'arg': [[grad_slice(np.asarray(input_data), *slice_params)]]
}
_test_unary_op(precision, device_id, C.slice, input_data,
expected_forward, expected_backward,
{ 'begin_index': slice_params[0],
'end_index': slice_params[1],
'axis': slice_params[2] })
expected_forward, expected_backward,
{'begin_index': slice_params[0],
'end_index': slice_params[1],
'axis': slice_params[2]})
SLICE_TEST_CASES_DYNAMIC = [
#(input_data, slice_params(beg_index, end_index), expected_result)
# Note that input_data contains sequences
([[[1,2,3]],[[-4,5,6]],[[7,8,9]]],
(0,2),
[[[1,2,3]],[[-4,5,6]]]),
([[[1,2,3],[11,12,13]],[[-4,5,6],[-14,15,16]],[[7,8,9],[17,18,19]]],
(0,2),
[[[1,2,3],[11,12,13]],[[-4,5,6],[-14,15,16]]]),
([[[1,2,3],[11,12,13]],[[-4,5,6],[-14,15,16]],[[7,8,9],[17,18,19]]],
(1,2),
[[[-4,5,6],[-14,15,16]]]),
([[[1, 2, 3]], [[-4, 5, 6]], [[7, 8, 9]]],
(0, 2),
[[[1, 2, 3]], [[-4, 5, 6]]]),
([[[1, 2, 3], [11, 12, 13]], [[-4, 5, 6], [-14, 15, 16]], [[7, 8, 9], [17, 18, 19]]],
(0, 2),
[[[1, 2, 3], [11, 12, 13]], [[-4, 5, 6], [-14, 15, 16]]]),
([[[1, 2, 3], [11, 12, 13]], [[-4, 5, 6], [-14, 15, 16]], [[7, 8, 9], [17, 18, 19]]],
(1, 2),
[[[-4, 5, 6], [-14, 15, 16]]]),
]
@pytest.mark.parametrize("input_data, slice_params, expected_result",
SLICE_TEST_CASES_DYNAMIC)
#FIXME enable once the ZeroesLike RuntimeError is fixed
SLICE_TEST_CASES_DYNAMIC)
# FIXME enable once the ZeroesLike RuntimeError is fixed
def test_op_slice_sequence(input_data, slice_params, expected_result, device_id, precision):
input_data = AA(input_data, dtype=PRECISION_TO_TYPE[precision])
t = C.Axis.new_unique_dynamic_axis('t')
sample_shape = input_data.shape[1:]
a = I(shape=sample_shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
dynamic_axes=[C.Axis.default_batch_axis(), t],
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
dynamic_axes=[C.Axis.default_batch_axis(), t],
name='a')
result = C.slice(a, axis=t, begin_index=slice_params[0], end_index=slice_params[1])
result = C.slice(a, axis=t, begin_index=slice_params[
0], end_index=slice_params[1])
def grad_slice(x, beg_index, end_index):
res = np.zeros_like(x)
@ -175,57 +186,62 @@ def test_op_slice_sequence(input_data, slice_params, expected_result, device_id,
return res
expected_gradient = grad_slice(np.asarray(input_data), *slice_params)
expected_forward = AA([expected_result], dtype=PRECISION_TO_TYPE[precision])
expected_backward = {
a: [grad_slice(np.asarray(input_data), *slice_params)]
}
# create batch
input_data.shape = (1,) + input_data.shape
forward_input = {a:input_data}
unittest_helper(result,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
expected_forward = AA(
[expected_result], dtype=PRECISION_TO_TYPE[precision])
expected_backward = {
a: [grad_slice(np.asarray(input_data), *slice_params)]
}
# create batch
input_data.shape = (1,) + input_data.shape
forward_input = {a: input_data}
unittest_helper(result,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
# FIXME once the overloads are in place, integrate test_op_slice_overload from
# F:\CNTKv2\contrib\Python\cntk\ops\tests\reshaping_test.py
SPLICE_TEST_CASES = [
#(input_data1, input_data2, axis, expected_result)
([1], [2], 0, [1,2]),
([[1,2],[4,5]], [[10,20],[30, 40],[50, 60]], 0,
[[1, 2],[4, 5],[10, 20],[30, 40],[50, 60]]),
([[1,2],[4,5]], [[10,20,30],[40, 50, 60]], 1,
[[1,2,10,20,30],[4,5,40,50,60]]),
([[[1,2],[3,4]],[[5,6],[7,8]]], [[10,20],[30,40]], 0,
[[[1,2],[3,4]],[[5,6],[7,8]],[[10,20],[30,40]]]),
([1], [2], 0, [1, 2]),
([[1, 2], [4, 5]], [[10, 20], [30, 40], [50, 60]], 0,
[[1, 2], [4, 5], [10, 20], [30, 40], [50, 60]]),
([[1, 2], [4, 5]], [[10, 20, 30], [40, 50, 60]], 1,
[[1, 2, 10, 20, 30], [4, 5, 40, 50, 60]]),
([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[10, 20], [30, 40]], 0,
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[10, 20], [30, 40]]]),
]
@pytest.mark.parametrize("input_data1, input_data2, axis, expected_result", SPLICE_TEST_CASES)
def test_op_splice(input_data1, input_data2, axis, expected_result, device_id, precision):
# FIXME This test currently fails in C++ with
# RuntimeError: Node 'splice_ab' (RowStack operation): Attempted to type-cast node to struct Microsoft::MSR::CNTK::INumInputs, which is not possible.
# FIXME This test currently fails in C++ with
# RuntimeError: Node 'splice_ab' (RowStack operation): Attempted to
# type-cast node to struct Microsoft::MSR::CNTK::INumInputs, which is not
# possible.
input_data1 = AA(input_data1, dtype=PRECISION_TO_TYPE[precision])
input_data2 = AA(input_data2, dtype=PRECISION_TO_TYPE[precision])
a = I(shape=input_data1.shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='a')
b = I(shape=input_data2.shape,
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='b')
data_type=sanitize_dtype_cntk(PRECISION_TO_TYPE[precision]),
needs_gradient=True,
name='b')
# create batch
input_data1.shape = (1,1) + input_data1.shape
input_data2.shape = (1,1) + input_data2.shape
input_data1.shape = (1, 1) + input_data1.shape
input_data2.shape = (1, 1) + input_data2.shape
# splice using the operator
root_op = C.splice((a, b), axis, name='splice_ab')
forward_input = {a:input_data1, b:input_data2}
forward_input = {a: input_data1, b: input_data2}
# Backward pass test
# ==================
@ -236,10 +252,10 @@ def test_op_splice(input_data1, input_data2, axis, expected_result, device_id, p
expected_forward = [[expected_result]]
expected_backward = {
a: grad_splice(np.asarray(input_data1)),
b: grad_splice(np.asarray(input_data2))
}
a: grad_splice(np.asarray(input_data1)),
b: grad_splice(np.asarray(input_data2))
}
unittest_helper(root_op,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)
unittest_helper(root_op,
forward_input, expected_forward, expected_backward,
device_id=device_id, precision=precision)

Просмотреть файл

@ -4,11 +4,12 @@ from cntk import DATATYPE
from cntk.tensor import TensorOpsMixin
from .. import utils
FLOAT_32='float32'
FLOAT_32 = 'float32'
def _sanitize_value(shape, value, dtype, device):
np_dtype = utils.sanitize_dtype_numpy(dtype)
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
if value is None:
if shape is None:
@ -16,7 +17,7 @@ def _sanitize_value(shape, value, dtype, device):
shape = utils.sanitize_shape(shape)
ndav = utils.create_NDArrayView(shape, cntk_dtype, device)
else:
if not isinstance(value, np.ndarray) or value.dtype!=np_dtype:
if not isinstance(value, np.ndarray) or value.dtype != np_dtype:
if np.isscalar(value) and shape:
value = np.full(shape, value, dtype=np_dtype)
else:
@ -26,18 +27,23 @@ def _sanitize_value(shape, value, dtype, device):
return ndav
class Variable(TensorOpsMixin,Variable):
class Variable(TensorOpsMixin, Variable):
def __init__(self, shape=None, data_type=None, needs_gradient=False, is_sparse=False,
dynamic_axes = [Axis.default_dynamic_axis(), Axis.default_batch_axis()], name=''):
dynamic_axes=[Axis.default_dynamic_axis(), Axis.default_batch_axis()], name=''):
shape = utils.sanitize_shape(shape)
if data_type is None:
data_type = FLOAT_32
dtype = utils.sanitize_dtype_cntk(data_type)
super(Variable, self).__init__(shape, is_sparse, dtype, needs_gradient, name, dynamic_axes)
super(Variable, self).__init__(shape, is_sparse,
dtype, needs_gradient, name, dynamic_axes)
class Parameter(TensorOpsMixin, Parameter):
class Parameter(TensorOpsMixin,Parameter):
def __init__(self, shape=None, value=None, data_type=None,
initializer=None, device=None, name=''):
@ -59,7 +65,8 @@ class Parameter(TensorOpsMixin,Parameter):
super(Parameter, self).__init__(ndav, name)
class Constant(TensorOpsMixin,Constant):
class Constant(TensorOpsMixin, Constant):
def __init__(self, shape=None, value=None, data_type=None,
device=None, name=''):
@ -71,4 +78,3 @@ class Constant(TensorOpsMixin,Constant):
ndav = _sanitize_value(shape, value, data_type, device)
super(Constant, self).__init__(ndav, name)

Просмотреть файл

@ -1,6 +1,7 @@
import cntk.cntk_py as cntk_py
class PyCallback(cntk_py.Callback):
def __init__(self):
@ -8,13 +9,12 @@ class PyCallback(cntk_py.Callback):
def forward(self):
print("PyCallback.forward()")
1/0
1 / 0
def backward(self):
print("PyCallback.backward()")
def callback_test():
op = cntk_py.FunctionInCNTK()
@ -36,7 +36,7 @@ def callback_test():
op.delCallback()
if __name__=='__main__':
import time
if __name__ == '__main__':
import time
callback_test()
#cntk_py.exception_tester()
# cntk_py.exception_tester()

Просмотреть файл

@ -68,28 +68,32 @@ class TensorOpsMixin(object):
def __abs__(self):
from . import ops
return ops.abs(self)
def __neg__(self):
from . import ops
return ops.negate(self)
# TODO __lt__, __le__, __gt__, __ge__, __and__, __rand__, __or__, __ror__, __xor__, __rxor__, __pow__, __rpow__, __invert__
# TODO __lt__, __le__, __gt__, __ge__, __and__, __rand__, __or__, __ror__,
# __xor__, __rxor__, __pow__, __rpow__, __invert__
def __getitem__(self, key):
from . import ops
if isinstance(key, int):
# Case 1: e.g. data[3] -> key=3
return ops.slice(self, key, key+1, axis=0)
return ops.slice(self, key, key + 1, axis=0)
elif isinstance(key, slice):
# Case 2: e.g. data[2:4] -> key will be a slice object
if key.step is not None:
raise TypeError('step argument is not supported')
if not isinstance(key.stop, int):
raise TypeError('end index has to be of type int, not "%s"'%type(key.stop))
raise TypeError(
'end index has to be of type int, not "%s"' % type(key.stop))
if isinstance(key.start, int):
if key.stop<=key.start:
raise ValueError('end index has to be greater than start index')
if key.stop <= key.start:
raise ValueError(
'end index has to be greater than start index')
return ops.slice(self, key.start or 0, key.stop or 0, axis=0)
elif isinstance(key, (tuple, list)):
@ -101,18 +105,20 @@ class TensorOpsMixin(object):
for ax_counter, so in enumerate(key):
if isinstance(so, int):
# Proceed as case 1
node = ops.slice(node, so, so+1, axis=ax_counter)
node = ops.slice(node, so, so + 1, axis=ax_counter)
elif isinstance(so, slice):
# Proceed as case 2
if so.step is not None:
raise TypeError('step argument is not supported')
if isinstance(so.start, int) and isinstance(so.stop, int):
if so.stop<=so.start:
raise ValueError('end index has to be greater than start index')
if so.stop <= so.start:
raise ValueError(
'end index has to be greater than start index')
if so.start is None and so.stop is None:
continue
node = ops.slice(node, so.start or 0, so.stop or 0, axis=ax_counter)
node = ops.slice(node, so.start or 0,
so.stop or 0, axis=ax_counter)
elif isinstance(so, list):
# Case 3b: e.g. data[[0],[2,3]] aka "advanced indexing" ->
# so = ([0], [2,3])
@ -121,37 +127,44 @@ class TensorOpsMixin(object):
# we decided to have all shapes like data[0] in this case
for idx in so:
if not isinstance(idx, int):
raise IndexError('indices have to be of type int and not "%s"'%type(idx))
node = ops.slice(node, idx, idx+1, axis=ax_counter)
raise IndexError(
'indices have to be of type int and not "%s"' % type(idx))
node = ops.slice(node, idx, idx + 1, axis=ax_counter)
else:
raise IndexError('type "%s" is not supported as index'%type(so))
raise IndexError(
'type "%s" is not supported as index' % type(so))
return node
else:
raise TypeError('index must be int or slice, not {}'.format(type(key).__name__))
raise TypeError(
'index must be int or slice, not {}'.format(type(key).__name__))
AVAILABLE_TENSOR_OPS = ['abs', 'add', 'div', 'getitem', 'matmul', 'mul',
'radd', 'rdiv', 'rmatmul', 'rmul', 'rsub', 'rtruediv', 'sub',
'truediv', 'neg']
'radd', 'rdiv', 'rmatmul', 'rmul', 'rsub', 'rtruediv', 'sub',
'truediv', 'neg']
def _add_tensor_ops(klass):
for op_name in AVAILABLE_TENSOR_OPS:
overload_name = '__%s__'%op_name
overload_name = '__%s__' % op_name
if getattr(klass, overload_name, None):
raise ValueError('class "%s" already has operator overload "%s"'%\
(klass, overload_name))
raise ValueError('class "%s" already has operator overload "%s"' %
(klass, overload_name))
setattr(klass, overload_name, getattr(TensorOpsMixin, overload_name))
class EvalMixin(object):
def eval(self, input_map=None):
from .utils import eval as utils_eval
from . import DeviceDescriptor
device = DeviceDescriptor.cpu_device()
if len(self.outputs())!=1:
raise ValueError('only operators with exactly one output can be evaluated')
if len(self.outputs()) != 1:
raise ValueError(
'only operators with exactly one output can be evaluated')
if input_map is None:
input_map = {}
@ -164,8 +177,7 @@ def _add_eval(klass):
overload_name = 'eval'
if getattr(klass, overload_name, None):
raise ValueError('class "%s" already has operator overload "%s"'%\
(klass, overload_name))
setattr(klass, overload_name, getattr(EvalMixin, overload_name))
raise ValueError('class "%s" already has operator overload "%s"' %
(klass, overload_name))
setattr(klass, overload_name, getattr(EvalMixin, overload_name))

Просмотреть файл

@ -10,19 +10,23 @@ import pytest
from ..initializer import *
from .. import parameter, input_variable, momentums_per_sample
def _check(init, name):
p = parameter(shape=(10,20,5), initializer=init)
p = parameter(shape=(10, 20, 5), initializer=init)
assert np.allclose(np.average(p.value().to_numpy()), 0, atol=0.1), name
assert np.var(p.value().to_numpy()) > 0.01, name
def test_initializer_init():
_check(uniform_initializer(scale=10), 'uniform')
_check(gaussian_initializer(output_rank=1, filter_rank=2, scale=10), 'gaussian')
_check(gaussian_initializer(output_rank=1,
filter_rank=2, scale=10), 'gaussian')
_check(xavier_initializer(output_rank=1, filter_rank=2, scale=10), 'xavier')
_check(glorot_uniform_initializer(output_rank=1, filter_rank=2, scale=10), 'glorot_uniform')
_check(glorot_normal_initializer(output_rank=1, filter_rank=2, scale=10), 'glorot_normal')
_check(he_uniform_initializer(output_rank=1, filter_rank=2, scale=10), 'he_uniform')
_check(he_normal_initializer(output_rank=1, filter_rank=2, scale=10), 'he_normal')
_check(glorot_uniform_initializer(output_rank=1,
filter_rank=2, scale=10), 'glorot_uniform')
_check(glorot_normal_initializer(output_rank=1,
filter_rank=2, scale=10), 'glorot_normal')
_check(he_uniform_initializer(output_rank=1,
filter_rank=2, scale=10), 'he_uniform')
_check(he_normal_initializer(output_rank=1,
filter_rank=2, scale=10), 'he_normal')

Просмотреть файл

@ -14,18 +14,20 @@ import pytest
def test_learner_init():
# TODO Test functionality
i = input_variable(shape=(1,),
needs_gradient=True,
name='a')
needs_gradient=True,
name='a')
w = parameter(shape=(1,))
res = i*w
res = i * w
sgd_learner(res.parameters(), lr=0.1)
momentum_time_constant = 1100
momentum_per_sample = momentums_per_sample(math.exp(-1.0 / momentum_time_constant))
momentum_per_sample = momentums_per_sample(
math.exp(-1.0 / momentum_time_constant))
momentum_sgd_learner(res.parameters(), lr=0.1, momentums=momentum_per_sample)
momentum_sgd_learner(res.parameters(), lr=0.1,
momentums=momentum_per_sample)
nesterov_learner(res.parameters(), lr=0.1, momentums=momentum_per_sample)
@ -33,5 +35,5 @@ def test_learner_init():
fsadagrad_learner(res.parameters(), lr=0.1, momentums=momentum_per_sample)
gamma, inc, dec, max, min = [0.1]*5
gamma, inc, dec, max, min = [0.1] * 5
rmsprop_learner(res.parameters(), 0.1, gamma, inc, dec, max, min, True)

Просмотреть файл

@ -19,7 +19,7 @@ PRECISION_TO_TYPE = {'float': np.float32, 'double': np.float64}
AA = np.asarray
@pytest.fixture(params=["float", "double"])
def precision(request):
return request.param

Просмотреть файл

@ -11,6 +11,7 @@ import numpy as np
import scipy.sparse
from cntk import cntk_py
def precision_numpy(precision):
'''
Converts string precision to NumPy precision
@ -26,7 +27,8 @@ def precision_numpy(precision):
elif precision == 'double':
return np.float64
else:
raise ValueError('precision value: "%s" is not supported'%precision)
raise ValueError('precision value: "%s" is not supported' % precision)
def cntk_device(device_id):
'''
@ -38,11 +40,12 @@ def cntk_device(device_id):
Returns:
CNTK DeviceDescriptor
'''
if device_id==-1:
if device_id == -1:
return cntk_py.DeviceDescriptor.cpu_device()
else:
return cntk_py.DeviceDescriptor.gpu_device(device_id)
def cntk_to_numpy_shape(shape):
'''
Removes the dynamic axis and returns a tuple representing the NumPy shape.
@ -63,6 +66,7 @@ def cntk_to_numpy_shape(shape):
# cntk uses column major, thus we reverse the axes
return tuple(reversed(shape))
def is_string(value):
if sys.version_info.major < 3:
return isinstance(value, basestring)
@ -89,7 +93,7 @@ def dense_to_str(data):
def sparse_to_str(data):
return ' '.join('%s:%s'%(k,v) for k,v in sorted(data.items()))
return ' '.join('%s:%s' % (k, v) for k, v in sorted(data.items()))
def tensors_to_text_format(sample_idx, alias_tensor_map):
@ -136,6 +140,7 @@ def tensors_to_text_format(sample_idx, alias_tensor_map):
return '\n'.join(lines)
def is_tensor(data):
'''
Checks whether the data is a tensor, i.e. whether it is a NumPy array or a
@ -175,13 +180,15 @@ def is_tensor(data):
return True
def is_tensor_list(data):
'''
Checks whether the data is a CNTK sequence, which is expressed in Python as
a list of varying sized NumPy objects.
'''
is_list = isinstance(data, list)
return is_list and len(data) > 0 and isinstance(data[0], np.ndarray)
return is_list and len(data) > 0 and isinstance(data[0], np.ndarray)
def get_temp_filename(directory=None):
'''
@ -205,6 +212,7 @@ def get_temp_filename(directory=None):
return tf.name
def sanitize_shape(shape):
"""
if shape is scalar create a tuple out of it and reverse it as cntk uses column major
@ -213,6 +221,7 @@ def sanitize_shape(shape):
shape = (shape,)
return tuple(reversed(shape))
def sanitize_input(arg, fallback_dtype=np.float32):
"""
Convert to Variable or Constant so that it can be passed as Variable to the CNTK
@ -238,22 +247,24 @@ def sanitize_input(arg, fallback_dtype=np.float32):
# or a Function?
# FIXME soon to be replaced by Function
#if isinstance(arg, (Function, cntk_py.Function)):
# if isinstance(arg, (Function, cntk_py.Function)):
if isinstance(arg, cntk_py.Function):
try:
return arg.output()
except RuntimeError:
raise ValueError('the argument has more than one output, please provide the one you want')
raise ValueError(
'the argument has more than one output, please provide the one you want')
# maybe a Python list that we can interpret as a NumPy array?
if isinstance(arg, list) and not arg:
raise ValueError('input is empty')
if not isinstance(arg, np.ndarray):
if not isinstance(arg, np.ndarray):
arg = np.asarray(arg, dtype=fallback_dtype)
return constant(value=arg)
def get_data_type(*args):
"""
Calculates the highest precision numpy datatype of the provided parameters. If
@ -266,7 +277,7 @@ def get_data_type(*args):
"""
dtypes = set()
if len(args)==1 and isinstance(args, cntk_py.Function):
if len(args) == 1 and isinstance(args, cntk_py.Function):
args = [args]
for arg in args:
@ -277,29 +288,31 @@ def get_data_type(*args):
dtypes.add(np.float32)
elif isinstance(arg, np.ndarray):
if arg.dtype not in (np.float32, np.float64):
raise ValueError('NumPy type "%s" is not supported'%arg.dtype)
raise ValueError(
'NumPy type "%s" is not supported' % arg.dtype)
dtypes.add(arg.dtype)
elif isinstance(arg, cntk_py.Function):
var_outputs = arg.outputs()
if len(var_outputs)>1:
raise ValueError('expected single output, but got %i'%len(var_outputs))
if len(var_outputs) > 1:
raise ValueError(
'expected single output, but got %i' % len(var_outputs))
var_output = var_outputs[0]
if cntk_py.DataType_Double == var_output.get_data_type():
dtypes.add(np.float64)
else:
# We don't know anything so we convert everything to float32. If it
# works, we know the type.
# works, we know the type.
# TODO figure out a better/faster way.
np.asarray(arg, dtype=np.float32)
dtypes.add(np.float32)
if np.float64 in dtypes:
return np.float64
else:
return np.float32
def pad_to_dense(batch):
"""Appends the minimal required amount of zeroes at the end of each sample
in the batch so that it becomes rectangular. `batch` is assumed to be
@ -324,14 +337,16 @@ def pad_to_dense(batch):
# This is not the most efficient way of dealing with variable length
# sequences, but so far the only one supported. Once, ragged arrays are
# natively supported in CNTK, this will change.
Z = np.zeros((len(batch), max_seq_len)+(data_point.shape), dtype=data_point.dtype)
Z = np.zeros((len(batch), max_seq_len) +
(data_point.shape), dtype=data_point.dtype)
for idx, seq in enumerate(batch):
if seq[0].shape != data_point.shape:
raise ValueError('shape mismatch: expected %s but got '
' %s'%(str(data_point.shape), str(seq[0].shape)))
Z[idx, :len(seq)] += seq
' %s' % (str(data_point.shape), str(seq[0].shape)))
Z[idx, :len(seq)] += seq
return Z
def sanitize_batch(batch, data_type=None, device=None):
"""
Convert to Value with `data_type`. If the samples in `batch` have different
@ -352,22 +367,22 @@ def sanitize_batch(batch, data_type=None, device=None):
num_seq = len(batch)
except TypeError:
raise ValueError('expected an object of type Value or a NumPy ' +
'array and not "%s"'%type(batch))
'array and not "%s"' % type(batch))
seq_lens = [len(seq) for seq in batch]
use_mask = len(set(seq_lens))!=1
use_mask = len(set(seq_lens)) != 1
if use_mask:
# If not all sequences are of the same length, we have to pad them to
# the same length and create a mask over the original data.
from cntk.cntk_py import NDMask
mask = NDMask((max(seq_lens), num_seq), device)
for idx, seq_len in enumerate(seq_lens):
mask.mask_section((seq_len, idx), (cntk_py.InferredDimension, 1))
mask.mask_section((seq_len, idx), (cntk_py.InferredDimension, 1))
# Then we pad the batch to rectangular shape
if isinstance(batch, list):
if len(batch)==0:
if len(batch) == 0:
raise ValueError('batch is empty')
batch = pad_to_dense(batch)
@ -389,7 +404,7 @@ def sanitize_batch(batch, data_type=None, device=None):
if len(cntk_shape) == 0:
raise ValueError('values should be an array of input samples')
'''
ndav = create_NDArrayView_from_NumPy(batch, device)
if use_mask:
@ -399,6 +414,7 @@ def sanitize_batch(batch, data_type=None, device=None):
return value
def sanitize_var_map(input_map, precision_numpy=None, device=None, add_batch_axis=False):
'''
Sanitizes a dictionary of `Variable`s to input data such that it can be
@ -423,8 +439,9 @@ def sanitize_var_map(input_map, precision_numpy=None, device=None, add_batch_axi
if isinstance(batch, np.ndarray):
if batch.dtype == np.int:
batch = batch.astype(np.float32)
if batch.dtype not in (np.float32, np.float64):
raise ValueError('only float32 and float64 are supported')
if batch.dtype not in (np.float32, np.float64):
raise ValueError(
'only float32 and float64 are supported')
batch = sanitize_batch(batch, precision_numpy, device)
else:
if is_tensor(batch):
@ -437,6 +454,7 @@ def sanitize_var_map(input_map, precision_numpy=None, device=None, add_batch_axi
return var_map
def remove_masked_elements(batch, mask):
'''
From a zero-padded `batch`, remove those entries that have a 0 in the
@ -449,7 +467,8 @@ def remove_masked_elements(batch, mask):
Returns:
a list of ndarrays
'''
return [seq[mask[idx]==1] for idx,seq in enumerate(batch)]
return [seq[mask[idx] == 1] for idx, seq in enumerate(batch)]
def ones_like(batch, precision_numpy):
'''
@ -461,57 +480,67 @@ def ones_like(batch, precision_numpy):
'''
return [np.ones_like(sample, dtype=precision_numpy) for sample in batch]
def create_NDArrayView(shape, data_type=cntk_py.DataType_Float, dev=None):
shape = sanitize_shape(shape)
if not dev:
dev = cntk_py.DeviceDescriptor.use_default_device()
# FIXME only dense supported so far
view = cntk_py.NDArrayView(data_type, cntk_py.StorageFormat_Dense, shape, dev)
view = cntk_py.NDArrayView(
data_type, cntk_py.StorageFormat_Dense, shape, dev)
return view
def create_NDArrayView_from_NumPy(nd, dev=None):
ndav_cpu = cntk_py.NDArrayView(nd, cntk_py.DeviceDescriptor.cpu_device(), False)
ndav_cpu = cntk_py.NDArrayView(
nd, cntk_py.DeviceDescriptor.cpu_device(), False)
if not dev:
dev = cntk_py.DeviceDescriptor.use_default_device()
dev = cntk_py.DeviceDescriptor.use_default_device()
ndav = ensure_dev(ndav_cpu, dev)
return ndav
def create_Value_for_Variable(var, shape=None, dev=None, mask=None):
if not dev:
dev = cntk_py.DeviceDescriptor.cpu_device()
if shape is None:
shape = var.shape().dimensions()
view = cntk_py.NDArrayView(var.get_data_type(), cntk_py.StorageFormat_Dense, shape, dev)
view = cntk_py.NDArrayView(
var.get_data_type(), cntk_py.StorageFormat_Dense, shape, dev)
if mask:
value = cntk_py.Value(view, mask)
else:
value = cntk_py.Value(view)
return value
def create_Value(shape, data_type, dev):
value = cntk_py.Value(create_NDArrayView(shape, data_type, dev))
return value
def create_Value_from_NumPy(nd, dev):
view = create_NDArrayView_from_NumPy(nd, dev)
value = cntk_py.Value(view)
return value
def sanitize_dtype_numpy(dtype):
if dtype in ('float', 'float32', np.float32):
return np.float32
elif dtype in ('double', 'float64', np.float64):
return np.float64
else:
raise ValueError('data type "%s" is not supported'%dtype)
raise ValueError('data type "%s" is not supported' % dtype)
def sanitize_dtype_cntk(dtype):
def sanitize_dtype_cntk(dtype):
if dtype in (cntk_py.DataType_Float, cntk_py.DataType_Double,
cntk_py.DataType_Unknown):
cntk_py.DataType_Unknown):
return dtype
if dtype in ('float', 'float32', np.float32):
return cntk_py.DataType_Float
@ -520,7 +549,8 @@ def sanitize_dtype_cntk(dtype):
elif not dtype:
return cntk_py.DataType_Unknown
else:
raise ValueError('data type "%s" is not supported'%dtype)
raise ValueError('data type "%s" is not supported' % dtype)
def sanitize_axis(rank, axis):
if axis is None:
@ -532,6 +562,7 @@ def sanitize_axis(rank, axis):
else:
return axis
def sanitize_dynamic_axes(axes):
if axes is not cntk_py.Axis.default_input_variable_dynamic_axes:
if not type(axes) in (list, tuple):
@ -539,7 +570,8 @@ def sanitize_dynamic_axes(axes):
else:
axes = tuple(reversed(axes))
return axes
def get_train_loss(trainer):
'''
Fetch the train loss from the last minibatch and copy it to the CPU in case it is on the GPU.
@ -547,11 +579,12 @@ def get_train_loss(trainer):
trainer (:class:`Trainer`): the trainer used.
Returns:
the loss value
'''
'''
import copy
#we copy the value so swig does not destroy it when we leave the scope
# we copy the value so swig does not destroy it when we leave the scope
return copy.copy(trainer.previous_minibatch_loss_average())
def get_train_eval_criterion(trainer):
'''
Fetch the train evaluation criterion (e.g., classification error) from the last minibatch and copy it to the CPU in case it is on the GPU.
@ -559,29 +592,33 @@ def get_train_eval_criterion(trainer):
trainer (:class:`Trainer`): the trainer used.
Returns:
the criterion value
'''
'''
import copy
#we copy the value so swig does not destroy it when we leave the scope
# we copy the value so swig does not destroy it when we leave the scope
return copy.copy(trainer.previous_minibatch_evaluation_average())
def ensure_dev(ndav, dev):
if ndav.device() != dev:
ndav_on_target = create_NDArrayView(ndav.shape().dimensions(), data_type=ndav.get_data_type(), dev=dev)
ndav_on_target = create_NDArrayView(
ndav.shape().dimensions(), data_type=ndav.get_data_type(), dev=dev)
ndav_on_target.copy_from(ndav)
ndav = ndav_on_target
return ndav
def ensure_cpu(ndav):
return ensure_dev(ndav, cntk_py.DeviceDescriptor.cpu_device())
def eval(op, precision, device, input_map=None, backward_pass=False):
'''
It evaluates `op` on the data provided by the reader. This is useful
mainly to explore the operators and for convenient unit testing.
Args:
op (:class:`Function`): operation to evaluate
precision (`str` or `None`): precision being 'float32', 'float64', or `None`, in which case it will be determined by inspecting the operator (costly)
@ -598,26 +635,28 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
forward_in_var_map = sanitize_var_map(input_map, precision, device)
forward_out_var_map = {}
forward_out_var_map = {}
forward_retain = set()
for v in op.outputs():
forward_out_var_map[v] = None # will be populated in Forward()
forward_out_var_map[v] = None # will be populated in Forward()
forward_retain.add(v)
state = op.forward(forward_in_var_map, forward_out_var_map, device, forward_retain)
state = op.forward(forward_in_var_map,
forward_out_var_map, device, forward_retain)
forward_output = {}
forward_output_mask = {}
for v in op.outputs():
value = forward_out_var_map[v]
np_data = ensure_cpu(value.data()).to_numpy()
np_data = ensure_cpu(value.data()).to_numpy()
if value.mask():
np_data = remove_masked_elements(np_data, ensure_cpu(value.mask()).to_numpy())
np_data = remove_masked_elements(
np_data, ensure_cpu(value.mask()).to_numpy())
forward_output[v] = np_data
forward_output_mask[v] = value.mask()
if backward_pass:
root_gradients = {}
root_gradients = {}
for v, o in forward_output.items():
root_gradients[v] = ones_like(o, precision)
root_gradients = sanitize_var_map(root_gradients, precision, device)
@ -628,15 +667,13 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
backward_output = {}
for var, value in backward_var_map.items():
np_data = ensure_cpu(value.data()).to_numpy()
np_data = ensure_cpu(value.data()).to_numpy()
if value.mask():
np_data = remove_masked_elements(np_data, ensure_cpu(value.mask()).to_numpy())
np_data = remove_masked_elements(
np_data, ensure_cpu(value.mask()).to_numpy())
backward_output[var] = np_data
return forward_output, backward_output
else:
return forward_output, None

Просмотреть файл

@ -15,7 +15,8 @@ from cntk.utils import *
AA = np.asarray
C = constant
#TOOD: adapt to v2 when needed
# TOOD: adapt to v2 when needed
@pytest.mark.parametrize("idx, alias_tensor_map, expected", [
(0, {'A': [object()]}, ValueError),
@ -59,6 +60,7 @@ def test_tensor_conversion_dense(idx, alias_tensor_map, expected):
def test_is_tensor(data, expected):
assert is_tensor(data) == expected
@pytest.mark.parametrize("data, expected", [
([], False),
([1], False),

Просмотреть файл

@ -1,6 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
@ -10,23 +10,25 @@ import pytest
collect_ignore = ["setup.py", "build"]
# content of conftest.py
_DEFAULT_DEVICE_ID=-1
_DEFAULT_DEVICE_ID = -1
def pytest_addoption(parser):
parser.addoption("--deviceid", action="append", default=[_DEFAULT_DEVICE_ID],
help="list of device ids to pass to test functions")
help="list of device ids to pass to test functions")
DEVICE_MAP = {
'auto': 'auto',
'cpu': -1,
'gpu': 0
}
'auto': 'auto',
'cpu': -1,
'gpu': 0
}
def pytest_generate_tests(metafunc):
if 'device_id' in metafunc.fixturenames:
def pytest_generate_tests(metafunc):
if 'device_id' in metafunc.fixturenames:
if (len(metafunc.config.option.deviceid)) > 1:
del metafunc.config.option.deviceid[0]
devices = set()
for elem in metafunc.config.option.deviceid:
try:
@ -35,13 +37,15 @@ def pytest_generate_tests(metafunc):
else:
devices.add(int(elem))
except ValueError:
raise RuntimeError("invalid deviceid value '{0}', please " +
"use integer values or 'auto'".format(elem))
raise RuntimeError("invalid deviceid value '{0}', please " +
"use integer values or 'auto'".format(elem))
metafunc.parametrize("device_id", devices)
import numpy
import cntk
@pytest.fixture(autouse=True)
def add_namespace(doctest_namespace):
doctest_namespace['np'] = numpy

Просмотреть файл

@ -214,17 +214,17 @@ htmlhelp_basename = 'CNTK15doc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
# Latex figure (float) alignment
#'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples

Просмотреть файл

@ -20,17 +20,22 @@ TRAIN_MAP_FILENAME = 'train_map.txt'
MEAN_FILENAME = 'CIFAR-10_mean.xml'
# Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net
# The minibatch source is configured using a hierarchical dictionary of key:value pairs
# The minibatch source is configured using a hierarchical dictionary of
# key:value pairs
def create_mb_source(features_stream_name, labels_stream_name, image_height,
image_width, num_channels, num_classes, cifar_data_path):
image_width, num_channels, num_classes, cifar_data_path):
map_file = os.path.join(cifar_data_path, TRAIN_MAP_FILENAME)
mean_file = os.path.join(cifar_data_path, MEAN_FILENAME)
if not os.path.exists(map_file) or not os.path.exists(mean_file):
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them"%(map_file, mean_file, cifar_py3, cifar_py3))
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them" %
(map_file, mean_file, cifar_py3, cifar_py3))
image = ImageDeserializer(map_file)
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
image.map_features(features_stream_name,
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
jitter_type='uniRatio'),
@ -41,10 +46,27 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height,
rc = ReaderConfig(image, epoch_size=sys.maxsize)
return rc.minibatch_source()
=======
image.map_features(feature_name,
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
jitter_type='uniRatio'),
ImageDeserializer.scale(width=image_width, height=image_height,
channels=num_channels, interpolations='linear'),
ImageDeserializer.mean(mean_file)])
image.map_labels(label_name, num_classes)
rc = ReaderConfig(image, epoch_size=sys.maxsize)
input_streams_config = {
features_stream_name: features_stream_config, labels_stream_name: labels_stream_config}
deserializer_config = {"type": "ImageDeserializer",
"file": map_file, "input": input_streams_config}
return rc.minibatch_source()
>>>>>>> Address comments in CR
def get_projection_map(out_dim, in_dim):
if in_dim > out_dim:
raise ValueError("Can only project from lower to higher dimensionality")
raise ValueError(
"Can only project from lower to higher dimensionality")
projection_map_values = np.zeros(in_dim * out_dim, dtype=np.float32)
for i in range(0, in_dim):
@ -69,22 +91,32 @@ def resnet_classifer(input, num_classes):
conv1_w_scale = 0.26
c_map1 = 16
conv1 = conv_bn_relu_layer(input, c_map1, kernel_width, kernel_height, 1, 1, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn1_1 = resnet_node2(conv1, c_map1, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn1_2 = resnet_node2(rn1_1, c_map1, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn1_3 = resnet_node2(rn1_2, c_map1, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
conv1 = conv_bn_relu_layer(input, c_map1, kernel_width, kernel_height,
1, 1, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn1_1 = resnet_node2(conv1, c_map1, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn1_2 = resnet_node2(rn1_1, c_map1, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn1_3 = resnet_node2(rn1_2, c_map1, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
c_map2 = 32
rn2_1_wProj=get_projection_map(c_map2, c_map1)
rn2_1 = resnet_node2_inc(rn1_3, c_map2, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const, rn2_1_wProj)
rn2_2 = resnet_node2(rn2_1, c_map2, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn2_3 = resnet_node2(rn2_2, c_map2, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn2_1_wProj = get_projection_map(c_map2, c_map1)
rn2_1 = resnet_node2_inc(rn1_3, c_map2, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const, rn2_1_wProj)
rn2_2 = resnet_node2(rn2_1, c_map2, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn2_3 = resnet_node2(rn2_2, c_map2, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
c_map3 = 64
rn3_1_wProj=get_projection_map(c_map3, c_map2)
rn3_1 = resnet_node2_inc(rn2_3, c_map3, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const, rn3_1_wProj)
rn3_2 = resnet_node2(rn3_1, c_map3, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn3_3 = resnet_node2(rn3_2, c_map3, kernel_width, kernel_height, conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn3_1_wProj = get_projection_map(c_map3, c_map2)
rn3_1 = resnet_node2_inc(rn2_3, c_map3, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const, rn3_1_wProj)
rn3_2 = resnet_node2(rn3_1, c_map3, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
rn3_3 = resnet_node2(rn3_2, c_map3, kernel_width, kernel_height,
conv1_w_scale, conv_b_value, sc_value, bn_time_const)
# Global average pooling
poolw = 8
@ -92,27 +124,46 @@ def resnet_classifer(input, num_classes):
poolh_stride = 1
poolv_stride = 1
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride))
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform_initializer())
out_bias_params = parameter(shape=(num_classes), value=0)
=======
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw),
(1, poolv_stride, poolh_stride))
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes))
out_bias_params = parameter(shape=(num_classes))
>>>>>>> Address comments in CR
t = times(pool, out_times_params)
return t + out_bias_params
# Trains a residual network model on the Cifar image dataset
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
def cifar_resnet(base_path):
=======
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride))
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform_initializer())
out_bias_params = parameter(shape=(num_classes), value=0)
image_height = 32
image_width = 32
num_channels = 3
num_classes = 10
feats_stream_name = 'features'
def cifar_resnet(base_path):
labels_stream_name = 'labels'
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
minibatch_source = create_mb_source(feats_stream_name, labels_stream_name,
image_height, image_width, num_channels, num_classes, base_path)
=======
minibatch_source = create_mb_source(feats_stream_name, labels_stream_name,
image_height, image_width, num_channels, num_classes)
>>>>>>> Address comments in CR
features_si = minibatch_source.stream_info(feats_stream_name)
labels_si = minibatch_source.stream_info(labels_stream_name)
# Input variables denoting the features and label data
image_input = input_variable((num_channels, image_height, image_width), features_si.m_element_type)
image_input = input_variable(
(num_channels, image_height, image_width), features_si.m_element_type)
label_var = input_variable((num_classes), features_si.m_element_type)
# Instantiate the resnet classification model
@ -123,27 +174,29 @@ def cifar_resnet(base_path):
# Instantiate the trainer object to drive the model training
trainer = Trainer(classifier_output, ce, pe,
[sgd_learner(classifier_output.parameters(), lr=0.0078125)])
[sgd_learner(classifier_output.parameters(), lr=0.0078125)])
# Get minibatches of images to train with and perform model training
mb_size = 32
training_progress_output_freq = 20
num_mbs = 1000
for i in range(0, num_mbs):
mb=minibatch_source.get_next_minibatch(mb_size)
mb = minibatch_source.get_next_minibatch(mb_size)
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {image_input : mb[features_si].m_data, label_var : mb[labels_si].m_data}
# Specify the mapping of input variables in the model to actual
# minibatch data to be trained with
arguments = {image_input: mb[
features_si].m_data, label_var: mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(trainer, i, training_progress_output_freq)
if __name__=='__main__':
if __name__ == '__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor.gpu_device(0)
DeviceDescriptor.set_default_device(target_device)
base_path = os.path.normpath(os.path.join(\
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
base_path = os.path.normpath(os.path.join(
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
cifar_resnet(base_path)

Просмотреть файл

@ -14,14 +14,15 @@ abs_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(abs_path, "..", ".."))
from examples.common.nn import fully_connected_classifier_net, print_training_progress
TOLERANCE_ABSOLUTE = 1E-1
def check_path(path):
if not os.path.exists(path):
readme_file = os.path.normpath(os.path.join(os.path.dirname(path), "..", "README.md"))
raise RuntimeError("File '%s' does not exist. Please follow the instructions at %s to download and prepare it."%(path, readme_file))
readme_file = os.path.normpath(os.path.join(
os.path.dirname(path), "..", "README.md"))
raise RuntimeError(
"File '%s' does not exist. Please follow the instructions at %s to download and prepare it." % (path, readme_file))
# Creates and trains a feedforward classification model for MNIST images
def simple_mnist(debug_output=False):
input_dim = 784
num_output_classes = 10
@ -34,7 +35,8 @@ def simple_mnist(debug_output=False):
# Instantiate the feedforward classification model
scaled_input = element_times(constant((), 0.00390625), input)
netout = fully_connected_classifier_net(scaled_input, num_output_classes, hidden_layers_dim, num_hidden_layers, sigmoid)
netout = fully_connected_classifier_net(
scaled_input, num_output_classes, hidden_layers_dim, num_hidden_layers, sigmoid)
ce = cross_entropy_with_softmax(netout, label)
pe = classification_error(netout, label)
@ -47,12 +49,11 @@ def simple_mnist(debug_output=False):
labels_stream_name = 'labels'
mb_source = text_format_minibatch_source(path, [
StreamConfiguration( feature_stream_name, input_dim ),
StreamConfiguration( labels_stream_name, num_output_classes) ])
StreamConfiguration(feature_stream_name, input_dim),
StreamConfiguration(labels_stream_name, num_output_classes)])
features_si = mb_source.stream_info(feature_stream_name)
labels_si = mb_source.stream_info(labels_stream_name)
<<<<<<< 7ce41912e13b71716986658b3f62e7c8cff3b728
# Instantiate the trainer object to drive the model training
trainer = Trainer(netout, ce, pe, [sgd_learner(netout.parameters(),
lr=0.003125)])
@ -66,21 +67,24 @@ def simple_mnist(debug_output=False):
for i in range(0, int(num_minibatches_to_train)):
mb = mb_source.get_next_minibatch(minibatch_size)
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {input : mb[features_si].m_data, label : mb[labels_si].m_data}
# Specify the mapping of input variables in the model to actual
# minibatch data to be trained with
arguments = {input: mb[features_si].m_data,
label: mb[labels_si].m_data}
trainer.train_minibatch(arguments)
if debug_output:
print_training_progress(trainer, i, training_progress_output_freq)
# Load test data
rel_path = os.path.join(*"../../../../Examples/Image/MNIST/Data/Test-28x28_cntk_text.txt".split("/"))
rel_path = os.path.join(
*"../../../../Examples/Image/MNIST/Data/Test-28x28_cntk_text.txt".split("/"))
path = os.path.normpath(os.path.join(abs_path, rel_path))
check_path(path)
test_mb_source = text_format_minibatch_source(path, [
StreamConfiguration( feature_stream_name, input_dim ),
StreamConfiguration( labels_stream_name, num_output_classes ) ])
StreamConfiguration(feature_stream_name, input_dim),
StreamConfiguration(labels_stream_name, num_output_classes)])
features_si = test_mb_source.stream_info(feature_stream_name)
labels_si = test_mb_source.stream_info(labels_stream_name)
@ -92,17 +96,23 @@ def simple_mnist(debug_output=False):
for i in range(0, int(num_minibatches_to_test)):
mb = test_mb_source.get_next_minibatch(test_minibatch_size)
# Specify the mapping of input variables in the model to actual minibatch data to be tested with
arguments = {input : mb[features_si].m_data, label : mb[labels_si].m_data}
# Specify the mapping of input variables in the model to actual
# minibatch data to be tested with
arguments = {input: mb[features_si].m_data,
label: mb[labels_si].m_data}
eval_error = trainer.test_minibatch(arguments)
test_result = test_result + eval_error
return test_result / num_minibatches_to_test # Average of evaluation errors of all test minibatches
# Average of evaluation errors of all test minibatches
return test_result / num_minibatches_to_test
if __name__=='__main__':
# Specify the target device to be used for computing
#target_device = DeviceDescriptor.gpu_device(0)
#DeviceDescriptor.set_default_device(target_device)
target_device = DeviceDescriptor.gpu_device(0)
# If it is crashing, probably you don't have a GPU, so try with CPU:
# target_device = DeviceDescriptor.cpu_device()
DeviceDescriptor.set_default_device(target_device)
accuracy = simple_mnist()
print("test: %f"%accuracy)
error = simple_mnist()
print("test: %f" % error)

Просмотреть файл

@ -14,25 +14,26 @@ abs_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(abs_path, "..", ".."))
from examples.common.nn import fully_connected_classifier_net, print_training_progress
TOLERANCE_ABSOLUTE=1E-03
TOLERANCE_ABSOLUTE = 1E-03
# make sure we get always the same "randomness"
np.random.seed(0)
def generate_random_data(sample_size, feature_dim, num_classes):
# Create synthetic data using NumPy.
# Create synthetic data using NumPy.
Y = np.random.randint(size=(sample_size, 1), low=0, high=num_classes)
# Make sure that the data is separable
X = (np.random.randn(sample_size, feature_dim)+3) * (Y+1)
X = X.astype(np.float32)
# converting class 0 into the vector "1 0 0",
X = (np.random.randn(sample_size, feature_dim) + 3) * (Y + 1)
X = X.astype(np.float32)
# converting class 0 into the vector "1 0 0",
# class 1 into vector "0 1 0", ...
class_ind = [Y==class_number for class_number in range(num_classes)]
class_ind = [Y == class_number for class_number in range(num_classes)]
Y = np.asarray(np.hstack(class_ind), dtype=np.float32)
return X, Y
return X, Y
# Creates and trains a feedforward classification model
def ffnet(debug_output=True):
input_dim = 2
num_output_classes = 2
@ -44,43 +45,53 @@ def ffnet(debug_output=True):
label = input_variable((num_output_classes), np.float32)
# Instantiate the feedforward classification model
netout = fully_connected_classifier_net(input, num_output_classes, hidden_layers_dim, num_hidden_layers, sigmoid)
netout = fully_connected_classifier_net(
input, num_output_classes, hidden_layers_dim, num_hidden_layers, sigmoid)
ce = cross_entropy_with_softmax(netout, label)
pe = classification_error(netout, label)
# Instantiate the trainer object to drive the model training
trainer = Trainer(netout, ce, pe, [sgd_learner(netout.parameters(), lr=0.02)])
trainer = Trainer(
netout, ce, pe, [sgd_learner(netout.parameters(), lr=0.02)])
# Get minibatches of training data and perform model training
minibatch_size = 25
num_samples_per_sweep = 10000
num_sweeps_to_train_with = 2
num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size
num_minibatches_to_train = (
num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size
training_progress_output_freq = 20
for i in range(0, int(num_minibatches_to_train)):
features, labels = generate_random_data(minibatch_size, input_dim, num_output_classes)
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
trainer.train_minibatch({input : features, label : labels})
features, labels = generate_random_data(
minibatch_size, input_dim, num_output_classes)
# Specify the mapping of input variables in the model to actual
# minibatch data to be trained with
trainer.train_minibatch({input: features, label: labels})
if debug_output:
print_training_progress(trainer, i, training_progress_output_freq)
test_features, test_labels = generate_random_data(minibatch_size, input_dim, num_output_classes)
avg_error = trainer.test_minibatch({input : test_features, label : test_labels})
test_features, test_labels = generate_random_data(
minibatch_size, input_dim, num_output_classes)
avg_error = trainer.test_minibatch(
{input: test_features, label: test_labels})
return avg_error
def test_accuracy(device_id):
def test_error(device_id):
from cntk.utils import cntk_device
DeviceDescriptor.set_default_device(cntk_device(device_id))
avg_error = ffnet(debug_output=False)
expected_avg_error = 0.12
assert np.allclose([avg_error], [expected_avg_error], atol=TOLERANCE_ABSOLUTE)
assert np.allclose(avg_error, expected_avg_error, atol=TOLERANCE_ABSOLUTE)
if __name__=='__main__':
if __name__ == '__main__':
# Specify the target device to be used for computing
#target_device = DeviceDescriptor.gpu_device(0)
#DeviceDescriptor.set_default_device(target_device)
target_device = DeviceDescriptor.gpu_device(0)
# If it is crashing, probably you don't have a GPU, so try with CPU:
# target_device = DeviceDescriptor.cpu_device()
DeviceDescriptor.set_default_device(target_device)
accuracy = ffnet()
print("test: %f"%accuracy)
error = ffnet()
print("test: %f" % error)

Просмотреть файл

@ -17,6 +17,8 @@ sys.path.append(os.path.join(abs_path, "..", ".."))
from examples.common.nn import LSTMP_component_with_self_stabilization, stabilize, linear_layer, print_training_progress
# Creates and trains a sequence to sequence translation model
def train_sequence_to_sequence_translator():
input_vocab_dim = 69
@ -30,11 +32,13 @@ def train_sequence_to_sequence_translator():
input_seq_axis = Axis('inputAxis')
label_seq_axis = Axis('labelAxis')
input_dynamic_axes = [ batch_axis, input_seq_axis ]
raw_input = input_variable(shape=(input_vocab_dim), dynamic_axes = input_dynamic_axes)
input_dynamic_axes = [batch_axis, input_seq_axis]
raw_input = input_variable(
shape=(input_vocab_dim), dynamic_axes=input_dynamic_axes)
label_dynamic_axes = [ batch_axis, label_seq_axis ]
raw_labels = input_variable(shape=(label_vocab_dim), dynamic_axes = label_dynamic_axes)
label_dynamic_axes = [batch_axis, label_seq_axis]
raw_labels = input_variable(
shape=(label_vocab_dim), dynamic_axes=label_dynamic_axes)
# Instantiate the sequence to sequence translation model
input_sequence = raw_input
@ -44,22 +48,27 @@ def train_sequence_to_sequence_translator():
label_sentence_start = sequence.first(raw_labels)
is_first_label = sequence.is_first(label_sequence)
label_sentence_start_scattered = sequence.scatter(label_sentence_start, is_first_label)
label_sentence_start_scattered = sequence.scatter(
label_sentence_start, is_first_label)
# Encoder
encoder_outputH = stabilize(input_sequence)
for i in range(0, num_layers):
(encoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization(encoder_outputH.output(), hidden_dim, hidden_dim, future_value, future_value)
(encoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization(
encoder_outputH.output(), hidden_dim, hidden_dim, future_value, future_value)
thought_vectorH = sequence.first(encoder_outputH)
thought_vectorC = sequence.first(encoder_outputC)
thought_vector_broadcastH = sequence.broadcast_as(thought_vectorH, label_sequence)
thought_vector_broadcastC = sequence.broadcast_as(thought_vectorC, label_sequence)
thought_vector_broadcastH = sequence.broadcast_as(
thought_vectorH, label_sequence)
thought_vector_broadcastC = sequence.broadcast_as(
thought_vectorC, label_sequence)
# Decoder
decoder_history_from_ground_truth = label_sequence
decoder_input = element_select(is_first_label, label_sentence_start_scattered, past_value(decoder_history_from_ground_truth))
decoder_input = element_select(is_first_label, label_sentence_start_scattered, past_value(
decoder_history_from_ground_truth))
decoder_outputH = stabilize(decoder_input)
for i in range(0, num_layers):
@ -68,10 +77,13 @@ def train_sequence_to_sequence_translator():
recurrence_hookC = past_value
else:
isFirst = sequence.is_first(label_sequence)
recurrence_hookH = lambda operand: element_select(isFirst, thought_vector_broadcastH, past_value(operand))
recurrence_hookC = lambda operand: element_select(isFirst, thought_vector_broadcastC, past_value(operand))
recurrence_hookH = lambda operand: element_select(
isFirst, thought_vector_broadcastH, past_value(operand))
recurrence_hookC = lambda operand: element_select(
isFirst, thought_vector_broadcastC, past_value(operand))
(decoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization(decoder_outputH.output(), hidden_dim, hidden_dim, recurrence_hookH, recurrence_hookC)
(decoder_outputH, encoder_outputC) = LSTMP_component_with_self_stabilization(
decoder_outputH.output(), hidden_dim, hidden_dim, recurrence_hookH, recurrence_hookC)
decoder_output = decoder_outputH
decoder_dim = hidden_dim
@ -86,38 +98,42 @@ def train_sequence_to_sequence_translator():
feature_stream_name = 'features'
labels_stream_name = 'labels'
mb_source = text_format_minibatch_source(path, [
StreamConfiguration( feature_stream_name, input_vocab_dim, True, 'S0' ),
StreamConfiguration( labels_stream_name, label_vocab_dim, True, 'S1') ], 10000)
mb_source = text_format_minibatch_source(path, [
StreamConfiguration(feature_stream_name, input_vocab_dim, True, 'S0'),
StreamConfiguration(labels_stream_name, label_vocab_dim, True, 'S1')], 10000)
features_si = mb_source.stream_info(feature_stream_name)
labels_si = mb_source.stream_info(labels_stream_name)
# Instantiate the trainer object to drive the model training
lr = 0.007
momentum_time_constant = 1100
momentum_per_sample = momentums_per_sample(math.exp(-1.0 / momentum_time_constant))
momentum_per_sample = momentums_per_sample(
math.exp(-1.0 / momentum_time_constant))
clipping_threshold_per_sample = 2.3
gradient_clipping_with_truncation = True
trainer = Trainer(z, ce, errs, [momentum_sgd_learner(z.parameters(), lr, momentum_per_sample, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
trainer = Trainer(z, ce, errs, [momentum_sgd_learner(z.parameters(
), lr, momentum_per_sample, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
# Get minibatches of sequences to train with and perform model training
minibatch_size = 72
training_progress_output_freq = 10
while True:
mb = mb_source.get_next_minibatch(minibatch_size)
if len(mb) == 0:
if len(mb) == 0:
break
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {raw_input : mb[features_si].m_data, raw_labels : mb[labels_si].m_data}
# Specify the mapping of input variables in the model to actual
# minibatch data to be trained with
arguments = {raw_input: mb[features_si].m_data,
raw_labels: mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(trainer, i, training_progress_output_freq)
i += 1
if __name__=='__main__':
if __name__ == '__main__':
# Specify the target device to be used for computing
target_device = DeviceDescriptor.cpu_device()
DeviceDescriptor.set_default_device(target_device)

Просмотреть файл

@ -15,30 +15,32 @@ abs_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(abs_path, "..", ".."))
from examples.common.nn import LSTMP_component_with_self_stabilization, embedding, linear_layer, select_last, print_training_progress
TOLERANCE_ABSOLUTE=1E-2
# Defines the LSTM model for classifying sequences
def LSTM_sequence_classifer_net(input, num_output_classes, embedding_dim, LSTM_dim, cell_dim):
embedding_function = embedding(input, embedding_dim)
LSTM_function = LSTMP_component_with_self_stabilization(embedding_function.output(), LSTM_dim, cell_dim)[0]
LSTM_function = LSTMP_component_with_self_stabilization(
embedding_function.output(), LSTM_dim, cell_dim)[0]
thought_vector = select_last(LSTM_function)
return linear_layer(thought_vector, num_output_classes)
# Creates and trains a LSTM sequence classification model
def train_sequence_classifier():
input_dim = 2000;
cell_dim = 25;
hidden_dim = 25;
embedding_dim = 50;
num_output_classes = 5;
input_dim = 2000
cell_dim = 25
hidden_dim = 25
embedding_dim = 50
num_output_classes = 5
# Input variables denoting the features and label data
features = input_variable(shape=input_dim, is_sparse=True)
label = input_variable(num_output_classes, dynamic_axes = [Axis.default_batch_axis()])
label = input_variable(num_output_classes, dynamic_axes=[
Axis.default_batch_axis()])
# Instantiate the sequence classification model
classifier_output = LSTM_sequence_classifer_net(features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
classifier_output = LSTM_sequence_classifer_net(
features, num_output_classes, embedding_dim, hidden_dim, cell_dim)
ce = cross_entropy_with_softmax(classifier_output, label)
pe = classification_error(classifier_output, label)
@ -49,28 +51,30 @@ def train_sequence_classifier():
labels_stream_name = 'labels'
mb_source = text_format_minibatch_source(path, [
StreamConfiguration( feature_stream_name, input_dim, True, 'x' ),
StreamConfiguration( labels_stream_name, num_output_classes, False, 'y')], 0)
StreamConfiguration(feature_stream_name, input_dim, True, 'x'),
StreamConfiguration(labels_stream_name, num_output_classes, False, 'y')], 0)
features_si = mb_source.stream_info(features)
labels_si = mb_source.stream_info(label)
# Instantiate the trainer object to drive the model training
trainer = Trainer(classifier_output, ce, pe,
[sgd_learner(classifier_output.parameters(), lr=0.0005)])
[sgd_learner(classifier_output.parameters(), lr=0.0005)])
# Get minibatches of sequences to train with and perform model training
minibatch_size = 200
training_progress_output_freq = 10
i = 0;
i = 0
while True:
mb = mb_source.get_next_minibatch(minibatch_size)
if len(mb) == 0:
if len(mb) == 0:
break
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {features : mb[features_si].m_data, label : mb[labels_si].m_data}
# Specify the mapping of input variables in the model to actual
# minibatch data to be trained with
arguments = {features: mb[features_si].m_data,
label: mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(trainer, i, training_progress_output_freq)
@ -79,24 +83,18 @@ def train_sequence_classifier():
import copy
evaluation_average = copy.copy(trainer.previous_minibatch_evaluation_average())
evaluation_average = copy.copy(
trainer.previous_minibatch_evaluation_average())
loss_average = copy.copy(trainer.previous_minibatch_loss_average())
return evaluation_average, loss_average
def test_accuracy(device_id):
from cntk.utils import cntk_device
DeviceDescriptor.set_default_device(cntk_device(device_id))
evaluation_avg, loss_avg = train_sequence_classifier()
expected_avg = [0.1595744, 0.35799171]
assert np.allclose([evaluation_avg, loss_avg], expected_avg, atol=TOLERANCE_ABSOLUTE)
if __name__=='__main__':
if __name__ == '__main__':
# Specify the target device to be used for computing
#target_device = DeviceDescriptor.gpu_device(0)
#DeviceDescriptor.set_default_device(target_device)
target_device = DeviceDescriptor.gpu_device(0)
# If it is crashing, probably you don't have a GPU, so try with CPU:
# target_device = DeviceDescriptor.cpu_device()
DeviceDescriptor.set_default_device(target_device)
accuracy, _ = train_sequence_classifier()
print("test: %f"%accuracy)
error, _ = train_sequence_classifier()
print("test: %f" % error)

Просмотреть файл

@ -11,6 +11,7 @@ from cntk.ops import *
from cntk.utils import sanitize_dtype_cntk, get_train_eval_criterion, get_train_loss
from cntk.initializer import glorot_uniform_initializer
def linear_layer(input_var, output_dim):
try:
shape = input_var.shape()
@ -25,11 +26,14 @@ def linear_layer(input_var, output_dim):
t = times(input_var, times_param)
return bias_param + t
def fully_connected_layer(input, output_dim, nonlinearity):
p = linear_layer(input, output_dim)
return nonlinearity(p)
# Defines a multilayer feedforward classification model
def fully_connected_classifier_net(input, num_output_classes, hidden_layer_dim, num_hidden_layers, nonlinearity):
r = fully_connected_layer(input, hidden_layer_dim, nonlinearity)
for i in range(1, num_hidden_layers):
@ -37,6 +41,7 @@ def fully_connected_classifier_net(input, num_output_classes, hidden_layer_dim,
return linear_layer(r, num_output_classes)
def conv_bn_layer(input, out_feature_map_count, kernel_width, kernel_height, h_stride, v_stride, w_scale, b_value, sc_value, bn_time_const):
try:
shape = input.shape()
@ -55,16 +60,22 @@ def conv_bn_layer(input, out_feature_map_count, kernel_width, kernel_height, h_s
running_invstd = constant((out_feature_map_count), 0.0)
return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, True, bn_time_const, 0.0, 0.000000001)
def conv_bn_relu_layer(input, out_feature_map_count, kernel_width, kernel_height, h_stride, v_stride, w_scale, b_value, sc_value, bn_time_const):
conv_bn_function = conv_bn_layer(input, out_feature_map_count, kernel_width, kernel_height, h_stride, v_stride, w_scale, b_value, sc_value, bn_time_const)
conv_bn_function = conv_bn_layer(input, out_feature_map_count, kernel_width,
kernel_height, h_stride, v_stride, w_scale, b_value, sc_value, bn_time_const)
return relu(conv_bn_function)
def resnet_node2(input, out_feature_map_count, kernel_width, kernel_height, w_scale, b_value, sc_value, bn_time_const):
c1 = conv_bn_relu_layer(input, out_feature_map_count, kernel_width, kernel_height, 1, 1, w_scale, b_value, sc_value, bn_time_const)
c2 = conv_bn_layer(c1, out_feature_map_count, kernel_width, kernel_height, 1, 1, w_scale, b_value, sc_value, bn_time_const)
c1 = conv_bn_relu_layer(input, out_feature_map_count, kernel_width,
kernel_height, 1, 1, w_scale, b_value, sc_value, bn_time_const)
c2 = conv_bn_layer(c1, out_feature_map_count, kernel_width,
kernel_height, 1, 1, w_scale, b_value, sc_value, bn_time_const)
p = c2 + input
return relu(p)
def proj_layer(w_proj, input, h_stride, v_stride, b_value, sc_value, bn_time_const):
try:
shape = input.shape()
@ -81,35 +92,43 @@ def proj_layer(w_proj, input, h_stride, v_stride, b_value, sc_value, bn_time_con
running_invstd = constant((out_feature_map_count), 0.0)
return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, True, bn_time_const)
def resnet_node2_inc(input, out_feature_map_count, kernel_width, kernel_height, w_scale, b_value, sc_value, bn_time_const, w_proj):
c1 = conv_bn_relu_layer(input, out_feature_map_count, kernel_width, kernel_height, 2, 2, w_scale, b_value, sc_value, bn_time_const)
c2 = conv_bn_layer(c1, out_feature_map_count, kernel_width, kernel_height, 1, 1, w_scale, b_value, sc_value, bn_time_const)
c1 = conv_bn_relu_layer(input, out_feature_map_count, kernel_width,
kernel_height, 2, 2, w_scale, b_value, sc_value, bn_time_const)
c2 = conv_bn_layer(c1, out_feature_map_count, kernel_width,
kernel_height, 1, 1, w_scale, b_value, sc_value, bn_time_const)
c_proj = proj_layer(w_proj, input, 2, 2, b_value, sc_value, bn_time_const)
p = c2 + c_proj
return relu(p)
def embedding(input, embedding_dim):
input_dim = input.shape()[0];
input_dim = input.shape()[0]
embedding_parameters = parameter(shape=(input_dim, embedding_dim), initializer=glorot_uniform_initializer())
return times(input, embedding_parameters)
def select_last(operand):
return slice(operand, Axis.default_dynamic_axis(), -1, 0)
def stabilize(operand):
scalar_constant = 4.0
f = constant(sanitize_dtype_cntk(np.float32), scalar_constant);
f = constant(sanitize_dtype_cntk(np.float32), scalar_constant)
fInv = constant(sanitize_dtype_cntk(np.float32), 1.0 / scalar_constant)
beta = element_times(fInv, log(constant(sanitize_dtype_cntk(np.float32), 1.0) + exp(element_times(f, parameter(value=0.99537863)))))
beta = element_times(fInv, log(constant(sanitize_dtype_cntk(
np.float32), 1.0) + exp(element_times(f, parameter(value=0.99537863)))))
return element_times(beta, operand)
def LSTMP_cell_with_self_stabilization(input, prev_output, prev_cell_state):
input_dim = input.shape()[0]
output_dim = prev_output.shape()[0];
cell_dim = prev_cell_state.shape()[0];
output_dim = prev_output.shape()[0]
cell_dim = prev_cell_state.shape()[0]
Wxo = parameter(shape=(input_dim, cell_dim), initializer=glorot_uniform_initializer())
Wxi = parameter(shape=(input_dim, cell_dim), initializer=glorot_uniform_initializer())
@ -193,21 +212,28 @@ def LSTMP_cell_with_self_stabilization(input, prev_output, prev_cell_state):
mt = element_times(ot, tanh(ct))
return (times(element_times(expsWmr, mt), Wmr), ct)
def LSTMP_component_with_self_stabilization(input, output_dim, cell_dim, recurrence_hookH = past_value, recurrence_hookC = past_value):
dh = placeholder_variable(shape=(output_dim), dynamic_axes=input.dynamic_axes())
dc = placeholder_variable(shape=(cell_dim), dynamic_axes=input.dynamic_axes())
def LSTMP_component_with_self_stabilization(input, output_dim, cell_dim, recurrence_hookH=past_value, recurrence_hookC=past_value):
dh = placeholder_variable(
shape=(output_dim), dynamic_axes=input.dynamic_axes())
dc = placeholder_variable(
shape=(cell_dim), dynamic_axes=input.dynamic_axes())
LSTMCell = LSTMP_cell_with_self_stabilization(input, dh, dc)
actualDh = recurrence_hookH(LSTMCell[0]);
actualDc = recurrence_hookC(LSTMCell[1]);
actualDh = recurrence_hookH(LSTMCell[0])
actualDc = recurrence_hookC(LSTMCell[1])
# Form the recurrence loop by replacing the dh and dc placeholders with the actualDh and actualDc
LSTMCell[0].replace_placeholders({ dh : actualDh.output(), dc : actualDc.output()})
# Form the recurrence loop by replacing the dh and dc placeholders with
# the actualDh and actualDc
LSTMCell[0].replace_placeholders(
{dh: actualDh.output(), dc: actualDc.output()})
return (LSTMCell[0], LSTMCell[1])
def print_training_progress(trainer, mb, frequency):
if mb%frequency == 0:
if mb % frequency == 0:
training_loss = get_train_loss(trainer)
eval_crit = get_train_eval_criterion(trainer)
print ("Minibatch: {}, Train Loss: {}, Train Evaluation Criterion: {}".format(mb, training_loss, eval_crit))
print("Minibatch: {}, Train Loss: {}, Train Evaluation Criterion: {}".format(
mb, training_loss, eval_crit))

Просмотреть файл

@ -0,0 +1,20 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
from cntk import DeviceDescriptor
from examples.NumpyInterop.FeedForwardNet import ffnet
TOLERANCE_ABSOLUTE = 1E-03
def test_error(device_id):
#from cntk.utils import cntk_device
#DeviceDescriptor.set_default_device(cntk_device(device_id))
avg_error = ffnet(debug_output=False)
expected_avg_error = 0.12
assert np.allclose(avg_error, expected_avg_error, atol=TOLERANCE_ABSOLUTE)

Просмотреть файл

@ -0,0 +1,22 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
from cntk import DeviceDescriptor
from examples.SequenceClassification.SequenceClassification import train_sequence_classifier
TOLERANCE_ABSOLUTE = 1E-2
def test_error(device_id):
#from cntk.utils import cntk_device
#DeviceDescriptor.set_default_device(cntk_device(device_id))
evaluation_avg, loss_avg = train_sequence_classifier()
expected_avg = [0.1595744, 0.35799171]
assert np.allclose([evaluation_avg, loss_avg],
expected_avg, atol=TOLERANCE_ABSOLUTE)

Просмотреть файл

@ -0,0 +1,22 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
from cntk import DeviceDescriptor
from examples.MNIST.SimpleMNIST import simple_mnist
TOLERANCE_ABSOLUTE = 1E-1
def test_error(device_id):
#from cntk.utils import cntk_device
#DeviceDescriptor.set_default_device(cntk_device(device_id))
test_error = simple_mnist()
expected_test_error = 0.7
assert np.allclose([test_error], [expected_test_error],
atol=TOLERANCE_ABSOLUTE)

Просмотреть файл

@ -23,20 +23,25 @@ else:
if IS_WINDOWS:
CNTK_LIB_PATH = os.path.join(CNTK_PATH, "x64", "Release")
else:
CNTK_LIB_PATH = os.path.join(CNTK_PATH, "build", "gpu", "release", "lib")
CNTK_LIB_PATH = os.path.join(
CNTK_PATH, "build", "gpu", "release", "lib")
print("Using CNTK sources at '%s'" % os.path.abspath(CNTK_SOURCE_PATH))
print("Using CNTK libs at '%s'" % os.path.abspath(CNTK_LIB_PATH))
print("Using CNTK sources at '%s'"%os.path.abspath(CNTK_SOURCE_PATH))
print("Using CNTK libs at '%s'"%os.path.abspath(CNTK_LIB_PATH))
def lib_path(fn):
return os.path.normpath(os.path.join(CNTK_LIB_PATH, fn))
def proj_lib_path(fn):
return os.path.normpath(os.path.join(PROJ_LIB_PATH, fn))
def strip_path(fn):
return os.path.split(fn)[1]
def strip_ext(fn):
return os.path.splitext(fn)[0]
@ -44,19 +49,20 @@ if IS_WINDOWS:
libname_rt_ext = '.dll'
link_libs = [strip_ext(strip_path(fn)) for fn in
glob(os.path.join(CNTK_LIB_PATH, '*.lib'))]
glob(os.path.join(CNTK_LIB_PATH, '*.lib'))]
else:
link_libs=[
"cntklibrary-2.0",
"cntkmath"
link_libs = [
"cntklibrary-2.0",
"cntkmath"
]
libname_rt_ext = '.so'
rt_libs = [strip_path(fn) for fn in glob(os.path.join(CNTK_LIB_PATH,
'*'+libname_rt_ext))]
'*' + libname_rt_ext))]
# copy over the libraries to the cntk base directory so that the rpath is correctly set
# copy over the libraries to the cntk base directory so that the rpath is
# correctly set
if os.path.exists(PROJ_LIB_PATH):
shutil.rmtree(PROJ_LIB_PATH)
@ -71,21 +77,21 @@ for fn in rt_libs:
rt_libs = [os.path.join('libs', fn) for fn in rt_libs]
extra_compile_args = [
"-DSWIG",
"-DUNICODE"
]
"-DSWIG",
"-DUNICODE"
]
if IS_WINDOWS:
extra_compile_args += [
"/EHsc",
"/DEBUG",
"/Zi",
"/EHsc",
"/EHsc",
"/DEBUG",
"/Zi",
"/EHsc",
]
runtime_library_dirs = []
else:
extra_compile_args += [
'--std=c++11',
'--std=c++11',
]
# Expecting the dependent libs (libcntklibrary-2.0.so, etc.) inside
@ -96,49 +102,51 @@ else:
swig_source = os.path.join("cntk", "swig", "cntk_py_wrap.cxx")
if not os.path.exists(swig_source):
print("SWIG wrapper missing. Have you run SWIG already?")
sys.exit(1)
print("SWIG wrapper missing. Have you run SWIG already?")
sys.exit(1)
cntk_module = Extension(
name="_cntk_py",
name="_cntk_py",
sources=[swig_source],
sources=[swig_source],
libraries=link_libs,
library_dirs=[CNTK_LIB_PATH],
libraries=link_libs,
library_dirs=[CNTK_LIB_PATH],
runtime_library_dirs=runtime_library_dirs,
runtime_library_dirs=runtime_library_dirs,
include_dirs=[
os.path.join(CNTK_SOURCE_PATH, "CNTKv2LibraryDll", "API"),
os.path.join(CNTK_SOURCE_PATH, "Math"),
os.path.join(CNTK_SOURCE_PATH, "Common", "Include"),
numpy.get_include(),
],
include_dirs=[
os.path.join(CNTK_SOURCE_PATH, "CNTKv2LibraryDll", "API"),
os.path.join(CNTK_SOURCE_PATH, "Math"),
os.path.join(CNTK_SOURCE_PATH, "Common", "Include"),
numpy.get_include(),
],
extra_compile_args = extra_compile_args,
extra_compile_args=extra_compile_args,
language="c++",
)
language="c++",
)
# do not include tests and examples
packages = [x for x in find_packages() if x.startswith('cntk') and not x.startswith('cntk.swig')]
packages = [x for x in find_packages() if x.startswith(
'cntk') and not x.startswith('cntk.swig')]
if IS_WINDOWS:
# On Windows copy all runtime libs to the base folder of Python
kwargs = dict(data_files = [('.', [ os.path.join('cntk', lib) for lib in rt_libs ])])
kwargs = dict(
data_files=[('.', [os.path.join('cntk', lib) for lib in rt_libs])])
else:
# On Linux copy all runtime libs into the cntk/lib folder.
kwargs = dict(package_data = { 'cntk': rt_libs })
# On Linux copy all runtime libs into the cntk/lib folder.
kwargs = dict(package_data={'cntk': rt_libs})
setup(name="cntk",
setup(name="cntk",
version="2.0a2",
url="http://cntk.ai",
ext_modules = [cntk_module],
ext_modules=[cntk_module],
packages=packages,
#install_requires=[
# install_requires=[
# 'numpy>=1.11',
# 'scipy>=0.17'
#],
**kwargs
)
)