Merge commit '6822bb5e0a694a9a23c749b0d629c65484e6219a' into wilrich/miscAlpha2
This commit is contained in:
Коммит
27684e284d
|
@ -19,16 +19,17 @@ from examples.common.nn import conv_bn_relu_layer, conv_bn_layer, resnet_node2,
|
||||||
|
|
||||||
TRAIN_MAP_FILENAME = 'train_map.txt'
|
TRAIN_MAP_FILENAME = 'train_map.txt'
|
||||||
MEAN_FILENAME = 'CIFAR-10_mean.xml'
|
MEAN_FILENAME = 'CIFAR-10_mean.xml'
|
||||||
|
TEST_MAP_FILENAME = 'test_map.txt'
|
||||||
|
|
||||||
# Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net
|
# Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net
|
||||||
# The minibatch source is configured using a hierarchical dictionary of
|
# The minibatch source is configured using a hierarchical dictionary of key:value pairs
|
||||||
# key:value pairs
|
|
||||||
|
|
||||||
|
|
||||||
def create_mb_source(features_stream_name, labels_stream_name, image_height,
|
def create_mb_source(features_stream_name, labels_stream_name, image_height,
|
||||||
image_width, num_channels, num_classes, cifar_data_path):
|
image_width, num_channels, num_classes, cifar_data_path):
|
||||||
map_file = os.path.join(cifar_data_path, TRAIN_MAP_FILENAME)
|
|
||||||
mean_file = os.path.join(cifar_data_path, MEAN_FILENAME)
|
path = os.path.normpath(os.path.join(abs_path, cifar_data_path))
|
||||||
|
map_file = os.path.join(path, TRAIN_MAP_FILENAME)
|
||||||
|
mean_file = os.path.join(path, MEAN_FILENAME)
|
||||||
|
|
||||||
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
||||||
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
|
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
|
||||||
|
@ -36,7 +37,6 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height,
|
||||||
(map_file, mean_file, cifar_py3, cifar_py3))
|
(map_file, mean_file, cifar_py3, cifar_py3))
|
||||||
|
|
||||||
image = ImageDeserializer(map_file)
|
image = ImageDeserializer(map_file)
|
||||||
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
|
|
||||||
image.map_features(features_stream_name,
|
image.map_features(features_stream_name,
|
||||||
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
|
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
|
||||||
jitter_type='uniRatio'),
|
jitter_type='uniRatio'),
|
||||||
|
@ -47,23 +47,32 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height,
|
||||||
|
|
||||||
rc = ReaderConfig(image, epoch_size=sys.maxsize)
|
rc = ReaderConfig(image, epoch_size=sys.maxsize)
|
||||||
return rc.minibatch_source()
|
return rc.minibatch_source()
|
||||||
=======
|
|
||||||
image.map_features(feature_name,
|
def create_test_mb_source(features_stream_name, labels_stream_name, image_height,
|
||||||
|
image_width, num_channels, num_classes, cifar_data_path):
|
||||||
|
|
||||||
|
path = os.path.normpath(os.path.join(abs_path, cifar_data_path))
|
||||||
|
|
||||||
|
map_file = os.path.join(path, TEST_MAP_FILENAME)
|
||||||
|
mean_file = os.path.join(path, MEAN_FILENAME)
|
||||||
|
|
||||||
|
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
||||||
|
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
|
||||||
|
raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them" %
|
||||||
|
(map_file, mean_file, cifar_py3, cifar_py3))
|
||||||
|
|
||||||
|
image = ImageDeserializer(map_file)
|
||||||
|
image.map_features(features_stream_name,
|
||||||
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
|
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
|
||||||
jitter_type='uniRatio'),
|
jitter_type='uniRatio'),
|
||||||
ImageDeserializer.scale(width=image_width, height=image_height,
|
ImageDeserializer.scale(width=image_width, height=image_height,
|
||||||
channels=num_channels, interpolations='linear'),
|
channels=num_channels, interpolations='linear'),
|
||||||
ImageDeserializer.mean(mean_file)])
|
ImageDeserializer.mean(mean_file)])
|
||||||
image.map_labels(label_name, num_classes)
|
image.map_labels(labels_stream_name, num_classes)
|
||||||
|
|
||||||
rc = ReaderConfig(image, epoch_size=sys.maxsize)
|
rc = ReaderConfig(image, epoch_size=sys.maxsize)
|
||||||
|
|
||||||
input_streams_config = {
|
|
||||||
features_stream_name: features_stream_config, labels_stream_name: labels_stream_config}
|
|
||||||
deserializer_config = {"type": "ImageDeserializer",
|
|
||||||
"file": map_file, "input": input_streams_config}
|
|
||||||
return rc.minibatch_source()
|
return rc.minibatch_source()
|
||||||
>>>>>>> Address comments in CR
|
|
||||||
def get_projection_map(out_dim, in_dim):
|
def get_projection_map(out_dim, in_dim):
|
||||||
if in_dim > out_dim:
|
if in_dim > out_dim:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -125,40 +134,23 @@ def resnet_classifer(input, num_classes):
|
||||||
poolh_stride = 1
|
poolh_stride = 1
|
||||||
poolv_stride = 1
|
poolv_stride = 1
|
||||||
|
|
||||||
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
|
|
||||||
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride))
|
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride))
|
||||||
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform())
|
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform())
|
||||||
out_bias_params = parameter(shape=(num_classes), value=0)
|
out_bias_params = parameter(shape=(num_classes), value=0)
|
||||||
=======
|
|
||||||
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw),
|
|
||||||
(1, poolv_stride, poolh_stride))
|
|
||||||
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes))
|
|
||||||
out_bias_params = parameter(shape=(num_classes))
|
|
||||||
>>>>>>> Address comments in CR
|
|
||||||
t = times(pool, out_times_params)
|
t = times(pool, out_times_params)
|
||||||
return t + out_bias_params
|
return t + out_bias_params
|
||||||
|
|
||||||
# Trains a residual network model on the Cifar image dataset
|
# Trains a residual network model on the Cifar image dataset
|
||||||
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
|
def cifar_resnet(base_path, debug_output=False):
|
||||||
def cifar_resnet(base_path):
|
|
||||||
=======
|
|
||||||
|
|
||||||
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride))
|
|
||||||
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform_initializer())
|
|
||||||
out_bias_params = parameter(shape=(num_classes), value=0)
|
|
||||||
image_height = 32
|
image_height = 32
|
||||||
image_width = 32
|
image_width = 32
|
||||||
num_channels = 3
|
num_channels = 3
|
||||||
num_classes = 10
|
num_classes = 10
|
||||||
def cifar_resnet(base_path):
|
feats_stream_name = 'features'
|
||||||
labels_stream_name = 'labels'
|
labels_stream_name = 'labels'
|
||||||
<<<<<<< 391432ca77060ad88807339d773f288de6557c4a
|
|
||||||
minibatch_source = create_mb_source(feats_stream_name, labels_stream_name,
|
minibatch_source = create_mb_source(feats_stream_name, labels_stream_name,
|
||||||
image_height, image_width, num_channels, num_classes, base_path)
|
image_height, image_width, num_channels, num_classes, base_path)
|
||||||
=======
|
|
||||||
minibatch_source = create_mb_source(feats_stream_name, labels_stream_name,
|
|
||||||
image_height, image_width, num_channels, num_classes)
|
|
||||||
>>>>>>> Address comments in CR
|
|
||||||
features_si = minibatch_source.stream_info(feats_stream_name)
|
features_si = minibatch_source.stream_info(feats_stream_name)
|
||||||
labels_si = minibatch_source.stream_info(labels_stream_name)
|
labels_si = minibatch_source.stream_info(labels_stream_name)
|
||||||
|
|
||||||
|
@ -181,6 +173,7 @@ def cifar_resnet(base_path):
|
||||||
mb_size = 32
|
mb_size = 32
|
||||||
training_progress_output_freq = 20
|
training_progress_output_freq = 20
|
||||||
num_mbs = 1000
|
num_mbs = 1000
|
||||||
|
|
||||||
for i in range(0, num_mbs):
|
for i in range(0, num_mbs):
|
||||||
mb = minibatch_source.get_next_minibatch(mb_size)
|
mb = minibatch_source.get_next_minibatch(mb_size)
|
||||||
|
|
||||||
|
@ -190,8 +183,30 @@ def cifar_resnet(base_path):
|
||||||
features_si].m_data, label_var: mb[labels_si].m_data}
|
features_si].m_data, label_var: mb[labels_si].m_data}
|
||||||
trainer.train_minibatch(arguments)
|
trainer.train_minibatch(arguments)
|
||||||
|
|
||||||
|
if debug_output:
|
||||||
print_training_progress(trainer, i, training_progress_output_freq)
|
print_training_progress(trainer, i, training_progress_output_freq)
|
||||||
|
|
||||||
|
test_minibatch_source = create_test_mb_source(feats_stream_name, labels_stream_name,
|
||||||
|
image_height, image_width, num_channels, num_classes, base_path)
|
||||||
|
features_si = test_minibatch_source.stream_info(feats_stream_name)
|
||||||
|
labels_si = test_minibatch_source.stream_info(labels_stream_name)
|
||||||
|
|
||||||
|
mb_size = 64
|
||||||
|
num_mbs = 300
|
||||||
|
|
||||||
|
total_error = 0.0
|
||||||
|
for i in range(0, num_mbs):
|
||||||
|
mb = test_minibatch_source.get_next_minibatch(mb_size)
|
||||||
|
|
||||||
|
# Specify the mapping of input variables in the model to actual
|
||||||
|
# minibatch data to be trained with
|
||||||
|
arguments = {image_input: mb[
|
||||||
|
features_si].m_data, label_var: mb[labels_si].m_data}
|
||||||
|
error = trainer.test_minibatch(arguments)
|
||||||
|
total_error += error
|
||||||
|
|
||||||
|
return total_error / num_mbs
|
||||||
|
|
||||||
# Place holder for real test
|
# Place holder for real test
|
||||||
def test_TODO_remove_me(device_id):
|
def test_TODO_remove_me(device_id):
|
||||||
#FIXME: need a backdoor to work around the limitation of changing the default device not possible
|
#FIXME: need a backdoor to work around the limitation of changing the default device not possible
|
||||||
|
@ -215,4 +230,6 @@ if __name__ == '__main__':
|
||||||
base_path = os.path.normpath(os.path.join(
|
base_path = os.path.normpath(os.path.join(
|
||||||
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
|
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
|
||||||
|
|
||||||
|
os.chdir(os.path.join(base_path, '..'))
|
||||||
|
|
||||||
cifar_resnet(base_path)
|
cifar_resnet(base_path)
|
||||||
|
|
|
@ -15,8 +15,6 @@ abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(os.path.join(abs_path, "..", ".."))
|
sys.path.append(os.path.join(abs_path, "..", ".."))
|
||||||
from examples.common.nn import fully_connected_classifier_net, print_training_progress
|
from examples.common.nn import fully_connected_classifier_net, print_training_progress
|
||||||
|
|
||||||
TOLERANCE_ABSOLUTE = 1E-03
|
|
||||||
|
|
||||||
# make sure we get always the same "randomness"
|
# make sure we get always the same "randomness"
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
|
@ -35,7 +33,7 @@ def generate_random_data(sample_size, feature_dim, num_classes):
|
||||||
|
|
||||||
# Creates and trains a feedforward classification model
|
# Creates and trains a feedforward classification model
|
||||||
|
|
||||||
def ffnet(debug_output=True):
|
def ffnet(debug_output=False):
|
||||||
input_dim = 2
|
input_dim = 2
|
||||||
num_output_classes = 2
|
num_output_classes = 2
|
||||||
num_hidden_layers = 2
|
num_hidden_layers = 2
|
||||||
|
@ -77,7 +75,7 @@ def ffnet(debug_output=True):
|
||||||
{input: test_features, label: test_labels})
|
{input: test_features, label: test_labels})
|
||||||
return avg_error
|
return avg_error
|
||||||
|
|
||||||
def test_error(device_id):
|
def test_error_TODO(device_id):
|
||||||
#FIXME: need a backdoor to work around the limitation of changing the default device not possible
|
#FIXME: need a backdoor to work around the limitation of changing the default device not possible
|
||||||
#from cntk.utils import cntk_device
|
#from cntk.utils import cntk_device
|
||||||
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) Microsoft. All rights reserved.
|
# Copyright (c) Microsoft. All rights reserved.
|
||||||
|
|
||||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||||
# for full license information.
|
# for full license information.
|
||||||
|
@ -19,8 +19,7 @@ from examples.common.nn import LSTMP_component_with_self_stabilization, stabiliz
|
||||||
|
|
||||||
# Creates and trains a sequence to sequence translation model
|
# Creates and trains a sequence to sequence translation model
|
||||||
|
|
||||||
|
def sequence_to_sequence_translator(debug_output=False):
|
||||||
def train_sequence_to_sequence_translator():
|
|
||||||
|
|
||||||
input_vocab_dim = 69
|
input_vocab_dim = 69
|
||||||
label_vocab_dim = 69
|
label_vocab_dim = 69
|
||||||
|
@ -94,6 +93,16 @@ def train_sequence_to_sequence_translator():
|
||||||
ce = cross_entropy_with_softmax(z, label_sequence)
|
ce = cross_entropy_with_softmax(z, label_sequence)
|
||||||
errs = classification_error(z, label_sequence)
|
errs = classification_error(z, label_sequence)
|
||||||
|
|
||||||
|
# Instantiate the trainer object to drive the model training
|
||||||
|
lr = 0.007
|
||||||
|
momentum_time_constant = 1100
|
||||||
|
momentum_per_sample = momentums_per_sample(
|
||||||
|
math.exp(-1.0 / momentum_time_constant))
|
||||||
|
clipping_threshold_per_sample = 2.3
|
||||||
|
gradient_clipping_with_truncation = True
|
||||||
|
|
||||||
|
trainer = Trainer(z, ce, errs, [momentum_sgd(z.parameters(), lr, momentum_per_sample, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
||||||
|
|
||||||
rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf"
|
rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf"
|
||||||
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
|
||||||
feature_stream_name = 'features'
|
feature_stream_name = 'features'
|
||||||
|
@ -105,16 +114,6 @@ def train_sequence_to_sequence_translator():
|
||||||
features_si = mb_source.stream_info(feature_stream_name)
|
features_si = mb_source.stream_info(feature_stream_name)
|
||||||
labels_si = mb_source.stream_info(labels_stream_name)
|
labels_si = mb_source.stream_info(labels_stream_name)
|
||||||
|
|
||||||
# Instantiate the trainer object to drive the model training
|
|
||||||
lr = 0.007
|
|
||||||
momentum_time_constant = 1100
|
|
||||||
momentum_per_sample = momentums_per_sample(
|
|
||||||
math.exp(-1.0 / momentum_time_constant))
|
|
||||||
clipping_threshold_per_sample = 2.3
|
|
||||||
gradient_clipping_with_truncation = True
|
|
||||||
|
|
||||||
trainer = Trainer(z, ce, errs, [momentum_sgd(z.parameters(), lr, momentum_per_sample, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
|
||||||
|
|
||||||
# Get minibatches of sequences to train with and perform model training
|
# Get minibatches of sequences to train with and perform model training
|
||||||
minibatch_size = 72
|
minibatch_size = 72
|
||||||
training_progress_output_freq = 10
|
training_progress_output_freq = 10
|
||||||
|
@ -129,13 +128,51 @@ def train_sequence_to_sequence_translator():
|
||||||
raw_labels: mb[labels_si].m_data}
|
raw_labels: mb[labels_si].m_data}
|
||||||
trainer.train_minibatch(arguments)
|
trainer.train_minibatch(arguments)
|
||||||
|
|
||||||
|
if debug_output:
|
||||||
print_training_progress(trainer, i, training_progress_output_freq)
|
print_training_progress(trainer, i, training_progress_output_freq)
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.test.ctf"
|
||||||
|
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
|
||||||
|
|
||||||
|
test_mb_source = text_format_minibatch_source(path, [
|
||||||
|
StreamConfiguration(feature_stream_name, input_vocab_dim, True, 'S0'),
|
||||||
|
StreamConfiguration(labels_stream_name, label_vocab_dim, True, 'S1')], 10000)
|
||||||
|
features_si = test_mb_source.stream_info(feature_stream_name)
|
||||||
|
labels_si = test_mb_source.stream_info(labels_stream_name)
|
||||||
|
|
||||||
|
# choose this to be big enough for the longest sentence
|
||||||
|
train_minibatch_size = 1024
|
||||||
|
|
||||||
|
# Get minibatches of sequences to test and perform testing
|
||||||
|
i = 0
|
||||||
|
total_error = 0.0
|
||||||
|
while True:
|
||||||
|
mb = test_mb_source.get_next_minibatch(train_minibatch_size)
|
||||||
|
if len(mb) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Specify the mapping of input variables in the model to actual
|
||||||
|
# minibatch data to be tested with
|
||||||
|
arguments = {raw_input: mb[features_si].m_data,
|
||||||
|
raw_labels: mb[labels_si].m_data}
|
||||||
|
mb_error = trainer.test_minibatch(arguments)
|
||||||
|
|
||||||
|
total_error += mb_error
|
||||||
|
|
||||||
|
if debug_output:
|
||||||
|
print("Minibatch {}, Error {} ".format(i, mb_error))
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Average of evaluation errors of all test minibatches
|
||||||
|
return total_error / i
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Specify the target device to be used for computing
|
# Specify the target device to be used for computing
|
||||||
target_device = DeviceDescriptor.cpu_device()
|
target_device = DeviceDescriptor.cpu_device()
|
||||||
DeviceDescriptor.set_default_device(target_device)
|
DeviceDescriptor.set_default_device(target_device)
|
||||||
|
|
||||||
train_sequence_to_sequence_translator()
|
error = sequence_to_sequence_translator()
|
||||||
|
print("test: %f" % error)
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Microsoft. All rights reserved.
|
||||||
|
|
||||||
|
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||||
|
# for full license information.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from cntk import DeviceDescriptor
|
||||||
|
from cntk.io import ReaderConfig, ImageDeserializer
|
||||||
|
|
||||||
|
from examples.CifarResNet.CifarResNet import cifar_resnet
|
||||||
|
|
||||||
|
TOLERANCE_ABSOLUTE = 1E-1
|
||||||
|
|
||||||
|
def test_cifar_resnet_error(device_id):
|
||||||
|
target_device = DeviceDescriptor.gpu_device(0)
|
||||||
|
DeviceDescriptor.set_default_device(target_device)
|
||||||
|
|
||||||
|
base_path = os.path.normpath(os.path.join(
|
||||||
|
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
|
||||||
|
|
||||||
|
os.chdir(os.path.join(base_path, '..'))
|
||||||
|
|
||||||
|
test_error = cifar_resnet(base_path)
|
||||||
|
expected_test_error = 0.7
|
||||||
|
|
||||||
|
assert np.allclose(test_error, expected_test_error,
|
||||||
|
atol=TOLERANCE_ABSOLUTE)
|
|
@ -11,7 +11,7 @@ from examples.NumpyInterop.FeedForwardNet import ffnet
|
||||||
|
|
||||||
TOLERANCE_ABSOLUTE = 1E-03
|
TOLERANCE_ABSOLUTE = 1E-03
|
||||||
|
|
||||||
def test_error(device_id):
|
def test_ffnet_error(device_id):
|
||||||
#from cntk.utils import cntk_device
|
#from cntk.utils import cntk_device
|
||||||
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from examples.SequenceClassification.SequenceClassification import train_sequenc
|
||||||
|
|
||||||
TOLERANCE_ABSOLUTE = 1E-2
|
TOLERANCE_ABSOLUTE = 1E-2
|
||||||
|
|
||||||
def test_error(device_id):
|
def test_seq_classification_error(device_id):
|
||||||
#from cntk.utils import cntk_device
|
#from cntk.utils import cntk_device
|
||||||
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Copyright (c) Microsoft. All rights reserved.
|
||||||
|
|
||||||
|
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||||
|
# for full license information.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from cntk import DeviceDescriptor
|
||||||
|
|
||||||
|
from examples.Sequence2Sequence.Sequence2Sequence import sequence_to_sequence_translator
|
||||||
|
|
||||||
|
TOLERANCE_ABSOLUTE = 1E-1
|
||||||
|
|
||||||
|
def test_sequence_to_sequence(device_id):
|
||||||
|
#from cntk.utils import cntk_device
|
||||||
|
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
||||||
|
|
||||||
|
error = sequence_to_sequence_translator()
|
||||||
|
expected_error = 0.758458
|
||||||
|
|
||||||
|
assert np.allclose(error, expected_error, atol=TOLERANCE_ABSOLUTE)
|
|
@ -11,12 +11,12 @@ from examples.MNIST.SimpleMNIST import simple_mnist
|
||||||
|
|
||||||
TOLERANCE_ABSOLUTE = 1E-1
|
TOLERANCE_ABSOLUTE = 1E-1
|
||||||
|
|
||||||
def test_error(device_id):
|
def test_simple_mnist_error(device_id):
|
||||||
#from cntk.utils import cntk_device
|
#from cntk.utils import cntk_device
|
||||||
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
||||||
|
|
||||||
test_error = simple_mnist()
|
test_error = simple_mnist()
|
||||||
expected_test_error = 0.7
|
expected_test_error = 0.09
|
||||||
|
|
||||||
assert np.allclose([test_error], [expected_test_error],
|
assert np.allclose(test_error, expected_test_error,
|
||||||
atol=TOLERANCE_ABSOLUTE)
|
atol=TOLERANCE_ABSOLUTE)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче