From 75a582f1d3edab3555d483eb70af5dae0b824538 Mon Sep 17 00:00:00 2001 From: Ivan Rodriguez Date: Thu, 29 Sep 2016 17:03:57 +0200 Subject: [PATCH 1/2] Add sequence to sequence example --- .../examples/CifarResNet/CifarResNet.py | 45 +----------- bindings/python/examples/MNIST/SimpleMNIST.py | 2 +- .../examples/NumpyInterop/FeedForwardNet.py | 13 +--- .../Sequence2Sequence/Sequence2Sequence.py | 69 ++++++++++++++----- .../examples/test/feed_forward_net_test.py | 2 +- .../test/sequence_classification_test.py | 2 +- .../test/sequence_to_sequence_test.py | 21 ++++++ .../python/examples/test/simple_mnist_test.py | 4 +- 8 files changed, 83 insertions(+), 75 deletions(-) create mode 100644 bindings/python/examples/test/sequence_to_sequence_test.py diff --git a/bindings/python/examples/CifarResNet/CifarResNet.py b/bindings/python/examples/CifarResNet/CifarResNet.py index a556896d4..d536e4da1 100644 --- a/bindings/python/examples/CifarResNet/CifarResNet.py +++ b/bindings/python/examples/CifarResNet/CifarResNet.py @@ -20,9 +20,7 @@ TRAIN_MAP_FILENAME = 'train_map.txt' MEAN_FILENAME = 'CIFAR-10_mean.xml' # Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net -# The minibatch source is configured using a hierarchical dictionary of -# key:value pairs - +# The minibatch source is configured using a hierarchical dictionary of key:value pairs def create_mb_source(features_stream_name, labels_stream_name, image_height, image_width, num_channels, num_classes, cifar_data_path): @@ -35,7 +33,6 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height, (map_file, mean_file, cifar_py3, cifar_py3)) image = ImageDeserializer(map_file) -<<<<<<< 391432ca77060ad88807339d773f288de6557c4a image.map_features(features_stream_name, [ImageDeserializer.crop(crop_type='Random', ratio=0.8, jitter_type='uniRatio'), @@ -44,25 +41,6 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height, ImageDeserializer.mean(mean_file)]) image.map_labels(labels_stream_name, num_classes) - rc = ReaderConfig(image, epoch_size=sys.maxsize) - return rc.minibatch_source() -======= - image.map_features(feature_name, - [ImageDeserializer.crop(crop_type='Random', ratio=0.8, - jitter_type='uniRatio'), - ImageDeserializer.scale(width=image_width, height=image_height, - channels=num_channels, interpolations='linear'), - ImageDeserializer.mean(mean_file)]) - image.map_labels(label_name, num_classes) - - rc = ReaderConfig(image, epoch_size=sys.maxsize) - - input_streams_config = { - features_stream_name: features_stream_config, labels_stream_name: labels_stream_config} - deserializer_config = {"type": "ImageDeserializer", - "file": map_file, "input": input_streams_config} - return rc.minibatch_source() ->>>>>>> Address comments in CR def get_projection_map(out_dim, in_dim): if in_dim > out_dim: raise ValueError( @@ -124,40 +102,23 @@ def resnet_classifer(input, num_classes): poolh_stride = 1 poolv_stride = 1 -<<<<<<< 391432ca77060ad88807339d773f288de6557c4a pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride)) out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform_initializer()) out_bias_params = parameter(shape=(num_classes), value=0) -======= - pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), - (1, poolv_stride, poolh_stride)) - out_times_params = parameter(shape=(c_map3, 1, 1, num_classes)) - out_bias_params = parameter(shape=(num_classes)) ->>>>>>> Address comments in CR t = times(pool, out_times_params) return t + out_bias_params # Trains a residual network model on the Cifar image dataset -<<<<<<< 391432ca77060ad88807339d773f288de6557c4a def cifar_resnet(base_path): -======= - - pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride)) - out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform_initializer()) - out_bias_params = parameter(shape=(num_classes), value=0) image_height = 32 image_width = 32 num_channels = 3 num_classes = 10 -def cifar_resnet(base_path): + feats_stream_name = 'features' labels_stream_name = 'labels' -<<<<<<< 391432ca77060ad88807339d773f288de6557c4a + minibatch_source = create_mb_source(feats_stream_name, labels_stream_name, image_height, image_width, num_channels, num_classes, base_path) -======= - minibatch_source = create_mb_source(feats_stream_name, labels_stream_name, - image_height, image_width, num_channels, num_classes) ->>>>>>> Address comments in CR features_si = minibatch_source.stream_info(feats_stream_name) labels_si = minibatch_source.stream_info(labels_stream_name) diff --git a/bindings/python/examples/MNIST/SimpleMNIST.py b/bindings/python/examples/MNIST/SimpleMNIST.py index 38644711d..0ee92c610 100644 --- a/bindings/python/examples/MNIST/SimpleMNIST.py +++ b/bindings/python/examples/MNIST/SimpleMNIST.py @@ -48,7 +48,7 @@ def simple_mnist(debug_output=False): feature_stream_name = 'features' labels_stream_name = 'labels' - mb_source = text_format_minibatch_source(path, [ + mb_source = text_format_minibatch_source(path, [ StreamConfiguration(feature_stream_name, input_dim), StreamConfiguration(labels_stream_name, num_output_classes)]) features_si = mb_source.stream_info(feature_stream_name) diff --git a/bindings/python/examples/NumpyInterop/FeedForwardNet.py b/bindings/python/examples/NumpyInterop/FeedForwardNet.py index 48f6f1d67..faf7d4a62 100644 --- a/bindings/python/examples/NumpyInterop/FeedForwardNet.py +++ b/bindings/python/examples/NumpyInterop/FeedForwardNet.py @@ -14,8 +14,6 @@ abs_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(abs_path, "..", "..")) from examples.common.nn import fully_connected_classifier_net, print_training_progress -TOLERANCE_ABSOLUTE = 1E-03 - # make sure we get always the same "randomness" np.random.seed(0) @@ -34,7 +32,7 @@ def generate_random_data(sample_size, feature_dim, num_classes): # Creates and trains a feedforward classification model -def ffnet(debug_output=True): +def ffnet(debug_output=False): input_dim = 2 num_output_classes = 2 num_hidden_layers = 2 @@ -77,15 +75,6 @@ def ffnet(debug_output=True): {input: test_features, label: test_labels}) return avg_error - -def test_error(device_id): - from cntk.utils import cntk_device - DeviceDescriptor.set_default_device(cntk_device(device_id)) - - avg_error = ffnet(debug_output=False) - expected_avg_error = 0.12 - assert np.allclose(avg_error, expected_avg_error, atol=TOLERANCE_ABSOLUTE) - if __name__ == '__main__': # Specify the target device to be used for computing target_device = DeviceDescriptor.gpu_device(0) diff --git a/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py b/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py index 5e2baabb7..9e67a2cce 100644 --- a/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py +++ b/bindings/python/examples/Sequence2Sequence/Sequence2Sequence.py @@ -1,4 +1,4 @@ -# Copyright (c) Microsoft. All rights reserved. +# Copyright (c) Microsoft. All rights reserved. # Licensed under the MIT license. See LICENSE.md file in the project root # for full license information. @@ -18,8 +18,7 @@ from examples.common.nn import LSTMP_component_with_self_stabilization, stabiliz # Creates and trains a sequence to sequence translation model - -def train_sequence_to_sequence_translator(): +def sequence_to_sequence_translator(debug_output=False): input_vocab_dim = 69 label_vocab_dim = 69 @@ -93,17 +92,6 @@ def train_sequence_to_sequence_translator(): ce = cross_entropy_with_softmax(z, label_sequence) errs = classification_error(z, label_sequence) - rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf" - path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path) - feature_stream_name = 'features' - labels_stream_name = 'labels' - - mb_source = text_format_minibatch_source(path, [ - StreamConfiguration(feature_stream_name, input_vocab_dim, True, 'S0'), - StreamConfiguration(labels_stream_name, label_vocab_dim, True, 'S1')], 10000) - features_si = mb_source.stream_info(feature_stream_name) - labels_si = mb_source.stream_info(labels_stream_name) - # Instantiate the trainer object to drive the model training lr = 0.007 momentum_time_constant = 1100 @@ -115,6 +103,17 @@ def train_sequence_to_sequence_translator(): trainer = Trainer(z, ce, errs, [momentum_sgd_learner(z.parameters( ), lr, momentum_per_sample, clipping_threshold_per_sample, gradient_clipping_with_truncation)]) + rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf" + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path) + feature_stream_name = 'features' + labels_stream_name = 'labels' + + mb_source = text_format_minibatch_source(path, [ + StreamConfiguration(feature_stream_name, input_vocab_dim, True, 'S0'), + StreamConfiguration(labels_stream_name, label_vocab_dim, True, 'S1')], 10000) + features_si = mb_source.stream_info(feature_stream_name) + labels_si = mb_source.stream_info(labels_stream_name) + # Get minibatches of sequences to train with and perform model training minibatch_size = 72 training_progress_output_freq = 10 @@ -129,13 +128,51 @@ def train_sequence_to_sequence_translator(): raw_labels: mb[labels_si].m_data} trainer.train_minibatch(arguments) - print_training_progress(trainer, i, training_progress_output_freq) + if debug_output: + print_training_progress(trainer, i, training_progress_output_freq) + + i += 1 + + rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.test.ctf" + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path) + + test_mb_source = text_format_minibatch_source(path, [ + StreamConfiguration(feature_stream_name, input_vocab_dim, True, 'S0'), + StreamConfiguration(labels_stream_name, label_vocab_dim, True, 'S1')], 10000) + features_si = test_mb_source.stream_info(feature_stream_name) + labels_si = test_mb_source.stream_info(labels_stream_name) + + # choose this to be big enough for the longest sentence + train_minibatch_size = 1024 + + # Get minibatches of sequences to test and perform testing + i = 0 + total_error = 0.0 + while True: + mb = test_mb_source.get_next_minibatch(train_minibatch_size) + if len(mb) == 0: + break + + # Specify the mapping of input variables in the model to actual + # minibatch data to be tested with + arguments = {raw_input: mb[features_si].m_data, + raw_labels: mb[labels_si].m_data} + mb_error = trainer.test_minibatch(arguments) + + total_error += mb_error + + if debug_output: + print("Minibatch {}, Error {} ".format(i, mb_error)) i += 1 + # Average of evaluation errors of all test minibatches + return total_error / i + if __name__ == '__main__': # Specify the target device to be used for computing target_device = DeviceDescriptor.cpu_device() DeviceDescriptor.set_default_device(target_device) - train_sequence_to_sequence_translator() + error = sequence_to_sequence_translator() + print("test: %f" % error) diff --git a/bindings/python/examples/test/feed_forward_net_test.py b/bindings/python/examples/test/feed_forward_net_test.py index 1d8ea22c2..89d32da9e 100644 --- a/bindings/python/examples/test/feed_forward_net_test.py +++ b/bindings/python/examples/test/feed_forward_net_test.py @@ -11,7 +11,7 @@ from examples.NumpyInterop.FeedForwardNet import ffnet TOLERANCE_ABSOLUTE = 1E-03 -def test_error(device_id): +def test_ffnet_error(device_id): #from cntk.utils import cntk_device #DeviceDescriptor.set_default_device(cntk_device(device_id)) diff --git a/bindings/python/examples/test/sequence_classification_test.py b/bindings/python/examples/test/sequence_classification_test.py index 4a95e45a1..eff05e21a 100644 --- a/bindings/python/examples/test/sequence_classification_test.py +++ b/bindings/python/examples/test/sequence_classification_test.py @@ -11,7 +11,7 @@ from examples.SequenceClassification.SequenceClassification import train_sequenc TOLERANCE_ABSOLUTE = 1E-2 -def test_error(device_id): +def test_seq_classification_error(device_id): #from cntk.utils import cntk_device #DeviceDescriptor.set_default_device(cntk_device(device_id)) diff --git a/bindings/python/examples/test/sequence_to_sequence_test.py b/bindings/python/examples/test/sequence_to_sequence_test.py new file mode 100644 index 000000000..57687af58 --- /dev/null +++ b/bindings/python/examples/test/sequence_to_sequence_test.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft. All rights reserved. + +# Licensed under the MIT license. See LICENSE.md file in the project root +# for full license information. +# ============================================================================== + +import numpy as np +from cntk import DeviceDescriptor + +from examples.Sequence2Sequence.Sequence2Sequence import sequence_to_sequence_translator + +TOLERANCE_ABSOLUTE = 1E-1 + +def test_sequence_to_sequence(device_id): + #from cntk.utils import cntk_device + #DeviceDescriptor.set_default_device(cntk_device(device_id)) + + error = sequence_to_sequence_translator() + expected_error = 0.758458 + + assert np.allclose(error, expected_error, atol=TOLERANCE_ABSOLUTE) diff --git a/bindings/python/examples/test/simple_mnist_test.py b/bindings/python/examples/test/simple_mnist_test.py index e546b9af7..7b26c4aa8 100644 --- a/bindings/python/examples/test/simple_mnist_test.py +++ b/bindings/python/examples/test/simple_mnist_test.py @@ -11,12 +11,12 @@ from examples.MNIST.SimpleMNIST import simple_mnist TOLERANCE_ABSOLUTE = 1E-1 -def test_error(device_id): +def test_simple_mnist_error(device_id): #from cntk.utils import cntk_device #DeviceDescriptor.set_default_device(cntk_device(device_id)) test_error = simple_mnist() expected_test_error = 0.7 - assert np.allclose([test_error], [expected_test_error], + assert np.allclose(test_error, expected_test_error, atol=TOLERANCE_ABSOLUTE) From 6822bb5e0a694a9a23c749b0d629c65484e6219a Mon Sep 17 00:00:00 2001 From: Ivan Rodriguez Date: Thu, 29 Sep 2016 17:03:57 +0200 Subject: [PATCH 2/2] Cifar test added --- .../examples/CifarResNet/CifarResNet.py | 64 +++++++++++++++++-- .../python/examples/test/cifar_resnet_test.py | 29 +++++++++ .../python/examples/test/simple_mnist_test.py | 2 +- 3 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 bindings/python/examples/test/cifar_resnet_test.py diff --git a/bindings/python/examples/CifarResNet/CifarResNet.py b/bindings/python/examples/CifarResNet/CifarResNet.py index d536e4da1..ebac6932a 100644 --- a/bindings/python/examples/CifarResNet/CifarResNet.py +++ b/bindings/python/examples/CifarResNet/CifarResNet.py @@ -18,14 +18,17 @@ from examples.common.nn import conv_bn_relu_layer, conv_bn_layer, resnet_node2, TRAIN_MAP_FILENAME = 'train_map.txt' MEAN_FILENAME = 'CIFAR-10_mean.xml' +TEST_MAP_FILENAME = 'test_map.txt' # Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net # The minibatch source is configured using a hierarchical dictionary of key:value pairs def create_mb_source(features_stream_name, labels_stream_name, image_height, image_width, num_channels, num_classes, cifar_data_path): - map_file = os.path.join(cifar_data_path, TRAIN_MAP_FILENAME) - mean_file = os.path.join(cifar_data_path, MEAN_FILENAME) + + path = os.path.normpath(os.path.join(abs_path, cifar_data_path)) + map_file = os.path.join(path, TRAIN_MAP_FILENAME) + mean_file = os.path.join(path, MEAN_FILENAME) if not os.path.exists(map_file) or not os.path.exists(mean_file): cifar_py3 = "" if sys.version_info.major < 3 else "_py3" @@ -41,6 +44,34 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height, ImageDeserializer.mean(mean_file)]) image.map_labels(labels_stream_name, num_classes) + rc = ReaderConfig(image, epoch_size=sys.maxsize) + return rc.minibatch_source() + +def create_test_mb_source(features_stream_name, labels_stream_name, image_height, + image_width, num_channels, num_classes, cifar_data_path): + + path = os.path.normpath(os.path.join(abs_path, cifar_data_path)) + + map_file = os.path.join(path, TEST_MAP_FILENAME) + mean_file = os.path.join(path, MEAN_FILENAME) + + if not os.path.exists(map_file) or not os.path.exists(mean_file): + cifar_py3 = "" if sys.version_info.major < 3 else "_py3" + raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them" % + (map_file, mean_file, cifar_py3, cifar_py3)) + + image = ImageDeserializer(map_file) + image.map_features(features_stream_name, + [ImageDeserializer.crop(crop_type='Random', ratio=0.8, + jitter_type='uniRatio'), + ImageDeserializer.scale(width=image_width, height=image_height, + channels=num_channels, interpolations='linear'), + ImageDeserializer.mean(mean_file)]) + image.map_labels(labels_stream_name, num_classes) + + rc = ReaderConfig(image, epoch_size=sys.maxsize) + return rc.minibatch_source() + def get_projection_map(out_dim, in_dim): if in_dim > out_dim: raise ValueError( @@ -109,7 +140,7 @@ def resnet_classifer(input, num_classes): return t + out_bias_params # Trains a residual network model on the Cifar image dataset -def cifar_resnet(base_path): +def cifar_resnet(base_path, debug_output=False): image_height = 32 image_width = 32 num_channels = 3 @@ -141,6 +172,7 @@ def cifar_resnet(base_path): mb_size = 32 training_progress_output_freq = 20 num_mbs = 1000 + for i in range(0, num_mbs): mb = minibatch_source.get_next_minibatch(mb_size) @@ -150,7 +182,29 @@ def cifar_resnet(base_path): features_si].m_data, label_var: mb[labels_si].m_data} trainer.train_minibatch(arguments) - print_training_progress(trainer, i, training_progress_output_freq) + if debug_output: + print_training_progress(trainer, i, training_progress_output_freq) + + test_minibatch_source = create_test_mb_source(feats_stream_name, labels_stream_name, + image_height, image_width, num_channels, num_classes, base_path) + features_si = test_minibatch_source.stream_info(feats_stream_name) + labels_si = test_minibatch_source.stream_info(labels_stream_name) + + mb_size = 64 + num_mbs = 300 + + total_error = 0.0 + for i in range(0, num_mbs): + mb = test_minibatch_source.get_next_minibatch(mb_size) + + # Specify the mapping of input variables in the model to actual + # minibatch data to be trained with + arguments = {image_input: mb[ + features_si].m_data, label_var: mb[labels_si].m_data} + error = trainer.test_minibatch(arguments) + total_error += error + + return total_error / num_mbs if __name__ == '__main__': # Specify the target device to be used for computing @@ -160,4 +214,6 @@ if __name__ == '__main__': base_path = os.path.normpath(os.path.join( *"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/"))) + os.chdir(os.path.join(base_path, '..')) + cifar_resnet(base_path) diff --git a/bindings/python/examples/test/cifar_resnet_test.py b/bindings/python/examples/test/cifar_resnet_test.py new file mode 100644 index 000000000..abe085b25 --- /dev/null +++ b/bindings/python/examples/test/cifar_resnet_test.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft. All rights reserved. + +# Licensed under the MIT license. See LICENSE.md file in the project root +# for full license information. +# ============================================================================== + +import numpy as np +import os +from cntk import DeviceDescriptor +from cntk.io import ReaderConfig, ImageDeserializer + +from examples.CifarResNet.CifarResNet import cifar_resnet + +TOLERANCE_ABSOLUTE = 1E-1 + +def test_cifar_resnet_error(device_id): + target_device = DeviceDescriptor.gpu_device(0) + DeviceDescriptor.set_default_device(target_device) + + base_path = os.path.normpath(os.path.join( + *"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/"))) + + os.chdir(os.path.join(base_path, '..')) + + test_error = cifar_resnet(base_path) + expected_test_error = 0.7 + + assert np.allclose(test_error, expected_test_error, + atol=TOLERANCE_ABSOLUTE) diff --git a/bindings/python/examples/test/simple_mnist_test.py b/bindings/python/examples/test/simple_mnist_test.py index 7b26c4aa8..cbd0c693a 100644 --- a/bindings/python/examples/test/simple_mnist_test.py +++ b/bindings/python/examples/test/simple_mnist_test.py @@ -16,7 +16,7 @@ def test_simple_mnist_error(device_id): #DeviceDescriptor.set_default_device(cntk_device(device_id)) test_error = simple_mnist() - expected_test_error = 0.7 + expected_test_error = 0.09 assert np.allclose(test_error, expected_test_error, atol=TOLERANCE_ABSOLUTE)