Improving shape handling
This commit is contained in:
Родитель
4089b37869
Коммит
90b124c93a
|
@ -9,6 +9,7 @@ import shutil as sh
|
|||
from cntk.graph import ComputationNode
|
||||
from cntk.ops.cntk1 import NewReshape
|
||||
from cntk.utils import CNTK_EXECUTABLE_PATH, MODEL_INDENTATION
|
||||
from .utils import cntk_to_numpy_shape
|
||||
|
||||
CNTK_TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates")
|
||||
CNTK_TRAIN_TEMPLATE_PATH = os.path.join(
|
||||
|
@ -276,21 +277,25 @@ class Context(AbstractContext):
|
|||
retrieve the node shapes.
|
||||
'''
|
||||
filename = os.path.join(self.directory, config_file_name)
|
||||
with open(os.path.join(self.directory, filename), "w") as out:
|
||||
with open(os.path.join(self.directory, filename), 'w') as out:
|
||||
out.write(config_content)
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
[CNTK_EXECUTABLE_PATH, "configFile=%s" % filename],
|
||||
output_bytes = subprocess.check_output(
|
||||
[CNTK_EXECUTABLE_PATH, 'configFile=%s' % filename],
|
||||
stderr=subprocess.STDOUT)
|
||||
output = output_bytes.decode('utf-8')
|
||||
with open(os.path.join(self.directory, 'cntk.log'), 'w') as log:
|
||||
log.write(output)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output.decode("utf-8"), file=open('error.txt', 'w'))
|
||||
print(e.output.decode('utf-8'), file=open('error.txt', 'w'))
|
||||
raise
|
||||
|
||||
if not output:
|
||||
raise ValueError('no output returned')
|
||||
|
||||
return output.decode("utf-8")
|
||||
return output
|
||||
|
||||
def train(self, optimizer, reader=None, override_existing=True):
|
||||
'''
|
||||
|
@ -344,29 +349,123 @@ class Context(AbstractContext):
|
|||
if not mo:
|
||||
continue
|
||||
var_name, shape = mo.group('var_name'), mo.group('shape')
|
||||
# In Debug mode, an additional stride information is printed
|
||||
shape = Context.SHAPE_STRIDE_REGEX.sub('', shape)
|
||||
|
||||
shape_list = []
|
||||
for x in Context.SHAPE_STRIDE_REGEX.sub('', shape).split('x'):
|
||||
for x in shape.split('x'):
|
||||
x = x.strip()
|
||||
if x != '*':
|
||||
if x == '*':
|
||||
shape_list.append(np.NaN)
|
||||
else:
|
||||
shape_list.append(int(x))
|
||||
|
||||
var_shape[var_name] = tuple(shape_list)
|
||||
|
||||
return var_shape
|
||||
|
||||
def _eval(self, node, reader):
|
||||
# FIXME manually setting the tag to output might have side-effects
|
||||
node.tag = 'output'
|
||||
config_content = self._generate_eval_config(node, reader)
|
||||
output = self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
|
||||
shapes = Context._parse_shapes_from_output(output)
|
||||
@staticmethod
|
||||
def _parse_result_output(output):
|
||||
'''
|
||||
Assuming the data has been output using the output format in the
|
||||
configuration
|
||||
|
||||
out_name = os.path.join(
|
||||
self.directory, CNTK_OUTPUT_FILENAME + '.' + node.var_name)
|
||||
data = np.loadtxt(out_name)
|
||||
format = [
|
||||
# %x = shape, %d = sequenceId
|
||||
sequencePrologue=%d\t|w.shape %x\n%d\t|w\s
|
||||
sampleSeparator=\n%d\t|w\s
|
||||
elementSeparator=\s
|
||||
]
|
||||
|
||||
return data, shapes
|
||||
this method will parse the output of the form
|
||||
|
||||
0 |w.shape 1 1
|
||||
0 |w 60.000000
|
||||
1 |w.shape 1 2
|
||||
1 |w 22.000000
|
||||
1 |w 24.000000
|
||||
|
||||
and return a list of tensors.
|
||||
'''
|
||||
|
||||
last_seq_idx = None
|
||||
list_of_tensors = []
|
||||
tensor_seq = []
|
||||
shape = None
|
||||
for line in output.splitlines():
|
||||
parts = line.split('|')
|
||||
|
||||
seq_idx = parts[0].strip()
|
||||
payload = parts[1]
|
||||
info, *data = payload.split(' ')
|
||||
|
||||
if seq_idx != last_seq_idx:
|
||||
if not info == 'w.shape':
|
||||
raise ValueError('expected shape information, but got "%s"'%line)
|
||||
|
||||
if tensor_seq:
|
||||
list_of_tensors.append(np.asarray(tensor_seq))
|
||||
tensor_seq = []
|
||||
|
||||
last_seq_idx = seq_idx
|
||||
|
||||
shape = cntk_to_numpy_shape(data)
|
||||
|
||||
continue
|
||||
else:
|
||||
data = np.asarray(data, dtype=float).reshape(shape)
|
||||
|
||||
tensor_seq.append(data)
|
||||
|
||||
list_of_tensors.append(np.asarray(tensor_seq))
|
||||
|
||||
return list_of_tensors
|
||||
|
||||
def _calc_expected_shape_and_size(self, node, data, shapes):
|
||||
'''
|
||||
Calculates the expected shape and size from the CNTK output and the
|
||||
retrieved data.
|
||||
|
||||
:param node: the node that was evaluated.
|
||||
:param data: the resulting data from `eval()`
|
||||
:param shapes: dictionary of node names to shape tuples
|
||||
|
||||
Returns the expected size and shape
|
||||
'''
|
||||
|
||||
# We got a single-dimensional array back, so we have to check whether
|
||||
# we need to reshape it based on CNTK's shape output.
|
||||
|
||||
expected_shape = np.asarray(shapes[node.var_name])
|
||||
|
||||
if sum(np.isnan(expected_shape))>1:
|
||||
raise ValueError("for node '%s' we received shape '%s', but " +
|
||||
"at most one dimension can be left unspecified."%\
|
||||
(node.var_name, expected_shape))
|
||||
|
||||
expected_size = np.multiply.reduce(expected_shape[~np.isnan(expected_shape)])
|
||||
if sum(np.isnan(expected_shape))==1:
|
||||
if data.size == expected_size:
|
||||
# We received all the data we need, so we have sequences of
|
||||
# length 1. For convenience, we ignore it.
|
||||
expected_shape = expected_shape[~np.isnan(expected_shape)]
|
||||
|
||||
elif data.size > expected_size:
|
||||
# We can fill in the missing dimensions
|
||||
missing_dimension = data.size / expected_size
|
||||
if int(missing_dimension) != missing_dimension:
|
||||
raise ValueError('could not infer the missing dimensions')
|
||||
|
||||
expected_shape[np.isnan(expected_shape)] = missing_dimension
|
||||
expected_size = np.multiply.reduce(expected_shape)
|
||||
# Now we have expected_size == data.size
|
||||
else:
|
||||
raise ValueError('unable to retrieve expected size')
|
||||
|
||||
# Move last dimension to the beginning: this is the time dimension
|
||||
#expected_shape = np.roll(expected_shape, 1)
|
||||
|
||||
return expected_shape, expected_size
|
||||
|
||||
def eval(self, node, reader=None):
|
||||
'''
|
||||
|
@ -383,24 +482,22 @@ class Context(AbstractContext):
|
|||
raise ValueError(
|
||||
'node is not of type ComputationNode, but %s' % type(node))
|
||||
|
||||
data, shapes = self._eval(node, reader)
|
||||
# Taking note of the original tag of this node to restore it later
|
||||
orig_node_tag = node.tag if hasattr(node, 'tag') else None
|
||||
node.tag = 'output'
|
||||
|
||||
expected_size = np.multiply.reduce(shapes[node.var_name])
|
||||
expected_shape = shapes[node.var_name]
|
||||
config_content = self._generate_eval_config(node, reader)
|
||||
output = self._call_cntk(CNTK_EVAL_CONFIG_FILENAME, config_content)
|
||||
|
||||
receieved_all = data.size == expected_size
|
||||
if not receieved_all:
|
||||
# For some reason the CNTK write action has issues with multi-row
|
||||
# output. So we have to CNTK reshape it to one row and do it again,
|
||||
# but then NumPy reshape using node's expected shape.
|
||||
node.tag = orig_node_tag
|
||||
|
||||
reshaped = NewReshape(node, expected_size)
|
||||
data, _ = self._eval(reshaped, reader)
|
||||
shapes = Context._parse_shapes_from_output(output)
|
||||
|
||||
if not (len(expected_shape) == 2 and expected_shape[1] == 1):
|
||||
# CNTK outputs e.g. [2 x 1] although it is just a vector.
|
||||
# TODO find better way to distinguis between
|
||||
data = data.reshape(expected_shape)
|
||||
out_name = os.path.join(
|
||||
self.directory, CNTK_OUTPUT_FILENAME + '.' + node.var_name)
|
||||
#data = np.loadtxt(out_name)
|
||||
result_content = open(out_name).read()
|
||||
data = Context._parse_result_output(result_content)
|
||||
|
||||
return data
|
||||
|
||||
|
|
|
@ -26,5 +26,5 @@ if (__name__ == "__main__"):
|
|||
with Context('demo', root_node=ce, clean_up=False) as ctx:
|
||||
ctx.train(my_sgd, None)
|
||||
|
||||
#result = ctx.eval(out)
|
||||
# print(result.argmax(axis=1))
|
||||
result = ctx.eval(out)
|
||||
print(result.argmax(axis=1))
|
||||
|
|
|
@ -9,6 +9,7 @@ class sparse(object):
|
|||
return hasattr(obj, 'todense')
|
||||
|
||||
from .utils import MODEL_INDENTATION
|
||||
from .utils import numpy_to_cntk_shape
|
||||
|
||||
def _tuple_to_cntk_shape(shape):
|
||||
return ':'.join(str(v) for v in shape)
|
||||
|
@ -305,36 +306,44 @@ from .reader import UCIFastReader, CNTKTextFormatReader
|
|||
# redefine some operators to work with NumPy and sequences as input
|
||||
|
||||
|
||||
def _dense_seq_to_str(seq):
|
||||
return ' '.join(seq.astype(np.str))
|
||||
def _dense_to_str(data):
|
||||
return ' '.join(data.ravel().astype(np.str))
|
||||
|
||||
|
||||
def _sparse_seq_to_str(seq):
|
||||
# return ' '.join('%s:%s'%(k,seq[k]) for k in sorted(seq.items()))
|
||||
def _sparse_to_str(data):
|
||||
# return ' '.join('%s:%s'%(k,data[k]) for k in sorted(data.items()))
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _seq_to_text_format(sequences, alias):
|
||||
def _tensor_to_text_format(idx, alias, tensor, has_sequence_dimension=True):
|
||||
'''
|
||||
`sequences` is a NumPy array
|
||||
Converts a NumPy array representing tensor of one input into a format that
|
||||
is readable by `CNTKTextReader`.
|
||||
|
||||
:param `alias`: alias to be used in the temporary file
|
||||
:param `tensor`: a NumPy array having sequence as its innermost dimension
|
||||
'''
|
||||
if not alias or not isinstance(alias, str):
|
||||
if not alias:
|
||||
raise ValueError('alias is missing')
|
||||
|
||||
first_elem = sequences[0]
|
||||
if isinstance(first_elem, np.ndarray):
|
||||
seq_to_str = _dense_seq_to_str
|
||||
elif sparse.issparse(first_elem):
|
||||
seq_to_str = _sparse_seq_to_str
|
||||
if isinstance(tensor, np.ndarray):
|
||||
to_str = _dense_to_str
|
||||
elif sparse.issparse(tensor):
|
||||
raise ValueError('sparse is not yet supported')
|
||||
#to_str = _sparse_to_str
|
||||
else:
|
||||
raise ValueError(
|
||||
'sequence elements have to be of type numpy.ndarray (dense) or dictionary (sparse), you gave "%s"' % str(first_elem))
|
||||
raise ValueError('sequence elements have to be of type numpy.ndarray' +
|
||||
' (dense) or dictionary (sparse), you gave "%s"' % \
|
||||
str(type(tensor)))
|
||||
|
||||
lines = []
|
||||
for idx, seq in enumerate(sequences):
|
||||
lines.append('%i|%s %s' % (idx, alias, seq_to_str(seq)))
|
||||
if has_sequence_dimension:
|
||||
num_seq_elements = tensor.shape[0]
|
||||
lines = []
|
||||
for seq_idx in range(0, num_seq_elements):
|
||||
lines.append('%i\t|%s %s'%(idx, alias, to_str(tensor[seq_idx])))
|
||||
|
||||
return '\n'.join(lines)
|
||||
return '\n'.join(lines)
|
||||
else:
|
||||
return '%i\t|%s %s'%(idx, alias, to_str(tensor))
|
||||
|
||||
|
||||
def _get_constant_node(value, **kw):
|
||||
|
@ -344,7 +353,7 @@ def _get_constant_node(value, **kw):
|
|||
|
||||
To be as generic as possible, we
|
||||
- flatten the data
|
||||
- initialize a LearnableParameter operator with it
|
||||
- initialize a ParameterTensor operator with it
|
||||
- ensure that the graph does not backprob to it.
|
||||
- Finally we to reshape it.
|
||||
'''
|
||||
|
@ -364,43 +373,41 @@ def _get_constant_node(value, **kw):
|
|||
prefix='_param_', suffix='.txt', dir=get_context().directory, delete=False)
|
||||
tf.close()
|
||||
|
||||
if isinstance(value, list):
|
||||
if isinstance(value, list) or np.isscalar(value):
|
||||
value = np.asarray(value)
|
||||
|
||||
if len(value.shape) == 1:
|
||||
# 1D list: interpret as one scalar per sample
|
||||
value = value[:, np.newaxis]
|
||||
|
||||
if sparse.issparse(value):
|
||||
raise ValueError('only dense data is supported')
|
||||
|
||||
with open(tf.name, 'w') as f:
|
||||
# TODO value.ravel() ?
|
||||
np.ndarray.flatten(value).tofile(f, sep='\n')
|
||||
|
||||
size = np.multiply.reduce(value.shape[:])
|
||||
|
||||
# The var_name specified by the user should be set to the operator that
|
||||
# is finally returned, which is the shape node.
|
||||
var_name = kw.pop('var_name', None)
|
||||
value.ravel().tofile(f, sep='\n')
|
||||
|
||||
from cntk.reader import CNTKTextFormatReader
|
||||
param_node = cntk1_ops.LearnableParameter(
|
||||
size,
|
||||
1,
|
||||
|
||||
cntk_shape = numpy_to_cntk_shape(value.shape)
|
||||
|
||||
dims = np.multiply.reduce(cntk_shape)
|
||||
|
||||
# TODO switch to ConstantTensor once it is in the core.bs file
|
||||
node = cntk1_ops.ParameterTensor(
|
||||
dims=dims,
|
||||
learningRateMultiplier=0.0,
|
||||
init='fromFile',
|
||||
initFromFilePath=tf.name,
|
||||
**kw)
|
||||
|
||||
reshape_node = cntk1_ops.NewReshape(param_node,
|
||||
dims=value.shape,
|
||||
var_name=var_name)
|
||||
if len(cntk_shape) > 1:
|
||||
node = cntk1_ops.NewReshape(node, dims=cntk_shape)
|
||||
|
||||
return reshape_node
|
||||
return node
|
||||
|
||||
|
||||
def _get_input_node(value, **kw):
|
||||
def _get_input_node(list_of_tensors, has_sequence_dimension, **kw):
|
||||
'''
|
||||
:param list_of_tensors: list of tensors potentially having sequences of
|
||||
different lengths.
|
||||
'''
|
||||
|
||||
# FIXME We need to better manage the context. How can we get hold
|
||||
# of the overall context without having to always pass it
|
||||
# explicitly?
|
||||
|
@ -415,13 +422,6 @@ def _get_input_node(value, **kw):
|
|||
dir=get_context().directory, delete=False)
|
||||
tf.close()
|
||||
|
||||
if isinstance(value, list):
|
||||
value = np.asarray(value)
|
||||
|
||||
if len(value.shape) == 1:
|
||||
# 1D list: interpret as one scalar per sample
|
||||
value = value[:, np.newaxis]
|
||||
|
||||
if 'alias' in kw:
|
||||
alias = kw['alias']
|
||||
del kw['alias'] # don't confuse with constructor's parameters
|
||||
|
@ -429,22 +429,51 @@ def _get_input_node(value, **kw):
|
|||
# TODO make sure we don't have clashes
|
||||
alias = '_I_%i' % np.random.randint(1000)
|
||||
|
||||
shapes = set()
|
||||
with open(tf.name, 'w') as f:
|
||||
f.write(_seq_to_text_format(value, alias))
|
||||
for idx,tensor in enumerate(list_of_tensors):
|
||||
if isinstance(tensor, list):
|
||||
tensor = np.asarray(tensor)
|
||||
|
||||
if has_sequence_dimension:
|
||||
# collecting the shapes ignoring the sequence dimension
|
||||
shapes.add(tensor.shape[1:])
|
||||
else:
|
||||
shapes.add(tensor.shape)
|
||||
|
||||
f.write(_tensor_to_text_format(idx, alias, tensor,
|
||||
has_sequence_dimension) + '\n')
|
||||
|
||||
# ignoring the sequence dimension, all shapes should be equal
|
||||
if len(shapes)!=1:
|
||||
raise ValueError('except for the sequence dimensions all shapes ' +
|
||||
'should be the same - instead we have: %s'%(", ".join(str(s) for s in shapes)))
|
||||
|
||||
# shapes now contains only one shape, which has the sequence dimension
|
||||
# removed.
|
||||
value_shape = shapes.pop()
|
||||
|
||||
cntk_shape = numpy_to_cntk_shape(value_shape)
|
||||
|
||||
from cntk.reader import CNTKTextFormatReader
|
||||
input_node = cntk1_ops.Input(value.shape, **kw)
|
||||
input_node.reader = CNTKTextFormatReader(tf.name)
|
||||
# In case we have the shape (2,3), which will be initialized at Input() as
|
||||
# '2:3', we have 2*3 = 6 dimensions when flattened out for the reader. Note
|
||||
# that the first dimension is the sample.
|
||||
dims = np.multiply.reduce(value.shape[:])
|
||||
input_node.reader.add_input(input_node, alias, dims)
|
||||
|
||||
return input_node
|
||||
# In case we have the shape (2,3) and assuming we have only sequences of
|
||||
# lengths 1, the input will be initialized with dim=3 (column major)
|
||||
# followed by a reshape node that has the dims '2:3'. So we have 2*3 = 6
|
||||
# dimensions when flattened out for the reader.
|
||||
dims = int(np.multiply.reduce(cntk_shape))
|
||||
node = cntk1_ops.Input(dims, **kw)
|
||||
node.reader = CNTKTextFormatReader(tf.name)
|
||||
node.reader.add_input(node, alias, dims)
|
||||
|
||||
if len(cntk_shape) > 1:
|
||||
node = cntk1_ops.NewReshape(node,
|
||||
dims=cntk_shape)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def is_sequence(data):
|
||||
def is_tensor_list(data):
|
||||
'''
|
||||
Checks whether the data is a CNTK sequence, which is expressed in Python as
|
||||
a list of varying sized NumPy objects.
|
||||
|
@ -455,7 +484,10 @@ def is_sequence(data):
|
|||
|
||||
def is_tensor(data):
|
||||
'''
|
||||
Checks whether the data is a tensor.
|
||||
Checks whether the data is a tensor, i.e. whether it is a NumPy array or a
|
||||
list of NumPy arrays.
|
||||
|
||||
:param `data`: data to check
|
||||
'''
|
||||
if isinstance(data, np.ndarray):
|
||||
return True
|
||||
|
@ -479,48 +511,29 @@ def is_tensor(data):
|
|||
return True
|
||||
|
||||
|
||||
def input(value, **kw):
|
||||
def input(value, has_sequence_dimension=True, **kw):
|
||||
'''
|
||||
Defining Input as a factory override that creates either a Constant()
|
||||
operator or an Input() operator based on the type of the `value`.
|
||||
Create an input node.
|
||||
|
||||
In case the `value` is a scalar, a normal CNTK Constant() operator is
|
||||
returned.
|
||||
|
||||
In case the `value` is a list of NumPy arrays, a CNTK Input() operator is
|
||||
returned, interpreting every element as a sequence of tensors.
|
||||
|
||||
In case the `value` is a NumPy array or list of lists, a CNTK Input()
|
||||
operator is returned, interpreting it as a dense tensor.
|
||||
|
||||
Non-scalar values are interpreted as sparse when they contain a colon.
|
||||
:param `value`: is a list of NumPy tensors. Currently, only dense tensors
|
||||
are supported. Sparse will come soon by the power of scipy.
|
||||
:param `has_sequence_dimension`: If True, the outermost dimension is
|
||||
treated as the sequence dimension. If False, it will wrap each sample
|
||||
into its own 1-dimensional array.
|
||||
:param `alias`: optional the alias to be used when serializing the data
|
||||
into an intermediate file
|
||||
'''
|
||||
if is_sequence(value) or is_tensor(value):
|
||||
return _get_input_node(value, **kw)
|
||||
if is_tensor_list(value) or is_tensor(value):
|
||||
return _get_input_node(value, has_sequence_dimension, **kw)
|
||||
else:
|
||||
raise ValueError('value type is not supported: %s' % type(value))
|
||||
|
||||
|
||||
def constant(value, **kw):
|
||||
'''
|
||||
Defining Constant as a factory override that creates either a Constant()
|
||||
operator or an Input() operator based on the type of the `value`.
|
||||
|
||||
In case the `value` is a scalar, a normal CNTK Constant() operator is
|
||||
returned.
|
||||
|
||||
In case the `value` is a list of NumPy arrays, a CNTK Input() operator is
|
||||
returned, interpreting every element as a sequence of tensors.
|
||||
|
||||
In case the `value` is a NumPy array or list of lists, a CNTK Input()
|
||||
operator is returned, interpreting it as a dense tensor.
|
||||
|
||||
Non-scalar values are interpreted as sparse when they contain a colon.
|
||||
Creating a constant tensor node around `value`.
|
||||
'''
|
||||
if np.isscalar(value):
|
||||
return cntk1_ops.Constant(value, **kw)
|
||||
if np.isscalar(value) or is_tensor(value):
|
||||
return _get_constant_node(value, **kw)
|
||||
else:
|
||||
if is_tensor(value):
|
||||
return _get_constant_node(value, **kw)
|
||||
else:
|
||||
raise ValueError('value type is not supported: %s' % type(value))
|
||||
raise ValueError('value type is not supported: %s' % type(value))
|
||||
|
|
|
@ -58,6 +58,7 @@ class UCIFastReader(AbstractReader):
|
|||
"""
|
||||
template = '''\
|
||||
reader = [
|
||||
traceLevel = 2
|
||||
readerType = "%(ReaderType)s"
|
||||
file = "%(FileName)s"
|
||||
randomize = "none"
|
||||
|
@ -132,10 +133,12 @@ class CNTKTextFormatReader(AbstractReader):
|
|||
def generate_config(self):
|
||||
"""Generate the reader configuration block
|
||||
"""
|
||||
template = ''' reader = [
|
||||
readerType = "%(ReaderType)s"
|
||||
file = "%(FileName)s"
|
||||
'''
|
||||
template = '''
|
||||
reader = [
|
||||
traceLevel = 2
|
||||
readerType = "%(ReaderType)s"
|
||||
file = "%(FileName)s"
|
||||
'''
|
||||
|
||||
if self.inputs_def is not None:
|
||||
template += '''
|
||||
|
@ -154,15 +157,15 @@ class CNTKTextFormatReader(AbstractReader):
|
|||
a = input_alias
|
||||
|
||||
template += '''
|
||||
{0}=[
|
||||
alias = "{1}"
|
||||
dim = {2}
|
||||
format = "{3}"
|
||||
]'''.format(name, a, dim, format)
|
||||
{0}=[
|
||||
alias = "{1}"
|
||||
dim = {2}
|
||||
format = "{3}"
|
||||
]'''.format(name, a, dim, format)
|
||||
|
||||
template += '''
|
||||
]
|
||||
]
|
||||
]
|
||||
'''
|
||||
return template % self
|
||||
|
||||
|
|
|
@ -15,6 +15,14 @@ Eval=[
|
|||
|
||||
%(Reader)s
|
||||
|
||||
|
||||
format = [
|
||||
# %%x = shape, %%d = sequenceId
|
||||
sequencePrologue=%%d\t|w.shape %%x\n%%d\t|w\s
|
||||
sampleSeparator=\n%%d\t|w\s
|
||||
elementSeparator=\s
|
||||
]
|
||||
|
||||
outputPath = "%(OutputFile)s"
|
||||
]
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import numpy as np
|
||||
from ..context import *
|
||||
|
||||
|
||||
def test_parse_shapes():
|
||||
def test_parse_shapes_1():
|
||||
output = '''\
|
||||
FormNestedNetwork: WARNING: Was called twice for v3 Plus operation
|
||||
|
||||
|
@ -26,7 +27,7 @@ Post-processing network complete.
|
|||
'''
|
||||
|
||||
expected = {
|
||||
'dummy_node': (2,),
|
||||
'dummy_node': (2, np.NaN),
|
||||
'v0': (4, 1),
|
||||
'v1': (2, 2),
|
||||
'v2': (1, 1),
|
||||
|
@ -34,3 +35,36 @@ Post-processing network complete.
|
|||
}
|
||||
|
||||
assert Context._parse_shapes_from_output(output) == expected
|
||||
|
||||
def test_parse_shapes_2():
|
||||
output = '''\
|
||||
Validating --> v1 = LearnableParameter() : -> [3 x 2 {1,3}]
|
||||
Validating --> v2 = InputValue() : -> [2 {1} x *]
|
||||
Validating --> v3 = Times (v1, v2) : [3 x 2 {1,3}], [2 {1} x *] -> [3 {1} x *]
|
||||
Validating --> v4 = LearnableParameter() : -> [3 x 1 {1,3}]
|
||||
Validating --> v5 = Plus (v3, v4) : [3 {1} x *], [3 x 1 {1,3}] -> [3 x 1 {1,3} x *]
|
||||
'''
|
||||
|
||||
expected = {
|
||||
'v1': (3, 2),
|
||||
'v2': (2, np.NaN),
|
||||
'v3': (3, np.NaN),
|
||||
'v4': (3, 1),
|
||||
'v5': (3, 1, np.NaN),
|
||||
}
|
||||
|
||||
assert Context._parse_shapes_from_output(output) == expected
|
||||
|
||||
def test_parse_result_output_1():
|
||||
output = '''\
|
||||
0 |w.shape 1 1
|
||||
0 |w 60.000000
|
||||
1 |w.shape 1 2
|
||||
1 |w 22.000000
|
||||
1 |w 24.000000'''
|
||||
list_of_tensors = Context._parse_result_output(output)
|
||||
expected = [[[60]], [[22],[24]]]
|
||||
assert len(list_of_tensors) == len(expected)
|
||||
for res, exp in zip(list_of_tensors, expected):
|
||||
assert np.allclose(res, np.asarray(exp))
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ..context import get_new_context, _CONTEXT
|
||||
from ..graph import *
|
||||
from ..graph import _seq_to_text_format
|
||||
from ..graph import _tensor_to_text_format
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -56,74 +56,50 @@ def _to_list(desc):
|
|||
return [line.strip() for line in desc.split('\n')]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root_node, expected", [
|
||||
(C(2, var_name='c0'), ["c0 = Constant(2, rows=1, cols=1)"]),
|
||||
# Input should behave as Constant in case of scalars
|
||||
(I([1, 2], var_name='i1'), ["i1 = Input(2:1, tag='feature')"]),
|
||||
(Plus(C(0), C(1)),
|
||||
["v0 = Constant(0, rows=1, cols=1)", "v1 = Constant(1, rows=1, cols=1)", "v2 = Plus(v0, v1)"]),
|
||||
])
|
||||
def test_description(root_node, expected):
|
||||
description, has_inputs, readers = root_node.to_config()
|
||||
assert _to_list(description) == expected
|
||||
|
||||
|
||||
def test_graph_with_same_node_twice():
|
||||
v0 = C(1)
|
||||
root_node = Plus(v0, v0)
|
||||
expected = ['v0 = Constant(1, rows=1, cols=1)', 'v1 = Plus(v0, v0)']
|
||||
description, has_inputs, readers = root_node.to_config()
|
||||
assert _to_list(description) == expected
|
||||
assert readers == []
|
||||
assert len(_to_list(description)) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("alias, data, expected", [
|
||||
('', [A([1, 0]), A([0, 0, 1, 0])], ValueError), # no alias given
|
||||
('A', [object()], ValueError),
|
||||
@pytest.mark.parametrize("alias, idx, data, expected", [
|
||||
('', 0, [A([1, 0]), A([0, 0, 1, 0])], ValueError), # no alias given
|
||||
('A', 0, [object()], ValueError),
|
||||
])
|
||||
def test_sequence_conversion_exceptions(alias, data, expected):
|
||||
def test_tensor_conversion_exceptions(alias, idx, data, expected):
|
||||
with pytest.raises(expected):
|
||||
_seq_to_text_format(data, alias=alias)
|
||||
_tensor_to_text_format(idx, alias, data)
|
||||
|
||||
|
||||
def test_constant_var_name():
|
||||
var_name = 'NODE'
|
||||
node = C([A([])], var_name=var_name)
|
||||
assert node.var_name == var_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize("alias, data, expected", [
|
||||
('W', [A([])], """\
|
||||
0|W \
|
||||
"""),
|
||||
('W', [A([1, 0]), A([0, 0, 1, 0])], """\
|
||||
0|W 1 0
|
||||
1|W 0 0 1 0\
|
||||
@pytest.mark.parametrize("alias, idx, data, expected", [
|
||||
('W', 0, A([]), "0 |W "),
|
||||
('W', 0, A([[1, 0, 0, 0], [0, 0, 1, 0]]), """\
|
||||
0 |W 1 0 0 0 0 0 1 0\
|
||||
"""),
|
||||
])
|
||||
def test_sequence_conversion_dense(alias, data, expected):
|
||||
assert _seq_to_text_format(data, alias=alias) == expected
|
||||
def test_tensor_conversion_dense(alias, idx, data, expected):
|
||||
assert _tensor_to_text_format(idx, alias, data,
|
||||
has_sequence_dimension=False) == expected
|
||||
|
||||
if False:
|
||||
@pytest.mark.parametrize("alias, data, expected", [
|
||||
('W', [A({})], """\
|
||||
0|W \
|
||||
"""),
|
||||
('W', [A({})], ""),
|
||||
('W', [{3: 1, 50: 1, 2: 0}, {1: -5}], """\
|
||||
0|W 2:0 3:1 50:1
|
||||
1|W 1:-5\
|
||||
0 |W 2:0 3:1 50:1
|
||||
1 |W 1:-5\
|
||||
"""),
|
||||
])
|
||||
def test_sequence_conversion_sparse(alias, data, expected):
|
||||
def test_tensor_conversion_sparse(alias, data, expected):
|
||||
# We use the dictionary in data to create a SciPy sparse dictionary of
|
||||
# keys, which we then feed to the converter.
|
||||
dok_data = []
|
||||
for data_elem in data:
|
||||
for idx, data_elem in enumerate(data):
|
||||
d = scipy.sparse.dok_matrix((100, 1))
|
||||
for k, v in data_elem.items():
|
||||
d[k] = v
|
||||
dok_data.append(d)
|
||||
assert _seq_to_text_format(dok_data, alias=alias) == expected
|
||||
assert _tensor_to_text_format(idx, alias, dok_data) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("data, expected", [
|
||||
|
@ -148,8 +124,8 @@ def test_is_tensor(data, expected):
|
|||
([A([1, 2])], True),
|
||||
([A([1, 2]), A([])], True),
|
||||
])
|
||||
def test_is_sequence(data, expected):
|
||||
assert is_sequence(data) == expected
|
||||
def test_is_tensor_list(data, expected):
|
||||
assert is_tensor_list(data) == expected
|
||||
|
||||
def test_loose_coupling():
|
||||
from cntk.ops.cntk1 import PastValue
|
||||
|
|
|
@ -16,35 +16,77 @@ def _test(root_node, expected, clean_up=True):
|
|||
ctx.clean_up = clean_up
|
||||
assert not ctx.input_nodes
|
||||
result = ctx.eval(root_node)
|
||||
expected = AA(expected)
|
||||
assert result.shape == expected.shape or result.shape == (
|
||||
1, 1) and expected.shape == ()
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
_VALUES = [0, [[1, 2], [3, 4]], [10.1, -20.2], 1.1]
|
||||
assert len(result) == len(expected)
|
||||
for res, exp in zip(result, expected):
|
||||
assert np.allclose(res, exp)
|
||||
assert res.shape == AA(exp).shape
|
||||
|
||||
@pytest.fixture(scope="module", params=_VALUES)
|
||||
def left_arg(request):
|
||||
#C_VALUES = [0, [[1, 2], [3, 4]], [10.1, -20.2], 1.1]
|
||||
C_VALUES = [0, [[1, 2], [3, 4]]]
|
||||
|
||||
@pytest.fixture(scope="module", params=C_VALUES)
|
||||
def c_arg(request):
|
||||
return request.param
|
||||
|
||||
right_arg = left_arg
|
||||
c_left_arg = c_arg
|
||||
c_right_arg = c_arg
|
||||
|
||||
def test_op_add(left_arg, right_arg):
|
||||
if False:
|
||||
def test_op_add_constant(c_left_arg, c_right_arg):
|
||||
expected = [AA(c_left_arg) + AA(c_right_arg)]
|
||||
_test(C(c_left_arg) + c_right_arg, expected, False)
|
||||
_test(C(c_left_arg) + C(c_right_arg), expected)
|
||||
_test(c_left_arg + C(c_right_arg), expected)
|
||||
_test(c_left_arg + C(c_left_arg) + c_right_arg, c_left_arg+expected)
|
||||
|
||||
def test_op_minus_constant(c_left_arg, c_right_arg):
|
||||
expected = [AA(c_left_arg) - AA(c_right_arg)]
|
||||
_test(C(c_left_arg) - c_right_arg, expected)
|
||||
_test(C(c_left_arg) - C(c_right_arg), expected)
|
||||
_test(c_left_arg - C(c_right_arg), expected)
|
||||
_test(c_left_arg - C(c_left_arg) + c_right_arg, c_left_arg-expected)
|
||||
|
||||
def test_op_times_constant(c_left_arg, c_right_arg):
|
||||
expected = [AA(c_left_arg) * AA(c_right_arg)]
|
||||
_test(C(c_left_arg) * c_right_arg, expected)
|
||||
_test(C(c_left_arg) * C(c_right_arg), expected)
|
||||
_test(c_left_arg * C(c_right_arg), expected)
|
||||
|
||||
# Testing inputs
|
||||
|
||||
@pytest.mark.parametrize("left_arg, right_arg", [
|
||||
([30], [10]),
|
||||
([[30]], [[10]]),
|
||||
([[1.5,2.1]], [[10,20]]),
|
||||
# Adding two 3x2 inputs of sequence length 1
|
||||
([[30,40], [1,2], [0.1, 0.2]], [[10,20], [3,4], [-0.5, -0.4]]),
|
||||
([5], [[30,40], [1,2]]),
|
||||
])
|
||||
def test_op_add_input_constant(left_arg, right_arg):
|
||||
expected = AA(left_arg) + AA(right_arg)
|
||||
_test(C(left_arg) + right_arg, expected)
|
||||
_test(C(left_arg) + C(right_arg), expected)
|
||||
_test(left_arg + C(right_arg), expected)
|
||||
_test(left_arg + C(left_arg) + right_arg, left_arg+expected)
|
||||
# sequence of 1 element, since we have has_sequence_dimension=False
|
||||
expected = [expected]
|
||||
# batch of one sample
|
||||
expected = [expected]
|
||||
_test(I([left_arg], has_sequence_dimension=False) + right_arg, expected, False)
|
||||
_test(left_arg + I([right_arg], has_sequence_dimension=False), expected, False)
|
||||
|
||||
def test_op_minus(left_arg, right_arg):
|
||||
expected = AA(left_arg) - AA(right_arg)
|
||||
_test(C(left_arg) - right_arg, expected)
|
||||
_test(C(left_arg) - C(right_arg), expected)
|
||||
_test(left_arg - C(right_arg), expected)
|
||||
_test(left_arg - C(left_arg) + right_arg, left_arg-expected)
|
||||
@pytest.mark.parametrize("left_arg, right_arg", [
|
||||
([
|
||||
[[30]], # 1st element has (1,) sequence of length 1
|
||||
[[11],[12]] # 2nd element has (1,) sequence of length 2
|
||||
] ,
|
||||
2),
|
||||
([
|
||||
[[33,22]], # 1st element has (1x2) sequence of length 1
|
||||
[[11,12], [1.1,2.2]] # 2nd element has (1x2) sequence of length 2
|
||||
],
|
||||
2),
|
||||
])
|
||||
|
||||
def test_op_mul_input_seq(left_arg, right_arg):
|
||||
expected = [AA(elem)*right_arg for elem in left_arg]
|
||||
result = I(left_arg, has_sequence_dimension=True) * right_arg
|
||||
_test(result, expected, False)
|
||||
|
||||
def test_op_times(left_arg, right_arg):
|
||||
expected = AA(left_arg) * AA(right_arg)
|
||||
_test(C(left_arg) * right_arg, expected)
|
||||
_test(C(left_arg) * C(right_arg), expected)
|
||||
_test(left_arg * C(right_arg), expected)
|
||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
from ..reader import *
|
||||
from ..graph import *
|
||||
from ..context import *
|
||||
from .. import cntk1_ops
|
||||
from ..ops import cntk1 as cntk1_ops
|
||||
|
||||
allclose = np.testing.assert_allclose
|
||||
|
||||
|
@ -20,6 +20,7 @@ def test_NumPyReader(tmpdir):
|
|||
|
||||
with get_new_context() as ctx:
|
||||
result = ctx.eval(out, reader)
|
||||
assert np.all(result == np.asarray(data) + 2)
|
||||
for r, d in zip(result, data):
|
||||
assert np.all(r== np.asarray(d) + 2)
|
||||
|
||||
# TODO test other readers
|
||||
|
|
|
@ -9,3 +9,34 @@ CNTK_EXECUTABLE_PATH = os.environ['CNTK_EXECUTABLE_PATH']
|
|||
# Indent model description by how many spaces
|
||||
MODEL_INDENTATION = 8
|
||||
|
||||
|
||||
def numpy_to_cntk_shape(shape):
|
||||
'''
|
||||
Converting the NumPy shape (row major) to CNTK shape (column major).
|
||||
|
||||
:param shape: NumPy shape tuple
|
||||
|
||||
Returns a tuple that can be ':'.join()ed to a CNTK dimension.
|
||||
'''
|
||||
if not shape:
|
||||
# in case of a scalar
|
||||
return (1,)
|
||||
|
||||
return tuple(reversed(shape))
|
||||
|
||||
def cntk_to_numpy_shape(shape):
|
||||
'''
|
||||
Converts col-major to row-major and removes the sequence dimension.
|
||||
|
||||
:param shape: CNTK shape iterable
|
||||
|
||||
Returns a tuple that describes the NumPy shape of a tensor
|
||||
'''
|
||||
|
||||
shape = tuple(int(s) for s in reversed(shape))
|
||||
|
||||
shape = shape[1:]
|
||||
if not shape:
|
||||
shape = (1,)
|
||||
|
||||
return shape
|
||||
|
|
Загрузка…
Ссылка в новой задаче