Review comments
This commit is contained in:
Родитель
83c152f5c1
Коммит
832655a154
|
@ -48,8 +48,6 @@ public:
|
|||
// Free resources
|
||||
//
|
||||
virtual void Destroy() = 0;
|
||||
|
||||
virtual void ResetState() = 0;
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
|
@ -97,6 +95,8 @@ public:
|
|||
// happen during evaluation
|
||||
//
|
||||
virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs) = 0;
|
||||
|
||||
virtual void ResetState() = 0;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -498,9 +498,8 @@ public:
|
|||
return outputNodes;
|
||||
}
|
||||
|
||||
// collect all input nodes that outputNodes depend on
|
||||
// TODO: This is rather generic, we should move this to a shared place. DataReaderHelpers.h?
|
||||
std::vector<ComputationNodeBasePtr> InputNodesFor(const std::vector<std::wstring>& outputNodeNames)
|
||||
// Collect all input nodes that outputNodes depend on.
|
||||
std::vector<ComputationNodeBasePtr> InputNodesForOutputs(const std::vector<std::wstring>& outputNodeNames)
|
||||
{
|
||||
// use map to remove duplicated items
|
||||
auto outputNodes = OutputNodesByName(outputNodeNames);
|
||||
|
|
|
@ -32,7 +32,7 @@ bool g_shareNodeValueMatrices = false;
|
|||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEvalBase<ElemType>::Init(const std::string& config)
|
||||
{
|
||||
m_config.Parse(config);
|
||||
|
@ -44,7 +44,7 @@ void CNTKEvalBase<ElemType>::Init(const std::string& config)
|
|||
|
||||
// CreateNetwork - create a network based on the network description
|
||||
// networkDescription - network description
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEvalBase<ElemType>::CreateNetwork(const std::string& networkDescription)
|
||||
{
|
||||
ConfigParameters config;
|
||||
|
@ -62,7 +62,7 @@ void CNTKEvalBase<ElemType>::CreateNetwork(const std::string& networkDescription
|
|||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEvalBase<ElemType>::Destroy()
|
||||
{
|
||||
// cleanup everything
|
||||
|
@ -74,7 +74,7 @@ void CNTKEvalBase<ElemType>::Destroy()
|
|||
// Basic interface
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void EVAL_API GetEval(IEvaluateModel<ElemType>** peval)
|
||||
{
|
||||
*peval = new CNTKEval<ElemType>();
|
||||
|
@ -93,7 +93,7 @@ extern "C" EVAL_API void GetEvalD(IEvaluateModel<double>** peval)
|
|||
// dimensions - map from name of node to dimension of the node, will be appended to for Input/Output scenarios
|
||||
// nodeGroup - type of node we are requesting (input/output/specified)
|
||||
// NOTE: when nodeGroup==specified the dimensions map is expected to be populated with the string names of the nodes requested, dimensions will be modified return the current value.
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup)
|
||||
{
|
||||
if (m_net == NULL)
|
||||
|
@ -145,7 +145,7 @@ void CNTKEval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimen
|
|||
|
||||
// StartEvaluateMinibatchLoop - Prepare network for Evaluate() calls.
|
||||
// ouputNodeName - name of node that will be evaluated
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEval<ElemType>::StartEvaluateMinibatchLoop(const std::wstring& outputNodeName)
|
||||
{
|
||||
m_net->StartEvaluateMinibatchLoop(m_net->GetNodeFromName(outputNodeName));
|
||||
|
@ -154,7 +154,7 @@ void CNTKEval<ElemType>::StartEvaluateMinibatchLoop(const std::wstring& outputNo
|
|||
// Evaluate - Evalute using the model with the given inputs and outputs
|
||||
// inputs - map from node name to input vector
|
||||
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs)
|
||||
{
|
||||
size_t minibatchSize = m_config(L"minibatchSize", (size_t) 10240);
|
||||
|
@ -191,7 +191,7 @@ void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>
|
|||
|
||||
// Evaluate - Evalute using the model with the given inputs and outputs
|
||||
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs)
|
||||
{
|
||||
// get the evaluation names from the output string
|
||||
|
@ -215,7 +215,7 @@ void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>
|
|||
}
|
||||
|
||||
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEval<ElemType>::Destroy()
|
||||
{
|
||||
CNTKEvalBase<ElemType>::Destroy();
|
||||
|
@ -256,7 +256,7 @@ void CNTKEvalExtended<ElemType>::StartForwardEvaluation(std::vector<wstring> out
|
|||
m_scopedNetworkOperationMode = make_shared<ScopedNetworkOperationMode>(m_net, NetworkOperationMode::inferring);
|
||||
// allocate memory for forward computation
|
||||
m_outputNodes = m_net->OutputNodesByName(outputNodeNames);
|
||||
m_inputNodes = m_net->InputNodesFor(outputNodeNames);
|
||||
m_inputNodes = m_net->InputNodesForOutputs(outputNodeNames);
|
||||
// allocate memory for forward computation
|
||||
m_net->AllocateAllMatrices({}, m_outputNodes, nullptr);
|
||||
m_net->StartEvaluateMinibatchLoop(m_outputNodes);
|
||||
|
@ -282,7 +282,7 @@ VariableSchema CNTKEvalExtended<ElemType>::GetInputSchema() const
|
|||
if (nodes.size() == 0)
|
||||
{
|
||||
// Default to all nodes
|
||||
nodes = m_net->InputNodesFor({});
|
||||
nodes = m_net->InputNodesForOutputs({});
|
||||
}
|
||||
|
||||
for (const auto& n : nodes)
|
||||
|
@ -354,14 +354,14 @@ void CNTKEvalExtended<ElemType>::ForwardPass(const Variables<ElemType>& inputs,
|
|||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void CNTKEvalExtended<ElemType>::Destroy()
|
||||
{
|
||||
CNTKEvalBase<ElemType>::Destroy();
|
||||
delete this;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
void EVAL_API GetEvalExtended(IEvaluateModelExtended<ElemType>** peval)
|
||||
{
|
||||
*peval = new CNTKEvalExtended<ElemType>();
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
class CNTKEvalBase : public IEvaluateModelBase<ElemType>
|
||||
{
|
||||
protected:
|
||||
|
@ -39,13 +39,12 @@ public:
|
|||
virtual void CreateNetwork(const std::string& networkDescription);
|
||||
virtual void Init(const std::string& config);
|
||||
virtual void Destroy();
|
||||
virtual void ResetState() {};
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
// Basic interface
|
||||
// ------------------------------------------------------------------------
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
class CNTKEval : public CNTKEvalBase<ElemType>, public IEvaluateModel<ElemType>
|
||||
{
|
||||
EvalReader<ElemType>* m_reader;
|
||||
|
@ -87,7 +86,7 @@ public:
|
|||
// ------------------------------------------------------------------------
|
||||
// Extended interface
|
||||
// ------------------------------------------------------------------------
|
||||
template <class ElemType>
|
||||
template <typename ElemType>
|
||||
class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExtended<ElemType>
|
||||
{
|
||||
virtual VariableSchema GetOutputSchema() const override;
|
||||
|
@ -109,8 +108,6 @@ class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExt
|
|||
{
|
||||
CNTKEvalBase<ElemType>::Init(config);
|
||||
}
|
||||
|
||||
virtual void ResetState() override { }
|
||||
private:
|
||||
static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n);
|
||||
std::vector<ComputationNodeBasePtr> m_outputNodes;
|
||||
|
|
|
@ -15,12 +15,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
/*static*/ struct DataReaderHelpers
|
||||
{
|
||||
// -------------------------------------------------------------------
|
||||
// GetMinibatchIntoNetwork() -- get one minibatch from Reader (this->trainSetDataReader) into Network (this->net)
|
||||
// Returns false if no data is read. In that case, no other return value can be expected to contain meaningful values (e.g. actualMBSize will be unchanged).
|
||||
// Sets actualMBSize to the number of matrix columns. Note that 0 is a valid value to be returned for actualMBSize, caller must handle that correctly.
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
template <class ElemType>
|
||||
static void NotifyChangedNodes(ComputationNetworkPtr net, StreamMinibatchInputs& inputMatrices)
|
||||
{
|
||||
|
@ -37,6 +31,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
node->NotifyFunctionValuesMBSizeModified();
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// GetMinibatchIntoNetwork() -- get one minibatch from Reader (this->trainSetDataReader) into Network (this->net)
|
||||
// Returns false if no data is read. In that case, no other return value can be expected to contain meaningful values (e.g. actualMBSize will be unchanged).
|
||||
// Sets actualMBSize to the number of matrix columns. Note that 0 is a valid value to be returned for actualMBSize, caller must handle that correctly.
|
||||
// -------------------------------------------------------------------
|
||||
// Note: This will go away with the redesigned reader interface.
|
||||
// TODO: callers of this often do ComputationNetwork::BumpEvalTimeStamp(featureNodes) and also for labels; we should eliminate the need for this.
|
||||
template <class ElemType>
|
||||
|
|
|
@ -43,7 +43,7 @@ public:
|
|||
fprintf(stderr, "OutputNodeNames are not specified, using the default outputnodes.\n");
|
||||
|
||||
std::vector<ComputationNodeBasePtr> outputNodes = m_net->OutputNodesByName(outputNodeNames);
|
||||
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesFor(outputNodeNames);
|
||||
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesForOutputs(outputNodeNames);
|
||||
|
||||
// allocate memory for forward computation
|
||||
m_net->AllocateAllMatrices({}, outputNodes, nullptr);
|
||||
|
@ -156,7 +156,7 @@ public:
|
|||
ScopedNetworkOperationMode modeGuard(m_net, nodeUnitTest ? NetworkOperationMode::training : NetworkOperationMode::inferring);
|
||||
|
||||
std::vector<ComputationNodeBasePtr> outputNodes = m_net->OutputNodesByName(outputNodeNames);
|
||||
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesFor(outputNodeNames);
|
||||
std::vector<ComputationNodeBasePtr> inputNodes = m_net->InputNodesForOutputs(outputNodeNames);
|
||||
std::vector<ComputationNodePtr> gradientNodes;
|
||||
std::vector<ComputationNodeBasePtr> allOutputNodes = outputNodes;
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче