Merge remote-tracking branch 'origin/wilrich/miscAlpha2' into wilrich/miscAlpha2
This commit is contained in:
Коммит
b4bcc96e38
|
@ -24,6 +24,7 @@ class Trainer(cntk_py.Trainer):
|
|||
def __init__(self, model, loss_function, eval_function, parameter_learners):
|
||||
if isinstance(model, cntk_py.Variable):
|
||||
model = model.owner
|
||||
self.model = model
|
||||
if isinstance(loss_function, cntk_py.Variable):
|
||||
loss_function = loss_function.owner
|
||||
if isinstance(eval_function, cntk_py.Variable):
|
||||
|
@ -37,8 +38,10 @@ class Trainer(cntk_py.Trainer):
|
|||
Returns false if all parameter learners indicate end of learning (through their Update method's return value).
|
||||
|
||||
Args:
|
||||
arguments (dict): map from input variables to the data, data should be either numpy
|
||||
arrays or cntk.Value instances returned by a minibatch source
|
||||
arguments (`dict` or `list`): map from input variables to the data
|
||||
or list of inputs in the order that the function expects. Data
|
||||
should be either NumPy arrays or cntk.Value instances returned by a
|
||||
minibatch source.
|
||||
device (:class:`cntk.DeviceDescriptor`): the device descriptor that
|
||||
contains the type and id of the device on which the computation is
|
||||
to be performed.
|
||||
|
@ -48,7 +51,7 @@ class Trainer(cntk_py.Trainer):
|
|||
'''
|
||||
if not device:
|
||||
device=DeviceDescriptor.use_default_device()
|
||||
arguments = sanitize_var_map(arguments, add_batch_axis=True)
|
||||
arguments = sanitize_var_map(self.model.arguments(), arguments, add_batch_axis=True)
|
||||
|
||||
return super(Trainer, self).train_minibatch(arguments, device)
|
||||
|
||||
|
@ -59,8 +62,10 @@ class Trainer(cntk_py.Trainer):
|
|||
of samples.
|
||||
|
||||
Args:
|
||||
arguments (dict): map from input variables to the data, data should be either numpy
|
||||
arrays or cntk.Value instances returned by a minibatch source
|
||||
arguments (`dict` or `list`): map from input variables to the data
|
||||
or list of inputs in the order that the function expects. Data
|
||||
should be either NumPy arrays or cntk.Value instances returned by a
|
||||
minibatch source.
|
||||
device (:class:`cntk.DeviceDescriptor`): the device descriptor that
|
||||
contains the type and id of the device on which the computation is
|
||||
to be performed.
|
||||
|
@ -70,7 +75,7 @@ class Trainer(cntk_py.Trainer):
|
|||
'''
|
||||
if not device:
|
||||
device=DeviceDescriptor.use_default_device()
|
||||
arguments = sanitize_var_map(arguments, add_batch_axis=True)
|
||||
arguments = sanitize_var_map(self.model.arguments(), arguments, add_batch_axis=True)
|
||||
|
||||
return super(Trainer, self).test_minibatch(arguments, device)
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
@ -7,9 +6,11 @@
|
|||
import os
|
||||
import sys
|
||||
import numbers
|
||||
import collections
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
from cntk import cntk_py
|
||||
from .persist import *
|
||||
|
||||
|
||||
def precision_numpy(precision):
|
||||
|
@ -415,42 +416,80 @@ def sanitize_batch(batch, data_type=None, device=None):
|
|||
return value
|
||||
|
||||
|
||||
def sanitize_var_map(input_map, precision_numpy=None, device=None, add_batch_axis=False):
|
||||
def sanitize_var_map(op_arguments, arguments, precision_numpy=None, device=None, add_batch_axis=False):
|
||||
'''
|
||||
Sanitizes a dictionary of `Variable`s to input data such that it can be
|
||||
handed off to the `Forward` method.
|
||||
|
||||
Args:
|
||||
input_map (`dict`): `Variable` to input (NumPy array or simple list of lists)
|
||||
op_arguments (`:class:Function`): arguments of the root function. In
|
||||
forward pass it is typically `op.arguments()`, in backward mode it is
|
||||
`op.outputs()`
|
||||
arguments (`dict` or `list`): map from input variables to the data or
|
||||
list of inputs in the order that the function expects or a single input,
|
||||
if the function only has one argument. Data should be either NumPy
|
||||
arrays or cntk.Value instances returned by a minibatch source.
|
||||
precision_numpy : `np.float32`, `np.float64`, or `None`
|
||||
device (`DeviceDescriptor` or `None`): CNTK DeviceDescriptor
|
||||
add_batch_axis (`bool`): data in `input_map` are single instances and a batch axis has to be added
|
||||
add_batch_axis (`bool`): data in `arguments` are single instances and a batch axis has to be added
|
||||
|
||||
Returns:
|
||||
`dict` that maps variables to sanitized batches
|
||||
'''
|
||||
var_map = {}
|
||||
if input_map:
|
||||
for var, batch in input_map.items():
|
||||
from ..cntk_py import Value
|
||||
if not isinstance(batch, Value):
|
||||
if add_batch_axis:
|
||||
batch = [batch]
|
||||
if isinstance(batch, np.ndarray):
|
||||
if batch.dtype == np.int:
|
||||
batch = batch.astype(np.float32)
|
||||
if batch.dtype not in (np.float32, np.float64):
|
||||
raise ValueError(
|
||||
'only float32 and float64 are supported')
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
else:
|
||||
if is_tensor(batch):
|
||||
batch = np.asarray(batch, dtype=precision_numpy)
|
||||
batch = create_Value_from_NumPy(batch, device)
|
||||
else:
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
|
||||
var_map[var] = batch
|
||||
if not arguments:
|
||||
if len(op_arguments) > 0:
|
||||
raise ValueError('function expects %i arguments'%len(op_arguments))
|
||||
return {}
|
||||
|
||||
if len(op_arguments) == 1 and not isinstance(arguments, dict):
|
||||
return { op_arguments[0] : arguments }
|
||||
|
||||
if isinstance(arguments, dict):
|
||||
arg_names = [var.name() for var in op_arguments]
|
||||
name_counter = collections.Counter(arg_names)
|
||||
|
||||
var_name_map = dict((var, var.name()) for var in op_arguments)
|
||||
|
||||
elif isinstance(arguments, list):
|
||||
arguments = dict(zip(op_arguments, arguments))
|
||||
|
||||
else:
|
||||
raise ValueError('type "%s" is not supported'%type(arguments))
|
||||
|
||||
if len(arguments) < len(op_arguments):
|
||||
raise ValueError('expected %i arguments, but got %i'%(len(op_arguments), len(arguments)))
|
||||
|
||||
var_map = {}
|
||||
for var, batch in arguments.items():
|
||||
if isinstance(var, str):
|
||||
if name_counter[var] == 0:
|
||||
raise ValueError('variable with name "%s" does not exist in the network. Available variable names: %s'%(var, ", ".join(var_name_map)))
|
||||
elif name_counter[var] > 1:
|
||||
raise ValueError('node name "%s" is not unique'%var)
|
||||
|
||||
var = var_name_map[var]
|
||||
|
||||
from ..cntk_py import Value
|
||||
if not isinstance(batch, Value):
|
||||
if add_batch_axis:
|
||||
batch = [batch]
|
||||
if isinstance(batch, np.ndarray):
|
||||
if batch.dtype == np.int:
|
||||
batch = batch.astype(np.float32)
|
||||
if batch.dtype not in (np.float32, np.float64):
|
||||
raise ValueError('only float32 and float64 are supported')
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
else:
|
||||
if is_tensor(batch):
|
||||
if precision_numpy is None:
|
||||
precision_numpy = np.float32
|
||||
batch = np.asarray(batch, dtype=precision_numpy)
|
||||
batch = create_Value_from_NumPy(batch, device)
|
||||
else:
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
|
||||
var_map[var] = batch
|
||||
|
||||
return var_map
|
||||
|
||||
|
@ -605,7 +644,7 @@ def ensure_dev(ndav, dev):
|
|||
|
||||
return ndav
|
||||
|
||||
def eval(op, precision, device, input_map=None, backward_pass=False):
|
||||
def eval(op, precision, device, arguments=None, backward_pass=False):
|
||||
'''
|
||||
It evaluates `op` on the data provided by the reader. This is useful
|
||||
mainly to explore the operators and for convenient unit testing.
|
||||
|
@ -614,7 +653,10 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
|
|||
op (:class:`Function`): operation to evaluate
|
||||
precision (`str` or `None`): precision being 'float32', 'float64', or `None`, in which case it will be determined by inspecting the operator (costly)
|
||||
device (:class:Cntk.DeviceDescriptor): the device the descriptor, whether it is CPU or GPU (and which one)
|
||||
input_map (`dict`): describes how to map inputs to the data in a data file using a number, NumPy array or reader object
|
||||
arguments (`dict` or `list`): map from input variables to the data
|
||||
or list of inputs in the order that the function expects. Data
|
||||
should be either NumPy arrays or cntk.Value instances returned by a
|
||||
minibatch source.
|
||||
backward_pass (`bool`, optional): whether a backward pass is performed
|
||||
|
||||
Returns:
|
||||
|
@ -624,7 +666,7 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
|
|||
if precision is not None:
|
||||
precision = precision_numpy(precision)
|
||||
|
||||
forward_in_var_map = sanitize_var_map(input_map, precision, device)
|
||||
forward_in_var_map = sanitize_var_map(op.arguments(), arguments, precision, device)
|
||||
|
||||
forward_out_var_map = {}
|
||||
forward_retain = set()
|
||||
|
@ -649,7 +691,7 @@ def eval(op, precision, device, input_map=None, backward_pass=False):
|
|||
root_gradients = {}
|
||||
for v, o in forward_output.items():
|
||||
root_gradients[v] = ones_like(o, precision)
|
||||
root_gradients = sanitize_var_map(root_gradients, precision, device)
|
||||
root_gradients = sanitize_var_map(op.outputs(), root_gradients, precision, device)
|
||||
|
||||
backward_var_map = dict((var, None) for var in forward_in_var_map)
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
from cntk import cntk_py
|
||||
|
||||
def save_model(root_op, filename):
|
||||
'''
|
||||
Save the network of `root_op` in `model_file`.
|
||||
|
||||
Args:
|
||||
root_op (`:class:cntk.functions.Function`): op of the graph to save
|
||||
filename (`str`): filename to store the model in
|
||||
'''
|
||||
cntk_py.save_as_legacy_model(root_op, filename)
|
||||
|
||||
def load_model(data_type, filename, device=None):
|
||||
'''
|
||||
Load the network of `root_op` in `model_file`, that has been saved using
|
||||
`:func:save_model`.
|
||||
|
||||
Args:
|
||||
data_type ('float' or 'double', or NumPy type): data type of the operation
|
||||
filename (`str`): filename to load the model from
|
||||
device (:class:`cntk.DeviceDescriptor`, default to default device): instance of DeviceDescriptor
|
||||
|
||||
Returns:
|
||||
root node
|
||||
'''
|
||||
from cntk.utils import sanitize_dtype_cntk
|
||||
data_type = sanitize_dtype_cntk(data_type)
|
||||
if not device:
|
||||
device=cntk_py.DeviceDescriptor.use_default_device()
|
||||
return cntk_py.load_legacy_model(data_type, filename)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from cntk.ops import *
|
||||
from cntk.utils import load_model, save_model
|
||||
|
||||
|
||||
def test_load_save_constant():
|
||||
c = constant(value=[1,3])
|
||||
root_node = c * 5
|
||||
|
||||
result = root_node.eval()
|
||||
expected = [[[[5,15]]]]
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
filename = 'c_plus_c.mod'
|
||||
save_model(root_node, filename)
|
||||
|
||||
loaded_node = load_model('float', filename)
|
||||
loaded_result = loaded_node.eval()
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
def test_load_save_inputs():
|
||||
i1 = input_variable((1,2), name='i1')
|
||||
i2 = input_variable((2,1), name='i2')
|
||||
root_node = plus(i1, i2)
|
||||
input1 = [[[1,2]]]
|
||||
input2 = [[[[1],[2]]]]
|
||||
|
||||
result = root_node.eval({i1: input1, i2: input2})
|
||||
expected = [[[[2,3],[3,4]]]]
|
||||
assert np.allclose(result, expected)
|
||||
|
||||
filename = 'i_plus_c.mod'
|
||||
save_model(root_node, filename)
|
||||
|
||||
loaded_node = load_model('float', filename)
|
||||
|
||||
# Test specifying the input nodes by name
|
||||
# FIXME: node names are not properly saved on C++ yet
|
||||
if False:
|
||||
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
# Test spefying the input node names by order
|
||||
loaded_result = loaded_node.eval([input1, input2])
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
Загрузка…
Ссылка в новой задаче