CNTK/Source/EvalDll/CNTKEval.cpp

186 строки
5.8 KiB
C++
Исходник Обычный вид История

2014-08-30 03:21:42 +04:00
//
2016-01-18 11:36:17 +03:00
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
2014-08-30 03:21:42 +04:00
//
// CNTKEval.cpp : Defines the exported functions for the CNTK DLL.
//
#include "stdafx.h"
#define EVAL_EXPORTS // creating the exports here
2014-08-30 03:21:42 +04:00
#include "Eval.h"
#include "CNTKEval.h"
#include "CPUMatrix.h" // for SetNumThreads()
2014-08-30 03:21:42 +04:00
#include "SimpleOutputWriter.h"
#ifdef LEAKDETECT
#include <vld.h> // leak detection
#endif
#include "BestGpu.h"
#include "MPIWrapper.h"
2014-08-30 03:21:42 +04:00
// TODO: Temporary mechanism to enable memory sharing for
// node output value matrices. This will go away when the
// sharing is ready to be enabled by default
bool g_shareNodeValueMatrices = false;
2014-08-30 03:21:42 +04:00
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
2014-08-30 03:21:42 +04:00
void EVAL_API GetEval(IEvaluateModel<ElemType>** peval)
{
*peval = new CNTKEval<ElemType>();
}
extern "C" EVAL_API void GetEvalF(IEvaluateModel<float>** peval)
{
GetEval(peval);
}
extern "C" EVAL_API void GetEvalD(IEvaluateModel<double>** peval)
{
GetEval(peval);
}
template <class ElemType>
2014-08-30 03:21:42 +04:00
void CNTKEval<ElemType>::Init(const std::string& config)
{
m_start = 0;
2014-08-30 03:21:42 +04:00
m_config.Parse(config);
if (m_config.Exists("modelPath"))
{
std::wstring path = m_config("modelPath");
LoadModel(path);
}
size_t nThreads = m_config("numCPUThreads", "1");
CPUMatrix<ElemType>::SetNumThreads(nThreads);
g_shareNodeValueMatrices = m_config(L"shareNodeValueMatrices", false);
2014-08-30 03:21:42 +04:00
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
2014-08-30 03:21:42 +04:00
void CNTKEval<ElemType>::Destroy()
{
// cleanup everything
m_net.reset();
2014-08-30 03:21:42 +04:00
delete m_reader;
delete m_writer;
delete this;
}
// LoadModel - load a model from the specified path
// modelFileName - file holding the model to load
template <class ElemType>
2014-08-30 03:21:42 +04:00
void CNTKEval<ElemType>::LoadModel(const std::wstring& modelFileName)
{
DEVICEID_TYPE deviceId = DeviceFromConfig(m_config);
fprintf(stderr, "DeviceID=%d\n", (int) deviceId);
m_net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName);
2014-08-30 03:21:42 +04:00
}
// GetNodeDimensions - Get the node dimensions of the specified nodes
// dimensions - map from name of node to dimension of the node, will be appended to for Input/Output scenarios
// nodeGroup - type of node we are requesting (input/output/specified)
// NOTE: when nodeGroup==specified the dimensions map is expected to be populated with the string names of the nodes requested, dimensions will be modified return the current value.
template <class ElemType>
2014-08-30 03:21:42 +04:00
void CNTKEval<ElemType>::GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup)
{
if (m_net == NULL)
{
for (auto iter = dimensions.begin(); iter != dimensions.end(); iter++)
2014-08-30 03:21:42 +04:00
iter->second = 0;
return;
}
const auto& outputNodes = m_net->OutputNodes();
2014-08-30 03:21:42 +04:00
switch (nodeGroup)
{
case nodeInput:
{
auto& nodes = m_net->InputNodes(outputNodes[0]);
for (auto& node : nodes)
2014-08-30 03:21:42 +04:00
{
std::wstring name = node->NodeName();
size_t size = node->GetSampleMatrixNumRows();
2014-08-30 03:21:42 +04:00
dimensions[name] = size;
}
break;
}
2014-08-30 03:21:42 +04:00
case nodeOutput:
{
const auto& nodes = outputNodes;
for (auto& node : nodes)
2014-08-30 03:21:42 +04:00
{
std::wstring name = node->NodeName();
size_t size = node->GetSampleMatrixNumRows();
2014-08-30 03:21:42 +04:00
dimensions[name] = size;
}
break;
}
2014-08-30 03:21:42 +04:00
case nodeSpecified:
for (auto iter = dimensions.begin(); iter != dimensions.end(); iter++)
2014-08-30 03:21:42 +04:00
{
auto node = m_net->GetNodeFromName(iter->first);
iter->second = node->GetSampleMatrixNumRows();
2014-08-30 03:21:42 +04:00
}
break;
}
}
// StartEvaluateMinibatchLoop - Prepare network for Evaluate() calls.
// ouputNodeName - name of node that will be evaluated
template <class ElemType>
void CNTKEval<ElemType>::StartEvaluateMinibatchLoop(const std::wstring& outputNodeName)
{
m_net->StartEvaluateMinibatchLoop(m_net->GetNodeFromName(outputNodeName));
}
2014-08-30 03:21:42 +04:00
// Evaluate - Evalute using the model with the given inputs and outputs
// inputs - map from node name to input vector
// outputs - map from node name to output vector, outputs vectors need to be preallocated by caller, sizing will happen during evaluation
template <class ElemType>
2014-08-30 03:21:42 +04:00
void CNTKEval<ElemType>::Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs)
{
size_t minibatchSize = m_config(L"minibatchSize", (size_t) 10240);
2014-08-30 03:21:42 +04:00
// get the evaluation names from the output string
vector<wstring> outNodeNames;
ConfigParameters config;
// config["deviceId"] = to_string(m_net->GetDeviceId());
2014-08-30 03:21:42 +04:00
// create the reader if necessary
if (m_reader == nullptr)
{
m_reader = new EvalReader<ElemType>(config);
}
// now set the data in the reader
GetNodeDimensions(m_dimensions, nodeInput);
m_reader->SetData(&inputs, &m_dimensions);
m_reader->SetBoundary(m_start);
2014-08-30 03:21:42 +04:00
// create the reader if necessary
if (m_writer == nullptr)
{
m_writer = new EvalWriter<ElemType>(config);
}
// now set the data in the reader
GetNodeDimensions(m_dimensions, nodeOutput);
m_writer->SetData(&outputs, &m_dimensions);
// call the evaluator
SimpleOutputWriter<ElemType> eval(m_net);
2014-08-30 03:21:42 +04:00
eval.WriteOutput(*m_reader, minibatchSize, *m_writer, outNodeNames);
}
// ResetState - Reset the cell state when we get start of an utterance
template <class ElemType>
2014-08-30 03:21:42 +04:00
void CNTKEval<ElemType>::ResetState()
{
m_start = 1 - m_start;
2014-08-30 03:21:42 +04:00
}
// instantiate all the combinations we expect to be used
template class CNTKEval<double>;
2014-08-30 03:21:42 +04:00
template class CNTKEval<float>;
} } }