246 строки
10 KiB
Python
246 строки
10 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
# Licensed under the MIT license. See LICENSE.md file in the project root
|
|
# for full license information.
|
|
# ==============================================================================
|
|
|
|
from .. import cntk_py
|
|
import numpy as np
|
|
from cntk import NDArrayView
|
|
from ..cntk_py import DictionaryValueFromDict, DictionaryValue, Dictionary, DictionaryValueFromNDArrayView
|
|
|
|
_VARIABLE_OR_FUNCTION = (cntk_py.Variable, cntk_py.Function)
|
|
|
|
def get_data_type(*args):
|
|
"""
|
|
Calculates the highest precision numpy data type of the provided parameters.
|
|
If the parameter is a Function instance, it calculates it based on its
|
|
inputs. placeholders are ignored in the type determination.
|
|
|
|
Args:
|
|
args (number, list, NumPy array, :class:`~cntk.variables.Variable`, or :class:`~cntk.ops.functions.Function`): input
|
|
|
|
Returns:
|
|
np.float32, np.float64, np.float16 or None
|
|
"""
|
|
from ..variables import Variable
|
|
|
|
cntk_dtypes = set()
|
|
numpy_dtypes = set()
|
|
if len(args) == 1 and isinstance(args, _VARIABLE_OR_FUNCTION):
|
|
args = [args]
|
|
|
|
for arg in args:
|
|
if isinstance(arg, Variable) and arg.is_placeholder == True:
|
|
continue
|
|
if isinstance(arg,
|
|
(cntk_py.Variable, cntk_py.Value, cntk_py.NDArrayView)):
|
|
if cntk_py.DataType_Double == arg.get_data_type():
|
|
cntk_dtypes.add(np.float64)
|
|
elif cntk_py.DataType_Float == arg.get_data_type():
|
|
cntk_dtypes.add(np.float32)
|
|
elif cntk_py.DataType_Float16 == arg.get_data_type():
|
|
cntk_dtypes.add(np.float16)
|
|
elif isinstance(arg, (np.ndarray, np.inexact)):
|
|
# https://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
|
|
# integer are not np.inexact -> np.float32
|
|
# only accepts numpy types
|
|
if arg.dtype not in (np.float32, np.float64, np.float16):
|
|
raise ValueError(
|
|
'NumPy type "%s" is not supported' % arg.dtype)
|
|
numpy_dtypes.add(arg.dtype.type)
|
|
elif isinstance(arg, _VARIABLE_OR_FUNCTION):
|
|
var_outputs = arg.outputs
|
|
if len(var_outputs) > 1:
|
|
raise ValueError(
|
|
'expected single output, but got %i' % len(var_outputs))
|
|
|
|
var_type = var_outputs[0].get_data_type()
|
|
if cntk_py.DataType_Double == var_type:
|
|
cntk_dtypes.add(np.float64)
|
|
elif cntk_py.DataType_Float == var_type:
|
|
cntk_dtypes.add(np.float32)
|
|
elif cntk_py.DataType_Float16 == var_type:
|
|
cntk_dtypes.add(np.float16)
|
|
else:
|
|
# We don't know anything so we convert everything to float32. If it
|
|
# works, we know the type.
|
|
# TODO figure out a better/faster way.
|
|
np.asarray(arg, dtype=np.float32)
|
|
numpy_dtypes.add(np.float32)
|
|
|
|
if cntk_dtypes:
|
|
if np.float64 in cntk_dtypes:
|
|
return np.float64
|
|
elif np.float32 in cntk_dtypes:
|
|
return np.float32
|
|
elif np.float16 in cntk_dtypes:
|
|
return np.float16
|
|
else:
|
|
if np.float64 in numpy_dtypes:
|
|
return np.float64
|
|
elif np.float32 in numpy_dtypes:
|
|
return np.float32
|
|
elif np.float16 in numpy_dtypes:
|
|
return np.float16
|
|
|
|
def get_python_function_arguments(f):
|
|
'''
|
|
Helper to get the parameter names and annotations of a Python function.
|
|
'''
|
|
# Note that we only return non-optional arguments (we assume that any optional args are not specified).
|
|
# This allows to, e.g., accept max(a, b, *more, name='') as a binary function
|
|
import sys
|
|
if sys.version_info.major >= 3:
|
|
from inspect import getfullargspec
|
|
else:
|
|
def getfullargspec(f):
|
|
from inspect import getargspec
|
|
from ..variables import Record
|
|
|
|
annotations = getattr(f, '__annotations__', {})
|
|
#f.__annotations__ = None # needed when faking it under Python 3 for debugging purposes
|
|
a = getargspec(f)
|
|
#f.__annotations__ = annotations
|
|
return Record(args=a.args, varargs=a.varargs, varkw=a.keywords, defaults=a.defaults, kwonlyargs=[], kwonlydefaults=None, annotations=annotations)
|
|
param_specs = getfullargspec(f)
|
|
annotations = param_specs.annotations
|
|
arg_names = param_specs.args
|
|
defaults = param_specs.defaults # "if this tuple has n elements, they correspond to the last n elements listed in args"
|
|
if defaults:
|
|
arg_names = arg_names[:-len(defaults)] # we allow Function(functions with default arguments), but those args will always have default values since CNTK Functions do not support this
|
|
return (arg_names, annotations)
|
|
|
|
def map_function_arguments(params, params_dict, *args, **kwargs):
|
|
'''
|
|
Helper to determine the argument map for use with various call operations.
|
|
Returns a dictionary from parameters to whatever arguments are passed.
|
|
Accepted are both positional and keyword arguments.
|
|
This mimics Python's argument interpretation, except that keyword arguments are not optional.
|
|
This does not require the arguments to be Variables or Functions. It is also called by train_minibatch() and @Signature.
|
|
'''
|
|
# start with positional arguments
|
|
arg_map = dict(zip(params, args))
|
|
|
|
# now look up keyword arguments
|
|
if len(kwargs) != 0:
|
|
for name, arg in kwargs.items(): # keyword args are matched by name
|
|
if name not in params_dict:
|
|
raise TypeError("got an unexpected keyword argument '%s'" % name)
|
|
param = params_dict[name]
|
|
if param in arg_map:
|
|
raise SyntaxError("got multiple values for argument '%s'" % name)
|
|
arg_map[param] = arg # add kw argument to dict
|
|
assert len(arg_map) == len(params)
|
|
|
|
return arg_map
|
|
|
|
def _ones_like(batch, precision):
|
|
'''
|
|
Returns a new batch, which has the same format as ``batch`` but all values
|
|
set to 1.
|
|
|
|
Args:
|
|
batch (list of NumPy arrays): a list of sequences, which are NumPy arrays
|
|
'''
|
|
from cntk.internal import sanitize_precision
|
|
return [np.ones_like(sample, dtype=sanitize_precision(precision)) for sample in batch]
|
|
|
|
|
|
def eval(op, arguments=None, precision=None, device=None, backward_pass=False, expected_backward=None):
|
|
'''
|
|
It evaluates ``op`` on the data provided by the reader. This is useful
|
|
mainly to explore the operators and for convenient unit testing.
|
|
|
|
Args:
|
|
op (:class:`~cntk.ops.functions.Function`): operation to evaluate
|
|
arguments: maps variables to their input data. The
|
|
interpretation depends on the input type:
|
|
|
|
* `dict`: keys are input variable or names, and values are the input data.
|
|
|
|
* any other type: if node has a unique input, ``arguments`` is mapped to this input.
|
|
For nodes with more than one input, only `dict` is allowed.
|
|
|
|
In both cases, every sample in the data will be interpreted
|
|
as a new sequence. To mark samples as continuations of the
|
|
previous sequence, specify ``arguments`` as `tuple`: the
|
|
first element will be used as ``arguments``, and the second one will
|
|
be used as a list of bools, denoting whether a sequence is a new
|
|
one (`True`) or a continuation of the previous one (`False`).
|
|
Data should be either NumPy arrays or a
|
|
:class:`~cntk.io.MinibatchData` instance.
|
|
seq_starts (list of bools or None): if None, every sequence is
|
|
treated as a new sequence. Otherwise, it is interpreted as a list of
|
|
Booleans that tell whether a sequence is a new sequence (`True`) or a
|
|
continuation of the sequence in the same slot of the previous
|
|
minibatch (`False`)
|
|
precision (str or None): precision being 'float32', 'float64', 'float16', or
|
|
None, in which case it will be determined by inspecting the operator
|
|
(costly)
|
|
device (:class:`~cntk.device.DeviceDescriptor`, default None): device
|
|
this value should be put on
|
|
backward_pass (`bool`, optional): whether a backward pass is performed
|
|
expected_backward (`dict` or None): keys are variables for which to
|
|
compute a backward ouptut. By default (None) all entries from
|
|
'arguments' are used
|
|
|
|
Returns:
|
|
mapping of output variables to their values.
|
|
'''
|
|
|
|
if backward_pass:
|
|
state, forward_output = op.forward(arguments, op.outputs, op.outputs,
|
|
device=device)
|
|
|
|
if expected_backward is None:
|
|
expected_backward = arguments
|
|
root_gradients = {v: _ones_like(o, precision) for v, o in
|
|
forward_output.items()}
|
|
|
|
backward_output = op.backward(state, root_gradients, expected_backward)
|
|
|
|
return forward_output, backward_output
|
|
|
|
else:
|
|
state, forward_output = op.forward(
|
|
arguments, op.outputs, None, device=device)
|
|
return forward_output, None
|
|
|
|
def _to_cntk_dict_value(py_value):
|
|
if isinstance(py_value, dict):
|
|
return DictionaryValueFromDict(_py_dict_to_cntk_dict(py_value))
|
|
|
|
if isinstance(py_value, list):
|
|
py_list = list(map(_to_cntk_dict_value, py_value))
|
|
return DictionaryValue(py_list)
|
|
|
|
if isinstance(py_value, np.ndarray):
|
|
py_value = NDArrayView.from_dense(py_value)
|
|
return DictionaryValueFromNDArrayView(py_value)
|
|
|
|
if isinstance(py_value, cntk_py.training_double_parameter_schedule):
|
|
return cntk_py.DictionaryValueFromTrainingDoubleParameterSchedule(py_value)
|
|
|
|
if py_value is None:
|
|
return DictionaryValue()
|
|
|
|
return DictionaryValue(py_value)
|
|
|
|
def _py_dict_to_cntk_dict(py_dict):
|
|
'''
|
|
Recursively converts a Python dictionary into a CNTK Dictionary
|
|
whose values are CNTK DictionaryValue instances.
|
|
|
|
Args:
|
|
py_dict (dict): a dictionary to be converted.
|
|
|
|
Returns:
|
|
cntk_py.Dictionary:
|
|
A :class:`~cntk.cntk_py.Dictionary` that has been converted from the input `dict`
|
|
'''
|
|
res = Dictionary()
|
|
for k, v in py_dict.items():
|
|
res[k] = _to_cntk_dict_value(v)
|
|
return res
|