Integrating LogReg into test framework; allow for multiple root nodes

This commit is contained in:
Willi Richert 2016-04-04 18:28:10 +02:00
Родитель 24802a4b75
Коммит 1aefbad8b7
7 изменённых файлов: 1678 добавлений и 1055 удалений

Просмотреть файл

@ -9,7 +9,7 @@ import shutil as sh
from cntk.graph import ComputationNode
from cntk.ops.cntk1 import NewReshape
from cntk.utils import CNTK_EXECUTABLE_PATH, MODEL_INDENTATION
from .utils import cntk_to_numpy_shape
from .utils import cntk_to_numpy_shape, dedupe_readers
CNTK_TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates")
CNTK_TRAIN_TEMPLATE_PATH = os.path.join(
@ -58,7 +58,7 @@ class AbstractContext(object, metaclass=ABCMeta):
def __init__(self, name,
graph=None,
device_id=-1,
root_node=None,
root_nodes=None,
clean_up=True,
node_unit_test=False):
'''
@ -67,7 +67,7 @@ class AbstractContext(object, metaclass=ABCMeta):
:param name: context name
:param graph: the computational graph to be used for training, testing and prediction
:param device_id: whether to use CPU or a specific GPU. -1 for CPU larger values
:param root_node: the top node of the graph
:param root_nodes: list of top nodes of the graph or single node itself
:param clean_up: whether the temporary directory should be removed when the context is left
are the GPUs indices.
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
@ -90,7 +90,10 @@ class AbstractContext(object, metaclass=ABCMeta):
self.device_id = device_id
self.clean_up = clean_up
self.input_nodes = set()
self.root_node = root_node
if root_nodes is None:
self.root_nodes = None
else:
self.root_nodes = root_nodes if isinstance(root_nodes, list) else [root_nodes]
self.node_unit_test= node_unit_test
def __enter__(self):
@ -104,6 +107,42 @@ class AbstractContext(object, metaclass=ABCMeta):
if self.clean_up:
sh.rmtree(self.directory)
def _generate_config(self, root_nodes=None):
'''
Helper function to create a configuration incorporating all root nodes
'''
has_inputs = False
desc = []
inputs = set()
readers = set()
unrolled_nodes = {}
node_counter = 0
dep_inputs = tuple()
reconciled_cache = {}
if root_nodes is None:
root_nodes = self.root_nodes
elif not isinstance(root_nodes, list):
root_nodes = [root_nodes]
for root_node in root_nodes:
var_name, node_counter, _desc, _has_inputs, _readers, _dep_inputs = \
root_node._to_config(desc,
unrolled_nodes,
inputs,
readers,
dep_inputs,
node_counter, reconciled_cache)
has_inputs |= _has_inputs
readers |= _readers
dep_inputs += _dep_inputs
description = "\n".join(desc)
return description, has_inputs, dedupe_readers(readers)
def _generate_train_config(self, optimizer, reader, override_existing):
'''
Generates the configuration file for the train action.
@ -124,7 +163,7 @@ class AbstractContext(object, metaclass=ABCMeta):
tmpl = open(CNTK_TRAIN_TEMPLATE_PATH, "r").read()
model_filename = os.path.join(model_dir, self.name)
description, has_inputs, readers = self.root_node.to_config()
description, has_inputs, readers = self._generate_config()
if reader:
readers.append(reader)
@ -148,7 +187,7 @@ class AbstractContext(object, metaclass=ABCMeta):
if reader:
reader_config = reader.generate_config()
else:
description, has_inputs, readers = self.root_node.to_config()
description, has_inputs, readers = self._generate_config()
reader_config = '\n'.join(r.generate_config() for r in readers)
tmpl_dict = {
@ -172,7 +211,7 @@ class AbstractContext(object, metaclass=ABCMeta):
if reader:
reader_config = reader.generate_config()
else:
description, has_inputs, readers = self.root_node.to_config()
description, has_inputs, readers = self._generate_config()
reader_config = '\n'.join(r.generate_config() for r in readers)
tmpl_dict = {
@ -183,17 +222,17 @@ class AbstractContext(object, metaclass=ABCMeta):
}
return tmpl % tmpl_dict
def _generate_eval_config(self, root_node, reader):
def _generate_eval_config(self, root_nodes, reader):
'''
Generates the configuration file for write action.
:param root_node: the node to evaluate.
:param root_nodes: the node to evaluate.
:param reader: the reader used to load the data, None if the network does not have input
'''
model_description, has_input, readers = root_node.to_config()
description, has_inputs, readers = self._generate_config(root_nodes)
if reader:
readers.append(reader)
if not has_input and not readers:
if not has_inputs and not readers:
# add dummy input to keep CNTK happy
# TODO relieve this requirement on CNTK side
data = [[1, 2], [3, 4]]
@ -203,7 +242,7 @@ class AbstractContext(object, metaclass=ABCMeta):
from .ops.cntk1 import Input
dummy_input_node = Input(2, var_name='dummy_node')
reader.add_input(dummy_input_node, 0, 2)
model_description += "\n" + " "*MODEL_INDENTATION + "dummy_node = Input(2, tag='output')"
description += "\n" + " "*MODEL_INDENTATION + "dummy_node = Input(2, tag='output')"
readers.append(reader)
tmpl = open(CNTK_EVAL_TEMPLATE_PATH, "r").read()
@ -212,7 +251,7 @@ class AbstractContext(object, metaclass=ABCMeta):
'DevideId': self.device_id,
'NodeUnitTest': self.node_unit_test,
'OutputFile': output_filename,
'ModelDescription': model_description,
'ModelDescription': description,
'Reader': '\n'.join(r.generate_config() for r in readers),
}
return tmpl % tmpl_dict
@ -317,7 +356,10 @@ class Context(AbstractContext):
can attach a reader directly to the input node.
'''
config_content = self._generate_test_config(reader)
return self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
output = self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
return Context._parse_test_result(output)
def predict(self, reader=None):
'''
@ -422,6 +464,41 @@ class Context(AbstractContext):
return list_of_tensors
TEST_RESULT_REGEX = re.compile('(?P<name>[^:]+): [^=]+ = (?P<number>[0-9.]+)')
@staticmethod
def _parse_test_result(output):
result = {}
PREAMPLE = 'Final Results: Minibatch[1-1]: '
for line in output.splitlines():
if not line.startswith(PREAMPLE):
continue
line = line[len(PREAMPLE):]
if not line.startswith('SamplesSeen = '):
raise ValueError('expected SamplesSeen but got "%s"'%line)
line = line[len('SamplesSeen = '):]
number_ends = line.index(' ')
result['SamplesSeen'] = int(line[:number_ends])
line = line[number_ends:]
perplexity_idx = line.index('Perplexity = ')
result['Perplexity'] = float(line[perplexity_idx+len('Perplexity = '):])
line = line[:perplexity_idx]
mo = Context.TEST_RESULT_REGEX.match(line)
while mo:
result[mo.group('name').strip()] = float(mo.group('number').strip())
line = line[mo.span()[1]:]
mo = Context.TEST_RESULT_REGEX.match(line)
return result
def _calc_expected_shape_and_size(self, node, data, shapes):
'''
Calculates the expected shape and size from the CNTK output and the

Просмотреть файл

@ -0,0 +1,500 @@
3.854499 4.163941 1
1.058121 1.204858 0
1.870621 1.284107 0
1.134650 1.651822 0
5.420541 4.557660 1
6.042731 3.375708 1
5.667109 2.811728 1
0.232070 1.814821 0
-0.647150 -1.612478 0
2.626172 5.321667 1
1.359751 2.056849 0
3.534476 6.011925 1
4.871508 2.245406 1
4.977201 6.092787 1
1.597508 2.110568 0
2.099170 0.073616 0
0.638281 -0.171881 0
4.606747 4.092115 1
5.168790 4.673153 1
5.084637 4.435160 1
3.379607 2.765107 1
3.992242 2.799751 1
1.807728 0.205914 0
1.946180 0.303569 0
0.218267 1.301271 0
4.932840 2.117177 1
3.739489 2.458558 1
1.597743 -2.192362 0
3.582005 3.350572 1
3.930642 5.733507 1
5.747863 3.739415 1
-0.631374 2.314482 0
0.866484 0.363432 0
0.293501 0.347385 0
4.544393 4.699040 1
-0.242005 0.926520 0
3.637198 5.238140 1
-0.269463 1.525586 0
0.682529 -0.703649 0
3.562643 -0.126556 0
2.671530 3.729066 1
4.034716 3.458366 1
5.401503 3.117191 1
1.157177 1.183186 0
0.778963 1.394348 0
4.599715 2.297663 1
4.532568 4.568362 1
1.785478 -0.213185 0
4.617391 4.230360 1
5.672957 3.668370 1
4.267738 5.390780 1
0.707751 2.955391 0
0.791275 1.654795 0
1.760541 0.976920 0
4.543920 2.222765 1
4.515881 6.199021 1
3.645005 3.611395 1
0.965049 1.737265 0
-1.779455 1.595554 0
-0.484797 -0.559924 0
2.944180 4.429239 1
3.326649 4.412622 1
4.275101 2.143945 1
1.173035 0.641844 0
4.003884 3.176954 1
1.960240 -0.244709 0
0.320283 2.115552 0
2.303185 3.047043 1
0.993086 0.074009 0
5.599144 3.857344 1
5.325894 3.931000 1
2.840053 4.781688 1
4.142453 3.405830 1
1.084043 1.589581 0
2.795705 2.319276 1
1.980552 0.717780 0
1.875956 -0.571905 0
2.013802 1.694811 0
4.690795 2.183334 1
4.321816 1.876459 1
4.088717 4.394346 1
4.991936 4.299770 1
2.592315 4.783210 1
0.703270 2.541733 0
0.467768 -0.007592 0
1.694096 -0.570847 0
2.255603 0.663395 0
1.300394 1.518341 0
4.354786 4.501928 1
1.474162 0.603113 0
1.340782 0.637653 0
-0.351240 0.501893 0
4.918587 5.366305 1
2.242199 -0.916682 0
-0.161858 0.448384 0
1.659615 1.524191 0
3.072670 1.703225 0
0.003256 -0.306702 0
-1.792094 1.193539 0
7.200298 3.962190 1
4.220305 4.190289 1
4.096599 3.264797 1
-0.674145 0.751491 0
3.215213 4.549768 1
1.522988 3.311437 0
4.393445 1.822070 1
1.991048 1.429309 0
4.741012 3.169984 1
2.563678 1.798587 0
3.310656 3.600789 1
0.559119 -0.193984 0
3.182626 3.279566 1
0.145061 1.428861 0
5.748625 2.766672 1
1.612338 -0.441931 0
0.521950 0.355267 0
4.284910 3.874950 1
4.911425 3.054658 1
2.946163 0.502614 0
4.381390 2.600999 1
0.585791 -0.528432 0
1.329802 -0.076910 0
0.860040 1.153562 0
0.930515 -0.257435 0
2.775174 0.751338 0
2.429059 0.615483 0
2.546002 1.132210 0
5.059000 3.423829 1
1.303533 0.013015 0
2.160149 -0.400779 0
5.038046 3.027673 1
4.583471 5.379319 1
5.608845 2.082021 1
3.406426 3.326734 1
4.267102 3.866177 1
1.799669 0.489094 0
1.807634 2.029468 0
1.536463 1.053052 0
5.653295 3.369125 1
2.493326 0.794542 0
1.528977 0.961929 0
1.973016 0.696162 0
2.283974 0.198255 0
5.227293 4.395268 1
5.302484 4.021613 1
6.223076 4.537934 1
1.460204 -1.055539 0
2.985097 4.228990 1
1.685054 0.499576 0
0.521659 0.510605 0
1.891089 1.284388 0
4.620926 3.662371 1
1.613905 -0.770152 0
6.007418 4.755721 1
0.798078 -0.304557 0
5.242706 2.099872 1
1.518268 -0.858963 0
3.733642 4.244483 1
0.970367 -1.534686 0
1.334952 2.250191 0
2.252214 3.343515 1
3.982213 4.457969 1
5.086620 3.180442 1
0.005277 0.197319 0
2.999128 2.909942 1
2.412666 2.046286 0
2.044537 3.416533 1
2.650439 3.372171 1
2.480446 1.327368 0
4.824915 5.603495 1
0.759204 0.531043 0
1.965476 1.372763 0
1.000248 1.208139 0
1.979980 -0.446807 0
0.528053 1.178535 0
5.442396 3.969797 1
-0.145691 1.375993 0
1.336725 -0.006089 0
5.291797 3.250537 1
4.286453 1.117735 1
-0.928654 -0.925485 0
3.332391 2.603963 1
3.215562 4.756808 1
1.610967 0.830856 0
2.174433 3.501271 1
4.848584 4.251824 1
0.810184 1.152021 0
4.873924 4.517936 1
1.915303 1.649095 0
1.623343 -0.081105 0
1.944076 0.482732 0
2.442956 1.254540 0
-1.002581 1.265333 0
0.959354 0.678516 0
-0.478621 2.502554 0
3.357642 2.993470 1
5.741979 2.958477 1
4.474261 3.260622 1
3.587932 4.572091 1
1.274866 0.695311 0
4.557162 4.754880 1
0.557867 0.280893 0
1.832047 -2.162059 0
3.904049 5.257427 1
3.225019 3.845294 1
4.451218 4.125344 1
3.138143 2.869685 1
4.451703 3.430654 1
0.124060 1.422203 0
4.692774 5.156611 1
0.735314 0.375099 0
0.727577 1.158726 0
0.643469 0.283426 0
5.126834 1.929468 1
-0.172361 2.982370 0
3.957745 1.561874 1
5.563733 3.417080 1
5.181533 1.465063 1
5.843654 5.040710 1
0.761570 0.171094 0
3.163795 3.940869 1
2.435362 1.047614 0
2.524330 3.602348 1
4.200838 3.267377 1
4.249560 2.926280 1
0.060257 0.295729 0
1.528257 1.651867 0
2.030978 1.566011 0
4.065243 4.375190 1
1.406204 0.238570 0
1.229776 1.186559 0
2.295681 1.883864 0
3.966570 4.293142 1
1.713323 0.534886 0
0.772032 -0.096214 0
3.392854 5.195064 1
5.063653 2.749764 1
1.410392 1.694554 0
0.540269 0.376759 0
4.103946 3.870140 1
5.132739 3.079176 1
2.524063 0.486934 0
0.046403 1.452778 0
1.705593 0.243750 0
1.621902 0.203138 0
-0.420733 0.589060 0
2.887145 2.621849 1
5.545509 4.473069 1
0.326439 -0.162102 0
0.906097 -0.018566 0
3.398280 5.125843 1
0.833088 -0.808535 0
4.535285 4.133511 1
1.781705 4.123651 1
4.345894 3.355084 1
4.770073 3.007432 1
2.537267 3.813503 1
0.994347 2.567949 0
0.337262 -0.224479 0
4.936596 3.107819 1
2.177957 -0.544641 0
3.434811 2.806362 1
3.172973 4.378089 1
4.015349 3.000845 1
3.640748 3.917499 1
5.432434 4.092587 1
4.701984 4.063092 1
3.978015 3.584431 1
5.029923 2.346036 1
4.939017 3.209084 1
3.999592 2.747525 1
5.233483 4.877698 1
2.260049 1.023384 0
-1.149943 1.257165 0
-0.026270 0.468090 0
5.155107 4.620842 1
4.179414 4.807546 1
2.560286 0.526253 0
5.843334 1.439470 1
4.417442 4.483117 1
4.354138 4.496168 1
0.873730 2.230023 0
4.531298 4.944164 1
2.010164 -0.358403 0
1.165044 1.376602 0
1.451538 -0.197779 0
-1.751961 0.210820 0
2.431281 3.878465 1
3.311168 3.697618 1
2.324742 -0.330745 0
1.447031 1.028776 0
0.711003 2.631227 0
4.872934 3.406132 1
2.419345 0.297983 0
0.437814 2.851194 0
3.105758 4.098041 1
5.310168 3.519401 1
1.218607 -1.505891 0
6.053827 2.848790 1
3.475758 3.352349 1
0.911730 -0.213069 0
1.255973 0.089677 0
4.152711 3.871858 1
3.003909 3.288998 1
0.291281 1.124965 0
2.155017 0.550642 0
3.494102 0.710991 0
4.376613 2.330150 1
4.707851 6.179972 1
0.614240 -0.243535 0
1.130049 0.870765 0
3.994615 2.855247 1
1.556420 0.106179 0
3.182309 5.121422 1
2.315933 0.418897 0
1.797904 0.633645 0
4.012446 3.887718 1
2.106849 3.776831 1
4.477828 3.989422 1
2.871290 4.610706 1
5.317459 5.621137 1
2.265963 -0.095395 0
2.963642 2.804267 1
5.859384 3.673343 1
6.365340 3.541960 1
1.450987 0.721751 0
4.641593 2.436289 1
-0.126649 0.101750 0
1.835293 1.594895 0
2.121195 0.152643 0
1.881799 1.169974 0
2.421852 -0.089441 0
0.110206 -1.491046 0
6.200556 4.284843 1
3.545593 5.217408 1
3.365187 2.790974 1
6.493131 5.311132 1
0.800791 0.229630 0
4.975666 4.214251 1
1.562586 0.181976 0
0.899273 0.003180 0
6.064242 3.482802 1
1.777259 2.498596 0
5.479965 5.168898 1
4.671380 3.356556 1
1.730588 0.417775 0
2.463118 -0.305587 0
3.967679 0.361350 0
0.164925 -0.167591 0
4.777002 3.088492 1
2.049808 3.096552 0
1.416130 -1.043606 0
0.318913 -1.539956 0
6.004351 2.521442 1
2.969229 3.311301 1
0.879291 0.094171 0
5.290177 5.198102 1
-0.305314 0.826116 0
2.091880 -1.176581 0
2.816867 2.875016 1
0.486424 -1.055319 0
3.012812 4.530291 1
1.137009 1.323397 0
0.088114 -0.353501 0
1.174005 0.188025 0
1.928114 1.398347 0
0.128505 1.430034 0
2.021187 0.577234 0
1.361335 0.394605 0
5.125811 4.221355 1
0.260733 1.758422 0
2.106970 0.305971 0
3.675850 5.051226 1
2.105405 0.240527 0
3.072167 3.130910 1
0.987479 0.036861 0
-0.271382 0.094250 0
4.703495 2.620398 1
3.005831 2.220124 1
5.072896 1.477152 1
4.443991 3.679157 1
0.845034 0.419956 0
4.698964 3.109439 1
1.766144 0.595496 0
2.046076 0.433007 0
0.874663 1.010155 0
4.939031 5.340021 1
3.881158 3.072467 1
2.928763 4.160337 1
5.582289 4.805588 1
3.180992 3.459563 1
-0.486820 -0.074926 0
4.091057 2.402846 1
4.915464 4.543850 1
1.492434 0.588755 0
2.594011 0.332043 0
0.317571 -0.525159 0
3.936029 4.312181 1
1.918811 -0.659594 0
2.657582 0.028525 0
4.637282 3.562483 1
-0.097472 1.250080 0
1.340281 -1.399129 0
4.330372 3.140502 1
4.358103 3.760854 1
3.897352 4.806873 1
4.962704 4.692459 1
1.667918 -0.134096 0
4.929650 1.727842 1
2.434315 3.000448 1
1.179167 1.894836 0
0.190498 0.655592 0
3.408802 4.843020 1
4.497565 3.844998 1
-0.501596 1.561013 0
4.158981 4.875362 1
4.017462 4.655003 1
3.319263 3.462037 1
2.635572 1.022114 0
2.638164 5.051437 1
4.875001 3.592322 1
-0.276607 0.800369 0
4.351591 3.321136 1
3.699848 3.317014 1
4.947319 4.252134 1
4.146336 2.162761 1
5.231704 5.477804 1
3.302101 3.994218 1
-0.249349 2.069960 0
4.705134 3.921461 1
4.652980 4.287917 1
3.937259 -0.334385 0
3.257619 2.758094 1
0.994191 3.135344 0
4.649768 2.123305 1
1.634135 0.241517 0
1.682542 2.057739 1
5.163117 4.467304 1
4.638594 4.141250 1
1.392605 0.635603 0
4.319784 2.965064 1
1.872466 1.566002 0
4.230714 5.179026 1
2.635294 3.470599 1
0.988464 0.943613 0
0.897546 0.129141 0
3.370731 2.019838 0
1.424812 0.081647 0
5.961444 3.372419 1
2.839070 0.926229 0
0.279132 1.607793 0
5.351031 3.693640 1
2.637437 1.951445 0
-0.179258 0.349339 0
3.246295 1.013459 0
5.839643 4.556761 1
1.435225 0.937185 0
0.500440 0.348246 0
4.948782 4.994416 1
0.810541 0.456830 0
5.098827 4.142789 1
2.365307 0.729496 0
-0.117730 0.891913 0
0.485735 0.513485 0
0.680270 1.486851 0
1.143053 0.227480 0
6.615446 4.561501 1
1.016051 1.862106 0
0.668177 -0.212610 0
2.906047 2.415627 1
5.576097 5.068683 1
1.315063 -0.040980 0
5.375285 3.306877 1
4.549934 3.805014 1
1.189238 0.661279 0
4.156567 3.280736 1
2.061355 1.090958 0
4.499387 3.640263 1
3.503883 1.015591 0
0.390200 -1.037188 0
2.922873 4.696711 1
1.803928 3.846808 1
0.907921 -2.139287 0
1.640739 0.592793 0
5.108193 3.194757 1
4.297873 4.034234 1
4.832678 4.073469 1
4.391764 3.557895 1
2.006343 0.836557 0
0.351400 1.534742 0
4.933823 2.937944 1
3.926482 2.073712 1
5.382385 4.818642 1
4.739010 3.213326 1
0.026227 0.177150 0
5.001353 3.300961 1
5.022782 2.921902 1
4.225051 4.534986 1
3.745148 -0.169000 0
5.891838 2.817417 1

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -4,27 +4,47 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
from cntk import *
if (__name__ == "__main__"):
X = Input(2)
X.attach_uci_fast_reader("Train-3Classes.txt", 0)
cur_dir = os.path.dirname(__file__)
# Using data from https://github.com/Microsoft/CNTK/wiki/Tutorial
train_file = os.path.join(cur_dir, "Train-3Classes.txt")
test_file = os.path.join(cur_dir, "Test-3Classes.txt")
mapping_file = os.path.join(cur_dir, "SimpleMapping-3Classes.txt")
def train_eval_logreg(criterion_name, eval_name):
X = Input(2)
y = Input(3)
y.attach_uci_fast_reader(
"Train-3Classes.txt", 2, True, 1, "SimpleMapping-3Classes.txt")
W = LearnableParameter(3, 2)
b = LearnableParameter(3, 1)
out = Times(W, X) + b
out.tag = 'output'
ce = CrossEntropyWithSoftmax(y, out)
ce = CrossEntropyWithSoftmax(y, out, var_name=criterion_name)
ce.tag = 'criterion'
eval = SquareError(y, out, var_name=eval_name)
eval.tag = 'eval'
my_sgd = SGD(
epoch_size=0, minibatch_size=25, learning_ratesPerMB=0.1, max_epochs=3)
with Context('demo', root_node=ce, clean_up=False) as ctx:
ctx.train(my_sgd, None)
with Context('demo', root_nodes=[ce,eval], clean_up=False) as ctx:
X.attach_uci_fast_reader(train_file, 0)
y.attach_uci_fast_reader(train_file, 2, True, 1, mapping_file)
ctx.train(my_sgd)
X.attach_uci_fast_reader(test_file, 0)
y.attach_uci_fast_reader(test_file, 2, True, 1, mapping_file)
result = ctx.test()
result = ctx.eval(out)
print(result.argmax(axis=1))
return result
def test_logreg():
result = train_eval_logreg('crit_node', 'eval_node')
assert result['SamplesSeen'] == 500
assert result['Perplexity'] == 1.2216067
assert result['eval_node'] == 13.779223
assert result['crit_node'] == 0.20016696
if __name__ == "__main__":
print(train_eval_logreg())

Просмотреть файл

@ -9,7 +9,7 @@ class sparse(object):
return hasattr(obj, 'todense')
from .utils import MODEL_INDENTATION
from .utils import numpy_to_cntk_shape
from .utils import numpy_to_cntk_shape, dedupe_readers
def _tuple_to_cntk_shape(shape):
return ':'.join(str(v) for v in shape)
@ -153,9 +153,8 @@ class ComputationNode(object):
return is_loop_node and p_name == 'input' and isinstance(p_value, str)
def _to_config_recursively(self, desc, unrolled_nodes, inputs,
readers, dep_inputs=None, node_counter=0):
if dep_inputs is None:
dep_inputs = tuple()
readers, dep_inputs, node_counter,
reconciled_cache):
param_variable_names = []
# In case we have multiple unreconciled inputs, we will reconcile each
@ -186,7 +185,7 @@ class ComputationNode(object):
else:
child_var, node_counter, child_desc, child_dep_inputs = pv._to_config_recursively(
desc, unrolled_nodes, inputs, readers,
dep_inputs, node_counter)
dep_inputs, node_counter, reconciled_cache)
unrolled_nodes[pv] = child_var, dep_inputs
@ -200,10 +199,16 @@ class ComputationNode(object):
first_unreconciled_input = pv
else:
pv = ReconcileMBLayout(pv, first_unreconciled_input)
child_var, node_counter, child_desc, dep_inputs = pv._to_config_recursively(
desc, unrolled_nodes, inputs, readers,
dep_inputs, node_counter)
if (pv, first_unreconciled_input) in reconciled_cache:
child_var, dep_inputs = reconciled_cache[(pv, first_unreconciled_input)]
else:
unrec_pv = pv
pv = ReconcileMBLayout(unrec_pv, first_unreconciled_input)
child_var, node_counter, child_desc, dep_inputs = pv._to_config_recursively(
desc, unrolled_nodes, inputs, readers,
dep_inputs, node_counter,
reconciled_cache)
reconciled_cache[(unrec_pv, first_unreconciled_input)] = pv.var_name, dep_inputs
unrolled_nodes[pv] = child_var, dep_inputs
@ -246,41 +251,38 @@ class ComputationNode(object):
return self.var_name, node_counter, desc, dep_inputs
def _to_config(self):
def _to_config(self, description, unrolled_nodes, inputs, readers,
dep_inputs, node_counter, reconciled_cache):
'''
Helper method to generate the CNTK configuration for this node.
'''
unrolled_nodes = {}
inputs = set()
readers = set()
var_name, node_counter, desc, dep_inputs = self._to_config_recursively(
desc=[],
description,
unrolled_nodes=unrolled_nodes,
inputs=inputs,
readers=readers)
readers=readers,
dep_inputs=dep_inputs,
node_counter=node_counter,
reconciled_cache=reconciled_cache)
return var_name, node_counter, desc, len(inputs) > 0, readers
def _dedupe_readers(self, readers):
import copy
readers_map = {}
for r in readers:
filename = r['FileName']
if filename in readers_map:
readers_map[filename].inputs_def.extend(r.inputs_def)
else:
readers_map[filename] = copy.deepcopy(r)
return [r for r in readers_map.values()]
return var_name, node_counter, desc, len(inputs) > 0, readers, dep_inputs
def to_config(self):
'''
Generate CNTK configuration for this node including the configuration
for all dependent child nodes.
'''
var_name, node_counter, desc, has_inputs, readers = self._to_config()
var_name, node_counter, desc, has_inputs, readers, dep_inputs = \
self._to_config(description=[],
unrolled_nodes={},
inputs=set(),
readers=set(),
dep_inputs=tuple(),
node_counter=0,
reconciled_cache={})
return "\n".join(desc), has_inputs, self._dedupe_readers(readers)
return "\n".join(desc), has_inputs, dedupe_readers(readers)
class InputComputationNodeBase(ComputationNode, metaclass=ABCMeta):

Просмотреть файл

@ -55,7 +55,7 @@ Validating --> v5 = Plus (v3, v4) : [3 {1} x *], [3 x 1 {1,3}] -> [3 x 1 {1,3} x
assert Context._parse_shapes_from_output(output) == expected
def test_parse_result_output_1():
def test_parse_eval_result_output_1():
output = '''\
0 |w.shape 1 1
0 |w 60.000000
@ -68,3 +68,14 @@ def test_parse_result_output_1():
for res, exp in zip(list_of_tensors, expected):
assert np.allclose(res, np.asarray(exp))
def test_parse_test_result_output():
output = '''\
Final Results: Minibatch[1-1]: SamplesSeen = 500 v8: SquareError/Sample = 13.779223 v7: CrossEntropyWithSoftmax/Sample = 0.20016696 Perplexity = 1.2216067 '''
result = Context._parse_test_result(output)
assert result['SamplesSeen'] == 500
assert result['Perplexity'] == 1.2216067
assert result['v8'] == 13.779223
assert result['v7'] == 0.20016696
assert len(result) == 4

Просмотреть файл

@ -40,3 +40,16 @@ def cntk_to_numpy_shape(shape):
shape = (1,)
return shape
def dedupe_readers(readers):
import copy
readers_map = {}
for r in readers:
filename = r['FileName']
if filename in readers_map:
readers_map[filename].inputs_def.extend(r.inputs_def)
else:
readers_map[filename] = copy.deepcopy(r)
return [r for r in readers_map.values()]