Merge remote-tracking branch 'origin/wilrich/miscAlpha2' into wilrich/miscAlpha2

This commit is contained in:
Mark Hillebrand 2016-09-30 22:40:27 +02:00
Родитель 0196086f9d 93c0643312
Коммит b4bcc96e38
4 изменённых файлов: 170 добавлений и 35 удалений

Просмотреть файл

@ -24,6 +24,7 @@ class Trainer(cntk_py.Trainer):
def __init__(self, model, loss_function, eval_function, parameter_learners):
if isinstance(model, cntk_py.Variable):
model = model.owner
self.model = model
if isinstance(loss_function, cntk_py.Variable):
loss_function = loss_function.owner
if isinstance(eval_function, cntk_py.Variable):
@ -37,8 +38,10 @@ class Trainer(cntk_py.Trainer):
Returns false if all parameter learners indicate end of learning (through their Update method's return value).
Args:
arguments (dict): map from input variables to the data, data should be either numpy
arrays or cntk.Value instances returned by a minibatch source
arguments (`dict` or `list`): map from input variables to the data
or list of inputs in the order that the function expects. Data
should be either NumPy arrays or cntk.Value instances returned by a
minibatch source.
device (:class:`cntk.DeviceDescriptor`): the device descriptor that
contains the type and id of the device on which the computation is
to be performed.
@ -48,7 +51,7 @@ class Trainer(cntk_py.Trainer):
'''
if not device:
device=DeviceDescriptor.use_default_device()
arguments = sanitize_var_map(arguments, add_batch_axis=True)
arguments = sanitize_var_map(self.model.arguments(), arguments, add_batch_axis=True)
return super(Trainer, self).train_minibatch(arguments, device)
@ -59,8 +62,10 @@ class Trainer(cntk_py.Trainer):
of samples.
Args:
arguments (dict): map from input variables to the data, data should be either numpy
arrays or cntk.Value instances returned by a minibatch source
arguments (`dict` or `list`): map from input variables to the data
or list of inputs in the order that the function expects. Data
should be either NumPy arrays or cntk.Value instances returned by a
minibatch source.
device (:class:`cntk.DeviceDescriptor`): the device descriptor that
contains the type and id of the device on which the computation is
to be performed.
@ -70,7 +75,7 @@ class Trainer(cntk_py.Trainer):
'''
if not device:
device=DeviceDescriptor.use_default_device()
arguments = sanitize_var_map(arguments, add_batch_axis=True)
arguments = sanitize_var_map(self.model.arguments(), arguments, add_batch_axis=True)
return super(Trainer, self).test_minibatch(arguments, device)

Просмотреть файл

@ -1,5 +1,4 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
@ -7,9 +6,11 @@
import os
import sys
import numbers
import collections
import numpy as np
import scipy.sparse
from cntk import cntk_py
from .persist import *
def precision_numpy(precision):
@ -415,42 +416,80 @@ def sanitize_batch(batch, data_type=None, device=None):
return value
def sanitize_var_map(input_map, precision_numpy=None, device=None, add_batch_axis=False):
def sanitize_var_map(op_arguments, arguments, precision_numpy=None, device=None, add_batch_axis=False):
'''
Sanitizes a dictionary of `Variable`s to input data such that it can be
handed off to the `Forward` method.
Args:
input_map (`dict`): `Variable` to input (NumPy array or simple list of lists)
op_arguments (`:class:Function`): arguments of the root function. In
forward pass it is typically `op.arguments()`, in backward mode it is
`op.outputs()`
arguments (`dict` or `list`): map from input variables to the data or
list of inputs in the order that the function expects or a single input,
if the function only has one argument. Data should be either NumPy
arrays or cntk.Value instances returned by a minibatch source.
precision_numpy : `np.float32`, `np.float64`, or `None`
device (`DeviceDescriptor` or `None`): CNTK DeviceDescriptor
add_batch_axis (`bool`): data in `input_map` are single instances and a batch axis has to be added
add_batch_axis (`bool`): data in `arguments` are single instances and a batch axis has to be added
Returns:
`dict` that maps variables to sanitized batches
'''
var_map = {}
if input_map:
for var, batch in input_map.items():
from ..cntk_py import Value
if not isinstance(batch, Value):
if add_batch_axis:
batch = [batch]
if isinstance(batch, np.ndarray):
if batch.dtype == np.int:
batch = batch.astype(np.float32)
if batch.dtype not in (np.float32, np.float64):
raise ValueError(
'only float32 and float64 are supported')
batch = sanitize_batch(batch, precision_numpy, device)
else:
if is_tensor(batch):
batch = np.asarray(batch, dtype=precision_numpy)
batch = create_Value_from_NumPy(batch, device)
else:
batch = sanitize_batch(batch, precision_numpy, device)
var_map[var] = batch
if not arguments:
if len(op_arguments) > 0:
raise ValueError('function expects %i arguments'%len(op_arguments))
return {}
if len(op_arguments) == 1 and not isinstance(arguments, dict):
return { op_arguments[0] : arguments }
if isinstance(arguments, dict):
arg_names = [var.name() for var in op_arguments]
name_counter = collections.Counter(arg_names)
var_name_map = dict((var, var.name()) for var in op_arguments)
elif isinstance(arguments, list):
arguments = dict(zip(op_arguments, arguments))
else:
raise ValueError('type "%s" is not supported'%type(arguments))
if len(arguments) < len(op_arguments):
raise ValueError('expected %i arguments, but got %i'%(len(op_arguments), len(arguments)))
var_map = {}
for var, batch in arguments.items():
if isinstance(var, str):
if name_counter[var] == 0:
raise ValueError('variable with name "%s" does not exist in the network. Available variable names: %s'%(var, ", ".join(var_name_map)))
elif name_counter[var] > 1:
raise ValueError('node name "%s" is not unique'%var)
var = var_name_map[var]
from ..cntk_py import Value
if not isinstance(batch, Value):
if add_batch_axis:
batch = [batch]
if isinstance(batch, np.ndarray):
if batch.dtype == np.int:
batch = batch.astype(np.float32)
if batch.dtype not in (np.float32, np.float64):
raise ValueError('only float32 and float64 are supported')
batch = sanitize_batch(batch, precision_numpy, device)
else:
if is_tensor(batch):
if precision_numpy is None:
precision_numpy = np.float32
batch = np.asarray(batch, dtype=precision_numpy)
batch = create_Value_from_NumPy(batch, device)
else:
batch = sanitize_batch(batch, precision_numpy, device)
var_map[var] = batch
return var_map
@ -605,7 +644,7 @@ def ensure_dev(ndav, dev):
return ndav
def eval(op, precision, device, input_map=None, backward_pass=False):
def eval(op, precision, device, arguments=None, backward_pass=False):
'''
It evaluates `op` on the data provided by the reader. This is useful
mainly to explore the operators and for convenient unit testing.
@ -614,7 +653,10 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
op (:class:`Function`): operation to evaluate
precision (`str` or `None`): precision being 'float32', 'float64', or `None`, in which case it will be determined by inspecting the operator (costly)
device (:class:Cntk.DeviceDescriptor): the device the descriptor, whether it is CPU or GPU (and which one)
input_map (`dict`): describes how to map inputs to the data in a data file using a number, NumPy array or reader object
arguments (`dict` or `list`): map from input variables to the data
or list of inputs in the order that the function expects. Data
should be either NumPy arrays or cntk.Value instances returned by a
minibatch source.
backward_pass (`bool`, optional): whether a backward pass is performed
Returns:
@ -624,7 +666,7 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
if precision is not None:
precision = precision_numpy(precision)
forward_in_var_map = sanitize_var_map(input_map, precision, device)
forward_in_var_map = sanitize_var_map(op.arguments(), arguments, precision, device)
forward_out_var_map = {}
forward_retain = set()
@ -649,7 +691,7 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
root_gradients = {}
for v, o in forward_output.items():
root_gradients[v] = ones_like(o, precision)
root_gradients = sanitize_var_map(root_gradients, precision, device)
root_gradients = sanitize_var_map(op.outputs(), root_gradients, precision, device)
backward_var_map = dict((var, None) for var in forward_in_var_map)

Просмотреть файл

@ -0,0 +1,35 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
from cntk import cntk_py
def save_model(root_op, filename):
'''
Save the network of `root_op` in `model_file`.
Args:
root_op (`:class:cntk.functions.Function`): op of the graph to save
filename (`str`): filename to store the model in
'''
cntk_py.save_as_legacy_model(root_op, filename)
def load_model(data_type, filename, device=None):
'''
Load the network of `root_op` in `model_file`, that has been saved using
`:func:save_model`.
Args:
data_type ('float' or 'double', or NumPy type): data type of the operation
filename (`str`): filename to load the model from
device (:class:`cntk.DeviceDescriptor`, default to default device): instance of DeviceDescriptor
Returns:
root node
'''
from cntk.utils import sanitize_dtype_cntk
data_type = sanitize_dtype_cntk(data_type)
if not device:
device=cntk_py.DeviceDescriptor.use_default_device()
return cntk_py.load_legacy_model(data_type, filename)

Просмотреть файл

@ -0,0 +1,53 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
import pytest
from cntk.ops import *
from cntk.utils import load_model, save_model
def test_load_save_constant():
c = constant(value=[1,3])
root_node = c * 5
result = root_node.eval()
expected = [[[[5,15]]]]
assert np.allclose(result, expected)
filename = 'c_plus_c.mod'
save_model(root_node, filename)
loaded_node = load_model('float', filename)
loaded_result = loaded_node.eval()
assert np.allclose(loaded_result, expected)
def test_load_save_inputs():
i1 = input_variable((1,2), name='i1')
i2 = input_variable((2,1), name='i2')
root_node = plus(i1, i2)
input1 = [[[1,2]]]
input2 = [[[[1],[2]]]]
result = root_node.eval({i1: input1, i2: input2})
expected = [[[[2,3],[3,4]]]]
assert np.allclose(result, expected)
filename = 'i_plus_c.mod'
save_model(root_node, filename)
loaded_node = load_model('float', filename)
# Test specifying the input nodes by name
# FIXME: node names are not properly saved on C++ yet
if False:
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
assert np.allclose(loaded_result, expected)
# Test spefying the input node names by order
loaded_result = loaded_node.eval([input1, input2])
assert np.allclose(loaded_result, expected)