2016-01-10 05:13:31 +03:00
|
|
|
//
|
2016-01-18 11:36:17 +03:00
|
|
|
// Copyright (c) Microsoft. All rights reserved.
|
|
|
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
|
|
|
//
|
|
|
|
// EvalActions.cpp -- CNTK evaluation-related actions
|
2016-01-10 05:13:31 +03:00
|
|
|
//
|
|
|
|
|
2016-01-18 11:36:14 +03:00
|
|
|
#define _CRT_NONSTDC_NO_DEPRECATE // make VS accept POSIX functions without _
|
2016-01-10 05:13:31 +03:00
|
|
|
|
|
|
|
#include "stdafx.h"
|
|
|
|
#include "Basics.h"
|
|
|
|
#include "Actions.h"
|
|
|
|
#include "ComputationNetwork.h"
|
|
|
|
#include "ComputationNode.h"
|
|
|
|
#include "DataReader.h"
|
|
|
|
#include "DataWriter.h"
|
|
|
|
#include "Config.h"
|
|
|
|
#include "SimpleEvaluator.h"
|
|
|
|
#include "SimpleOutputWriter.h"
|
2016-04-13 09:09:08 +03:00
|
|
|
#include "Criterion.h"
|
2016-01-10 05:13:31 +03:00
|
|
|
#include "BestGpu.h"
|
|
|
|
#include "ScriptableObjects.h"
|
|
|
|
#include "BrainScriptEvaluator.h"
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
#include <chrono>
|
|
|
|
#include <algorithm>
|
|
|
|
#include <vector>
|
|
|
|
#include <iostream>
|
|
|
|
#include <queue>
|
|
|
|
#include <set>
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
#ifndef let
|
|
|
|
#define let const auto
|
|
|
|
#endif
|
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
using namespace Microsoft::MSR;
|
|
|
|
using namespace Microsoft::MSR::CNTK;
|
|
|
|
|
2016-10-02 17:00:55 +03:00
|
|
|
bool GetDistributedMBReadingDefaultValue(const ConfigParameters& config, const IDataReader& reader)
|
|
|
|
{
|
|
|
|
// Return 'true' if we're running a parallel training with a v2 reader, 'false' otherwise.
|
|
|
|
return (MPIWrapper::GetInstance() != nullptr && !reader.IsLegacyReader());
|
|
|
|
}
|
|
|
|
|
2016-01-10 05:13:31 +03:00
|
|
|
// ===========================================================================
|
|
|
|
// DoEvalBase() - implements CNTK "eval" command
|
|
|
|
// ===========================================================================
|
|
|
|
|
|
|
|
template <typename ElemType>
|
2016-02-29 06:01:07 +03:00
|
|
|
static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)
|
2016-01-10 05:13:31 +03:00
|
|
|
{
|
2016-06-07 04:22:10 +03:00
|
|
|
//DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
2016-01-10 05:13:31 +03:00
|
|
|
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
|
|
|
size_t epochSize = config(L"epochSize", "0");
|
|
|
|
if (epochSize == 0)
|
|
|
|
{
|
|
|
|
epochSize = requestDataSize;
|
|
|
|
}
|
|
|
|
wstring modelPath = config(L"modelPath");
|
|
|
|
intargvector mbSize = minibatchSize;
|
|
|
|
|
2016-09-05 01:56:55 +03:00
|
|
|
int traceLevel = config(L"traceLevel", 0);
|
2016-01-10 05:13:31 +03:00
|
|
|
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
2016-06-02 14:46:10 +03:00
|
|
|
size_t firstMBsToShowResult = config(L"firstMBsToShowResult", "0");
|
2016-02-25 12:56:59 +03:00
|
|
|
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
|
|
|
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
2016-01-10 05:13:31 +03:00
|
|
|
|
2016-10-02 17:00:55 +03:00
|
|
|
bool enableDistributedMBReading = config(L"distributedMBReading", GetDistributedMBReadingDefaultValue(config, reader));
|
2016-04-26 05:13:47 +03:00
|
|
|
|
2016-06-07 04:22:10 +03:00
|
|
|
vector<wstring> evalNodeNamesVector;
|
|
|
|
|
|
|
|
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"evalNodeNames", evalNodeNamesVector);
|
|
|
|
|
2016-03-30 01:32:54 +03:00
|
|
|
// set tracing flags
|
|
|
|
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
|
|
|
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
|
|
|
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
|
|
|
|
2016-06-02 14:46:10 +03:00
|
|
|
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), enableDistributedMBReading, numMBsToShowResult,
|
2016-06-07 04:22:10 +03:00
|
|
|
firstMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
2016-01-10 05:13:31 +03:00
|
|
|
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename ElemType>
|
|
|
|
void DoEval(const ConfigParameters& config)
|
|
|
|
{
|
2016-01-23 00:58:47 +03:00
|
|
|
// test
|
2016-01-10 05:13:31 +03:00
|
|
|
ConfigParameters readerConfig(config(L"reader"));
|
|
|
|
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
2016-09-20 01:57:47 +03:00
|
|
|
if (!readerConfig.ExistsCurrent(L"randomize"))
|
|
|
|
{
|
|
|
|
readerConfig.Insert("randomize", "None");
|
|
|
|
}
|
2016-01-10 05:13:31 +03:00
|
|
|
|
2016-02-29 06:01:07 +03:00
|
|
|
DataReader testDataReader(readerConfig);
|
|
|
|
DoEvalBase<ElemType>(config, testDataReader);
|
2016-01-10 05:13:31 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
template void DoEval<double>(const ConfigParameters& config);
|
|
|
|
template void DoEval<float>(const ConfigParameters& config);
|
|
|
|
|
|
|
|
// ===========================================================================
|
|
|
|
// DoCrossValidate() - implements CNTK "cv" command
|
|
|
|
// ===========================================================================
|
|
|
|
|
|
|
|
template <typename ElemType>
|
|
|
|
void DoCrossValidate(const ConfigParameters& config)
|
|
|
|
{
|
2016-01-23 00:58:47 +03:00
|
|
|
// test
|
2016-01-10 05:13:31 +03:00
|
|
|
ConfigParameters readerConfig(config(L"reader"));
|
|
|
|
readerConfig.Insert("traceLevel", config(L"traceLevel", "0"));
|
|
|
|
|
|
|
|
DEVICEID_TYPE deviceId = DeviceFromConfig(config);
|
|
|
|
ConfigArray minibatchSize = config(L"minibatchSize", "40960");
|
|
|
|
size_t epochSize = config(L"epochSize", "0");
|
|
|
|
if (epochSize == 0)
|
|
|
|
{
|
|
|
|
epochSize = requestDataSize;
|
|
|
|
}
|
|
|
|
wstring modelPath = config(L"modelPath");
|
|
|
|
intargvector mbSize = minibatchSize;
|
|
|
|
|
|
|
|
ConfigArray cvIntervalConfig = config(L"crossValidationInterval");
|
|
|
|
intargvector cvInterval = cvIntervalConfig;
|
|
|
|
|
|
|
|
size_t sleepSecondsBetweenRuns = config(L"sleepTimeBetweenRuns", "0");
|
|
|
|
|
2016-09-05 01:56:55 +03:00
|
|
|
int traceLevel = config(L"traceLevel", 0);
|
2016-01-10 05:13:31 +03:00
|
|
|
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");
|
2016-06-06 11:09:36 +03:00
|
|
|
size_t firstMBsToShowResult = config(L"firstMBsToShowResult", "0");
|
2016-04-13 09:09:08 +03:00
|
|
|
size_t maxSamplesInRAM = config(L"maxSamplesInRAM", (size_t)SIZE_MAX);
|
|
|
|
size_t numSubminiBatches = config(L"numSubminibatches", (size_t)1);
|
2016-01-10 05:13:31 +03:00
|
|
|
|
|
|
|
ConfigArray evalNodeNames = config(L"evalNodeNames", "");
|
|
|
|
vector<wstring> evalNodeNamesVector;
|
|
|
|
for (int i = 0; i < evalNodeNames.size(); ++i)
|
|
|
|
{
|
|
|
|
evalNodeNamesVector.push_back(evalNodeNames[i]);
|
|
|
|
}
|
|
|
|
|
2016-04-13 09:09:08 +03:00
|
|
|
std::vector<std::vector<EpochCriterion>> cvErrorResults;
|
2016-01-10 05:13:31 +03:00
|
|
|
std::vector<std::wstring> cvModels;
|
|
|
|
|
2016-02-29 06:01:07 +03:00
|
|
|
DataReader cvDataReader(readerConfig);
|
2016-01-10 05:13:31 +03:00
|
|
|
|
2016-10-02 17:00:55 +03:00
|
|
|
bool enableDistributedMBReading = config(L"distributedMBReading", GetDistributedMBReadingDefaultValue(config, cvDataReader));
|
|
|
|
|
2016-01-10 05:13:31 +03:00
|
|
|
bool finalModelEvaluated = false;
|
|
|
|
for (size_t i = cvInterval[0]; i <= cvInterval[2]; i += cvInterval[1])
|
|
|
|
{
|
|
|
|
wstring cvModelPath = msra::strfun::wstrprintf(L"%ls.%lld", modelPath.c_str(), i);
|
|
|
|
|
|
|
|
if (!fexists(cvModelPath))
|
|
|
|
{
|
2016-04-13 09:09:08 +03:00
|
|
|
fprintf(stderr, "Model %ls does not exist.\n", cvModelPath.c_str());
|
2016-01-10 05:13:31 +03:00
|
|
|
if (finalModelEvaluated || !fexists(modelPath))
|
|
|
|
continue; // file missing
|
|
|
|
else
|
|
|
|
{
|
|
|
|
cvModelPath = modelPath;
|
|
|
|
finalModelEvaluated = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
cvModels.push_back(cvModelPath);
|
|
|
|
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, cvModelPath);
|
2016-06-07 04:22:10 +03:00
|
|
|
// BUGBUG: ^^ Should use GetModelFromConfig()
|
2016-10-17 15:07:22 +03:00
|
|
|
|
2016-06-06 11:09:36 +03:00
|
|
|
SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), enableDistributedMBReading, numMBsToShowResult,
|
|
|
|
firstMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
|
2016-01-10 05:13:31 +03:00
|
|
|
|
2016-04-13 09:09:08 +03:00
|
|
|
fprintf(stderr, "Model %ls --> \n", cvModelPath.c_str());
|
2016-01-10 05:13:31 +03:00
|
|
|
auto evalErrors = eval.Evaluate(&cvDataReader, evalNodeNamesVector, mbSize[0], epochSize);
|
|
|
|
cvErrorResults.push_back(evalErrors);
|
|
|
|
|
|
|
|
::Sleep(1000 * sleepSecondsBetweenRuns);
|
|
|
|
}
|
|
|
|
|
2016-01-23 00:58:47 +03:00
|
|
|
// find best model
|
2016-01-10 05:13:31 +03:00
|
|
|
if (cvErrorResults.size() == 0)
|
|
|
|
LogicError("No model is evaluated.");
|
|
|
|
|
2016-04-13 09:09:08 +03:00
|
|
|
vector<double> minErrors;
|
|
|
|
vector<int> minErrIds;
|
|
|
|
vector<EpochCriterion> evalErrors = cvErrorResults[0];
|
2016-01-10 05:13:31 +03:00
|
|
|
for (int i = 0; i < evalErrors.size(); ++i)
|
|
|
|
{
|
2016-04-13 09:09:08 +03:00
|
|
|
minErrors.push_back(evalErrors[i].Average());
|
2016-01-10 05:13:31 +03:00
|
|
|
minErrIds.push_back(0);
|
|
|
|
}
|
|
|
|
|
2016-01-18 11:36:14 +03:00
|
|
|
for (int i = 0; i < cvErrorResults.size(); i++)
|
2016-01-10 05:13:31 +03:00
|
|
|
{
|
|
|
|
evalErrors = cvErrorResults[i];
|
2016-01-18 11:36:14 +03:00
|
|
|
for (int j = 0; j < evalErrors.size(); j++)
|
2016-01-10 05:13:31 +03:00
|
|
|
{
|
2016-04-13 09:09:08 +03:00
|
|
|
if (evalErrors[j].Average() < minErrors[j])
|
2016-01-10 05:13:31 +03:00
|
|
|
{
|
2016-04-13 09:09:08 +03:00
|
|
|
minErrors[j] = evalErrors[j].Average();
|
2016-01-10 05:13:31 +03:00
|
|
|
minErrIds[j] = i;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fprintf(stderr, "Best models:\n");
|
|
|
|
fprintf(stderr, "------------\n");
|
|
|
|
for (int i = 0; i < minErrors.size(); ++i)
|
|
|
|
fprintf(stderr, "Based on Err[%d]: Best model = %ls with min err %.8g\n", i, cvModels[minErrIds[i]].c_str(), minErrors[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
template void DoCrossValidate<float>(const ConfigParameters& config);
|
|
|
|
template void DoCrossValidate<double>(const ConfigParameters& config);
|
|
|
|
|
|
|
|
// ===========================================================================
|
|
|
|
// DoWriteOutput() - implements CNTK "write" command
|
|
|
|
// ===========================================================================
|
|
|
|
|
|
|
|
template <typename ElemType>
|
|
|
|
void DoWriteOutput(const ConfigParameters& config)
|
|
|
|
{
|
|
|
|
ConfigParameters readerConfig(config(L"reader"));
|
2016-01-23 00:58:47 +03:00
|
|
|
readerConfig.Insert("randomize", "None"); // we don't want randomization when output results
|
2016-01-10 05:13:31 +03:00
|
|
|
|
2016-02-29 06:01:07 +03:00
|
|
|
DataReader testDataReader(readerConfig);
|
2016-01-10 05:13:31 +03:00
|
|
|
|
|
|
|
ConfigArray minibatchSize = config(L"minibatchSize", "2048");
|
|
|
|
intargvector mbSize = minibatchSize;
|
|
|
|
|
|
|
|
size_t epochSize = config(L"epochSize", "0");
|
|
|
|
if (epochSize == 0)
|
|
|
|
{
|
|
|
|
epochSize = requestDataSize;
|
|
|
|
}
|
|
|
|
|
|
|
|
vector<wstring> outputNodeNamesVector;
|
|
|
|
|
2016-06-07 04:22:10 +03:00
|
|
|
let net = GetModelFromConfig<ConfigParameters, ElemType>(config, L"outputNodeNames", outputNodeNamesVector);
|
2016-01-10 05:13:31 +03:00
|
|
|
|
2016-03-30 01:32:54 +03:00
|
|
|
// set tracing flags
|
|
|
|
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
|
|
|
|
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
|
|
|
|
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));
|
|
|
|
|
2016-01-10 05:13:31 +03:00
|
|
|
SimpleOutputWriter<ElemType> writer(net, 1);
|
|
|
|
|
|
|
|
if (config.Exists("writer"))
|
|
|
|
{
|
|
|
|
ConfigParameters writerConfig(config(L"writer"));
|
2016-03-18 21:25:11 +03:00
|
|
|
bool writerUnittest = writerConfig(L"unittest", "false");
|
2016-02-29 06:01:07 +03:00
|
|
|
DataWriter testDataWriter(writerConfig);
|
2016-03-18 21:25:11 +03:00
|
|
|
writer.WriteOutput(testDataReader, mbSize[0], testDataWriter, outputNodeNamesVector, epochSize, writerUnittest);
|
2016-01-10 05:13:31 +03:00
|
|
|
}
|
|
|
|
else if (config.Exists("outputPath"))
|
|
|
|
{
|
2016-02-15 21:43:03 +03:00
|
|
|
wstring outputPath = config(L"outputPath");
|
2016-04-03 09:29:40 +03:00
|
|
|
WriteFormattingOptions formattingOptions(config);
|
2016-03-18 21:25:11 +03:00
|
|
|
bool nodeUnitTest = config(L"nodeUnitTest", "false");
|
|
|
|
writer.WriteOutput(testDataReader, mbSize[0], outputPath, outputNodeNamesVector, formattingOptions, epochSize, nodeUnitTest);
|
2016-01-10 05:13:31 +03:00
|
|
|
}
|
2016-02-15 21:43:03 +03:00
|
|
|
else
|
|
|
|
InvalidArgument("write command: You must specify either 'writer'or 'outputPath'");
|
2016-01-10 05:13:31 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
template void DoWriteOutput<float>(const ConfigParameters& config);
|
|
|
|
template void DoWriteOutput<double>(const ConfigParameters& config);
|