This commit is contained in:
jeanfad 2016-03-29 22:32:20 +02:00
Родитель 80882eabe5
Коммит 242202444d
2 изменённых файлов: 6 добавлений и 1 удалений

Просмотреть файл

@ -58,7 +58,8 @@ class AbstractContext(object, metaclass=ABCMeta):
graph=None,
device_id=-1,
root_node=None,
clean_up=True):
clean_up=True,
node_unit_test=False):
'''
AbstractContext Constructer
@ -68,6 +69,7 @@ class AbstractContext(object, metaclass=ABCMeta):
:param root_node: the top node of the graph
:param clean_up: whether the temporary directory should be removed when the context is left
are the GPUs indices.
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
'''
if isinstance(name, str):
@ -88,6 +90,7 @@ class AbstractContext(object, metaclass=ABCMeta):
self.clean_up = clean_up
self.input_nodes = set()
self.root_node = root_node
self.node_unit_test= node_unit_test
def __enter__(self):
_CONTEXT[self.name] = self
@ -206,6 +209,7 @@ class AbstractContext(object, metaclass=ABCMeta):
output_filename = os.path.join(self.directory, CNTK_OUTPUT_FILENAME)
tmpl_dict = {
'DevideId': self.device_id,
'NodeUnitTest': self.node_unit_test,
'OutputFile': output_filename,
'ModelDescription': model_description,
'Reader': '\n'.join(r.generate_config() for r in readers),

Просмотреть файл

@ -6,6 +6,7 @@ deviceId=%(DevideId)s
Eval=[
action="write"
nodeUnitTest=%(NodeUnitTest)s
run=BrainScriptNetworkBuilder
BrainScriptNetworkBuilder=[