prepare context for grad test
This commit is contained in:
Родитель
80882eabe5
Коммит
242202444d
|
@ -58,7 +58,8 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
graph=None,
|
||||
device_id=-1,
|
||||
root_node=None,
|
||||
clean_up=True):
|
||||
clean_up=True,
|
||||
node_unit_test=False):
|
||||
'''
|
||||
AbstractContext Constructer
|
||||
|
||||
|
@ -68,6 +69,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
:param root_node: the top node of the graph
|
||||
:param clean_up: whether the temporary directory should be removed when the context is left
|
||||
are the GPUs indices.
|
||||
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
|
||||
|
||||
'''
|
||||
if isinstance(name, str):
|
||||
|
@ -88,6 +90,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
self.clean_up = clean_up
|
||||
self.input_nodes = set()
|
||||
self.root_node = root_node
|
||||
self.node_unit_test= node_unit_test
|
||||
|
||||
def __enter__(self):
|
||||
_CONTEXT[self.name] = self
|
||||
|
@ -206,6 +209,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
output_filename = os.path.join(self.directory, CNTK_OUTPUT_FILENAME)
|
||||
tmpl_dict = {
|
||||
'DevideId': self.device_id,
|
||||
'NodeUnitTest': self.node_unit_test,
|
||||
'OutputFile': output_filename,
|
||||
'ModelDescription': model_description,
|
||||
'Reader': '\n'.join(r.generate_config() for r in readers),
|
||||
|
|
|
@ -6,6 +6,7 @@ deviceId=%(DevideId)s
|
|||
|
||||
Eval=[
|
||||
action="write"
|
||||
nodeUnitTest=%(NodeUnitTest)s
|
||||
run=BrainScriptNetworkBuilder
|
||||
|
||||
BrainScriptNetworkBuilder=[
|
||||
|
|
Загрузка…
Ссылка в новой задаче