diff --git a/Makefile b/Makefile index fbb6314df..9583ce1c3 100644 --- a/Makefile +++ b/Makefile @@ -375,6 +375,8 @@ CNTKLIBRARY_SRC =\ $(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \ + $(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \ + CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC) CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC) diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index b73d471f5..b38e4fcb7 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -285,6 +285,7 @@ namespace CNTK class NDArrayView final : public std::enable_shared_from_this { friend class CompositeFunction; + friend class LearnerBase; template friend inline std::shared_ptr MakeSharedObject(CtorArgTypes&& ...ctorArgs); @@ -1396,4 +1397,342 @@ namespace CNTK /// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output. /// CNTK_API FunctionPtr Combine(const std::initializer_list& operands, const std::wstring& name = L""); + + /// + /// A serializable value represents one of: + /// a) Boolean + /// b) Signed long integer + /// c) Single and double precision floating point values + /// d) NDShape + /// e) vector + /// + /// TODO: we need to have native support for DictionaryValue and DictionaryValue. + class CNTK_API DictionaryValue final + { + public: + enum class Type : unsigned int + { + None, + Bool, + SizeT, + Float, + Double, + NDShape, + Vector + }; + + static const char* TypeName(Type type) + { + switch (type) + { + case Type::None: + return "None"; + case Type::Bool: + return "Bool"; + case Type::SizeT: + return "SizeT"; + case Type::Float: + return "Float"; + case Type::Double: + return "Double"; + case Type::NDShape: + return "NDShape"; + case Type::Vector: + return "Vector"; + default: + LogicError("Unknown DictionaryValue::Type"); + } + } + + public: + DictionaryValue() : m_valueType(Type::None) + { + } + + DictionaryValue(bool value) : m_valueType(GetValueType()) + { + m_data.m_boolean = value; + } + + DictionaryValue(size_t value) : m_valueType(GetValueType()) + { + m_data.m_sizeT = value; + } + + DictionaryValue(float value) : m_valueType(GetValueType()) + { + m_data.m_float = value; + } + + DictionaryValue(double value) : m_valueType(GetValueType()) + { + m_data.m_double = value; + } + + template + DictionaryValue(const T& value) : m_valueType(GetValueType()) + { + static_assert(std::is_same::value || + std::is_same>::value, + "Unsupported ValueType"); + + AllocateDataPtr(value); + } + + DictionaryValue(const DictionaryValue& other) : m_valueType(Type::Bool) + { + // The m_valueType must have been set to a non-ptr type to prevent an attempt to interpret + // the underlying underlying uninitialized value as a ptr and free it. + *this = other; + } + + DictionaryValue& operator=(const DictionaryValue& other) + { + if (this != &other) + { + FreeDataPtr(); + + m_valueType = other.m_valueType; + m_data = other.m_data; + + if (other.m_valueType == Type::NDShape) + AllocateDataPtr(other.GetValue()); + else if (other.m_valueType == Type::Vector) + AllocateDataPtr(other.GetValue>()); + } + + return *this; + } + + ~DictionaryValue() + { + FreeDataPtr(); + } + + template ::value>::type* = nullptr> + const T& GetValue() const + { + VerifyType(); + return m_data.m_boolean; + } + + template ::value>::type* = nullptr> + const T& GetValue() const + { + VerifyType(); + return m_data.m_sizeT; + } + + template ::value>::type* = nullptr> + const T& GetValue() const + { + VerifyType(); + return m_data.m_float; + } + + template ::value>::type* = nullptr> + const T& GetValue() const + { + VerifyType(); + return m_data.m_double; + } + + template ::value || std::is_same>::value>::type* = nullptr> + const T& GetValue() const + { + VerifyType(); + return *(reinterpret_cast(m_data.m_ptr)); + } + + bool HasValue() const + { + return m_valueType != Type::None; + } + + Type ValueType() const + { + return m_valueType; + } + + friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us); + friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us); + + private: + template + static Type GetValueType() + { + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same>::value, + "Unsupported ValueType"); + + if (std::is_same::value) return Type::Bool; + if (std::is_same::value) return Type::SizeT; + if (std::is_same::value) return Type::Float; + if (std::is_same::value) return Type::Double; + if (std::is_same::value) return Type::NDShape; + if (std::is_same>::value) return Type::Vector; + } + + template + void VerifyType() const + { + if (GetValueType() != m_valueType) + RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType)); + } + + template + void AllocateDataPtr(const T& value); + + template + void FreePtrAsType(); + + void FreeDataPtr(); + + Type m_valueType; + + union ValueData + { + bool m_boolean; + size_t m_sizeT; + float m_float; + double m_double; + void* m_ptr; + } m_data; + + const size_t version = 1; + }; + + /// + /// A type denoting a dictionary (keyed by Unicode strings) of serializable values (dynamically typed). + /// + class CNTK_API Dictionary final + { + public: + Dictionary(); + ~Dictionary(); + + // Disallow copy construction and assignment + Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete; + + Dictionary(Dictionary&& other); + Dictionary& operator=(Dictionary&& other); + + DictionaryValue& operator[](const std::wstring& key) + { + return operator[](key.c_str()); + } + + DictionaryValue& operator[](const wchar_t* key); + + DictionaryValue operator[](const std::wstring& key) const + { + return operator[](key.c_str()); + } + + DictionaryValue operator[](const wchar_t* key) const; + + bool Contains(const std::wstring& key) const + { + return Contains(key.c_str()); + } + + bool Contains(const wchar_t* key) const; + + friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us); + friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us); + + private: + std::unordered_map* m_dictionaryData; + const size_t version = 1; + }; + + /// + /// Abstraction for learning a subset of parameters of a learnable function using first order gradient values + /// For e.g momentum, AdaGrad, RMSProp etc. are different types of learners with their own algorithms for + /// learning parameter values using first order gradients. + /// + class Learner : public std::enable_shared_from_this + { + public: + // + // Method to update the parameters associated with this learner. By returning false, this method indicates that + // learning has stopped for all of the parameters associated with this learner + // + CNTK_API virtual bool Update(const std::unordered_map& parameterValues, + const std::unordered_map& gradientValues, + size_t trainingSampleCount) = 0; + + /// + /// Returns the set of parameters associated with this learner. + /// + const std::unordered_set& Parameters() const { return m_parameters; } + + // TODO: move the following two methods into ISerializable interface, make + // Learner (and all other entities that need checkpointing capability) implement it. + /// + /// Optionally overridable method to checkpoint the learner's state. + /// + CNTK_API virtual Dictionary GetCheckpointState() const = 0; + + /// + /// Optionally overridable method to restore the learner's state from a previous checkpoint. + /// + CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0; + + virtual ~Learner() + { + } + + protected: + Learner(const std::unordered_set& parameters) + : m_parameters(parameters) + { + } + + std::unordered_set m_parameters; + + }; + + /// + /// Create an instance of the CNTK built-in SGD learner. + /// + /// TODO: add additional SGD parameters here (a collection of learning rate values) + CNTK_API LearnerPtr SGDLearner(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + /// + /// Create an instance of the CNTK built-in Momentum SGD learner. + /// + /// TODO: add additional Momentum parameters here (a collection of momentum rate values) + CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + /// + /// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner. + /// + CNTK_API LearnerPtr NesterovLearner(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + /// + /// Create an instance of the CNTK built-in AdaGrad learner. + /// + CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set& parameters, bool needAveMultiplier = true, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + /// + /// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner. + /// + CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + /// + /// Create an instance of the CNTK built-in RMSProp learner. + /// + CNTK_API LearnerPtr RMSPropLearner(const std::unordered_set& parameters, + double gamma, double inc, double dec, double max, double min, bool needAveMultiplier = true, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); } diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h index 0ca54673e..fe7bcd45b 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h @@ -47,6 +47,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { template class ComputationNode; + + class File; }}} // TODO: The following should be reconciled with the equivalent code in the CNTK implementation @@ -158,4 +160,7 @@ namespace CNTK class Function; typedef std::shared_ptr FunctionPtr; + + class Learner; + typedef std::shared_ptr LearnerPtr; } diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj index 0ebcfe835..bf18f452a 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj @@ -128,6 +128,7 @@ + @@ -140,6 +141,7 @@ + diff --git a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters index 0dcd63972..1d2b139d1 100644 --- a/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters +++ b/Source/CNTKv2LibraryDll/CNTKv2LibraryDll.vcxproj.filters @@ -10,6 +10,7 @@ + @@ -22,6 +23,7 @@ API + diff --git a/Source/CNTKv2LibraryDll/Learner.cpp b/Source/CNTKv2LibraryDll/Learner.cpp new file mode 100644 index 000000000..75ba85870 --- /dev/null +++ b/Source/CNTKv2LibraryDll/Learner.cpp @@ -0,0 +1,464 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#include "Learner.h" +#include "TensorView.h" +#include "Utils.h" + +#define UPDATE_FUNCTION \ + switch (smoothedGradientValue->Data()->GetDataType()) \ + { \ + case DataType::Float: \ + Update(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \ + break; \ + case DataType::Double: \ + Update(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \ + break; \ + default: \ + NOT_IMPLEMENTED; \ + } + + +using namespace Microsoft::MSR::CNTK; +using namespace std; + +namespace CNTK +{ + template + /*static*/ shared_ptr> LearnerBase::GetMatrix(const NDArrayViewPtr arrayView) + { + return arrayView->GetMatrix(); + } + + template + /*static*/ shared_ptr> LearnerBase::GetWritableMatrix(NDArrayViewPtr arrayView) + { + return arrayView->GetWritableMatrix(); + } + + template + /*static*/ const TensorView* LearnerBase::GetTensorView(const NDArrayViewPtr arrayView) + { + return arrayView->GetTensorView(); + } + + /*static*/ bool LearnerBase::HasNan(const ValuePtr& value, const char* name) + { + const auto& data = value->Data(); + switch (data->GetDataType()) + { + case DataType::Float: + return data->GetMatrix()->HasNan(name); + case DataType::Double: + return data->GetMatrix()->HasNan(name); + default: + LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType())); + } + } + + /*static*/ void LearnerBase::Print(const ValuePtr& value, const char* msg) + { + const auto& data = value->Data(); + switch (data->GetDataType()) + { + case DataType::Float: + data->GetMatrix()->Print(msg); + break; + case DataType::Double: + data->GetMatrix()->Print(msg); + break; + default: + LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType())); + } + } + + // Clipping gradients to prevent outliers, + template + void LearnerBase::ClipGradient(Matrix& gradient, size_t actualMBSize) const + { + if (m_additionalOptions.gradientClippingThresholdPerSample != numeric_limits::infinity()) + { + double maxGradientPerMB = m_additionalOptions.gradientClippingThresholdPerSample * actualMBSize; + if (m_additionalOptions.gradientClippingWithTruncation) + gradient.InplaceTruncate(ElementType(maxGradientPerMB)); + else + { + // norm2 normalized + double gradientNorm = gradient.FrobeniusNorm(); + if (gradientNorm > maxGradientPerMB) + { + double normFactor = maxGradientPerMB / gradientNorm; + gradient *= ElementType(normFactor); + } + } + } + } + + // Performs additional preprocessing before calling the update method + // (gradient clipping and L2 regularization depending on the additional learning parameters). + template + void LearnerBase::PreProcess(const ValuePtr& gradientValue,const ValuePtr& parameterValue, size_t actualMBSize) const + { + const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix(); + + // clipping gradients to prevent outliers + ClipGradient(*gradientMatrix, actualMBSize); + + // L2 regularizer + if (m_additionalOptions.l2RegularizationWeight > 0) + { + // multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample + auto weight = ElementType(m_additionalOptions.l2RegularizationWeight * actualMBSize); + const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix(); + Matrix::ScaleAndAdd(weight, *parameterMatrix, *gradientMatrix); + } + } + + // Performs additional postprocessing after the update method has been executed + // (noise injection and L1 regularization specified by the additional learning parameters). + template + void LearnerBase::PostProcess(const Variable& parameter, const ValuePtr& gradientValue, + const ValuePtr& parameterValue, size_t actualMBSize) const + { + const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix(); + if (m_additionalOptions.gaussianNoiseInjectionStdDev > 0) + { + const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix(); + + Matrix sgdUpdateNoise((DEVICEID_TYPE)parameterMatrix->GetDeviceId()); + + // get the gradient structure since gradient is sparse + sgdUpdateNoise.SetValue(*gradientMatrix); + + auto noiseStdDev = ElementType(m_additionalOptions.gaussianNoiseInjectionStdDev); + + // reset its value to random + sgdUpdateNoise.SetGaussianRandomValue(ElementType(0.0), noiseStdDev); + + Matrix::ScaleAndAdd(ElementType(1.0), sgdUpdateNoise, *parameterMatrix); + } + + // L1 regularizer with proximal gradient descent method + if (m_additionalOptions.l1RegularizationWeight > 0) + { + auto learningRate = ElementType(ParameterDependentLearningRate(parameter)); + // multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample + auto weight = ElementType(learningRate * m_additionalOptions.l1RegularizationWeight * actualMBSize); + parameterValue->Data()->GetWritableMatrix()->InplaceSoftThreshold(weight); + } + } + + template + /*static*/ TensorView* LearnerBase::GetWritableTensorView(NDArrayViewPtr arrayView) + { + return arrayView->GetWritableTensorView(); + } + + LearnerBase::LearnerBase(const unordered_set& parameters, const DeviceDescriptor& device) + : Learner(parameters), + m_learningRatePerSample(0.0), + m_sampleCount(0) + { + const unordered_set& parameterSet = parameters; + for (const auto& parameter : parameterSet) + { + // TODO: using the same device to allocate data for all smoothed gradients. Is this correct? + // Should the device be specified on the per-parameter basis? + NDArrayViewPtr view; + if (parameter.GetDataType() == DataType::Float) + { + view = MakeSharedObject(0.0f, parameter.Shape(), device); + } + else + { + view = MakeSharedObject(0.0, parameter.Shape(), device); + } + + m_smoothedGradientValues.insert(make_pair(parameter, MakeSharedObject(view))); + m_additionalOptions.learningRateMultipliers.insert(make_pair(parameter, 1.0)); + } + } + + void LearnerBase::ResetSmoothedGradients() + { + for (const auto& parameter : Parameters()) + { + const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter); + const auto& data = smoothedGradientValue->Data(); + switch (data->GetDataType()) + { + case DataType::Float: + data->SetValue(0.0f); + break; + case DataType::Double: + data->SetValue(0.0); + break; + default: + LogicError("Unsupported DataType %s", ::CNTK::DataTypeName(data->GetDataType())); + } + } + } + + /*virtual*/ bool LearnerBase::Update(const unordered_map& parameterValues, + const unordered_map& gradientValues, + size_t trainingSampleCount) /*override*/ + { + // make sure trainingSampleCount is a valid value + assert(trainingSampleCount > 0); + + for (const auto& parameter : Parameters()) + { + const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter); + const auto& gradientValue = gradientValues.at(parameter); + const auto& parameterValue = parameterValues.at(parameter); + +// TODO: make this a runtime parameter. +#if DUMPOUTPUT + LOGPRINTF(stderr, "Update_%ls\n", parameter.Name().c_str()); +#endif + +#ifdef _DEBUG + if (HasNan(smoothedGradientValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): ")) + LogicError("%ls has NaNs in smoothedGradient.", parameter.Name().c_str()); +#endif + +#if DUMPOUTPUT + LOGPRINTF(stderr, "learnRatePerSample=%0.8f, momentum=%0.8f, actualMBSize=%ld\n", + m_learningRatePerSample, m_momentumPerSample, trainingSampleCount); + LOGPRINTF(stderr, "GradUpdateType()=%s, GradientUpdateNoiseStd()=%0.8f\n", + LearnerType().c_str(), m_GaussianNoiseInjectStd); + Print(gradientValue, "Gradient Update"); + Print(smoothedGradientValue, "Smoothed Gradient Input"); +#endif + UPDATE_FUNCTION; + +#if DUMPOUTPUT + Print(parameterValue, "Parameter Update"); +#endif + +#ifdef _DEBUG + if (HasNan(parameterValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): ")) + LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Name().c_str()); +#endif + } + m_sampleCount += trainingSampleCount; + return false; + } + + template + void LearnerBase::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const + { + PreProcess(gradientValue, parameterValue, trainingSampleCount); + Update(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); + PostProcess(parameter, gradientValue, parameterValue, trainingSampleCount); + } + + string LearnerBase::LearnerType() const + { + auto name = typeid(*this).name(); + if (strncmp(name, "class ", 6) == 0) + { + // On Windows, the type name contains "class" prefix. + // Return the actual name, omitting the prefix. + return &name[6]; + } + return name; + } + + /*virtual*/ Dictionary LearnerBase::GetCheckpointState() const /*override*/ + { + NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this. + Dictionary checkpoint; + + for (const auto& parameter : Parameters()) + { + // TODO: parameter name is not guaranteed to be unique. Instead, all serializable objects + // need to expose "UId" property -- a persistent unique internal name. + // Switch to UId as soon as it's available. + if (checkpoint.Contains(parameter.Name())) + { + LogicError("Parameter names must be unique"); + } + const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter); + + // Potentially, could store things like dimensions, element size, format, etc., but + // that seems to be redundant, since all of that is passed in the constructor. + checkpoint[parameter.Name()] = SerializeToVector(smoothedGradientValue->Data()); + } + return checkpoint; + } + + /*virtual*/ void LearnerBase::RestoreFromCheckpoint(const Dictionary& checkpoint) /*override*/ + { + NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this. + for (const auto& parameter : Parameters()) + { + if (!checkpoint.Contains(parameter.Name())) + { + LogicError("Checkpoint does not contain state for parameter %ls", parameter.Name().c_str()); + } + const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter); + + const DictionaryValue& state = checkpoint[parameter.Name()]; + + const auto& data = smoothedGradientValue->Data(); + + DeserializeFromVector(data, state.GetValue>()); + } + } + + /*virtual*/ void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/ + { + UPDATE_FUNCTION; + } + + template + void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const + { + UNUSED(trainingSampleCount); + + const auto& smoothedGradientMatrix = GetWritableMatrix(smoothedGradientValue->Data()); + const auto& gradientMatrix = GetWritableMatrix(gradientValue->Data()); + const auto& parameterMatrix = GetWritableMatrix(parameterValue->Data()); + + const auto& learningRate = ElementType(ParameterDependentLearningRate(parameter)); + + // TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters + // (one for vanilla SGD, the other for momentum SGD, and the third one for NAG). + smoothedGradientMatrix->NormalGrad(*gradientMatrix, *parameterMatrix, + learningRate, ElementType(m_momentumPerSample), m_useNesterovAcceleration); + } + + LearnerAdaGrad::LearnerAdaGrad(const unordered_set& parameters, bool needAveMultiplier, const DeviceDescriptor& device) + : LearnerBase(parameters, device), + m_needAveMultiplier(needAveMultiplier) + { + } + + /*virtual*/ void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/ + { + UPDATE_FUNCTION; + } + + template + void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const + { + UNUSED(trainingSampleCount); + + const auto& smoothedGradientMatrix = GetWritableMatrix(smoothedGradientValue->Data()); + const auto& gradientMatrix = GetWritableMatrix(gradientValue->Data()); + const auto& parameterMatrix = GetWritableMatrix(parameterValue->Data()); + + auto learningRate = ElementType(ParameterDependentLearningRate(parameter)); + + auto aveMultiplier = smoothedGradientMatrix->Adagrad(*gradientMatrix, m_needAveMultiplier); + Matrix::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix); + } + + LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set& parameters, const DeviceDescriptor& device) + : LearnerMomentumSGD(parameters, device) + { + } + + /*virtual*/ void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/ + { + UPDATE_FUNCTION; + } + + template + void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const + { + UNUSED(trainingSampleCount); + + const auto& smoothedGradientMatrix = GetWritableMatrix(smoothedGradientValue->Data()); + const auto& gradientMatrix = GetWritableMatrix(gradientValue->Data()); + const auto& parameterMatrix = GetWritableMatrix(parameterValue->Data()); + + //const double momentum = MomentumPerMB(m_momentumPerSample, trainingSampleCount); + + auto learningRate = ElementType(ParameterDependentLearningRate(parameter)); + + smoothedGradientMatrix->FSAdagrad(trainingSampleCount, *gradientMatrix, *parameterMatrix, + learningRate, ElementType(m_momentumPerSample)); + } + + LearnerRMSProp::LearnerRMSProp(const unordered_set& parameters, + double gamma, double inc, double dec, double max, double min, + bool needAveMultiplier, const DeviceDescriptor& device) + : LearnerBase(parameters, device), + m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min), + m_needAveMultiplier(needAveMultiplier) + { + } + + /*virtual*/ void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/ + { + UPDATE_FUNCTION; + } + + template + void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const + { + UNUSED(trainingSampleCount); + + const auto& smoothedGradientMatrix = GetWritableMatrix(smoothedGradientValue->Data()); + const auto& gradientMatrix = GetWritableMatrix(gradientValue->Data()); + const auto& parameterMatrix = GetWritableMatrix(parameterValue->Data()); + + auto learningRate = ElementType(ParameterDependentLearningRate(parameter)); + + auto aveMultiplier = smoothedGradientMatrix->RmsProp(*gradientMatrix, + ElementType(m_gamma), ElementType(m_inc), + ElementType(m_max), ElementType(m_dec), + ElementType(m_min), m_needAveMultiplier); + Matrix::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix); + } + + // Explicit template instantiations + template shared_ptr> LearnerBase::GetWritableMatrix(const NDArrayViewPtr arrayView); + template shared_ptr> LearnerBase::GetWritableMatrix(const NDArrayViewPtr arrayView); + + LearnerPtr SGDLearner(const unordered_set& parameters, const DeviceDescriptor& device) + { + return MakeSharedObject(parameters, device); + } + + LearnerPtr MomentumSGDLearner(const unordered_set& parameters, const DeviceDescriptor& device) + { + return MakeSharedObject(parameters, device); + } + + LearnerPtr NesterovLearner(const unordered_set& parameters, const DeviceDescriptor& device) + { + return MakeSharedObject(parameters, device); + } + + LearnerPtr AdaGradLearner(const unordered_set& parameters, bool needAveMultiplier, const DeviceDescriptor& device) + { + return MakeSharedObject(parameters, needAveMultiplier, device); + } + + LearnerPtr FSAdaGradLearner(const unordered_set& parameters, const DeviceDescriptor& device) + { + return MakeSharedObject(parameters, device); + } + + LearnerPtr RMSPropLearner(const unordered_set& parameters, + double gamma, double inc, double dec, double max, double min, bool needAveMultiplier, + const DeviceDescriptor& device) + { + return MakeSharedObject(parameters, gamma, inc, dec, max, min, needAveMultiplier, device); + } + +} \ No newline at end of file diff --git a/Source/CNTKv2LibraryDll/Learner.h b/Source/CNTKv2LibraryDll/Learner.h new file mode 100644 index 000000000..568ec2948 --- /dev/null +++ b/Source/CNTKv2LibraryDll/Learner.h @@ -0,0 +1,224 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +// + +#include "stdafx.h" +#include "CNTKLibrary.h" + +namespace CNTK +{ + // A collection of additional options that are applicable for all standard learners + // (after these options are set, they retain their value for the entire lifespan of a learner). + struct AdditionalLearningOptions + { + double l1RegularizationWeight = 0.0; + double l2RegularizationWeight = 0.0; + double gaussianNoiseInjectionStdDev = 0.0; + bool gradientClippingWithTruncation = false; + double gradientClippingThresholdPerSample = 0.0; + std::unordered_map learningRateMultipliers; + }; + + // An abstract base class at the root of the standard learners hierarchy + // It implements most of the learner functionality, except for the actual update function, + // and adds a few pre-/postprocessing methods (which are invoked before and after the update). + class LearnerBase : public Learner + { + public: + + CNTK_API virtual bool Update(const std::unordered_map& parameterValues, + const std::unordered_map& gradientValues, + size_t trainingSampleCount) override final; + + CNTK_API virtual Dictionary GetCheckpointState() const override; + + CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) override; + + CNTK_API void SetAdditionalOptions(const AdditionalLearningOptions& additionalOptions) + { + m_additionalOptions = additionalOptions; + } + + // TODO: should this be called ResetMomentum? + // needed for BlockMomemtumSGD to reset SGD momentum after aggregation. + CNTK_API void ResetSmoothedGradients(); + + // TODO: move learning rate and momentum scheduling and adjustment functionality + // inside the learner and drop these setters. + void SetLearningRate(double value) { m_learningRatePerSample = value; } + + protected: + LearnerBase(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const = 0; + + double ParameterDependentLearningRate(const Variable& parameter) const + { + return m_learningRatePerSample * m_additionalOptions.learningRateMultipliers.at(parameter); + } + + std::string LearnerType() const; + + double m_learningRatePerSample; + + AdditionalLearningOptions m_additionalOptions; + + std::unordered_map m_smoothedGradientValues; + + // The following four static protected methods expose private methods of NDArrayView class + // (which declares LearnerBase as friend class), so that they are available to subclasses. + template + static std::shared_ptr> GetMatrix(const NDArrayViewPtr arrayView); + + template + static std::shared_ptr> GetWritableMatrix(NDArrayViewPtr arrayView); + + template + static const Microsoft::MSR::CNTK::TensorView* GetTensorView(const NDArrayViewPtr arrayView); + + template + static Microsoft::MSR::CNTK::TensorView* GetWritableTensorView(NDArrayViewPtr arrayView); + + template + void ClipGradient(Microsoft::MSR::CNTK::Matrix& gradient, size_t actualMBSize) const; + + // Performs additional preprocessing before calling the update method + // (gradient clipping and L2 regularization depending on the additional learning parameters). + template + void PreProcess(const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t actualMBSize) const; + + // Performs additional postprocessing after the update method has been executed + // (noise injection and L1 regularization specified by the additional learning parameters). + template + void PostProcess(const Variable& parameter, const ValuePtr& gradientValue, + const ValuePtr& parameterValue, size_t actualMBSize) const; + private: + // Templatized update function, it invokes preprocess and postprocess using the provided + // template parameter and also invokes virtual Update method implemented in one of the subclasses. + template + void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const; + + // TODO: make these functions friends of NDViewArray and move to Utils? + static bool HasNan(const ValuePtr& value, const char* name); + static void Print(const ValuePtr& value, const char* msg); + + size_t m_sampleCount; + }; + + // Vanilla gradient descent optimization algorithm. + class LearnerSGD : public LearnerBase + { + public: + + LearnerSGD(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()) + : LearnerBase(parameters, device), + m_momentumPerSample(0.0), + m_useNesterovAcceleration(false) + { + } + + protected: + + virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override; + + template + void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const; + + double m_momentumPerSample; + bool m_useNesterovAcceleration; + }; + + // SGD optimization with momentum. + class LearnerMomentumSGD : public LearnerSGD + { + public: + + LearnerMomentumSGD(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()) + : LearnerSGD(parameters, device) + { + } + + void SetMomentum(double value) { m_momentumPerSample = value; } + }; + + // Nesterov's accelerated SGDLearnerBase descent. + class LearnerNesterov : public LearnerSGD + { + public: + + LearnerNesterov(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()) + : LearnerSGD(parameters, device) + { + m_useNesterovAcceleration = true; + } + }; + + class LearnerAdaGrad : public LearnerBase + { + public: + + LearnerAdaGrad(const std::unordered_set& parameters, bool needAveMultiplier, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + protected: + bool m_needAveMultiplier; + + virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override; + + template + void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const; + }; + + class LearnerFSAdaGrad : public LearnerMomentumSGD + { + public: + + LearnerFSAdaGrad(const std::unordered_set& parameters, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + protected: + + virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override; + + template + void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const; + }; + + class LearnerRMSProp : public LearnerBase + { + public: + + LearnerRMSProp(const std::unordered_set& parameters, + double gamma, double inc, double dec, double max, double min, bool needAveMultiplier, + const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice()); + + protected: + + double m_gamma; + double m_inc; + double m_dec; + double m_max; + double m_min; + bool m_needAveMultiplier; + + virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override; + + template + void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue, + const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const; + }; +} \ No newline at end of file diff --git a/Source/CNTKv2LibraryDll/NDArrayView.cpp b/Source/CNTKv2LibraryDll/NDArrayView.cpp index c7f1d2973..1a4ed1ac1 100644 --- a/Source/CNTKv2LibraryDll/NDArrayView.cpp +++ b/Source/CNTKv2LibraryDll/NDArrayView.cpp @@ -338,8 +338,10 @@ namespace CNTK template std::shared_ptr> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const; template std::shared_ptr> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const; - template std::shared_ptr> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/); - template std::shared_ptr> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/); + template std::shared_ptr> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/); + template std::shared_ptr> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/); + template TensorView* NDArrayView::GetWritableTensorView(); + template TensorView* NDArrayView::GetWritableTensorView(); template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const float* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/); template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const double* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/); diff --git a/Source/CNTKv2LibraryDll/Utils.cpp b/Source/CNTKv2LibraryDll/Utils.cpp index 19b208b83..75114ddb6 100644 --- a/Source/CNTKv2LibraryDll/Utils.cpp +++ b/Source/CNTKv2LibraryDll/Utils.cpp @@ -6,11 +6,138 @@ #include "stdafx.h" #include "CNTKLibrary.h" #include "Utils.h" +#include "File.h" + +using namespace std; namespace CNTK { + template + void DictionaryValue::AllocateDataPtr(const T& value) + { + static_assert(is_same::value || is_same>::value, "AllocateDataPtr called with invalid type"); + m_data.m_ptr = new T(value); + } + + template + void DictionaryValue::FreePtrAsType() + { + T* typedPtr = reinterpret_cast(m_data.m_ptr); + delete typedPtr; + + m_data.m_ptr = nullptr; + } + + void DictionaryValue::FreeDataPtr() + { + if (m_valueType == Type::NDShape) + FreePtrAsType(); + else if (m_valueType == Type::Vector) + FreePtrAsType>(); + } + + Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us) + { + size_t version; + stream >> version; + + stream >> us.m_valueType; + + switch (us.ValueType()) + { + case DictionaryValue::Type::Bool: + stream >> us.m_data.m_boolean; + break; + case DictionaryValue::Type::SizeT: + stream >> us.m_data.m_sizeT; + break; + case DictionaryValue::Type::Float: + stream >> us.m_data.m_float; + break; + case DictionaryValue::Type::Double: + stream >> us.m_data.m_double; + break; + case DictionaryValue::Type::NDShape: + { + size_t size; + stream >> size; + vector dims(size); + for (auto i = 0; i < size; i++) + { + stream >> dims[i]; + } + us.AllocateDataPtr(NDShape(dims)); + break; + } + case DictionaryValue::Type::Vector: + { + size_t size; + stream >> size; + vector values(size); + for (auto i = 0; i < size; i++) + { + stream >> values[i]; + } + us.AllocateDataPtr(values); + break; + } + default: + NOT_IMPLEMENTED; + } + return stream; + } + + Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us) + { + stream << us.version; + + stream << us.ValueType(); + + switch (us.ValueType()) + { + case DictionaryValue::Type::Bool: + stream << us.m_data.m_boolean; + break; + case DictionaryValue::Type::SizeT: + stream << us.m_data.m_sizeT; + break; + case DictionaryValue::Type::Float: + stream << us.m_data.m_float; + break; + case DictionaryValue::Type::Double: + stream << us.m_data.m_double; + break; + case DictionaryValue::Type::NDShape: + { + NDShape* shapePtr = reinterpret_cast(us.m_data.m_ptr); + auto size = shapePtr->NumAxes(); + stream << size; + for (auto i = 0; i < size; i++) + { + stream << shapePtr->operator[](i); + } + break; + } + case DictionaryValue::Type::Vector: + { + vector* vectorPtr = + reinterpret_cast*>(us.m_data.m_ptr); + auto size = vectorPtr->size(); + stream << size; + for (auto i = 0; i < size; i++) + { + stream << vectorPtr->operator[](i); + } + break; + } + default: + NOT_IMPLEMENTED; + } + return stream; + } + Dictionary::Dictionary() - : m_dictionaryData(new std::unordered_map < std::wstring, DictionaryValue>) + : m_dictionaryData(new unordered_map ) { } @@ -22,7 +149,7 @@ namespace CNTK Dictionary::Dictionary(Dictionary&& other) : m_dictionaryData(nullptr) { - *this = std::move(other); + *this = move(other); } Dictionary& Dictionary::operator=(Dictionary&& other) @@ -51,4 +178,130 @@ namespace CNTK { return (m_dictionaryData->find(key) != m_dictionaryData->end()); } + + Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us) + { + stream << us.version; + stream << us.m_dictionaryData->size(); + for (auto it = us.m_dictionaryData->begin(); it != us.m_dictionaryData->end(); ++it) + { + stream << it->first; + stream << it->second; + } + return stream; + } + + Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us) + { + size_t version; + stream >> version; + size_t size; + stream >> size; + us.m_dictionaryData->reserve(size); + for (auto i = 0; i < size; i++) + { + wstring key; + stream >> key; + DictionaryValue value; + stream >> value; + us.m_dictionaryData->insert(make_pair(key, value)); + } + return stream; + } + + template + vector SerializeToVector(const NDArrayViewPtr& viewPtr) + { + if (viewPtr->IsSparse()) + { + LogicError("Sparse NDArrayView cannot be serialized into a vector."); + } + + auto numElements = viewPtr->Shape().TotalSize(); + + vector values(numElements); + + NDArrayViewPtr cpuDataViewPtr = viewPtr; + if ((viewPtr->Device().Type() != DeviceKind::CPU)) + { + cpuDataViewPtr = MakeSharedObject(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice()); + cpuDataViewPtr->CopyFrom(*viewPtr); + } + + const T* buffer = cpuDataViewPtr->DataBuffer(); + for (auto i = 0; i < numElements; ++i) + { + T v = buffer[i]; + values[i] = DictionaryValue(v); + } + + return values; + } + + template + void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector& values) + { + if (viewPtr->IsSparse()) + { + LogicError("Sparse NDArrayView cannot be deserialized from a vector."); + } + + auto numElements = viewPtr->Shape().TotalSize(); + + if (values.size() != numElements) + { + LogicError("Number of elements (%lu) in the deserialized representation does not match the expected value (%lu)", + values.size(), numElements); + } + + NDArrayViewPtr cpuDataViewPtr = viewPtr; + if ((viewPtr->Device().Type() != DeviceKind::CPU)) + { + cpuDataViewPtr = MakeSharedObject(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice()); + } + + T* buffer = cpuDataViewPtr->WritableDataBuffer(); + for (auto i = 0; i < numElements; ++i) + { + buffer[i] = values[i].GetValue(); + } + + if ((viewPtr->Device().Type() != DeviceKind::CPU)) + { + viewPtr->CopyFrom(*cpuDataViewPtr); + } + } + + // TODO: we store the type info for every element in the vector, which is extremely redundant. + // Instead, it'd be nice to introduce some sort of DictionaryValueVector. + vector SerializeToVector(const NDArrayViewPtr& viewPtr) + { + switch (viewPtr->GetDataType()) + { + case DataType::Float: + return SerializeToVector(viewPtr); + case DataType::Double: + return SerializeToVector(viewPtr); + default: + LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType())); + } + } + + void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector& values) + { + switch (viewPtr->GetDataType()) + { + case DataType::Float: + DeserializeFromVector(viewPtr, values); + break; + case DataType::Double: + DeserializeFromVector(viewPtr, values); + break; + default: + LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType())); + } + } + + template void DictionaryValue::AllocateDataPtr(const NDShape& value); + template void DictionaryValue::AllocateDataPtr>(const vector& value); } diff --git a/Source/CNTKv2LibraryDll/Utils.h b/Source/CNTKv2LibraryDll/Utils.h index 68aa651fb..c18d8e4d2 100644 --- a/Source/CNTKv2LibraryDll/Utils.h +++ b/Source/CNTKv2LibraryDll/Utils.h @@ -15,245 +15,6 @@ namespace CNTK // Forward declarations class Dictionary; - class DictionaryValue - { - public: - enum class Type : unsigned int - { - None, - Bool, - SizeT, - Double, - NDShape, - Vector - }; - - static const char* TypeName(Type type) - { - if (type == Type::None) - return "None"; - else if (type == Type::Bool) - return "Bool"; - else if (type == Type::SizeT) - return "SizeT"; - else if (type == Type::Double) - return "Double"; - else if (type == Type::NDShape) - return "NDShape"; - else if (type == Type::Vector) - return "Vector"; - else - LogicError("Unknown DictionaryValue::Type"); - } - - public: - DictionaryValue() - : m_valueType(Type::None) - { - } - - DictionaryValue(bool value) - : m_valueType(GetValueType()) - { - m_data.m_boolean = value; - } - - DictionaryValue(size_t value) - : m_valueType(GetValueType()) - { - m_data.m_sizeT = value; - } - - DictionaryValue(double value) - : m_valueType(GetValueType()) - { - m_data.m_double = value; - } - - template - DictionaryValue(const T& value) - : m_valueType(GetValueType()) - { - static_assert(std::is_same::value || - std::is_same>::value, - "Unsupported ValueType"); - - AllocateDataPtr(value); - } - - DictionaryValue(const DictionaryValue& other) - : m_valueType(Type::Bool) - { - // The m_valueType must hvae been set to a non-ptr type to prevent an attempt to interpret - // the underlying underlying uninitialized value as a ptr and free it. - *this = other; - } - - DictionaryValue& operator=(const DictionaryValue& other) - { - if (this != &other) - { - FreeDataPtr(); - - m_valueType = other.m_valueType; - m_data = other.m_data; - - if (other.m_valueType == Type::NDShape) - AllocateDataPtr(other.GetValue()); - else if (other.m_valueType == Type::Vector) - AllocateDataPtr(other.GetValue>()); - } - - return *this; - } - - ~DictionaryValue() - { - FreeDataPtr(); - } - - template ::value>::type* = nullptr> - const T& GetValue() const - { - VerifyType(); - return m_data.m_boolean; - } - - template ::value>::type* = nullptr> - const T& GetValue() const - { - VerifyType(); - return m_data.m_sizeT; - } - - template ::value>::type* = nullptr> - const T& GetValue() const - { - VerifyType(); - return m_data.m_double; - } - - template ::value || std::is_same>::value>::type* = nullptr> - const T& GetValue() const - { - VerifyType(); - return *(reinterpret_cast(m_data.m_ptr)); - } - - bool HasValue() const - { - return m_valueType != Type::None; - } - - Type ValueType() const - { - return m_valueType; - } - - private: - template - static Type GetValueType() - { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same>::value || - std::is_same::value, - "Unsupported ValueType"); - - if (std::is_same::value) - return Type::Bool; - else if (std::is_same::value) - return Type::SizeT; - else if (std::is_same::value) - return Type::Double; - else if (std::is_same::value) - return Type::NDShape; - else if (std::is_same>::value) - return Type::Vector; - } - - template - void VerifyType() const - { - if (GetValueType() != m_valueType) - RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType)); - } - - template - void AllocateDataPtr(const T& value) - { - static_assert(std::is_same::value || std::is_same>::value, "AllocateDataPtr called with invalid type"); - m_data.m_ptr = new T(value); - } - - template - void FreePtrAsType() - { - T* typedPtr = reinterpret_cast(m_data.m_ptr); - delete typedPtr; - - m_data.m_ptr = nullptr; - } - - void FreeDataPtr() - { - if (m_valueType == Type::NDShape) - FreePtrAsType(); - else if (m_valueType == Type::Vector) - FreePtrAsType>(); - } - - private: - Type m_valueType; - - union ValueData - { - bool m_boolean; - size_t m_sizeT; - double m_double; - void* m_ptr; - } m_data; - }; - - class Dictionary - { - public: - Dictionary(); - ~Dictionary(); - - // Disallow copy contruction and assignment - Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete; - - Dictionary(Dictionary&& other); - Dictionary& operator=(Dictionary&& other); - - DictionaryValue& operator[](const std::wstring& key) - { - return operator[](key.c_str()); - } - - DictionaryValue& operator[](const wchar_t* key); - - DictionaryValue operator[](const std::wstring& key) const - { - return operator[](key.c_str()); - } - - DictionaryValue operator[](const wchar_t* key) const; - - bool Contains(const std::wstring& key) const - { - return Contains(key.c_str()); - } - - bool Contains(const wchar_t* key) const; - - private: - std::unordered_map* m_dictionaryData; - }; - // Helper to get the size of an element of the specified DataType inline size_t ElementSize(DataType dataType) { @@ -363,4 +124,8 @@ namespace CNTK { return var.IsInput() && var.IsSparse(); } + + std::vector SerializeToVector(const NDArrayViewPtr& viewPtr); + + void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const std::vector& values); }