CNTK/Source/EvalDll/CNTKEval.h

136 строки
4.4 KiB
C++

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// CNTKEval.h - Include file for the CNTK Evaluation DLL
//
// NOTICE: This interface is a public interface for evaluating models in CNTK.
// Changes to this interface may affect other projects, such as Argon and LatGen,
// and therefore need to be communicated with such groups.
//
#pragma once
#include <string>
#include <map>
#include <vector>
#include "Eval.h"
#include "EvalReader.h"
#include "EvalWriter.h"
#include "ComputationNetwork.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template <typename ElemType>
class CNTKEvalBase : public IEvaluateModelBase<ElemType>
{
protected:
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
ConfigParameters m_config;
ComputationNetworkPtr m_net;
// constructor
CNTKEvalBase() : m_net(nullptr) { }
public:
// CreateNetwork - create a network based on the network description
// networkDescription - network description
virtual void CreateNetwork(const std::string& networkDescription);
virtual void Init(const std::string& config);
virtual void Destroy();
};
// ------------------------------------------------------------------------
// Basic interface
// ------------------------------------------------------------------------
template <typename ElemType>
class CNTKEval : public CNTKEvalBase<ElemType>, public IEvaluateModel<ElemType>
{
EvalReader<ElemType>* m_reader;
EvalWriter<ElemType>* m_writer;
std::map<std::wstring, size_t> m_dimensions;
size_t m_start;
public:
CNTKEval() : CNTKEvalBase<ElemType>(), m_reader(nullptr), m_writer(nullptr) {}
virtual void GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup);
virtual void StartEvaluateMinibatchLoop(const std::wstring& outputNodeName);
virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs);
virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs);
virtual void Destroy() override;
virtual void CreateNetwork(const std::string& networkDescription) override
{
CNTKEvalBase<ElemType>::CreateNetwork(networkDescription);
}
virtual void Init(const std::string& config) override
{
CNTKEvalBase<ElemType>::Init(config);
m_start = 0;
}
virtual void ResetState() override
{
m_start = 1 - m_start;
}
};
// ------------------------------------------------------------------------
// Extended interface
// ------------------------------------------------------------------------
template <typename ElemType>
class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExtended<ElemType>
{
public:
CNTKEvalExtended() : CNTKEvalBase<ElemType>(),
m_started(false){}
virtual VariableSchema GetOutputSchema() const override;
virtual void StartForwardEvaluation(const std::vector<wstring>& outputs) override;
virtual VariableSchema GetInputSchema() const override;
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output) override;
virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output, bool resetRNN) override;
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output) override;
virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output, bool resetRNN) override;
virtual void Destroy() override;
virtual void CreateNetwork(const std::string& networkDescription) override
{
CNTKEvalBase<ElemType>::CreateNetwork(networkDescription);
}
virtual void Init(const std::string& config) override
{
CNTKEvalBase<ElemType>::Init(config);
}
private:
static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n);
std::vector<ComputationNodeBasePtr> m_outputNodes;
std::shared_ptr<ScopedNetworkOperationMode> m_scopedNetworkOperationMode;
std::vector<ComputationNodeBasePtr> m_inputNodes;
StreamMinibatchInputs m_inputMatrices;
bool m_started;
template<template<typename> class ValueContainer>
void ForwardPassT(const std::vector < ValueBuffer<ElemType, ValueContainer> >& inputs,
std::vector < ValueBuffer<ElemType, ValueContainer> >& outputs, bool resetRNN);
};
} } }