Merge branch 'master' into ebarsoum/ImageHandsOn
For checkin...`
This commit is contained in:
Коммит
496c67299d
|
@ -22,7 +22,7 @@ rootDir = "."
|
|||
dataDir = "$rootDir$/data/"
|
||||
outputDir = "$rootDir$/Output"
|
||||
|
||||
modelPath = "$outputDir$/Fast-RCNN"
|
||||
modelPath = "$outputDir$/Fast-RCNN.model"
|
||||
#stderr = "$outputDir$/Fast-RCNN.log"
|
||||
|
||||
ImageH = 1000
|
||||
|
|
|
@ -20,7 +20,4 @@ from .io import *
|
|||
from .persist import load_model, save_model
|
||||
from .device import *
|
||||
|
||||
# TODO wrap
|
||||
from .cntk_py import momentums_per_sample
|
||||
|
||||
DATATYPE = np.float32
|
||||
|
|
|
@ -1252,7 +1252,7 @@ StreamInformation.__eq__ = lambda a,b: a.m_name==b.m_name and a.m_id==b.m_id and
|
|||
# in case of multiple outputs return the function, not the variable
|
||||
def get_output_and_keep_reference(self):
|
||||
variable = self._output()
|
||||
variable.owner = self
|
||||
variable.__owner = self
|
||||
return variable
|
||||
Function.output = lambda self:get_output_and_keep_reference(self)
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
def dfs_walk(node, visitor):
|
||||
'''
|
||||
Generic function that walks through the graph starting at ``node`` and
|
||||
uses function ``visitor`` on each node to check whether it should be
|
||||
returned.
|
||||
|
||||
Args:
|
||||
node (graph node): the node to start the journey from
|
||||
visitor (Python function or lambda): function that takes a node as
|
||||
argument and returns ``True`` if that node should be returned.
|
||||
|
||||
Returns:
|
||||
List of nodes, for which ``visitor`` was ``True``
|
||||
'''
|
||||
stack = [node]
|
||||
accum = []
|
||||
visited = set()
|
||||
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if node in visited:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Function node
|
||||
stack.extend(node.root_function.inputs)
|
||||
except AttributeError:
|
||||
# OutputVariable node
|
||||
try:
|
||||
if node.is_output:
|
||||
stack.append(node.owner)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if visitor(node):
|
||||
accum.append(node)
|
||||
|
||||
visited.add(node)
|
||||
|
||||
return accum
|
||||
|
||||
def find_nodes_by_name(node, node_name):
|
||||
'''
|
||||
Finds nodes in the graph starting from `node` and doing a depth-first
|
||||
search.
|
||||
|
||||
Args:
|
||||
node (graph node): the node to start the journey from
|
||||
node_name (`str`): name for which we are search nodes
|
||||
|
||||
Returns:
|
||||
List of nodes having the specified name
|
||||
'''
|
||||
return dfs_walk(node, lambda x: x.name == node_name)
|
||||
|
|
@ -90,11 +90,11 @@ class Function(cntk_py.Function):
|
|||
return self(other)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__:
|
||||
try:
|
||||
return self.__dict__[name]
|
||||
|
||||
if len(self.outputs) == 1:
|
||||
return getattr(self.output, name)
|
||||
except KeyError:
|
||||
if len(self.outputs) == 1:
|
||||
return getattr(self.output, name)
|
||||
|
||||
raise AttributeError("'%s' object has no attribute '%s'" %
|
||||
(type(self), name))
|
||||
|
@ -413,3 +413,12 @@ class Function(cntk_py.Function):
|
|||
The primitive function at the root of the graph of functions underlying this function.
|
||||
'''
|
||||
return super(Function, self).root_function()
|
||||
|
||||
@property
|
||||
@typemap
|
||||
def uid(self):
|
||||
'''
|
||||
The internally generated unique name of the function.
|
||||
'''
|
||||
return super(Function, self).uid()
|
||||
|
||||
|
|
|
@ -3,41 +3,18 @@ from cntk import cntk_py, utils
|
|||
from ..tensor import TensorOpsMixin
|
||||
from ..utils import typemap, sanitize_precision, sanitize_value, sanitize_dtype_cntk
|
||||
|
||||
|
||||
class Variable(TensorOpsMixin, cntk_py.Variable):
|
||||
class VariableMixin:
|
||||
'''
|
||||
Denotes a symbolic entity corresponding to the inputs and outputs of a Function.
|
||||
|
||||
Args:
|
||||
shape (`tuple`): the shape of this variable.
|
||||
dtype (`np.float32` or `np.float64`): data type of the values that will be bound to this variable.
|
||||
Default is np.float32
|
||||
needs_gradient (`bool`): if set to True any expression that contains this variable
|
||||
will also be differentiated with respect to this variable.
|
||||
is_sparse(`bool`): whether this is a sparse or dense input (or output)
|
||||
dynamic_axes(`list` of :class:`cntk.axis.Axis`): the dynamic axes of this variable. These
|
||||
express dimensions that can vary across examples or minibatches.
|
||||
name(`str`): an optional name for this parameter.
|
||||
Standard properties for :class:`Variable` and its derived classes
|
||||
:class:`Parameter` and :class:`Constant`.
|
||||
'''
|
||||
def __init__(self, shape=None, dtype=None, needs_gradient=False, is_sparse=False,
|
||||
dynamic_axes=[cntk_py.Axis.default_dynamic_axis(), cntk_py.Axis.default_batch_axis()], name=''):
|
||||
shape = utils.sanitize_shape(shape)
|
||||
|
||||
if dtype is None:
|
||||
dtype = np.float32
|
||||
|
||||
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
|
||||
|
||||
super(Variable, self).__init__(shape, is_sparse, cntk_dtype,
|
||||
needs_gradient, name, dynamic_axes)
|
||||
|
||||
@property
|
||||
@typemap
|
||||
def dynamic_axes(self):
|
||||
'''
|
||||
The dynamic axes of this variable.
|
||||
'''
|
||||
return super(Variable, self).dynamic_axes()
|
||||
return super().dynamic_axes()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
|
@ -51,56 +28,56 @@ class Variable(TensorOpsMixin, cntk_py.Variable):
|
|||
'''
|
||||
Whether this variable is a constant.
|
||||
'''
|
||||
return super(Variable, self).is_constant()
|
||||
return super().is_constant()
|
||||
|
||||
@property
|
||||
def is_input(self):
|
||||
'''
|
||||
Whether this variable is an input.
|
||||
'''
|
||||
return super(Variable, self).is_input()
|
||||
return super().is_input()
|
||||
|
||||
@property
|
||||
def is_output(self):
|
||||
'''
|
||||
Whether this variable is an output.
|
||||
'''
|
||||
return super(Variable, self).is_output()
|
||||
return super().is_output()
|
||||
|
||||
@property
|
||||
def is_parameter(self):
|
||||
'''
|
||||
Whether this variable is a parameter.
|
||||
'''
|
||||
return super(Variable, self).is_parameter()
|
||||
return super().is_parameter()
|
||||
|
||||
@property
|
||||
def is_placeholder(self):
|
||||
'''
|
||||
Whether this variable is a placeholder.
|
||||
'''
|
||||
return super(Variable, self).is_placeholder()
|
||||
return super().is_placeholder()
|
||||
|
||||
@property
|
||||
def is_sparse(self):
|
||||
'''
|
||||
Whether this variable is sparse.
|
||||
'''
|
||||
return super(Variable, self).is_sparse()
|
||||
return super().is_sparse()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
'''
|
||||
The name of this variable.
|
||||
'''
|
||||
return super(Variable, self).name()
|
||||
return super().name()
|
||||
|
||||
@property
|
||||
def needs_gradient(self):
|
||||
'''
|
||||
Whether this variable needs gradients.
|
||||
'''
|
||||
return super(Variable, self).needs_gradient()
|
||||
return super().needs_gradient()
|
||||
|
||||
@property
|
||||
@typemap
|
||||
|
@ -110,23 +87,51 @@ class Variable(TensorOpsMixin, cntk_py.Variable):
|
|||
'''
|
||||
if self.is_output == False:
|
||||
raise RuntimeError('called owner() on a variable that is not an output variable')
|
||||
return super(Variable, self).owner()
|
||||
return super().owner()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
'''
|
||||
The shape of this variable as a tuple.
|
||||
'''
|
||||
return super(Variable, self).shape().dimensions()
|
||||
return super().shape().dimensions()
|
||||
|
||||
@property
|
||||
def uid(self):
|
||||
'''
|
||||
The internally generated unique name of the variable.
|
||||
'''
|
||||
return super(Variable, self).uid()
|
||||
return super().uid()
|
||||
|
||||
class Parameter(TensorOpsMixin, cntk_py.Parameter):
|
||||
|
||||
class Variable(VariableMixin, TensorOpsMixin, cntk_py.Variable):
|
||||
'''
|
||||
Denotes a symbolic entity corresponding to the inputs and outputs of a Function.
|
||||
|
||||
Args:
|
||||
shape (`tuple`): the shape of this variable.
|
||||
data_type (`np.float32 or np.float64`): data type of the values that will be bound to this variable.
|
||||
Default is np.float32
|
||||
needs_gradient (`bool`): if set to True any expression that contains this variable
|
||||
will also be differentiated with respect to this variable.
|
||||
is_sparse(`bool`): whether this is a sparse or dense input (or output)
|
||||
dynamic_axes(`list` of :class:`cntk.axis.Axis`): the dynamic axes of this variable. These
|
||||
express dimensions that can vary across examples or minibatches.
|
||||
name(`str`): an optional name for this parameter.
|
||||
'''
|
||||
def __init__(self, shape=None, data_type=None, needs_gradient=False, is_sparse=False,
|
||||
dynamic_axes=[cntk_py.Axis.default_dynamic_axis(), cntk_py.Axis.default_batch_axis()], name=''):
|
||||
shape = utils.sanitize_shape(shape)
|
||||
|
||||
if data_type is None:
|
||||
data_type = np.float32
|
||||
dtype = utils.sanitize_dtype_cntk(data_type)
|
||||
|
||||
super().__init__(shape, is_sparse, dtype, needs_gradient, name,
|
||||
dynamic_axes)
|
||||
|
||||
|
||||
class Parameter(VariableMixin, TensorOpsMixin, cntk_py.Parameter):
|
||||
'''
|
||||
A trainable parameter. It can be a scalar, vector, matrix, or tensor
|
||||
of floating point numbers that can be modified by a training
|
||||
|
@ -146,7 +151,7 @@ class Parameter(TensorOpsMixin, cntk_py.Parameter):
|
|||
Parameters are Variables and therefore they inherit all their methods.
|
||||
'''
|
||||
def __init__(self, shape=None, init=None, dtype=None,
|
||||
device=None, name=''):
|
||||
device=None, name=''):
|
||||
|
||||
if dtype is None:
|
||||
if isinstance(init, np.ndarray):
|
||||
|
@ -159,11 +164,11 @@ class Parameter(TensorOpsMixin, cntk_py.Parameter):
|
|||
|
||||
if isinstance(init, (np.ndarray, list, float, int)):
|
||||
ndav = sanitize_value(shape, init, dtype, device)
|
||||
super(Parameter, self).__init__(ndav, name)
|
||||
super().__init__(ndav, name)
|
||||
else:
|
||||
shape = utils.sanitize_shape(shape)
|
||||
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
|
||||
super(Parameter, self).__init__(shape, cntk_dtype, init,
|
||||
cntk_dtype = utils.sanitize_dtype_cntk(dtype)
|
||||
super().__init__(shape, cntk_dtype, init,
|
||||
device, name)
|
||||
|
||||
@property
|
||||
|
@ -171,23 +176,9 @@ class Parameter(TensorOpsMixin, cntk_py.Parameter):
|
|||
'''
|
||||
NumPy array of the value
|
||||
'''
|
||||
return super(Parameter, self).value().to_numpy()
|
||||
return super().value().to_numpy()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
'''
|
||||
The shape of this parameter as a tuple.
|
||||
'''
|
||||
return super(Parameter, self).shape().dimensions()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
'''
|
||||
The NumPy type of this variable.
|
||||
'''
|
||||
return sanitize_precision(self.get_data_type())
|
||||
|
||||
class Constant(TensorOpsMixin, cntk_py.Constant):
|
||||
class Constant(VariableMixin, TensorOpsMixin, cntk_py.Constant):
|
||||
'''
|
||||
A constant value. It can be a scalar, vector, matrix, or tensor
|
||||
of floating point numbers that cannot be modified.
|
||||
|
@ -210,49 +201,16 @@ class Constant(TensorOpsMixin, cntk_py.Constant):
|
|||
dtype = np.float32
|
||||
|
||||
if np.isscalar(value):
|
||||
super(Constant, self).__init__(utils.sanitize_shape(shape), sanitize_dtype_cntk(dtype), value)
|
||||
super().__init__(utils.sanitize_shape(shape), sanitize_dtype_cntk(dtype), value)
|
||||
else:
|
||||
ndav = sanitize_value(shape, value, dtype, device)
|
||||
super(Constant, self).__init__(ndav, name)
|
||||
#ndav = sanitize_value(shape, value, dtype, device)
|
||||
#super(Constant, self).__init__(ndav, name)
|
||||
super().__init__(ndav, name)
|
||||
|
||||
|
||||
|
||||
##ndav = sanitize_value(shape, value, data_type, device)
|
||||
##super(Constant, self).__init__(ndav, name)
|
||||
#
|
||||
## from Parameter: [fseide]
|
||||
#if isinstance(value, (np.ndarray, list, float, int)):
|
||||
# ndav = sanitize_value(shape, value, data_type, device)
|
||||
# super(Constant, self).__init__(ndav, name)
|
||||
#else:
|
||||
# shape = utils.sanitize_shape(shape)
|
||||
# data_type = utils.sanitize_dtype_cntk(data_type)
|
||||
# super(Constant, self).__init__(shape, data_type, value,
|
||||
# device, name)
|
||||
|
||||
|
||||
|
||||
#TODO how to expose Scalar ?
|
||||
@property
|
||||
def value(self):
|
||||
'''
|
||||
NumPy array of the value
|
||||
'''
|
||||
return super(Constant, self).value().to_numpy()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
'''
|
||||
The shape of this constant as tuple.
|
||||
'''
|
||||
return super(Constant, self).shape().dimensions()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
'''
|
||||
The NumPy type of this variable.
|
||||
'''
|
||||
return sanitize_precision(self.get_data_type())
|
||||
return super().value().to_numpy()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from cntk import cntk_py
|
||||
from .utils.swig_helper import typemap
|
||||
from cntk.device import use_default_device
|
||||
|
@ -20,21 +21,23 @@ def save_model(root_op, filename, use_legacy_format=True):
|
|||
root_op.save_model(filename, use_legacy_format)
|
||||
|
||||
@typemap
|
||||
def load_model(data_type, filename, device=None):
|
||||
def load_model(filename, dtype=np.float32, device=None):
|
||||
'''
|
||||
Load the network in ``filename``, that has been saved using
|
||||
`:func:save_model`.
|
||||
|
||||
Args:
|
||||
data_type ('float' or 'double', or NumPy type): data type of the operation
|
||||
filename (`str`): filename to load the model from
|
||||
device (:class:`cntk.device.DeviceDescriptor`, default to default device): instance of DeviceDescriptor
|
||||
dtype ('float', 'double', or NumPy type, default ``np.float32``): data
|
||||
type of the operation
|
||||
device (:class:`cntk.DeviceDescriptor`, default is the default device):
|
||||
instance of DeviceDescriptor
|
||||
|
||||
Returns:
|
||||
root node
|
||||
'''
|
||||
from cntk.utils import sanitize_dtype_cntk
|
||||
data_type = sanitize_dtype_cntk(data_type)
|
||||
dtype = sanitize_dtype_cntk(dtype)
|
||||
if not device:
|
||||
device = use_default_device()
|
||||
return cntk_py.Function.load_model(data_type, filename, device)
|
||||
return cntk_py.Function.load_model(dtype, filename, device)
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from ..graph import *
|
||||
from ..ops import *
|
||||
from ..axis import Axis
|
||||
|
||||
def _graph_dict():
|
||||
# This function creates a graph that has no real meaning other than
|
||||
# providing something to traverse.
|
||||
d = {}
|
||||
|
||||
batch_axis = Axis.default_batch_axis()
|
||||
input_seq_axis = Axis('ia')
|
||||
input_dynamic_axes = [batch_axis, input_seq_axis]
|
||||
|
||||
d['i1'] = input_variable(shape=(2,3), dynamic_axes=input_dynamic_axes, name='i1')
|
||||
d['i2'] = input_variable(shape=(2,3), dynamic_axes=input_dynamic_axes, name='i2')
|
||||
|
||||
d['p1'] = parameter(shape=(3,2), name='p1')
|
||||
|
||||
|
||||
d['op1'] = plus(d['i1'], d['i2'], name='op1')
|
||||
d['op2'] = times(d['op1'], d['p1'], name='op2')
|
||||
|
||||
#d['slice'] = slice(d['i2'], Axis.default_dynamic_axis(), 0, 3)
|
||||
#label_sentence_start = sequence.first(raw_labels)
|
||||
|
||||
# no name
|
||||
d['p2'] = parameter(shape=(2,2))
|
||||
|
||||
# duplicate names
|
||||
d['op3a'] = plus(d['op2'], d['p2'], name='op3')
|
||||
d['op3b'] = plus(d['op3a'], d['p2'], name='op3')
|
||||
|
||||
d['first'] = sequence.first(d['op3b'], name='past')
|
||||
|
||||
d['root'] = d['first']
|
||||
|
||||
return d
|
||||
|
||||
|
||||
def test_find_nodes():
|
||||
d = _graph_dict()
|
||||
|
||||
for name in ['i1', 'i2', 'p1', 'op1', 'op2', 'past']:
|
||||
n = find_nodes_by_name(d['root'], name)
|
||||
assert len(n) == 1, name
|
||||
assert n[0].name == name, name
|
||||
|
||||
n = find_nodes_by_name(d['root'], 'op3')
|
||||
assert len(n) == 2, 'op3'
|
||||
assert n[0].name == 'op3' and n[1].name == 'op3', 'op3'
|
||||
|
||||
none = find_nodes_by_name(d['root'], 'none')
|
||||
assert none == []
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from ..initializer import *
|
||||
from .. import parameter, input_variable, momentums_per_sample
|
||||
from .. import parameter
|
||||
|
||||
|
||||
def _check(init, name):
|
||||
|
|
|
@ -21,7 +21,7 @@ def test_load_save_constant(tmpdir):
|
|||
filename = str(tmpdir / 'c_plus_c.mod')
|
||||
save_model(root_node, filename)
|
||||
|
||||
loaded_node = load_model('float', filename)
|
||||
loaded_node = load_model(filename)
|
||||
loaded_result = loaded_node.eval()
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
|
@ -37,7 +37,7 @@ def test_load_save_input(tmpdir):
|
|||
filename = str(tmpdir / 'i_plus_c_0.mod')
|
||||
save_model(root_node, filename)
|
||||
|
||||
loaded_node = load_model('float', filename)
|
||||
loaded_node = load_model(filename)
|
||||
|
||||
# Test spefying the input node names by order
|
||||
loaded_result = loaded_node.eval([input1])
|
||||
|
@ -57,7 +57,7 @@ def test_load_save_inputs(tmpdir):
|
|||
filename = str(tmpdir / 'i_plus_i_0.mod')
|
||||
save_model(root_node, filename)
|
||||
|
||||
loaded_node = load_model('float', filename)
|
||||
loaded_node = load_model(filename)
|
||||
|
||||
# Test specifying the input nodes by name
|
||||
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
|
||||
|
@ -75,7 +75,7 @@ def test_load_save_unique_input(tmpdir):
|
|||
filename = str(tmpdir / 'i_plus_0.mod')
|
||||
save_model(root_node, filename)
|
||||
|
||||
loaded_node = load_model('float', filename)
|
||||
loaded_node = load_model(filename)
|
||||
|
||||
# Test specifying the only value for an unique input
|
||||
loaded_result = loaded_node.eval(input1)
|
||||
|
|
|
@ -33,6 +33,9 @@ class Trainer(cntk_py.Trainer):
|
|||
loss_function = sanitize_function(loss_function)
|
||||
eval_function = sanitize_function(eval_function)
|
||||
|
||||
if not isinstance(parameter_learners, list):
|
||||
parameter_learners = [parameter_learners]
|
||||
|
||||
super(Trainer, self).__init__(model, loss_function, eval_function,
|
||||
parameter_learners)
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
|
|||
lr_schedule = learning_rate_schedule(lr_per_sample, units=epoch_size)
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, momentum_per_sample,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, [learner])
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
|
|
|
@ -154,7 +154,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
|
|||
lr_schedule = learning_rate_schedule(lr_per_sample, units=epoch_size)
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, momentum_per_sample,
|
||||
l2_regularization_weight = l2_reg_weight)
|
||||
trainer = Trainer(z, ce, pe, [learner])
|
||||
trainer = Trainer(z, ce, pe, learner)
|
||||
|
||||
# define mapping from reader streams to network inputs
|
||||
input_map = {
|
||||
|
|
|
@ -61,8 +61,7 @@ def simple_mnist(debug_output=False):
|
|||
labels_si = mb_source[labels_stream_name]
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(netout, ce, pe, [sgd(netout.parameters,
|
||||
lr=0.003125)])
|
||||
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=0.003125))
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
minibatch_size = 64
|
||||
|
|
|
@ -53,7 +53,7 @@ def ffnet():
|
|||
pe = classification_error(netout, label)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(netout, ce, pe, [sgd(netout.parameters, lr=0.005)])
|
||||
trainer = Trainer(netout, ce, pe, sgd(netout.parameters, lr=0.02))
|
||||
|
||||
# Get minibatches of training data and perform model training
|
||||
minibatch_size = 25
|
||||
|
|
|
@ -116,9 +116,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
m_schedule = momentum_schedule(momentum_time_constant)
|
||||
clipping_threshold_per_sample = 2.3
|
||||
gradient_clipping_with_truncation = True
|
||||
|
||||
trainer = Trainer(z, ce, errs, [momentum_sgd(
|
||||
z.parameters, lr, m_schedule, clipping_threshold_per_sample, gradient_clipping_with_truncation)])
|
||||
learner = momentum_sgd(z.parameters, lr, m_schedule, clipping_threshold_per_sample, gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
||||
# setup data
|
||||
rel_path = r"../../../../Examples/SequenceToSequence/CMUDict/Data/cmudict-0.7b.train-dev-20-21.ctf"
|
||||
|
|
|
@ -62,7 +62,7 @@ def train_sequence_classifier(debug_output=False):
|
|||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
trainer = Trainer(classifier_output, ce, pe,
|
||||
[sgd(classifier_output.parameters, lr=0.0005)])
|
||||
sgd(classifier_output.parameters, lr=0.0005))
|
||||
|
||||
# Get minibatches of sequences to train with and perform model training
|
||||
minibatch_size = 200
|
||||
|
|
Загрузка…
Ссылка в новой задаче