FRCN Python perf better than BS on AlexNet

This commit is contained in:
Philipp Kranen 2017-01-24 11:10:49 +01:00
Родитель 517d6c06d6
Коммит 8b64aafe8f
5 изменённых файлов: 162 добавлений и 168 удалений

Просмотреть файл

@ -17,6 +17,7 @@ image_sets = ["train", "test"]
boAddSelectiveSearchROIs = True
boAddRoisOnGrid = True
####################################
# Main
####################################

Просмотреть файл

@ -5,237 +5,220 @@
# ==============================================================================
from __future__ import print_function
import numpy as np
import os, sys, importlib
from cntk import * # Trainer, load_model, UnitType
from cntk.device import cpu, set_default_device
from cntk.learner import sgd
from cntk import Trainer, UnitType, load_model
from cntk.blocks import Placeholder, Constant
from cntk.layers import Dense
from cntk.graph import find_by_name, plot
from cntk.initializer import glorot_uniform
from cntk.io import ReaderConfig, ImageDeserializer, CTFDeserializer
from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_as_time_constant_schedule
from cntk.ops import input_variable, constant, parameter, cross_entropy_with_softmax, classification_error, times, combine
from cntk.ops import input_variable, parameter, cross_entropy_with_softmax, classification_error, times, combine
from cntk.ops import roipooling
from cntk.ops.functions import CloneMethod
from cntk.io import ReaderConfig, ImageDeserializer, CTFDeserializer, StreamConfiguration
from cntk.initializer import glorot_uniform
from cntk.graph import find_by_name, depth_first_search
import PARAMETERS
locals().update(importlib.import_module("PARAMETERS").__dict__)
from cntk.utils import log_number_of_parameters, ProgressPrinter
from PARAMETERS import *
import numpy as np
import os, sys
###############################################################
###############################################################
abs_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(abs_path, "..", ".."))
TRAIN_MAP_FILENAME = 'train.txt'
TEST_MAP_FILENAME = 'test.txt'
ROIS_FILENAME_POSTFIX = '.rois.txt'
ROILABELS_FILENAME_POSTFIX = '.roilabels.txt'
# file and stream names
train_map_filename = 'train.txt'
test_map_filename = 'test.txt'
rois_filename_postfix = '.rois.txt'
roilabels_filename_postfix = '.roilabels.txt'
features_stream_name = 'features'
roi_stream_name = 'rois'
label_stream_name = 'roiLabels'
# from PARAMETERS.py
base_path = cntkFilesDir
num_channels = 3
image_height = cntk_padHeight
image_width = cntk_padWidth
num_classes = nrClasses
num_rois = cntk_nrRois
epoch_size = cntk_num_train_images
num_test_images = cntk_num_test_images
mb_size = cntk_mb_size
max_epochs = cntk_max_epochs
momentum_time_constant = cntk_momentum_time_constant
# model specific variables
use_model = "VGG"
if (use_model == "AlexNet"):
# model specific variables (only AlexNet for now)
base_model = "AlexNet"
if base_model == "AlexNet":
model_file = "../../../../../PretrainedModels/AlexNet.model"
feature_node_name = "features"
conv5_node_name = "conv5.y" # "z.x._.x._.x.x_output"
pool3_node_name = "pool3" # "z.x._.x._.x_output"
h2d_node_name = "h2_d" # "z.x_output"
elif (use_model == "VGG"):
model_file = "../../../../../PretrainedModels/VGG19_legacy.model"
feature_node_name = "data"
conv5_node_name = "relu5_4" # "z.x._.x._.x.x_output"
pool3_node_name = "pool5" # "z.x._.x._.x_output"
h2d_node_name = "drop7" # "z.x_output"
# Helper to print all node names
def print_all_node_names(model_file, is_BrainScript=True):
loaded_model = load_model(model_file)
if is_BrainScript:
loaded_model = combine([loaded_model.outputs[0]])
node_list = depth_first_search(loaded_model, lambda x: True) #x.is_output)
print("printing node information in the format")
print("node name (tensor shape)")
for node in node_list:
print(node.name, node.shape)
def print_training_progress(trainer, mb, frequency):
if mb % frequency == 0:
training_loss = get_train_loss(trainer)
eval_crit = get_train_eval_criterion(trainer)
print("Minibatch: {}, Train Loss: {}, Train Evaluation Criterion: {}".format(
mb, training_loss, eval_crit))
last_conv_node_name = "conv5.y"
pool_node_name = "pool3"
last_hidden_node_name = "h2_d"
roi_dim = 6
else:
raise ValueError('unknown base model: %s' % base_model)
###############################################################
###############################################################
# Instantiates a composite minibatch source for reading images, roi coordinates and roi labels for training Fast R-CNN
# The minibatch source is configured using a hierarchical dictionary of key:value pairs
def create_mb_source(features_stream_name, rois_stream_name, labels_stream_name, image_height,
image_width, num_channels, num_classes, num_rois, data_path, data_set):
rois_dim = 4 * num_rois
label_dim = num_classes * num_rois
def create_mb_source(img_height, img_width, img_channels, n_classes, n_rois, data_path, data_set):
rois_dim = 4 * n_rois
label_dim = n_classes * n_rois
path = os.path.normpath(os.path.join(abs_path, data_path))
if (data_set == 'test'):
map_file = os.path.join(path, TEST_MAP_FILENAME)
if data_set == 'test':
map_file = os.path.join(path, test_map_filename)
else:
map_file = os.path.join(path, TRAIN_MAP_FILENAME)
roi_file = os.path.join(path, data_set + ROIS_FILENAME_POSTFIX)
label_file = os.path.join(path, data_set + ROILABELS_FILENAME_POSTFIX)
map_file = os.path.join(path, train_map_filename)
roi_file = os.path.join(path, data_set + rois_filename_postfix)
label_file = os.path.join(path, data_set + roilabels_filename_postfix)
if not os.path.exists(map_file) or not os.path.exists(roi_file) or not os.path.exists(label_file):
raise RuntimeError("File '%s', '%s' or '%s' does not exist. Please run install_fastrcnn.py from Examples/Image/Detection/FastRCNN to fetch them" %
raise RuntimeError("File '%s', '%s' or '%s' does not exist. "
"Please run install_fastrcnn.py from Examples/Image/Detection/FastRCNN to fetch them" %
(map_file, roi_file, label_file))
# read images
# ??? do we still need 'transpose'?
image_source = ImageDeserializer(map_file)
image_source.ignore_labels()
image_source.map_features(features_stream_name,
[ImageDeserializer.scale(width=image_width, height=image_height, channels=num_channels,
scale_mode="pad", pad_value=114, interpolations='linear')])
[ImageDeserializer.scale(width=img_width, height=img_height, channels=img_channels,
scale_mode="pad", pad_value=114, interpolations='linear')])
# read rois and labels
roi_source = CTFDeserializer(roi_file)
roi_source.map_input(rois_stream_name, dim=rois_dim, format="dense")
roi_source.map_input(roi_stream_name, dim=rois_dim, format="dense")
label_source = CTFDeserializer(label_file)
label_source.map_input(labels_stream_name, dim=label_dim, format="dense")
label_source.map_input(label_stream_name, dim=label_dim, format="dense")
# define a composite reader
rc = ReaderConfig([image_source, roi_source, label_source], epoch_size=sys.maxsize)
rc = ReaderConfig([image_source, roi_source, label_source], epoch_size=sys.maxsize, randomize=data_set == "train")
return rc.minibatch_source()
# Defines the Fast R-CNN network model for detecting objects in images
def frcn_predictor(features, rois, num_classes):
# Load the pretrained model and find nodes
def frcn_predictor(features, rois, n_classes):
# Load the pretrained classification net and find nodes
loaded_model = load_model(model_file)
feature_node = find_by_name(loaded_model, feature_node_name)
conv5_node = find_by_name(loaded_model, conv5_node_name)
pool3_node = find_by_name(loaded_model, pool3_node_name)
h2d_node = find_by_name(loaded_model, h2d_node_name)
conv_node = find_by_name(loaded_model, last_conv_node_name)
pool_node = find_by_name(loaded_model, pool_node_name)
last_node = find_by_name(loaded_model, last_hidden_node_name)
# Clone the conv layers of the network, i.e. from the input features up to the output of the 5th conv layer
print("Cloning conv layers for %s model (%s to %s)" % (use_model, feature_node_name, conv5_node_name))
conv_layers = combine([conv5_node.owner]).clone(CloneMethod.freeze, {feature_node: Placeholder()})
# Clone the conv layers and the fully connected layers of the network
conv_layers = combine([conv_node.owner]).clone(CloneMethod.freeze, {feature_node: Placeholder()})
fc_layers = combine([last_node.owner]).clone(CloneMethod.clone, {pool_node: Placeholder()})
#import pdb
#pdb.set_trace()
# Clone the fully connected layers, i.e. from the output of the last pooling layer to the output of the last dense layer
print("Cloning fc layers for %s model (%s to %s)" % (use_model, pool3_node_name, h2d_node_name))
fc_layers = combine([h2d_node.owner]).clone(CloneMethod.clone, {pool3_node: Placeholder()})
# create Fast R-CNN model
# Create the Fast R-CNN model
feat_norm = features - Constant(114)
conv_out = conv_layers(feat_norm)
roi_out = roipooling(conv_out, rois, (6,6)) # rename to roi_max_pooling
roi_out = roipooling(conv_out, rois, (roi_dim, roi_dim))
fc_out = fc_layers(roi_out)
# z = Dense((rois[0], num_classes), map_rank=1)(fc_out) --> map_rank=1 is not yet supported
W = parameter(shape=(4096, num_classes), init=glorot_uniform())
b = parameter(shape=(num_classes), init=0)
# z = Dense(rois[0], num_classes, map_rank=1)(fc_out) # --> map_rank=1 is not yet supported
W = parameter(shape=(4096, n_classes), init=glorot_uniform())
b = parameter(shape=n_classes, init=0)
z = times(fc_out, W) + b
return z
# Trains a Fast R-CNN network model on the grocery image dataset
def frcn_grocery(base_path, debug_output=False):
num_channels = 3
image_height = cntk_padHeight # from PARAMETERS.py
image_width = cntk_padWidth # from PARAMETERS.py
num_classes = nrClasses # from PARAMETERS.py
num_rois = cntk_nrRois # from PARAMETERS.py
feats_stream_name = 'features'
rois_stream_name = 'rois'
labels_stream_name = 'roiLabels'
# Trains a Fast R-CNN model
def train_fast_rcnn(debug_output=False):
if debug_output:
print("Storing graphs and intermediate models to %s." % os.path.join(abs_path, "Output"))
#####
# training
minibatch_source = create_mb_source(feats_stream_name, rois_stream_name, labels_stream_name,
image_height, image_width, num_channels, num_classes, num_rois, base_path, "train")
features_si = minibatch_source[feats_stream_name]
rois_si = minibatch_source[rois_stream_name]
labels_si = minibatch_source[labels_stream_name]
# Create the minibatch source
minibatch_source = create_mb_source(image_height, image_width, num_channels,
num_classes, num_rois, base_path, "train")
# Input variables denoting features, rois and label data
image_input = input_variable((num_channels, image_height, image_width), features_si.m_element_type)
roi_input = input_variable((num_rois, 4), rois_si.m_element_type)
label_input = input_variable((num_rois, num_classes), labels_si.m_element_type)
image_input = input_variable((num_channels, image_height, image_width))
roi_input = input_variable((num_rois, 4))
label_input = input_variable((num_rois, num_classes))
# Instantiate the Fast R-CNN prediction model
# define mapping from reader streams to network inputs
input_map = {
image_input: minibatch_source[features_stream_name],
roi_input: minibatch_source[roi_stream_name],
label_input: minibatch_source[label_stream_name]
}
# Instantiate the Fast R-CNN prediction model and loss function
frcn_output = frcn_predictor(image_input, roi_input, num_classes)
ce = cross_entropy_with_softmax(frcn_output, label_input, axis=1)
pe = classification_error(frcn_output, label_input, axis=1)
if debug_output:
plot(frcn_output, os.path.join(abs_path, "Output", "graph_frcn.png"))
# Set learning parameters
epoch_size = 25 # for now we manually specify epoch size
mb_size = 1
max_epochs = 17
momentum_time_constant = -mb_size/np.log(0.9)
l2_reg_weight = 0.0005
lr_per_mb = [0.00001] * 10 + [0.000001] * 5 + [0.0000001]
lr_schedule = learning_rate_schedule(lr_per_mb, unit=UnitType.minibatch)
lr_per_sample = [0.00001] * 10 + [0.000001] * 5 + [0.0000001]
lr_schedule = learning_rate_schedule(lr_per_sample, unit=UnitType.sample)
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant)
# Instantiate the trainer object to drive the model training
# Instantiate the trainer object
learner = momentum_sgd(frcn_output.parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight)
trainer = Trainer(frcn_output, ce, pe, learner)
# Get minibatches of images to train with and perform model training
training_progress_output_freq = int(epoch_size / mb_size)
num_mbs = int(epoch_size * max_epochs / mb_size) # 17 epochs * 25 images / 1 mbSize
# Get minibatches of images and perform model training
print("Training Fast R-CNN model for %s epochs." % max_epochs)
log_number_of_parameters(frcn_output)
progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)
for epoch in range(max_epochs): # loop over epochs
sample_count = 0
while sample_count < epoch_size: # loop over minibatches in the epoch
data = minibatch_source.next_minibatch(min(mb_size, epoch_size-sample_count), input_map=input_map)
trainer.train_minibatch(data) # update model with it
sample_count += trainer.previous_minibatch_sample_count # count samples processed so far
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
if debug_output:
training_progress_output_freq = training_progress_output_freq / 10
progress_printer.epoch_summary(with_metric=True)
if debug_output:
frcn_output.save_model(os.path.join(abs_path, "Output", "frcn_py_%s.model" % (epoch+1)))
# Main training loop
for i in range(0, num_mbs):
mb = minibatch_source.next_minibatch(mb_size)
return frcn_output
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
arguments = {
image_input: mb[features_si],
roi_input: mb[rois_si],
label_input: mb[labels_si]
}
trainer.train_minibatch(arguments)
print_training_progress(trainer, i, training_progress_output_freq)
#####
# testing
test_minibatch_source = create_mb_source(feats_stream_name, rois_stream_name, labels_stream_name,
image_height, image_width, num_channels, num_classes, num_rois, base_path, "test")
features_si = test_minibatch_source[feats_stream_name]
rois_si = test_minibatch_source[rois_stream_name]
mb_size = 1
num_mbs = 5
# Tests a Fast R-CNN model
def test_fast_rcnn(model):
test_minibatch_source = create_mb_source(image_height, image_width, num_channels,
num_classes, num_rois, base_path, "test")
input_map = {
model.arguments[0]: test_minibatch_source[features_stream_name],
model.arguments[1]: test_minibatch_source[roi_stream_name],
}
# evaluate test images and write netwrok output to file
print("Evaluating Fast R-CNN model for %s images." % num_test_images)
results_file_path = base_path + "test.z"
with open(results_file_path, 'wb') as results_file:
for i in range(0, num_mbs):
mb = test_minibatch_source.next_minibatch(mb_size)
# Specify the mapping of input variables in the model to actual minibatch data to be tested with
arguments = {
image_input: mb[features_si],
roi_input: mb[rois_si],
}
output = trainer.model.eval(arguments)
out_values = output[0,0].flatten()
for i in range(0, num_test_images):
data = test_minibatch_source.next_minibatch(1, input_map=input_map)
output = model.eval(data)
out_values = output[0, 0].flatten()
np.savetxt(results_file, out_values[np.newaxis], fmt="%.6f")
if (i+1) % 100 == 0:
print("Evaluated %s images.." % (i+1))
return
#if __name__ == '__main__':
# Specify the target device to be used for computing, if you do not want to
# use the best available one, e.g.
# set_default_device(cpu())
os.chdir(cntkFilesDir)
print_all_node_names(model_file)
# The main method trains and evaluates a Fast R-CNN model.
# If a trained model is already available it is loaded an no training will be performed.
if __name__ == '__main__':
os.chdir(base_path)
model_path = os.path.join(abs_path, "Output", "frcn_py.model")
frcn_grocery(cntkFilesDir)
# Train only is no model exists yet
if os.path.exists(model_path):
print("Loading existing model from %s" % model_path)
trained_model = load_model(model_path)
else:
trained_model = train_fast_rcnn()
trained_model.save_model(model_path)
print("Stored trained model at %s" % model_path)
# Evaluate the test set
test_fast_rcnn(trained_model)

Просмотреть файл

@ -15,7 +15,7 @@ datasetName = "grocery"
# default parameters
############################
# cntk params
cntk_nrRois = 100 # 2000 # how many ROIs to zero-pad
cntk_nrRois = 100 # how many ROIs to zero-pad. Use 100 to get quick result. Use 2000 to get good results.
cntk_padWidth = 1000
cntk_padHeight = 1000
@ -49,6 +49,11 @@ train_posOverlapThres = 0.5 # threshold for marking ROIs as positive.
nmsThreshold = 0.3 # Non-Maxima suppression threshold (in range [0,1]).
# The lower the more ROIs will be combined. Used in 5_evaluateResults and 5_visualizeResults.
cntk_num_train_images = -1 # set per data set below
cntk_num_test_images = -1 # set per data set below
cntk_mb_size = -1 # set per data set below
cntk_max_epochs = -1 # set per data set below
cntk_momentum_time_constant = -1 # set per data set below
############################
# project-specific parameters
@ -66,6 +71,11 @@ if datasetName.startswith("grocery"):
# model training / scoring
classifier = 'nn'
cntk_num_train_images = 25
cntk_num_test_images = 5
cntk_mb_size = 5
cntk_max_epochs = 20
cntk_momentum_time_constant = 10
# postprocessing
nmsThreshold = 0.01
@ -90,6 +100,11 @@ elif datasetName.startswith("pascalVoc"):
# use cntk_nrRois = 4000. more than 99% of the test images have less than 4000 rois, but 50% more than 2000
# model training / scoring
classifier = 'nn'
cntk_num_train_images = 5011
cntk_num_test_images = 4952
cntk_mb_size = 2
cntk_max_epochs = 17
cntk_momentum_time_constant = 20
# database
imdbs = dict()

Просмотреть файл

Просмотреть файл

@ -50,14 +50,9 @@ Train = {
action = "train"
BrainScriptNetworkBuilder = {
# using ndl model:
#network = BS.Network.Load ("../../../../../PretrainedModels/AlexNet.model")
#convLayers = BS.Network.CloneFunction(network.features, network.conv5_y, parameters = "constant")
#fcLayers = BS.Network.CloneFunction(network.pool3, network.h2_d)
# using brain scipt model
network = BS.Network.Load ("../../../../../PretrainedModels/AlexNetBS.model")
convLayers = BS.Network.CloneFunction(network.features, network.z_x___x___x_x, parameters = "constant") # network.features, network.z.x._.x._.x.x
fcLayers = BS.Network.CloneFunction(network.z_x___x___x, network.z_x) # network.z.x._.x._.x, network.z.x
network = BS.Network.Load ("../../../../../PretrainedModels/AlexNet.model")
convLayers = BS.Network.CloneFunction(network.features, network.conv5_y, parameters = "constant")
fcLayers = BS.Network.CloneFunction(network.pool3, network.h2_d)
model (features, rois) = {
featNorm = features - 114
@ -95,7 +90,7 @@ Train = {
maxEpochs = 17
learningRatesPerSample=0.00001*10:0.000001*5:0.0000001
momentumPerMB=0.9
momentumAsTimeConstant = 20
gradUpdateType=None
L2RegWeight=0.0005
dropoutRate=0.5