remove reader requirement from test and predict actions and update mnist example
This commit is contained in:
Родитель
fefbbf9e64
Коммит
6b1438edc7
|
@ -103,12 +103,6 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
import shutil
|
||||
shutil.rmtree(self.directory)
|
||||
|
||||
def to_config(self):
|
||||
'''
|
||||
Generates the CNTK configuration for the root node.
|
||||
'''
|
||||
return self.root_node.to_config()
|
||||
|
||||
def _generate_train_config(self, optimizer, reader, override_existing):
|
||||
'''
|
||||
Generates the configuration file for the train action.
|
||||
|
@ -128,7 +122,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
|
||||
tmpl = open(CNTK_TRAIN_TEMPLATE_PATH, "r").read()
|
||||
model_filename = os.path.join(model_dir, self.name)
|
||||
description, has_inputs, readers = self.to_config()
|
||||
description, has_inputs, readers = self.root_node.to_config()
|
||||
if reader:
|
||||
readers.append(reader)
|
||||
|
||||
|
@ -147,10 +141,18 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
'''
|
||||
tmpl = open(CNTK_TEST_TEMPLATE_PATH, "r").read()
|
||||
model_filename = os.path.join(self.directory, 'Models', self.name)
|
||||
|
||||
# if no reader is passed generate the reader from the network
|
||||
if reader:
|
||||
reader_config = reader.generate_config()
|
||||
else:
|
||||
description, has_inputs, readers = self.root_node.to_config()
|
||||
reader_config = '\n'.join(r.generate_config() for r in readers)
|
||||
|
||||
tmpl_dict = {
|
||||
'DevideId': self.device_id,
|
||||
'ModelPath': model_filename,
|
||||
'Reader': reader.generate_config(),
|
||||
'Reader': reader_config,
|
||||
}
|
||||
return tmpl % tmpl_dict
|
||||
|
||||
|
@ -162,11 +164,19 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
tmpl = open(CNTK_PREDICT_TEMPLATE_PATH, "r").read()
|
||||
model_filename = os.path.join(self.directory, 'Models', self.name)
|
||||
output_filename_base = os.path.join(self.directory, 'Outputs', self.name)
|
||||
|
||||
# if no reader is passed generate the reader from the network
|
||||
if reader:
|
||||
reader_config = reader.generate_config()
|
||||
else:
|
||||
description, has_inputs, readers = self.root_node.to_config()
|
||||
reader_config = '\n'.join(r.generate_config() for r in readers)
|
||||
|
||||
tmpl_dict = {
|
||||
'DevideId': self.device_id,
|
||||
'ModelPath': model_filename,
|
||||
'PredictOutputFile': output_filename_base,
|
||||
'Reader': reader.generate_config(),
|
||||
'Reader': reader_config,
|
||||
}
|
||||
return tmpl % tmpl_dict
|
||||
|
||||
|
@ -204,7 +214,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
return tmpl % tmpl_dict
|
||||
|
||||
@abstractmethod
|
||||
def train(self, reader=None):
|
||||
def train(self, optimizer, reader=None, override_existing = True):
|
||||
'''
|
||||
Abstract method for the action train.
|
||||
:param reader: the reader to use for this action. Alternatively, you
|
||||
|
@ -288,7 +298,7 @@ class Context(AbstractContext):
|
|||
:param override_existing: if the folder exists already override it
|
||||
'''
|
||||
config_content = self._generate_train_config(optimizer, reader, override_existing)
|
||||
output = self._call_cntk(CNTK_TRAIN_CONFIG_FILENAME, config_content)
|
||||
return self._call_cntk(CNTK_TRAIN_CONFIG_FILENAME, config_content)
|
||||
|
||||
def test(self, reader=None):
|
||||
'''
|
||||
|
@ -297,7 +307,7 @@ class Context(AbstractContext):
|
|||
can attach a reader directly to the input node.
|
||||
'''
|
||||
config_content = self._generate_test_config(reader)
|
||||
output = self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
|
||||
return self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
|
||||
|
||||
def predict(self, reader=None):
|
||||
'''
|
||||
|
@ -308,7 +318,7 @@ class Context(AbstractContext):
|
|||
Returns the predicted output
|
||||
'''
|
||||
config_content = self._generate_predict_config(reader)
|
||||
output = self._call_cntk(CNTK_PREDICT_CONFIG_FILENAME, config_content)
|
||||
return self._call_cntk(CNTK_PREDICT_CONFIG_FILENAME, config_content)
|
||||
|
||||
'''
|
||||
Regular expression to parse the shape information of the nodes out of
|
||||
|
|
|
@ -29,27 +29,23 @@ if (__name__ == "__main__"):
|
|||
hidden_dim=200
|
||||
|
||||
training_filename=os.path.join("Data", "Train-28x28.txt")
|
||||
test_filename=os.path.join("Data", "Test-28x28.txt")
|
||||
|
||||
features = Input(feat_dim, var_name='features')
|
||||
features.attach_uci_fast_reader(training_filename, 1)
|
||||
|
||||
feat_scale = Constant(0.00390625)
|
||||
feats_scaled = Scale(feat_scale, features)
|
||||
|
||||
labels = Input(label_dim, tag='label', var_name='labels')
|
||||
labels.attach_uci_fast_reader(training_filename, 0, True, 1, os.path.join("Data", "labelsmap.txt"))
|
||||
|
||||
h1 = dnn_sigmoid_layer(feat_dim, hidden_dim, feats_scaled, 1)
|
||||
out = dnn_layer(hidden_dim, label_dim, h1, 1)
|
||||
out.tag = 'output'
|
||||
|
||||
ec = CrossEntropyWithSoftmax(labels, out)
|
||||
ec.tag = 'criterion'
|
||||
|
||||
# Build the reader
|
||||
r = UCIFastReader(filename=training_filename)
|
||||
|
||||
# Add the input node to the reader
|
||||
r.add_input(features, 1, feat_dim)
|
||||
r.add_input(labels, 0, 1, label_dim, os.path.join("Data", "labelsmap.txt"))
|
||||
ec.tag = 'criterion'
|
||||
|
||||
# Build the optimizer (settings are scaled down)
|
||||
my_sgd = SGD(epoch_size = 600, minibatch_size = 32, learning_ratesPerMB = 0.1, max_epochs = 5, momentum_per_mb = 0)
|
||||
|
@ -57,7 +53,10 @@ if (__name__ == "__main__"):
|
|||
# Create a context or re-use if already there
|
||||
with Context('mnist_one_layer', root_node= ec, clean_up=False) as ctx:
|
||||
# CNTK actions
|
||||
ctx.train(my_sgd, r)
|
||||
r["FileName"] = os.path.join("Data", "Test-28x28.txt")
|
||||
ctx.test(r)
|
||||
ctx.predict(r)
|
||||
#ctx.train(my_sgd)
|
||||
features.attach_uci_fast_reader(test_filename, 1)
|
||||
labels.attach_uci_fast_reader(test_filename, 0, True, 1, os.path.join("Data", "labelsmap.txt"))
|
||||
ctx.predict()
|
||||
ctx.test()
|
||||
ctx.predict()
|
||||
|
||||
|
|
|
@ -147,15 +147,15 @@ class ComputationNode(object):
|
|||
inputs_param = [p_value]
|
||||
|
||||
input_nodes_vars = []
|
||||
for p_value in inputs_param:
|
||||
if p_value in unrolled_nodes:
|
||||
for pv in inputs_param:
|
||||
if pv in unrolled_nodes:
|
||||
# we have seen this node already, so just retrieve its
|
||||
# name
|
||||
child_var = unrolled_nodes[p_value]
|
||||
child_var = unrolled_nodes[pv]
|
||||
else:
|
||||
child_var, node_counter, child_desc = p_value._to_config_recursively(
|
||||
child_var, node_counter, child_desc = pv._to_config_recursively(
|
||||
desc, unrolled_nodes, inputs, readers, node_counter)
|
||||
unrolled_nodes[p_value] = child_var
|
||||
unrolled_nodes[pv] = child_var
|
||||
input_nodes_vars.append(child_var)
|
||||
|
||||
param_variable_names.append(_tuple_to_cntk_shape(input_nodes_vars))
|
||||
|
@ -164,7 +164,7 @@ class ComputationNode(object):
|
|||
self._param_to_brainscript(p_name, p_value))
|
||||
|
||||
if self.reader:
|
||||
readers.append(self.reader)
|
||||
readers.add(self.reader)
|
||||
|
||||
if self._is_input():
|
||||
inputs.add(self)
|
||||
|
@ -188,7 +188,7 @@ class ComputationNode(object):
|
|||
'''
|
||||
unrolled_nodes = {}
|
||||
inputs=set()
|
||||
readers=[]
|
||||
readers=set()
|
||||
var_name, node_counter, desc = self._to_config_recursively(
|
||||
desc=[],
|
||||
unrolled_nodes=unrolled_nodes,
|
||||
|
@ -198,13 +198,14 @@ class ComputationNode(object):
|
|||
return var_name, node_counter, desc, len(inputs)>0, readers
|
||||
|
||||
def _dedupe_readers(self, readers):
|
||||
import copy
|
||||
readers_map = {}
|
||||
for r in readers:
|
||||
filename = r['FileName']
|
||||
if filename in readers_map:
|
||||
readers_map[filename].inputs_def.extend(r.inputs_def)
|
||||
else:
|
||||
readers_map[filename] = r
|
||||
readers_map[filename] = copy.deepcopy(r)
|
||||
|
||||
return [r for r in readers_map.values()]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче