This commit is contained in:
Ivan Rodriguez 2016-09-29 17:03:57 +02:00
Родитель 75a582f1d3
Коммит 6822bb5e0a
3 изменённых файлов: 90 добавлений и 5 удалений

Просмотреть файл

@ -18,14 +18,17 @@ from examples.common.nn import conv_bn_relu_layer, conv_bn_layer, resnet_node2,
TRAIN_MAP_FILENAME = 'train_map.txt'
MEAN_FILENAME = 'CIFAR-10_mean.xml'
TEST_MAP_FILENAME = 'test_map.txt'
# Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net
# The minibatch source is configured using a hierarchical dictionary of key:value pairs
def create_mb_source(features_stream_name, labels_stream_name, image_height,
image_width, num_channels, num_classes, cifar_data_path):
map_file = os.path.join(cifar_data_path, TRAIN_MAP_FILENAME)
mean_file = os.path.join(cifar_data_path, MEAN_FILENAME)
path = os.path.normpath(os.path.join(abs_path, cifar_data_path))
map_file = os.path.join(path, TRAIN_MAP_FILENAME)
mean_file = os.path.join(path, MEAN_FILENAME)
if not os.path.exists(map_file) or not os.path.exists(mean_file):
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
@ -41,6 +44,34 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height,
ImageDeserializer.mean(mean_file)])
image.map_labels(labels_stream_name, num_classes)
rc = ReaderConfig(image, epoch_size=sys.maxsize)
return rc.minibatch_source()
def create_test_mb_source(features_stream_name, labels_stream_name, image_height,
image_width, num_channels, num_classes, cifar_data_path):
path = os.path.normpath(os.path.join(abs_path, cifar_data_path))
map_file = os.path.join(path, TEST_MAP_FILENAME)
mean_file = os.path.join(path, MEAN_FILENAME)
if not os.path.exists(map_file) or not os.path.exists(mean_file):
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them" %
(map_file, mean_file, cifar_py3, cifar_py3))
image = ImageDeserializer(map_file)
image.map_features(features_stream_name,
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
jitter_type='uniRatio'),
ImageDeserializer.scale(width=image_width, height=image_height,
channels=num_channels, interpolations='linear'),
ImageDeserializer.mean(mean_file)])
image.map_labels(labels_stream_name, num_classes)
rc = ReaderConfig(image, epoch_size=sys.maxsize)
return rc.minibatch_source()
def get_projection_map(out_dim, in_dim):
if in_dim > out_dim:
raise ValueError(
@ -109,7 +140,7 @@ def resnet_classifer(input, num_classes):
return t + out_bias_params
# Trains a residual network model on the Cifar image dataset
def cifar_resnet(base_path):
def cifar_resnet(base_path, debug_output=False):
image_height = 32
image_width = 32
num_channels = 3
@ -141,6 +172,7 @@ def cifar_resnet(base_path):
mb_size = 32
training_progress_output_freq = 20
num_mbs = 1000
for i in range(0, num_mbs):
mb = minibatch_source.get_next_minibatch(mb_size)
@ -150,7 +182,29 @@ def cifar_resnet(base_path):
features_si].m_data, label_var: mb[labels_si].m_data}
trainer.train_minibatch(arguments)
print_training_progress(trainer, i, training_progress_output_freq)
if debug_output:
print_training_progress(trainer, i, training_progress_output_freq)
test_minibatch_source = create_test_mb_source(feats_stream_name, labels_stream_name,
image_height, image_width, num_channels, num_classes, base_path)
features_si = test_minibatch_source.stream_info(feats_stream_name)
labels_si = test_minibatch_source.stream_info(labels_stream_name)
mb_size = 64
num_mbs = 300
total_error = 0.0
for i in range(0, num_mbs):
mb = test_minibatch_source.get_next_minibatch(mb_size)
# Specify the mapping of input variables in the model to actual
# minibatch data to be trained with
arguments = {image_input: mb[
features_si].m_data, label_var: mb[labels_si].m_data}
error = trainer.test_minibatch(arguments)
total_error += error
return total_error / num_mbs
if __name__ == '__main__':
# Specify the target device to be used for computing
@ -160,4 +214,6 @@ if __name__ == '__main__':
base_path = os.path.normpath(os.path.join(
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
os.chdir(os.path.join(base_path, '..'))
cifar_resnet(base_path)

Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
import os
from cntk import DeviceDescriptor
from cntk.io import ReaderConfig, ImageDeserializer
from examples.CifarResNet.CifarResNet import cifar_resnet
TOLERANCE_ABSOLUTE = 1E-1
def test_cifar_resnet_error(device_id):
target_device = DeviceDescriptor.gpu_device(0)
DeviceDescriptor.set_default_device(target_device)
base_path = os.path.normpath(os.path.join(
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
os.chdir(os.path.join(base_path, '..'))
test_error = cifar_resnet(base_path)
expected_test_error = 0.7
assert np.allclose(test_error, expected_test_error,
atol=TOLERANCE_ABSOLUTE)

Просмотреть файл

@ -16,7 +16,7 @@ def test_simple_mnist_error(device_id):
#DeviceDescriptor.set_default_device(cntk_device(device_id))
test_error = simple_mnist()
expected_test_error = 0.7
expected_test_error = 0.09
assert np.allclose(test_error, expected_test_error,
atol=TOLERANCE_ABSOLUTE)