CNTK v2 library: Use dynamic loading of CompositeDataReader instead of a static dependency
This commit is contained in:
Родитель
5b1c217688
Коммит
4164856b1f
1
CNTK.sln
1
CNTK.sln
|
@ -1118,7 +1118,6 @@ EndProject
|
|||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CNTKv2LibraryDll", "Source\CNTKv2LibraryDll\CNTKv2LibraryDll.vcxproj", "{E5606ECE-48CA-4464-BB12-09D81D02B9EF}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{928ABD1B-4D3B-4017-AEF1-0FA1B4467513} = {928ABD1B-4D3B-4017-AEF1-0FA1B4467513}
|
||||
{7B7A563D-AA8E-4660-A805-D50235A02120} = {7B7A563D-AA8E-4660-A805-D50235A02120}
|
||||
{60BDB847-D0C4-4FD3-A947-0C15C08BCDB5} = {60BDB847-D0C4-4FD3-A947-0C15C08BCDB5}
|
||||
{86883653-8A61-4038-81A0-2379FAE4200A} = {86883653-8A61-4038-81A0-2379FAE4200A}
|
||||
{F0A9637C-20DA-42F0-83D4-23B4704DE602} = {F0A9637C-20DA-42F0-83D4-23B4704DE602}
|
||||
|
|
6
Makefile
6
Makefile
|
@ -371,13 +371,14 @@ CNTKLIBRARY_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/BackCompat.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Common.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Function.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/MinibatchSource.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/NDArrayView.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/NDMask.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Trainer.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
|
||||
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
|
||||
|
||||
CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC)
|
||||
CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC)
|
||||
|
@ -408,6 +409,7 @@ CNTKLIBRARY_TESTS_SRC =\
|
|||
Tests/UnitTests/V2LibraryTests/NDArrayViewTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/RecurrentFunctionTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/TensorTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/TrainerTests.cpp \
|
||||
|
||||
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests
|
||||
CNTKLIBRARY_TESTS_OBJ := $(patsubst %.cu, $(OBJDIR)/%.o, $(patsubst %.cpp, $(OBJDIR)/%.o, $(CNTKLIBRARY_TESTS_SRC)))
|
||||
|
|
|
@ -102,9 +102,15 @@ namespace CNTK
|
|||
|
||||
// RuntimeError - throw a std::runtime_error with a formatted error string
|
||||
#ifndef _MSC_VER // gcc __attribute__((format(printf())) does not percolate through variadic templates; so must go the macro route
|
||||
#ifndef RuntimeError
|
||||
#define RuntimeError ThrowFormatted<std::runtime_error>
|
||||
#endif
|
||||
#ifndef LogicError
|
||||
#define LogicError ThrowFormatted<std::logic_error>
|
||||
#endif
|
||||
#ifndef InvalidArgument
|
||||
#define InvalidArgument ThrowFormatted<std::invalid_argument>
|
||||
#endif
|
||||
#else
|
||||
template <class... _Types>
|
||||
__declspec_noreturn inline void RuntimeError(const char* format, _Types&&... _Args)
|
||||
|
|
|
@ -59,7 +59,7 @@ namespace CNTK
|
|||
// TODO: Currently only default dynamic axis is supported
|
||||
const std::wstring defaultCNTKDynamicAxisName = L"";
|
||||
if (inputNode->GetRequestedDynamicAxis() != defaultCNTKDynamicAxisName)
|
||||
LogicError("Found dynamic axis named '%S' while currently only default dynamic axis named '%S' is supported!", node->GetMBLayout()->GetAxisName(), defaultCNTKDynamicAxisName);
|
||||
LogicError("Found dynamic axis named '%S' while currently only default dynamic axis named '%S' is supported!", node->GetMBLayout()->GetAxisName(), defaultCNTKDynamicAxisName.c_str());
|
||||
|
||||
var = Variable(varShape, isSparse, AsDataType<ElementType>(), node->GetLearningRateMultiplier() != 0, node->GetName());
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@
|
|||
</PropertyGroup>
|
||||
<ItemDefinitionGroup>
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>.\API;$(SolutionDir)Source\SGDLib;$(SolutionDir)Source\Readers\ReaderLib;$(SolutionDir)Source\Readers\CompositeDataReader;$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\SequenceTrainingLib;$(SolutionDir)Source\Math;$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\CNTK\BrainScript;$(SolutionDir)Source\ActionsLib;$(MSMPI_INC);$(NvmlInclude)</AdditionalIncludeDirectories>
|
||||
<AdditionalIncludeDirectories>.\API;$(SolutionDir)Source\SGDLib;$(SolutionDir)Source\Readers\ReaderLib;$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\SequenceTrainingLib;$(SolutionDir)Source\Math;$(SolutionDir)Source\Common\Include;$(SolutionDir)Source\CNTK\BrainScript;$(SolutionDir)Source\ActionsLib;$(MSMPI_INC);$(NvmlInclude)</AdditionalIncludeDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)Source\ComputationNetworkLib;$(SolutionDir)Source\Math;$(MSMPI_LIB64);$(SolutionDir)$(Platform)\$(Configuration);$(NvmlLibPath)</AdditionalLibraryDirectories>
|
||||
|
@ -75,7 +75,7 @@
|
|||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; SequenceTrainingLib.lib; ReaderLib.lib; CompositeDataReader.lib; kernel32.lib; user32.lib; shell32.lib; %(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; SequenceTrainingLib.lib; ReaderLib.lib; kernel32.lib; user32.lib; shell32.lib; %(AdditionalDependencies)</AdditionalDependencies>
|
||||
<DelayLoadDLLs>Math.dll; nvml.dll; $(CudaRuntimeDll)</DelayLoadDLLs>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
@ -99,7 +99,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; ReaderLib.lib; CompositeDataReader.lib; kernel32.lib; user32.lib; shell32.lib; SequenceTrainingLib.lib; %(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>ComputationNetworkLib.lib; Math.lib; Common.lib; ReaderLib.lib; kernel32.lib; user32.lib; shell32.lib; SequenceTrainingLib.lib; %(AdditionalDependencies)</AdditionalDependencies>
|
||||
<Profile>true</Profile>
|
||||
<DelayLoadDLLs>Math.dll; nvml.dll; $(CudaRuntimeDll)</DelayLoadDLLs>
|
||||
</Link>
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include "Utils.h"
|
||||
#include "Config.h"
|
||||
#include "MinibatchSource.h"
|
||||
#include "CompositeDataReader.h"
|
||||
#include "HeapMemoryProvider.h"
|
||||
#include "ReaderShim.h"
|
||||
#include "Function.h"
|
||||
|
@ -39,7 +38,9 @@ namespace CNTK
|
|||
|
||||
m_epochSize = configuration[epochSizeConfigurationKey].GetValue<size_t>();
|
||||
|
||||
m_compositeDataReader.reset(new CompositeDataReader(config, std::make_shared<HeapMemoryProvider>()));
|
||||
typedef Reader*(*CreateCompositeDataReaderProc)(const ConfigParameters* parameters);
|
||||
CreateCompositeDataReaderProc createReaderProc = (CreateCompositeDataReaderProc)Plugin().Load(L"CompositeDataReader", "CreateCompositeDataReader");
|
||||
m_compositeDataReader.reset(createReaderProc(&config));
|
||||
|
||||
auto compositeDataReaderStreamDescs = m_compositeDataReader->GetStreamDescriptions();
|
||||
for (auto streamDesc : compositeDataReaderStreamDescs)
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "CompositeDataReader.h"
|
||||
#include "Reader.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
@ -23,7 +23,7 @@ namespace CNTK
|
|||
|
||||
private:
|
||||
std::unordered_set<StreamInfo> m_streamInfos;
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::CompositeDataReader> m_compositeDataReader;
|
||||
std::shared_ptr<Microsoft::MSR::CNTK::Reader> m_compositeDataReader;
|
||||
bool m_startNewEpoch;
|
||||
size_t m_nextEpochIndex;
|
||||
size_t m_prevMinibatchSize;
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace CNTK
|
|||
{
|
||||
auto insertRetVal = learnerParameters.insert(parameter);
|
||||
if (!insertRetVal.second)
|
||||
InvalidArgument("Trainer::Trainer: Parameter named %S is covered by 2 different learners", parameter.Name());
|
||||
InvalidArgument("Trainer::Trainer: Parameter named %S is covered by 2 different learners", parameter.Name().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -308,4 +308,11 @@ namespace CNTK
|
|||
|
||||
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
|
||||
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
|
||||
template void DictionaryValue::AllocateDataPtr<wstring>(const wstring& value);
|
||||
template void DictionaryValue::AllocateDataPtr<Dictionary>(const Dictionary& value);
|
||||
|
||||
template void DictionaryValue::FreePtrAsType<NDShape>();
|
||||
template void DictionaryValue::FreePtrAsType<vector<DictionaryValue>>();
|
||||
template void DictionaryValue::FreePtrAsType<wstring>();
|
||||
template void DictionaryValue::FreePtrAsType<Dictionary>();
|
||||
}
|
||||
|
|
|
@ -84,9 +84,15 @@ __declspec_noreturn static inline void ThrowFormatted(const char* format, ...)
|
|||
|
||||
// RuntimeError - throw a std::runtime_error with a formatted error string
|
||||
#ifndef _MSC_VER // gcc __attribute__((format(printf())) does not percolate through variadic templates; so must go the macro route
|
||||
#ifndef RuntimeError
|
||||
#define RuntimeError ThrowFormatted<std::runtime_error>
|
||||
#endif
|
||||
#ifndef LogicError
|
||||
#define LogicError ThrowFormatted<std::logic_error>
|
||||
#endif
|
||||
#ifndef InvalidArgument
|
||||
#define InvalidArgument ThrowFormatted<std::invalid_argument>
|
||||
#endif
|
||||
#else
|
||||
template <class... _Types>
|
||||
__declspec_noreturn static inline void RuntimeError(const char* format, _Types&&... _Args)
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <future>
|
||||
#define DATAREADER_EXPORTS
|
||||
#include "DataReader.h"
|
||||
#include "Reader.h"
|
||||
#include "Transformer.h"
|
||||
|
@ -56,7 +55,7 @@ struct Minibatch;
|
|||
class CompositeDataReader : public Reader, protected Plugin
|
||||
{
|
||||
public:
|
||||
DATAREADER_API CompositeDataReader(const ConfigParameters& parameters, MemoryProviderPtr provider);
|
||||
CompositeDataReader(const ConfigParameters& parameters, MemoryProviderPtr provider);
|
||||
|
||||
// Describes the streams this reader produces.
|
||||
std::vector<StreamDescriptionPtr> GetStreamDescriptions() override;
|
||||
|
|
|
@ -45,4 +45,9 @@ extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
|
|||
*preader = new CompositeReaderShim<double>(factory);
|
||||
}
|
||||
|
||||
extern "C" DATAREADER_API Reader* CreateCompositeDataReader(const ConfigParameters* parameters)
|
||||
{
|
||||
return new CompositeDataReader(*parameters, std::make_shared<HeapMemoryProvider>());
|
||||
}
|
||||
|
||||
}}}
|
||||
|
|
|
@ -1,7 +1,39 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [ "$OS" == "Windows_NT" ]; then
|
||||
$TEST_BIN_DIR/V2LibraryTests.exe || exit $?
|
||||
else
|
||||
$TEST_BIN_DIR/v2librarytests || exit $?
|
||||
. $TEST_ROOT_DIR/run-test-common
|
||||
|
||||
# This test uses a large dataset which is not part of the CNTK repository itself
|
||||
# We use the dataset from an external location specified using an environment variable
|
||||
if [[ "$CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY" == "" || ! -d "$CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY" ]]; then
|
||||
echo 'This test uses external data that is not part of the CNTK repository. Environment variable CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY must be set to point to the external test data location'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "$OS" == "Windows_NT" ]; then
|
||||
DataSourceDir=`cygpath -au $CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY`/Image/MNIST/v0
|
||||
else
|
||||
DataSourceDir=$CNTK_EXTERNAL_TESTDATA_SOURCE_DIRECTORY/Image/MNIST/v0
|
||||
fi
|
||||
|
||||
# Copy the test data to the test run directory
|
||||
DataDir=$TEST_RUN_DIR/TestData
|
||||
mkdir $DataDir
|
||||
cp -R $DataSourceDir/Train-28x28_cntk_text.txt $DataDir || exit $?
|
||||
cp -R $TEST_DIR/../../../../Examples/Other/Simple2d/Data/SimpleDataTrain_cntk_text.txt $DataDir || exit $?
|
||||
|
||||
pushd $DataDir
|
||||
|
||||
if [ "$OS" == "Windows_NT" ]; then
|
||||
$TEST_BIN_DIR/V2LibraryTests.exe
|
||||
else
|
||||
$TEST_BIN_DIR/v2librarytests
|
||||
fi
|
||||
|
||||
ExitCode=$?
|
||||
|
||||
|
||||
# Delete the test data
|
||||
popd
|
||||
rm -rf $DataDir
|
||||
|
||||
exit $ExitCode
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
# ----------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# ---------------------------------------------------------
|
||||
|
|
|
@ -24,34 +24,52 @@ inline void FloatingPointVectorCompare(const std::vector<ElementType>& first, co
|
|||
#pragma warning(push)
|
||||
#pragma warning(disable: 4996)
|
||||
|
||||
#ifndef _MSC_VER
|
||||
#include <unistd.h>
|
||||
static inline std::string wtocharpath(const wchar_t *p)
|
||||
{
|
||||
size_t len = wcslen(p);
|
||||
std::string buf;
|
||||
buf.resize(2 * len + 1); // max: 1 wchar => 2 mb chars
|
||||
::wcstombs(&buf[0], p, buf.size()); // note: technically it is forbidden to stomp over std::strings 0 terminator, but it is known to work in all implementations
|
||||
buf.resize(strlen(&buf[0])); // set size correctly for shorter strings
|
||||
return buf;
|
||||
}
|
||||
|
||||
static inline int _wunlink(const wchar_t *p)
|
||||
{
|
||||
return unlink(wtocharpath(p).c_str());
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ElementType>
|
||||
inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector<CNTK::Variable*>& variables, const CNTK::DeviceDescriptor& device)
|
||||
{
|
||||
static std::wstring s_tempModelPath = L"feedForward.net";
|
||||
|
||||
if ((_wunlink(s_tempModelPath.c_str()) != 0) && (errno != ENOENT))
|
||||
RuntimeError("Error deleting file '%ls': %s", s_tempModelPath.c_str(), strerror(errno));
|
||||
std::runtime_error("Error deleting temp model file 'feedForward.net'");
|
||||
|
||||
std::unordered_map<std::wstring, Variable*> inputVarNames;
|
||||
std::unordered_map<std::wstring, Variable*> outputVarNames;
|
||||
std::unordered_map<std::wstring, CNTK::Variable*> inputVarNames;
|
||||
std::unordered_map<std::wstring, CNTK::Variable*> outputVarNames;
|
||||
|
||||
for (auto varPtr : variables)
|
||||
{
|
||||
auto retVal = varPtr->IsOutput() ? outputVarNames.insert({ varPtr->Owner()->Name(), varPtr }) : inputVarNames.insert({ varPtr->Name(), varPtr });
|
||||
if (!retVal.second)
|
||||
RuntimeError("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload");
|
||||
std::runtime_error("SaveAndReloadModel: Multiple variables having same name cannot be restored after save and reload");
|
||||
}
|
||||
|
||||
SaveAsLegacyModel<ElementType>(functionPtr, s_tempModelPath);
|
||||
functionPtr = LoadLegacyModel<ElementType>(s_tempModelPath, device);
|
||||
CNTK::SaveAsLegacyModel<ElementType>(functionPtr, s_tempModelPath);
|
||||
functionPtr = CNTK::LoadLegacyModel<ElementType>(s_tempModelPath, device);
|
||||
|
||||
if (_wunlink(s_tempModelPath.c_str()) != 0)
|
||||
RuntimeError("Error deleting file '%ls': %s", s_tempModelPath.c_str(), strerror(errno));
|
||||
std::runtime_error("Error deleting temp model file 'feedForward.net'");
|
||||
|
||||
auto inputs = functionPtr->Inputs();
|
||||
for (auto inputVarInfo : inputVarNames)
|
||||
{
|
||||
auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const Variable& var) {
|
||||
auto newInputVar = *(std::find_if(inputs.begin(), inputs.end(), [inputVarInfo](const CNTK::Variable& var) {
|
||||
return (var.Name() == inputVarInfo.first);
|
||||
}));
|
||||
|
||||
|
@ -61,7 +79,7 @@ inline void SaveAndReloadModel(CNTK::FunctionPtr& functionPtr, const std::vector
|
|||
auto outputs = functionPtr->Outputs();
|
||||
for (auto outputVarInfo : outputVarNames)
|
||||
{
|
||||
auto newOutputVar = *(std::find_if(outputs.begin(), outputs.end(), [outputVarInfo](const Variable& var) {
|
||||
auto newOutputVar = *(std::find_if(outputs.begin(), outputs.end(), [outputVarInfo](const CNTK::Variable& var) {
|
||||
return (var.Owner()->Name() == outputVarInfo.first);
|
||||
}));
|
||||
|
||||
|
|
|
@ -6,6 +6,43 @@ using namespace CNTK;
|
|||
|
||||
using namespace std::placeholders;
|
||||
|
||||
MinibatchSourcePtr CreateTextMinibatchSource(const std::wstring& filePath, size_t featureDim, size_t labelDim, size_t epochSize)
|
||||
{
|
||||
Dictionary featuresStreamConfig;
|
||||
featuresStreamConfig[L"dim"] = featureDim;
|
||||
featuresStreamConfig[L"format"] = L"dense";
|
||||
|
||||
Dictionary labelsStreamConfig;
|
||||
labelsStreamConfig[L"dim"] = labelDim;
|
||||
labelsStreamConfig[L"format"] = L"dense";
|
||||
|
||||
Dictionary inputStreamsConfig;
|
||||
inputStreamsConfig[L"features"] = featuresStreamConfig;
|
||||
inputStreamsConfig[L"labels"] = labelsStreamConfig;
|
||||
|
||||
Dictionary deserializerConfiguration;
|
||||
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
|
||||
deserializerConfiguration[L"module"] = L"CNTKTextFormatReader";
|
||||
deserializerConfiguration[L"file"] = filePath;
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
|
||||
Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"epochSize"] = epochSize;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
|
||||
|
||||
return CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
}
|
||||
|
||||
float PrevMinibatchTrainingLossValue(const Trainer& trainer)
|
||||
{
|
||||
float trainLossValue = 0.0;
|
||||
auto prevMBTrainingLossValue = trainer.PreviousMinibatchTrainingLossValue()->Data();
|
||||
NDArrayView cpuTrainLossValue(prevMBTrainingLossValue->Shape(), &trainLossValue, 1, DeviceDescriptor::CPUDevice());
|
||||
cpuTrainLossValue.CopyFrom(*prevMBTrainingLossValue);
|
||||
|
||||
return trainLossValue;
|
||||
}
|
||||
|
||||
void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
|
||||
{
|
||||
const size_t inputDim = 2;
|
||||
|
@ -34,52 +71,28 @@ void TrainSimpleFeedForwardClassifer(const DeviceDescriptor& device)
|
|||
const size_t numSweepsToTrainWith = 2;
|
||||
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
|
||||
|
||||
Dictionary featuresStreamConfig;
|
||||
featuresStreamConfig[L"dim"] = 2ULL;
|
||||
featuresStreamConfig[L"format"] = L"dense";
|
||||
|
||||
Dictionary labelsStreamConfig;
|
||||
labelsStreamConfig[L"dim"] = 2ULL;
|
||||
labelsStreamConfig[L"format"] = L"dense";
|
||||
|
||||
Dictionary inputStreamsConfig;
|
||||
inputStreamsConfig[L"features"] = featuresStreamConfig;
|
||||
inputStreamsConfig[L"labels"] = labelsStreamConfig;
|
||||
|
||||
Dictionary deserializerConfiguration;
|
||||
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
|
||||
deserializerConfiguration[L"module"] = L"CNTKTextFormatReader";
|
||||
deserializerConfiguration[L"file"] = L"SimpleDataTrain_cntk_text.txt";
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
|
||||
Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"epochSize"] = numSamplesPerSweep;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
|
||||
|
||||
auto minibatchSource = CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
auto minibatchSource = CreateTextMinibatchSource(L"SimpleDataTrain_cntk_text.txt", (size_t)2, (size_t)2, numSamplesPerSweep);
|
||||
|
||||
auto streamInfos = minibatchSource->StreamInfos();
|
||||
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) {
|
||||
return (streamInfo.m_name == L"features");
|
||||
});
|
||||
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) {
|
||||
return (streamInfo.m_name == L"labels");
|
||||
});
|
||||
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) { return (streamInfo.m_name == L"features"); });
|
||||
auto labelStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) { return (streamInfo.m_name == L"labels"); });
|
||||
|
||||
double learningRatePerSample = 0.02;
|
||||
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
|
||||
std::unordered_map<StreamInfo, std::pair<size_t, ValuePtr>> minibatchData = { { *featureStreamInfo, { minibatchSize, nullptr } }, { *labelStreamInfo, { minibatchSize, nullptr } } };
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
minibatchSource->GetNextMinibatch(minibatchData);
|
||||
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].second }, { labels, minibatchData[*labelStreamInfo].second } }, device);
|
||||
float trainLossValue = 0.0;
|
||||
auto prevMBTrainingLossValue = trainer.PreviousMinibatchTrainingLossValue()->Data();
|
||||
NDArrayView cpuTrainLossValue(prevMBTrainingLossValue->Shape(), &trainLossValue, 1, DeviceDescriptor::CPUDevice());
|
||||
cpuTrainLossValue.CopyFrom(*prevMBTrainingLossValue);
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", i, trainLossValue);
|
||||
|
||||
if ((i % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
float trainLossValue = PrevMinibatchTrainingLossValue(trainer);
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TrainMNISTClassifier(const DeviceDescriptor& device)
|
||||
{
|
||||
|
@ -105,30 +118,7 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
|
|||
const size_t numSweepsToTrainWith = 3;
|
||||
const size_t numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize;
|
||||
|
||||
Dictionary featuresStreamConfig;
|
||||
featuresStreamConfig[L"dim"] = 784ULL;
|
||||
featuresStreamConfig[L"format"] = L"dense";
|
||||
|
||||
Dictionary labelsStreamConfig;
|
||||
labelsStreamConfig[L"dim"] = 10ULL;
|
||||
labelsStreamConfig[L"format"] = L"dense";
|
||||
|
||||
Dictionary inputStreamsConfig;
|
||||
inputStreamsConfig[L"features"] = featuresStreamConfig;
|
||||
inputStreamsConfig[L"labels"] = labelsStreamConfig;
|
||||
|
||||
Dictionary deserializerConfiguration;
|
||||
deserializerConfiguration[L"type"] = L"CNTKTextFormatDeserializer";
|
||||
deserializerConfiguration[L"module"] = L"CNTKTextFormatReader";
|
||||
deserializerConfiguration[L"file"] = L"Train-28x28_cntk_text.txt";
|
||||
deserializerConfiguration[L"input"] = inputStreamsConfig;
|
||||
|
||||
Dictionary minibatchSourceConfiguration;
|
||||
minibatchSourceConfiguration[L"randomize"] = true;
|
||||
minibatchSourceConfiguration[L"epochSize"] = numSamplesPerSweep;
|
||||
minibatchSourceConfiguration[L"deserializers"] = std::vector<DictionaryValue>({ deserializerConfiguration });
|
||||
|
||||
auto minibatchSource = CreateCompositeMinibatchSource(minibatchSourceConfiguration);
|
||||
auto minibatchSource = CreateTextMinibatchSource(L"Train-28x28_cntk_text.txt", (size_t)784, (size_t)10, numSamplesPerSweep);
|
||||
|
||||
auto streamInfos = minibatchSource->StreamInfos();
|
||||
auto featureStreamInfo = std::find_if(streamInfos.begin(), streamInfos.end(), [](const StreamInfo& streamInfo) {
|
||||
|
@ -141,15 +131,17 @@ void TrainMNISTClassifier(const DeviceDescriptor& device)
|
|||
double learningRatePerSample = 0.003125;
|
||||
Trainer trainer(oneHiddenLayerClassifier, trainingLoss, { SGDLearner(oneHiddenLayerClassifier->Parameters(), learningRatePerSample) });
|
||||
std::unordered_map<StreamInfo, std::pair<size_t, ValuePtr>> minibatchData = { { *featureStreamInfo, { minibatchSize, nullptr } }, { *labelStreamInfo, { minibatchSize, nullptr } } };
|
||||
size_t outputFrequencyInMinibatches = 20;
|
||||
for (size_t i = 0; i < numMinibatchesToTrain; ++i)
|
||||
{
|
||||
minibatchSource->GetNextMinibatch(minibatchData);
|
||||
trainer.TrainMinibatch({ { input, minibatchData[*featureStreamInfo].second }, { labels, minibatchData[*labelStreamInfo].second } }, device);
|
||||
float trainLossValue = 0.0;
|
||||
auto prevMBTrainingLossValue = trainer.PreviousMinibatchTrainingLossValue()->Data();
|
||||
NDArrayView cpuTrainLossValue(prevMBTrainingLossValue->Shape(), &trainLossValue, 1, DeviceDescriptor::CPUDevice());
|
||||
cpuTrainLossValue.CopyFrom(*prevMBTrainingLossValue);
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", i, trainLossValue);
|
||||
|
||||
if ((i % outputFrequencyInMinibatches) == 0)
|
||||
{
|
||||
float trainLossValue = PrevMinibatchTrainingLossValue(trainer);
|
||||
printf("Minibatch %d: CrossEntropy loss = %.8g\n", (int)i, trainLossValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче