FRCN Python perf better than BS on AlexNet
This commit is contained in:
Родитель
517d6c06d6
Коммит
8b64aafe8f
|
@ -17,6 +17,7 @@ image_sets = ["train", "test"]
|
|||
boAddSelectiveSearchROIs = True
|
||||
boAddRoisOnGrid = True
|
||||
|
||||
|
||||
####################################
|
||||
# Main
|
||||
####################################
|
||||
|
|
|
@ -5,237 +5,220 @@
|
|||
# ==============================================================================
|
||||
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import os, sys, importlib
|
||||
from cntk import * # Trainer, load_model, UnitType
|
||||
from cntk.device import cpu, set_default_device
|
||||
from cntk.learner import sgd
|
||||
from cntk import Trainer, UnitType, load_model
|
||||
from cntk.blocks import Placeholder, Constant
|
||||
from cntk.layers import Dense
|
||||
from cntk.graph import find_by_name, plot
|
||||
from cntk.initializer import glorot_uniform
|
||||
from cntk.io import ReaderConfig, ImageDeserializer, CTFDeserializer
|
||||
from cntk.learner import momentum_sgd, learning_rate_schedule, momentum_as_time_constant_schedule
|
||||
from cntk.ops import input_variable, constant, parameter, cross_entropy_with_softmax, classification_error, times, combine
|
||||
from cntk.ops import input_variable, parameter, cross_entropy_with_softmax, classification_error, times, combine
|
||||
from cntk.ops import roipooling
|
||||
from cntk.ops.functions import CloneMethod
|
||||
from cntk.io import ReaderConfig, ImageDeserializer, CTFDeserializer, StreamConfiguration
|
||||
from cntk.initializer import glorot_uniform
|
||||
from cntk.graph import find_by_name, depth_first_search
|
||||
import PARAMETERS
|
||||
locals().update(importlib.import_module("PARAMETERS").__dict__)
|
||||
from cntk.utils import log_number_of_parameters, ProgressPrinter
|
||||
from PARAMETERS import *
|
||||
import numpy as np
|
||||
import os, sys
|
||||
|
||||
###############################################################
|
||||
###############################################################
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.join(abs_path, "..", ".."))
|
||||
|
||||
TRAIN_MAP_FILENAME = 'train.txt'
|
||||
TEST_MAP_FILENAME = 'test.txt'
|
||||
ROIS_FILENAME_POSTFIX = '.rois.txt'
|
||||
ROILABELS_FILENAME_POSTFIX = '.roilabels.txt'
|
||||
# file and stream names
|
||||
train_map_filename = 'train.txt'
|
||||
test_map_filename = 'test.txt'
|
||||
rois_filename_postfix = '.rois.txt'
|
||||
roilabels_filename_postfix = '.roilabels.txt'
|
||||
features_stream_name = 'features'
|
||||
roi_stream_name = 'rois'
|
||||
label_stream_name = 'roiLabels'
|
||||
|
||||
# from PARAMETERS.py
|
||||
base_path = cntkFilesDir
|
||||
num_channels = 3
|
||||
image_height = cntk_padHeight
|
||||
image_width = cntk_padWidth
|
||||
num_classes = nrClasses
|
||||
num_rois = cntk_nrRois
|
||||
epoch_size = cntk_num_train_images
|
||||
num_test_images = cntk_num_test_images
|
||||
mb_size = cntk_mb_size
|
||||
max_epochs = cntk_max_epochs
|
||||
momentum_time_constant = cntk_momentum_time_constant
|
||||
|
||||
# model specific variables
|
||||
use_model = "VGG"
|
||||
|
||||
if (use_model == "AlexNet"):
|
||||
# model specific variables (only AlexNet for now)
|
||||
base_model = "AlexNet"
|
||||
if base_model == "AlexNet":
|
||||
model_file = "../../../../../PretrainedModels/AlexNet.model"
|
||||
feature_node_name = "features"
|
||||
conv5_node_name = "conv5.y" # "z.x._.x._.x.x_output"
|
||||
pool3_node_name = "pool3" # "z.x._.x._.x_output"
|
||||
h2d_node_name = "h2_d" # "z.x_output"
|
||||
elif (use_model == "VGG"):
|
||||
model_file = "../../../../../PretrainedModels/VGG19_legacy.model"
|
||||
feature_node_name = "data"
|
||||
conv5_node_name = "relu5_4" # "z.x._.x._.x.x_output"
|
||||
pool3_node_name = "pool5" # "z.x._.x._.x_output"
|
||||
h2d_node_name = "drop7" # "z.x_output"
|
||||
|
||||
|
||||
# Helper to print all node names
|
||||
def print_all_node_names(model_file, is_BrainScript=True):
|
||||
loaded_model = load_model(model_file)
|
||||
if is_BrainScript:
|
||||
loaded_model = combine([loaded_model.outputs[0]])
|
||||
node_list = depth_first_search(loaded_model, lambda x: True) #x.is_output)
|
||||
print("printing node information in the format")
|
||||
print("node name (tensor shape)")
|
||||
for node in node_list:
|
||||
print(node.name, node.shape)
|
||||
|
||||
|
||||
def print_training_progress(trainer, mb, frequency):
|
||||
if mb % frequency == 0:
|
||||
training_loss = get_train_loss(trainer)
|
||||
eval_crit = get_train_eval_criterion(trainer)
|
||||
print("Minibatch: {}, Train Loss: {}, Train Evaluation Criterion: {}".format(
|
||||
mb, training_loss, eval_crit))
|
||||
last_conv_node_name = "conv5.y"
|
||||
pool_node_name = "pool3"
|
||||
last_hidden_node_name = "h2_d"
|
||||
roi_dim = 6
|
||||
else:
|
||||
raise ValueError('unknown base model: %s' % base_model)
|
||||
###############################################################
|
||||
###############################################################
|
||||
|
||||
|
||||
# Instantiates a composite minibatch source for reading images, roi coordinates and roi labels for training Fast R-CNN
|
||||
# The minibatch source is configured using a hierarchical dictionary of key:value pairs
|
||||
def create_mb_source(features_stream_name, rois_stream_name, labels_stream_name, image_height,
|
||||
image_width, num_channels, num_classes, num_rois, data_path, data_set):
|
||||
rois_dim = 4 * num_rois
|
||||
label_dim = num_classes * num_rois
|
||||
def create_mb_source(img_height, img_width, img_channels, n_classes, n_rois, data_path, data_set):
|
||||
rois_dim = 4 * n_rois
|
||||
label_dim = n_classes * n_rois
|
||||
|
||||
path = os.path.normpath(os.path.join(abs_path, data_path))
|
||||
if (data_set == 'test'):
|
||||
map_file = os.path.join(path, TEST_MAP_FILENAME)
|
||||
if data_set == 'test':
|
||||
map_file = os.path.join(path, test_map_filename)
|
||||
else:
|
||||
map_file = os.path.join(path, TRAIN_MAP_FILENAME)
|
||||
roi_file = os.path.join(path, data_set + ROIS_FILENAME_POSTFIX)
|
||||
label_file = os.path.join(path, data_set + ROILABELS_FILENAME_POSTFIX)
|
||||
map_file = os.path.join(path, train_map_filename)
|
||||
roi_file = os.path.join(path, data_set + rois_filename_postfix)
|
||||
label_file = os.path.join(path, data_set + roilabels_filename_postfix)
|
||||
|
||||
if not os.path.exists(map_file) or not os.path.exists(roi_file) or not os.path.exists(label_file):
|
||||
raise RuntimeError("File '%s', '%s' or '%s' does not exist. Please run install_fastrcnn.py from Examples/Image/Detection/FastRCNN to fetch them" %
|
||||
raise RuntimeError("File '%s', '%s' or '%s' does not exist. "
|
||||
"Please run install_fastrcnn.py from Examples/Image/Detection/FastRCNN to fetch them" %
|
||||
(map_file, roi_file, label_file))
|
||||
|
||||
# read images
|
||||
# ??? do we still need 'transpose'?
|
||||
image_source = ImageDeserializer(map_file)
|
||||
image_source.ignore_labels()
|
||||
image_source.map_features(features_stream_name,
|
||||
[ImageDeserializer.scale(width=image_width, height=image_height, channels=num_channels,
|
||||
scale_mode="pad", pad_value=114, interpolations='linear')])
|
||||
[ImageDeserializer.scale(width=img_width, height=img_height, channels=img_channels,
|
||||
scale_mode="pad", pad_value=114, interpolations='linear')])
|
||||
|
||||
# read rois and labels
|
||||
roi_source = CTFDeserializer(roi_file)
|
||||
roi_source.map_input(rois_stream_name, dim=rois_dim, format="dense")
|
||||
roi_source.map_input(roi_stream_name, dim=rois_dim, format="dense")
|
||||
label_source = CTFDeserializer(label_file)
|
||||
label_source.map_input(labels_stream_name, dim=label_dim, format="dense")
|
||||
label_source.map_input(label_stream_name, dim=label_dim, format="dense")
|
||||
|
||||
# define a composite reader
|
||||
rc = ReaderConfig([image_source, roi_source, label_source], epoch_size=sys.maxsize)
|
||||
rc = ReaderConfig([image_source, roi_source, label_source], epoch_size=sys.maxsize, randomize=data_set == "train")
|
||||
return rc.minibatch_source()
|
||||
|
||||
|
||||
# Defines the Fast R-CNN network model for detecting objects in images
|
||||
def frcn_predictor(features, rois, num_classes):
|
||||
# Load the pretrained model and find nodes
|
||||
def frcn_predictor(features, rois, n_classes):
|
||||
# Load the pretrained classification net and find nodes
|
||||
loaded_model = load_model(model_file)
|
||||
feature_node = find_by_name(loaded_model, feature_node_name)
|
||||
conv5_node = find_by_name(loaded_model, conv5_node_name)
|
||||
pool3_node = find_by_name(loaded_model, pool3_node_name)
|
||||
h2d_node = find_by_name(loaded_model, h2d_node_name)
|
||||
conv_node = find_by_name(loaded_model, last_conv_node_name)
|
||||
pool_node = find_by_name(loaded_model, pool_node_name)
|
||||
last_node = find_by_name(loaded_model, last_hidden_node_name)
|
||||
|
||||
# Clone the conv layers of the network, i.e. from the input features up to the output of the 5th conv layer
|
||||
print("Cloning conv layers for %s model (%s to %s)" % (use_model, feature_node_name, conv5_node_name))
|
||||
conv_layers = combine([conv5_node.owner]).clone(CloneMethod.freeze, {feature_node: Placeholder()})
|
||||
# Clone the conv layers and the fully connected layers of the network
|
||||
conv_layers = combine([conv_node.owner]).clone(CloneMethod.freeze, {feature_node: Placeholder()})
|
||||
fc_layers = combine([last_node.owner]).clone(CloneMethod.clone, {pool_node: Placeholder()})
|
||||
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
|
||||
# Clone the fully connected layers, i.e. from the output of the last pooling layer to the output of the last dense layer
|
||||
print("Cloning fc layers for %s model (%s to %s)" % (use_model, pool3_node_name, h2d_node_name))
|
||||
fc_layers = combine([h2d_node.owner]).clone(CloneMethod.clone, {pool3_node: Placeholder()})
|
||||
|
||||
# create Fast R-CNN model
|
||||
# Create the Fast R-CNN model
|
||||
feat_norm = features - Constant(114)
|
||||
conv_out = conv_layers(feat_norm)
|
||||
roi_out = roipooling(conv_out, rois, (6,6)) # rename to roi_max_pooling
|
||||
roi_out = roipooling(conv_out, rois, (roi_dim, roi_dim))
|
||||
fc_out = fc_layers(roi_out)
|
||||
|
||||
# z = Dense((rois[0], num_classes), map_rank=1)(fc_out) --> map_rank=1 is not yet supported
|
||||
W = parameter(shape=(4096, num_classes), init=glorot_uniform())
|
||||
b = parameter(shape=(num_classes), init=0)
|
||||
# z = Dense(rois[0], num_classes, map_rank=1)(fc_out) # --> map_rank=1 is not yet supported
|
||||
W = parameter(shape=(4096, n_classes), init=glorot_uniform())
|
||||
b = parameter(shape=n_classes, init=0)
|
||||
z = times(fc_out, W) + b
|
||||
|
||||
return z
|
||||
|
||||
|
||||
# Trains a Fast R-CNN network model on the grocery image dataset
|
||||
def frcn_grocery(base_path, debug_output=False):
|
||||
num_channels = 3
|
||||
image_height = cntk_padHeight # from PARAMETERS.py
|
||||
image_width = cntk_padWidth # from PARAMETERS.py
|
||||
num_classes = nrClasses # from PARAMETERS.py
|
||||
num_rois = cntk_nrRois # from PARAMETERS.py
|
||||
feats_stream_name = 'features'
|
||||
rois_stream_name = 'rois'
|
||||
labels_stream_name = 'roiLabels'
|
||||
# Trains a Fast R-CNN model
|
||||
def train_fast_rcnn(debug_output=False):
|
||||
if debug_output:
|
||||
print("Storing graphs and intermediate models to %s." % os.path.join(abs_path, "Output"))
|
||||
|
||||
#####
|
||||
# training
|
||||
minibatch_source = create_mb_source(feats_stream_name, rois_stream_name, labels_stream_name,
|
||||
image_height, image_width, num_channels, num_classes, num_rois, base_path, "train")
|
||||
features_si = minibatch_source[feats_stream_name]
|
||||
rois_si = minibatch_source[rois_stream_name]
|
||||
labels_si = minibatch_source[labels_stream_name]
|
||||
# Create the minibatch source
|
||||
minibatch_source = create_mb_source(image_height, image_width, num_channels,
|
||||
num_classes, num_rois, base_path, "train")
|
||||
|
||||
# Input variables denoting features, rois and label data
|
||||
image_input = input_variable((num_channels, image_height, image_width), features_si.m_element_type)
|
||||
roi_input = input_variable((num_rois, 4), rois_si.m_element_type)
|
||||
label_input = input_variable((num_rois, num_classes), labels_si.m_element_type)
|
||||
image_input = input_variable((num_channels, image_height, image_width))
|
||||
roi_input = input_variable((num_rois, 4))
|
||||
label_input = input_variable((num_rois, num_classes))
|
||||
|
||||
# Instantiate the Fast R-CNN prediction model
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
image_input: minibatch_source[features_stream_name],
|
||||
roi_input: minibatch_source[roi_stream_name],
|
||||
label_input: minibatch_source[label_stream_name]
|
||||
}
|
||||
|
||||
# Instantiate the Fast R-CNN prediction model and loss function
|
||||
frcn_output = frcn_predictor(image_input, roi_input, num_classes)
|
||||
|
||||
ce = cross_entropy_with_softmax(frcn_output, label_input, axis=1)
|
||||
pe = classification_error(frcn_output, label_input, axis=1)
|
||||
if debug_output:
|
||||
plot(frcn_output, os.path.join(abs_path, "Output", "graph_frcn.png"))
|
||||
|
||||
# Set learning parameters
|
||||
epoch_size = 25 # for now we manually specify epoch size
|
||||
mb_size = 1
|
||||
max_epochs = 17
|
||||
momentum_time_constant = -mb_size/np.log(0.9)
|
||||
l2_reg_weight = 0.0005
|
||||
|
||||
lr_per_mb = [0.00001] * 10 + [0.000001] * 5 + [0.0000001]
|
||||
lr_schedule = learning_rate_schedule(lr_per_mb, unit=UnitType.minibatch)
|
||||
lr_per_sample = [0.00001] * 10 + [0.000001] * 5 + [0.0000001]
|
||||
lr_schedule = learning_rate_schedule(lr_per_sample, unit=UnitType.sample)
|
||||
mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
# Instantiate the trainer object
|
||||
learner = momentum_sgd(frcn_output.parameters, lr_schedule, mm_schedule, l2_regularization_weight=l2_reg_weight)
|
||||
trainer = Trainer(frcn_output, ce, pe, learner)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
training_progress_output_freq = int(epoch_size / mb_size)
|
||||
num_mbs = int(epoch_size * max_epochs / mb_size) # 17 epochs * 25 images / 1 mbSize
|
||||
# Get minibatches of images and perform model training
|
||||
print("Training Fast R-CNN model for %s epochs." % max_epochs)
|
||||
log_number_of_parameters(frcn_output)
|
||||
progress_printer = ProgressPrinter(tag='Training', num_epochs=max_epochs)
|
||||
for epoch in range(max_epochs): # loop over epochs
|
||||
sample_count = 0
|
||||
while sample_count < epoch_size: # loop over minibatches in the epoch
|
||||
data = minibatch_source.next_minibatch(min(mb_size, epoch_size-sample_count), input_map=input_map)
|
||||
trainer.train_minibatch(data) # update model with it
|
||||
sample_count += trainer.previous_minibatch_sample_count # count samples processed so far
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
if debug_output:
|
||||
training_progress_output_freq = training_progress_output_freq / 10
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
if debug_output:
|
||||
frcn_output.save_model(os.path.join(abs_path, "Output", "frcn_py_%s.model" % (epoch+1)))
|
||||
|
||||
# Main training loop
|
||||
for i in range(0, num_mbs):
|
||||
mb = minibatch_source.next_minibatch(mb_size)
|
||||
return frcn_output
|
||||
|
||||
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
|
||||
arguments = {
|
||||
image_input: mb[features_si],
|
||||
roi_input: mb[rois_si],
|
||||
label_input: mb[labels_si]
|
||||
}
|
||||
trainer.train_minibatch(arguments)
|
||||
print_training_progress(trainer, i, training_progress_output_freq)
|
||||
|
||||
#####
|
||||
# testing
|
||||
test_minibatch_source = create_mb_source(feats_stream_name, rois_stream_name, labels_stream_name,
|
||||
image_height, image_width, num_channels, num_classes, num_rois, base_path, "test")
|
||||
features_si = test_minibatch_source[feats_stream_name]
|
||||
rois_si = test_minibatch_source[rois_stream_name]
|
||||
|
||||
mb_size = 1
|
||||
num_mbs = 5
|
||||
# Tests a Fast R-CNN model
|
||||
def test_fast_rcnn(model):
|
||||
test_minibatch_source = create_mb_source(image_height, image_width, num_channels,
|
||||
num_classes, num_rois, base_path, "test")
|
||||
input_map = {
|
||||
model.arguments[0]: test_minibatch_source[features_stream_name],
|
||||
model.arguments[1]: test_minibatch_source[roi_stream_name],
|
||||
}
|
||||
|
||||
# evaluate test images and write netwrok output to file
|
||||
print("Evaluating Fast R-CNN model for %s images." % num_test_images)
|
||||
results_file_path = base_path + "test.z"
|
||||
with open(results_file_path, 'wb') as results_file:
|
||||
for i in range(0, num_mbs):
|
||||
mb = test_minibatch_source.next_minibatch(mb_size)
|
||||
|
||||
# Specify the mapping of input variables in the model to actual minibatch data to be tested with
|
||||
arguments = {
|
||||
image_input: mb[features_si],
|
||||
roi_input: mb[rois_si],
|
||||
}
|
||||
output = trainer.model.eval(arguments)
|
||||
out_values = output[0,0].flatten()
|
||||
for i in range(0, num_test_images):
|
||||
data = test_minibatch_source.next_minibatch(1, input_map=input_map)
|
||||
output = model.eval(data)
|
||||
out_values = output[0, 0].flatten()
|
||||
np.savetxt(results_file, out_values[np.newaxis], fmt="%.6f")
|
||||
if (i+1) % 100 == 0:
|
||||
print("Evaluated %s images.." % (i+1))
|
||||
|
||||
return
|
||||
|
||||
#if __name__ == '__main__':
|
||||
# Specify the target device to be used for computing, if you do not want to
|
||||
# use the best available one, e.g.
|
||||
# set_default_device(cpu())
|
||||
|
||||
os.chdir(cntkFilesDir)
|
||||
print_all_node_names(model_file)
|
||||
# The main method trains and evaluates a Fast R-CNN model.
|
||||
# If a trained model is already available it is loaded an no training will be performed.
|
||||
if __name__ == '__main__':
|
||||
os.chdir(base_path)
|
||||
model_path = os.path.join(abs_path, "Output", "frcn_py.model")
|
||||
|
||||
frcn_grocery(cntkFilesDir)
|
||||
# Train only is no model exists yet
|
||||
if os.path.exists(model_path):
|
||||
print("Loading existing model from %s" % model_path)
|
||||
trained_model = load_model(model_path)
|
||||
else:
|
||||
trained_model = train_fast_rcnn()
|
||||
trained_model.save_model(model_path)
|
||||
print("Stored trained model at %s" % model_path)
|
||||
|
||||
# Evaluate the test set
|
||||
test_fast_rcnn(trained_model)
|
||||
|
|
|
@ -15,7 +15,7 @@ datasetName = "grocery"
|
|||
# default parameters
|
||||
############################
|
||||
# cntk params
|
||||
cntk_nrRois = 100 # 2000 # how many ROIs to zero-pad
|
||||
cntk_nrRois = 100 # how many ROIs to zero-pad. Use 100 to get quick result. Use 2000 to get good results.
|
||||
cntk_padWidth = 1000
|
||||
cntk_padHeight = 1000
|
||||
|
||||
|
@ -49,6 +49,11 @@ train_posOverlapThres = 0.5 # threshold for marking ROIs as positive.
|
|||
nmsThreshold = 0.3 # Non-Maxima suppression threshold (in range [0,1]).
|
||||
# The lower the more ROIs will be combined. Used in 5_evaluateResults and 5_visualizeResults.
|
||||
|
||||
cntk_num_train_images = -1 # set per data set below
|
||||
cntk_num_test_images = -1 # set per data set below
|
||||
cntk_mb_size = -1 # set per data set below
|
||||
cntk_max_epochs = -1 # set per data set below
|
||||
cntk_momentum_time_constant = -1 # set per data set below
|
||||
|
||||
############################
|
||||
# project-specific parameters
|
||||
|
@ -66,6 +71,11 @@ if datasetName.startswith("grocery"):
|
|||
|
||||
# model training / scoring
|
||||
classifier = 'nn'
|
||||
cntk_num_train_images = 25
|
||||
cntk_num_test_images = 5
|
||||
cntk_mb_size = 5
|
||||
cntk_max_epochs = 20
|
||||
cntk_momentum_time_constant = 10
|
||||
|
||||
# postprocessing
|
||||
nmsThreshold = 0.01
|
||||
|
@ -90,6 +100,11 @@ elif datasetName.startswith("pascalVoc"):
|
|||
# use cntk_nrRois = 4000. more than 99% of the test images have less than 4000 rois, but 50% more than 2000
|
||||
# model training / scoring
|
||||
classifier = 'nn'
|
||||
cntk_num_train_images = 5011
|
||||
cntk_num_test_images = 4952
|
||||
cntk_mb_size = 2
|
||||
cntk_max_epochs = 17
|
||||
cntk_momentum_time_constant = 20
|
||||
|
||||
# database
|
||||
imdbs = dict()
|
||||
|
|
|
@ -50,14 +50,9 @@ Train = {
|
|||
action = "train"
|
||||
|
||||
BrainScriptNetworkBuilder = {
|
||||
# using ndl model:
|
||||
#network = BS.Network.Load ("../../../../../PretrainedModels/AlexNet.model")
|
||||
#convLayers = BS.Network.CloneFunction(network.features, network.conv5_y, parameters = "constant")
|
||||
#fcLayers = BS.Network.CloneFunction(network.pool3, network.h2_d)
|
||||
# using brain scipt model
|
||||
network = BS.Network.Load ("../../../../../PretrainedModels/AlexNetBS.model")
|
||||
convLayers = BS.Network.CloneFunction(network.features, network.z_x___x___x_x, parameters = "constant") # network.features, network.z.x._.x._.x.x
|
||||
fcLayers = BS.Network.CloneFunction(network.z_x___x___x, network.z_x) # network.z.x._.x._.x, network.z.x
|
||||
network = BS.Network.Load ("../../../../../PretrainedModels/AlexNet.model")
|
||||
convLayers = BS.Network.CloneFunction(network.features, network.conv5_y, parameters = "constant")
|
||||
fcLayers = BS.Network.CloneFunction(network.pool3, network.h2_d)
|
||||
|
||||
model (features, rois) = {
|
||||
featNorm = features - 114
|
||||
|
@ -95,7 +90,7 @@ Train = {
|
|||
maxEpochs = 17
|
||||
|
||||
learningRatesPerSample=0.00001*10:0.000001*5:0.0000001
|
||||
momentumPerMB=0.9
|
||||
momentumAsTimeConstant = 20
|
||||
gradUpdateType=None
|
||||
L2RegWeight=0.0005
|
||||
dropoutRate=0.5
|
||||
|
|
Загрузка…
Ссылка в новой задаче