diff --git a/Examples/SequenceToSequence/CMUDict/Config/G2P.cntk b/Examples/SequenceToSequence/CMUDict/Config/G2P.cntk index 6c2b32fdd..724581646 100644 --- a/Examples/SequenceToSequence/CMUDict/Config/G2P.cntk +++ b/Examples/SequenceToSequence/CMUDict/Config/G2P.cntk @@ -193,7 +193,7 @@ BrainScriptNetworkBuilder = (new ComputationNetwork [ # Note: We reverse our input by running the recurrence from right to left. - encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack + encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBidirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack encoder = encoderFunction (encoderDims, cellDims=encoderDims, S(inputEmbedded), inputDim=inputEmbeddingDim, previousHook=if useBidirectionalEncoder then BS.RNNs.PreviousHC else BS.RNNs.NextHC, enableSelfStabilization=useStabilizer) diff --git a/Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs b/Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs index 00d4091dd..f18b55af4 100644 --- a/Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs +++ b/Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs @@ -501,9 +501,10 @@ PerDimMeanVarDeNormalization(dataVectorSequence, meanVector, invStdDevVector, ta PerDimMeanVarNormalization (x, mean, invStdDev) = (x - mean) .* invStdDev Reciprocal(z, tag='') = new ComputationNode [ operation = 'Reciprocal' ; inputs = z /*plus the function args*/ ] //# the following is a temporary workaround until we have the C++ version -OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = new ComputationNode [ operation = 'OptimizedRNNStack' ; recurrentOp = rnnMode; inputs = ( input : weights ) /*plus the function args*/ ] +# TODO: change hiddenDims to hiddenShape and pass as a TensorShape (currently, the node only supports rank-1 data) +OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, recurrentOp='lstm', axis=-1, tag='') = new ComputationNode [ operation = 'OptimizedRNNStack' ; inputs = ( weights : input ) /*plus the function args*/ ] # legacy: -RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = OptimizedRNNStack(W, X, hiddenSize, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') +RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='lstm', tag='') = OptimizedRNNStack(W, x, hiddenSize, numLayers=1, bidirectional=false, recurrentOp=rnnMode, tag='') Scale(scalarScalingFactor, matrix, tag='') = new ComputationNode [ operation = 'Scale' ; inputs = (scalarScalingFactor : matrix) /*plus the function args*/ ] # TODO: Scale = ElementTimes ScatterPacked(cond, indexSequence, sourceData, tag='') = new ComputationNode [ operation = 'ScatterPacked' ; inputs = (cond : indexSequence : sourceData) /*plus the function args*/ ] @@ -1045,7 +1046,7 @@ RNNs = # a stack of recurrent LSTMs (bidirectional) # TODO: Should we define layerDims as the total (sum of both forward and backward direction)? - RecurrentBirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [ + RecurrentBidirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [ previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization layers[i:0..Length (layerDims)-1] = [ @@ -1159,7 +1160,7 @@ RNNs = # a stack of recurrent GRUs (bidirectional) # TODO: Should we define layerDims as the total (sum of both forward and backward direction)? - RecurrentBirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [ + RecurrentBidirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [ previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization layers[i:0..Length (layerDims)-1] = [ diff --git a/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp b/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp index 6b2617ddf..b69271a47 100644 --- a/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp +++ b/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp @@ -82,7 +82,7 @@ static shared_ptr> CreateStandardNode(const std::wstri else if (nodeType == OperationNameOf(NegateNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(NotEqualNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(NoiseContrastiveEstimationNode)) return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(OptimizedRNNStack)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(OptimizedRNNStackNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PackedIndexNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PastValueNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) return New>(forward<_Types>(_Args)...); @@ -127,7 +127,7 @@ static shared_ptr> CreateStandardNode(const std::wstri else if (nodeType == L"PerDimMeanVarNormalizationNode") return New>(forward<_Types>(_Args)...); else if (nodeType == L"PerDimMeanVarDeNormalizationNode") return New>(forward<_Types>(_Args)...); else if (nodeType == L"ReconcileMBLayout") return New>(forward<_Types>(_Args)...); - else if (nodeType == L"RNN") return New>(forward<_Types>(_Args)...); + else if (nodeType == L"RNN") return New>(forward<_Types>(_Args)...); else if (nodeType == L"RowElementTimes") return New>(forward<_Types>(_Args)...); else if (nodeType == L"RowSlice") return New>(forward<_Types>(_Args)...); else if (nodeType == L"Scale") return New>(forward<_Types>(_Args)...); diff --git a/Source/ComputationNetworkLib/RNNNodes.cpp b/Source/ComputationNetworkLib/RNNNodes.cpp index e9c751d36..09836a223 100644 --- a/Source/ComputationNetworkLib/RNNNodes.cpp +++ b/Source/ComputationNetworkLib/RNNNodes.cpp @@ -25,20 +25,20 @@ namespace Microsoft { namespace MSR { namespace CNTK { vector numSequencesForFrame; // ----------------------------------------------------------------------- -// OptimizedRNNStack +// OptimizedRNNStackNode // ----------------------------------------------------------------------- template -OptimizedRNNStack::OptimizedRNNStack(DEVICEID_TYPE deviceId, const wstring& name) +OptimizedRNNStackNode::OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name) : Base(deviceId, name), - m_rnnAttributes(0, 0, 0, L"LSTM", -1), + m_rnnAttributes(0, 0, 0, L"lstm", -1), m_BackwardDataCalledYet(false) { } // This constructor helps with BrainScript integration template -OptimizedRNNStack::OptimizedRNNStack(const ScriptableObjects::IConfigRecordPtr configp) +OptimizedRNNStackNode::OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp) : Base(configp->Get(L"deviceId"), L""), m_rnnAttributes(configp->Get(L"bidirectional"), configp->Get(L"numLayers"), configp->Get(L"hiddenDims"), configp->Get(L"recurrentOp"), configp->Get(L"axis")), m_BackwardDataCalledYet(false) @@ -47,32 +47,32 @@ OptimizedRNNStack::OptimizedRNNStack(const ScriptableObjects::IConfigR } template -/*virtual*/ void OptimizedRNNStack::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/ +/*virtual*/ void OptimizedRNNStackNode::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/ { Base::CopyTo(nodeP, newName, flags); if (flags & CopyNodeFlags::copyNodeValue) { - auto node = dynamic_pointer_cast>(nodeP); + auto node = dynamic_pointer_cast>(nodeP); node->m_rnnAttributes = m_rnnAttributes; } } template -void OptimizedRNNStack::Save(File& fstream) const +void OptimizedRNNStackNode::Save(File& fstream) const { Base::Save(fstream); m_rnnAttributes.Write(fstream); } template -void OptimizedRNNStack::Load(File& fstream, size_t modelVersion) +void OptimizedRNNStackNode::Load(File& fstream, size_t modelVersion) { Base::Load(fstream, modelVersion); m_rnnAttributes.Read(fstream, /*readAxis=*/ modelVersion >= CNTK_MODEL_VERSION_14); } template -void OptimizedRNNStack::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY) +void OptimizedRNNStackNode::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY) { // This function transposes the second and third axes of the input (X), creating a transposed copy in the output (Y). // @@ -89,19 +89,19 @@ void OptimizedRNNStack::TransposeHelper(const MatrixBasePtr matX, cons }; template -void OptimizedRNNStack::ForwardProp(const FrameRange& fr) +void OptimizedRNNStackNode::ForwardProp(const FrameRange& fr) { // ComputationNode derived classes are guaranteed to have a MBLayout if (!HasMBLayout()) { - LogicError("OptimizedRNNStack must operate on minibatches"); + LogicError("OptimizedRNNStackNode must operate on minibatches"); } // The parameters are stored in a column matrix Matrix& paramW = Input(1)->Value(); MBLayoutPtr mb = GetMBLayout(); - if (m_rnnAttributes.IsWindowedRecurrence()) + if (m_rnnAttributes.IsSpatialRecurrence()) { TensorView outputY = ValueTensorFor(SIZE_MAX, fr); @@ -137,7 +137,7 @@ void OptimizedRNNStack::ForwardProp(const FrameRange& fr) else { if (mb->GetNumTimeSteps() == 1) - RuntimeError("OptimizedRNNStack configured for sequence mode, but minibatch only has one time step."); + RuntimeError("OptimizedRNNStackNode configured for sequence mode, but minibatch only has one time step."); shapeXT = TensorShape(Input(0)->GetTensorSliceFor(SIZE_MAX, fr)); shapeYT = TensorShape(this->GetTensorSliceFor(SIZE_MAX, fr)); @@ -155,7 +155,7 @@ void OptimizedRNNStack::ForwardProp(const FrameRange& fr) } template -void OptimizedRNNStack::BackpropTo(const size_t inputIndex, const FrameRange& fr) +void OptimizedRNNStackNode::BackpropTo(const size_t inputIndex, const FrameRange& fr) { MBLayoutPtr mb = this->GetMBLayout(); @@ -164,7 +164,7 @@ void OptimizedRNNStack::BackpropTo(const size_t inputIndex, const Fram { Matrix& paramW = Input(1)->Value(); - if (m_rnnAttributes.IsWindowedRecurrence()) + if (m_rnnAttributes.IsSpatialRecurrence()) { // To obey the data layout constraints of CuDnn, we take the derivative we're given, // and transpose it before feeding to the interface. @@ -191,7 +191,7 @@ void OptimizedRNNStack::BackpropTo(const size_t inputIndex, const Fram else if (inputIndex == 0) // data { // all of the work was done above, where RNNBackwardData is called. Now, just unpack the result. - if (m_rnnAttributes.IsWindowedRecurrence()) + if (m_rnnAttributes.IsSpatialRecurrence()) { TensorShape tmp; TransposeHelper(m_transposedDInput, shapeXT, Input(0)->GradientPtr(), tmp); @@ -204,20 +204,23 @@ void OptimizedRNNStack::BackpropTo(const size_t inputIndex, const Fram } template -void OptimizedRNNStack::Validate(bool isFinalValidationPass) +void OptimizedRNNStackNode::Validate(bool isFinalValidationPass) { // N.B.: I need both of these lines. Base::Validate(isFinalValidationPass); InferMBLayoutFromInputsForStandardCase(isFinalValidationPass); // get tensor shapes - auto dimsA = Input(1)->GetSampleLayout().GetDims(); // data - auto dimsB = Input(0)->GetSampleLayout().GetDims(); // parameters + let& shapeA = Input(0)->GetSampleLayout(); // parameters + let& shapeB = Input(1)->GetSampleLayout(); // data + auto dimsA = shapeA.GetDims(); + auto dimsB = shapeB.GetDims(); + // data rank must match spatial/temporal recurrence mode if (isFinalValidationPass && - dimsA.size() != (m_rnnAttributes.IsWindowedRecurrence() ? 2 : 1)) + dimsB.size() != (m_rnnAttributes.IsSpatialRecurrence() ? 2 : 1)) { - InvalidArgument("%ls: Input must have rank 1 for axis=-1 and rank 2 for axis=2.", NodeDescription().c_str()); + InvalidArgument("%ls: Input [%s] must have rank 1 for axis=-1 and rank 2 for axis=2.", NodeDescription().c_str(), string(shapeB).c_str()); } // validate and infer @@ -229,6 +232,14 @@ void OptimizedRNNStack::Validate(bool isFinalValidationPass) // output dims dimsC[0] = (m_rnnAttributes.m_bidirectional ? 2 : 1) * m_rnnAttributes.m_hiddenSize; + // infer input size + // Note: Output dim is second axis, so say initOutputRank=-1 in the Parameters{} definition. + if (dimsA.size() == 2) + { + let numParameters = m_rnnAttributes.GetNumParameters(shapeB.GetNumElements()); + Input(0)->ValidateInferInputDimsFrom(TensorShape(numParameters.first, numParameters.second)); + } + // N.B. - this is the magical call, the reason for the function // dimensions would be outputRank * numSamples * minibatch * time. // This call establishes outputRank * numSamples, the rest will be filled in @@ -238,7 +249,7 @@ void OptimizedRNNStack::Validate(bool isFinalValidationPass) }; template -void OptimizedRNNStack::PackSequencesForCuDNN(const Matrix& src, Matrix& dst, vector& numSequencesForFrame) +void OptimizedRNNStackNode::PackSequencesForCuDNN(const Matrix& src, Matrix& dst, vector& numSequencesForFrame) { MBLayoutPtr mb = this->GetMBLayout(); if (mb->HasSequenceBeyondBegin()) @@ -307,7 +318,7 @@ void OptimizedRNNStack::PackSequencesForCuDNN(const Matrix& dst.DoGatherColumnsOf(0.0, *(this->m_packingIndex), src, 1.0); } template -void OptimizedRNNStack::UnpackSequencesFromCuDNN(const Matrix& src, Matrix& dst) +void OptimizedRNNStackNode::UnpackSequencesFromCuDNN(const Matrix& src, Matrix& dst) { // this->scatter(beta,ndx,a,alpha) operation is defined as // *this[:,idx[j]] = a[:,j] * alpha + *this[:,idx[j]] * beta @@ -315,7 +326,7 @@ void OptimizedRNNStack::UnpackSequencesFromCuDNN(const Matrix; -template class OptimizedRNNStack; +template class OptimizedRNNStackNode; +template class OptimizedRNNStackNode; }}} diff --git a/Source/ComputationNetworkLib/RNNNodes.h b/Source/ComputationNetworkLib/RNNNodes.h index 74c7c1125..9b6b49e6b 100644 --- a/Source/ComputationNetworkLib/RNNNodes.h +++ b/Source/ComputationNetworkLib/RNNNodes.h @@ -24,19 +24,19 @@ namespace Microsoft { namespace MSR { namespace CNTK { // ----------------------------------------------------------------------- -// OptimizedRNNStack (data, weights) +// OptimizedRNNStack (weights, data) // ----------------------------------------------------------------------- template -class OptimizedRNNStack : public ComputationNode, public NumInputs<2> +class OptimizedRNNStackNode : public ComputationNode, public NumInputs<2> { typedef ComputationNode Base; UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() { return L"OptimizedRNN"; } + static const std::wstring TypeName() { return L"OptimizedRNNStack"; } using Base::OperationName; public: - OptimizedRNNStack(DEVICEID_TYPE deviceId, const wstring& name); - OptimizedRNNStack(const ScriptableObjects::IConfigRecordPtr configp); + OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name); + OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp); virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override; virtual void Save(File& fstream) const; diff --git a/Source/Math/CuDnnRNN.h b/Source/Math/CuDnnRNN.h index 6b70cf4bb..40fe279cc 100644 --- a/Source/Math/CuDnnRNN.h +++ b/Source/Math/CuDnnRNN.h @@ -16,9 +16,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { class CuDnnDropout { CuDnn::ptr_t m_cudnn; - unsigned long long m_seed = 0xdeadbeefull; + unsigned long long m_seed = 1; public: - CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 0xdeadbeefull) + CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 1) : m_dropoutDesc(nullptr), m_cudnn(CuDnn::Instance()) { CUDNN_CALL(cudnnCreateDropoutDescriptor(&m_dropoutDesc)); @@ -66,15 +66,11 @@ private: cudnnRNNMode_t GetMode() { - if (m_rnnAttributes.m_rnnMode == wstring(L"LSTM")) - return cudnnRNNMode_t::CUDNN_LSTM; - if (m_rnnAttributes.m_rnnMode == wstring(L"GRU")) - return cudnnRNNMode_t::CUDNN_GRU; - if (m_rnnAttributes.m_rnnMode == wstring(L"RNN_RELU")) - return cudnnRNNMode_t::CUDNN_RNN_RELU; - if (m_rnnAttributes.m_rnnMode == wstring(L"RNN_TANH")) - return cudnnRNNMode_t::CUDNN_RNN_TANH; - InvalidArgument("RNN Mode set to %ls, but supported values are LSTM, GRU, RNN_RELU, RNN_TANH.", m_rnnAttributes.m_rnnMode.c_str()); + if (m_rnnAttributes.m_recurrentOp == wstring(L"lstm")) return cudnnRNNMode_t::CUDNN_LSTM; + else if (m_rnnAttributes.m_recurrentOp == wstring(L"gru")) return cudnnRNNMode_t::CUDNN_GRU; + else if (m_rnnAttributes.m_recurrentOp == wstring(L"rnnReLU")) return cudnnRNNMode_t::CUDNN_RNN_RELU; + else if (m_rnnAttributes.m_recurrentOp == wstring(L"rnnTanh")) return cudnnRNNMode_t::CUDNN_RNN_TANH; + else InvalidArgument("Unknown cell type. Supported values are 'lstm', 'gru', 'rnnReLU', 'rnnTanh'.", m_rnnAttributes.m_recurrentOp.c_str()); } public: diff --git a/Source/Math/RNNCommon.h b/Source/Math/RNNCommon.h index 38e76de7f..c3a7ef7c2 100644 --- a/Source/Math/RNNCommon.h +++ b/Source/Math/RNNCommon.h @@ -16,24 +16,52 @@ struct RnnAttributes bool m_bidirectional; size_t m_numLayers; size_t m_hiddenSize; - wstring m_rnnMode; + wstring m_recurrentOp; int m_axis; - bool IsWindowedRecurrence() const { return m_axis >= 0; } + bool IsSpatialRecurrence() const { return m_axis >= 0; } - RnnAttributes(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& rnnMode, int axis) : - m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_rnnMode(rnnMode), m_axis(axis) + RnnAttributes(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& recurrentOp, int axis) : + m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_recurrentOp(recurrentOp), m_axis(axis) { + if (m_recurrentOp != wstring(L"lstm") && m_recurrentOp != wstring(L"gru") && + m_recurrentOp != wstring(L"rnnReLU") && m_recurrentOp != wstring(L"rnnTanh")) + { + InvalidArgument("Unknown cell type '%ls'. Supported values are 'lstm', 'gru', 'rnnReLU', 'rnnTanh'.", m_recurrentOp.c_str()); + } + if (m_axis != -1 && m_axis != 2) InvalidArgument("OptimizedRNNStack: invalid 'axis' parameter %d, currently supported values are -1 and 2.", m_axis); } + // compute the total number of parameters, for inference of weight matrix size + pair GetNumParameters(size_t inputDim) const + { + const size_t bidirFactor = m_bidirectional ? 2 : 1; + const size_t numNetworks = + (m_recurrentOp == L"lstm" ) ? 4 : + (m_recurrentOp == L"gru" ) ? 3 : + /*else*/ 1; + size_t total = 0; + for (size_t i = 0; i < m_numLayers; i++) + { + size_t oneNetTotal = + numNetworks * m_hiddenSize // 1, 3, or 4 networks producing hidden-dim output + * (inputDim + m_hiddenSize) // each network has these two inputs + + numNetworks * m_hiddenSize // biases + * 2; // for unknown reasons, cudnn5 uses 2 bias terms everywhere + total += oneNetTotal * bidirFactor; // 1 or 2 directions + inputDim = bidirFactor * m_hiddenSize; // next layer continues with this as input + } + return make_pair(m_hiddenSize, total / m_hiddenSize); + } + bool operator==(const RnnAttributes& other) const { return m_bidirectional == other.m_bidirectional && m_numLayers == other.m_numLayers && m_hiddenSize == other.m_hiddenSize && - m_rnnMode == other.m_rnnMode && + m_recurrentOp == other.m_recurrentOp && m_axis == other.m_axis; } @@ -43,11 +71,17 @@ struct RnnAttributes stream >> bidirectional; m_bidirectional = !!bidirectional; stream >> m_numLayers; stream >> m_hiddenSize; - stream >> m_rnnMode; + stream >> m_recurrentOp; if (readAxis) - stream >> m_axis; // note: back compat for windowed models deliberately dropped - else - m_axis = -1; + stream >> m_axis; + else // lecagy + { + m_axis = -1; // note: back compat for windowed models deliberately dropped + if (m_recurrentOp == wstring(L"LSTM")) m_recurrentOp = L"lstm"; // map names + else if (m_recurrentOp == wstring(L"GRU")) m_recurrentOp = L"gru"; + else if (m_recurrentOp == wstring(L"RNN_RELU")) m_recurrentOp = L"rnnReLU"; + else if (m_recurrentOp == wstring(L"RNN_TANH")) m_recurrentOp = L"rnnTanh"; + } } void Write(File& stream) const @@ -56,7 +90,7 @@ struct RnnAttributes stream << bidirectional; stream << m_numLayers; stream << m_hiddenSize; - stream << m_rnnMode; + stream << m_recurrentOp; stream << m_axis; } diff --git a/Tests/EndToEndTests/Speech/LSTM/cntk.cntk b/Tests/EndToEndTests/Speech/LSTM/cntk.cntk index d2d5739cd..d4c883d76 100644 --- a/Tests/EndToEndTests/Speech/LSTM/cntk.cntk +++ b/Tests/EndToEndTests/Speech/LSTM/cntk.cntk @@ -34,11 +34,11 @@ speechTrain = { # cudnn5 library # Note: does not run in truncated mode - W = ParameterTensor {14704-8*(40-33):hiddenDim, init='heNormal', initValueScale=1/10} # -> change to 0:hiddenDim, outputRank=-1 + W = ParameterTensor {hiddenDim:14704-8*(40-33), initOutputRank=-1, init='heNormal', initValueScale=1/10} # -> change to 0:hiddenDim, outputRank=-1 modelUsingCuDNN5 = Sequential ( MeanVarNorm : - (_ => OptimizedRNNStack(W, _, hiddenDim, numLayers=numLSTMLayers, bidirectional=true, rnnMode='LSTM')) : + (_ => OptimizedRNNStack(W, _, hiddenDim, numLayers=numLSTMLayers, bidirectional=true)) : DenseLayer {labelDim, init='heUniform', initValueScale=1/3} ) @@ -144,7 +144,7 @@ speechTrain = { // features features = Input((1 : featDim), tag='feature') // TEST: Artificially reading data transposed - realFeatures = Transpose (features) // and swapping them back to (featDim:1), for testing Transpose() + realFeatures = FlattenDimensions (Transpose (features), 1, 2) // and swapping them back to (featDim:1), for testing Transpose() feashift = RowSlice(featDim - baseFeatDim, baseFeatDim, realFeatures); # interface with a reader set up for frame mode labels = Input(labelDim, tag='label') diff --git a/Tutorials/SLUHandsOn/SLUHandsOn_Solution1.cntk b/Tutorials/SLUHandsOn/SLUHandsOn_Solution1.cntk index f8c07626c..7e2bcadc0 100644 --- a/Tutorials/SLUHandsOn/SLUHandsOn_Solution1.cntk +++ b/Tutorials/SLUHandsOn/SLUHandsOn_Solution1.cntk @@ -36,7 +36,7 @@ TrainTagger = { # loss and metric ce = CrossEntropyWithSoftmax (slotLabels, z) - errs = ClassificationError (slotLabels, z) + errs = ClassificationError (slotLabels, z) featureNodes = (query) labelNodes = (slotLabels) diff --git a/Tutorials/SLUHandsOn/SLUHandsOn_Solution2.cntk b/Tutorials/SLUHandsOn/SLUHandsOn_Solution2.cntk index 4311cabb1..77acbd061 100644 --- a/Tutorials/SLUHandsOn/SLUHandsOn_Solution2.cntk +++ b/Tutorials/SLUHandsOn/SLUHandsOn_Solution2.cntk @@ -39,7 +39,7 @@ TrainTagger = { # loss and metric ce = CrossEntropyWithSoftmax (slotLabels, z) - errs = ClassificationError (slotLabels, z) + errs = ClassificationError (slotLabels, z) featureNodes = (query) labelNodes = (slotLabels) diff --git a/Tutorials/SLUHandsOn/SLUHandsOn_Solution3.cntk b/Tutorials/SLUHandsOn/SLUHandsOn_Solution3.cntk index 84cda08e3..caaa8d3ce 100644 --- a/Tutorials/SLUHandsOn/SLUHandsOn_Solution3.cntk +++ b/Tutorials/SLUHandsOn/SLUHandsOn_Solution3.cntk @@ -42,7 +42,7 @@ TrainTagger = { # loss and metric ce = CrossEntropyWithSoftmax (slotLabels, z) - errs = ClassificationError (slotLabels, z) + errs = ClassificationError (slotLabels, z) featureNodes = (query) labelNodes = (slotLabels) diff --git a/Tutorials/SLUHandsOn/SLUHandsOn_Solution4.cntk b/Tutorials/SLUHandsOn/SLUHandsOn_Solution4.cntk index 7a02a35f8..b4bb55e0b 100644 --- a/Tutorials/SLUHandsOn/SLUHandsOn_Solution4.cntk +++ b/Tutorials/SLUHandsOn/SLUHandsOn_Solution4.cntk @@ -43,7 +43,7 @@ TrainTagger = { # loss and metric ce = CrossEntropyWithSoftmax (intentLabels, z) - errs = ClassificationError (intentLabels, z) + errs = ClassificationError (intentLabels, z) featureNodes = (query) labelNodes = (intentLabels)