This commit is contained in:
jeanfad 2016-05-13 00:11:49 +02:00
Родитель b519e5de40
Коммит eab9bfdc66
4 изменённых файлов: 13 добавлений и 22 удалений

Просмотреть файл

@ -44,7 +44,7 @@ CNTK_OUTPUT_FILENAME = "out"
_CONTEXT = {}
def get_context(handle='default'):
def get_context(handle):
# TODO: we need more sanity in the model handling here
if handle not in _CONTEXT:
_CONTEXT[handle] = LocalExecutionContext(handle)
@ -249,7 +249,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
tmpl_dict = {
'ActionName': action_name,
'ModelDescription': description,
'Reader': input_map._to_config_description(),
'Reader': input_map._to_config_description(self.directory),
'SGD': training_params._to_config_description(),
}
@ -282,7 +282,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
tmpl_dict = {
'ActionName': action_name,
'Reader': input_map._to_config_description(),
'Reader': input_map._to_config_description(self.directory),
}
return "{0}\n{1}".format(g_params, tmpl % tmpl_dict)
@ -310,7 +310,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
tmpl_dict = {
'ActionName': action_name,
'OutputFile': self.output_filename_base,
'Reader': input_map._to_config_description(),
'Reader': input_map._to_config_description(self.directory),
}
return "{0}\n{1}".format(g_params, tmpl % tmpl_dict)
@ -355,7 +355,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
'NodeUnitTest': node_unit_test,
'OutputFile': self.output_filename_base,
'ModelDescription': description,
'Reader': input_map._to_config_description(),
'Reader': input_map._to_config_description(self.directory),
}
return "{0}\n{1}".format(g_params, tmpl % tmpl_dict)
@ -423,7 +423,7 @@ class LocalExecutionContext(AbstractContext):
if not output:
raise ValueError('no output returned')
return output
'''

Просмотреть файл

@ -493,7 +493,7 @@ class InputMap(object):
def is_empty(self):
return not self.has_mapped() and not self.has_unmapped()
def _to_config_description(self):
def _to_config_description(self, directory=None):
if self.reader is None:
if not self.unmapped_nodes:
# No inputs in the graph
@ -504,8 +504,10 @@ class InputMap(object):
from .context import get_context
from .utils import get_temp_filename
filename = get_temp_filename(get_context().directory)
if not directory:
filename = get_temp_filename(get_context().directory)
else:
filename = get_temp_filename(directory)
if len(self.node_map) > 0:
raise ValueError('you cannot have inputs initialized with '+
'NumPy arrays together with inputs that are ' +

Просмотреть файл

@ -65,14 +65,3 @@ def test_serialize_unmapped_node(tmpdir):
with open(tmpfile, 'r') as f:
assert f.read() == expected
def test_cntk_eval():
import cntk
import cntk.ops
result = cntk.eval(cntk.ops.floor([.4]))
np.allclose(result, [0])
result = cntk.eval(cntk.ops.floor([[.4]]))
np.allclose(result, [[0]])

Просмотреть файл

@ -25,12 +25,12 @@ def eval(node):
NumPy array containing the result
"""
from cntk.context import get_context
from cntk.context import get_new_context
from cntk.ops import input_numpy, constant
from cntk.graph import ComputationNode
# call a helper method to get a context
ctx = get_context()
ctx = get_new_context()
first = True
# The params are passed as arryas, e.g. plus([1,2], [3,4]), and we need to