Integrating LogReg into test framework; allow for multiple root nodes
This commit is contained in:
Родитель
24802a4b75
Коммит
1aefbad8b7
|
@ -9,7 +9,7 @@ import shutil as sh
|
|||
from cntk.graph import ComputationNode
|
||||
from cntk.ops.cntk1 import NewReshape
|
||||
from cntk.utils import CNTK_EXECUTABLE_PATH, MODEL_INDENTATION
|
||||
from .utils import cntk_to_numpy_shape
|
||||
from .utils import cntk_to_numpy_shape, dedupe_readers
|
||||
|
||||
CNTK_TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates")
|
||||
CNTK_TRAIN_TEMPLATE_PATH = os.path.join(
|
||||
|
@ -58,7 +58,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
def __init__(self, name,
|
||||
graph=None,
|
||||
device_id=-1,
|
||||
root_node=None,
|
||||
root_nodes=None,
|
||||
clean_up=True,
|
||||
node_unit_test=False):
|
||||
'''
|
||||
|
@ -67,7 +67,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
:param name: context name
|
||||
:param graph: the computational graph to be used for training, testing and prediction
|
||||
:param device_id: whether to use CPU or a specific GPU. -1 for CPU larger values
|
||||
:param root_node: the top node of the graph
|
||||
:param root_nodes: list of top nodes of the graph or single node itself
|
||||
:param clean_up: whether the temporary directory should be removed when the context is left
|
||||
are the GPUs indices.
|
||||
:param node_unit_test: set to True if you want to output the gradient of a node (backward pass)
|
||||
|
@ -90,7 +90,10 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
self.device_id = device_id
|
||||
self.clean_up = clean_up
|
||||
self.input_nodes = set()
|
||||
self.root_node = root_node
|
||||
if root_nodes is None:
|
||||
self.root_nodes = None
|
||||
else:
|
||||
self.root_nodes = root_nodes if isinstance(root_nodes, list) else [root_nodes]
|
||||
self.node_unit_test= node_unit_test
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -104,6 +107,42 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
if self.clean_up:
|
||||
sh.rmtree(self.directory)
|
||||
|
||||
def _generate_config(self, root_nodes=None):
|
||||
'''
|
||||
Helper function to create a configuration incorporating all root nodes
|
||||
'''
|
||||
has_inputs = False
|
||||
|
||||
desc = []
|
||||
inputs = set()
|
||||
readers = set()
|
||||
unrolled_nodes = {}
|
||||
node_counter = 0
|
||||
dep_inputs = tuple()
|
||||
reconciled_cache = {}
|
||||
|
||||
if root_nodes is None:
|
||||
root_nodes = self.root_nodes
|
||||
elif not isinstance(root_nodes, list):
|
||||
root_nodes = [root_nodes]
|
||||
|
||||
for root_node in root_nodes:
|
||||
var_name, node_counter, _desc, _has_inputs, _readers, _dep_inputs = \
|
||||
root_node._to_config(desc,
|
||||
unrolled_nodes,
|
||||
inputs,
|
||||
readers,
|
||||
dep_inputs,
|
||||
node_counter, reconciled_cache)
|
||||
|
||||
has_inputs |= _has_inputs
|
||||
readers |= _readers
|
||||
dep_inputs += _dep_inputs
|
||||
|
||||
description = "\n".join(desc)
|
||||
|
||||
return description, has_inputs, dedupe_readers(readers)
|
||||
|
||||
def _generate_train_config(self, optimizer, reader, override_existing):
|
||||
'''
|
||||
Generates the configuration file for the train action.
|
||||
|
@ -124,7 +163,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
|
||||
tmpl = open(CNTK_TRAIN_TEMPLATE_PATH, "r").read()
|
||||
model_filename = os.path.join(model_dir, self.name)
|
||||
description, has_inputs, readers = self.root_node.to_config()
|
||||
description, has_inputs, readers = self._generate_config()
|
||||
if reader:
|
||||
readers.append(reader)
|
||||
|
||||
|
@ -148,7 +187,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
if reader:
|
||||
reader_config = reader.generate_config()
|
||||
else:
|
||||
description, has_inputs, readers = self.root_node.to_config()
|
||||
description, has_inputs, readers = self._generate_config()
|
||||
reader_config = '\n'.join(r.generate_config() for r in readers)
|
||||
|
||||
tmpl_dict = {
|
||||
|
@ -172,7 +211,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
if reader:
|
||||
reader_config = reader.generate_config()
|
||||
else:
|
||||
description, has_inputs, readers = self.root_node.to_config()
|
||||
description, has_inputs, readers = self._generate_config()
|
||||
reader_config = '\n'.join(r.generate_config() for r in readers)
|
||||
|
||||
tmpl_dict = {
|
||||
|
@ -183,17 +222,17 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
}
|
||||
return tmpl % tmpl_dict
|
||||
|
||||
def _generate_eval_config(self, root_node, reader):
|
||||
def _generate_eval_config(self, root_nodes, reader):
|
||||
'''
|
||||
Generates the configuration file for write action.
|
||||
:param root_node: the node to evaluate.
|
||||
:param root_nodes: the node to evaluate.
|
||||
:param reader: the reader used to load the data, None if the network does not have input
|
||||
'''
|
||||
model_description, has_input, readers = root_node.to_config()
|
||||
description, has_inputs, readers = self._generate_config(root_nodes)
|
||||
if reader:
|
||||
readers.append(reader)
|
||||
|
||||
if not has_input and not readers:
|
||||
if not has_inputs and not readers:
|
||||
# add dummy input to keep CNTK happy
|
||||
# TODO relieve this requirement on CNTK side
|
||||
data = [[1, 2], [3, 4]]
|
||||
|
@ -203,7 +242,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
from .ops.cntk1 import Input
|
||||
dummy_input_node = Input(2, var_name='dummy_node')
|
||||
reader.add_input(dummy_input_node, 0, 2)
|
||||
model_description += "\n" + " "*MODEL_INDENTATION + "dummy_node = Input(2, tag='output')"
|
||||
description += "\n" + " "*MODEL_INDENTATION + "dummy_node = Input(2, tag='output')"
|
||||
readers.append(reader)
|
||||
|
||||
tmpl = open(CNTK_EVAL_TEMPLATE_PATH, "r").read()
|
||||
|
@ -212,7 +251,7 @@ class AbstractContext(object, metaclass=ABCMeta):
|
|||
'DevideId': self.device_id,
|
||||
'NodeUnitTest': self.node_unit_test,
|
||||
'OutputFile': output_filename,
|
||||
'ModelDescription': model_description,
|
||||
'ModelDescription': description,
|
||||
'Reader': '\n'.join(r.generate_config() for r in readers),
|
||||
}
|
||||
return tmpl % tmpl_dict
|
||||
|
@ -317,7 +356,10 @@ class Context(AbstractContext):
|
|||
can attach a reader directly to the input node.
|
||||
'''
|
||||
config_content = self._generate_test_config(reader)
|
||||
return self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
|
||||
output = self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)
|
||||
|
||||
return Context._parse_test_result(output)
|
||||
|
||||
|
||||
def predict(self, reader=None):
|
||||
'''
|
||||
|
@ -422,6 +464,41 @@ class Context(AbstractContext):
|
|||
|
||||
return list_of_tensors
|
||||
|
||||
TEST_RESULT_REGEX = re.compile('(?P<name>[^:]+): [^=]+ = (?P<number>[0-9.]+)')
|
||||
|
||||
@staticmethod
|
||||
def _parse_test_result(output):
|
||||
result = {}
|
||||
|
||||
PREAMPLE = 'Final Results: Minibatch[1-1]: '
|
||||
for line in output.splitlines():
|
||||
|
||||
if not line.startswith(PREAMPLE):
|
||||
continue
|
||||
|
||||
line = line[len(PREAMPLE):]
|
||||
|
||||
if not line.startswith('SamplesSeen = '):
|
||||
raise ValueError('expected SamplesSeen but got "%s"'%line)
|
||||
|
||||
line = line[len('SamplesSeen = '):]
|
||||
number_ends = line.index(' ')
|
||||
result['SamplesSeen'] = int(line[:number_ends])
|
||||
line = line[number_ends:]
|
||||
|
||||
perplexity_idx = line.index('Perplexity = ')
|
||||
result['Perplexity'] = float(line[perplexity_idx+len('Perplexity = '):])
|
||||
|
||||
line = line[:perplexity_idx]
|
||||
|
||||
mo = Context.TEST_RESULT_REGEX.match(line)
|
||||
while mo:
|
||||
result[mo.group('name').strip()] = float(mo.group('number').strip())
|
||||
line = line[mo.span()[1]:]
|
||||
mo = Context.TEST_RESULT_REGEX.match(line)
|
||||
|
||||
return result
|
||||
|
||||
def _calc_expected_shape_and_size(self, node, data, shapes):
|
||||
'''
|
||||
Calculates the expected shape and size from the CNTK output and the
|
||||
|
|
|
@ -0,0 +1,500 @@
|
|||
3.854499 4.163941 1
|
||||
1.058121 1.204858 0
|
||||
1.870621 1.284107 0
|
||||
1.134650 1.651822 0
|
||||
5.420541 4.557660 1
|
||||
6.042731 3.375708 1
|
||||
5.667109 2.811728 1
|
||||
0.232070 1.814821 0
|
||||
-0.647150 -1.612478 0
|
||||
2.626172 5.321667 1
|
||||
1.359751 2.056849 0
|
||||
3.534476 6.011925 1
|
||||
4.871508 2.245406 1
|
||||
4.977201 6.092787 1
|
||||
1.597508 2.110568 0
|
||||
2.099170 0.073616 0
|
||||
0.638281 -0.171881 0
|
||||
4.606747 4.092115 1
|
||||
5.168790 4.673153 1
|
||||
5.084637 4.435160 1
|
||||
3.379607 2.765107 1
|
||||
3.992242 2.799751 1
|
||||
1.807728 0.205914 0
|
||||
1.946180 0.303569 0
|
||||
0.218267 1.301271 0
|
||||
4.932840 2.117177 1
|
||||
3.739489 2.458558 1
|
||||
1.597743 -2.192362 0
|
||||
3.582005 3.350572 1
|
||||
3.930642 5.733507 1
|
||||
5.747863 3.739415 1
|
||||
-0.631374 2.314482 0
|
||||
0.866484 0.363432 0
|
||||
0.293501 0.347385 0
|
||||
4.544393 4.699040 1
|
||||
-0.242005 0.926520 0
|
||||
3.637198 5.238140 1
|
||||
-0.269463 1.525586 0
|
||||
0.682529 -0.703649 0
|
||||
3.562643 -0.126556 0
|
||||
2.671530 3.729066 1
|
||||
4.034716 3.458366 1
|
||||
5.401503 3.117191 1
|
||||
1.157177 1.183186 0
|
||||
0.778963 1.394348 0
|
||||
4.599715 2.297663 1
|
||||
4.532568 4.568362 1
|
||||
1.785478 -0.213185 0
|
||||
4.617391 4.230360 1
|
||||
5.672957 3.668370 1
|
||||
4.267738 5.390780 1
|
||||
0.707751 2.955391 0
|
||||
0.791275 1.654795 0
|
||||
1.760541 0.976920 0
|
||||
4.543920 2.222765 1
|
||||
4.515881 6.199021 1
|
||||
3.645005 3.611395 1
|
||||
0.965049 1.737265 0
|
||||
-1.779455 1.595554 0
|
||||
-0.484797 -0.559924 0
|
||||
2.944180 4.429239 1
|
||||
3.326649 4.412622 1
|
||||
4.275101 2.143945 1
|
||||
1.173035 0.641844 0
|
||||
4.003884 3.176954 1
|
||||
1.960240 -0.244709 0
|
||||
0.320283 2.115552 0
|
||||
2.303185 3.047043 1
|
||||
0.993086 0.074009 0
|
||||
5.599144 3.857344 1
|
||||
5.325894 3.931000 1
|
||||
2.840053 4.781688 1
|
||||
4.142453 3.405830 1
|
||||
1.084043 1.589581 0
|
||||
2.795705 2.319276 1
|
||||
1.980552 0.717780 0
|
||||
1.875956 -0.571905 0
|
||||
2.013802 1.694811 0
|
||||
4.690795 2.183334 1
|
||||
4.321816 1.876459 1
|
||||
4.088717 4.394346 1
|
||||
4.991936 4.299770 1
|
||||
2.592315 4.783210 1
|
||||
0.703270 2.541733 0
|
||||
0.467768 -0.007592 0
|
||||
1.694096 -0.570847 0
|
||||
2.255603 0.663395 0
|
||||
1.300394 1.518341 0
|
||||
4.354786 4.501928 1
|
||||
1.474162 0.603113 0
|
||||
1.340782 0.637653 0
|
||||
-0.351240 0.501893 0
|
||||
4.918587 5.366305 1
|
||||
2.242199 -0.916682 0
|
||||
-0.161858 0.448384 0
|
||||
1.659615 1.524191 0
|
||||
3.072670 1.703225 0
|
||||
0.003256 -0.306702 0
|
||||
-1.792094 1.193539 0
|
||||
7.200298 3.962190 1
|
||||
4.220305 4.190289 1
|
||||
4.096599 3.264797 1
|
||||
-0.674145 0.751491 0
|
||||
3.215213 4.549768 1
|
||||
1.522988 3.311437 0
|
||||
4.393445 1.822070 1
|
||||
1.991048 1.429309 0
|
||||
4.741012 3.169984 1
|
||||
2.563678 1.798587 0
|
||||
3.310656 3.600789 1
|
||||
0.559119 -0.193984 0
|
||||
3.182626 3.279566 1
|
||||
0.145061 1.428861 0
|
||||
5.748625 2.766672 1
|
||||
1.612338 -0.441931 0
|
||||
0.521950 0.355267 0
|
||||
4.284910 3.874950 1
|
||||
4.911425 3.054658 1
|
||||
2.946163 0.502614 0
|
||||
4.381390 2.600999 1
|
||||
0.585791 -0.528432 0
|
||||
1.329802 -0.076910 0
|
||||
0.860040 1.153562 0
|
||||
0.930515 -0.257435 0
|
||||
2.775174 0.751338 0
|
||||
2.429059 0.615483 0
|
||||
2.546002 1.132210 0
|
||||
5.059000 3.423829 1
|
||||
1.303533 0.013015 0
|
||||
2.160149 -0.400779 0
|
||||
5.038046 3.027673 1
|
||||
4.583471 5.379319 1
|
||||
5.608845 2.082021 1
|
||||
3.406426 3.326734 1
|
||||
4.267102 3.866177 1
|
||||
1.799669 0.489094 0
|
||||
1.807634 2.029468 0
|
||||
1.536463 1.053052 0
|
||||
5.653295 3.369125 1
|
||||
2.493326 0.794542 0
|
||||
1.528977 0.961929 0
|
||||
1.973016 0.696162 0
|
||||
2.283974 0.198255 0
|
||||
5.227293 4.395268 1
|
||||
5.302484 4.021613 1
|
||||
6.223076 4.537934 1
|
||||
1.460204 -1.055539 0
|
||||
2.985097 4.228990 1
|
||||
1.685054 0.499576 0
|
||||
0.521659 0.510605 0
|
||||
1.891089 1.284388 0
|
||||
4.620926 3.662371 1
|
||||
1.613905 -0.770152 0
|
||||
6.007418 4.755721 1
|
||||
0.798078 -0.304557 0
|
||||
5.242706 2.099872 1
|
||||
1.518268 -0.858963 0
|
||||
3.733642 4.244483 1
|
||||
0.970367 -1.534686 0
|
||||
1.334952 2.250191 0
|
||||
2.252214 3.343515 1
|
||||
3.982213 4.457969 1
|
||||
5.086620 3.180442 1
|
||||
0.005277 0.197319 0
|
||||
2.999128 2.909942 1
|
||||
2.412666 2.046286 0
|
||||
2.044537 3.416533 1
|
||||
2.650439 3.372171 1
|
||||
2.480446 1.327368 0
|
||||
4.824915 5.603495 1
|
||||
0.759204 0.531043 0
|
||||
1.965476 1.372763 0
|
||||
1.000248 1.208139 0
|
||||
1.979980 -0.446807 0
|
||||
0.528053 1.178535 0
|
||||
5.442396 3.969797 1
|
||||
-0.145691 1.375993 0
|
||||
1.336725 -0.006089 0
|
||||
5.291797 3.250537 1
|
||||
4.286453 1.117735 1
|
||||
-0.928654 -0.925485 0
|
||||
3.332391 2.603963 1
|
||||
3.215562 4.756808 1
|
||||
1.610967 0.830856 0
|
||||
2.174433 3.501271 1
|
||||
4.848584 4.251824 1
|
||||
0.810184 1.152021 0
|
||||
4.873924 4.517936 1
|
||||
1.915303 1.649095 0
|
||||
1.623343 -0.081105 0
|
||||
1.944076 0.482732 0
|
||||
2.442956 1.254540 0
|
||||
-1.002581 1.265333 0
|
||||
0.959354 0.678516 0
|
||||
-0.478621 2.502554 0
|
||||
3.357642 2.993470 1
|
||||
5.741979 2.958477 1
|
||||
4.474261 3.260622 1
|
||||
3.587932 4.572091 1
|
||||
1.274866 0.695311 0
|
||||
4.557162 4.754880 1
|
||||
0.557867 0.280893 0
|
||||
1.832047 -2.162059 0
|
||||
3.904049 5.257427 1
|
||||
3.225019 3.845294 1
|
||||
4.451218 4.125344 1
|
||||
3.138143 2.869685 1
|
||||
4.451703 3.430654 1
|
||||
0.124060 1.422203 0
|
||||
4.692774 5.156611 1
|
||||
0.735314 0.375099 0
|
||||
0.727577 1.158726 0
|
||||
0.643469 0.283426 0
|
||||
5.126834 1.929468 1
|
||||
-0.172361 2.982370 0
|
||||
3.957745 1.561874 1
|
||||
5.563733 3.417080 1
|
||||
5.181533 1.465063 1
|
||||
5.843654 5.040710 1
|
||||
0.761570 0.171094 0
|
||||
3.163795 3.940869 1
|
||||
2.435362 1.047614 0
|
||||
2.524330 3.602348 1
|
||||
4.200838 3.267377 1
|
||||
4.249560 2.926280 1
|
||||
0.060257 0.295729 0
|
||||
1.528257 1.651867 0
|
||||
2.030978 1.566011 0
|
||||
4.065243 4.375190 1
|
||||
1.406204 0.238570 0
|
||||
1.229776 1.186559 0
|
||||
2.295681 1.883864 0
|
||||
3.966570 4.293142 1
|
||||
1.713323 0.534886 0
|
||||
0.772032 -0.096214 0
|
||||
3.392854 5.195064 1
|
||||
5.063653 2.749764 1
|
||||
1.410392 1.694554 0
|
||||
0.540269 0.376759 0
|
||||
4.103946 3.870140 1
|
||||
5.132739 3.079176 1
|
||||
2.524063 0.486934 0
|
||||
0.046403 1.452778 0
|
||||
1.705593 0.243750 0
|
||||
1.621902 0.203138 0
|
||||
-0.420733 0.589060 0
|
||||
2.887145 2.621849 1
|
||||
5.545509 4.473069 1
|
||||
0.326439 -0.162102 0
|
||||
0.906097 -0.018566 0
|
||||
3.398280 5.125843 1
|
||||
0.833088 -0.808535 0
|
||||
4.535285 4.133511 1
|
||||
1.781705 4.123651 1
|
||||
4.345894 3.355084 1
|
||||
4.770073 3.007432 1
|
||||
2.537267 3.813503 1
|
||||
0.994347 2.567949 0
|
||||
0.337262 -0.224479 0
|
||||
4.936596 3.107819 1
|
||||
2.177957 -0.544641 0
|
||||
3.434811 2.806362 1
|
||||
3.172973 4.378089 1
|
||||
4.015349 3.000845 1
|
||||
3.640748 3.917499 1
|
||||
5.432434 4.092587 1
|
||||
4.701984 4.063092 1
|
||||
3.978015 3.584431 1
|
||||
5.029923 2.346036 1
|
||||
4.939017 3.209084 1
|
||||
3.999592 2.747525 1
|
||||
5.233483 4.877698 1
|
||||
2.260049 1.023384 0
|
||||
-1.149943 1.257165 0
|
||||
-0.026270 0.468090 0
|
||||
5.155107 4.620842 1
|
||||
4.179414 4.807546 1
|
||||
2.560286 0.526253 0
|
||||
5.843334 1.439470 1
|
||||
4.417442 4.483117 1
|
||||
4.354138 4.496168 1
|
||||
0.873730 2.230023 0
|
||||
4.531298 4.944164 1
|
||||
2.010164 -0.358403 0
|
||||
1.165044 1.376602 0
|
||||
1.451538 -0.197779 0
|
||||
-1.751961 0.210820 0
|
||||
2.431281 3.878465 1
|
||||
3.311168 3.697618 1
|
||||
2.324742 -0.330745 0
|
||||
1.447031 1.028776 0
|
||||
0.711003 2.631227 0
|
||||
4.872934 3.406132 1
|
||||
2.419345 0.297983 0
|
||||
0.437814 2.851194 0
|
||||
3.105758 4.098041 1
|
||||
5.310168 3.519401 1
|
||||
1.218607 -1.505891 0
|
||||
6.053827 2.848790 1
|
||||
3.475758 3.352349 1
|
||||
0.911730 -0.213069 0
|
||||
1.255973 0.089677 0
|
||||
4.152711 3.871858 1
|
||||
3.003909 3.288998 1
|
||||
0.291281 1.124965 0
|
||||
2.155017 0.550642 0
|
||||
3.494102 0.710991 0
|
||||
4.376613 2.330150 1
|
||||
4.707851 6.179972 1
|
||||
0.614240 -0.243535 0
|
||||
1.130049 0.870765 0
|
||||
3.994615 2.855247 1
|
||||
1.556420 0.106179 0
|
||||
3.182309 5.121422 1
|
||||
2.315933 0.418897 0
|
||||
1.797904 0.633645 0
|
||||
4.012446 3.887718 1
|
||||
2.106849 3.776831 1
|
||||
4.477828 3.989422 1
|
||||
2.871290 4.610706 1
|
||||
5.317459 5.621137 1
|
||||
2.265963 -0.095395 0
|
||||
2.963642 2.804267 1
|
||||
5.859384 3.673343 1
|
||||
6.365340 3.541960 1
|
||||
1.450987 0.721751 0
|
||||
4.641593 2.436289 1
|
||||
-0.126649 0.101750 0
|
||||
1.835293 1.594895 0
|
||||
2.121195 0.152643 0
|
||||
1.881799 1.169974 0
|
||||
2.421852 -0.089441 0
|
||||
0.110206 -1.491046 0
|
||||
6.200556 4.284843 1
|
||||
3.545593 5.217408 1
|
||||
3.365187 2.790974 1
|
||||
6.493131 5.311132 1
|
||||
0.800791 0.229630 0
|
||||
4.975666 4.214251 1
|
||||
1.562586 0.181976 0
|
||||
0.899273 0.003180 0
|
||||
6.064242 3.482802 1
|
||||
1.777259 2.498596 0
|
||||
5.479965 5.168898 1
|
||||
4.671380 3.356556 1
|
||||
1.730588 0.417775 0
|
||||
2.463118 -0.305587 0
|
||||
3.967679 0.361350 0
|
||||
0.164925 -0.167591 0
|
||||
4.777002 3.088492 1
|
||||
2.049808 3.096552 0
|
||||
1.416130 -1.043606 0
|
||||
0.318913 -1.539956 0
|
||||
6.004351 2.521442 1
|
||||
2.969229 3.311301 1
|
||||
0.879291 0.094171 0
|
||||
5.290177 5.198102 1
|
||||
-0.305314 0.826116 0
|
||||
2.091880 -1.176581 0
|
||||
2.816867 2.875016 1
|
||||
0.486424 -1.055319 0
|
||||
3.012812 4.530291 1
|
||||
1.137009 1.323397 0
|
||||
0.088114 -0.353501 0
|
||||
1.174005 0.188025 0
|
||||
1.928114 1.398347 0
|
||||
0.128505 1.430034 0
|
||||
2.021187 0.577234 0
|
||||
1.361335 0.394605 0
|
||||
5.125811 4.221355 1
|
||||
0.260733 1.758422 0
|
||||
2.106970 0.305971 0
|
||||
3.675850 5.051226 1
|
||||
2.105405 0.240527 0
|
||||
3.072167 3.130910 1
|
||||
0.987479 0.036861 0
|
||||
-0.271382 0.094250 0
|
||||
4.703495 2.620398 1
|
||||
3.005831 2.220124 1
|
||||
5.072896 1.477152 1
|
||||
4.443991 3.679157 1
|
||||
0.845034 0.419956 0
|
||||
4.698964 3.109439 1
|
||||
1.766144 0.595496 0
|
||||
2.046076 0.433007 0
|
||||
0.874663 1.010155 0
|
||||
4.939031 5.340021 1
|
||||
3.881158 3.072467 1
|
||||
2.928763 4.160337 1
|
||||
5.582289 4.805588 1
|
||||
3.180992 3.459563 1
|
||||
-0.486820 -0.074926 0
|
||||
4.091057 2.402846 1
|
||||
4.915464 4.543850 1
|
||||
1.492434 0.588755 0
|
||||
2.594011 0.332043 0
|
||||
0.317571 -0.525159 0
|
||||
3.936029 4.312181 1
|
||||
1.918811 -0.659594 0
|
||||
2.657582 0.028525 0
|
||||
4.637282 3.562483 1
|
||||
-0.097472 1.250080 0
|
||||
1.340281 -1.399129 0
|
||||
4.330372 3.140502 1
|
||||
4.358103 3.760854 1
|
||||
3.897352 4.806873 1
|
||||
4.962704 4.692459 1
|
||||
1.667918 -0.134096 0
|
||||
4.929650 1.727842 1
|
||||
2.434315 3.000448 1
|
||||
1.179167 1.894836 0
|
||||
0.190498 0.655592 0
|
||||
3.408802 4.843020 1
|
||||
4.497565 3.844998 1
|
||||
-0.501596 1.561013 0
|
||||
4.158981 4.875362 1
|
||||
4.017462 4.655003 1
|
||||
3.319263 3.462037 1
|
||||
2.635572 1.022114 0
|
||||
2.638164 5.051437 1
|
||||
4.875001 3.592322 1
|
||||
-0.276607 0.800369 0
|
||||
4.351591 3.321136 1
|
||||
3.699848 3.317014 1
|
||||
4.947319 4.252134 1
|
||||
4.146336 2.162761 1
|
||||
5.231704 5.477804 1
|
||||
3.302101 3.994218 1
|
||||
-0.249349 2.069960 0
|
||||
4.705134 3.921461 1
|
||||
4.652980 4.287917 1
|
||||
3.937259 -0.334385 0
|
||||
3.257619 2.758094 1
|
||||
0.994191 3.135344 0
|
||||
4.649768 2.123305 1
|
||||
1.634135 0.241517 0
|
||||
1.682542 2.057739 1
|
||||
5.163117 4.467304 1
|
||||
4.638594 4.141250 1
|
||||
1.392605 0.635603 0
|
||||
4.319784 2.965064 1
|
||||
1.872466 1.566002 0
|
||||
4.230714 5.179026 1
|
||||
2.635294 3.470599 1
|
||||
0.988464 0.943613 0
|
||||
0.897546 0.129141 0
|
||||
3.370731 2.019838 0
|
||||
1.424812 0.081647 0
|
||||
5.961444 3.372419 1
|
||||
2.839070 0.926229 0
|
||||
0.279132 1.607793 0
|
||||
5.351031 3.693640 1
|
||||
2.637437 1.951445 0
|
||||
-0.179258 0.349339 0
|
||||
3.246295 1.013459 0
|
||||
5.839643 4.556761 1
|
||||
1.435225 0.937185 0
|
||||
0.500440 0.348246 0
|
||||
4.948782 4.994416 1
|
||||
0.810541 0.456830 0
|
||||
5.098827 4.142789 1
|
||||
2.365307 0.729496 0
|
||||
-0.117730 0.891913 0
|
||||
0.485735 0.513485 0
|
||||
0.680270 1.486851 0
|
||||
1.143053 0.227480 0
|
||||
6.615446 4.561501 1
|
||||
1.016051 1.862106 0
|
||||
0.668177 -0.212610 0
|
||||
2.906047 2.415627 1
|
||||
5.576097 5.068683 1
|
||||
1.315063 -0.040980 0
|
||||
5.375285 3.306877 1
|
||||
4.549934 3.805014 1
|
||||
1.189238 0.661279 0
|
||||
4.156567 3.280736 1
|
||||
2.061355 1.090958 0
|
||||
4.499387 3.640263 1
|
||||
3.503883 1.015591 0
|
||||
0.390200 -1.037188 0
|
||||
2.922873 4.696711 1
|
||||
1.803928 3.846808 1
|
||||
0.907921 -2.139287 0
|
||||
1.640739 0.592793 0
|
||||
5.108193 3.194757 1
|
||||
4.297873 4.034234 1
|
||||
4.832678 4.073469 1
|
||||
4.391764 3.557895 1
|
||||
2.006343 0.836557 0
|
||||
0.351400 1.534742 0
|
||||
4.933823 2.937944 1
|
||||
3.926482 2.073712 1
|
||||
5.382385 4.818642 1
|
||||
4.739010 3.213326 1
|
||||
0.026227 0.177150 0
|
||||
5.001353 3.300961 1
|
||||
5.022782 2.921902 1
|
||||
4.225051 4.534986 1
|
||||
3.745148 -0.169000 0
|
||||
5.891838 2.817417 1
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -4,27 +4,47 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
|||
|
||||
from cntk import *
|
||||
|
||||
if (__name__ == "__main__"):
|
||||
X = Input(2)
|
||||
X.attach_uci_fast_reader("Train-3Classes.txt", 0)
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
|
||||
# Using data from https://github.com/Microsoft/CNTK/wiki/Tutorial
|
||||
train_file = os.path.join(cur_dir, "Train-3Classes.txt")
|
||||
test_file = os.path.join(cur_dir, "Test-3Classes.txt")
|
||||
mapping_file = os.path.join(cur_dir, "SimpleMapping-3Classes.txt")
|
||||
|
||||
def train_eval_logreg(criterion_name, eval_name):
|
||||
X = Input(2)
|
||||
y = Input(3)
|
||||
y.attach_uci_fast_reader(
|
||||
"Train-3Classes.txt", 2, True, 1, "SimpleMapping-3Classes.txt")
|
||||
|
||||
W = LearnableParameter(3, 2)
|
||||
b = LearnableParameter(3, 1)
|
||||
|
||||
out = Times(W, X) + b
|
||||
out.tag = 'output'
|
||||
ce = CrossEntropyWithSoftmax(y, out)
|
||||
ce = CrossEntropyWithSoftmax(y, out, var_name=criterion_name)
|
||||
ce.tag = 'criterion'
|
||||
eval = SquareError(y, out, var_name=eval_name)
|
||||
eval.tag = 'eval'
|
||||
|
||||
my_sgd = SGD(
|
||||
epoch_size=0, minibatch_size=25, learning_ratesPerMB=0.1, max_epochs=3)
|
||||
|
||||
with Context('demo', root_node=ce, clean_up=False) as ctx:
|
||||
ctx.train(my_sgd, None)
|
||||
with Context('demo', root_nodes=[ce,eval], clean_up=False) as ctx:
|
||||
X.attach_uci_fast_reader(train_file, 0)
|
||||
y.attach_uci_fast_reader(train_file, 2, True, 1, mapping_file)
|
||||
ctx.train(my_sgd)
|
||||
|
||||
X.attach_uci_fast_reader(test_file, 0)
|
||||
y.attach_uci_fast_reader(test_file, 2, True, 1, mapping_file)
|
||||
result = ctx.test()
|
||||
|
||||
result = ctx.eval(out)
|
||||
print(result.argmax(axis=1))
|
||||
return result
|
||||
|
||||
def test_logreg():
|
||||
result = train_eval_logreg('crit_node', 'eval_node')
|
||||
assert result['SamplesSeen'] == 500
|
||||
assert result['Perplexity'] == 1.2216067
|
||||
assert result['eval_node'] == 13.779223
|
||||
assert result['crit_node'] == 0.20016696
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(train_eval_logreg())
|
||||
|
|
|
@ -9,7 +9,7 @@ class sparse(object):
|
|||
return hasattr(obj, 'todense')
|
||||
|
||||
from .utils import MODEL_INDENTATION
|
||||
from .utils import numpy_to_cntk_shape
|
||||
from .utils import numpy_to_cntk_shape, dedupe_readers
|
||||
|
||||
def _tuple_to_cntk_shape(shape):
|
||||
return ':'.join(str(v) for v in shape)
|
||||
|
@ -153,9 +153,8 @@ class ComputationNode(object):
|
|||
return is_loop_node and p_name == 'input' and isinstance(p_value, str)
|
||||
|
||||
def _to_config_recursively(self, desc, unrolled_nodes, inputs,
|
||||
readers, dep_inputs=None, node_counter=0):
|
||||
if dep_inputs is None:
|
||||
dep_inputs = tuple()
|
||||
readers, dep_inputs, node_counter,
|
||||
reconciled_cache):
|
||||
|
||||
param_variable_names = []
|
||||
# In case we have multiple unreconciled inputs, we will reconcile each
|
||||
|
@ -186,7 +185,7 @@ class ComputationNode(object):
|
|||
else:
|
||||
child_var, node_counter, child_desc, child_dep_inputs = pv._to_config_recursively(
|
||||
desc, unrolled_nodes, inputs, readers,
|
||||
dep_inputs, node_counter)
|
||||
dep_inputs, node_counter, reconciled_cache)
|
||||
|
||||
unrolled_nodes[pv] = child_var, dep_inputs
|
||||
|
||||
|
@ -200,10 +199,16 @@ class ComputationNode(object):
|
|||
first_unreconciled_input = pv
|
||||
|
||||
else:
|
||||
pv = ReconcileMBLayout(pv, first_unreconciled_input)
|
||||
child_var, node_counter, child_desc, dep_inputs = pv._to_config_recursively(
|
||||
desc, unrolled_nodes, inputs, readers,
|
||||
dep_inputs, node_counter)
|
||||
if (pv, first_unreconciled_input) in reconciled_cache:
|
||||
child_var, dep_inputs = reconciled_cache[(pv, first_unreconciled_input)]
|
||||
else:
|
||||
unrec_pv = pv
|
||||
pv = ReconcileMBLayout(unrec_pv, first_unreconciled_input)
|
||||
child_var, node_counter, child_desc, dep_inputs = pv._to_config_recursively(
|
||||
desc, unrolled_nodes, inputs, readers,
|
||||
dep_inputs, node_counter,
|
||||
reconciled_cache)
|
||||
reconciled_cache[(unrec_pv, first_unreconciled_input)] = pv.var_name, dep_inputs
|
||||
|
||||
unrolled_nodes[pv] = child_var, dep_inputs
|
||||
|
||||
|
@ -246,41 +251,38 @@ class ComputationNode(object):
|
|||
|
||||
return self.var_name, node_counter, desc, dep_inputs
|
||||
|
||||
def _to_config(self):
|
||||
def _to_config(self, description, unrolled_nodes, inputs, readers,
|
||||
dep_inputs, node_counter, reconciled_cache):
|
||||
'''
|
||||
Helper method to generate the CNTK configuration for this node.
|
||||
'''
|
||||
unrolled_nodes = {}
|
||||
inputs = set()
|
||||
readers = set()
|
||||
|
||||
var_name, node_counter, desc, dep_inputs = self._to_config_recursively(
|
||||
desc=[],
|
||||
description,
|
||||
unrolled_nodes=unrolled_nodes,
|
||||
inputs=inputs,
|
||||
readers=readers)
|
||||
readers=readers,
|
||||
dep_inputs=dep_inputs,
|
||||
node_counter=node_counter,
|
||||
reconciled_cache=reconciled_cache)
|
||||
|
||||
return var_name, node_counter, desc, len(inputs) > 0, readers
|
||||
|
||||
def _dedupe_readers(self, readers):
|
||||
import copy
|
||||
readers_map = {}
|
||||
for r in readers:
|
||||
filename = r['FileName']
|
||||
if filename in readers_map:
|
||||
readers_map[filename].inputs_def.extend(r.inputs_def)
|
||||
else:
|
||||
readers_map[filename] = copy.deepcopy(r)
|
||||
|
||||
return [r for r in readers_map.values()]
|
||||
return var_name, node_counter, desc, len(inputs) > 0, readers, dep_inputs
|
||||
|
||||
def to_config(self):
|
||||
'''
|
||||
Generate CNTK configuration for this node including the configuration
|
||||
for all dependent child nodes.
|
||||
'''
|
||||
var_name, node_counter, desc, has_inputs, readers = self._to_config()
|
||||
var_name, node_counter, desc, has_inputs, readers, dep_inputs = \
|
||||
self._to_config(description=[],
|
||||
unrolled_nodes={},
|
||||
inputs=set(),
|
||||
readers=set(),
|
||||
dep_inputs=tuple(),
|
||||
node_counter=0,
|
||||
reconciled_cache={})
|
||||
|
||||
return "\n".join(desc), has_inputs, self._dedupe_readers(readers)
|
||||
return "\n".join(desc), has_inputs, dedupe_readers(readers)
|
||||
|
||||
|
||||
class InputComputationNodeBase(ComputationNode, metaclass=ABCMeta):
|
||||
|
|
|
@ -55,7 +55,7 @@ Validating --> v5 = Plus (v3, v4) : [3 {1} x *], [3 x 1 {1,3}] -> [3 x 1 {1,3} x
|
|||
|
||||
assert Context._parse_shapes_from_output(output) == expected
|
||||
|
||||
def test_parse_result_output_1():
|
||||
def test_parse_eval_result_output_1():
|
||||
output = '''\
|
||||
0 |w.shape 1 1
|
||||
0 |w 60.000000
|
||||
|
@ -68,3 +68,14 @@ def test_parse_result_output_1():
|
|||
for res, exp in zip(list_of_tensors, expected):
|
||||
assert np.allclose(res, np.asarray(exp))
|
||||
|
||||
|
||||
def test_parse_test_result_output():
|
||||
output = '''\
|
||||
Final Results: Minibatch[1-1]: SamplesSeen = 500 v8: SquareError/Sample = 13.779223 v7: CrossEntropyWithSoftmax/Sample = 0.20016696 Perplexity = 1.2216067 '''
|
||||
result = Context._parse_test_result(output)
|
||||
|
||||
assert result['SamplesSeen'] == 500
|
||||
assert result['Perplexity'] == 1.2216067
|
||||
assert result['v8'] == 13.779223
|
||||
assert result['v7'] == 0.20016696
|
||||
assert len(result) == 4
|
||||
|
|
|
@ -40,3 +40,16 @@ def cntk_to_numpy_shape(shape):
|
|||
shape = (1,)
|
||||
|
||||
return shape
|
||||
|
||||
def dedupe_readers(readers):
|
||||
import copy
|
||||
readers_map = {}
|
||||
for r in readers:
|
||||
filename = r['FileName']
|
||||
if filename in readers_map:
|
||||
readers_map[filename].inputs_def.extend(r.inputs_def)
|
||||
else:
|
||||
readers_map[filename] = copy.deepcopy(r)
|
||||
|
||||
return [r for r in readers_map.values()]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче