remove reader requirement from test and predict actions and update mnist example

This commit is contained in:
jeanfad 2016-03-24 18:49:09 +01:00 коммит произвёл Willi Richert
Родитель fefbbf9e64
Коммит 6b1438edc7
3 изменённых файлов: 43 добавлений и 33 удалений

Просмотреть файл

@ -103,12 +103,6 @@ class AbstractContext(object, metaclass=ABCMeta):
import shutil
shutil.rmtree(self.directory)
def to_config(self):
'''
Generates the CNTK configuration for the root node.
'''
return self.root_node.to_config()
def _generate_train_config(self, optimizer, reader, override_existing):
'''
Generates the configuration file for the train action.
@ -128,7 +122,7 @@ class AbstractContext(object, metaclass=ABCMeta):
tmpl = open(CNTK_TRAIN_TEMPLATE_PATH, "r").read()
model_filename = os.path.join(model_dir, self.name)
description, has_inputs, readers = self.to_config()
description, has_inputs, readers = self.root_node.to_config()
if reader:
readers.append(reader)
@ -147,10 +141,18 @@ class AbstractContext(object, metaclass=ABCMeta):
'''
tmpl = open(CNTK_TEST_TEMPLATE_PATH, "r").read()
model_filename = os.path.join(self.directory, 'Models', self.name)
# if no reader is passed generate the reader from the network
if reader:
reader_config = reader.generate_config()
else:
description, has_inputs, readers = self.root_node.to_config()
reader_config = '\n'.join(r.generate_config() for r in readers)
tmpl_dict = {
'DevideId': self.device_id,
'ModelPath': model_filename,
'Reader': reader.generate_config(),
'Reader': reader_config,
}
return tmpl % tmpl_dict
@ -162,11 +164,19 @@ class AbstractContext(object, metaclass=ABCMeta):
tmpl = open(CNTK_PREDICT_TEMPLATE_PATH, "r").read()
model_filename = os.path.join(self.directory, 'Models', self.name)
output_filename_base = os.path.join(self.directory, 'Outputs', self.name)
# if no reader is passed generate the reader from the network
if reader:
reader_config = reader.generate_config()
else:
description, has_inputs, readers = self.root_node.to_config()
reader_config = '\n'.join(r.generate_config() for r in readers)
tmpl_dict = {
'DevideId': self.device_id,
'ModelPath': model_filename,
'PredictOutputFile': output_filename_base,
'Reader': reader.generate_config(),
'Reader': reader_config,
}
return tmpl % tmpl_dict
@ -204,7 +214,7 @@ class AbstractContext(object, metaclass=ABCMeta):
return tmpl % tmpl_dict
@abstractmethod
def train(self, reader=None):
def train(self, optimizer, reader=None, override_existing = True):
'''
Abstract method for the action train.
:param reader: the reader to use for this action. Alternatively, you
@ -288,7 +298,7 @@ class Context(AbstractContext):
:param override_existing: if the folder exists already override it
'''
config_content = self._generate_train_config(optimizer, reader, override_existing)
output = self._call_cntk(CNTK_TRAIN_CONFIG_FILENAME, config_content)
return self._call_cntk(CNTK_TRAIN_CONFIG_FILENAME, config_content)
def test(self, reader=None):
'''
@ -297,7 +307,7 @@ class Context(AbstractContext):
can attach a reader directly to the input node.
'''
config_content = self._generate_test_config(reader)
output = self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
return self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
def predict(self, reader=None):
'''
@ -308,7 +318,7 @@ class Context(AbstractContext):
Returns the predicted output
'''
config_content = self._generate_predict_config(reader)
output = self._call_cntk(CNTK_PREDICT_CONFIG_FILENAME, config_content)
return self._call_cntk(CNTK_PREDICT_CONFIG_FILENAME, config_content)
'''
Regular expression to parse the shape information of the nodes out of

Просмотреть файл

@ -29,27 +29,23 @@ if (__name__ == "__main__"):
hidden_dim=200
training_filename=os.path.join("Data", "Train-28x28.txt")
test_filename=os.path.join("Data", "Test-28x28.txt")
features = Input(feat_dim, var_name='features')
features.attach_uci_fast_reader(training_filename, 1)
feat_scale = Constant(0.00390625)
feats_scaled = Scale(feat_scale, features)
labels = Input(label_dim, tag='label', var_name='labels')
labels.attach_uci_fast_reader(training_filename, 0, True, 1, os.path.join("Data", "labelsmap.txt"))
h1 = dnn_sigmoid_layer(feat_dim, hidden_dim, feats_scaled, 1)
out = dnn_layer(hidden_dim, label_dim, h1, 1)
out.tag = 'output'
ec = CrossEntropyWithSoftmax(labels, out)
ec.tag = 'criterion'
# Build the reader
r = UCIFastReader(filename=training_filename)
# Add the input node to the reader
r.add_input(features, 1, feat_dim)
r.add_input(labels, 0, 1, label_dim, os.path.join("Data", "labelsmap.txt"))
ec.tag = 'criterion'
# Build the optimizer (settings are scaled down)
my_sgd = SGD(epoch_size = 600, minibatch_size = 32, learning_ratesPerMB = 0.1, max_epochs = 5, momentum_per_mb = 0)
@ -57,7 +53,10 @@ if (__name__ == "__main__"):
# Create a context or re-use if already there
with Context('mnist_one_layer', root_node= ec, clean_up=False) as ctx:
# CNTK actions
ctx.train(my_sgd, r)
r["FileName"] = os.path.join("Data", "Test-28x28.txt")
ctx.test(r)
ctx.predict(r)
#ctx.train(my_sgd)
features.attach_uci_fast_reader(test_filename, 1)
labels.attach_uci_fast_reader(test_filename, 0, True, 1, os.path.join("Data", "labelsmap.txt"))
ctx.predict()
ctx.test()
ctx.predict()

Просмотреть файл

@ -147,15 +147,15 @@ class ComputationNode(object):
inputs_param = [p_value]
input_nodes_vars = []
for p_value in inputs_param:
if p_value in unrolled_nodes:
for pv in inputs_param:
if pv in unrolled_nodes:
# we have seen this node already, so just retrieve its
# name
child_var = unrolled_nodes[p_value]
child_var = unrolled_nodes[pv]
else:
child_var, node_counter, child_desc = p_value._to_config_recursively(
child_var, node_counter, child_desc = pv._to_config_recursively(
desc, unrolled_nodes, inputs, readers, node_counter)
unrolled_nodes[p_value] = child_var
unrolled_nodes[pv] = child_var
input_nodes_vars.append(child_var)
param_variable_names.append(_tuple_to_cntk_shape(input_nodes_vars))
@ -164,7 +164,7 @@ class ComputationNode(object):
self._param_to_brainscript(p_name, p_value))
if self.reader:
readers.append(self.reader)
readers.add(self.reader)
if self._is_input():
inputs.add(self)
@ -188,7 +188,7 @@ class ComputationNode(object):
'''
unrolled_nodes = {}
inputs=set()
readers=[]
readers=set()
var_name, node_counter, desc = self._to_config_recursively(
desc=[],
unrolled_nodes=unrolled_nodes,
@ -198,13 +198,14 @@ class ComputationNode(object):
return var_name, node_counter, desc, len(inputs)>0, readers
def _dedupe_readers(self, readers):
import copy
readers_map = {}
for r in readers:
filename = r['FileName']
if filename in readers_map:
readers_map[filename].inputs_def.extend(r.inputs_def)
else:
readers_map[filename] = r
readers_map[filename] = copy.deepcopy(r)
return [r for r in readers_map.values()]