Integrate chazhang/classification_py into master
This commit is contained in:
Коммит
dd496e5111
|
@ -4,7 +4,7 @@ command = TrainConvNet:Eval
|
|||
|
||||
precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
||||
|
||||
rootDir = "../.." ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
rootDir = "../../.." ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
outputDir = "./Output" ;
|
||||
|
||||
modelPath = "$outputDir$/Models/ConvNet_CIFAR10"
|
||||
|
@ -29,8 +29,8 @@ TrainConvNet = {
|
|||
ConvolutionalLayer {64, (3:3), pad = true} : ReLU :
|
||||
ConvolutionalLayer {64, (3:3), pad = true} : ReLU :
|
||||
MaxPoolingLayer {(3:3), stride = (2:2)} :
|
||||
DenseLayer {256} : Dropout : ReLU :
|
||||
DenseLayer {128} : Dropout : ReLU :
|
||||
DenseLayer {256} : ReLU : Dropout :
|
||||
DenseLayer {128} : ReLU : Dropout :
|
||||
LinearLayer {labelDim}
|
||||
)
|
||||
|
||||
|
@ -58,7 +58,7 @@ TrainConvNet = {
|
|||
minibatchSize = 64
|
||||
|
||||
learningRatesPerSample = 0.0015625*10:0.00046875*10:0.00015625
|
||||
momentumAsTimeConstant = 0*20:6400
|
||||
momentumAsTimeConstant = 0*20:607.44
|
||||
maxEpochs = 30
|
||||
L2RegWeight = 0.002
|
||||
dropoutRate = 0*5:0.5
|
||||
|
@ -69,8 +69,8 @@ TrainConvNet = {
|
|||
reader = {
|
||||
readerType = "CNTKTextFormatReader"
|
||||
file = "$DataDir$/Train_cntk_text.txt"
|
||||
randomize = true
|
||||
keepDataInMemory = true # cache all data in memory
|
||||
randomize = true
|
||||
keepDataInMemory = true # cache all data in memory
|
||||
input = {
|
||||
features = { dim = 3072 ; format = "dense" }
|
||||
labels = { dim = 10 ; format = "dense" }
|
|
@ -4,7 +4,7 @@ command = TrainConvNet:Eval
|
|||
|
||||
precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
||||
|
||||
rootDir = "../.." ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
rootDir = "../../.." ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
outputDir = "./Output" ;
|
||||
|
||||
modelPath = "$outputDir$/Models/ConvNet_CIFAR10_DataAug"
|
||||
|
@ -57,10 +57,10 @@ TrainConvNet = {
|
|||
minibatchSize = 64
|
||||
|
||||
learningRatesPerSample = 0.0015625*20:0.00046875*20:0.00015625*20:0.000046875*10:0.000015625
|
||||
momentumAsTimeConstant = 0*20:600*20:6400
|
||||
momentumAsTimeConstant = 0*20:600*20:1200
|
||||
maxEpochs = 80
|
||||
L2RegWeight = 0.002
|
||||
dropoutRate = 0*5:0.5
|
||||
dropoutRate = 0.5
|
||||
|
||||
numMBsToShowResult = 100
|
||||
}
|
|
@ -4,7 +4,7 @@ command = trainNetwork:testNetwork
|
|||
|
||||
precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
||||
|
||||
rootDir = "../.." ; dataDir = "$rootDir$/DataSets/MNIST" ;
|
||||
rootDir = "../../.." ; dataDir = "$rootDir$/DataSets/MNIST" ;
|
||||
outputDir = "./Output" ;
|
||||
|
||||
modelPath = "$outputDir$/Models/ConvNet_MNIST"
|
|
@ -0,0 +1,34 @@
|
|||
# CNTK Examples: Image/Classification/ConvNet
|
||||
|
||||
## BrainScript
|
||||
|
||||
### ConvNet_MNIST.cntk
|
||||
|
||||
Our first example applies CNN on the MNIST dataset. The network we use contains three convolution layers and two dense layers. Dropout is applied after the first dense layer. No data augmentation is used in this example. We start the training with no momentum, and add momentum after training for 5 epochs. Please refer to the CNTK configuration file [ConvNet_MNIST.cntk](./ConvNet_MNIST.cntk) for more details.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ConvNet_MNIST.cntk`
|
||||
|
||||
The network achieves an error rate of `0.5%`, which is very good considering no data augmentation is used. This accuracy is comparable, if not better, than many other vanilla CNN implementations (http://yann.lecun.com/exdb/mnist/).
|
||||
|
||||
### ConvNet_CIFAR10.cntk
|
||||
|
||||
The second example applies CNN on the CIFAR-10 dataset. The network contains four convolution layers and three dense layers. Max pooling is conducted for every two convolution layers. Dropout is applied after the first two dense layers. No data augmentation is used. Please refer to the CTNK configuration file [ConvNet_CIFAR10.cntk](./ConvNet_CIFAR10.cntk) for more details.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ConvNet_CIFAR10.cntk`
|
||||
|
||||
The network achieves an error rate of around `18%` after 30 epochs. This is comparable to the network published by [cuda-convnet](https://code.google.com/p/cuda-convnet/), which has 18% error with no data augmentation. One difference is that we do not use a `local response normalization layer`. This layer type is now rarely used in most state-of-the-art deep learning networks.
|
||||
|
||||
### ConvNet_CIFAR10_DataAug.cntk
|
||||
|
||||
The third example uses the same CNN as the previous example, but it improves by adding data augmentation to training. For this purpose, we use the `ImageReader` instead of the `CNTKTextFormatReader` to load the data. The ImageReader currently supports crop, flip, scale, color jittering, and mean subtraction.
|
||||
For a reference on image reader and transforms, please check [here](https://github.com/Microsoft/CNTK/wiki/Image-reader).
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ConvNet_CIFAR10_DataAug.cntk`
|
||||
|
||||
As seen in the CNTK configuration file [ConvNet_CIFAR10_DataAug.cntk](./ConvNet_CIFAR10_DataAug.cntk), we use a fixed crop ratio of `0.8` and scale the image to `32x32` pixels for training. Since all training images are pre-padded to `40x40` pixels, effectively we only perform translation transform without scaling. The accuracy of the network on test data is around `14%`, which is a lot better than the previous model.
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from cntk import Trainer, persist
|
||||
from cntk.utils import *
|
||||
from cntk.layers import *
|
||||
from cntk.models import Sequential, LayerStack
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_schedule, momentum_as_time_constant_schedule
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, relu, minus, element_times, constant
|
||||
from _cntk_py import set_computation_network_trace_level
|
||||
|
||||
# Paths relative to current python file.
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
data_path = os.path.join(abs_path, "..", "..", "..", "Datasets", "CIFAR-10")
|
||||
model_path = os.path.join(abs_path, "Models")
|
||||
|
||||
# Define the reader for both training and evaluation action.
|
||||
def create_reader(path, is_training, input_dim, label_dim):
|
||||
return MinibatchSource(CTFDeserializer(path, StreamDefs(
|
||||
features = StreamDef(field='features', shape=input_dim),
|
||||
labels = StreamDef(field='labels', shape=label_dim)
|
||||
)), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
|
||||
|
||||
# Creates and trains a feedforward classification model for MNIST images
|
||||
def convnet_cifar10(debug_output=False):
|
||||
set_computation_network_trace_level(0)
|
||||
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
num_channels = 3
|
||||
input_dim = image_height * image_width * num_channels
|
||||
num_output_classes = 10
|
||||
|
||||
# Input variables denoting the features and label data
|
||||
input_var = input_variable((num_channels, image_height, image_width), np.float32)
|
||||
label_var = input_variable(num_output_classes, np.float32)
|
||||
|
||||
# Instantiate the feedforward classification model
|
||||
input_removemean = minus(input_var, constant(128))
|
||||
scaled_input = element_times(constant(0.00390625), input_removemean)
|
||||
with default_options (activation=relu, pad=True):
|
||||
z = Sequential([
|
||||
LayerStack(2, lambda : [
|
||||
Convolution((3,3), 64),
|
||||
Convolution((3,3), 64),
|
||||
MaxPooling((3,3), (2,2))
|
||||
]),
|
||||
LayerStack(2, lambda i: [
|
||||
Dense([256,128][i]),
|
||||
Dropout(0.5)
|
||||
]),
|
||||
Dense(num_output_classes, activation=None)
|
||||
])(scaled_input)
|
||||
|
||||
ce = cross_entropy_with_softmax(z, label_var)
|
||||
pe = classification_error(z, label_var)
|
||||
|
||||
reader_train = create_reader(os.path.join(data_path, 'Train_cntk_text.txt'), True, input_dim, num_output_classes)
|
||||
|
||||
# training config
|
||||
epoch_size = 50000 # for now we manually specify epoch size
|
||||
minibatch_size = 64
|
||||
|
||||
# Set learning parameters
|
||||
lr_per_sample = [0.0015625]*10+[0.00046875]*10+[0.00015625]
|
||||
lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size)
|
||||
momentum_time_constant = [0]*20+[-minibatch_size/np.log(0.9)]
|
||||
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant, epoch_size=epoch_size)
|
||||
l2_reg_weight = 0.002
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule, l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
input_var : reader_train.streams.features,
|
||||
label_var : reader_train.streams.labels
|
||||
}
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(tag='Training')
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
max_epochs = 30
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), input_map=input_map) # fetch minibatch.
|
||||
trainer.train_minibatch(data) # update model with it
|
||||
sample_count += data[label_var].num_samples # count samples processed so far
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
persist.save_model(z, os.path.join(model_path, "ConvNet_CIFAR10_{}.dnn".format(epoch)))
|
||||
|
||||
# Load test data
|
||||
reader_test = create_reader(os.path.join(data_path, 'Test_cntk_text.txt'), False, input_dim, num_output_classes)
|
||||
|
||||
input_map = {
|
||||
input_var : reader_test.streams.features,
|
||||
label_var : reader_test.streams.labels
|
||||
}
|
||||
|
||||
# Test data for trained model
|
||||
epoch_size = 10000
|
||||
minibatch_size = 16
|
||||
|
||||
# process minibatches and evaluate the model
|
||||
metric_numer = 0
|
||||
metric_denom = 0
|
||||
sample_count = 0
|
||||
minibatch_index = 0
|
||||
|
||||
while sample_count < epoch_size:
|
||||
current_minibatch = min(minibatch_size, epoch_size - sample_count)
|
||||
# Fetch next test min batch.
|
||||
data = reader_test.next_minibatch(current_minibatch, input_map=input_map)
|
||||
# minibatch data to be trained with
|
||||
metric_numer += trainer.test_minibatch(data) * current_minibatch
|
||||
metric_denom += current_minibatch
|
||||
# Keep track of the number of samples processed so far.
|
||||
sample_count += data[label_var].num_samples
|
||||
minibatch_index += 1
|
||||
|
||||
print("")
|
||||
print("Final Results: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))
|
||||
print("")
|
||||
|
||||
return metric_numer/metric_denom
|
||||
|
||||
if __name__=='__main__':
|
||||
convnet_cifar10()
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
from cntk.utils import *
|
||||
from cntk.layers import *
|
||||
from cntk.models import Sequential, LayerStack
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, relu, element_times, constant
|
||||
from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs
|
||||
from cntk import Trainer, persist, cntk_py
|
||||
from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_schedule, momentum_as_time_constant_schedule
|
||||
from _cntk_py import set_computation_network_trace_level
|
||||
|
||||
# Paths relative to current python file.
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
data_path = os.path.join(abs_path, "..", "..", "..", "Datasets", "CIFAR-10")
|
||||
model_path = os.path.join(abs_path, "Models")
|
||||
|
||||
# model dimensions
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
num_channels = 3 # RGB
|
||||
num_classes = 10
|
||||
|
||||
# Define the reader for both training and evaluation action.
|
||||
def create_reader(map_file, mean_file, train, distributed_communicator=None):
|
||||
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
||||
raise RuntimeError("File '%s' or '%s' does not exist. Please run install_cifar10.py from DataSets/CIFAR-10 to fetch them" %
|
||||
(map_file, mean_file))
|
||||
|
||||
# transformation pipeline for the features has jitter/crop only when training
|
||||
transforms = []
|
||||
if train:
|
||||
transforms += [
|
||||
ImageDeserializer.crop(crop_type='Random', ratio=0.8, jitter_type='uniRatio') # train uses jitter
|
||||
]
|
||||
transforms += [
|
||||
ImageDeserializer.scale(width=image_width, height=image_height, channels=num_channels, interpolations='linear'),
|
||||
ImageDeserializer.mean(mean_file)
|
||||
]
|
||||
# deserializer
|
||||
return MinibatchSource(ImageDeserializer(map_file, StreamDefs(
|
||||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes))), # and second as 'label'
|
||||
distributed_communicator=distributed_communicator)
|
||||
|
||||
|
||||
# Train and evaluate the network.
|
||||
def convnet_cifar10_dataaug(reader_train, reader_test):
|
||||
set_computation_network_trace_level(0)
|
||||
|
||||
# Input variables denoting the features and label data
|
||||
input_var = input_variable((num_channels, image_height, image_width))
|
||||
label_var = input_variable((num_classes))
|
||||
|
||||
# apply model to input
|
||||
scaled_input = element_times(constant(0.00390625), input_var)
|
||||
with default_options (activation=relu, pad=True):
|
||||
z = Sequential([
|
||||
LayerStack(2, lambda : [
|
||||
Convolution((3,3), 64),
|
||||
Convolution((3,3), 64),
|
||||
MaxPooling((3,3), (2,2))
|
||||
]),
|
||||
LayerStack(2, lambda i: [
|
||||
Dense([256,128][i]),
|
||||
Dropout(0.5)
|
||||
]),
|
||||
Dense(num_classes, activation=None)
|
||||
])(scaled_input)
|
||||
|
||||
# loss and metric
|
||||
ce = cross_entropy_with_softmax(z, label_var)
|
||||
pe = classification_error(z, label_var)
|
||||
|
||||
# training config
|
||||
epoch_size = 50000 # for now we manually specify epoch size
|
||||
minibatch_size = 64
|
||||
|
||||
# Set learning parameters
|
||||
lr_per_sample = [0.0015625]*20+[0.00046875]*20+[0.00015625]*20+[0.000046875]*10+[0.000015625]
|
||||
lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size)
|
||||
momentum_time_constant = [0]*20+[600]*20+[1200]
|
||||
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant, epoch_size=epoch_size)
|
||||
l2_reg_weight = 0.002
|
||||
|
||||
# trainer object
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
input_var: reader_train.streams.features,
|
||||
label_var: reader_train.streams.labels
|
||||
}
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(tag='Training')
|
||||
|
||||
# perform model training
|
||||
max_epochs = 80
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
data = reader_train.next_minibatch(min(minibatch_size, epoch_size-sample_count), input_map=input_map) # fetch minibatch.
|
||||
trainer.train_minibatch(data) # update model with it
|
||||
sample_count += data[label_var].num_samples # count samples processed so far
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
persist.save_model(z, os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
|
||||
|
||||
### Evaluation action
|
||||
epoch_size = 10000
|
||||
minibatch_size = 16
|
||||
|
||||
# process minibatches and evaluate the model
|
||||
metric_numer = 0
|
||||
metric_denom = 0
|
||||
sample_count = 0
|
||||
minibatch_index = 0
|
||||
|
||||
while sample_count < epoch_size:
|
||||
current_minibatch = min(minibatch_size, epoch_size - sample_count)
|
||||
# Fetch next test min batch.
|
||||
data = reader_test.next_minibatch(current_minibatch, input_map=input_map)
|
||||
# minibatch data to be trained with
|
||||
metric_numer += trainer.test_minibatch(data) * current_minibatch
|
||||
metric_denom += current_minibatch
|
||||
# Keep track of the number of samples processed so far.
|
||||
sample_count += data[label_var].num_samples
|
||||
minibatch_index += 1
|
||||
|
||||
print("")
|
||||
print("Final Results: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))
|
||||
print("")
|
||||
|
||||
return metric_numer/metric_denom
|
||||
|
||||
if __name__=='__main__':
|
||||
reader_train = create_reader(os.path.join(data_path, 'train_map.txt'), os.path.join(data_path, 'CIFAR-10_mean.xml'), True)
|
||||
reader_test = create_reader(os.path.join(data_path, 'test_map.txt'), os.path.join(data_path, 'CIFAR-10_mean.xml'), False)
|
||||
|
||||
convnet_cifar10_dataaug(reader_train, reader_test)
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
from cntk import Trainer, persist
|
||||
from cntk.utils import *
|
||||
from cntk.layers import *
|
||||
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, FULL_DATA_SWEEP
|
||||
from cntk.learner import momentum_sgd, learning_rate_schedule
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, relu, element_times, constant
|
||||
|
||||
# Paths relative to current python file.
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
data_path = os.path.join(abs_path, "..", "..", "..", "Datasets", "MNIST")
|
||||
model_path = os.path.join(abs_path, "Models")
|
||||
|
||||
# Define the reader for both training and evaluation action.
|
||||
def create_reader(path, is_training, input_dim, label_dim):
|
||||
return MinibatchSource(CTFDeserializer(path, StreamDefs(
|
||||
features = StreamDef(field='features', shape=input_dim),
|
||||
labels = StreamDef(field='labels', shape=label_dim)
|
||||
)), randomize=is_training, epoch_size = INFINITELY_REPEAT if is_training else FULL_DATA_SWEEP)
|
||||
|
||||
|
||||
# Creates and trains a feedforward classification model for MNIST images
|
||||
def convnet_mnist(debug_output=False):
|
||||
image_height = 28
|
||||
image_width = 28
|
||||
num_channels = 1
|
||||
input_dim = image_height * image_width * num_channels
|
||||
num_output_classes = 10
|
||||
|
||||
# Input variables denoting the features and label data
|
||||
input_var = input_variable((num_channels, image_height, image_width), np.float32)
|
||||
label_var = input_variable(num_output_classes, np.float32)
|
||||
|
||||
# Instantiate the feedforward classification model
|
||||
scaled_input = element_times(constant(0.00390625), input_var)
|
||||
with default_options (activation=relu, pad=False):
|
||||
conv1 = Convolution((5,5), 32, pad=True)(scaled_input)
|
||||
pool1 = MaxPooling((3,3), (2,2))(conv1)
|
||||
conv2 = Convolution((3,3), 48)(pool1)
|
||||
pool2 = MaxPooling((3,3), (2,2))(conv2)
|
||||
conv3 = Convolution((3,3), 64)(pool2)
|
||||
f4 = Dense(96)(conv3)
|
||||
drop4 = Dropout(0.5)(f4)
|
||||
z = Dense(num_output_classes, activation=None)(drop4)
|
||||
|
||||
ce = cross_entropy_with_softmax(z, label_var)
|
||||
pe = classification_error(z, label_var)
|
||||
|
||||
reader_train = create_reader(os.path.join(data_path, 'Train-28x28_cntk_text.txt'), True, input_dim, num_output_classes)
|
||||
|
||||
# training config
|
||||
epoch_size = 60000 # for now we manually specify epoch size
|
||||
minibatch_size = 128
|
||||
|
||||
# Set learning parameters
|
||||
lr_per_sample = [0.001]*10+[0.0005]*10+[0.0001]
|
||||
lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size)
|
||||
momentum_time_constant = [0]*5+[1024]
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, momentum_time_constant)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
input_var : reader_train.streams.features,
|
||||
label_var : reader_train.streams.labels
|
||||
}
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(tag='Training')
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
max_epochs = 40
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
data = reader_train.next_minibatch(min(minibatch_size, epoch_size - sample_count), input_map=input_map) # fetch minibatch.
|
||||
trainer.train_minibatch(data) # update model with it
|
||||
sample_count += data[label_var].num_samples # count samples processed so far
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
persist.save_model(z, os.path.join(model_path, "ConvNet_MNIST_{}.dnn".format(epoch)))
|
||||
|
||||
# Load test data
|
||||
reader_test = create_reader(os.path.join(data_path, 'Test-28x28_cntk_text.txt'), False, input_dim, num_output_classes)
|
||||
|
||||
input_map = {
|
||||
input_var : reader_test.streams.features,
|
||||
label_var : reader_test.streams.labels
|
||||
}
|
||||
|
||||
# Test data for trained model
|
||||
epoch_size = 10000
|
||||
minibatch_size = 1024
|
||||
|
||||
# process minibatches and evaluate the model
|
||||
metric_numer = 0
|
||||
metric_denom = 0
|
||||
sample_count = 0
|
||||
minibatch_index = 0
|
||||
|
||||
while sample_count < epoch_size:
|
||||
current_minibatch = min(minibatch_size, epoch_size - sample_count)
|
||||
# Fetch next test min batch.
|
||||
data = reader_test.next_minibatch(current_minibatch, input_map=input_map)
|
||||
# minibatch data to be trained with
|
||||
metric_numer += trainer.test_minibatch(data) * current_minibatch
|
||||
metric_denom += current_minibatch
|
||||
# Keep track of the number of samples processed so far.
|
||||
sample_count += data[label_var].num_samples
|
||||
minibatch_index += 1
|
||||
|
||||
print("")
|
||||
print("Final Results: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))
|
||||
print("")
|
||||
|
||||
return metric_numer/metric_denom
|
||||
|
||||
if __name__=='__main__':
|
||||
convnet_mnist()
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# CNTK Examples: Image/Classification/ConvNet
|
||||
|
||||
## Python
|
||||
|
||||
### ConvNet_MNIST.py
|
||||
|
||||
Our first example applies CNN on the MNIST dataset. The network we use contains three convolution layers and two dense layers. Dropout is applied after the first dense layer. No data augmentation is used in this example.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`python ConvNet_MNIST.py`
|
||||
|
||||
The network achieves an error rate around `0.5%`, which is very good considering no data augmentation is used. This accuracy is comparable, if not better, than many other vanilla CNN implementations (http://yann.lecun.com/exdb/mnist/).
|
||||
|
||||
### ConvNet_CIFAR10.py
|
||||
|
||||
The second example applies CNN on the CIFAR-10 dataset. The network contains four convolution layers and three dense layers. Max pooling is conducted for every two convolution layers. Dropout is applied after the first two dense layers. No data augmentation is used.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`python ConvNet_CIFAR10.py`
|
||||
|
||||
The network achieves an error rate of around `18%` after 30 epochs. This is comparable to the network published by [cuda-convnet](https://code.google.com/p/cuda-convnet/), which has 18% error with no data augmentation. One difference is that we do not use a `local response normalization layer`. This layer type is now rarely used in most state-of-the-art deep learning networks.
|
||||
|
||||
### ConvNet_CIFAR10_DataAug.py
|
||||
|
||||
The third example uses the same CNN as the previous example, but it improves by adding data augmentation to training. For this purpose, we use the `ImageDeserializer` instead of the `CTFDeserializer` to load the data. The image deserializer currently supports crop, flip, scale, color jittering, and mean subtraction.
|
||||
For a reference on image reader and transforms, please check [here](https://www.cntk.ai/pythondocs/cntk.io.html?highlight=imagedeserializer#cntk.io.ImageDeserializer).
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`python ConvNet_CIFAR10_DataAug.py`
|
||||
|
||||
We use a fixed crop ratio of `0.8` and scale the image to `32x32` pixels for training. Since all training images are pre-padded to `40x40` pixels, effectively we only perform translation transform without scaling. The accuracy of the network on test data is around `14%`, which is a lot better than the previous model.
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
|Data: |The MNIST dataset (http://yann.lecun.com/exdb/mnist/) of handwritten digits and the CIFAR-10 dataset (http://www.cs.toronto.edu/~kriz/cifar.html) for image classification.
|
||||
|:---------|:---
|
||||
|Purpose |This folder contains a number of examples that demonstrate the usage of BrainScript to define convolutional neural networks for image classification.
|
||||
|Purpose |This folder contains a number of examples that demonstrate how to use CNTK to define convolutional neural networks for image classification.
|
||||
|Network |Convolutional neural networks.
|
||||
|Training |Stochastic gradient descent with momentum.
|
||||
|Comments |See below.
|
||||
|
@ -13,39 +13,14 @@
|
|||
|
||||
### Getting the data
|
||||
|
||||
we use the MNIST and CIFAR-10 datasets to demonstrate how to train a `convolutional neural network (CNN)`. CNN has been one of the most popular neural networks for image-related tasks. A very well-known early work on CNN is the [LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf). In 2012 Alex Krizhevsky, Ilya Sutskever, and Geoffrey Hinton won the ILSVRC-2012 competition using a [CNN architecture](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). And most state-of-the-art neural networks on image classification tasks today adopts a modified CNN architecture, such as [VGG](../VGG), [GoogLeNet](../GoogLeNet), [ResNet](../ResNet), etc.
|
||||
We use the MNIST and CIFAR-10 datasets to demonstrate how to train a `convolutional neural network (CNN)`. CNN has been one of the most popular neural networks for image-related tasks. A very well-known early work on CNN is the [LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf). In 2012 Alex Krizhevsky, Ilya Sutskever, and Geoffrey Hinton won the ILSVRC-2012 competition using a CNN architecture, [AlexNet](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). And most state-of-the-art neural networks on image classification tasks today adopt a modified CNN architecture, such as [VGG](../VGG), [GoogLeNet](../GoogLeNet), [ResNet](../ResNet), etc.
|
||||
|
||||
MNIST and CIFAR-10 datasets are not included in the CNTK distribution but can be easily downloaded and converted by following the instructions in [DataSets/MNIST](../../DataSets/MNIST) and [DataSets/CIFAR-10](../../DataSets/CIFAR-10). We recommend you to keep the downloaded data in the respective folder while downloading, as the configuration files in this folder assumes that by default.
|
||||
|
||||
## Details
|
||||
|
||||
### ConvNet_MNIST.cntk
|
||||
We offer multiple CNN examples, including one for the MNIST dataset, and two for the CIFAR-10 dataset (one with and one without data augmentation). For details, please click the respective links below.
|
||||
|
||||
Our first example applies CNN on the MNIST dataset. The network we use contains three convolution layers and two dense layers. Dropout is applied after the first dense layer. No data augmentation is used in this example. We start the training with no momentum, and add momentum after training for 5 epochs. Please refer to the cntk configuration file [ConvNet_MNIST.cntk](./ConvNet_MNIST.cntk) for more details.
|
||||
### [Python](./Python)
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ConvNet_MNIST.cntk`
|
||||
|
||||
The network achieves an error rate of `0.5%`, which is very good considering no data augmentation is used. This accuracy is comparable, if not better, than many other vanilla CNN implementations (http://yann.lecun.com/exdb/mnist/).
|
||||
|
||||
### ConvNet_CIFAR10.cntk
|
||||
|
||||
The second exmaple applies CNN on the CIFAR-10 dataset. The network contains four convolution layers and three dense layers. Max pooling is conducted for every two convolution layers. Dropout is applied after the first two dense layers. No data augmentation is used. Please refer to the cntk configuration file [ConvNet_CIFAR10.cntk](./ConvNet_CIFAR10.cntk) for more details.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ConvNet_CIFAR10.cntk`
|
||||
|
||||
The network achieves an error rate of `18.51%` after 30 epochs. This is comparable to the network published by [cuda-convnet](https://code.google.com/p/cuda-convnet/), which has 18% error with no data augmentation. One difference is that we do not use a `local response normalization layer`. This layer type is now rarely used in most state-of-the-art deep learning networks.
|
||||
|
||||
### ConvNet_CIFAR10_DataAug.cntk
|
||||
|
||||
The third example uses the same CNN as the previous example, but it improves by adding data augmentation to training. For this purpose, we use the `ImageReader` instead of the `CNTKTextFormatReader` to load the data. The ImageReader currently supports crop, flip, scale, color jittering, and mean subtraction.
|
||||
For a reference on image reader and transforms, please check [here](https://github.com/Microsoft/CNTK/wiki/Image-reader).
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ConvNet_CIFAR10_DataAug.cntk`
|
||||
|
||||
As seen in the cntk configuration file [ConvNet_CIFAR10_DataAug.cntk](./ConvNet_CIFAR10_DataAug.cntk), we use a fix crop ratio of `0.8` and scale the image to `32x32` pixels for training. Since all training images are pre-padded to `40x40` pixels, effectively we only perfrom translation transform without scaling. The accuracy of the network on test data is `14.39%`, which is a lot better than the previous model.
|
||||
### [BrainScript](./BrainScript)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# CNTK Examples: Image/Classification/ResNet
|
||||
|
||||
## BrainScript
|
||||
|
||||
### ResNet20_CIFAR10.cntk
|
||||
|
||||
Our first example applies a relatively shallow ResNet on the CIFAR-10 dataset. We strictly follow the [ResNet paper](http://arxiv.org/abs/1512.03385) for the network architecture. That is, the network has a first layer of `3x3` convolutions, followed by `6n` layers with `3x3` convolution on the feature maps of size `{32, 16, 8}` respectively, with `2n` layers for each feature map size. Note for ResNet20, we have `n=3`. The network ends with a global average pooling, a 10-way fully-connected
|
||||
layer, and softmax. [Batch normalization](https://arxiv.org/abs/1502.03167) is applied everywhere except the last fully-connected layer.
|
||||
|
||||
We use a fixed crop ratio of `0.8` and scale the image to `32x32` pixels for training. Since all training images are pre-padded to `40x40` pixels, effectively we only perform translation transform without scaling. Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ResNet20_CIFAR10.cntk`
|
||||
|
||||
The network achieves an error rate of about `8.2%`, which is lower than the number reported in the original paper.
|
||||
|
||||
### ResNet110_CIFAR10.cntk
|
||||
|
||||
In this example we increase the depth of the ResNet to 110 layers. That is, we set `n=18`. Only very minor changes are made to the CNTK configuration file. To run this example, use:
|
||||
|
||||
`cntk configFile=ResNet110_CIFAR10.cntk`
|
||||
|
||||
The network achieves an error rate of about `6.3%`.
|
||||
|
||||
### ResNet50_ImageNet1K.cntk
|
||||
|
||||
This is an example using a 50-layer ResNet to train on ILSVRC2012 datasets. Compared with the CIFAR-10 examples, we introduced bottleneck blocks to reduce the amount of computation by replacing the two `3x3` convolutions by a `1x1` convolution, bottlenecked to 1/4 of feature maps, followed by a `3x3` convolution, and then a `1x1` convolution again, with the same number feature maps as input.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ResNet50_ImageNet1K.cntk`
|
||||
|
||||
### ResNet101_ImageNet1K.cntk
|
||||
|
||||
Increase the depth of the ResNet to 101 layers:
|
||||
|
||||
`cntk configFile=ResNet101_ImageNet1K.cntk`
|
||||
|
||||
### ResNet152_ImageNet1K.cntk
|
||||
|
||||
Further increase the depth of the ResNet to 152 layers:
|
||||
|
||||
`cntk configFile=ResNet152_ImageNet1K.cntk`
|
|
@ -8,6 +8,7 @@ precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
|||
rootDir = "." ; configDir = "$RootDir$" ; dataDir = "$RootDir$" ;
|
||||
outputDir = "$rootDir$/Output" ;
|
||||
|
||||
meanDir = "$rootDir$/../../../DataSets/ImageNet"
|
||||
modelPath = "$outputDir$/Models/ResNet_101"
|
||||
stderr = "$outputDir$/ResNet_101_BS_out"
|
||||
|
||||
|
@ -113,7 +114,7 @@ TrainNetwork = {
|
|||
jitterType = "UniRatio"
|
||||
cropRatio = 0.46666:0.875
|
||||
hflip = true
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -130,7 +131,7 @@ TrainNetwork = {
|
|||
channels = 3
|
||||
cropType = "Center"
|
||||
cropRatio = 0.875
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -158,7 +159,7 @@ BNStatistics = {
|
|||
hflip = true
|
||||
cropRatio = 0.46666:0.875
|
||||
jitterType = "UniRatio"
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -182,7 +183,7 @@ Eval = {
|
|||
channels = 3
|
||||
cropType = "Center"
|
||||
cropRatio = 0.875
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
|
@ -4,7 +4,7 @@ command = TrainConvNet:Eval
|
|||
|
||||
precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
||||
|
||||
rootDir = "../.." ; configDir = "./" ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
rootDir = "../../.." ; configDir = "./" ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
outputDir = "./Output" ;
|
||||
|
||||
modelPath = "$outputDir$/Models/ResNet110_CIFAR10_DataAug"
|
|
@ -8,6 +8,7 @@ precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
|||
rootDir = "." ; configDir = "$RootDir$" ; dataDir = "$RootDir$" ;
|
||||
outputDir = "$rootDir$/Output" ;
|
||||
|
||||
meanDir = "$rootDir$/../../../DataSets/ImageNet"
|
||||
modelPath = "$outputDir$/Models/ResNet_152"
|
||||
stderr = "$outputDir$/ResNet_152_BS_out"
|
||||
|
||||
|
@ -113,7 +114,7 @@ TrainNetwork = {
|
|||
jitterType = "UniRatio"
|
||||
cropRatio = 0.46666:0.875
|
||||
hflip = true
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -130,7 +131,7 @@ TrainNetwork = {
|
|||
channels = 3
|
||||
cropType = "Center"
|
||||
cropRatio = 0.875
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -158,7 +159,7 @@ BNStatistics = {
|
|||
hflip = true
|
||||
cropRatio = 0.46666:0.875
|
||||
jitterType = "UniRatio"
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -182,7 +183,7 @@ Eval = {
|
|||
channels = 3
|
||||
cropType = "Center"
|
||||
cropRatio = 0.875
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
|
@ -4,7 +4,7 @@ command = TrainConvNet:Eval
|
|||
|
||||
precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
||||
|
||||
rootDir = "../.." ; configDir = "./" ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
rootDir = "../../.." ; configDir = "./" ; dataDir = "$rootDir$/DataSets/CIFAR-10" ;
|
||||
outputDir = "./Output" ;
|
||||
|
||||
modelPath = "$outputDir$/Models/ResNet20_CIFAR10_DataAug"
|
|
@ -8,6 +8,7 @@ precision = "float"; traceLevel = 1 ; deviceId = "auto"
|
|||
rootDir = "." ; configDir = "$RootDir$" ; dataDir = "$RootDir$" ;
|
||||
outputDir = "$rootDir$/Output" ;
|
||||
|
||||
meanDir = "$rootDir$/../../../DataSets/ImageNet"
|
||||
modelPath = "$outputDir$/Models/ResNet_50"
|
||||
stderr = "$outputDir$/ResNet_50_BS_out"
|
||||
|
||||
|
@ -113,7 +114,7 @@ TrainNetwork = {
|
|||
jitterType = "UniRatio"
|
||||
cropRatio = 0.46666:0.875
|
||||
hflip = true
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -130,7 +131,7 @@ TrainNetwork = {
|
|||
channels = 3
|
||||
cropType = "Center"
|
||||
cropRatio = 0.875
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -159,7 +160,7 @@ BNStatistics = {
|
|||
hflip = true
|
||||
cropRatio = 0.46666:0.875
|
||||
jitterType = "UniRatio"
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
||||
|
@ -183,7 +184,7 @@ Eval = {
|
|||
channels = 3
|
||||
cropType = "Center"
|
||||
cropRatio = 0.875
|
||||
meanFile = "$ConfigDir$/ImageNet1K_mean.xml"
|
||||
meanFile = "$meanDir$/ImageNet1K_mean.xml"
|
||||
}
|
||||
labels = {
|
||||
labelDim = 1000
|
|
@ -0,0 +1,14 @@
|
|||
# CNTK Examples: Image/Classification/ResNet
|
||||
|
||||
## Python
|
||||
|
||||
### TrainResNet_CIFAR10.py
|
||||
|
||||
This example code applies ResNet on the CIFAR-10 dataset. We strictly follow the [ResNet paper](http://arxiv.org/abs/1512.03385) for the network architecture. That is, the network has a first layer of `3x3` convolutions, followed by `6n` layers with `3x3` convolution on the feature maps of size `{32, 16, 8}` respectively, with `2n` layers for each feature map size. For ResNet20, we have `n=3`, for ResNet110, we have `n=18`. The network ends with a global average pooling, a 10-way fully-connected layer, and softmax. [Batch normalization](https://arxiv.org/abs/1502.03167) is applied everywhere except the last fully-connected layer.
|
||||
|
||||
We use a fixed crop ratio of `0.8` and scale the image to `32x32` pixels for training. Since all training images are pre-padded to `40x40` pixels, effectively we only perform translation transform without scaling. Run the example from the current folder using:
|
||||
|
||||
`python TrainResNet_CIFAR10.py resnet20`
|
||||
`python TrainResNet_CIFAR10.py resnet110`
|
||||
|
||||
for ResNet20 and ResNet110, respectively. The ResNet20 network achieves an error rate of about `8.2%`, and the ResNet110 network achieves an error rate of about `6.3%`.
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from cntk.utils import *
|
||||
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error
|
||||
from cntk.io import MinibatchSource, ImageDeserializer, StreamDef, StreamDefs
|
||||
from cntk import Trainer, persist, cntk_py
|
||||
from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_as_time_constant_schedule
|
||||
from _cntk_py import set_computation_network_trace_level
|
||||
|
||||
from resnet_models import *
|
||||
|
||||
# Paths relative to current python file.
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
data_path = os.path.join(abs_path, "..", "..", "..", "Datasets", "CIFAR-10")
|
||||
model_path = os.path.join(abs_path, "Models")
|
||||
|
||||
# model dimensions
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
num_channels = 3 # RGB
|
||||
num_classes = 10
|
||||
|
||||
# Define the reader for both training and evaluation action.
|
||||
def create_reader(map_file, mean_file, train):
|
||||
if not os.path.exists(map_file) or not os.path.exists(mean_file):
|
||||
raise RuntimeError("File '%s' or '%s' does not exist. Please run install_cifar10.py from DataSets/CIFAR-10 to fetch them" %
|
||||
(map_file, mean_file))
|
||||
|
||||
# transformation pipeline for the features has jitter/crop only when training
|
||||
transforms = []
|
||||
if train:
|
||||
transforms += [
|
||||
ImageDeserializer.crop(crop_type='Random', ratio=0.8, jitter_type='uniRatio') # train uses jitter
|
||||
]
|
||||
transforms += [
|
||||
ImageDeserializer.scale(width=image_width, height=image_height, channels=num_channels, interpolations='linear'),
|
||||
ImageDeserializer.mean(mean_file)
|
||||
]
|
||||
# deserializer
|
||||
return MinibatchSource(ImageDeserializer(map_file, StreamDefs(
|
||||
features = StreamDef(field='image', transforms=transforms), # first column in map file is referred to as 'image'
|
||||
labels = StreamDef(field='label', shape=num_classes)))) # and second as 'label'
|
||||
|
||||
|
||||
# Train and evaluate the network.
|
||||
def train_and_evaluate(reader_train, reader_test, network_name):
|
||||
|
||||
set_computation_network_trace_level(0)
|
||||
|
||||
# Input variables denoting the features and label data
|
||||
input_var = input_variable((num_channels, image_height, image_width))
|
||||
label_var = input_variable((num_classes))
|
||||
|
||||
# create model, and configure learning parameters
|
||||
if network_name == 'resnet20':
|
||||
z = create_cifar10_model(input_var, 3, num_classes)
|
||||
lr_per_mb = [1.0]*80+[0.1]*40+[0.01]
|
||||
elif network_name == 'resnet110':
|
||||
z = create_cifar10_model(input_var, 18, num_classes)
|
||||
lr_per_mb = [0.1]*1+[1.0]*80+[0.1]*40+[0.01]
|
||||
else:
|
||||
return RuntimeError("Unknown model name!")
|
||||
|
||||
# loss and metric
|
||||
ce = cross_entropy_with_softmax(z, label_var)
|
||||
pe = classification_error(z, label_var)
|
||||
|
||||
# shared training parameters
|
||||
epoch_size = 50000 # for now we manually specify epoch size
|
||||
minibatch_size = 128
|
||||
max_epochs = 160
|
||||
momentum_time_constant = -minibatch_size/np.log(0.9)
|
||||
l2_reg_weight = 0.0001
|
||||
|
||||
# Set learning parameters
|
||||
lr_per_sample = [lr/minibatch_size for lr in lr_per_mb]
|
||||
lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size)
|
||||
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant)
|
||||
|
||||
# trainer object
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
input_var: reader_train.streams.features,
|
||||
label_var: reader_train.streams.labels
|
||||
}
|
||||
|
||||
log_number_of_parameters(z) ; print()
|
||||
progress_printer = ProgressPrinter(tag='Training')
|
||||
|
||||
# perform model training
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
data = reader_train.next_minibatch(min(minibatch_size, epoch_size-sample_count), input_map=input_map) # fetch minibatch.
|
||||
trainer.train_minibatch(data) # update model with it
|
||||
sample_count += data[label_var].num_samples # count samples processed so far
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
persist.save_model(z, os.path.join(model_path, network_name + "_{}.dnn".format(epoch)))
|
||||
|
||||
# Evaluation parameters
|
||||
epoch_size = 10000
|
||||
minibatch_size = 16
|
||||
|
||||
# process minibatches and evaluate the model
|
||||
metric_numer = 0
|
||||
metric_denom = 0
|
||||
sample_count = 0
|
||||
minibatch_index = 0
|
||||
|
||||
while sample_count < epoch_size:
|
||||
current_minibatch = min(minibatch_size, epoch_size - sample_count)
|
||||
# Fetch next test min batch.
|
||||
data = reader_test.next_minibatch(current_minibatch, input_map=input_map)
|
||||
# minibatch data to be trained with
|
||||
metric_numer += trainer.test_minibatch(data) * current_minibatch
|
||||
metric_denom += current_minibatch
|
||||
# Keep track of the number of samples processed so far.
|
||||
sample_count += data[label_var].num_samples
|
||||
minibatch_index += 1
|
||||
|
||||
print("")
|
||||
print("Final Results: Minibatch[1-{}]: errs = {:0.2f}% * {}".format(minibatch_index+1, (metric_numer*100.0)/metric_denom, metric_denom))
|
||||
print("")
|
||||
|
||||
return metric_numer/metric_denom
|
||||
|
||||
if __name__=='__main__':
|
||||
reader_train = create_reader(os.path.join(data_path, 'train_map.txt'), os.path.join(data_path, 'CIFAR-10_mean.xml'), True)
|
||||
reader_test = create_reader(os.path.join(data_path, 'test_map.txt'), os.path.join(data_path, 'CIFAR-10_mean.xml'), False)
|
||||
|
||||
if ('--help' in sys.argv) or ('-h' in sys.argv):
|
||||
print("Trains a neural network on CIFAR-10 using CNTK.")
|
||||
print("Usage: %s [MODEL_NAME]" % sys.argv[0])
|
||||
print("MODEL_NAME: name of the model to be trained. Available models: 'resnet20' or 'resnet110'.")
|
||||
else:
|
||||
if len(sys.argv) != 2:
|
||||
print("Please specify MODEL_NAME. ")
|
||||
else:
|
||||
network_name = sys.argv[1]
|
||||
train_and_evaluate(reader_train, reader_test, network_name)
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
from cntk.initializer import he_normal
|
||||
from cntk.layers import AveragePooling, BatchNormalization, Convolution, Dense
|
||||
from cntk.ops import element_times, relu
|
||||
|
||||
#
|
||||
# Resnet building blocks
|
||||
#
|
||||
def conv_bn(input, filter_size, num_filters, strides=(1,1), init=he_normal()):
|
||||
c = Convolution(filter_size, num_filters, activation=None, init=init, pad=True, strides=strides, bias=False)(input)
|
||||
r = BatchNormalization(map_rank=1, normalization_time_constant=4096, use_cntk_engine=False)(c)
|
||||
return r
|
||||
|
||||
def conv_bn_relu(input, filter_size, num_filters, strides=(1,1), init=he_normal()):
|
||||
r = conv_bn(input, filter_size, num_filters, strides, init)
|
||||
return relu(r)
|
||||
|
||||
def resnet_basic(input, num_filters):
|
||||
c1 = conv_bn_relu(input, (3,3), num_filters)
|
||||
c2 = conv_bn(c1, (3,3), num_filters)
|
||||
p = c2 + input
|
||||
return relu(p)
|
||||
|
||||
def resnet_basic_inc(input, num_filters, strides=(2,2)):
|
||||
c1 = conv_bn_relu(input, (3,3), num_filters, strides)
|
||||
c2 = conv_bn(c1, (3,3), num_filters)
|
||||
s = conv_bn(input, (1,1), num_filters, strides)
|
||||
p = c2 + s
|
||||
return relu(p)
|
||||
|
||||
def resnet_basic_stack(input, num_stack_layers, num_filters):
|
||||
assert (num_stack_layers >= 0)
|
||||
l = input
|
||||
for _ in range(num_stack_layers):
|
||||
l = resnet_basic(l, num_filters)
|
||||
return l
|
||||
|
||||
#
|
||||
# Defines the residual network model for classifying images
|
||||
#
|
||||
def create_cifar10_model(input, num_stack_layers, num_classes):
|
||||
c_map = [16, 32, 64]
|
||||
|
||||
conv = conv_bn_relu(input, (3,3), c_map[0])
|
||||
r1 = resnet_basic_stack(conv, num_stack_layers, c_map[0])
|
||||
|
||||
r2_1 = resnet_basic_inc(r1, c_map[1])
|
||||
r2_2 = resnet_basic_stack(r2_1, num_stack_layers-1, c_map[1])
|
||||
|
||||
r3_1 = resnet_basic_inc(r2_2, c_map[2])
|
||||
r3_2 = resnet_basic_stack(r3_1, num_stack_layers-1, c_map[2])
|
||||
|
||||
# Global average pooling and output
|
||||
pool = AveragePooling(filter_shape=(8,8))(r3_2)
|
||||
z = Dense(num_classes)(pool)
|
||||
return z
|
|
@ -1,12 +1,10 @@
|
|||
# CNTK Examples: Image/Classification/ResNet
|
||||
|
||||
The recipes above are in BrainScript. [For Python click here](https://github.com/Microsoft/CNTK/blob/master/bindings/python/tutorials/CNTK_201A_CIFAR-10_DataLoader.ipynb).
|
||||
|
||||
## Overview
|
||||
|
||||
|Data: |The CIFAR-10 dataset (http://www.cs.toronto.edu/~kriz/cifar.html) and the ILSVRC2012 dataset (http://www.image-net.org/challenges/LSVRC/2012/) for image classification.
|
||||
|:---------|:---
|
||||
|Purpose |This folder contains a number of examples that demonstrate the usage of BrainScript to define residual network (http://arxiv.org/abs/1512.03385) for image classification.
|
||||
|Purpose |This folder contains a number of examples that demonstrate how to use CNTK to define residual network (http://arxiv.org/abs/1512.03385) for image classification.
|
||||
|Network |Deep convolutional residual networks (ResNet).
|
||||
|Training |Stochastic gradient descent with momentum.
|
||||
|Comments |See below.
|
||||
|
@ -20,43 +18,8 @@ CIFAR-10 and ILSVRC2012 datasets are not included in the CNTK distribution. The
|
|||
|
||||
## Details
|
||||
|
||||
### ResNet20_CIFAR10.cntk
|
||||
We offer multiple ResNet examples, including ResNet20 and ResNet110 for CIFAR-10 dataset, and ResNet50, ResNet101 and ResNet152 for the ILSVRC2012 dataset (BrainScript only at this moment). For details, please click the respective links below.
|
||||
|
||||
Our first example applies a relatively shallow ResNet on the CIFAR-10 dataset. We strictly follow the [ResNet paper](http://arxiv.org/abs/1512.03385) for the network architecture. That is, the network has a first layer of `3x3` convolutions, followed by `6n` layers with `3x3` convolution on the feature maps of size `{32, 16, 8}` respectively, with `2n` layers for each feature map size. Note for ResNet20, we have `n=3`. The network ends with a global average pooling, a 10-way fully-connected
|
||||
layer, and softmax. [Batch normalization](https://arxiv.org/abs/1502.03167) is applied everywhere except the last fully-connected layer.
|
||||
### [Python](./Python)
|
||||
|
||||
Other than the network architecture, the CIFAR-10 dataset is augmented with random translation, identical to that in [GettingStarted/ConvNet_CIFAR10_DataAug.cntk](../../GettingStarted/ConvNet_CIFAR10_DataAug.cntk). Please refer to the cntk configuration file [ResNet20_CIFAR10.cntk](./ResNet20_CIFAR10.cntk) for more details.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ResNet20_CIFAR10.cntk`
|
||||
|
||||
The network achieves an error rate of about `8.2%`, which is lower than the number reported in the original paper.
|
||||
|
||||
### ResNet110_CIFAR10.cntk
|
||||
|
||||
In this example we increase the depth of the ResNet to 110 layers. That is, we set `n=18`. Only very minor changes are made to the CNTK configuration file. To run this example, use:
|
||||
|
||||
`cntk configFile=ResNet110_CIFAR10.cntk`
|
||||
|
||||
The network achieves an error rate of about `6.2-6.5%`.
|
||||
|
||||
### ResNet50_ImageNet1K.cntk
|
||||
|
||||
This is an example using a 50-layer ResNet to train on ILSVRC2012 datasets. Compared with the CIFAR-10 examples, we introduced bottleneck blocks to reduce the amount of computation by replacing the two `3x3` convolutions by a `1x1` convolution, bottlenecked to 1/4 of feature maps, followed by a `3x3` convolution, and then a `1x1` convolution again, with the same number feature maps as input.
|
||||
|
||||
Run the example from the current folder using:
|
||||
|
||||
`cntk configFile=ResNet50_ImageNet1K.cntk`
|
||||
|
||||
### ResNet101_ImageNet1K.cntk
|
||||
|
||||
Increase the depth of the ResNet to 101 layers:
|
||||
|
||||
`cntk configFile=ResNet101_ImageNet1K.cntk`
|
||||
|
||||
### ResNet152_ImageNet1K.cntk
|
||||
|
||||
Further increase the depth of the ResNet to 152 layers:
|
||||
|
||||
`cntk configFile=ResNet152_ImageNet1K.cntk`
|
||||
### [BrainScript](./BrainScript)
|
||||
|
|
Загрузка…
Ссылка в новой задаче