Merge branch 'master' into ebarsoum/ImageHandsOn

For checkin...`
This commit is contained in:
Emad Barsoum 2016-10-24 23:36:27 -07:00
Родитель 56b2c6826f b60acf9e55
Коммит 496c67299d
17 изменённых файлов: 210 добавлений и 123 удалений

Просмотреть файл

@ -22,7 +22,7 @@ rootDir = "."
dataDir = "$rootDir$/data/"
outputDir = "$rootDir$/Output"
modelPath = "$outputDir$/Fast-RCNN"
modelPath = "$outputDir$/Fast-RCNN.model"
#stderr = "$outputDir$/Fast-RCNN.log"
ImageH = 1000

Просмотреть файл

@ -20,7 +20,4 @@ from .io import *
from .persist import load_model, save_model
from .device import *
# TODO wrap
from .cntk_py import momentums_per_sample
DATATYPE = np.float32

Просмотреть файл

@ -1252,7 +1252,7 @@ StreamInformation.__eq__ = lambda a,b: a.m_name==b.m_name and a.m_id==b.m_id and
# in case of multiple outputs return the function, not the variable
def get_output_and_keep_reference(self):
variable = self._output()
variable.owner = self
variable.__owner = self
return variable
Function.output = lambda self:get_output_and_keep_reference(self)

Просмотреть файл

@ -0,0 +1,60 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
def dfs_walk(node, visitor):
'''
Generic function that walks through the graph starting at ``node`` and
uses function ``visitor`` on each node to check whether it should be
returned.
Args:
node (graph node): the node to start the journey from
visitor (Python function or lambda): function that takes a node as
argument and returns ``True`` if that node should be returned.
Returns:
List of nodes, for which ``visitor`` was ``True``
'''
stack = [node]
accum = []
visited = set()
while stack:
node = stack.pop()
if node in visited:
continue
try:
# Function node
stack.extend(node.root_function.inputs)
except AttributeError:
# OutputVariable node
try:
if node.is_output:
stack.append(node.owner)
except AttributeError:
pass
if visitor(node):
accum.append(node)
visited.add(node)
return accum
def find_nodes_by_name(node, node_name):
'''
Finds nodes in the graph starting from `node` and doing a depth-first
search.
Args:
node (graph node): the node to start the journey from
node_name (`str`): name for which we are search nodes
Returns:
List of nodes having the specified name
'''
return dfs_walk(node, lambda x: x.name == node_name)

Просмотреть файл

@ -90,11 +90,11 @@ class Function(cntk_py.Function):
return self(other)
def __getattr__(self, name):
if name in self.__dict__:
try:
return self.__dict__[name]
if len(self.outputs) == 1:
return getattr(self.output, name)
except KeyError:
if len(self.outputs) == 1:
return getattr(self.output, name)
raise AttributeError("'%s' object has no attribute '%s'" %
(type(self), name))
@ -413,3 +413,12 @@ class Function(cntk_py.Function):
The primitive function at the root of the graph of functions underlying this function.
'''
return super(Function, self).root_function()
@property
@typemap
def uid(self):
'''
The internally generated unique name of the function.
'''
return super(Function, self).uid()

Просмотреть файл

@ -3,41 +3,18 @@ from cntk import cntk_py, utils
from ..tensor import TensorOpsMixin
from ..utils import typemap, sanitize_precision, sanitize_value, sanitize_dtype_cntk
class Variable(TensorOpsMixin, cntk_py.Variable):
class VariableMixin:
'''
Denotes a symbolic entity corresponding to the inputs and outputs of a Function.
Args:
shape (`tuple`): the shape of this variable.
dtype (`np.float32` or `np.float64`): data type of the values that will be bound to this variable.
Default is np.float32
needs_gradient (`bool`): if set to True any expression that contains this variable
will also be differentiated with respect to this variable.
is_sparse(`bool`): whether this is a sparse or dense input (or output)
dynamic_axes(`list` of :class:`cntk.axis.Axis`): the dynamic axes of this variable. These
express dimensions that can vary across examples or minibatches.
name(`str`): an optional name for this parameter.
Standard properties for :class:`Variable` and its derived classes
:class:`Parameter` and :class:`Constant`.
'''
def __init__(self, shape=None, dtype=None, needs_gradient=False, is_sparse=False,
dynamic_axes=[cntk_py.Axis.default_dynamic_axis(), cntk_py.Axis.default_batch_axis()], name=''):
shape = utils.sanitize_shape(shape)
if dtype is None:
dtype = np.float32
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
super(Variable, self).__init__(shape, is_sparse, cntk_dtype,
needs_gradient, name, dynamic_axes)
@property
@typemap
def dynamic_axes(self):
'''
The dynamic axes of this variable.
'''
return super(Variable, self).dynamic_axes()
return super().dynamic_axes()
@property
def dtype(self):
@ -51,56 +28,56 @@ class Variable(TensorOpsMixin, cntk_py.Variable):
'''
Whether this variable is a constant.
'''
return super(Variable, self).is_constant()
return super().is_constant()
@property
def is_input(self):
'''
Whether this variable is an input.
'''
return super(Variable, self).is_input()
return super().is_input()
@property
def is_output(self):
'''
Whether this variable is an output.
'''
return super(Variable, self).is_output()
return super().is_output()
@property
def is_parameter(self):
'''
Whether this variable is a parameter.
'''
return super(Variable, self).is_parameter()
return super().is_parameter()
@property
def is_placeholder(self):
'''
Whether this variable is a placeholder.
'''
return super(Variable, self).is_placeholder()
return super().is_placeholder()
@property
def is_sparse(self):
'''
Whether this variable is sparse.
'''
return super(Variable, self).is_sparse()
return super().is_sparse()
@property
def name(self):
'''
The name of this variable.
'''
return super(Variable, self).name()
return super().name()
@property
def needs_gradient(self):
'''
Whether this variable needs gradients.
'''
return super(Variable, self).needs_gradient()
return super().needs_gradient()
@property
@typemap
@ -110,23 +87,51 @@ class Variable(TensorOpsMixin, cntk_py.Variable):
'''
if self.is_output == False:
raise RuntimeError('called owner() on a variable that is not an output variable')
return super(Variable, self).owner()
return super().owner()
@property
def shape(self):
'''
The shape of this variable as a tuple.
'''
return super(Variable, self).shape().dimensions()
return super().shape().dimensions()
@property
def uid(self):
'''
The internally generated unique name of the variable.
'''
return super(Variable, self).uid()
return super().uid()
class Parameter(TensorOpsMixin, cntk_py.Parameter):
class Variable(VariableMixin, TensorOpsMixin, cntk_py.Variable):
'''
Denotes a symbolic entity corresponding to the inputs and outputs of a Function.
Args:
shape (`tuple`): the shape of this variable.
data_type (`np.float32 or np.float64`): data type of the values that will be bound to this variable.
Default is np.float32
needs_gradient (`bool`): if set to True any expression that contains this variable
will also be differentiated with respect to this variable.
is_sparse(`bool`): whether this is a sparse or dense input (or output)
dynamic_axes(`list` of :class:`cntk.axis.Axis`): the dynamic axes of this variable. These
express dimensions that can vary across examples or minibatches.
name(`str`): an optional name for this parameter.
'''
def __init__(self, shape=None, data_type=None, needs_gradient=False, is_sparse=False,
dynamic_axes=[cntk_py.Axis.default_dynamic_axis(), cntk_py.Axis.default_batch_axis()], name=''):
shape = utils.sanitize_shape(shape)
if data_type is None:
data_type = np.float32
dtype = utils.sanitize_dtype_cntk(data_type)
super().__init__(shape, is_sparse, dtype, needs_gradient, name,
dynamic_axes)
class Parameter(VariableMixin, TensorOpsMixin, cntk_py.Parameter):
'''
A trainable parameter. It can be a scalar, vector, matrix, or tensor
of floating point numbers that can be modified by a training
@ -146,7 +151,7 @@ class Parameter(TensorOpsMixin, cntk_py.Parameter):
Parameters are Variables and therefore they inherit all their methods.
'''
def __init__(self, shape=None, init=None, dtype=None,
device=None, name=''):
device=None, name=''):
if dtype is None:
if isinstance(init, np.ndarray):
@ -159,11 +164,11 @@ class Parameter(TensorOpsMixin, cntk_py.Parameter):
if isinstance(init, (np.ndarray, list, float, int)):
ndav = sanitize_value(shape, init, dtype, device)
super(Parameter, self).__init__(ndav, name)
super().__init__(ndav, name)
else:
shape = utils.sanitize_shape(shape)
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
super(Parameter, self).__init__(shape, cntk_dtype, init,
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
super().__init__(shape, cntk_dtype, init,
device, name)
@property
@ -171,23 +176,9 @@ class Parameter(TensorOpsMixin, cntk_py.Parameter):
'''
NumPy array of the value
'''
return super(Parameter, self).value().to_numpy()
return super().value().to_numpy()
@property
def shape(self):
'''
The shape of this parameter as a tuple.
'''
return super(Parameter, self).shape().dimensions()
@property
def dtype(self):
'''
The NumPy type of this variable.
'''
return sanitize_precision(self.get_data_type())
class Constant(TensorOpsMixin, cntk_py.Constant):
class Constant(VariableMixin, TensorOpsMixin, cntk_py.Constant):
'''
A constant value. It can be a scalar, vector, matrix, or tensor
of floating point numbers that cannot be modified.
@ -210,49 +201,16 @@ class Constant(TensorOpsMixin, cntk_py.Constant):
dtype = np.float32
if np.isscalar(value):
super(Constant, self).__init__(utils.sanitize_shape(shape), sanitize_dtype_cntk(dtype), value)
super().__init__(utils.sanitize_shape(shape), sanitize_dtype_cntk(dtype), value)
else:
ndav = sanitize_value(shape, value, dtype, device)
super(Constant, self).__init__(ndav, name)
#ndav = sanitize_value(shape, value, dtype, device)
#super(Constant, self).__init__(ndav, name)
super().__init__(ndav, name)
##ndav = sanitize_value(shape, value, data_type, device)
##super(Constant, self).__init__(ndav, name)
#
## from Parameter: [fseide]
#if isinstance(value, (np.ndarray, list, float, int)):
# ndav = sanitize_value(shape, value, data_type, device)
# super(Constant, self).__init__(ndav, name)
#else:
# shape = utils.sanitize_shape(shape)
# data_type = utils.sanitize_dtype_cntk(data_type)
# super(Constant, self).__init__(shape, data_type, value,
# device, name)
#TODO how to expose Scalar ?
@property
def value(self):
'''
NumPy array of the value
'''
return super(Constant, self).value().to_numpy()
@property
def shape(self):
'''
The shape of this constant as tuple.
'''
return super(Constant, self).shape().dimensions()
@property
def dtype(self):
'''
The NumPy type of this variable.
'''
return sanitize_precision(self.get_data_type())
return super().value().to_numpy()

Просмотреть файл

@ -3,6 +3,7 @@
# for full license information.
# ==============================================================================
import numpy as np
from cntk import cntk_py
from .utils.swig_helper import typemap
from cntk.device import use_default_device
@ -20,21 +21,23 @@ def save_model(root_op, filename, use_legacy_format=True):
root_op.save_model(filename, use_legacy_format)
@typemap
def load_model(data_type, filename, device=None):
def load_model(filename, dtype=np.float32, device=None):
'''
Load the network in ``filename``, that has been saved using
`:func:save_model`.
Args:
data_type ('float' or 'double', or NumPy type): data type of the operation
filename (`str`): filename to load the model from
device (:class:`cntk.device.DeviceDescriptor`, default to default device): instance of DeviceDescriptor
dtype ('float', 'double', or NumPy type, default ``np.float32``): data
type of the operation
device (:class:`cntk.DeviceDescriptor`, default is the default device):
instance of DeviceDescriptor
Returns:
root node
'''
from cntk.utils import sanitize_dtype_cntk
data_type = sanitize_dtype_cntk(data_type)
dtype = sanitize_dtype_cntk(dtype)
if not device:
device = use_default_device()
return cntk_py.Function.load_model(data_type, filename, device)
return cntk_py.Function.load_model(dtype, filename, device)

Просмотреть файл

@ -0,0 +1,59 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import numpy as np
from ..graph import *
from ..ops import *
from ..axis import Axis
def _graph_dict():
# This function creates a graph that has no real meaning other than
# providing something to traverse.
d = {}
batch_axis = Axis.default_batch_axis()
input_seq_axis = Axis('ia')
input_dynamic_axes = [batch_axis, input_seq_axis]
d['i1'] = input_variable(shape=(2,3), dynamic_axes=input_dynamic_axes, name='i1')
d['i2'] = input_variable(shape=(2,3), dynamic_axes=input_dynamic_axes, name='i2')
d['p1'] = parameter(shape=(3,2), name='p1')
d['op1'] = plus(d['i1'], d['i2'], name='op1')
d['op2'] = times(d['op1'], d['p1'], name='op2')
#d['slice'] = slice(d['i2'], Axis.default_dynamic_axis(), 0, 3)
#label_sentence_start = sequence.first(raw_labels)
# no name
d['p2'] = parameter(shape=(2,2))
# duplicate names
d['op3a'] = plus(d['op2'], d['p2'], name='op3')
d['op3b'] = plus(d['op3a'], d['p2'], name='op3')
d['first'] = sequence.first(d['op3b'], name='past')
d['root'] = d['first']
return d
def test_find_nodes():
d = _graph_dict()
for name in ['i1', 'i2', 'p1', 'op1', 'op2', 'past']:
n = find_nodes_by_name(d['root'], name)
assert len(n) == 1, name
assert n[0].name == name, name
n = find_nodes_by_name(d['root'], 'op3')
assert len(n) == 2, 'op3'
assert n[0].name == 'op3' and n[1].name == 'op3', 'op3'
none = find_nodes_by_name(d['root'], 'none')
assert none == []

Просмотреть файл

@ -8,7 +8,7 @@ import numpy as np
import pytest
from ..initializer import *
from .. import parameter, input_variable, momentums_per_sample
from .. import parameter
def _check(init, name):

Просмотреть файл

@ -21,7 +21,7 @@ def test_load_save_constant(tmpdir):
filename = str(tmpdir / 'c_plus_c.mod')
save_model(root_node, filename)
loaded_node = load_model('float', filename)
loaded_node = load_model(filename)
loaded_result = loaded_node.eval()
assert np.allclose(loaded_result, expected)
@ -37,7 +37,7 @@ def test_load_save_input(tmpdir):
filename = str(tmpdir / 'i_plus_c_0.mod')
save_model(root_node, filename)
loaded_node = load_model('float', filename)
loaded_node = load_model(filename)
# Test spefying the input node names by order
loaded_result = loaded_node.eval([input1])
@ -57,7 +57,7 @@ def test_load_save_inputs(tmpdir):
filename = str(tmpdir / 'i_plus_i_0.mod')
save_model(root_node, filename)
loaded_node = load_model('float', filename)
loaded_node = load_model(filename)
# Test specifying the input nodes by name
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
@ -75,7 +75,7 @@ def test_load_save_unique_input(tmpdir):
filename = str(tmpdir / 'i_plus_0.mod')
save_model(root_node, filename)
loaded_node = load_model('float', filename)
loaded_node = load_model(filename)
# Test specifying the only value for an unique input
loaded_result = loaded_node.eval(input1)

Просмотреть файл

@ -33,6 +33,9 @@ class Trainer(cntk_py.Trainer):
loss_function = sanitize_function(loss_function)
eval_function = sanitize_function(eval_function)
if not isinstance(parameter_learners, list):
parameter_learners = [parameter_learners]
super(Trainer, self).__init__(model, loss_function, eval_function,
parameter_learners)

Просмотреть файл

@ -131,7 +131,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
lr_schedule = learning_rate_schedule(lr_per_sample, units=epoch_size)
learner = momentum_sgd(z.parameters, lr_schedule, momentum_per_sample,
l2_regularization_weight = l2_reg_weight)
trainer = Trainer(z, ce, pe, [learner])
trainer = Trainer(z, ce, pe, learner)
# define mapping from reader streams to network inputs
input_map = {

Просмотреть файл

@ -154,7 +154,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
lr_schedule = learning_rate_schedule(lr_per_sample, units=epoch_size)
learner = momentum_sgd(z.parameters, lr_schedule, momentum_per_sample,
l2_regularization_weight = l2_reg_weight)
trainer = Trainer(z, ce, pe, [learner])
trainer = Trainer(z, ce, pe, learner)
# define mapping from reader streams to network inputs
input_map = {

Просмотреть файл

@ -61,8 +61,7 @@ def simple_mnist(debug_output=False):
labels_si = mb_source[labels_stream_name]
# Instantiate the trainer object to drive the model training
trainer = Trainer(netout, ce, pe, [sgd(netout.parameters,
lr=0.003125)])
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=0.003125))
# Get minibatches of images to train with and perform model training
minibatch_size = 64

Просмотреть файл

@ -53,7 +53,7 @@ def ffnet():
pe = classification_error(netout, label)
# Instantiate the trainer object to drive the model training
trainer = Trainer(netout, ce, pe, [sgd(netout.parameters, lr=0.005)])
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=0.02))
# Get minibatches of training data and perform model training
minibatch_size = 25

Просмотреть файл

@ -116,9 +116,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
m_schedule = momentum_schedule(momentum_time_constant)
clipping_threshold_per_sample = 2.3
gradient_clipping_with_truncation = True
trainer = Trainer(z, ce, errs, [momentum_sgd(
z.parameters, lr, m_schedule, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
learner = momentum_sgd(z.parameters, lr, m_schedule, clipping_threshold_per_sample, gradient_clipping_with_truncation)
trainer = Trainer(z, ce, errs, learner)
# setup data
rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf"

Просмотреть файл

@ -62,7 +62,7 @@ def train_sequence_classifier(debug_output=False):
# Instantiate the trainer object to drive the model training
trainer = Trainer(classifier_output, ce, pe,
[sgd(classifier_output.parameters, lr=0.0005)])
sgd(classifier_output.parameters, lr=0.0005))
# Get minibatches of sequences to train with and perform model training
minibatch_size = 200