CNTK v2 library: Fix broken python examples and a bug in python Parameter construction wrapper

This commit is contained in:
Amit Agarwal 2016-09-30 03:46:13 -07:00
Родитель 9bad39ecc3
Коммит d0d0de6eee
9 изменённых файлов: 58 добавлений и 62 удалений

Просмотреть файл

@ -750,7 +750,7 @@ namespace CNTK
/// ///
/// Destruct 'this' Value object. /// Destruct 'this' Value object.
/// ///
CNTK_API virtual ~Value(); virtual ~Value();
/// ///
/// Returns the descriptor of the device that 'this' Value resides on /// Returns the descriptor of the device that 'this' Value resides on
@ -796,28 +796,28 @@ namespace CNTK
/// ///
/// Returns the NDArrayView object corresponding to the data contents of 'this value object. /// Returns the NDArrayView object corresponding to the data contents of 'this value object.
/// ///
CNTK_API virtual NDArrayViewPtr Data() const; virtual NDArrayViewPtr Data() const;
/// ///
/// Returns the NDMask object corresponding to the mask associated with 'this value object. /// Returns the NDMask object corresponding to the mask associated with 'this value object.
/// ///
CNTK_API virtual NDMaskPtr Mask() const; virtual NDMaskPtr Mask() const;
/// ///
/// Creates a new Value with newly allocated storage on the same device as 'this' Value and copies 'this' Value's contents into the newly allocated Value. /// Creates a new Value with newly allocated storage on the same device as 'this' Value and copies 'this' Value's contents into the newly allocated Value.
/// ///
CNTK_API virtual ValuePtr DeepClone(bool readOnly = false) const; virtual ValuePtr DeepClone(bool readOnly = false) const;
/// ///
/// Creates a new Value which is an alias of 'this' Value. /// Creates a new Value which is an alias of 'this' Value.
/// ///
CNTK_API virtual ValuePtr Alias(bool readOnly = false) const; virtual ValuePtr Alias(bool readOnly = false) const;
/// ///
/// Copies the contents of the 'source' Value to 'this' Value. /// Copies the contents of the 'source' Value to 'this' Value.
/// The shapes of the 'source' Value's data and mask must be identical to 'this' Value's data and mask. /// The shapes of the 'source' Value's data and mask must be identical to 'this' Value's data and mask.
/// ///
CNTK_API virtual void CopyFrom(const Value& source); virtual void CopyFrom(const Value& source);
private: private:
// Disallow copy and move construction and assignment // Disallow copy and move construction and assignment
@ -2025,10 +2025,10 @@ namespace CNTK
/// and the user is responsible for ensuring that the contents of the inputs and outputs are unchanged until after any uses of the BackPropState instance /// and the user is responsible for ensuring that the contents of the inputs and outputs are unchanged until after any uses of the BackPropState instance
/// for backpropagating gradients through this function. /// for backpropagating gradients through this function.
/// ///
CNTK_API virtual BackPropStatePtr Forward(const std::unordered_map<Variable, ValuePtr>& arguments, virtual BackPropStatePtr Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
std::unordered_map<Variable, ValuePtr>& outputs, std::unordered_map<Variable, ValuePtr>& outputs,
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(), const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor = {}) = 0; const std::unordered_set<Variable>& outputsToRetainBackwardStateFor = {}) = 0;
/// ///
/// Backpropagates supplied 'rootGradientValues' for one or more of the output variables of the Function, to produce gradient Values /// Backpropagates supplied 'rootGradientValues' for one or more of the output variables of the Function, to produce gradient Values
@ -2039,9 +2039,9 @@ namespace CNTK
/// The 'state' parameter is an instance of an BackPropState instance obtained from a previous call to the Forward method on 'this; Function for the /// The 'state' parameter is an instance of an BackPropState instance obtained from a previous call to the Forward method on 'this; Function for the
/// computation that this gradient backpropagation corresponds to. /// computation that this gradient backpropagation corresponds to.
/// ///
CNTK_API virtual void Backward(const BackPropStatePtr& state, virtual void Backward(const BackPropStatePtr& state,
const std::unordered_map<Variable, ValuePtr>& rootGradientValues, const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs) = 0; std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs) = 0;
public: public:
@ -2621,7 +2621,7 @@ namespace CNTK
// Method to update the parameters associated with this learner. By returning false, this method indicates that // Method to update the parameters associated with this learner. By returning false, this method indicates that
// learning has stopped for all of the parameters associated with this learner // learning has stopped for all of the parameters associated with this learner
// //
CNTK_API virtual bool Update(const std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) = 0; virtual bool Update(const std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount) = 0;
/// ///
/// Returns the set of parameters associated with this learner. /// Returns the set of parameters associated with this learner.
@ -2633,7 +2633,7 @@ namespace CNTK
/// ///
// TODO: move the following two methods into ISerializable interface, make // TODO: move the following two methods into ISerializable interface, make
// Learner (and all other entities that need checkpointing capability) implement it. // Learner (and all other entities that need checkpointing capability) implement it.
CNTK_API virtual Dictionary GetCheckpointState() const virtual Dictionary GetCheckpointState() const
{ {
Dictionary baseCheckpointState; Dictionary baseCheckpointState;
baseCheckpointState[LearningRateAttributeName] = m_learningRate; baseCheckpointState[LearningRateAttributeName] = m_learningRate;
@ -2644,7 +2644,7 @@ namespace CNTK
/// ///
/// Optionally overridable method to restore the learner's state from a previous checkpoint. /// Optionally overridable method to restore the learner's state from a previous checkpoint.
/// ///
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) virtual void RestoreFromCheckpoint(const Dictionary& checkpoint)
{ {
if (checkpoint.Contains(LearningRateAttributeName)) if (checkpoint.Contains(LearningRateAttributeName))
m_learningRate = checkpoint[LearningRateAttributeName].Value<double>(); m_learningRate = checkpoint[LearningRateAttributeName].Value<double>();
@ -2655,8 +2655,8 @@ namespace CNTK
/// ///
virtual ~Learner() {} virtual ~Learner() {}
CNTK_API virtual void ResetLearningRate(double learningRate) { m_learningRate = learningRate; } virtual void ResetLearningRate(double learningRate) { m_learningRate = learningRate; }
CNTK_API virtual double LearningRate() const { return m_learningRate; } virtual double LearningRate() const { return m_learningRate; }
protected: protected:
Learner(const std::vector<Parameter>& parameters, double learningRate) Learner(const std::vector<Parameter>& parameters, double learningRate)

Просмотреть файл

@ -112,7 +112,7 @@ FunctionPtr ResNetClassifier(Variable input, size_t numOutputClasses, const Devi
auto pool = Pooling(rn3_3, PoolingType::Average, { poolW, poolH, 1 }, { poolhStride, poolvStride, 1 }); auto pool = Pooling(rn3_3, PoolingType::Average, { poolW, poolH, 1 }, { poolhStride, poolvStride, 1 });
// Output DNN layer // Output DNN layer
auto outTimesParams = Parameter(NDArrayView::RandomNormal<float>({ numOutputClasses, 1, 1, cMap3 }, 0.0, fc1WScale, 1, device)); auto outTimesParams = Parameter({ numOutputClasses, 1, 1, cMap3 }, DataType::Float, GlorotUniformInitializer(1, 0, fc1WScale), device);
auto outBiasParams = Parameter({ numOutputClasses }, (float)fc1BValue, device); auto outBiasParams = Parameter({ numOutputClasses }, (float)fc1BValue, device);
return Plus(Times(outTimesParams, pool), outBiasParams, outputName); return Plus(Times(outTimesParams, pool), outBiasParams, outputName);

Просмотреть файл

@ -123,10 +123,10 @@ inline CNTK::FunctionPtr FullyConnectedLinearLayer(CNTK::Variable input, size_t
assert(input.Shape().Rank() == 1); assert(input.Shape().Rank() == 1);
size_t inputDim = input.Shape()[0]; size_t inputDim = input.Shape()[0];
auto timesParam = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<float>({ outputDim, inputDim }, -0.05, 0.05, 1, device)); auto timesParam = CNTK::Parameter({ outputDim, inputDim }, CNTK::DataType::Float, CNTK::GlorotUniformInitializer(), device);
auto timesFunction = CNTK::Times(timesParam, input); auto timesFunction = CNTK::Times(timesParam, input);
auto plusParam = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<float>({ outputDim }, -0.05, 0.05, 1, device)); auto plusParam = CNTK::Parameter({ outputDim }, 0.0f, device);
return CNTK::Plus(plusParam, timesFunction, outputName); return CNTK::Plus(plusParam, timesFunction, outputName);
} }
@ -159,11 +159,11 @@ std::pair<CNTK::FunctionPtr, CNTK::FunctionPtr> LSTMPCellWithSelfStabilization(C
unsigned long seed = 1; unsigned long seed = 1;
auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) { auto createProjectionParam = [device, &seed](size_t outputDim, size_t inputDim) {
return CNTK::Parameter({ outputDim, inputDim }, CNTK::AsDataType<ElementType>(), CNTK::UniformInitializer(1, seed++), device); return CNTK::Parameter({ outputDim, inputDim }, CNTK::AsDataType<ElementType>(), CNTK::GlorotUniformInitializer(1, 0, 1, seed++), device);
}; };
auto createDiagWeightParam = [device, &seed](size_t dim) { auto createDiagWeightParam = [device, &seed](size_t dim) {
return CNTK::Parameter({ dim }, CNTK::AsDataType<ElementType>(), CNTK::UniformInitializer(1, seed++), device); return CNTK::Parameter({ dim }, CNTK::AsDataType<ElementType>(), CNTK::GlorotUniformInitializer(1, 0, 1, seed++), device);
}; };
auto stabilizedPrevOutput = Stabilize<ElementType>(prevOutput, device); auto stabilizedPrevOutput = Stabilize<ElementType>(prevOutput, device);

Просмотреть файл

@ -38,8 +38,8 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
bool forceEmbedding = useSparseInputs; bool forceEmbedding = useSparseInputs;
/* Embeddings */ /* Embeddings */
auto inputEmbeddingWeights = Parameter(NDArrayView::RandomUniform<float>({ inputEmbeddingDim, inputVocabDim }, -0.05, 0.05, 1, device)); auto inputEmbeddingWeights = Parameter({ inputEmbeddingDim, inputVocabDim }, DataType::Float, GlorotUniformInitializer(), device);
auto labelEmbeddingWeights = Parameter(NDArrayView::RandomUniform<float>({ labelEmbeddingDim, labelVocabDim }, -0.05, 0.05, 1, device)); auto labelEmbeddingWeights = Parameter({ labelEmbeddingDim, labelVocabDim }, DataType::Float, GlorotUniformInitializer(), device);
auto inputEmbedding = (!forceEmbedding && (inputVocabDim <= inputEmbeddingDim)) ? inputSequence : Times(inputEmbeddingWeights, inputSequence); auto inputEmbedding = (!forceEmbedding && (inputVocabDim <= inputEmbeddingDim)) ? inputSequence : Times(inputEmbeddingWeights, inputSequence);
auto labelEmbedding = (!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSequence : Times(labelEmbeddingWeights, labelSequence); auto labelEmbedding = (!forceEmbedding && (labelVocabDim <= labelEmbeddingDim)) ? labelSequence : Times(labelEmbeddingWeights, labelSequence);
@ -111,7 +111,7 @@ void TrainSequenceToSequenceTranslator(const DeviceDescriptor& device, bool useS
auto decoderDim = hiddenDim; auto decoderDim = hiddenDim;
/* Softmax output layer */ /* Softmax output layer */
auto outputLayerProjWeights = Parameter(NDArrayView::RandomUniform<float>({ labelVocabDim, decoderDim }, -0.05, 0.05, 1, device)); auto outputLayerProjWeights = Parameter({ labelVocabDim, decoderDim }, DataType::Float, GlorotUniformInitializer(), device);
auto biasWeights = Parameter({ labelVocabDim }, 0.0f, device); auto biasWeights = Parameter({ labelVocabDim }, 0.0f, device);
auto z = Plus(Times(outputLayerProjWeights, Stabilize<float>(decoderOutput, device)), biasWeights, L"classifierOutput"); auto z = Plus(Times(outputLayerProjWeights, Stabilize<float>(decoderOutput, device)), biasWeights, L"classifierOutput");

Просмотреть файл

@ -11,7 +11,7 @@ FunctionPtr Embedding(const Variable& input, size_t embeddingDim, const DeviceDe
assert(input.Shape().Rank() == 1); assert(input.Shape().Rank() == 1);
size_t inputDim = input.Shape()[0]; size_t inputDim = input.Shape()[0];
auto embeddingParameters = Parameter(CNTK::NDArrayView::RandomUniform<float>({ embeddingDim, inputDim }, -0.05, 0.05, 1, device)); auto embeddingParameters = Parameter({ embeddingDim, inputDim }, DataType::Float, GlorotUniformInitializer(), device);
return Times(embeddingParameters, input); return Times(embeddingParameters, input);
} }

Просмотреть файл

@ -48,6 +48,7 @@ class Parameter(TensorOpsMixin,Parameter):
data_type = str(value.dtype) data_type = str(value.dtype)
if initializer is not None: if initializer is not None:
shape = utils.sanitize_shape(shape)
data_type = utils.sanitize_dtype_cntk(data_type) data_type = utils.sanitize_dtype_cntk(data_type)
super(Parameter, self).__init__(shape, data_type, initializer, super(Parameter, self).__init__(shape, data_type, initializer,
device, name) device, name)

Просмотреть файл

@ -10,6 +10,7 @@ import os
from cntk import Trainer, sgd_learner, DeviceDescriptor from cntk import Trainer, sgd_learner, DeviceDescriptor
from cntk.ops import input_variable, constant, parameter, cross_entropy_with_softmax, combine, classification_error, times, pooling, AVG_POOLING from cntk.ops import input_variable, constant, parameter, cross_entropy_with_softmax, combine, classification_error, times, pooling, AVG_POOLING
from cntk.io import ReaderConfig, ImageDeserializer from cntk.io import ReaderConfig, ImageDeserializer
from cntk.initializer import glorot_uniform_initializer
abs_path = os.path.dirname(os.path.abspath(__file__)) abs_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(abs_path, "..", "..")) sys.path.append(os.path.join(abs_path, "..", ".."))
@ -30,23 +31,16 @@ def create_mb_source(features_stream_name, labels_stream_name, image_height,
raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them"%(map_file, mean_file, cifar_py3, cifar_py3)) raise RuntimeError("File '%s' or '%s' do not exist. Please run CifarDownload%s.py and CifarConverter%s.py from CIFAR-10 to fetch them"%(map_file, mean_file, cifar_py3, cifar_py3))
image = ImageDeserializer(map_file) image = ImageDeserializer(map_file)
image.map_features(feature_name, image.map_features(features_stream_name,
[ImageDeserializer.crop(crop_type='Random', ratio=0.8, [ImageDeserializer.crop(crop_type='Random', ratio=0.8,
jitter_type='uniRatio'), jitter_type='uniRatio'),
ImageDeserializer.scale(width=image_width, height=image_height, ImageDeserializer.scale(width=image_width, height=image_height,
channels=num_channels, interpolations='linear'), channels=num_channels, interpolations='linear'),
ImageDeserializer.mean(mean_file)]) ImageDeserializer.mean(mean_file)])
image.map_labels(label_name, num_classes) image.map_labels(labels_stream_name, num_classes)
rc = ReaderConfig(image, epoch_size=sys.maxsize) rc = ReaderConfig(image, epoch_size=sys.maxsize)
return rc.minibatch_source()
input_streams_config = {features_stream_name: features_stream_config, labels_stream_name: labels_stream_config}
deserializer_config = {"type" : "ImageDeserializer", "file" : map_file, "input" : input_streams_config}
minibatch_config = {"epochSize" : sys.maxsize, "deserializers" : [deserializer_config]}
print(minibatch_config)
return minibatch_source(minibatch_config)
def get_projection_map(out_dim, in_dim): def get_projection_map(out_dim, in_dim):
if in_dim > out_dim: if in_dim > out_dim:
@ -99,13 +93,13 @@ def resnet_classifer(input, num_classes):
poolv_stride = 1 poolv_stride = 1
pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride)) pool = pooling(rn3_3, AVG_POOLING, (1, poolh, poolw), (1, poolv_stride, poolh_stride))
out_times_params = parameter(shape=(c_map3, 1, 1, num_classes)) out_times_params = parameter(shape=(c_map3, 1, 1, num_classes), initializer=glorot_uniform_initializer())
out_bias_params = parameter(shape=(num_classes)) out_bias_params = parameter(shape=(num_classes), value=0)
t = times(pool, out_times_params) t = times(pool, out_times_params)
return t + out_bias_params return t + out_bias_params
# Trains a residual network model on the Cifar image dataset # Trains a residual network model on the Cifar image dataset
def cifar_resnet(): def cifar_resnet(base_path):
image_height = 32 image_height = 32
image_width = 32 image_width = 32
num_channels = 3 num_channels = 3
@ -113,7 +107,7 @@ def cifar_resnet():
feats_stream_name = 'features' feats_stream_name = 'features'
labels_stream_name = 'labels' labels_stream_name = 'labels'
minibatch_source = create_mb_source(feats_stream_name, labels_stream_name, minibatch_source = create_mb_source(feats_stream_name, labels_stream_name,
image_height, image_width, num_channels, num_classes) image_height, image_width, num_channels, num_classes, base_path)
features_si = minibatch_source.stream_info(feats_stream_name) features_si = minibatch_source.stream_info(feats_stream_name)
labels_si = minibatch_source.stream_info(labels_stream_name) labels_si = minibatch_source.stream_info(labels_stream_name)

Просмотреть файл

@ -47,7 +47,7 @@ def simple_mnist():
labels_si = mb_source.stream_info(labels_stream_name) labels_si = mb_source.stream_info(labels_stream_name)
# Instantiate the trainer object to drive the model training # Instantiate the trainer object to drive the model training
trainer = Trainer(netout, ce, pe, [sgd_learner(netout.owner.parameters(), trainer = Trainer(netout, ce, pe, [sgd_learner(netout.parameters(),
lr=0.003125)]) lr=0.003125)])
# Get minibatches of images to train with and perform model training # Get minibatches of images to train with and perform model training

Просмотреть файл

@ -9,6 +9,7 @@ import sys
import os import os
from cntk.ops import * from cntk.ops import *
from cntk.utils import sanitize_dtype_cntk, get_train_eval_criterion, get_train_loss from cntk.utils import sanitize_dtype_cntk, get_train_eval_criterion, get_train_loss
from cntk.initializer import glorot_uniform_initializer
def linear_layer(input_var, output_dim): def linear_layer(input_var, output_dim):
try: try:
@ -18,8 +19,8 @@ def linear_layer(input_var, output_dim):
shape = input_var.shape() shape = input_var.shape()
input_dim = shape[0] input_dim = shape[0]
times_param = parameter(shape=(input_dim, output_dim)) times_param = parameter(shape=(input_dim, output_dim), initializer=glorot_uniform_initializer())
bias_param = parameter(shape=(output_dim)) bias_param = parameter(shape=(output_dim), value=0)
t = times(input_var, times_param) t = times(input_var, times_param)
return bias_param + t return bias_param + t
@ -44,12 +45,12 @@ def conv_bn_layer(input, out_feature_map_count, kernel_width, kernel_height, h_s
shape = input_var.shape() shape = input_var.shape()
num_in_channels = shape[0] num_in_channels = shape[0]
#TODO: use RandomNormal to initialize, needs to be exposed in the python api #TODO: use RandomNormal to initialize, needs to be exposed in the python api
conv_params = parameter(shape=(num_in_channels, kernel_height, kernel_width, out_feature_map_count)) conv_params = parameter(shape=(num_in_channels, kernel_height, kernel_width, out_feature_map_count), initializer=glorot_uniform_initializer(output_rank=-1, filter_rank=2))
conv_func = convolution(conv_params, input, (num_in_channels, v_stride, h_stride)) conv_func = convolution(conv_params, input, (num_in_channels, v_stride, h_stride))
#TODO: initialize using b_value and sc_value, needs to be exposed in the python api #TODO: initialize using b_value and sc_value, needs to be exposed in the python api
bias_params = parameter(shape=(out_feature_map_count)) bias_params = parameter(shape=(out_feature_map_count), value=b_value)
scale_params = parameter(shape=(out_feature_map_count)) scale_params = parameter(shape=(out_feature_map_count), value=sc_value)
running_mean = constant((out_feature_map_count), 0.0) running_mean = constant((out_feature_map_count), 0.0)
running_invstd = constant((out_feature_map_count), 0.0) running_invstd = constant((out_feature_map_count), 0.0)
return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, True, bn_time_const, 0.0, 0.000000001) return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, True, bn_time_const, 0.0, 0.000000001)
@ -74,8 +75,8 @@ def proj_layer(w_proj, input, h_stride, v_stride, b_value, sc_value, bn_time_con
conv_func = convolution(w_proj, input, (num_in_channels, v_stride, h_stride)) conv_func = convolution(w_proj, input, (num_in_channels, v_stride, h_stride))
out_feature_map_count = w_proj.shape()[-1]; out_feature_map_count = w_proj.shape()[-1];
#TODO: initialize using b_value and sc_value, needs to be exposed in the python api #TODO: initialize using b_value and sc_value, needs to be exposed in the python api
bias_params = parameter(shape=(out_feature_map_count)) bias_params = parameter(shape=(out_feature_map_count), value=b_value)
scale_params = parameter(shape=(out_feature_map_count)) scale_params = parameter(shape=(out_feature_map_count), value=sc_value)
running_mean = constant((out_feature_map_count), 0.0) running_mean = constant((out_feature_map_count), 0.0)
running_invstd = constant((out_feature_map_count), 0.0) running_invstd = constant((out_feature_map_count), 0.0)
return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, True, bn_time_const) return batch_normalization(conv_func, scale_params, bias_params, running_mean, running_invstd, True, bn_time_const)
@ -91,7 +92,7 @@ def resnet_node2_inc(input, out_feature_map_count, kernel_width, kernel_height,
def embedding(input, embedding_dim): def embedding(input, embedding_dim):
input_dim = input.shape()[0]; input_dim = input.shape()[0];
embedding_parameters = parameter(shape=(input_dim, embedding_dim)) embedding_parameters = parameter(shape=(input_dim, embedding_dim), initializer=glorot_uniform_initializer())
return times(input, embedding_parameters) return times(input, embedding_parameters)
def select_last(operand): def select_last(operand):
@ -110,28 +111,28 @@ def LSTMP_cell_with_self_stabilization(input, prev_output, prev_cell_state):
output_dim = prev_output.shape()[0]; output_dim = prev_output.shape()[0];
cell_dim = prev_cell_state.shape()[0]; cell_dim = prev_cell_state.shape()[0];
Wxo = parameter(shape=(input_dim, cell_dim)) Wxo = parameter(shape=(input_dim, cell_dim), initializer=glorot_uniform_initializer())
Wxi = parameter(shape=(input_dim, cell_dim)) Wxi = parameter(shape=(input_dim, cell_dim), initializer=glorot_uniform_initializer())
Wxf = parameter(shape=(input_dim, cell_dim)) Wxf = parameter(shape=(input_dim, cell_dim), initializer=glorot_uniform_initializer())
Wxc = parameter(shape=(input_dim, cell_dim)) Wxc = parameter(shape=(input_dim, cell_dim), initializer=glorot_uniform_initializer())
Bo = parameter(shape=(cell_dim), value=0) Bo = parameter(shape=(cell_dim), value=0)
Bc = parameter(shape=(cell_dim), value=0) Bc = parameter(shape=(cell_dim), value=0)
Bi = parameter(shape=(cell_dim), value=0) Bi = parameter(shape=(cell_dim), value=0)
Bf = parameter(shape=(cell_dim), value=0) Bf = parameter(shape=(cell_dim), value=0)
Whi = parameter(shape=(output_dim, cell_dim)) Whi = parameter(shape=(output_dim, cell_dim), initializer=glorot_uniform_initializer())
Wci = parameter(shape=(cell_dim)) Wci = parameter(shape=(cell_dim), initializer=glorot_uniform_initializer())
Whf = parameter(shape=(output_dim, cell_dim)) Whf = parameter(shape=(output_dim, cell_dim), initializer=glorot_uniform_initializer())
Wcf = parameter(shape=(cell_dim)) Wcf = parameter(shape=(cell_dim), initializer=glorot_uniform_initializer())
Who = parameter(shape=(output_dim, cell_dim)) Who = parameter(shape=(output_dim, cell_dim), initializer=glorot_uniform_initializer())
Wco = parameter(shape=(cell_dim)) Wco = parameter(shape=(cell_dim), initializer=glorot_uniform_initializer())
Whc = parameter(shape=(output_dim, cell_dim)) Whc = parameter(shape=(output_dim, cell_dim), initializer=glorot_uniform_initializer())
Wmr = parameter(shape=(cell_dim, output_dim)) Wmr = parameter(shape=(cell_dim, output_dim), initializer=glorot_uniform_initializer())
# Stabilization by routing input through an extra scalar parameter # Stabilization by routing input through an extra scalar parameter
sWxo = parameter(value=0) sWxo = parameter(value=0)