Add resetRNN flag to Evaluate in ExtendedEval intf

This commit is contained in:
Vadim Mazalov 2016-09-20 13:46:17 -07:00
Родитель f946c20f1b
Коммит 7447d55ed0
3 изменённых файлов: 41 добавлений и 33 удалений

Просмотреть файл

@ -49,12 +49,6 @@ public:
// Free resources
//
virtual void Destroy() = 0;
//
// Reset initial state of all Recurrence loops (RNNs) in the model.
// Call this before processing the first sequence or whenever need to reset the memory cells to default value.
//
virtual void ResetState() = 0;
};
// ------------------------------------------------------------------------
@ -102,6 +96,11 @@ public:
// happen during evaluation
//
virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs) = 0;
//
// Reset initial state of all Recurrence loops (RNNs) in the model.
//
virtual void ResetState() = 0;
};
@ -339,11 +338,23 @@ public:
//
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output) = 0;
//
// Same as above, and
// resetRNN - flags whether to reset memory cells of RNN.
//
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output, bool resetRNN) = 0;
//
// Same as above, but takes references to static arrays instead of std::vector
// (e.g. when vectors are manages by .net)
//
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output) = 0;
//
// Same as above, and
// resetRNN - flags whether to reset memory cells of RNN.
//
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output, bool resetRNN) = 0;
};
template <typename ElemType>

Просмотреть файл

@ -310,7 +310,7 @@ VariableSchema CNTKEvalExtended<ElemType>::GetInputSchema() const
template<typename ElemType>
template<template<typename> class ValueContainer>
void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<ElemType, ValueContainer> >& inputs, std::vector<ValueBuffer<ElemType, ValueContainer> >& outputs)
void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<ElemType, ValueContainer> >& inputs, std::vector<ValueBuffer<ElemType, ValueContainer> >& outputs, bool resetRNN)
{
if (!m_started)
RuntimeError("ForwardPass() called before StartForwardEvaluation()");
@ -360,10 +360,9 @@ void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<Elem
int numCols = type == MatrixType::DENSE ? buffer.m_buffer.size() / numRows : buffer.m_colIndices.size() - 1;
assert(numCols >= 1);
inputNode->GetMBLayout()->Init(1, numCols);
inputNode->GetMBLayout()->AddSequence(0, 0, INT_MIN, numCols);
if (m_SeqBeginTimeMin < m_SeqBeginTime)
m_SeqBeginTime--;
// INT_MIN is used to specify the lower bound of look-back step of recurrent nodes
inputNode->GetMBLayout()->AddSequence(0, 0, resetRNN ? 0 : INT_MIN, numCols);
if (type == MatrixType::DENSE)
matrix->SetValue(numRows, numCols, matrix->GetDeviceId(), buffer.m_buffer.data(), matrixFlagNormal);
@ -415,13 +414,25 @@ void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<Elem
template<typename ElemType>
void CNTKEvalExtended<ElemType>::ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& outputs)
{
ForwardPassT(inputs, outputs);
ForwardPassT(inputs, outputs, false);
}
template<typename ElemType>
void CNTKEvalExtended<ElemType>::ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& outputs, bool resetRNN)
{
ForwardPassT(inputs, outputs, resetRNN);
}
template<typename ElemType>
void CNTKEvalExtended<ElemType>::ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& outputs)
{
ForwardPassT(inputs, outputs);
ForwardPassT(inputs, outputs, false);
}
template<typename ElemType>
void CNTKEvalExtended<ElemType>::ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& outputs, bool resetRNN)
{
ForwardPassT(inputs, outputs, resetRNN);
}
template <typename ElemType>
@ -431,13 +442,6 @@ void CNTKEvalExtended<ElemType>::Destroy()
delete this;
}
template <typename ElemType>
void CNTKEvalExtended<ElemType>::ResetState()
{
m_SeqBeginTime = 0;
m_SeqBeginTimeMin = INT_MIN;
}
template <typename ElemType>
void EVAL_API GetEvalExtended(IEvaluateModelExtended<ElemType>** peval)
{

Просмотреть файл

@ -39,7 +39,6 @@ public:
virtual void CreateNetwork(const std::string& networkDescription);
virtual void Init(const std::string& config);
virtual void Destroy();
virtual void ResetState() = 0;
};
// ------------------------------------------------------------------------
@ -92,9 +91,7 @@ class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExt
{
public:
CNTKEvalExtended() : CNTKEvalBase<ElemType>(),
m_started(false),
m_SeqBeginTime(0),
m_SeqBeginTimeMin(0){}
m_started(false){}
virtual VariableSchema GetOutputSchema() const override;
@ -104,8 +101,12 @@ public:
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output) override;
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output, bool resetRNN) override;
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output) override;
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output, bool resetRNN) override;
virtual void Destroy() override;
virtual void CreateNetwork(const std::string& networkDescription) override
@ -118,8 +119,6 @@ public:
CNTKEvalBase<ElemType>::Init(config);
}
virtual void ResetState() override;
private:
static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n);
std::vector<ComputationNodeBasePtr> m_outputNodes;
@ -130,13 +129,7 @@ private:
template<template<typename> class ValueContainer>
void ForwardPassT(const std::vector < ValueBuffer<ElemType, ValueContainer> >& inputs,
std::vector < ValueBuffer<ElemType, ValueContainer> >& outputs);
std::vector < ValueBuffer<ElemType, ValueContainer> >& outputs, bool resetRNN);
// First time index in this minibatch. Note that this may be negative if the sequence started before this MB.
int m_SeqBeginTime;
// The min possible value of the first time index in this minibatch.
// For regular RNN/LSTM networks this should be -1.
int m_SeqBeginTimeMin;
};
} } }