This commit is contained in:
Alexey Reznichenko 2016-07-13 09:37:06 +02:00
Родитель 271476466a
Коммит 1b0548fdde
10 изменённых файлов: 1301 добавлений и 243 удалений

Просмотреть файл

@ -375,6 +375,8 @@ CNTKLIBRARY_SRC =\
$(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \ $(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC) CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC)
CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC) CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC)

Просмотреть файл

@ -285,6 +285,7 @@ namespace CNTK
class NDArrayView final : public std::enable_shared_from_this<NDArrayView> class NDArrayView final : public std::enable_shared_from_this<NDArrayView>
{ {
friend class CompositeFunction; friend class CompositeFunction;
friend class LearnerBase;
template <typename T, typename ...CtorArgTypes> template <typename T, typename ...CtorArgTypes>
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs); friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
@ -1396,4 +1397,342 @@ namespace CNTK
/// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output. /// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output.
/// ///
CNTK_API FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name = L""); CNTK_API FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name = L"");
///
/// A serializable value represents one of:
/// a) Boolean
/// b) Signed long integer
/// c) Single and double precision floating point values
/// d) NDShape
/// e) vector<DictionaryValue>
///
/// TODO: we need to have native support for DictionaryValue<vector> and DictionaryValue<NDArrayView>.
class CNTK_API DictionaryValue final
{
public:
enum class Type : unsigned int
{
None,
Bool,
SizeT,
Float,
Double,
NDShape,
Vector
};
static const char* TypeName(Type type)
{
switch (type)
{
case Type::None:
return "None";
case Type::Bool:
return "Bool";
case Type::SizeT:
return "SizeT";
case Type::Float:
return "Float";
case Type::Double:
return "Double";
case Type::NDShape:
return "NDShape";
case Type::Vector:
return "Vector";
default:
LogicError("Unknown DictionaryValue::Type");
}
}
public:
DictionaryValue() : m_valueType(Type::None)
{
}
DictionaryValue(bool value) : m_valueType(GetValueType<bool>())
{
m_data.m_boolean = value;
}
DictionaryValue(size_t value) : m_valueType(GetValueType<size_t>())
{
m_data.m_sizeT = value;
}
DictionaryValue(float value) : m_valueType(GetValueType<float>())
{
m_data.m_float = value;
}
DictionaryValue(double value) : m_valueType(GetValueType<double>())
{
m_data.m_double = value;
}
template <typename T>
DictionaryValue(const T& value) : m_valueType(GetValueType<T>())
{
static_assert(std::is_same<T, NDShape>::value ||
std::is_same<T, std::vector<DictionaryValue>>::value,
"Unsupported ValueType");
AllocateDataPtr(value);
}
DictionaryValue(const DictionaryValue& other) : m_valueType(Type::Bool)
{
// The m_valueType must have been set to a non-ptr type to prevent an attempt to interpret
// the underlying underlying uninitialized value as a ptr and free it.
*this = other;
}
DictionaryValue& operator=(const DictionaryValue& other)
{
if (this != &other)
{
FreeDataPtr();
m_valueType = other.m_valueType;
m_data = other.m_data;
if (other.m_valueType == Type::NDShape)
AllocateDataPtr(other.GetValue<NDShape>());
else if (other.m_valueType == Type::Vector)
AllocateDataPtr(other.GetValue<std::vector<DictionaryValue>>());
}
return *this;
}
~DictionaryValue()
{
FreeDataPtr();
}
template <typename T, typename std::enable_if<std::is_same<T, bool>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_boolean;
}
template <typename T, typename std::enable_if<std::is_same<T, size_t>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_sizeT;
}
template <typename T, typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_float;
}
template <typename T, typename std::enable_if<std::is_same<T, double>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_double;
}
template <typename T, typename std::enable_if<std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return *(reinterpret_cast<T*>(m_data.m_ptr));
}
bool HasValue() const
{
return m_valueType != Type::None;
}
Type ValueType() const
{
return m_valueType;
}
friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us);
friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us);
private:
template <typename T>
static Type GetValueType()
{
static_assert(std::is_same<T, bool>::value ||
std::is_same<T, size_t>::value ||
std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, NDShape>::value ||
std::is_same<T, std::vector<DictionaryValue>>::value,
"Unsupported ValueType");
if (std::is_same<T, bool>::value) return Type::Bool;
if (std::is_same<T, size_t>::value) return Type::SizeT;
if (std::is_same<T, float>::value) return Type::Float;
if (std::is_same<T, double>::value) return Type::Double;
if (std::is_same<T, NDShape>::value) return Type::NDShape;
if (std::is_same<T, std::vector<DictionaryValue>>::value) return Type::Vector;
}
template <typename T>
void VerifyType() const
{
if (GetValueType<T>() != m_valueType)
RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType));
}
template <typename T>
void AllocateDataPtr(const T& value);
template <typename T>
void FreePtrAsType();
void FreeDataPtr();
Type m_valueType;
union ValueData
{
bool m_boolean;
size_t m_sizeT;
float m_float;
double m_double;
void* m_ptr;
} m_data;
const size_t version = 1;
};
///
/// A type denoting a dictionary (keyed by Unicode strings) of serializable values (dynamically typed).
///
class CNTK_API Dictionary final
{
public:
Dictionary();
~Dictionary();
// Disallow copy construction and assignment
Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete;
Dictionary(Dictionary&& other);
Dictionary& operator=(Dictionary&& other);
DictionaryValue& operator[](const std::wstring& key)
{
return operator[](key.c_str());
}
DictionaryValue& operator[](const wchar_t* key);
DictionaryValue operator[](const std::wstring& key) const
{
return operator[](key.c_str());
}
DictionaryValue operator[](const wchar_t* key) const;
bool Contains(const std::wstring& key) const
{
return Contains(key.c_str());
}
bool Contains(const wchar_t* key) const;
friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us);
friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us);
private:
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
const size_t version = 1;
};
///
/// Abstraction for learning a subset of parameters of a learnable function using first order gradient values
/// For e.g momentum, AdaGrad, RMSProp etc. are different types of learners with their own algorithms for
/// learning parameter values using first order gradients.
///
class Learner : public std::enable_shared_from_this<Learner>
{
public:
//
// Method to update the parameters associated with this learner. By returning false, this method indicates that
// learning has stopped for all of the parameters associated with this learner
//
CNTK_API virtual bool Update(const std::unordered_map<Variable, ValuePtr>& parameterValues,
const std::unordered_map<Variable, const ValuePtr>& gradientValues,
size_t trainingSampleCount) = 0;
///
/// Returns the set of parameters associated with this learner.
///
const std::unordered_set<Variable>& Parameters() const { return m_parameters; }
// TODO: move the following two methods into ISerializable interface, make
// Learner (and all other entities that need checkpointing capability) implement it.
///
/// Optionally overridable method to checkpoint the learner's state.
///
CNTK_API virtual Dictionary GetCheckpointState() const = 0;
///
/// Optionally overridable method to restore the learner's state from a previous checkpoint.
///
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
virtual ~Learner()
{
}
protected:
Learner(const std::unordered_set<Variable>& parameters)
: m_parameters(parameters)
{
}
std::unordered_set<Variable> m_parameters;
};
///
/// Create an instance of the CNTK built-in SGD learner.
///
/// TODO: add additional SGD parameters here (a collection of learning rate values)
CNTK_API LearnerPtr SGDLearner(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
///
/// Create an instance of the CNTK built-in Momentum SGD learner.
///
/// TODO: add additional Momentum parameters here (a collection of momentum rate values)
CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
///
/// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner.
///
CNTK_API LearnerPtr NesterovLearner(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
///
/// Create an instance of the CNTK built-in AdaGrad learner.
///
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Variable>& parameters, bool needAveMultiplier = true,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
///
/// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner.
///
CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
///
/// Create an instance of the CNTK built-in RMSProp learner.
///
CNTK_API LearnerPtr RMSPropLearner(const std::unordered_set<Variable>& parameters,
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier = true,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
} }

Просмотреть файл

@ -47,6 +47,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template <typename ElementType> template <typename ElementType>
class ComputationNode; class ComputationNode;
class File;
}}} }}}
// TODO: The following should be reconciled with the equivalent code in the CNTK implementation // TODO: The following should be reconciled with the equivalent code in the CNTK implementation
@ -158,4 +160,7 @@ namespace CNTK
class Function; class Function;
typedef std::shared_ptr<Function> FunctionPtr; typedef std::shared_ptr<Function> FunctionPtr;
class Learner;
typedef std::shared_ptr<Learner> LearnerPtr;
} }

Просмотреть файл

@ -128,6 +128,7 @@
<ClInclude Include="API\CNTKLibrary.h" /> <ClInclude Include="API\CNTKLibrary.h" />
<ClInclude Include="API\CNTKLibraryInternals.h" /> <ClInclude Include="API\CNTKLibraryInternals.h" />
<ClInclude Include="Function.h" /> <ClInclude Include="Function.h" />
<ClInclude Include="Learner.h" />
<ClInclude Include="Utils.h" /> <ClInclude Include="Utils.h" />
<ClInclude Include="stdafx.h" /> <ClInclude Include="stdafx.h" />
<ClInclude Include="targetver.h" /> <ClInclude Include="targetver.h" />
@ -140,6 +141,7 @@
</PrecompiledHeader> </PrecompiledHeader>
</ClCompile> </ClCompile>
<ClCompile Include="Function.cpp" /> <ClCompile Include="Function.cpp" />
<ClCompile Include="Learner.cpp" />
<ClCompile Include="NDArrayView.cpp" /> <ClCompile Include="NDArrayView.cpp" />
<ClCompile Include="NDMask.cpp" /> <ClCompile Include="NDMask.cpp" />
<ClCompile Include="stdafx.cpp"> <ClCompile Include="stdafx.cpp">

Просмотреть файл

@ -10,6 +10,7 @@
<ClCompile Include="Variable.cpp" /> <ClCompile Include="Variable.cpp" />
<ClCompile Include="Utils.cpp" /> <ClCompile Include="Utils.cpp" />
<ClCompile Include="NDMask.cpp" /> <ClCompile Include="NDMask.cpp" />
<ClCompile Include="Learner.cpp" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClInclude Include="stdafx.h" /> <ClInclude Include="stdafx.h" />
@ -22,6 +23,7 @@
<Filter>API</Filter> <Filter>API</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="Function.h" /> <ClInclude Include="Function.h" />
<ClInclude Include="Learner.h" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<Filter Include="API"> <Filter Include="API">

Просмотреть файл

@ -0,0 +1,464 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "Learner.h"
#include "TensorView.h"
#include "Utils.h"
#define UPDATE_FUNCTION \
switch (smoothedGradientValue->Data()->GetDataType()) \
{ \
case DataType::Float: \
Update<float>(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \
break; \
case DataType::Double: \
Update<double>(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \
break; \
default: \
NOT_IMPLEMENTED; \
}
using namespace Microsoft::MSR::CNTK;
using namespace std;
namespace CNTK
{
template <typename ElementType>
/*static*/ shared_ptr<const Matrix<ElementType>> LearnerBase::GetMatrix(const NDArrayViewPtr arrayView)
{
return arrayView->GetMatrix<ElementType>();
}
template <typename ElementType>
/*static*/ shared_ptr<Matrix<ElementType>> LearnerBase::GetWritableMatrix(NDArrayViewPtr arrayView)
{
return arrayView->GetWritableMatrix<ElementType>();
}
template <typename ElementType>
/*static*/ const TensorView<ElementType>* LearnerBase::GetTensorView(const NDArrayViewPtr arrayView)
{
return arrayView->GetTensorView<ElementType>();
}
/*static*/ bool LearnerBase::HasNan(const ValuePtr& value, const char* name)
{
const auto& data = value->Data();
switch (data->GetDataType())
{
case DataType::Float:
return data->GetMatrix<float>()->HasNan(name);
case DataType::Double:
return data->GetMatrix<double>()->HasNan(name);
default:
LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType()));
}
}
/*static*/ void LearnerBase::Print(const ValuePtr& value, const char* msg)
{
const auto& data = value->Data();
switch (data->GetDataType())
{
case DataType::Float:
data->GetMatrix<float>()->Print(msg);
break;
case DataType::Double:
data->GetMatrix<double>()->Print(msg);
break;
default:
LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType()));
}
}
// Clipping gradients to prevent outliers,
template <typename ElementType>
void LearnerBase::ClipGradient(Matrix<ElementType>& gradient, size_t actualMBSize) const
{
if (m_additionalOptions.gradientClippingThresholdPerSample != numeric_limits<double>::infinity())
{
double maxGradientPerMB = m_additionalOptions.gradientClippingThresholdPerSample * actualMBSize;
if (m_additionalOptions.gradientClippingWithTruncation)
gradient.InplaceTruncate(ElementType(maxGradientPerMB));
else
{
// norm2 normalized
double gradientNorm = gradient.FrobeniusNorm();
if (gradientNorm > maxGradientPerMB)
{
double normFactor = maxGradientPerMB / gradientNorm;
gradient *= ElementType(normFactor);
}
}
}
}
// Performs additional preprocessing before calling the update method
// (gradient clipping and L2 regularization depending on the additional learning parameters).
template <typename ElementType>
void LearnerBase::PreProcess(const ValuePtr& gradientValue,const ValuePtr& parameterValue, size_t actualMBSize) const
{
const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix<ElementType>();
// clipping gradients to prevent outliers
ClipGradient<ElementType>(*gradientMatrix, actualMBSize);
// L2 regularizer
if (m_additionalOptions.l2RegularizationWeight > 0)
{
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
auto weight = ElementType(m_additionalOptions.l2RegularizationWeight * actualMBSize);
const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix<ElementType>();
Matrix<ElementType>::ScaleAndAdd(weight, *parameterMatrix, *gradientMatrix);
}
}
// Performs additional postprocessing after the update method has been executed
// (noise injection and L1 regularization specified by the additional learning parameters).
template <typename ElementType>
void LearnerBase::PostProcess(const Variable& parameter, const ValuePtr& gradientValue,
const ValuePtr& parameterValue, size_t actualMBSize) const
{
const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix<ElementType>();
if (m_additionalOptions.gaussianNoiseInjectionStdDev > 0)
{
const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix<ElementType>();
Matrix<ElementType> sgdUpdateNoise((DEVICEID_TYPE)parameterMatrix->GetDeviceId());
// get the gradient structure since gradient is sparse
sgdUpdateNoise.SetValue(*gradientMatrix);
auto noiseStdDev = ElementType(m_additionalOptions.gaussianNoiseInjectionStdDev);
// reset its value to random
sgdUpdateNoise.SetGaussianRandomValue(ElementType(0.0), noiseStdDev);
Matrix<ElementType>::ScaleAndAdd(ElementType(1.0), sgdUpdateNoise, *parameterMatrix);
}
// L1 regularizer with proximal gradient descent method
if (m_additionalOptions.l1RegularizationWeight > 0)
{
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
auto weight = ElementType(learningRate * m_additionalOptions.l1RegularizationWeight * actualMBSize);
parameterValue->Data()->GetWritableMatrix<ElementType>()->InplaceSoftThreshold(weight);
}
}
template <typename ElementType>
/*static*/ TensorView<ElementType>* LearnerBase::GetWritableTensorView(NDArrayViewPtr arrayView)
{
return arrayView->GetWritableTensorView<ElementType>();
}
LearnerBase::LearnerBase(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
: Learner(parameters),
m_learningRatePerSample(0.0),
m_sampleCount(0)
{
const unordered_set<Variable>& parameterSet = parameters;
for (const auto& parameter : parameterSet)
{
// TODO: using the same device to allocate data for all smoothed gradients. Is this correct?
// Should the device be specified on the per-parameter basis?
NDArrayViewPtr view;
if (parameter.GetDataType() == DataType::Float)
{
view = MakeSharedObject<NDArrayView>(0.0f, parameter.Shape(), device);
}
else
{
view = MakeSharedObject<NDArrayView>(0.0, parameter.Shape(), device);
}
m_smoothedGradientValues.insert(make_pair(parameter, MakeSharedObject<Value>(view)));
m_additionalOptions.learningRateMultipliers.insert(make_pair(parameter, 1.0));
}
}
void LearnerBase::ResetSmoothedGradients()
{
for (const auto& parameter : Parameters())
{
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
const auto& data = smoothedGradientValue->Data();
switch (data->GetDataType())
{
case DataType::Float:
data->SetValue(0.0f);
break;
case DataType::Double:
data->SetValue(0.0);
break;
default:
LogicError("Unsupported DataType %s", ::CNTK::DataTypeName(data->GetDataType()));
}
}
}
/*virtual*/ bool LearnerBase::Update(const unordered_map<Variable, ValuePtr>& parameterValues,
const unordered_map<Variable, const ValuePtr>& gradientValues,
size_t trainingSampleCount) /*override*/
{
// make sure trainingSampleCount is a valid value
assert(trainingSampleCount > 0);
for (const auto& parameter : Parameters())
{
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
const auto& gradientValue = gradientValues.at(parameter);
const auto& parameterValue = parameterValues.at(parameter);
// TODO: make this a runtime parameter.
#if DUMPOUTPUT
LOGPRINTF(stderr, "Update_%ls\n", parameter.Name().c_str());
#endif
#ifdef _DEBUG
if (HasNan(smoothedGradientValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
LogicError("%ls has NaNs in smoothedGradient.", parameter.Name().c_str());
#endif
#if DUMPOUTPUT
LOGPRINTF(stderr, "learnRatePerSample=%0.8f, momentum=%0.8f, actualMBSize=%ld\n",
m_learningRatePerSample, m_momentumPerSample, trainingSampleCount);
LOGPRINTF(stderr, "GradUpdateType()=%s, GradientUpdateNoiseStd()=%0.8f\n",
LearnerType().c_str(), m_GaussianNoiseInjectStd);
Print(gradientValue, "Gradient Update");
Print(smoothedGradientValue, "Smoothed Gradient Input");
#endif
UPDATE_FUNCTION;
#if DUMPOUTPUT
Print(parameterValue, "Parameter Update");
#endif
#ifdef _DEBUG
if (HasNan(parameterValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Name().c_str());
#endif
}
m_sampleCount += trainingSampleCount;
return false;
}
template <typename ElementType>
void LearnerBase::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
{
PreProcess<ElementType>(gradientValue, parameterValue, trainingSampleCount);
Update(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount);
PostProcess<ElementType>(parameter, gradientValue, parameterValue, trainingSampleCount);
}
string LearnerBase::LearnerType() const
{
auto name = typeid(*this).name();
if (strncmp(name, "class ", 6) == 0)
{
// On Windows, the type name contains "class" prefix.
// Return the actual name, omitting the prefix.
return &name[6];
}
return name;
}
/*virtual*/ Dictionary LearnerBase::GetCheckpointState() const /*override*/
{
NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this.
Dictionary checkpoint;
for (const auto& parameter : Parameters())
{
// TODO: parameter name is not guaranteed to be unique. Instead, all serializable objects
// need to expose "UId" property -- a persistent unique internal name.
// Switch to UId as soon as it's available.
if (checkpoint.Contains(parameter.Name()))
{
LogicError("Parameter names must be unique");
}
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
// Potentially, could store things like dimensions, element size, format, etc., but
// that seems to be redundant, since all of that is passed in the constructor.
checkpoint[parameter.Name()] = SerializeToVector(smoothedGradientValue->Data());
}
return checkpoint;
}
/*virtual*/ void LearnerBase::RestoreFromCheckpoint(const Dictionary& checkpoint) /*override*/
{
NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this.
for (const auto& parameter : Parameters())
{
if (!checkpoint.Contains(parameter.Name()))
{
LogicError("Checkpoint does not contain state for parameter %ls", parameter.Name().c_str());
}
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
const DictionaryValue& state = checkpoint[parameter.Name()];
const auto& data = smoothedGradientValue->Data();
DeserializeFromVector(data, state.GetValue<vector<DictionaryValue>>());
}
}
/*virtual*/ void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
{
UPDATE_FUNCTION;
}
template <typename ElementType>
void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
{
UNUSED(trainingSampleCount);
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
const auto& learningRate = ElementType(ParameterDependentLearningRate(parameter));
// TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters
// (one for vanilla SGD, the other for momentum SGD, and the third one for NAG).
smoothedGradientMatrix->NormalGrad(*gradientMatrix, *parameterMatrix,
learningRate, ElementType(m_momentumPerSample), m_useNesterovAcceleration);
}
LearnerAdaGrad::LearnerAdaGrad(const unordered_set<Variable>& parameters, bool needAveMultiplier, const DeviceDescriptor& device)
: LearnerBase(parameters, device),
m_needAveMultiplier(needAveMultiplier)
{
}
/*virtual*/ void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
{
UPDATE_FUNCTION;
}
template <typename ElementType>
void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
{
UNUSED(trainingSampleCount);
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
auto aveMultiplier = smoothedGradientMatrix->Adagrad(*gradientMatrix, m_needAveMultiplier);
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
}
LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
: LearnerMomentumSGD(parameters, device)
{
}
/*virtual*/ void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
{
UPDATE_FUNCTION;
}
template <typename ElementType>
void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
{
UNUSED(trainingSampleCount);
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
//const double momentum = MomentumPerMB(m_momentumPerSample, trainingSampleCount);
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
smoothedGradientMatrix->FSAdagrad(trainingSampleCount, *gradientMatrix, *parameterMatrix,
learningRate, ElementType(m_momentumPerSample));
}
LearnerRMSProp::LearnerRMSProp(const unordered_set<Variable>& parameters,
double gamma, double inc, double dec, double max, double min,
bool needAveMultiplier, const DeviceDescriptor& device)
: LearnerBase(parameters, device),
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min),
m_needAveMultiplier(needAveMultiplier)
{
}
/*virtual*/ void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
{
UPDATE_FUNCTION;
}
template <typename ElementType>
void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
{
UNUSED(trainingSampleCount);
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
auto aveMultiplier = smoothedGradientMatrix->RmsProp(*gradientMatrix,
ElementType(m_gamma), ElementType(m_inc),
ElementType(m_max), ElementType(m_dec),
ElementType(m_min), m_needAveMultiplier);
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
}
// Explicit template instantiations
template shared_ptr<Matrix<float>> LearnerBase::GetWritableMatrix<float>(const NDArrayViewPtr arrayView);
template shared_ptr<Matrix<double>> LearnerBase::GetWritableMatrix<double>(const NDArrayViewPtr arrayView);
LearnerPtr SGDLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
{
return MakeSharedObject<LearnerSGD>(parameters, device);
}
LearnerPtr MomentumSGDLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
{
return MakeSharedObject<LearnerMomentumSGD>(parameters, device);
}
LearnerPtr NesterovLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
{
return MakeSharedObject<LearnerNesterov>(parameters, device);
}
LearnerPtr AdaGradLearner(const unordered_set<Variable>& parameters, bool needAveMultiplier, const DeviceDescriptor& device)
{
return MakeSharedObject<LearnerAdaGrad>(parameters, needAveMultiplier, device);
}
LearnerPtr FSAdaGradLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
{
return MakeSharedObject<LearnerFSAdaGrad>(parameters, device);
}
LearnerPtr RMSPropLearner(const unordered_set<Variable>& parameters,
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
const DeviceDescriptor& device)
{
return MakeSharedObject<LearnerRMSProp>(parameters, gamma, inc, dec, max, min, needAveMultiplier, device);
}
}

Просмотреть файл

@ -0,0 +1,224 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "CNTKLibrary.h"
namespace CNTK
{
// A collection of additional options that are applicable for all standard learners
// (after these options are set, they retain their value for the entire lifespan of a learner).
struct AdditionalLearningOptions
{
double l1RegularizationWeight = 0.0;
double l2RegularizationWeight = 0.0;
double gaussianNoiseInjectionStdDev = 0.0;
bool gradientClippingWithTruncation = false;
double gradientClippingThresholdPerSample = 0.0;
std::unordered_map<Variable, double> learningRateMultipliers;
};
// An abstract base class at the root of the standard learners hierarchy
// It implements most of the learner functionality, except for the actual update function,
// and adds a few pre-/postprocessing methods (which are invoked before and after the update).
class LearnerBase : public Learner
{
public:
CNTK_API virtual bool Update(const std::unordered_map<Variable, ValuePtr>& parameterValues,
const std::unordered_map<Variable, const ValuePtr>& gradientValues,
size_t trainingSampleCount) override final;
CNTK_API virtual Dictionary GetCheckpointState() const override;
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
CNTK_API void SetAdditionalOptions(const AdditionalLearningOptions& additionalOptions)
{
m_additionalOptions = additionalOptions;
}
// TODO: should this be called ResetMomentum?
// needed for BlockMomemtumSGD to reset SGD momentum after aggregation.
CNTK_API void ResetSmoothedGradients();
// TODO: move learning rate and momentum scheduling and adjustment functionality
// inside the learner and drop these setters.
void SetLearningRate(double value) { m_learningRatePerSample = value; }
protected:
LearnerBase(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const = 0;
double ParameterDependentLearningRate(const Variable& parameter) const
{
return m_learningRatePerSample * m_additionalOptions.learningRateMultipliers.at(parameter);
}
std::string LearnerType() const;
double m_learningRatePerSample;
AdditionalLearningOptions m_additionalOptions;
std::unordered_map<Variable, ValuePtr> m_smoothedGradientValues;
// The following four static protected methods expose private methods of NDArrayView class
// (which declares LearnerBase as friend class), so that they are available to subclasses.
template <typename ElementType>
static std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrix(const NDArrayViewPtr arrayView);
template <typename ElementType>
static std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetWritableMatrix(NDArrayViewPtr arrayView);
template <typename ElementType>
static const Microsoft::MSR::CNTK::TensorView<ElementType>* GetTensorView(const NDArrayViewPtr arrayView);
template <typename ElementType>
static Microsoft::MSR::CNTK::TensorView<ElementType>* GetWritableTensorView(NDArrayViewPtr arrayView);
template <typename ElementType>
void ClipGradient(Microsoft::MSR::CNTK::Matrix<ElementType>& gradient, size_t actualMBSize) const;
// Performs additional preprocessing before calling the update method
// (gradient clipping and L2 regularization depending on the additional learning parameters).
template <typename ElementType>
void PreProcess(const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t actualMBSize) const;
// Performs additional postprocessing after the update method has been executed
// (noise injection and L1 regularization specified by the additional learning parameters).
template <typename ElementType>
void PostProcess(const Variable& parameter, const ValuePtr& gradientValue,
const ValuePtr& parameterValue, size_t actualMBSize) const;
private:
// Templatized update function, it invokes preprocess and postprocess using the provided
// template parameter and also invokes virtual Update method implemented in one of the subclasses.
template <typename ElementType>
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
// TODO: make these functions friends of NDViewArray and move to Utils?
static bool HasNan(const ValuePtr& value, const char* name);
static void Print(const ValuePtr& value, const char* msg);
size_t m_sampleCount;
};
// Vanilla gradient descent optimization algorithm.
class LearnerSGD : public LearnerBase
{
public:
LearnerSGD(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
: LearnerBase(parameters, device),
m_momentumPerSample(0.0),
m_useNesterovAcceleration(false)
{
}
protected:
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
template <typename ElementType>
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
double m_momentumPerSample;
bool m_useNesterovAcceleration;
};
// SGD optimization with momentum.
class LearnerMomentumSGD : public LearnerSGD
{
public:
LearnerMomentumSGD(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
: LearnerSGD(parameters, device)
{
}
void SetMomentum(double value) { m_momentumPerSample = value; }
};
// Nesterov's accelerated SGDLearnerBase descent.
class LearnerNesterov : public LearnerSGD
{
public:
LearnerNesterov(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
: LearnerSGD(parameters, device)
{
m_useNesterovAcceleration = true;
}
};
class LearnerAdaGrad : public LearnerBase
{
public:
LearnerAdaGrad(const std::unordered_set<Variable>& parameters, bool needAveMultiplier,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
protected:
bool m_needAveMultiplier;
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
template <typename ElementType>
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
};
class LearnerFSAdaGrad : public LearnerMomentumSGD
{
public:
LearnerFSAdaGrad(const std::unordered_set<Variable>& parameters,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
protected:
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
template <typename ElementType>
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
};
class LearnerRMSProp : public LearnerBase
{
public:
LearnerRMSProp(const std::unordered_set<Variable>& parameters,
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
protected:
double m_gamma;
double m_inc;
double m_dec;
double m_max;
double m_min;
bool m_needAveMultiplier;
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
template <typename ElementType>
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
};
}

Просмотреть файл

@ -338,8 +338,10 @@ namespace CNTK
template std::shared_ptr<const Matrix<float>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const; template std::shared_ptr<const Matrix<float>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
template std::shared_ptr<const Matrix<double>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const; template std::shared_ptr<const Matrix<double>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
template std::shared_ptr<Matrix<float>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/); template std::shared_ptr<Matrix<float>> NDArrayView::GetWritableMatrix<float>(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
template std::shared_ptr<Matrix<double>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/); template std::shared_ptr<Matrix<double>> NDArrayView::GetWritableMatrix<double>(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
template TensorView<float>* NDArrayView::GetWritableTensorView<float>();
template TensorView<double>* NDArrayView::GetWritableTensorView<double>();
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const float* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/); template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const float* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const double* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/); template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const double* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);

Просмотреть файл

@ -6,11 +6,138 @@
#include "stdafx.h" #include "stdafx.h"
#include "CNTKLibrary.h" #include "CNTKLibrary.h"
#include "Utils.h" #include "Utils.h"
#include "File.h"
using namespace std;
namespace CNTK namespace CNTK
{ {
template <typename T>
void DictionaryValue::AllocateDataPtr(const T& value)
{
static_assert(is_same<T, NDShape>::value || is_same<T, vector<DictionaryValue>>::value, "AllocateDataPtr called with invalid type");
m_data.m_ptr = new T(value);
}
template <typename T>
void DictionaryValue::FreePtrAsType()
{
T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
delete typedPtr;
m_data.m_ptr = nullptr;
}
void DictionaryValue::FreeDataPtr()
{
if (m_valueType == Type::NDShape)
FreePtrAsType<NDShape>();
else if (m_valueType == Type::Vector)
FreePtrAsType<vector<DictionaryValue>>();
}
Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us)
{
size_t version;
stream >> version;
stream >> us.m_valueType;
switch (us.ValueType())
{
case DictionaryValue::Type::Bool:
stream >> us.m_data.m_boolean;
break;
case DictionaryValue::Type::SizeT:
stream >> us.m_data.m_sizeT;
break;
case DictionaryValue::Type::Float:
stream >> us.m_data.m_float;
break;
case DictionaryValue::Type::Double:
stream >> us.m_data.m_double;
break;
case DictionaryValue::Type::NDShape:
{
size_t size;
stream >> size;
vector<size_t> dims(size);
for (auto i = 0; i < size; i++)
{
stream >> dims[i];
}
us.AllocateDataPtr(NDShape(dims));
break;
}
case DictionaryValue::Type::Vector:
{
size_t size;
stream >> size;
vector<DictionaryValue> values(size);
for (auto i = 0; i < size; i++)
{
stream >> values[i];
}
us.AllocateDataPtr(values);
break;
}
default:
NOT_IMPLEMENTED;
}
return stream;
}
Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us)
{
stream << us.version;
stream << us.ValueType();
switch (us.ValueType())
{
case DictionaryValue::Type::Bool:
stream << us.m_data.m_boolean;
break;
case DictionaryValue::Type::SizeT:
stream << us.m_data.m_sizeT;
break;
case DictionaryValue::Type::Float:
stream << us.m_data.m_float;
break;
case DictionaryValue::Type::Double:
stream << us.m_data.m_double;
break;
case DictionaryValue::Type::NDShape:
{
NDShape* shapePtr = reinterpret_cast<NDShape*>(us.m_data.m_ptr);
auto size = shapePtr->NumAxes();
stream << size;
for (auto i = 0; i < size; i++)
{
stream << shapePtr->operator[](i);
}
break;
}
case DictionaryValue::Type::Vector:
{
vector<DictionaryValue>* vectorPtr =
reinterpret_cast<vector<DictionaryValue>*>(us.m_data.m_ptr);
auto size = vectorPtr->size();
stream << size;
for (auto i = 0; i < size; i++)
{
stream << vectorPtr->operator[](i);
}
break;
}
default:
NOT_IMPLEMENTED;
}
return stream;
}
Dictionary::Dictionary() Dictionary::Dictionary()
: m_dictionaryData(new std::unordered_map < std::wstring, DictionaryValue>) : m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
{ {
} }
@ -22,7 +149,7 @@ namespace CNTK
Dictionary::Dictionary(Dictionary&& other) Dictionary::Dictionary(Dictionary&& other)
: m_dictionaryData(nullptr) : m_dictionaryData(nullptr)
{ {
*this = std::move(other); *this = move(other);
} }
Dictionary& Dictionary::operator=(Dictionary&& other) Dictionary& Dictionary::operator=(Dictionary&& other)
@ -51,4 +178,130 @@ namespace CNTK
{ {
return (m_dictionaryData->find(key) != m_dictionaryData->end()); return (m_dictionaryData->find(key) != m_dictionaryData->end());
} }
Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us)
{
stream << us.version;
stream << us.m_dictionaryData->size();
for (auto it = us.m_dictionaryData->begin(); it != us.m_dictionaryData->end(); ++it)
{
stream << it->first;
stream << it->second;
}
return stream;
}
Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us)
{
size_t version;
stream >> version;
size_t size;
stream >> size;
us.m_dictionaryData->reserve(size);
for (auto i = 0; i < size; i++)
{
wstring key;
stream >> key;
DictionaryValue value;
stream >> value;
us.m_dictionaryData->insert(make_pair(key, value));
}
return stream;
}
template <typename T>
vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr)
{
if (viewPtr->IsSparse())
{
LogicError("Sparse NDArrayView cannot be serialized into a vector.");
}
auto numElements = viewPtr->Shape().TotalSize();
vector<DictionaryValue> values(numElements);
NDArrayViewPtr cpuDataViewPtr = viewPtr;
if ((viewPtr->Device().Type() != DeviceKind::CPU))
{
cpuDataViewPtr = MakeSharedObject<NDArrayView>(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice());
cpuDataViewPtr->CopyFrom(*viewPtr);
}
const T* buffer = cpuDataViewPtr->DataBuffer<T>();
for (auto i = 0; i < numElements; ++i)
{
T v = buffer[i];
values[i] = DictionaryValue(v);
}
return values;
}
template <typename T>
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector<DictionaryValue>& values)
{
if (viewPtr->IsSparse())
{
LogicError("Sparse NDArrayView cannot be deserialized from a vector.");
}
auto numElements = viewPtr->Shape().TotalSize();
if (values.size() != numElements)
{
LogicError("Number of elements (%lu) in the deserialized representation does not match the expected value (%lu)",
values.size(), numElements);
}
NDArrayViewPtr cpuDataViewPtr = viewPtr;
if ((viewPtr->Device().Type() != DeviceKind::CPU))
{
cpuDataViewPtr = MakeSharedObject<NDArrayView>(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice());
}
T* buffer = cpuDataViewPtr->WritableDataBuffer<T>();
for (auto i = 0; i < numElements; ++i)
{
buffer[i] = values[i].GetValue<T>();
}
if ((viewPtr->Device().Type() != DeviceKind::CPU))
{
viewPtr->CopyFrom(*cpuDataViewPtr);
}
}
// TODO: we store the type info for every element in the vector, which is extremely redundant.
// Instead, it'd be nice to introduce some sort of DictionaryValueVector.
vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr)
{
switch (viewPtr->GetDataType())
{
case DataType::Float:
return SerializeToVector<float>(viewPtr);
case DataType::Double:
return SerializeToVector<double>(viewPtr);
default:
LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
}
}
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector<DictionaryValue>& values)
{
switch (viewPtr->GetDataType())
{
case DataType::Float:
DeserializeFromVector<float>(viewPtr, values);
break;
case DataType::Double:
DeserializeFromVector<double>(viewPtr, values);
break;
default:
LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
}
}
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
} }

Просмотреть файл

@ -15,245 +15,6 @@ namespace CNTK
// Forward declarations // Forward declarations
class Dictionary; class Dictionary;
class DictionaryValue
{
public:
enum class Type : unsigned int
{
None,
Bool,
SizeT,
Double,
NDShape,
Vector
};
static const char* TypeName(Type type)
{
if (type == Type::None)
return "None";
else if (type == Type::Bool)
return "Bool";
else if (type == Type::SizeT)
return "SizeT";
else if (type == Type::Double)
return "Double";
else if (type == Type::NDShape)
return "NDShape";
else if (type == Type::Vector)
return "Vector";
else
LogicError("Unknown DictionaryValue::Type");
}
public:
DictionaryValue()
: m_valueType(Type::None)
{
}
DictionaryValue(bool value)
: m_valueType(GetValueType<bool>())
{
m_data.m_boolean = value;
}
DictionaryValue(size_t value)
: m_valueType(GetValueType<size_t>())
{
m_data.m_sizeT = value;
}
DictionaryValue(double value)
: m_valueType(GetValueType<double>())
{
m_data.m_double = value;
}
template <typename T>
DictionaryValue(const T& value)
: m_valueType(GetValueType<T>())
{
static_assert(std::is_same<T, NDShape>::value ||
std::is_same<T, std::vector<DictionaryValue>>::value,
"Unsupported ValueType");
AllocateDataPtr(value);
}
DictionaryValue(const DictionaryValue& other)
: m_valueType(Type::Bool)
{
// The m_valueType must hvae been set to a non-ptr type to prevent an attempt to interpret
// the underlying underlying uninitialized value as a ptr and free it.
*this = other;
}
DictionaryValue& operator=(const DictionaryValue& other)
{
if (this != &other)
{
FreeDataPtr();
m_valueType = other.m_valueType;
m_data = other.m_data;
if (other.m_valueType == Type::NDShape)
AllocateDataPtr(other.GetValue<NDShape>());
else if (other.m_valueType == Type::Vector)
AllocateDataPtr(other.GetValue<std::vector<DictionaryValue>>());
}
return *this;
}
~DictionaryValue()
{
FreeDataPtr();
}
template <typename T, typename std::enable_if<std::is_same<T, bool>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_boolean;
}
template <typename T, typename std::enable_if<std::is_same<T, size_t>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_sizeT;
}
template <typename T, typename std::enable_if<std::is_same<T, double>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return m_data.m_double;
}
template <typename T, typename std::enable_if<std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value>::type* = nullptr>
const T& GetValue() const
{
VerifyType<T>();
return *(reinterpret_cast<T*>(m_data.m_ptr));
}
bool HasValue() const
{
return m_valueType != Type::None;
}
Type ValueType() const
{
return m_valueType;
}
private:
template <typename T>
static Type GetValueType()
{
static_assert(std::is_same<T, bool>::value ||
std::is_same<T, size_t>::value ||
std::is_same<T, double>::value ||
std::is_same<T, NDShape>::value ||
std::is_same<T, std::vector<DictionaryValue>>::value ||
std::is_same<T, CNTK::Dictionary>::value,
"Unsupported ValueType");
if (std::is_same<T, bool>::value)
return Type::Bool;
else if (std::is_same<T, size_t>::value)
return Type::SizeT;
else if (std::is_same<T, double>::value)
return Type::Double;
else if (std::is_same<T, NDShape>::value)
return Type::NDShape;
else if (std::is_same<T, std::vector<DictionaryValue>>::value)
return Type::Vector;
}
template <typename T>
void VerifyType() const
{
if (GetValueType<T>() != m_valueType)
RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType));
}
template <typename T>
void AllocateDataPtr(const T& value)
{
static_assert(std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value, "AllocateDataPtr called with invalid type");
m_data.m_ptr = new T(value);
}
template <typename T>
void FreePtrAsType()
{
T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
delete typedPtr;
m_data.m_ptr = nullptr;
}
void FreeDataPtr()
{
if (m_valueType == Type::NDShape)
FreePtrAsType<NDShape>();
else if (m_valueType == Type::Vector)
FreePtrAsType<std::vector<DictionaryValue>>();
}
private:
Type m_valueType;
union ValueData
{
bool m_boolean;
size_t m_sizeT;
double m_double;
void* m_ptr;
} m_data;
};
class Dictionary
{
public:
Dictionary();
~Dictionary();
// Disallow copy contruction and assignment
Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete;
Dictionary(Dictionary&& other);
Dictionary& operator=(Dictionary&& other);
DictionaryValue& operator[](const std::wstring& key)
{
return operator[](key.c_str());
}
DictionaryValue& operator[](const wchar_t* key);
DictionaryValue operator[](const std::wstring& key) const
{
return operator[](key.c_str());
}
DictionaryValue operator[](const wchar_t* key) const;
bool Contains(const std::wstring& key) const
{
return Contains(key.c_str());
}
bool Contains(const wchar_t* key) const;
private:
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
};
// Helper to get the size of an element of the specified DataType // Helper to get the size of an element of the specified DataType
inline size_t ElementSize(DataType dataType) inline size_t ElementSize(DataType dataType)
{ {
@ -363,4 +124,8 @@ namespace CNTK
{ {
return var.IsInput() && var.IsSparse(); return var.IsInput() && var.IsSparse();
} }
std::vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr);
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const std::vector<DictionaryValue>& values);
} }