Add resetRNN flag to Evaluate in ExtendedEval intf
This commit is contained in:
Родитель
f946c20f1b
Коммит
7447d55ed0
|
@ -49,12 +49,6 @@ public:
|
|||
// Free resources
|
||||
//
|
||||
virtual void Destroy() = 0;
|
||||
|
||||
//
|
||||
// Reset initial state of all Recurrence loops (RNNs) in the model.
|
||||
// Call this before processing the first sequence or whenever need to reset the memory cells to default value.
|
||||
//
|
||||
virtual void ResetState() = 0;
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
|
@ -102,6 +96,11 @@ public:
|
|||
// happen during evaluation
|
||||
//
|
||||
virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs) = 0;
|
||||
|
||||
//
|
||||
// Reset initial state of all Recurrence loops (RNNs) in the model.
|
||||
//
|
||||
virtual void ResetState() = 0;
|
||||
};
|
||||
|
||||
|
||||
|
@ -339,11 +338,23 @@ public:
|
|||
//
|
||||
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output) = 0;
|
||||
|
||||
//
|
||||
// Same as above, and
|
||||
// resetRNN - flags whether to reset memory cells of RNN.
|
||||
//
|
||||
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output, bool resetRNN) = 0;
|
||||
|
||||
//
|
||||
// Same as above, but takes references to static arrays instead of std::vector
|
||||
// (e.g. when vectors are manages by .net)
|
||||
//
|
||||
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output) = 0;
|
||||
|
||||
//
|
||||
// Same as above, and
|
||||
// resetRNN - flags whether to reset memory cells of RNN.
|
||||
//
|
||||
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output, bool resetRNN) = 0;
|
||||
};
|
||||
|
||||
template <typename ElemType>
|
||||
|
|
|
@ -310,7 +310,7 @@ VariableSchema CNTKEvalExtended<ElemType>::GetInputSchema() const
|
|||
|
||||
template<typename ElemType>
|
||||
template<template<typename> class ValueContainer>
|
||||
void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<ElemType, ValueContainer> >& inputs, std::vector<ValueBuffer<ElemType, ValueContainer> >& outputs)
|
||||
void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<ElemType, ValueContainer> >& inputs, std::vector<ValueBuffer<ElemType, ValueContainer> >& outputs, bool resetRNN)
|
||||
{
|
||||
if (!m_started)
|
||||
RuntimeError("ForwardPass() called before StartForwardEvaluation()");
|
||||
|
@ -360,10 +360,9 @@ void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<Elem
|
|||
int numCols = type == MatrixType::DENSE ? buffer.m_buffer.size() / numRows : buffer.m_colIndices.size() - 1;
|
||||
assert(numCols >= 1);
|
||||
inputNode->GetMBLayout()->Init(1, numCols);
|
||||
inputNode->GetMBLayout()->AddSequence(0, 0, INT_MIN, numCols);
|
||||
|
||||
if (m_SeqBeginTimeMin < m_SeqBeginTime)
|
||||
m_SeqBeginTime--;
|
||||
|
||||
// INT_MIN is used to specify the lower bound of look-back step of recurrent nodes
|
||||
inputNode->GetMBLayout()->AddSequence(0, 0, resetRNN ? 0 : INT_MIN, numCols);
|
||||
|
||||
if (type == MatrixType::DENSE)
|
||||
matrix->SetValue(numRows, numCols, matrix->GetDeviceId(), buffer.m_buffer.data(), matrixFlagNormal);
|
||||
|
@ -415,13 +414,25 @@ void CNTKEvalExtended<ElemType>::ForwardPassT(const std::vector<ValueBuffer<Elem
|
|||
template<typename ElemType>
|
||||
void CNTKEvalExtended<ElemType>::ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& outputs)
|
||||
{
|
||||
ForwardPassT(inputs, outputs);
|
||||
ForwardPassT(inputs, outputs, false);
|
||||
}
|
||||
|
||||
template<typename ElemType>
|
||||
void CNTKEvalExtended<ElemType>::ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& outputs, bool resetRNN)
|
||||
{
|
||||
ForwardPassT(inputs, outputs, resetRNN);
|
||||
}
|
||||
|
||||
template<typename ElemType>
|
||||
void CNTKEvalExtended<ElemType>::ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& outputs)
|
||||
{
|
||||
ForwardPassT(inputs, outputs);
|
||||
ForwardPassT(inputs, outputs, false);
|
||||
}
|
||||
|
||||
template<typename ElemType>
|
||||
void CNTKEvalExtended<ElemType>::ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& outputs, bool resetRNN)
|
||||
{
|
||||
ForwardPassT(inputs, outputs, resetRNN);
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
|
@ -431,13 +442,6 @@ void CNTKEvalExtended<ElemType>::Destroy()
|
|||
delete this;
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void CNTKEvalExtended<ElemType>::ResetState()
|
||||
{
|
||||
m_SeqBeginTime = 0;
|
||||
m_SeqBeginTimeMin = INT_MIN;
|
||||
}
|
||||
|
||||
template <typename ElemType>
|
||||
void EVAL_API GetEvalExtended(IEvaluateModelExtended<ElemType>** peval)
|
||||
{
|
||||
|
|
|
@ -39,7 +39,6 @@ public:
|
|||
virtual void CreateNetwork(const std::string& networkDescription);
|
||||
virtual void Init(const std::string& config);
|
||||
virtual void Destroy();
|
||||
virtual void ResetState() = 0;
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
|
@ -92,9 +91,7 @@ class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExt
|
|||
{
|
||||
public:
|
||||
CNTKEvalExtended() : CNTKEvalBase<ElemType>(),
|
||||
m_started(false),
|
||||
m_SeqBeginTime(0),
|
||||
m_SeqBeginTimeMin(0){}
|
||||
m_started(false){}
|
||||
|
||||
virtual VariableSchema GetOutputSchema() const override;
|
||||
|
||||
|
@ -104,8 +101,12 @@ public:
|
|||
|
||||
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output) override;
|
||||
|
||||
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output, bool resetRNN) override;
|
||||
|
||||
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output) override;
|
||||
|
||||
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output, bool resetRNN) override;
|
||||
|
||||
virtual void Destroy() override;
|
||||
|
||||
virtual void CreateNetwork(const std::string& networkDescription) override
|
||||
|
@ -118,8 +119,6 @@ public:
|
|||
CNTKEvalBase<ElemType>::Init(config);
|
||||
}
|
||||
|
||||
virtual void ResetState() override;
|
||||
|
||||
private:
|
||||
static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n);
|
||||
std::vector<ComputationNodeBasePtr> m_outputNodes;
|
||||
|
@ -130,13 +129,7 @@ private:
|
|||
|
||||
template<template<typename> class ValueContainer>
|
||||
void ForwardPassT(const std::vector < ValueBuffer<ElemType, ValueContainer> >& inputs,
|
||||
std::vector < ValueBuffer<ElemType, ValueContainer> >& outputs);
|
||||
std::vector < ValueBuffer<ElemType, ValueContainer> >& outputs, bool resetRNN);
|
||||
|
||||
// First time index in this minibatch. Note that this may be negative if the sequence started before this MB.
|
||||
int m_SeqBeginTime;
|
||||
|
||||
// The min possible value of the first time index in this minibatch.
|
||||
// For regular RNN/LSTM networks this should be -1.
|
||||
int m_SeqBeginTimeMin;
|
||||
};
|
||||
} } }
|
||||
|
|
Загрузка…
Ссылка в новой задаче