OptimizedRNNStackNode: renamed some variables, renamed recurrentOps to camelCase, added weigth inference

This commit is contained in:
Frank Seide 2016-08-24 17:17:23 -07:00
Родитель 8a86da8f02
Коммит d9c7e82031
12 изменённых файлов: 107 добавлений и 65 удалений

Просмотреть файл

@ -193,7 +193,7 @@ BrainScriptNetworkBuilder = (new ComputationNetwork [
# Note: We reverse our input by running the recurrence from right to left.
encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack
encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBidirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack
encoder = encoderFunction (encoderDims, cellDims=encoderDims, S(inputEmbedded), inputDim=inputEmbeddingDim,
previousHook=if useBidirectionalEncoder then BS.RNNs.PreviousHC else BS.RNNs.NextHC,
enableSelfStabilization=useStabilizer)

Просмотреть файл

@ -501,9 +501,10 @@ PerDimMeanVarDeNormalization(dataVectorSequence, meanVector, invStdDevVector, ta
PerDimMeanVarNormalization (x, mean, invStdDev) = (x - mean) .* invStdDev
Reciprocal(z, tag='') = new ComputationNode [ operation = 'Reciprocal' ; inputs = z /*plus the function args*/ ]
//# the following is a temporary workaround until we have the C++ version
OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = new ComputationNode [ operation = 'OptimizedRNNStack' ; recurrentOp = rnnMode; inputs = ( input : weights ) /*plus the function args*/ ]
# TODO: change hiddenDims to hiddenShape and pass as a TensorShape (currently, the node only supports rank-1 data)
OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, recurrentOp='lstm', axis=-1, tag='') = new ComputationNode [ operation = 'OptimizedRNNStack' ; inputs = ( weights : input ) /*plus the function args*/ ]
# legacy:
RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = OptimizedRNNStack(W, X, hiddenSize, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='')
RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='lstm', tag='') = OptimizedRNNStack(W, x, hiddenSize, numLayers=1, bidirectional=false, recurrentOp=rnnMode, tag='')
Scale(scalarScalingFactor, matrix, tag='') = new ComputationNode [ operation = 'Scale' ; inputs = (scalarScalingFactor : matrix) /*plus the function args*/ ]
# TODO: Scale = ElementTimes
ScatterPacked(cond, indexSequence, sourceData, tag='') = new ComputationNode [ operation = 'ScatterPacked' ; inputs = (cond : indexSequence : sourceData) /*plus the function args*/ ]
@ -1045,7 +1046,7 @@ RNNs =
# a stack of recurrent LSTMs (bidirectional)
# TODO: Should we define layerDims as the total (sum of both forward and backward direction)?
RecurrentBirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
RecurrentBidirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization
layers[i:0..Length (layerDims)-1] =
[
@ -1159,7 +1160,7 @@ RNNs =
# a stack of recurrent GRUs (bidirectional)
# TODO: Should we define layerDims as the total (sum of both forward and backward direction)?
RecurrentBirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
RecurrentBidirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization
layers[i:0..Length (layerDims)-1] =
[

Просмотреть файл

@ -82,7 +82,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(NegateNode)) return New<NegateNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(NotEqualNode)) return New<NotEqualNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(NoiseContrastiveEstimationNode)) return New<NoiseContrastiveEstimationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(OptimizedRNNStack)) return New<OptimizedRNNStack<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(OptimizedRNNStackNode)) return New<OptimizedRNNStackNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PackedIndexNode)) return New<PackedIndexNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PastValueNode)) return New<PastValueNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) return New<PerDimMeanVarNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
@ -127,7 +127,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == L"PerDimMeanVarNormalizationNode") return New<PerDimMeanVarNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"PerDimMeanVarDeNormalizationNode") return New<PerDimMeanVarDeNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"ReconcileMBLayout") return New<ReconcileDynamicAxisNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"RNN") return New<OptimizedRNNStack<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"RNN") return New<OptimizedRNNStackNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"RowElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"RowSlice") return New<SliceNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"Scale") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);

Просмотреть файл

@ -25,20 +25,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
vector<size_t> numSequencesForFrame;
// -----------------------------------------------------------------------
// OptimizedRNNStack
// OptimizedRNNStackNode
// -----------------------------------------------------------------------
template<class ElemType>
OptimizedRNNStack<ElemType>::OptimizedRNNStack(DEVICEID_TYPE deviceId, const wstring& name)
OptimizedRNNStackNode<ElemType>::OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name),
m_rnnAttributes(0, 0, 0, L"LSTM", -1),
m_rnnAttributes(0, 0, 0, L"lstm", -1),
m_BackwardDataCalledYet(false)
{
}
// This constructor helps with BrainScript integration
template<class ElemType>
OptimizedRNNStack<ElemType>::OptimizedRNNStack(const ScriptableObjects::IConfigRecordPtr configp)
OptimizedRNNStackNode<ElemType>::OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp)
: Base(configp->Get(L"deviceId"), L"<placeholder>"),
m_rnnAttributes(configp->Get(L"bidirectional"), configp->Get(L"numLayers"), configp->Get(L"hiddenDims"), configp->Get(L"recurrentOp"), configp->Get(L"axis")),
m_BackwardDataCalledYet(false)
@ -47,32 +47,32 @@ OptimizedRNNStack<ElemType>::OptimizedRNNStack(const ScriptableObjects::IConfigR
}
template<class ElemType>
/*virtual*/ void OptimizedRNNStack<ElemType>::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/
/*virtual*/ void OptimizedRNNStackNode<ElemType>::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<OptimizedRNNStack<ElemType>>(nodeP);
auto node = dynamic_pointer_cast<OptimizedRNNStackNode<ElemType>>(nodeP);
node->m_rnnAttributes = m_rnnAttributes;
}
}
template<class ElemType>
void OptimizedRNNStack<ElemType>::Save(File& fstream) const
void OptimizedRNNStackNode<ElemType>::Save(File& fstream) const
{
Base::Save(fstream);
m_rnnAttributes.Write(fstream);
}
template<class ElemType>
void OptimizedRNNStack<ElemType>::Load(File& fstream, size_t modelVersion)
void OptimizedRNNStackNode<ElemType>::Load(File& fstream, size_t modelVersion)
{
Base::Load(fstream, modelVersion);
m_rnnAttributes.Read(fstream, /*readAxis=*/ modelVersion >= CNTK_MODEL_VERSION_14);
}
template<class ElemType>
void OptimizedRNNStack<ElemType>::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY)
void OptimizedRNNStackNode<ElemType>::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY)
{
// This function transposes the second and third axes of the input (X), creating a transposed copy in the output (Y).
//
@ -89,19 +89,19 @@ void OptimizedRNNStack<ElemType>::TransposeHelper(const MatrixBasePtr matX, cons
};
template<class ElemType>
void OptimizedRNNStack<ElemType>::ForwardProp(const FrameRange& fr)
void OptimizedRNNStackNode<ElemType>::ForwardProp(const FrameRange& fr)
{
// ComputationNode derived classes are guaranteed to have a MBLayout
if (!HasMBLayout())
{
LogicError("OptimizedRNNStack must operate on minibatches");
LogicError("OptimizedRNNStackNode must operate on minibatches");
}
// The parameters are stored in a column matrix
Matrix<ElemType>& paramW = Input(1)->Value();
MBLayoutPtr mb = GetMBLayout();
if (m_rnnAttributes.IsWindowedRecurrence())
if (m_rnnAttributes.IsSpatialRecurrence())
{
TensorView<ElemType> outputY = ValueTensorFor(SIZE_MAX, fr);
@ -137,7 +137,7 @@ void OptimizedRNNStack<ElemType>::ForwardProp(const FrameRange& fr)
else
{
if (mb->GetNumTimeSteps() == 1)
RuntimeError("OptimizedRNNStack configured for sequence mode, but minibatch only has one time step.");
RuntimeError("OptimizedRNNStackNode configured for sequence mode, but minibatch only has one time step.");
shapeXT = TensorShape(Input(0)->GetTensorSliceFor(SIZE_MAX, fr));
shapeYT = TensorShape(this->GetTensorSliceFor(SIZE_MAX, fr));
@ -155,7 +155,7 @@ void OptimizedRNNStack<ElemType>::ForwardProp(const FrameRange& fr)
}
template<class ElemType>
void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr)
void OptimizedRNNStackNode<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr)
{
MBLayoutPtr mb = this->GetMBLayout();
@ -164,7 +164,7 @@ void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const Fram
{
Matrix<ElemType>& paramW = Input(1)->Value();
if (m_rnnAttributes.IsWindowedRecurrence())
if (m_rnnAttributes.IsSpatialRecurrence())
{
// To obey the data layout constraints of CuDnn, we take the derivative we're given,
// and transpose it before feeding to the interface.
@ -191,7 +191,7 @@ void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const Fram
else if (inputIndex == 0) // data
{
// all of the work was done above, where RNNBackwardData is called. Now, just unpack the result.
if (m_rnnAttributes.IsWindowedRecurrence())
if (m_rnnAttributes.IsSpatialRecurrence())
{
TensorShape tmp;
TransposeHelper(m_transposedDInput, shapeXT, Input(0)->GradientPtr(), tmp);
@ -204,20 +204,23 @@ void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const Fram
}
template<class ElemType>
void OptimizedRNNStack<ElemType>::Validate(bool isFinalValidationPass)
void OptimizedRNNStackNode<ElemType>::Validate(bool isFinalValidationPass)
{
// N.B.: I need both of these lines.
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
// get tensor shapes
auto dimsA = Input(1)->GetSampleLayout().GetDims(); // data
auto dimsB = Input(0)->GetSampleLayout().GetDims(); // parameters
let& shapeA = Input(0)->GetSampleLayout(); // parameters
let& shapeB = Input(1)->GetSampleLayout(); // data
auto dimsA = shapeA.GetDims();
auto dimsB = shapeB.GetDims();
// data rank must match spatial/temporal recurrence mode
if (isFinalValidationPass &&
dimsA.size() != (m_rnnAttributes.IsWindowedRecurrence() ? 2 : 1))
dimsB.size() != (m_rnnAttributes.IsSpatialRecurrence() ? 2 : 1))
{
InvalidArgument("%ls: Input must have rank 1 for axis=-1 and rank 2 for axis=2.", NodeDescription().c_str());
InvalidArgument("%ls: Input [%s] must have rank 1 for axis=-1 and rank 2 for axis=2.", NodeDescription().c_str(), string(shapeB).c_str());
}
// validate and infer
@ -229,6 +232,14 @@ void OptimizedRNNStack<ElemType>::Validate(bool isFinalValidationPass)
// output dims
dimsC[0] = (m_rnnAttributes.m_bidirectional ? 2 : 1) * m_rnnAttributes.m_hiddenSize;
// infer input size
// Note: Output dim is second axis, so say initOutputRank=-1 in the Parameters{} definition.
if (dimsA.size() == 2)
{
let numParameters = m_rnnAttributes.GetNumParameters(shapeB.GetNumElements());
Input(0)->ValidateInferInputDimsFrom(TensorShape(numParameters.first, numParameters.second));
}
// N.B. - this is the magical call, the reason for the function
// dimensions would be outputRank * numSamples * minibatch * time.
// This call establishes outputRank * numSamples, the rest will be filled in
@ -238,7 +249,7 @@ void OptimizedRNNStack<ElemType>::Validate(bool isFinalValidationPass)
};
template<class ElemType>
void OptimizedRNNStack<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame)
void OptimizedRNNStackNode<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame)
{
MBLayoutPtr mb = this->GetMBLayout();
if (mb->HasSequenceBeyondBegin())
@ -307,7 +318,7 @@ void OptimizedRNNStack<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>&
dst.DoGatherColumnsOf(0.0, *(this->m_packingIndex), src, 1.0);
}
template<class ElemType>
void OptimizedRNNStack<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst)
void OptimizedRNNStackNode<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst)
{
// this->scatter(beta,ndx,a,alpha) operation is defined as
// *this[:,idx[j]] = a[:,j] * alpha + *this[:,idx[j]] * beta
@ -315,7 +326,7 @@ void OptimizedRNNStack<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType
}
template class OptimizedRNNStack<float>;
template class OptimizedRNNStack<double>;
template class OptimizedRNNStackNode<float>;
template class OptimizedRNNStackNode<double>;
}}}

Просмотреть файл

@ -24,19 +24,19 @@
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// OptimizedRNNStack (data, weights)
// OptimizedRNNStack (weights, data)
// -----------------------------------------------------------------------
template <class ElemType>
class OptimizedRNNStack : public ComputationNode<ElemType>, public NumInputs<2>
class OptimizedRNNStackNode : public ComputationNode<ElemType>, public NumInputs<2>
{
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"OptimizedRNN"; }
static const std::wstring TypeName() { return L"OptimizedRNNStack"; }
using Base::OperationName;
public:
OptimizedRNNStack(DEVICEID_TYPE deviceId, const wstring& name);
OptimizedRNNStack(const ScriptableObjects::IConfigRecordPtr configp);
OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name);
OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp);
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override;
virtual void Save(File& fstream) const;

Просмотреть файл

@ -16,9 +16,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
class CuDnnDropout
{
CuDnn::ptr_t m_cudnn;
unsigned long long m_seed = 0xdeadbeefull;
unsigned long long m_seed = 1;
public:
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 0xdeadbeefull)
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 1)
: m_dropoutDesc(nullptr), m_cudnn(CuDnn::Instance())
{
CUDNN_CALL(cudnnCreateDropoutDescriptor(&m_dropoutDesc));
@ -66,15 +66,11 @@ private:
cudnnRNNMode_t GetMode()
{
if (m_rnnAttributes.m_rnnMode == wstring(L"LSTM"))
return cudnnRNNMode_t::CUDNN_LSTM;
if (m_rnnAttributes.m_rnnMode == wstring(L"GRU"))
return cudnnRNNMode_t::CUDNN_GRU;
if (m_rnnAttributes.m_rnnMode == wstring(L"RNN_RELU"))
return cudnnRNNMode_t::CUDNN_RNN_RELU;
if (m_rnnAttributes.m_rnnMode == wstring(L"RNN_TANH"))
return cudnnRNNMode_t::CUDNN_RNN_TANH;
InvalidArgument("RNN Mode set to %ls, but supported values are LSTM, GRU, RNN_RELU, RNN_TANH.", m_rnnAttributes.m_rnnMode.c_str());
if (m_rnnAttributes.m_recurrentOp == wstring(L"lstm")) return cudnnRNNMode_t::CUDNN_LSTM;
else if (m_rnnAttributes.m_recurrentOp == wstring(L"gru")) return cudnnRNNMode_t::CUDNN_GRU;
else if (m_rnnAttributes.m_recurrentOp == wstring(L"rnnReLU")) return cudnnRNNMode_t::CUDNN_RNN_RELU;
else if (m_rnnAttributes.m_recurrentOp == wstring(L"rnnTanh")) return cudnnRNNMode_t::CUDNN_RNN_TANH;
else InvalidArgument("Unknown cell type. Supported values are 'lstm', 'gru', 'rnnReLU', 'rnnTanh'.", m_rnnAttributes.m_recurrentOp.c_str());
}
public:

Просмотреть файл

@ -16,24 +16,52 @@ struct RnnAttributes
bool m_bidirectional;
size_t m_numLayers;
size_t m_hiddenSize;
wstring m_rnnMode;
wstring m_recurrentOp;
int m_axis;
bool IsWindowedRecurrence() const { return m_axis >= 0; }
bool IsSpatialRecurrence() const { return m_axis >= 0; }
RnnAttributes(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& rnnMode, int axis) :
m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_rnnMode(rnnMode), m_axis(axis)
RnnAttributes(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& recurrentOp, int axis) :
m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_recurrentOp(recurrentOp), m_axis(axis)
{
if (m_recurrentOp != wstring(L"lstm") && m_recurrentOp != wstring(L"gru") &&
m_recurrentOp != wstring(L"rnnReLU") && m_recurrentOp != wstring(L"rnnTanh"))
{
InvalidArgument("Unknown cell type '%ls'. Supported values are 'lstm', 'gru', 'rnnReLU', 'rnnTanh'.", m_recurrentOp.c_str());
}
if (m_axis != -1 && m_axis != 2)
InvalidArgument("OptimizedRNNStack: invalid 'axis' parameter %d, currently supported values are -1 and 2.", m_axis);
}
// compute the total number of parameters, for inference of weight matrix size
pair<size_t,size_t> GetNumParameters(size_t inputDim) const
{
const size_t bidirFactor = m_bidirectional ? 2 : 1;
const size_t numNetworks =
(m_recurrentOp == L"lstm" ) ? 4 :
(m_recurrentOp == L"gru" ) ? 3 :
/*else*/ 1;
size_t total = 0;
for (size_t i = 0; i < m_numLayers; i++)
{
size_t oneNetTotal =
numNetworks * m_hiddenSize // 1, 3, or 4 networks producing hidden-dim output
* (inputDim + m_hiddenSize) // each network has these two inputs
+ numNetworks * m_hiddenSize // biases
* 2; // for unknown reasons, cudnn5 uses 2 bias terms everywhere
total += oneNetTotal * bidirFactor; // 1 or 2 directions
inputDim = bidirFactor * m_hiddenSize; // next layer continues with this as input
}
return make_pair(m_hiddenSize, total / m_hiddenSize);
}
bool operator==(const RnnAttributes& other) const
{
return
m_bidirectional == other.m_bidirectional &&
m_numLayers == other.m_numLayers &&
m_hiddenSize == other.m_hiddenSize &&
m_rnnMode == other.m_rnnMode &&
m_recurrentOp == other.m_recurrentOp &&
m_axis == other.m_axis;
}
@ -43,11 +71,17 @@ struct RnnAttributes
stream >> bidirectional; m_bidirectional = !!bidirectional;
stream >> m_numLayers;
stream >> m_hiddenSize;
stream >> m_rnnMode;
stream >> m_recurrentOp;
if (readAxis)
stream >> m_axis; // note: back compat for windowed models deliberately dropped
else
m_axis = -1;
stream >> m_axis;
else // lecagy
{
m_axis = -1; // note: back compat for windowed models deliberately dropped
if (m_recurrentOp == wstring(L"LSTM")) m_recurrentOp = L"lstm"; // map names
else if (m_recurrentOp == wstring(L"GRU")) m_recurrentOp = L"gru";
else if (m_recurrentOp == wstring(L"RNN_RELU")) m_recurrentOp = L"rnnReLU";
else if (m_recurrentOp == wstring(L"RNN_TANH")) m_recurrentOp = L"rnnTanh";
}
}
void Write(File& stream) const
@ -56,7 +90,7 @@ struct RnnAttributes
stream << bidirectional;
stream << m_numLayers;
stream << m_hiddenSize;
stream << m_rnnMode;
stream << m_recurrentOp;
stream << m_axis;
}

Просмотреть файл

@ -34,11 +34,11 @@ speechTrain = {
# cudnn5 library
# Note: does not run in truncated mode
W = ParameterTensor {14704-8*(40-33):hiddenDim, init='heNormal', initValueScale=1/10} # -> change to 0:hiddenDim, outputRank=-1
W = ParameterTensor {hiddenDim:14704-8*(40-33), initOutputRank=-1, init='heNormal', initValueScale=1/10} # -> change to 0:hiddenDim, outputRank=-1
modelUsingCuDNN5 = Sequential
(
MeanVarNorm :
(_ => OptimizedRNNStack(W, _, hiddenDim, numLayers=numLSTMLayers, bidirectional=true, rnnMode='LSTM')) :
(_ => OptimizedRNNStack(W, _, hiddenDim, numLayers=numLSTMLayers, bidirectional=true)) :
DenseLayer {labelDim, init='heUniform', initValueScale=1/3}
)
@ -144,7 +144,7 @@ speechTrain = {
// features
features = Input((1 : featDim), tag='feature') // TEST: Artificially reading data transposed
realFeatures = Transpose (features) // and swapping them back to (featDim:1), for testing Transpose()
realFeatures = FlattenDimensions (Transpose (features), 1, 2) // and swapping them back to (featDim:1), for testing Transpose()
feashift = RowSlice(featDim - baseFeatDim, baseFeatDim, realFeatures); # interface with a reader set up for frame mode
labels = Input(labelDim, tag='label')

Просмотреть файл

@ -36,7 +36,7 @@ TrainTagger = {
# loss and metric
ce = CrossEntropyWithSoftmax (slotLabels, z)
errs = ClassificationError (slotLabels, z)
errs = ClassificationError (slotLabels, z)
featureNodes = (query)
labelNodes = (slotLabels)

Просмотреть файл

@ -39,7 +39,7 @@ TrainTagger = {
# loss and metric
ce = CrossEntropyWithSoftmax (slotLabels, z)
errs = ClassificationError (slotLabels, z)
errs = ClassificationError (slotLabels, z)
featureNodes = (query)
labelNodes = (slotLabels)

Просмотреть файл

@ -42,7 +42,7 @@ TrainTagger = {
# loss and metric
ce = CrossEntropyWithSoftmax (slotLabels, z)
errs = ClassificationError (slotLabels, z)
errs = ClassificationError (slotLabels, z)
featureNodes = (query)
labelNodes = (slotLabels)

Просмотреть файл

@ -43,7 +43,7 @@ TrainTagger = {
# loss and metric
ce = CrossEntropyWithSoftmax (intentLabels, z)
errs = ClassificationError (intentLabels, z)
errs = ClassificationError (intentLabels, z)
featureNodes = (query)
labelNodes = (intentLabels)