CNTK v2 library: Use dynamic loading of CompositeDataReader instead of a static dependency

This commit is contained in:
Amit Agarwal 2016-07-24 13:02:55 -07:00 коммит произвёл Amit
Родитель 5b1c217688
Коммит 4164856b1f
16 изменённых файлов: 158 добавлений и 91 удалений

Просмотреть файл

@ -1118,7 +1118,6 @@ EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKv2LibraryDll", "Source\CNTKv2LibraryDll\CNTKv2LibraryDll.vcxproj", "{E5606ECE-48CA-4464-BB12-09D81D02B9EF}"
ProjectSection(ProjectDependencies) = postProject
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
{7B7A563D-AA8E-4660-A805-D50235A02120} = {7B7A563D-AA8E-4660-A805-D50235A02120}
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
{86883653-8A61-4038-81A0-2379FAE4200A} = {86883653-8A61-4038-81A0-2379FAE4200A}
{F0A9637C-20DA-42F0-83D4-23B4704DE602} = {F0A9637C-20DA-42F0-83D4-23B4704DE602}

Просмотреть файл

@ -371,13 +371,14 @@ CNTKLIBRARY_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/BackCompat.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Common.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/MinibatchSource.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC)
CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC)
@ -408,6 +409,7 @@ CNTKLIBRARY_TESTS_SRC =\
Tests/UnitTests/V2LibraryTests/NDArrayViewTests.cpp \
Tests/UnitTests/V2LibraryTests/RecurrentFunctionTests.cpp \
Tests/UnitTests/V2LibraryTests/TensorTests.cpp \
Tests/UnitTests/V2LibraryTests/TrainerTests.cpp \
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests
CNTKLIBRARY_TESTS_OBJ := $(patsubst %.cu, $(OBJDIR)/%.o, $(patsubst %.cpp, $(OBJDIR)/%.o, $(CNTKLIBRARY_TESTS_SRC)))

Просмотреть файл

@ -102,9 +102,15 @@ namespace CNTK
// RuntimeError - throw a std::runtime_error with a formatted error string
#ifndef _MSC_VER // gcc __attribute__((format(printf())) does not percolate through variadic templates; so must go the macro route
#ifndef RuntimeError
#define RuntimeError ThrowFormatted<std::runtime_error>
#endif
#ifndef LogicError
#define LogicError ThrowFormatted<std::logic_error>
#endif
#ifndef InvalidArgument
#define InvalidArgument ThrowFormatted<std::invalid_argument>
#endif
#else
template <class... _Types>
__declspec_noreturn inline void RuntimeError(const char* format, _Types&&... _Args)

Просмотреть файл

@ -59,7 +59,7 @@ namespace CNTK
// TODO: Currently only default dynamic axis is supported
const std::wstring defaultCNTKDynamicAxisName = L"";
if (inputNode->GetRequestedDynamicAxis() != defaultCNTKDynamicAxisName)
LogicError("Found dynamic axis named '%S' while currently only default dynamic axis named '%S' is supported!", node->GetMBLayout()->GetAxisName(), defaultCNTKDynamicAxisName);
LogicError("Found dynamic axis named '%S' while currently only default dynamic axis named '%S' is supported!", node->GetMBLayout()->GetAxisName(), defaultCNTKDynamicAxisName.c_str());
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName());
}

Просмотреть файл

@ -56,7 +56,7 @@
</PropertyGroup>
<ItemDefinitionGroup>
<ClCompile>
<AdditionalIncludeDirectories>.\API;$(SolutionDir)Source\SGDLib;$(SolutionDir)Source\Readers\ReaderLib;$(SolutionDir)Source\Readers\CompositeDataReader;$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\SequenceTrainingLib;$(SolutionDir)Source\Math;$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\CNTK\BrainScript;$(SolutionDir)Source\ActionsLib;$(MSMPI_INC);$(NvmlInclude)</AdditionalIncludeDirectories>
<AdditionalIncludeDirectories>.\API;$(SolutionDir)Source\SGDLib;$(SolutionDir)Source\Readers\ReaderLib;$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\SequenceTrainingLib;$(SolutionDir)Source\Math;$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\CNTK\BrainScript;$(SolutionDir)Source\ActionsLib;$(MSMPI_INC);$(NvmlInclude)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<AdditionalLibraryDirectories>$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\Math;$(MSMPI_LIB64);$(SolutionDir)$(Platform)\$(Configuration);$(NvmlLibPath)</AdditionalLibraryDirectories>
@ -75,7 +75,7 @@
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; SequenceTrainingLib.lib; ReaderLib.lib; CompositeDataReader.lib; kernel32.lib; user32.lib; shell32.lib; %(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; SequenceTrainingLib.lib; ReaderLib.lib; kernel32.lib; user32.lib; shell32.lib; %(AdditionalDependencies)</AdditionalDependencies>
<DelayLoadDLLs>Math.dll; nvml.dll; $(CudaRuntimeDll)</DelayLoadDLLs>
</Link>
</ItemDefinitionGroup>
@ -99,7 +99,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; ReaderLib.lib; CompositeDataReader.lib; kernel32.lib; user32.lib; shell32.lib; SequenceTrainingLib.lib; %(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; ReaderLib.lib; kernel32.lib; user32.lib; shell32.lib; SequenceTrainingLib.lib; %(AdditionalDependencies)</AdditionalDependencies>
<Profile>true</Profile>
<DelayLoadDLLs>Math.dll; nvml.dll; $(CudaRuntimeDll)</DelayLoadDLLs>
</Link>

Просмотреть файл

@ -8,7 +8,6 @@
#include "Utils.h"
#include "Config.h"
#include "MinibatchSource.h"
#include "CompositeDataReader.h"
#include "HeapMemoryProvider.h"
#include "ReaderShim.h"
#include "Function.h"
@ -39,7 +38,9 @@ namespace CNTK
m_epochSize = configuration[epochSizeConfigurationKey].GetValue<size_t>();
m_compositeDataReader.reset(new CompositeDataReader(config, std::make_shared<HeapMemoryProvider>()));
typedef Reader*(*CreateCompositeDataReaderProc)(const ConfigParameters* parameters);
CreateCompositeDataReaderProc createReaderProc = (CreateCompositeDataReaderProc)Plugin().Load(L"CompositeDataReader", "CreateCompositeDataReader");
m_compositeDataReader.reset(createReaderProc(&config));
auto compositeDataReaderStreamDescs = m_compositeDataReader->GetStreamDescriptions();
for (auto streamDesc : compositeDataReaderStreamDescs)

Просмотреть файл

@ -8,7 +8,7 @@
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Utils.h"
#include "CompositeDataReader.h"
#include "Reader.h"
namespace CNTK
{
@ -23,7 +23,7 @@ namespace CNTK
private:
std::unordered_set<StreamInfo> m_streamInfos;
std::shared_ptr<Microsoft::MSR::CNTK::CompositeDataReader> m_compositeDataReader;
std::shared_ptr<Microsoft::MSR::CNTK::Reader> m_compositeDataReader;
bool m_startNewEpoch;
size_t m_nextEpochIndex;
size_t m_prevMinibatchSize;

Просмотреть файл

@ -21,7 +21,7 @@ namespace CNTK
{
auto insertRetVal = learnerParameters.insert(parameter);
if (!insertRetVal.second)
InvalidArgument("Trainer::Trainer: Parameter named %S is covered by 2 different learners", parameter.Name());
InvalidArgument("Trainer::Trainer: Parameter named %S is covered by 2 different learners", parameter.Name().c_str());
}
}

Просмотреть файл

@ -308,4 +308,11 @@ namespace CNTK
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
template void DictionaryValue::AllocateDataPtr<wstring>(const wstring& value);
template void DictionaryValue::AllocateDataPtr<Dictionary>(const Dictionary& value);
template void DictionaryValue::FreePtrAsType<NDShape>();
template void DictionaryValue::FreePtrAsType<vector<DictionaryValue>>();
template void DictionaryValue::FreePtrAsType<wstring>();
template void DictionaryValue::FreePtrAsType<Dictionary>();
}

Просмотреть файл

@ -84,9 +84,15 @@ __declspec_noreturn static inline void ThrowFormatted(const char* format, ...)
// RuntimeError - throw a std::runtime_error with a formatted error string
#ifndef _MSC_VER // gcc __attribute__((format(printf())) does not percolate through variadic templates; so must go the macro route
#ifndef RuntimeError
#define RuntimeError ThrowFormatted<std::runtime_error>
#endif
#ifndef LogicError
#define LogicError ThrowFormatted<std::logic_error>
#endif
#ifndef InvalidArgument
#define InvalidArgument ThrowFormatted<std::invalid_argument>
#endif
#else
template <class... _Types>
__declspec_noreturn static inline void RuntimeError(const char* format, _Types&&... _Args)

Просмотреть файл

@ -8,7 +8,6 @@
#include <map>
#include <string>
#include <future>
#define DATAREADER_EXPORTS
#include "DataReader.h"
#include "Reader.h"
#include "Transformer.h"
@ -56,7 +55,7 @@ struct Minibatch;
class CompositeDataReader : public Reader, protected Plugin
{
public:
DATAREADER_API CompositeDataReader(const ConfigParameters& parameters, MemoryProviderPtr provider);
CompositeDataReader(const ConfigParameters& parameters, MemoryProviderPtr provider);
// Describes the streams this reader produces.
std::vector<StreamDescriptionPtr> GetStreamDescriptions() override;

Просмотреть файл

@ -45,4 +45,9 @@ extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
*preader = new CompositeReaderShim<double>(factory);
}
extern "C" DATAREADER_API Reader* CreateCompositeDataReader(const ConfigParameters* parameters)
{
return new CompositeDataReader(*parameters, std::make_shared<HeapMemoryProvider>());
}
}}}

Просмотреть файл

@ -1,7 +1,39 @@
#!/bin/bash
if [ "$OS" == "Windows_NT" ]; then
$TEST_BIN_DIR/V2LibraryTests.exe || exit $?
else
$TEST_BIN_DIR/v2librarytests || exit $?
. $TEST_ROOT_DIR/run-test-common
# This test uses a large dataset which is not part of the CNTK repository itself
# We use the dataset from an external location specified using an environment variable
if [[ "$CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY" == "" || ! -d "$CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY" ]]; then
echo 'This test uses external data that is not part of the CNTK repository. Environment variable CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY must be set to point to the external test data location'
exit 1
fi
if [ "$OS" == "Windows_NT" ]; then
DataSourceDir=`cygpath -au $CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY`/Image/MNIST/v0
else
DataSourceDir=$CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY/Image/MNIST/v0
fi
# Copy the test data to the test run directory
DataDir=$TEST_RUN_DIR/TestData
mkdir $DataDir
cp -R $DataSourceDir/Train-28x28_cntk_text.txt $DataDir || exit $?
cp -R $TEST_DIR/../../../../Examples/Other/Simple2d/Data/SimpleDataTrain_cntk_text.txt $DataDir || exit $?
pushd $DataDir
if [ "$OS" == "Windows_NT" ]; then
$TEST_BIN_DIR/V2LibraryTests.exe
else
$TEST_BIN_DIR/v2librarytests
fi
ExitCode=$?
# Delete the test data
popd
rm -rf $DataDir
exit $ExitCode

Просмотреть файл

@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python
# ----------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

Просмотреть файл

@ -24,34 +24,52 @@ inline void FloatingPointVectorCompare(const std::vector<ElementType>& first, co
#pragma warning(push)
#pragma warning(disable: 4996)
#ifndef _MSC_VER
#include <unistd.h>
static inline std::string wtocharpath(const wchar_t *p)
{
size_t len = wcslen(p);
std::string buf;
buf.resize(2 * len + 1); // max: 1 wchar => 2 mb chars
::wcstombs(&buf[0], p, buf.size()); // note: technically it is forbidden to stomp over std::strings 0 terminator, but it is known to work in all implementations
buf.resize(strlen(&buf[0])); // set size correctly for shorter strings
return buf;
}
static inline int _wunlink(const wchar_t *p)
{
return unlink(wtocharpath(p).c_str());
}
#endif
template <typename ElementType>
inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector<CNTK::Variable*>& variables, const CNTK::DeviceDescriptor& device)
{
static std::wstring s_tempModelPath = L"feedForward.net";
if ((_wunlink(s_tempModelPath.c_str()) != 0) && (errno != ENOENT))
RuntimeError("Error deleting file '%ls': %s", s_tempModelPath.c_str(), strerror(errno));
std::runtime_error("Error deleting temp model file 'feedForward.net'");
std::unordered_map<std::wstring, Variable*> inputVarNames;
std::unordered_map<std::wstring, Variable*> outputVarNames;
std::unordered_map<std::wstring, CNTK::Variable*> inputVarNames;
std::unordered_map<std::wstring, CNTK::Variable*> outputVarNames;
for (auto varPtr : variables)
{
auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarNames.insert({ varPtr->Name(), varPtr });
if (!retVal.second)
RuntimeError("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload");
std::runtime_error("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload");
}
SaveAsLegacyModel<ElementType>(functionPtr, s_tempModelPath);
functionPtr = LoadLegacyModel<ElementType>(s_tempModelPath, device);
CNTK::SaveAsLegacyModel<ElementType>(functionPtr, s_tempModelPath);
functionPtr = CNTK::LoadLegacyModel<ElementType>(s_tempModelPath, device);
if (_wunlink(s_tempModelPath.c_str()) != 0)
RuntimeError("Error deleting file '%ls': %s", s_tempModelPath.c_str(), strerror(errno));
std::runtime_error("Error deleting temp model file 'feedForward.net'");
auto inputs = functionPtr->Inputs();
for (auto inputVarInfo : inputVarNames)
{
auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const Variable& var) {
auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const CNTK::Variable& var) {
return (var.Name() == inputVarInfo.first);
}));
@ -61,7 +79,7 @@ inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector
auto outputs = functionPtr->Outputs();
for (auto outputVarInfo : outputVarNames)
{
auto newOutputVar = *(std::find_if(outputs.begin(), outputs.end(), [outputVarInfo](const Variable& var) {
auto newOutputVar = *(std::find_if(outputs.begin(), outputs.end(), [outputVarInfo](const CNTK::Variable& var) {
return (var.Owner()->Name() == outputVarInfo.first);
}));

Просмотреть файл

@ -6,6 +6,43 @@ using namespace CNTK;
using namespace std::placeholders;
MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath, size_t featureDim, size_t labelDim, size_t epochSize)
{
Dictionary featuresStreamConfig;
featuresStreamConfig[L"dim"] = featureDim;
featuresStreamConfig[L"format"] = L"dense";
Dictionary labelsStreamConfig;
labelsStreamConfig[L"dim"] = labelDim;
labelsStreamConfig[L"format"] = L"dense";
Dictionary inputStreamsConfig;
inputStreamsConfig[L"features"] = featuresStreamConfig;
inputStreamsConfig[L"labels"] = labelsStreamConfig;
Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"module"] = L"CNTKTextFormatReader";
deserializerConfiguration[L"file"] = filePath;
deserializerConfiguration[L"input"] = inputStreamsConfig;
Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"epochSize"] = epochSize;
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
}
float PrevMinibatchTrainingLossValue(const Trainer& trainer)
{
float trainLossValue = 0.0;
auto prevMBTrainingLossValue = trainer.PreviousMinibatchTrainingLossValue()->Data();
NDArrayView cpuTrainLossValue(prevMBTrainingLossValue->Shape(), &trainLossValue, 1, DeviceDescriptor::CPUDevice());
cpuTrainLossValue.CopyFrom(*prevMBTrainingLossValue);
return trainLossValue;
}
void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
{
const size_t inputDim = 2;
@ -34,52 +71,28 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
const size_t numSweepsToTrainWith = 2;
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
Dictionary featuresStreamConfig;
featuresStreamConfig[L"dim"] = 2ULL;
featuresStreamConfig[L"format"] = L"dense";
Dictionary labelsStreamConfig;
labelsStreamConfig[L"dim"] = 2ULL;
labelsStreamConfig[L"format"] = L"dense";
Dictionary inputStreamsConfig;
inputStreamsConfig[L"features"] = featuresStreamConfig;
inputStreamsConfig[L"labels"] = labelsStreamConfig;
Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"module"] = L"CNTKTextFormatReader";
deserializerConfiguration[L"file"] = L"SimpleDataTrain_cntk_text.txt";
deserializerConfiguration[L"input"] = inputStreamsConfig;
Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"epochSize"] = numSamplesPerSweep;
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
auto minibatchSource = CreateCompositeMinibatchSource(minibatchSourceConfiguration);
auto minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, numSamplesPerSweep);
auto streamInfos = minibatchSource->StreamInfos();
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) {
return (streamInfo.m_name == L"features");
});
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) {
return (streamInfo.m_name == L"labels");
});
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) { return (streamInfo.m_name == L"features"); });
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) { return (streamInfo.m_name == L"labels"); });
double learningRatePerSample = 0.02;
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
std::unordered_map<StreamInfo, std::pair<size_t, ValuePtr>> minibatchData = { { *featureStreamInfo, { minibatchSize, nullptr } }, { *labelStreamInfo, { minibatchSize, nullptr } } };
size_t outputFrequencyInMinibatches = 20;
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{
minibatchSource->GetNextMinibatch(minibatchData);
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].second }, { labels, minibatchData[*labelStreamInfo].second } }, device);
float trainLossValue = 0.0;
auto prevMBTrainingLossValue = trainer.PreviousMinibatchTrainingLossValue()->Data();
NDArrayView cpuTrainLossValue(prevMBTrainingLossValue->Shape(), &trainLossValue, 1, DeviceDescriptor::CPUDevice());
cpuTrainLossValue.CopyFrom(*prevMBTrainingLossValue);
printf("Minibatch %d: CrossEntropy loss = %.8g\n", i, trainLossValue);
if ((i % outputFrequencyInMinibatches) == 0)
{
float trainLossValue = PrevMinibatchTrainingLossValue(trainer);
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
}
}
}
void TrainMNISTClassifier(const DeviceDescriptor& device)
{
@ -105,30 +118,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
const size_t numSweepsToTrainWith = 3;
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
Dictionary featuresStreamConfig;
featuresStreamConfig[L"dim"] = 784ULL;
featuresStreamConfig[L"format"] = L"dense";
Dictionary labelsStreamConfig;
labelsStreamConfig[L"dim"] = 10ULL;
labelsStreamConfig[L"format"] = L"dense";
Dictionary inputStreamsConfig;
inputStreamsConfig[L"features"] = featuresStreamConfig;
inputStreamsConfig[L"labels"] = labelsStreamConfig;
Dictionary deserializerConfiguration;
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
deserializerConfiguration[L"module"] = L"CNTKTextFormatReader";
deserializerConfiguration[L"file"] = L"Train-28x28_cntk_text.txt";
deserializerConfiguration[L"input"] = inputStreamsConfig;
Dictionary minibatchSourceConfiguration;
minibatchSourceConfiguration[L"randomize"] = true;
minibatchSourceConfiguration[L"epochSize"] = numSamplesPerSweep;
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
auto minibatchSource = CreateCompositeMinibatchSource(minibatchSourceConfiguration);
auto minibatchSource = CreateTextMinibatchSource(L"Train-28x28_cntk_text.txt", (size_t)784, (size_t)10, numSamplesPerSweep);
auto streamInfos = minibatchSource->StreamInfos();
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) {
@ -141,15 +131,17 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
double learningRatePerSample = 0.003125;
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
std::unordered_map<StreamInfo, std::pair<size_t, ValuePtr>> minibatchData = { { *featureStreamInfo, { minibatchSize, nullptr } }, { *labelStreamInfo, { minibatchSize, nullptr } } };
size_t outputFrequencyInMinibatches = 20;
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
{
minibatchSource->GetNextMinibatch(minibatchData);
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].second }, { labels, minibatchData[*labelStreamInfo].second } }, device);
float trainLossValue = 0.0;
auto prevMBTrainingLossValue = trainer.PreviousMinibatchTrainingLossValue()->Data();
NDArrayView cpuTrainLossValue(prevMBTrainingLossValue->Shape(), &trainLossValue, 1, DeviceDescriptor::CPUDevice());
cpuTrainLossValue.CopyFrom(*prevMBTrainingLossValue);
printf("Minibatch %d: CrossEntropy loss = %.8g\n", i, trainLossValue);
if ((i % outputFrequencyInMinibatches) == 0)
{
float trainLossValue = PrevMinibatchTrainingLossValue(trainer);
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
}
}
}