Cifar test added
This commit is contained in:
Родитель
75a582f1d3
Коммит
6822bb5e0a
|
@ -18,14 +18,17 @@ from examples.common.nn import conv_bn_relu_layer, conv_bn_layer, resnet_node2,
|
|||
|
||||
TRAIN_MAP_FILENAME = 'train_map.txt'
|
||||
MEAN_FILENAME = 'CIFAR-10_mean.xml'
|
||||
TEST_MAP_FILENAME = 'test_map.txt'
|
||||
|
||||
# Instantiates the CNTK built-in minibatch source for reading images to be used for training the residual net
|
||||
# The minibatch source is configured using a hierarchical dictionary of key:value pairs
|
||||
|
||||
def create_mb_source(features_stream_name, labels_stream_name, image_height,
|
||||
image_width, num_channels, num_classes, cifar_data_path):
|
||||
map_file = os.path.join(cifar_data_path, TRAIN_MAP_FILENAME)
|
||||
mean_file = os.path.join(cifar_data_path, MEAN_FILENAME)
|
||||
|
||||
path = os.path.normpath(os.path.join(abs_path, cifar_data_path))
|
||||
map_file = os.path.join(path, TRAIN_MAP_FILENAME)
|
||||
mean_file = os.path.join(path, MEAN_FILENAME)
|
||||
|
||||
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
||||
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
|
||||
|
@ -41,6 +44,34 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height,
|
|||
ImageDeserializer.mean(mean_file)])
|
||||
image.map_labels(labels_stream_name, num_classes)
|
||||
|
||||
rc = ReaderConfig(image, epoch_size=sys.maxsize)
|
||||
return rc.minibatch_source()
|
||||
|
||||
def create_test_mb_source(features_stream_name, labels_stream_name, image_height,
|
||||
image_width, num_channels, num_classes, cifar_data_path):
|
||||
|
||||
path = os.path.normpath(os.path.join(abs_path, cifar_data_path))
|
||||
|
||||
map_file = os.path.join(path, TEST_MAP_FILENAME)
|
||||
mean_file = os.path.join(path, MEAN_FILENAME)
|
||||
|
||||
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
||||
cifar_py3 = "" if sys.version_info.major < 3 else "_py3"
|
||||
raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them" %
|
||||
(map_file, mean_file, cifar_py3, cifar_py3))
|
||||
|
||||
image = ImageDeserializer(map_file)
|
||||
image.map_features(features_stream_name,
|
||||
[ImageDeserializer.crop(crop_type='Random', ratio=0.8,
|
||||
jitter_type='uniRatio'),
|
||||
ImageDeserializer.scale(width=image_width, height=image_height,
|
||||
channels=num_channels, interpolations='linear'),
|
||||
ImageDeserializer.mean(mean_file)])
|
||||
image.map_labels(labels_stream_name, num_classes)
|
||||
|
||||
rc = ReaderConfig(image, epoch_size=sys.maxsize)
|
||||
return rc.minibatch_source()
|
||||
|
||||
def get_projection_map(out_dim, in_dim):
|
||||
if in_dim > out_dim:
|
||||
raise ValueError(
|
||||
|
@ -109,7 +140,7 @@ def resnet_classifer(input, num_classes):
|
|||
return t + out_bias_params
|
||||
|
||||
# Trains a residual network model on the Cifar image dataset
|
||||
def cifar_resnet(base_path):
|
||||
def cifar_resnet(base_path, debug_output=False):
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
num_channels = 3
|
||||
|
@ -141,6 +172,7 @@ def cifar_resnet(base_path):
|
|||
mb_size = 32
|
||||
training_progress_output_freq = 20
|
||||
num_mbs = 1000
|
||||
|
||||
for i in range(0, num_mbs):
|
||||
mb = minibatch_source.get_next_minibatch(mb_size)
|
||||
|
||||
|
@ -150,7 +182,29 @@ def cifar_resnet(base_path):
|
|||
features_si].m_data, label_var: mb[labels_si].m_data}
|
||||
trainer.train_minibatch(arguments)
|
||||
|
||||
print_training_progress(trainer, i, training_progress_output_freq)
|
||||
if debug_output:
|
||||
print_training_progress(trainer, i, training_progress_output_freq)
|
||||
|
||||
test_minibatch_source = create_test_mb_source(feats_stream_name, labels_stream_name,
|
||||
image_height, image_width, num_channels, num_classes, base_path)
|
||||
features_si = test_minibatch_source.stream_info(feats_stream_name)
|
||||
labels_si = test_minibatch_source.stream_info(labels_stream_name)
|
||||
|
||||
mb_size = 64
|
||||
num_mbs = 300
|
||||
|
||||
total_error = 0.0
|
||||
for i in range(0, num_mbs):
|
||||
mb = test_minibatch_source.get_next_minibatch(mb_size)
|
||||
|
||||
# Specify the mapping of input variables in the model to actual
|
||||
# minibatch data to be trained with
|
||||
arguments = {image_input: mb[
|
||||
features_si].m_data, label_var: mb[labels_si].m_data}
|
||||
error = trainer.test_minibatch(arguments)
|
||||
total_error += error
|
||||
|
||||
return total_error / num_mbs
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Specify the target device to be used for computing
|
||||
|
@ -160,4 +214,6 @@ if __name__ == '__main__':
|
|||
base_path = os.path.normpath(os.path.join(
|
||||
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
|
||||
|
||||
os.chdir(os.path.join(base_path, '..'))
|
||||
|
||||
cifar_resnet(base_path)
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from cntk import DeviceDescriptor
|
||||
from cntk.io import ReaderConfig, ImageDeserializer
|
||||
|
||||
from examples.CifarResNet.CifarResNet import cifar_resnet
|
||||
|
||||
TOLERANCE_ABSOLUTE = 1E-1
|
||||
|
||||
def test_cifar_resnet_error(device_id):
|
||||
target_device = DeviceDescriptor.gpu_device(0)
|
||||
DeviceDescriptor.set_default_device(target_device)
|
||||
|
||||
base_path = os.path.normpath(os.path.join(
|
||||
*"../../../../Examples/Image/Miscellaneous/CIFAR-10/cifar-10-batches-py".split("/")))
|
||||
|
||||
os.chdir(os.path.join(base_path, '..'))
|
||||
|
||||
test_error = cifar_resnet(base_path)
|
||||
expected_test_error = 0.7
|
||||
|
||||
assert np.allclose(test_error, expected_test_error,
|
||||
atol=TOLERANCE_ABSOLUTE)
|
|
@ -16,7 +16,7 @@ def test_simple_mnist_error(device_id):
|
|||
#DeviceDescriptor.set_default_device(cntk_device(device_id))
|
||||
|
||||
test_error = simple_mnist()
|
||||
expected_test_error = 0.7
|
||||
expected_test_error = 0.09
|
||||
|
||||
assert np.allclose(test_error, expected_test_error,
|
||||
atol=TOLERANCE_ABSOLUTE)
|
||||
|
|
Загрузка…
Ссылка в новой задаче