get rid of default context
This commit is contained in:
Родитель
b519e5de40
Коммит
eab9bfdc66
|
@ -44,7 +44,7 @@ CNTK_OUTPUT_FILENAME = "out"
|
|||
_CONTEXT = {}
|
||||
|
||||
|
||||
def get_context(handle='default'):
|
||||
def get_context(handle):
|
||||
# TODO: we need more sanity in the model handling here
|
||||
if handle not in _CONTEXT:
|
||||
_CONTEXT[handle] = LocalExecutionContext(handle)
|
||||
|
@ -249,7 +249,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
|
|||
tmpl_dict = {
|
||||
'ActionName': action_name,
|
||||
'ModelDescription': description,
|
||||
'Reader': input_map._to_config_description(),
|
||||
'Reader': input_map._to_config_description(self.directory),
|
||||
'SGD': training_params._to_config_description(),
|
||||
}
|
||||
|
||||
|
@ -282,7 +282,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
|
|||
|
||||
tmpl_dict = {
|
||||
'ActionName': action_name,
|
||||
'Reader': input_map._to_config_description(),
|
||||
'Reader': input_map._to_config_description(self.directory),
|
||||
}
|
||||
return "{0}\n{1}".format(g_params, tmpl % tmpl_dict)
|
||||
|
||||
|
@ -310,7 +310,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
|
|||
tmpl_dict = {
|
||||
'ActionName': action_name,
|
||||
'OutputFile': self.output_filename_base,
|
||||
'Reader': input_map._to_config_description(),
|
||||
'Reader': input_map._to_config_description(self.directory),
|
||||
}
|
||||
return "{0}\n{1}".format(g_params, tmpl % tmpl_dict)
|
||||
|
||||
|
@ -355,7 +355,7 @@ class AbstractContext(with_metaclass(ABCMeta, object)):
|
|||
'NodeUnitTest': node_unit_test,
|
||||
'OutputFile': self.output_filename_base,
|
||||
'ModelDescription': description,
|
||||
'Reader': input_map._to_config_description(),
|
||||
'Reader': input_map._to_config_description(self.directory),
|
||||
}
|
||||
return "{0}\n{1}".format(g_params, tmpl % tmpl_dict)
|
||||
|
||||
|
@ -423,7 +423,7 @@ class LocalExecutionContext(AbstractContext):
|
|||
|
||||
if not output:
|
||||
raise ValueError('no output returned')
|
||||
|
||||
|
||||
return output
|
||||
|
||||
'''
|
||||
|
|
|
@ -493,7 +493,7 @@ class InputMap(object):
|
|||
def is_empty(self):
|
||||
return not self.has_mapped() and not self.has_unmapped()
|
||||
|
||||
def _to_config_description(self):
|
||||
def _to_config_description(self, directory=None):
|
||||
if self.reader is None:
|
||||
if not self.unmapped_nodes:
|
||||
# No inputs in the graph
|
||||
|
@ -504,8 +504,10 @@ class InputMap(object):
|
|||
|
||||
from .context import get_context
|
||||
from .utils import get_temp_filename
|
||||
filename = get_temp_filename(get_context().directory)
|
||||
|
||||
if not directory:
|
||||
filename = get_temp_filename(get_context().directory)
|
||||
else:
|
||||
filename = get_temp_filename(directory)
|
||||
if len(self.node_map) > 0:
|
||||
raise ValueError('you cannot have inputs initialized with '+
|
||||
'NumPy arrays together with inputs that are ' +
|
||||
|
|
|
@ -65,14 +65,3 @@ def test_serialize_unmapped_node(tmpdir):
|
|||
|
||||
with open(tmpfile, 'r') as f:
|
||||
assert f.read() == expected
|
||||
|
||||
def test_cntk_eval():
|
||||
|
||||
import cntk
|
||||
import cntk.ops
|
||||
|
||||
result = cntk.eval(cntk.ops.floor([.4]))
|
||||
np.allclose(result, [0])
|
||||
|
||||
result = cntk.eval(cntk.ops.floor([[.4]]))
|
||||
np.allclose(result, [[0]])
|
||||
|
|
|
@ -25,12 +25,12 @@ def eval(node):
|
|||
NumPy array containing the result
|
||||
"""
|
||||
|
||||
from cntk.context import get_context
|
||||
from cntk.context import get_new_context
|
||||
from cntk.ops import input_numpy, constant
|
||||
from cntk.graph import ComputationNode
|
||||
|
||||
# call a helper method to get a context
|
||||
ctx = get_context()
|
||||
ctx = get_new_context()
|
||||
first = True
|
||||
|
||||
# The params are passed as arryas, e.g. plus([1,2], [3,4]), and we need to
|
||||
|
|
Загрузка…
Ссылка в новой задаче