CNTK/MachineLearning/cn/SimpleRecurrentNetEvaluator.h

191 строка
7.3 KiB
C++

#pragma once
#include "ComputationNetwork.h"
#include "DataReader.h"
#include <vector>
#include <string>
#include <stdexcept>
#include "basetypes.h"
#include "fileutil.h"
#include "commandArgUtil.h"
#include <Windows.h>
#include <WinBase.h>
#include <fstream>
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
class RecurrentNetEvaluator : public SimpleEvaluator<ElemType>
{
public:
RecurrentNetEvaluator (ComputationNetwork<ElemType>& net, const size_t numMBsToShowResult=100)
: SimpleEvaluator<ElemType>(net, numMBsToShowResult)
{
}
//returns error rate
// ElemType Evaluate(IDataReader<ElemType>& dataReader, size_t mbSize, ElemType &evalSetCrossEntropy, const wchar_t* output=nullptr, size_t testSize=requestDataSize)
ElemType Evaluate(IDataReader<ElemType>& dataReader, ElemType &evalSetCrossEntropy, const wchar_t* output=nullptr)
{
std::vector<ComputationNodePtr> FeatureNodes = m_net.FeatureNodes();
std::vector<ComputationNodePtr> labelNodes = m_net.LabelNodes();
std::vector<ComputationNodePtr> evaluationNodes = m_net.EvaluationNodes();
std::list<ComputationNodePtr> crossEntropyNodes = m_net.GetNodesWithType(L"CrossEntropyWithSoftmax");
if (crossEntropyNodes.size()==0)
{
throw new runtime_error("No CrossEntropyWithSoftmax node found\n");
}
if (evaluationNodes.size()==0)
{
throw new runtime_error("No Evaluation node found\n");
}
if (crossEntropyNodes.size()==0)
{
throw new runtime_error("Evaluate() does not yet support reading multiple CrossEntropyWithSoftMax Nodes\n");
}
if (evaluationNodes.size() == 0)
{
throw new runtime_error("Evaluate() does not yet support reading multiple Evaluation Nodes\n");
}
std::map<std::wstring, Matrix<ElemType>*> inputMatrices;
for (size_t i=0; i<FeatureNodes.size(); i++)
{
inputMatrices[FeatureNodes[i]->NodeName()] = &FeatureNodes[i]->FunctionValues();
}
for (size_t i=0; i<labelNodes.size(); i++)
{
inputMatrices[labelNodes[i]->NodeName()] = &labelNodes[i]->FunctionValues();
}
// dataReader.StartMinibatchLoop(mbSize, 0, testSize);
ElemType epochEvalError = 0;
ElemType epochCrossEntropy = 0;
size_t totalEpochSamples = 0;
ElemType prevEpochEvalError = 0;
ElemType prevEpochCrossEntropy = 0;
size_t prevTotalEpochSamples = 0;
size_t prevStart = 1;
size_t numSamples = 0;
ElemType crossEntropy = 0;
ElemType evalError = 0;
ofstream outputStream;
if (output)
{
outputStream.open(output);
}
size_t numMBsRun = 0;
size_t actualMBSize = 0;
GenerateOneSentence(FeatureNodes, labelNodes, 10);
// while (dataReader.GetMinibatch(inputMatrices))
while (true)
{
actualMBSize = labelNodes[0]->FunctionValues().GetNumCols();
for (size_t i=0; i<FeatureNodes.size(); i++)
{
FeatureNodes[i]->UpdateEvalTimeStamp();
}
for (size_t i=0; i<labelNodes.size(); i++)
{
labelNodes[i]->UpdateEvalTimeStamp();
}
size_t npos = 0;
for (auto nodeIter = crossEntropyNodes.begin(); nodeIter != crossEntropyNodes.end() && npos < 100; nodeIter++)
{
m_net.Evaluate(evaluationNodes[npos]);
ElemType mbEvalError = evaluationNodes[npos]->FunctionValues().Get00Element(); //criterionNode should be a scalar
epochEvalError += mbEvalError;
//std::list<ComputationNodePtr>::iterator iter = crossEntropyNodes.begin();
//ComputationNodePtr cnp = crossEntropyNodes.front();
ComputationNodePtr crossEntropyNode = (*nodeIter);
m_net.Evaluate(crossEntropyNode);
ElemType mbCrossEntropy = crossEntropyNode->FunctionValues().Get00Element(); // criterionNode should be a scalar
epochCrossEntropy += mbCrossEntropy;
totalEpochSamples += actualMBSize;
}
break;
}
cout << "entropy = " << epochCrossEntropy << endl;
if (outputStream.is_open())
{
//TODO: add support to dump multiple outputs
ComputationNodePtr outputNode = m_net.OutputNodes()[0];
foreach_column(j, outputNode->FunctionValues())
{
foreach_row(i,outputNode->FunctionValues())
{
outputStream<<outputNode->FunctionValues()(i,j)<<" ";
}
outputStream<<endl;
}
}
numMBsRun++;
// show final grouping of output
numSamples = totalEpochSamples-prevTotalEpochSamples;
crossEntropy = epochCrossEntropy - prevEpochCrossEntropy;
evalError = epochEvalError - prevEpochEvalError;
fprintf(stderr,"Minibatch[%lu-%lu]: Samples Evaluated = %lu EvalErr Per Sample = %.8g Loss Per Sample = %.8g\n",
prevStart, numMBsRun, numSamples, evalError/numSamples, crossEntropy/numSamples);
//final statistics
epochEvalError /= (ElemType)totalEpochSamples;
epochCrossEntropy /= (ElemType)totalEpochSamples;
fprintf(stderr,"Overall: Samples Evaluated = %lu EvalErr Per Sample = %.8g Loss Per Sample = %.8g\n", totalEpochSamples, epochEvalError,epochCrossEntropy);
if (outputStream.is_open())
{
outputStream.close();
}
evalSetCrossEntropy = epochCrossEntropy;
return epochEvalError;
}
ElemType Evaluate(IDataReader<ElemType>& dataReader, size_t mbSize, const wchar_t* output=nullptr, size_t testSize=requestDataSize)
{
ElemType tmpCrossEntropy;
return Evaluate(dataReader,mbSize,tmpCrossEntropy,output,testSize);
}
bool GenerateOneSentence(
std::vector<ComputationNodePtr>& FeatureNodes,
std::vector<ComputationNodePtr>& labelNodes,
size_t nbrSamples
)
{
for (size_t i = 0; i < nbrSamples; i++)
{
for (size_t d = 0; d < FeatureNodes[0]->FunctionValues().GetNumRows(); d++)
{
FeatureNodes[i]->FunctionValues()(d,0) = (ElemType)rand();
}
for (size_t d = 0; d < labelNodes[0]->FunctionValues().GetNumRows(); d++)
{
labelNodes[i]->FunctionValues()(d,0) = (ElemType)((d == i)?1:0);
}
}
return true;
}
};
}}}