This commit is contained in:
Clemens Marschner 2016-05-19 17:41:17 +02:00
Родитель 83c152f5c1
Коммит 832655a154
6 изменённых файлов: 27 добавлений и 32 удалений

Просмотреть файл

@ -48,8 +48,6 @@ public:
// Free resources
//
virtual void Destroy() = 0;
virtual void ResetState() = 0;
};
// ------------------------------------------------------------------------
@ -97,6 +95,8 @@ public:
// happen during evaluation
//
virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs) = 0;
virtual void ResetState() = 0;
};

Просмотреть файл

@ -498,9 +498,8 @@ public:
return outputNodes;
}
// collect all input nodes that outputNodes depend on
// TODO: This is rather generic, we should move this to a shared place. DataReaderHelpers.h?
std::vector<ComputationNodeBasePtr> InputNodesFor(const std::vector<std::wstring>& outputNodeNames)
// Collect all input nodes that outputNodes depend on.
std::vector<ComputationNodeBasePtr> InputNodesForOutputs(const std::vector<std::wstring>& outputNodeNames)
{
// use map to remove duplicated items
auto outputNodes = OutputNodesByName(outputNodeNames);

Просмотреть файл

@ -32,7 +32,7 @@ bool g_shareNodeValueMatrices = false;
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
template <typename ElemType>
void CNTKEvalBase<ElemType>::Init(const std::string& config)
{
m_config.Parse(config);
@ -44,7 +44,7 @@ void CNTKEvalBase<ElemType>::Init(const std::string& config)
// CreateNetwork - create a network based on the network description
// networkDescription - network description
template <class ElemType>
template <typename ElemType>
void CNTKEvalBase<ElemType>::CreateNetwork(const std::string& networkDescription)
{
ConfigParameters config;
@ -62,7 +62,7 @@ void CNTKEvalBase<ElemType>::CreateNetwork(const std::string& networkDescription
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
template <typename ElemType>
void CNTKEvalBase<ElemType>::Destroy()
{
// cleanup everything
@ -74,7 +74,7 @@ void CNTKEvalBase<ElemType>::Destroy()
// Basic interface
// ----------------------------------------------------------------------------
template <class ElemType>
template <typename ElemType>
void EVAL_API GetEval(IEvaluateModel<ElemType>** peval)
{
*peval = new CNTKEval<ElemType>();
@ -93,7 +93,7 @@ extern "C" EVAL_API void GetEvalD(IEvaluateModel<double>** peval)
// dimensions - map from name of node to dimension of the node, will be appended to for Input/Output scenarios
// nodeGroup - type of node we are requesting (input/output/specified)
// NOTE: when nodeGroup==specified the dimensions map is expected to be populated with the string names of the nodes requested, dimensions will be modified return the current value.
template <class ElemType>
template <typename ElemType>
void CNTKEval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup)
{
if (m_net == NULL)
@ -145,7 +145,7 @@ void CNTKEval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimen
// StartEvaluateMinibatchLoop - Prepare network for Evaluate() calls.
// ouputNodeName - name of node that will be evaluated
template <class ElemType>
template <typename ElemType>
void CNTKEval<ElemType>::StartEvaluateMinibatchLoop(const std::wstring& outputNodeName)
{
m_net->StartEvaluateMinibatchLoop(m_net->GetNodeFromName(outputNodeName));
@ -154,7 +154,7 @@ void CNTKEval<ElemType>::StartEvaluateMinibatchLoop(const std::wstring& outputNo
// Evaluate - Evalute using the model with the given inputs and outputs
// inputs - map from node name to input vector
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
template <class ElemType>
template <typename ElemType>
void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs)
{
size_t minibatchSize = m_config(L"minibatchSize", (size_t) 10240);
@ -191,7 +191,7 @@ void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>
// Evaluate - Evalute using the model with the given inputs and outputs
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
template <class ElemType>
template <typename ElemType>
void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs)
{
// get the evaluation names from the output string
@ -215,7 +215,7 @@ void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>
}
template <class ElemType>
template <typename ElemType>
void CNTKEval<ElemType>::Destroy()
{
CNTKEvalBase<ElemType>::Destroy();
@ -256,7 +256,7 @@ void CNTKEvalExtended<ElemType>::StartForwardEvaluation(std::vector<wstring> out
m_scopedNetworkOperationMode = make_shared<ScopedNetworkOperationMode>(m_net, NetworkOperationMode::inferring);
// allocate memory for forward computation
m_outputNodes = m_net->OutputNodesByName(outputNodeNames);
m_inputNodes = m_net->InputNodesFor(outputNodeNames);
m_inputNodes = m_net->InputNodesForOutputs(outputNodeNames);
// allocate memory for forward computation
m_net->AllocateAllMatrices({}, m_outputNodes, nullptr);
m_net->StartEvaluateMinibatchLoop(m_outputNodes);
@ -282,7 +282,7 @@ VariableSchema CNTKEvalExtended<ElemType>::GetInputSchema() const
if (nodes.size() == 0)
{
// Default to all nodes
nodes = m_net->InputNodesFor({});
nodes = m_net->InputNodesForOutputs({});
}
for (const auto& n : nodes)
@ -354,14 +354,14 @@ void CNTKEvalExtended<ElemType>::ForwardPass(const Variables<ElemType>& inputs,
}
}
template <class ElemType>
template <typename ElemType>
void CNTKEvalExtended<ElemType>::Destroy()
{
CNTKEvalBase<ElemType>::Destroy();
delete this;
}
template <class ElemType>
template <typename ElemType>
void EVAL_API GetEvalExtended(IEvaluateModelExtended<ElemType>** peval)
{
*peval = new CNTKEvalExtended<ElemType>();

Просмотреть файл

@ -22,7 +22,7 @@
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
template <typename ElemType>
class CNTKEvalBase : public IEvaluateModelBase<ElemType>
{
protected:
@ -39,13 +39,12 @@ public:
virtual void CreateNetwork(const std::string& networkDescription);
virtual void Init(const std::string& config);
virtual void Destroy();
virtual void ResetState() {};
};
// ------------------------------------------------------------------------
// Basic interface
// ------------------------------------------------------------------------
template <class ElemType>
template <typename ElemType>
class CNTKEval : public CNTKEvalBase<ElemType>, public IEvaluateModel<ElemType>
{
EvalReader<ElemType>* m_reader;
@ -87,7 +86,7 @@ public:
// ------------------------------------------------------------------------
// Extended interface
// ------------------------------------------------------------------------
template <class ElemType>
template <typename ElemType>
class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExtended<ElemType>
{
virtual VariableSchema GetOutputSchema() const override;
@ -109,8 +108,6 @@ class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExt
{
CNTKEvalBase<ElemType>::Init(config);
}
virtual void ResetState() override { }
private:
static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n);
std::vector<ComputationNodeBasePtr> m_outputNodes;

Просмотреть файл

@ -15,12 +15,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
/*static*/ struct DataReaderHelpers
{
// -------------------------------------------------------------------
// GetMinibatchIntoNetwork() -- get one minibatch from Reader (this->trainSetDataReader) into Network (this->net)
// Returns false if no data is read. In that case, no other return value can be expected to contain meaningful values (e.g. actualMBSize will be unchanged).
// Sets actualMBSize to the number of matrix columns. Note that 0 is a valid value to be returned for actualMBSize, caller must handle that correctly.
// -------------------------------------------------------------------
template <class ElemType>
static void NotifyChangedNodes(ComputationNetworkPtr net, StreamMinibatchInputs& inputMatrices)
{
@ -37,6 +31,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
node->NotifyFunctionValuesMBSizeModified();
}
// -------------------------------------------------------------------
// GetMinibatchIntoNetwork() -- get one minibatch from Reader (this->trainSetDataReader) into Network (this->net)
// Returns false if no data is read. In that case, no other return value can be expected to contain meaningful values (e.g. actualMBSize will be unchanged).
// Sets actualMBSize to the number of matrix columns. Note that 0 is a valid value to be returned for actualMBSize, caller must handle that correctly.
// -------------------------------------------------------------------
// Note: This will go away with the redesigned reader interface.
// TODO: callers of this often do ComputationNetwork::BumpEvalTimeStamp(featureNodes) and also for labels; we should eliminate the need for this.
template <class ElemType>

Просмотреть файл

@ -43,7 +43,7 @@ public:
fprintf(stderr, "OutputNodeNames are not specified, using the default outputnodes.\n");
std::vector<ComputationNodeBasePtr> outputNodes = m_net->OutputNodesByName(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesFor(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesForOutputs(outputNodeNames);
// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);
@ -156,7 +156,7 @@ public:
ScopedNetworkOperationMode modeGuard(m_net, nodeUnitTest ? NetworkOperationMode::training : NetworkOperationMode::inferring);
std::vector<ComputationNodeBasePtr> outputNodes = m_net->OutputNodesByName(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesFor(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesForOutputs(outputNodeNames);
std::vector<ComputationNodePtr> gradientNodes;
std::vector<ComputationNodeBasePtr> allOutputNodes = outputNodes;