OptimizedRNNStackNode: renamed some variables, renamed recurrentOps to camelCase, added weigth inference
This commit is contained in:
Родитель
8a86da8f02
Коммит
d9c7e82031
|
@ -193,7 +193,7 @@ BrainScriptNetworkBuilder = (new ComputationNetwork [
|
|||
|
||||
# Note: We reverse our input by running the recurrence from right to left.
|
||||
|
||||
encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack
|
||||
encoderFunction = if useBidirectionalEncoder then BS.RNNs.RecurrentBidirectionalLSTMPStack else BS.RNNs.RecurrentLSTMPStack
|
||||
encoder = encoderFunction (encoderDims, cellDims=encoderDims, S(inputEmbedded), inputDim=inputEmbeddingDim,
|
||||
previousHook=if useBidirectionalEncoder then BS.RNNs.PreviousHC else BS.RNNs.NextHC,
|
||||
enableSelfStabilization=useStabilizer)
|
||||
|
|
|
@ -501,9 +501,10 @@ PerDimMeanVarDeNormalization(dataVectorSequence, meanVector, invStdDevVector, ta
|
|||
PerDimMeanVarNormalization (x, mean, invStdDev) = (x - mean) .* invStdDev
|
||||
Reciprocal(z, tag='') = new ComputationNode [ operation = 'Reciprocal' ; inputs = z /*plus the function args*/ ]
|
||||
//# the following is a temporary workaround until we have the C++ version
|
||||
OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = new ComputationNode [ operation = 'OptimizedRNNStack' ; recurrentOp = rnnMode; inputs = ( input : weights ) /*plus the function args*/ ]
|
||||
# TODO: change hiddenDims to hiddenShape and pass as a TensorShape (currently, the node only supports rank-1 data)
|
||||
OptimizedRNNStack(weights, input, hiddenDims, numLayers=1, bidirectional=false, recurrentOp='lstm', axis=-1, tag='') = new ComputationNode [ operation = 'OptimizedRNNStack' ; inputs = ( weights : input ) /*plus the function args*/ ]
|
||||
# legacy:
|
||||
RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='') = OptimizedRNNStack(W, X, hiddenSize, numLayers=1, bidirectional=false, rnnMode='LSTM', tag='')
|
||||
RNNStack(x, W, hiddenSize=10, numLayers=1, bidirectional=false, rnnMode='lstm', tag='') = OptimizedRNNStack(W, x, hiddenSize, numLayers=1, bidirectional=false, recurrentOp=rnnMode, tag='')
|
||||
Scale(scalarScalingFactor, matrix, tag='') = new ComputationNode [ operation = 'Scale' ; inputs = (scalarScalingFactor : matrix) /*plus the function args*/ ]
|
||||
# TODO: Scale = ElementTimes
|
||||
ScatterPacked(cond, indexSequence, sourceData, tag='') = new ComputationNode [ operation = 'ScatterPacked' ; inputs = (cond : indexSequence : sourceData) /*plus the function args*/ ]
|
||||
|
@ -1045,7 +1046,7 @@ RNNs =
|
|||
|
||||
# a stack of recurrent LSTMs (bidirectional)
|
||||
# TODO: Should we define layerDims as the total (sum of both forward and backward direction)?
|
||||
RecurrentBirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
|
||||
RecurrentBidirectionalLSTMPStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
|
||||
previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization
|
||||
layers[i:0..Length (layerDims)-1] =
|
||||
[
|
||||
|
@ -1159,7 +1160,7 @@ RNNs =
|
|||
|
||||
# a stack of recurrent GRUs (bidirectional)
|
||||
# TODO: Should we define layerDims as the total (sum of both forward and backward direction)?
|
||||
RecurrentBirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
|
||||
RecurrentBidirectionalGRUStack (layerDims, cellDims=layerDims, input, inputDim=input.dim, previousHook=PreviousHC, nextHook=NextHC, enableSelfStabilization=false) = [
|
||||
previousHook1 = previousHook ; nextHook1 = nextHook ; useStabilizer = enableSelfStabilization
|
||||
layers[i:0..Length (layerDims)-1] =
|
||||
[
|
||||
|
|
|
@ -82,7 +82,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
|
|||
else if (nodeType == OperationNameOf(NegateNode)) return New<NegateNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(NotEqualNode)) return New<NotEqualNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(NoiseContrastiveEstimationNode)) return New<NoiseContrastiveEstimationNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(OptimizedRNNStack)) return New<OptimizedRNNStack<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(OptimizedRNNStackNode)) return New<OptimizedRNNStackNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(PackedIndexNode)) return New<PackedIndexNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(PastValueNode)) return New<PastValueNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) return New<PerDimMeanVarNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
@ -127,7 +127,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
|
|||
else if (nodeType == L"PerDimMeanVarNormalizationNode") return New<PerDimMeanVarNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"PerDimMeanVarDeNormalizationNode") return New<PerDimMeanVarDeNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"ReconcileMBLayout") return New<ReconcileDynamicAxisNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"RNN") return New<OptimizedRNNStack<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"RNN") return New<OptimizedRNNStackNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"RowElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"RowSlice") return New<SliceNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == L"Scale") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
|
|
@ -25,20 +25,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
vector<size_t> numSequencesForFrame;
|
||||
// -----------------------------------------------------------------------
|
||||
// OptimizedRNNStack
|
||||
// OptimizedRNNStackNode
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
template<class ElemType>
|
||||
OptimizedRNNStack<ElemType>::OptimizedRNNStack(DEVICEID_TYPE deviceId, const wstring& name)
|
||||
OptimizedRNNStackNode<ElemType>::OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name)
|
||||
: Base(deviceId, name),
|
||||
m_rnnAttributes(0, 0, 0, L"LSTM", -1),
|
||||
m_rnnAttributes(0, 0, 0, L"lstm", -1),
|
||||
m_BackwardDataCalledYet(false)
|
||||
{
|
||||
}
|
||||
|
||||
// This constructor helps with BrainScript integration
|
||||
template<class ElemType>
|
||||
OptimizedRNNStack<ElemType>::OptimizedRNNStack(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
OptimizedRNNStackNode<ElemType>::OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: Base(configp->Get(L"deviceId"), L"<placeholder>"),
|
||||
m_rnnAttributes(configp->Get(L"bidirectional"), configp->Get(L"numLayers"), configp->Get(L"hiddenDims"), configp->Get(L"recurrentOp"), configp->Get(L"axis")),
|
||||
m_BackwardDataCalledYet(false)
|
||||
|
@ -47,32 +47,32 @@ OptimizedRNNStack<ElemType>::OptimizedRNNStack(const ScriptableObjects::IConfigR
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
/*virtual*/ void OptimizedRNNStack<ElemType>::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/
|
||||
/*virtual*/ void OptimizedRNNStackNode<ElemType>::CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const /*override*/
|
||||
{
|
||||
Base::CopyTo(nodeP, newName, flags);
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
auto node = dynamic_pointer_cast<OptimizedRNNStack<ElemType>>(nodeP);
|
||||
auto node = dynamic_pointer_cast<OptimizedRNNStackNode<ElemType>>(nodeP);
|
||||
node->m_rnnAttributes = m_rnnAttributes;
|
||||
}
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::Save(File& fstream) const
|
||||
void OptimizedRNNStackNode<ElemType>::Save(File& fstream) const
|
||||
{
|
||||
Base::Save(fstream);
|
||||
m_rnnAttributes.Write(fstream);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::Load(File& fstream, size_t modelVersion)
|
||||
void OptimizedRNNStackNode<ElemType>::Load(File& fstream, size_t modelVersion)
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
m_rnnAttributes.Read(fstream, /*readAxis=*/ modelVersion >= CNTK_MODEL_VERSION_14);
|
||||
}
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY)
|
||||
void OptimizedRNNStackNode<ElemType>::TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY)
|
||||
{
|
||||
// This function transposes the second and third axes of the input (X), creating a transposed copy in the output (Y).
|
||||
//
|
||||
|
@ -89,19 +89,19 @@ void OptimizedRNNStack<ElemType>::TransposeHelper(const MatrixBasePtr matX, cons
|
|||
};
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::ForwardProp(const FrameRange& fr)
|
||||
void OptimizedRNNStackNode<ElemType>::ForwardProp(const FrameRange& fr)
|
||||
{
|
||||
// ComputationNode derived classes are guaranteed to have a MBLayout
|
||||
if (!HasMBLayout())
|
||||
{
|
||||
LogicError("OptimizedRNNStack must operate on minibatches");
|
||||
LogicError("OptimizedRNNStackNode must operate on minibatches");
|
||||
}
|
||||
|
||||
// The parameters are stored in a column matrix
|
||||
Matrix<ElemType>& paramW = Input(1)->Value();
|
||||
|
||||
MBLayoutPtr mb = GetMBLayout();
|
||||
if (m_rnnAttributes.IsWindowedRecurrence())
|
||||
if (m_rnnAttributes.IsSpatialRecurrence())
|
||||
{
|
||||
TensorView<ElemType> outputY = ValueTensorFor(SIZE_MAX, fr);
|
||||
|
||||
|
@ -137,7 +137,7 @@ void OptimizedRNNStack<ElemType>::ForwardProp(const FrameRange& fr)
|
|||
else
|
||||
{
|
||||
if (mb->GetNumTimeSteps() == 1)
|
||||
RuntimeError("OptimizedRNNStack configured for sequence mode, but minibatch only has one time step.");
|
||||
RuntimeError("OptimizedRNNStackNode configured for sequence mode, but minibatch only has one time step.");
|
||||
|
||||
shapeXT = TensorShape(Input(0)->GetTensorSliceFor(SIZE_MAX, fr));
|
||||
shapeYT = TensorShape(this->GetTensorSliceFor(SIZE_MAX, fr));
|
||||
|
@ -155,7 +155,7 @@ void OptimizedRNNStack<ElemType>::ForwardProp(const FrameRange& fr)
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr)
|
||||
void OptimizedRNNStackNode<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr)
|
||||
{
|
||||
MBLayoutPtr mb = this->GetMBLayout();
|
||||
|
||||
|
@ -164,7 +164,7 @@ void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const Fram
|
|||
{
|
||||
Matrix<ElemType>& paramW = Input(1)->Value();
|
||||
|
||||
if (m_rnnAttributes.IsWindowedRecurrence())
|
||||
if (m_rnnAttributes.IsSpatialRecurrence())
|
||||
{
|
||||
// To obey the data layout constraints of CuDnn, we take the derivative we're given,
|
||||
// and transpose it before feeding to the interface.
|
||||
|
@ -191,7 +191,7 @@ void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const Fram
|
|||
else if (inputIndex == 0) // data
|
||||
{
|
||||
// all of the work was done above, where RNNBackwardData is called. Now, just unpack the result.
|
||||
if (m_rnnAttributes.IsWindowedRecurrence())
|
||||
if (m_rnnAttributes.IsSpatialRecurrence())
|
||||
{
|
||||
TensorShape tmp;
|
||||
TransposeHelper(m_transposedDInput, shapeXT, Input(0)->GradientPtr(), tmp);
|
||||
|
@ -204,20 +204,23 @@ void OptimizedRNNStack<ElemType>::BackpropTo(const size_t inputIndex, const Fram
|
|||
}
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::Validate(bool isFinalValidationPass)
|
||||
void OptimizedRNNStackNode<ElemType>::Validate(bool isFinalValidationPass)
|
||||
{
|
||||
// N.B.: I need both of these lines.
|
||||
Base::Validate(isFinalValidationPass);
|
||||
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
|
||||
|
||||
// get tensor shapes
|
||||
auto dimsA = Input(1)->GetSampleLayout().GetDims(); // data
|
||||
auto dimsB = Input(0)->GetSampleLayout().GetDims(); // parameters
|
||||
let& shapeA = Input(0)->GetSampleLayout(); // parameters
|
||||
let& shapeB = Input(1)->GetSampleLayout(); // data
|
||||
auto dimsA = shapeA.GetDims();
|
||||
auto dimsB = shapeB.GetDims();
|
||||
|
||||
// data rank must match spatial/temporal recurrence mode
|
||||
if (isFinalValidationPass &&
|
||||
dimsA.size() != (m_rnnAttributes.IsWindowedRecurrence() ? 2 : 1))
|
||||
dimsB.size() != (m_rnnAttributes.IsSpatialRecurrence() ? 2 : 1))
|
||||
{
|
||||
InvalidArgument("%ls: Input must have rank 1 for axis=-1 and rank 2 for axis=2.", NodeDescription().c_str());
|
||||
InvalidArgument("%ls: Input [%s] must have rank 1 for axis=-1 and rank 2 for axis=2.", NodeDescription().c_str(), string(shapeB).c_str());
|
||||
}
|
||||
|
||||
// validate and infer
|
||||
|
@ -229,6 +232,14 @@ void OptimizedRNNStack<ElemType>::Validate(bool isFinalValidationPass)
|
|||
// output dims
|
||||
dimsC[0] = (m_rnnAttributes.m_bidirectional ? 2 : 1) * m_rnnAttributes.m_hiddenSize;
|
||||
|
||||
// infer input size
|
||||
// Note: Output dim is second axis, so say initOutputRank=-1 in the Parameters{} definition.
|
||||
if (dimsA.size() == 2)
|
||||
{
|
||||
let numParameters = m_rnnAttributes.GetNumParameters(shapeB.GetNumElements());
|
||||
Input(0)->ValidateInferInputDimsFrom(TensorShape(numParameters.first, numParameters.second));
|
||||
}
|
||||
|
||||
// N.B. - this is the magical call, the reason for the function
|
||||
// dimensions would be outputRank * numSamples * minibatch * time.
|
||||
// This call establishes outputRank * numSamples, the rest will be filled in
|
||||
|
@ -238,7 +249,7 @@ void OptimizedRNNStack<ElemType>::Validate(bool isFinalValidationPass)
|
|||
};
|
||||
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame)
|
||||
void OptimizedRNNStackNode<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame)
|
||||
{
|
||||
MBLayoutPtr mb = this->GetMBLayout();
|
||||
if (mb->HasSequenceBeyondBegin())
|
||||
|
@ -307,7 +318,7 @@ void OptimizedRNNStack<ElemType>::PackSequencesForCuDNN(const Matrix<ElemType>&
|
|||
dst.DoGatherColumnsOf(0.0, *(this->m_packingIndex), src, 1.0);
|
||||
}
|
||||
template<class ElemType>
|
||||
void OptimizedRNNStack<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst)
|
||||
void OptimizedRNNStackNode<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst)
|
||||
{
|
||||
// this->scatter(beta,ndx,a,alpha) operation is defined as
|
||||
// *this[:,idx[j]] = a[:,j] * alpha + *this[:,idx[j]] * beta
|
||||
|
@ -315,7 +326,7 @@ void OptimizedRNNStack<ElemType>::UnpackSequencesFromCuDNN(const Matrix<ElemType
|
|||
}
|
||||
|
||||
|
||||
template class OptimizedRNNStack<float>;
|
||||
template class OptimizedRNNStack<double>;
|
||||
template class OptimizedRNNStackNode<float>;
|
||||
template class OptimizedRNNStackNode<double>;
|
||||
|
||||
}}}
|
||||
|
|
|
@ -24,19 +24,19 @@
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// OptimizedRNNStack (data, weights)
|
||||
// OptimizedRNNStack (weights, data)
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
template <class ElemType>
|
||||
class OptimizedRNNStack : public ComputationNode<ElemType>, public NumInputs<2>
|
||||
class OptimizedRNNStackNode : public ComputationNode<ElemType>, public NumInputs<2>
|
||||
{
|
||||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() { return L"OptimizedRNN"; }
|
||||
static const std::wstring TypeName() { return L"OptimizedRNNStack"; }
|
||||
using Base::OperationName;
|
||||
|
||||
public:
|
||||
OptimizedRNNStack(DEVICEID_TYPE deviceId, const wstring& name);
|
||||
OptimizedRNNStack(const ScriptableObjects::IConfigRecordPtr configp);
|
||||
OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name);
|
||||
OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp);
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override;
|
||||
virtual void Save(File& fstream) const;
|
||||
|
|
|
@ -16,9 +16,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
class CuDnnDropout
|
||||
{
|
||||
CuDnn::ptr_t m_cudnn;
|
||||
unsigned long long m_seed = 0xdeadbeefull;
|
||||
unsigned long long m_seed = 1;
|
||||
public:
|
||||
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 0xdeadbeefull)
|
||||
CuDnnDropout(float dropout = 0.0f, unsigned long long seed = 1)
|
||||
: m_dropoutDesc(nullptr), m_cudnn(CuDnn::Instance())
|
||||
{
|
||||
CUDNN_CALL(cudnnCreateDropoutDescriptor(&m_dropoutDesc));
|
||||
|
@ -66,15 +66,11 @@ private:
|
|||
|
||||
cudnnRNNMode_t GetMode()
|
||||
{
|
||||
if (m_rnnAttributes.m_rnnMode == wstring(L"LSTM"))
|
||||
return cudnnRNNMode_t::CUDNN_LSTM;
|
||||
if (m_rnnAttributes.m_rnnMode == wstring(L"GRU"))
|
||||
return cudnnRNNMode_t::CUDNN_GRU;
|
||||
if (m_rnnAttributes.m_rnnMode == wstring(L"RNN_RELU"))
|
||||
return cudnnRNNMode_t::CUDNN_RNN_RELU;
|
||||
if (m_rnnAttributes.m_rnnMode == wstring(L"RNN_TANH"))
|
||||
return cudnnRNNMode_t::CUDNN_RNN_TANH;
|
||||
InvalidArgument("RNN Mode set to %ls, but supported values are LSTM, GRU, RNN_RELU, RNN_TANH.", m_rnnAttributes.m_rnnMode.c_str());
|
||||
if (m_rnnAttributes.m_recurrentOp == wstring(L"lstm")) return cudnnRNNMode_t::CUDNN_LSTM;
|
||||
else if (m_rnnAttributes.m_recurrentOp == wstring(L"gru")) return cudnnRNNMode_t::CUDNN_GRU;
|
||||
else if (m_rnnAttributes.m_recurrentOp == wstring(L"rnnReLU")) return cudnnRNNMode_t::CUDNN_RNN_RELU;
|
||||
else if (m_rnnAttributes.m_recurrentOp == wstring(L"rnnTanh")) return cudnnRNNMode_t::CUDNN_RNN_TANH;
|
||||
else InvalidArgument("Unknown cell type. Supported values are 'lstm', 'gru', 'rnnReLU', 'rnnTanh'.", m_rnnAttributes.m_recurrentOp.c_str());
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
|
@ -16,24 +16,52 @@ struct RnnAttributes
|
|||
bool m_bidirectional;
|
||||
size_t m_numLayers;
|
||||
size_t m_hiddenSize;
|
||||
wstring m_rnnMode;
|
||||
wstring m_recurrentOp;
|
||||
int m_axis;
|
||||
bool IsWindowedRecurrence() const { return m_axis >= 0; }
|
||||
bool IsSpatialRecurrence() const { return m_axis >= 0; }
|
||||
|
||||
RnnAttributes(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& rnnMode, int axis) :
|
||||
m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_rnnMode(rnnMode), m_axis(axis)
|
||||
RnnAttributes(bool bidirectional, size_t numLayers, size_t hiddenSize, const wstring& recurrentOp, int axis) :
|
||||
m_bidirectional(bidirectional), m_numLayers(numLayers), m_hiddenSize(hiddenSize), m_recurrentOp(recurrentOp), m_axis(axis)
|
||||
{
|
||||
if (m_recurrentOp != wstring(L"lstm") && m_recurrentOp != wstring(L"gru") &&
|
||||
m_recurrentOp != wstring(L"rnnReLU") && m_recurrentOp != wstring(L"rnnTanh"))
|
||||
{
|
||||
InvalidArgument("Unknown cell type '%ls'. Supported values are 'lstm', 'gru', 'rnnReLU', 'rnnTanh'.", m_recurrentOp.c_str());
|
||||
}
|
||||
|
||||
if (m_axis != -1 && m_axis != 2)
|
||||
InvalidArgument("OptimizedRNNStack: invalid 'axis' parameter %d, currently supported values are -1 and 2.", m_axis);
|
||||
}
|
||||
|
||||
// compute the total number of parameters, for inference of weight matrix size
|
||||
pair<size_t,size_t> GetNumParameters(size_t inputDim) const
|
||||
{
|
||||
const size_t bidirFactor = m_bidirectional ? 2 : 1;
|
||||
const size_t numNetworks =
|
||||
(m_recurrentOp == L"lstm" ) ? 4 :
|
||||
(m_recurrentOp == L"gru" ) ? 3 :
|
||||
/*else*/ 1;
|
||||
size_t total = 0;
|
||||
for (size_t i = 0; i < m_numLayers; i++)
|
||||
{
|
||||
size_t oneNetTotal =
|
||||
numNetworks * m_hiddenSize // 1, 3, or 4 networks producing hidden-dim output
|
||||
* (inputDim + m_hiddenSize) // each network has these two inputs
|
||||
+ numNetworks * m_hiddenSize // biases
|
||||
* 2; // for unknown reasons, cudnn5 uses 2 bias terms everywhere
|
||||
total += oneNetTotal * bidirFactor; // 1 or 2 directions
|
||||
inputDim = bidirFactor * m_hiddenSize; // next layer continues with this as input
|
||||
}
|
||||
return make_pair(m_hiddenSize, total / m_hiddenSize);
|
||||
}
|
||||
|
||||
bool operator==(const RnnAttributes& other) const
|
||||
{
|
||||
return
|
||||
m_bidirectional == other.m_bidirectional &&
|
||||
m_numLayers == other.m_numLayers &&
|
||||
m_hiddenSize == other.m_hiddenSize &&
|
||||
m_rnnMode == other.m_rnnMode &&
|
||||
m_recurrentOp == other.m_recurrentOp &&
|
||||
m_axis == other.m_axis;
|
||||
}
|
||||
|
||||
|
@ -43,11 +71,17 @@ struct RnnAttributes
|
|||
stream >> bidirectional; m_bidirectional = !!bidirectional;
|
||||
stream >> m_numLayers;
|
||||
stream >> m_hiddenSize;
|
||||
stream >> m_rnnMode;
|
||||
stream >> m_recurrentOp;
|
||||
if (readAxis)
|
||||
stream >> m_axis; // note: back compat for windowed models deliberately dropped
|
||||
else
|
||||
m_axis = -1;
|
||||
stream >> m_axis;
|
||||
else // lecagy
|
||||
{
|
||||
m_axis = -1; // note: back compat for windowed models deliberately dropped
|
||||
if (m_recurrentOp == wstring(L"LSTM")) m_recurrentOp = L"lstm"; // map names
|
||||
else if (m_recurrentOp == wstring(L"GRU")) m_recurrentOp = L"gru";
|
||||
else if (m_recurrentOp == wstring(L"RNN_RELU")) m_recurrentOp = L"rnnReLU";
|
||||
else if (m_recurrentOp == wstring(L"RNN_TANH")) m_recurrentOp = L"rnnTanh";
|
||||
}
|
||||
}
|
||||
|
||||
void Write(File& stream) const
|
||||
|
@ -56,7 +90,7 @@ struct RnnAttributes
|
|||
stream << bidirectional;
|
||||
stream << m_numLayers;
|
||||
stream << m_hiddenSize;
|
||||
stream << m_rnnMode;
|
||||
stream << m_recurrentOp;
|
||||
stream << m_axis;
|
||||
}
|
||||
|
||||
|
|
|
@ -34,11 +34,11 @@ speechTrain = {
|
|||
|
||||
# cudnn5 library
|
||||
# Note: does not run in truncated mode
|
||||
W = ParameterTensor {14704-8*(40-33):hiddenDim, init='heNormal', initValueScale=1/10} # -> change to 0:hiddenDim, outputRank=-1
|
||||
W = ParameterTensor {hiddenDim:14704-8*(40-33), initOutputRank=-1, init='heNormal', initValueScale=1/10} # -> change to 0:hiddenDim, outputRank=-1
|
||||
modelUsingCuDNN5 = Sequential
|
||||
(
|
||||
MeanVarNorm :
|
||||
(_ => OptimizedRNNStack(W, _, hiddenDim, numLayers=numLSTMLayers, bidirectional=true, rnnMode='LSTM')) :
|
||||
(_ => OptimizedRNNStack(W, _, hiddenDim, numLayers=numLSTMLayers, bidirectional=true)) :
|
||||
DenseLayer {labelDim, init='heUniform', initValueScale=1/3}
|
||||
)
|
||||
|
||||
|
@ -144,7 +144,7 @@ speechTrain = {
|
|||
|
||||
// features
|
||||
features = Input((1 : featDim), tag='feature') // TEST: Artificially reading data transposed
|
||||
realFeatures = Transpose (features) // and swapping them back to (featDim:1), for testing Transpose()
|
||||
realFeatures = FlattenDimensions (Transpose (features), 1, 2) // and swapping them back to (featDim:1), for testing Transpose()
|
||||
feashift = RowSlice(featDim - baseFeatDim, baseFeatDim, realFeatures); # interface with a reader set up for frame mode
|
||||
labels = Input(labelDim, tag='label')
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ TrainTagger = {
|
|||
|
||||
# loss and metric
|
||||
ce = CrossEntropyWithSoftmax (slotLabels, z)
|
||||
errs = ClassificationError (slotLabels, z)
|
||||
errs = ClassificationError (slotLabels, z)
|
||||
|
||||
featureNodes = (query)
|
||||
labelNodes = (slotLabels)
|
||||
|
|
|
@ -39,7 +39,7 @@ TrainTagger = {
|
|||
|
||||
# loss and metric
|
||||
ce = CrossEntropyWithSoftmax (slotLabels, z)
|
||||
errs = ClassificationError (slotLabels, z)
|
||||
errs = ClassificationError (slotLabels, z)
|
||||
|
||||
featureNodes = (query)
|
||||
labelNodes = (slotLabels)
|
||||
|
|
|
@ -42,7 +42,7 @@ TrainTagger = {
|
|||
|
||||
# loss and metric
|
||||
ce = CrossEntropyWithSoftmax (slotLabels, z)
|
||||
errs = ClassificationError (slotLabels, z)
|
||||
errs = ClassificationError (slotLabels, z)
|
||||
|
||||
featureNodes = (query)
|
||||
labelNodes = (slotLabels)
|
||||
|
|
|
@ -43,7 +43,7 @@ TrainTagger = {
|
|||
|
||||
# loss and metric
|
||||
ce = CrossEntropyWithSoftmax (intentLabels, z)
|
||||
errs = ClassificationError (intentLabels, z)
|
||||
errs = ClassificationError (intentLabels, z)
|
||||
|
||||
featureNodes = (query)
|
||||
labelNodes = (intentLabels)
|
||||
|
|
Загрузка…
Ссылка в новой задаче