Add v2 Learners (standalone)
This commit is contained in:
Родитель
271476466a
Коммит
1b0548fdde
2
Makefile
2
Makefile
|
@ -375,6 +375,8 @@ CNTKLIBRARY_SRC =\
|
||||||
$(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \
|
$(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \
|
||||||
$(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \
|
$(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \
|
||||||
$(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \
|
$(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \
|
||||||
|
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
|
||||||
|
|
||||||
|
|
||||||
CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC)
|
CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC)
|
||||||
CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC)
|
CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC)
|
||||||
|
|
|
@ -285,6 +285,7 @@ namespace CNTK
|
||||||
class NDArrayView final : public std::enable_shared_from_this<NDArrayView>
|
class NDArrayView final : public std::enable_shared_from_this<NDArrayView>
|
||||||
{
|
{
|
||||||
friend class CompositeFunction;
|
friend class CompositeFunction;
|
||||||
|
friend class LearnerBase;
|
||||||
|
|
||||||
template <typename T, typename ...CtorArgTypes>
|
template <typename T, typename ...CtorArgTypes>
|
||||||
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
||||||
|
@ -1396,4 +1397,342 @@ namespace CNTK
|
||||||
/// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output.
|
/// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output.
|
||||||
///
|
///
|
||||||
CNTK_API FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name = L"");
|
CNTK_API FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name = L"");
|
||||||
|
|
||||||
|
///
|
||||||
|
/// A serializable value represents one of:
|
||||||
|
/// a) Boolean
|
||||||
|
/// b) Signed long integer
|
||||||
|
/// c) Single and double precision floating point values
|
||||||
|
/// d) NDShape
|
||||||
|
/// e) vector<DictionaryValue>
|
||||||
|
///
|
||||||
|
/// TODO: we need to have native support for DictionaryValue<vector> and DictionaryValue<NDArrayView>.
|
||||||
|
class CNTK_API DictionaryValue final
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
enum class Type : unsigned int
|
||||||
|
{
|
||||||
|
None,
|
||||||
|
Bool,
|
||||||
|
SizeT,
|
||||||
|
Float,
|
||||||
|
Double,
|
||||||
|
NDShape,
|
||||||
|
Vector
|
||||||
|
};
|
||||||
|
|
||||||
|
static const char* TypeName(Type type)
|
||||||
|
{
|
||||||
|
switch (type)
|
||||||
|
{
|
||||||
|
case Type::None:
|
||||||
|
return "None";
|
||||||
|
case Type::Bool:
|
||||||
|
return "Bool";
|
||||||
|
case Type::SizeT:
|
||||||
|
return "SizeT";
|
||||||
|
case Type::Float:
|
||||||
|
return "Float";
|
||||||
|
case Type::Double:
|
||||||
|
return "Double";
|
||||||
|
case Type::NDShape:
|
||||||
|
return "NDShape";
|
||||||
|
case Type::Vector:
|
||||||
|
return "Vector";
|
||||||
|
default:
|
||||||
|
LogicError("Unknown DictionaryValue::Type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
DictionaryValue() : m_valueType(Type::None)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue(bool value) : m_valueType(GetValueType<bool>())
|
||||||
|
{
|
||||||
|
m_data.m_boolean = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue(size_t value) : m_valueType(GetValueType<size_t>())
|
||||||
|
{
|
||||||
|
m_data.m_sizeT = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue(float value) : m_valueType(GetValueType<float>())
|
||||||
|
{
|
||||||
|
m_data.m_float = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue(double value) : m_valueType(GetValueType<double>())
|
||||||
|
{
|
||||||
|
m_data.m_double = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
DictionaryValue(const T& value) : m_valueType(GetValueType<T>())
|
||||||
|
{
|
||||||
|
static_assert(std::is_same<T, NDShape>::value ||
|
||||||
|
std::is_same<T, std::vector<DictionaryValue>>::value,
|
||||||
|
"Unsupported ValueType");
|
||||||
|
|
||||||
|
AllocateDataPtr(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue(const DictionaryValue& other) : m_valueType(Type::Bool)
|
||||||
|
{
|
||||||
|
// The m_valueType must have been set to a non-ptr type to prevent an attempt to interpret
|
||||||
|
// the underlying underlying uninitialized value as a ptr and free it.
|
||||||
|
*this = other;
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue& operator=(const DictionaryValue& other)
|
||||||
|
{
|
||||||
|
if (this != &other)
|
||||||
|
{
|
||||||
|
FreeDataPtr();
|
||||||
|
|
||||||
|
m_valueType = other.m_valueType;
|
||||||
|
m_data = other.m_data;
|
||||||
|
|
||||||
|
if (other.m_valueType == Type::NDShape)
|
||||||
|
AllocateDataPtr(other.GetValue<NDShape>());
|
||||||
|
else if (other.m_valueType == Type::Vector)
|
||||||
|
AllocateDataPtr(other.GetValue<std::vector<DictionaryValue>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~DictionaryValue()
|
||||||
|
{
|
||||||
|
FreeDataPtr();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename std::enable_if<std::is_same<T, bool>::value>::type* = nullptr>
|
||||||
|
const T& GetValue() const
|
||||||
|
{
|
||||||
|
VerifyType<T>();
|
||||||
|
return m_data.m_boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename std::enable_if<std::is_same<T, size_t>::value>::type* = nullptr>
|
||||||
|
const T& GetValue() const
|
||||||
|
{
|
||||||
|
VerifyType<T>();
|
||||||
|
return m_data.m_sizeT;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||||
|
const T& GetValue() const
|
||||||
|
{
|
||||||
|
VerifyType<T>();
|
||||||
|
return m_data.m_float;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename std::enable_if<std::is_same<T, double>::value>::type* = nullptr>
|
||||||
|
const T& GetValue() const
|
||||||
|
{
|
||||||
|
VerifyType<T>();
|
||||||
|
return m_data.m_double;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename std::enable_if<std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value>::type* = nullptr>
|
||||||
|
const T& GetValue() const
|
||||||
|
{
|
||||||
|
VerifyType<T>();
|
||||||
|
return *(reinterpret_cast<T*>(m_data.m_ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasValue() const
|
||||||
|
{
|
||||||
|
return m_valueType != Type::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type ValueType() const
|
||||||
|
{
|
||||||
|
return m_valueType;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us);
|
||||||
|
friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us);
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
static Type GetValueType()
|
||||||
|
{
|
||||||
|
static_assert(std::is_same<T, bool>::value ||
|
||||||
|
std::is_same<T, size_t>::value ||
|
||||||
|
std::is_same<T, float>::value ||
|
||||||
|
std::is_same<T, double>::value ||
|
||||||
|
std::is_same<T, NDShape>::value ||
|
||||||
|
std::is_same<T, std::vector<DictionaryValue>>::value,
|
||||||
|
"Unsupported ValueType");
|
||||||
|
|
||||||
|
if (std::is_same<T, bool>::value) return Type::Bool;
|
||||||
|
if (std::is_same<T, size_t>::value) return Type::SizeT;
|
||||||
|
if (std::is_same<T, float>::value) return Type::Float;
|
||||||
|
if (std::is_same<T, double>::value) return Type::Double;
|
||||||
|
if (std::is_same<T, NDShape>::value) return Type::NDShape;
|
||||||
|
if (std::is_same<T, std::vector<DictionaryValue>>::value) return Type::Vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void VerifyType() const
|
||||||
|
{
|
||||||
|
if (GetValueType<T>() != m_valueType)
|
||||||
|
RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AllocateDataPtr(const T& value);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void FreePtrAsType();
|
||||||
|
|
||||||
|
void FreeDataPtr();
|
||||||
|
|
||||||
|
Type m_valueType;
|
||||||
|
|
||||||
|
union ValueData
|
||||||
|
{
|
||||||
|
bool m_boolean;
|
||||||
|
size_t m_sizeT;
|
||||||
|
float m_float;
|
||||||
|
double m_double;
|
||||||
|
void* m_ptr;
|
||||||
|
} m_data;
|
||||||
|
|
||||||
|
const size_t version = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// A type denoting a dictionary (keyed by Unicode strings) of serializable values (dynamically typed).
|
||||||
|
///
|
||||||
|
class CNTK_API Dictionary final
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Dictionary();
|
||||||
|
~Dictionary();
|
||||||
|
|
||||||
|
// Disallow copy construction and assignment
|
||||||
|
Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete;
|
||||||
|
|
||||||
|
Dictionary(Dictionary&& other);
|
||||||
|
Dictionary& operator=(Dictionary&& other);
|
||||||
|
|
||||||
|
DictionaryValue& operator[](const std::wstring& key)
|
||||||
|
{
|
||||||
|
return operator[](key.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue& operator[](const wchar_t* key);
|
||||||
|
|
||||||
|
DictionaryValue operator[](const std::wstring& key) const
|
||||||
|
{
|
||||||
|
return operator[](key.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
DictionaryValue operator[](const wchar_t* key) const;
|
||||||
|
|
||||||
|
bool Contains(const std::wstring& key) const
|
||||||
|
{
|
||||||
|
return Contains(key.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Contains(const wchar_t* key) const;
|
||||||
|
|
||||||
|
friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us);
|
||||||
|
friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
|
||||||
|
const size_t version = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Abstraction for learning a subset of parameters of a learnable function using first order gradient values
|
||||||
|
/// For e.g momentum, AdaGrad, RMSProp etc. are different types of learners with their own algorithms for
|
||||||
|
/// learning parameter values using first order gradients.
|
||||||
|
///
|
||||||
|
class Learner : public std::enable_shared_from_this<Learner>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
//
|
||||||
|
// Method to update the parameters associated with this learner. By returning false, this method indicates that
|
||||||
|
// learning has stopped for all of the parameters associated with this learner
|
||||||
|
//
|
||||||
|
CNTK_API virtual bool Update(const std::unordered_map<Variable, ValuePtr>& parameterValues,
|
||||||
|
const std::unordered_map<Variable, const ValuePtr>& gradientValues,
|
||||||
|
size_t trainingSampleCount) = 0;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Returns the set of parameters associated with this learner.
|
||||||
|
///
|
||||||
|
const std::unordered_set<Variable>& Parameters() const { return m_parameters; }
|
||||||
|
|
||||||
|
// TODO: move the following two methods into ISerializable interface, make
|
||||||
|
// Learner (and all other entities that need checkpointing capability) implement it.
|
||||||
|
///
|
||||||
|
/// Optionally overridable method to checkpoint the learner's state.
|
||||||
|
///
|
||||||
|
CNTK_API virtual Dictionary GetCheckpointState() const = 0;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Optionally overridable method to restore the learner's state from a previous checkpoint.
|
||||||
|
///
|
||||||
|
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
|
||||||
|
|
||||||
|
virtual ~Learner()
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Learner(const std::unordered_set<Variable>& parameters)
|
||||||
|
: m_parameters(parameters)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_set<Variable> m_parameters;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in SGD learner.
|
||||||
|
///
|
||||||
|
/// TODO: add additional SGD parameters here (a collection of learning rate values)
|
||||||
|
CNTK_API LearnerPtr SGDLearner(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in Momentum SGD learner.
|
||||||
|
///
|
||||||
|
/// TODO: add additional Momentum parameters here (a collection of momentum rate values)
|
||||||
|
CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner.
|
||||||
|
///
|
||||||
|
CNTK_API LearnerPtr NesterovLearner(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in AdaGrad learner.
|
||||||
|
///
|
||||||
|
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Variable>& parameters, bool needAveMultiplier = true,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner.
|
||||||
|
///
|
||||||
|
CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in RMSProp learner.
|
||||||
|
///
|
||||||
|
CNTK_API LearnerPtr RMSPropLearner(const std::unordered_set<Variable>& parameters,
|
||||||
|
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier = true,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
||||||
|
|
||||||
template <typename ElementType>
|
template <typename ElementType>
|
||||||
class ComputationNode;
|
class ComputationNode;
|
||||||
|
|
||||||
|
class File;
|
||||||
}}}
|
}}}
|
||||||
|
|
||||||
// TODO: The following should be reconciled with the equivalent code in the CNTK implementation
|
// TODO: The following should be reconciled with the equivalent code in the CNTK implementation
|
||||||
|
@ -158,4 +160,7 @@ namespace CNTK
|
||||||
|
|
||||||
class Function;
|
class Function;
|
||||||
typedef std::shared_ptr<Function> FunctionPtr;
|
typedef std::shared_ptr<Function> FunctionPtr;
|
||||||
|
|
||||||
|
class Learner;
|
||||||
|
typedef std::shared_ptr<Learner> LearnerPtr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,6 +128,7 @@
|
||||||
<ClInclude Include="API\CNTKLibrary.h" />
|
<ClInclude Include="API\CNTKLibrary.h" />
|
||||||
<ClInclude Include="API\CNTKLibraryInternals.h" />
|
<ClInclude Include="API\CNTKLibraryInternals.h" />
|
||||||
<ClInclude Include="Function.h" />
|
<ClInclude Include="Function.h" />
|
||||||
|
<ClInclude Include="Learner.h" />
|
||||||
<ClInclude Include="Utils.h" />
|
<ClInclude Include="Utils.h" />
|
||||||
<ClInclude Include="stdafx.h" />
|
<ClInclude Include="stdafx.h" />
|
||||||
<ClInclude Include="targetver.h" />
|
<ClInclude Include="targetver.h" />
|
||||||
|
@ -140,6 +141,7 @@
|
||||||
</PrecompiledHeader>
|
</PrecompiledHeader>
|
||||||
</ClCompile>
|
</ClCompile>
|
||||||
<ClCompile Include="Function.cpp" />
|
<ClCompile Include="Function.cpp" />
|
||||||
|
<ClCompile Include="Learner.cpp" />
|
||||||
<ClCompile Include="NDArrayView.cpp" />
|
<ClCompile Include="NDArrayView.cpp" />
|
||||||
<ClCompile Include="NDMask.cpp" />
|
<ClCompile Include="NDMask.cpp" />
|
||||||
<ClCompile Include="stdafx.cpp">
|
<ClCompile Include="stdafx.cpp">
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
<ClCompile Include="Variable.cpp" />
|
<ClCompile Include="Variable.cpp" />
|
||||||
<ClCompile Include="Utils.cpp" />
|
<ClCompile Include="Utils.cpp" />
|
||||||
<ClCompile Include="NDMask.cpp" />
|
<ClCompile Include="NDMask.cpp" />
|
||||||
|
<ClCompile Include="Learner.cpp" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<ClInclude Include="stdafx.h" />
|
<ClInclude Include="stdafx.h" />
|
||||||
|
@ -22,6 +23,7 @@
|
||||||
<Filter>API</Filter>
|
<Filter>API</Filter>
|
||||||
</ClInclude>
|
</ClInclude>
|
||||||
<ClInclude Include="Function.h" />
|
<ClInclude Include="Function.h" />
|
||||||
|
<ClInclude Include="Learner.h" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<Filter Include="API">
|
<Filter Include="API">
|
||||||
|
|
|
@ -0,0 +1,464 @@
|
||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "Learner.h"
|
||||||
|
#include "TensorView.h"
|
||||||
|
#include "Utils.h"
|
||||||
|
|
||||||
|
#define UPDATE_FUNCTION \
|
||||||
|
switch (smoothedGradientValue->Data()->GetDataType()) \
|
||||||
|
{ \
|
||||||
|
case DataType::Float: \
|
||||||
|
Update<float>(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \
|
||||||
|
break; \
|
||||||
|
case DataType::Double: \
|
||||||
|
Update<double>(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
NOT_IMPLEMENTED; \
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
using namespace Microsoft::MSR::CNTK;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace CNTK
|
||||||
|
{
|
||||||
|
template <typename ElementType>
|
||||||
|
/*static*/ shared_ptr<const Matrix<ElementType>> LearnerBase::GetMatrix(const NDArrayViewPtr arrayView)
|
||||||
|
{
|
||||||
|
return arrayView->GetMatrix<ElementType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
/*static*/ shared_ptr<Matrix<ElementType>> LearnerBase::GetWritableMatrix(NDArrayViewPtr arrayView)
|
||||||
|
{
|
||||||
|
return arrayView->GetWritableMatrix<ElementType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
/*static*/ const TensorView<ElementType>* LearnerBase::GetTensorView(const NDArrayViewPtr arrayView)
|
||||||
|
{
|
||||||
|
return arrayView->GetTensorView<ElementType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ bool LearnerBase::HasNan(const ValuePtr& value, const char* name)
|
||||||
|
{
|
||||||
|
const auto& data = value->Data();
|
||||||
|
switch (data->GetDataType())
|
||||||
|
{
|
||||||
|
case DataType::Float:
|
||||||
|
return data->GetMatrix<float>()->HasNan(name);
|
||||||
|
case DataType::Double:
|
||||||
|
return data->GetMatrix<double>()->HasNan(name);
|
||||||
|
default:
|
||||||
|
LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*static*/ void LearnerBase::Print(const ValuePtr& value, const char* msg)
|
||||||
|
{
|
||||||
|
const auto& data = value->Data();
|
||||||
|
switch (data->GetDataType())
|
||||||
|
{
|
||||||
|
case DataType::Float:
|
||||||
|
data->GetMatrix<float>()->Print(msg);
|
||||||
|
break;
|
||||||
|
case DataType::Double:
|
||||||
|
data->GetMatrix<double>()->Print(msg);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clipping gradients to prevent outliers,
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerBase::ClipGradient(Matrix<ElementType>& gradient, size_t actualMBSize) const
|
||||||
|
{
|
||||||
|
if (m_additionalOptions.gradientClippingThresholdPerSample != numeric_limits<double>::infinity())
|
||||||
|
{
|
||||||
|
double maxGradientPerMB = m_additionalOptions.gradientClippingThresholdPerSample * actualMBSize;
|
||||||
|
if (m_additionalOptions.gradientClippingWithTruncation)
|
||||||
|
gradient.InplaceTruncate(ElementType(maxGradientPerMB));
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// norm2 normalized
|
||||||
|
double gradientNorm = gradient.FrobeniusNorm();
|
||||||
|
if (gradientNorm > maxGradientPerMB)
|
||||||
|
{
|
||||||
|
double normFactor = maxGradientPerMB / gradientNorm;
|
||||||
|
gradient *= ElementType(normFactor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs additional preprocessing before calling the update method
|
||||||
|
// (gradient clipping and L2 regularization depending on the additional learning parameters).
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerBase::PreProcess(const ValuePtr& gradientValue,const ValuePtr& parameterValue, size_t actualMBSize) const
|
||||||
|
{
|
||||||
|
const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix<ElementType>();
|
||||||
|
|
||||||
|
// clipping gradients to prevent outliers
|
||||||
|
ClipGradient<ElementType>(*gradientMatrix, actualMBSize);
|
||||||
|
|
||||||
|
// L2 regularizer
|
||||||
|
if (m_additionalOptions.l2RegularizationWeight > 0)
|
||||||
|
{
|
||||||
|
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
|
||||||
|
auto weight = ElementType(m_additionalOptions.l2RegularizationWeight * actualMBSize);
|
||||||
|
const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix<ElementType>();
|
||||||
|
Matrix<ElementType>::ScaleAndAdd(weight, *parameterMatrix, *gradientMatrix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs additional postprocessing after the update method has been executed
|
||||||
|
// (noise injection and L1 regularization specified by the additional learning parameters).
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerBase::PostProcess(const Variable& parameter, const ValuePtr& gradientValue,
|
||||||
|
const ValuePtr& parameterValue, size_t actualMBSize) const
|
||||||
|
{
|
||||||
|
const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix<ElementType>();
|
||||||
|
if (m_additionalOptions.gaussianNoiseInjectionStdDev > 0)
|
||||||
|
{
|
||||||
|
const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix<ElementType>();
|
||||||
|
|
||||||
|
Matrix<ElementType> sgdUpdateNoise((DEVICEID_TYPE)parameterMatrix->GetDeviceId());
|
||||||
|
|
||||||
|
// get the gradient structure since gradient is sparse
|
||||||
|
sgdUpdateNoise.SetValue(*gradientMatrix);
|
||||||
|
|
||||||
|
auto noiseStdDev = ElementType(m_additionalOptions.gaussianNoiseInjectionStdDev);
|
||||||
|
|
||||||
|
// reset its value to random
|
||||||
|
sgdUpdateNoise.SetGaussianRandomValue(ElementType(0.0), noiseStdDev);
|
||||||
|
|
||||||
|
Matrix<ElementType>::ScaleAndAdd(ElementType(1.0), sgdUpdateNoise, *parameterMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
// L1 regularizer with proximal gradient descent method
|
||||||
|
if (m_additionalOptions.l1RegularizationWeight > 0)
|
||||||
|
{
|
||||||
|
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||||
|
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
|
||||||
|
auto weight = ElementType(learningRate * m_additionalOptions.l1RegularizationWeight * actualMBSize);
|
||||||
|
parameterValue->Data()->GetWritableMatrix<ElementType>()->InplaceSoftThreshold(weight);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
/*static*/ TensorView<ElementType>* LearnerBase::GetWritableTensorView(NDArrayViewPtr arrayView)
|
||||||
|
{
|
||||||
|
return arrayView->GetWritableTensorView<ElementType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerBase::LearnerBase(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||||
|
: Learner(parameters),
|
||||||
|
m_learningRatePerSample(0.0),
|
||||||
|
m_sampleCount(0)
|
||||||
|
{
|
||||||
|
const unordered_set<Variable>& parameterSet = parameters;
|
||||||
|
for (const auto& parameter : parameterSet)
|
||||||
|
{
|
||||||
|
// TODO: using the same device to allocate data for all smoothed gradients. Is this correct?
|
||||||
|
// Should the device be specified on the per-parameter basis?
|
||||||
|
NDArrayViewPtr view;
|
||||||
|
if (parameter.GetDataType() == DataType::Float)
|
||||||
|
{
|
||||||
|
view = MakeSharedObject<NDArrayView>(0.0f, parameter.Shape(), device);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
view = MakeSharedObject<NDArrayView>(0.0, parameter.Shape(), device);
|
||||||
|
}
|
||||||
|
|
||||||
|
m_smoothedGradientValues.insert(make_pair(parameter, MakeSharedObject<Value>(view)));
|
||||||
|
m_additionalOptions.learningRateMultipliers.insert(make_pair(parameter, 1.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LearnerBase::ResetSmoothedGradients()
|
||||||
|
{
|
||||||
|
for (const auto& parameter : Parameters())
|
||||||
|
{
|
||||||
|
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||||
|
const auto& data = smoothedGradientValue->Data();
|
||||||
|
switch (data->GetDataType())
|
||||||
|
{
|
||||||
|
case DataType::Float:
|
||||||
|
data->SetValue(0.0f);
|
||||||
|
break;
|
||||||
|
case DataType::Double:
|
||||||
|
data->SetValue(0.0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LogicError("Unsupported DataType %s", ::CNTK::DataTypeName(data->GetDataType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ bool LearnerBase::Update(const unordered_map<Variable, ValuePtr>& parameterValues,
|
||||||
|
const unordered_map<Variable, const ValuePtr>& gradientValues,
|
||||||
|
size_t trainingSampleCount) /*override*/
|
||||||
|
{
|
||||||
|
// make sure trainingSampleCount is a valid value
|
||||||
|
assert(trainingSampleCount > 0);
|
||||||
|
|
||||||
|
for (const auto& parameter : Parameters())
|
||||||
|
{
|
||||||
|
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||||
|
const auto& gradientValue = gradientValues.at(parameter);
|
||||||
|
const auto& parameterValue = parameterValues.at(parameter);
|
||||||
|
|
||||||
|
// TODO: make this a runtime parameter.
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
LOGPRINTF(stderr, "Update_%ls\n", parameter.Name().c_str());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _DEBUG
|
||||||
|
if (HasNan(smoothedGradientValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
|
||||||
|
LogicError("%ls has NaNs in smoothedGradient.", parameter.Name().c_str());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
LOGPRINTF(stderr, "learnRatePerSample=%0.8f, momentum=%0.8f, actualMBSize=%ld\n",
|
||||||
|
m_learningRatePerSample, m_momentumPerSample, trainingSampleCount);
|
||||||
|
LOGPRINTF(stderr, "GradUpdateType()=%s, GradientUpdateNoiseStd()=%0.8f\n",
|
||||||
|
LearnerType().c_str(), m_GaussianNoiseInjectStd);
|
||||||
|
Print(gradientValue, "Gradient Update");
|
||||||
|
Print(smoothedGradientValue, "Smoothed Gradient Input");
|
||||||
|
#endif
|
||||||
|
UPDATE_FUNCTION;
|
||||||
|
|
||||||
|
#if DUMPOUTPUT
|
||||||
|
Print(parameterValue, "Parameter Update");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _DEBUG
|
||||||
|
if (HasNan(parameterValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
|
||||||
|
LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Name().c_str());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
m_sampleCount += trainingSampleCount;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerBase::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||||
|
{
|
||||||
|
PreProcess<ElementType>(gradientValue, parameterValue, trainingSampleCount);
|
||||||
|
Update(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount);
|
||||||
|
PostProcess<ElementType>(parameter, gradientValue, parameterValue, trainingSampleCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
string LearnerBase::LearnerType() const
|
||||||
|
{
|
||||||
|
auto name = typeid(*this).name();
|
||||||
|
if (strncmp(name, "class ", 6) == 0)
|
||||||
|
{
|
||||||
|
// On Windows, the type name contains "class" prefix.
|
||||||
|
// Return the actual name, omitting the prefix.
|
||||||
|
return &name[6];
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ Dictionary LearnerBase::GetCheckpointState() const /*override*/
|
||||||
|
{
|
||||||
|
NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this.
|
||||||
|
Dictionary checkpoint;
|
||||||
|
|
||||||
|
for (const auto& parameter : Parameters())
|
||||||
|
{
|
||||||
|
// TODO: parameter name is not guaranteed to be unique. Instead, all serializable objects
|
||||||
|
// need to expose "UId" property -- a persistent unique internal name.
|
||||||
|
// Switch to UId as soon as it's available.
|
||||||
|
if (checkpoint.Contains(parameter.Name()))
|
||||||
|
{
|
||||||
|
LogicError("Parameter names must be unique");
|
||||||
|
}
|
||||||
|
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||||
|
|
||||||
|
// Potentially, could store things like dimensions, element size, format, etc., but
|
||||||
|
// that seems to be redundant, since all of that is passed in the constructor.
|
||||||
|
checkpoint[parameter.Name()] = SerializeToVector(smoothedGradientValue->Data());
|
||||||
|
}
|
||||||
|
return checkpoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ void LearnerBase::RestoreFromCheckpoint(const Dictionary& checkpoint) /*override*/
|
||||||
|
{
|
||||||
|
NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this.
|
||||||
|
for (const auto& parameter : Parameters())
|
||||||
|
{
|
||||||
|
if (!checkpoint.Contains(parameter.Name()))
|
||||||
|
{
|
||||||
|
LogicError("Checkpoint does not contain state for parameter %ls", parameter.Name().c_str());
|
||||||
|
}
|
||||||
|
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||||
|
|
||||||
|
const DictionaryValue& state = checkpoint[parameter.Name()];
|
||||||
|
|
||||||
|
const auto& data = smoothedGradientValue->Data();
|
||||||
|
|
||||||
|
DeserializeFromVector(data, state.GetValue<vector<DictionaryValue>>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||||
|
{
|
||||||
|
UPDATE_FUNCTION;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||||
|
{
|
||||||
|
UNUSED(trainingSampleCount);
|
||||||
|
|
||||||
|
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||||
|
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||||
|
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||||
|
|
||||||
|
const auto& learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||||
|
|
||||||
|
// TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters
|
||||||
|
// (one for vanilla SGD, the other for momentum SGD, and the third one for NAG).
|
||||||
|
smoothedGradientMatrix->NormalGrad(*gradientMatrix, *parameterMatrix,
|
||||||
|
learningRate, ElementType(m_momentumPerSample), m_useNesterovAcceleration);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerAdaGrad::LearnerAdaGrad(const unordered_set<Variable>& parameters, bool needAveMultiplier, const DeviceDescriptor& device)
|
||||||
|
: LearnerBase(parameters, device),
|
||||||
|
m_needAveMultiplier(needAveMultiplier)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||||
|
{
|
||||||
|
UPDATE_FUNCTION;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||||
|
{
|
||||||
|
UNUSED(trainingSampleCount);
|
||||||
|
|
||||||
|
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||||
|
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||||
|
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||||
|
|
||||||
|
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||||
|
|
||||||
|
auto aveMultiplier = smoothedGradientMatrix->Adagrad(*gradientMatrix, m_needAveMultiplier);
|
||||||
|
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||||
|
: LearnerMomentumSGD(parameters, device)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||||
|
{
|
||||||
|
UPDATE_FUNCTION;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||||
|
{
|
||||||
|
UNUSED(trainingSampleCount);
|
||||||
|
|
||||||
|
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||||
|
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||||
|
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||||
|
|
||||||
|
//const double momentum = MomentumPerMB(m_momentumPerSample, trainingSampleCount);
|
||||||
|
|
||||||
|
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||||
|
|
||||||
|
smoothedGradientMatrix->FSAdagrad(trainingSampleCount, *gradientMatrix, *parameterMatrix,
|
||||||
|
learningRate, ElementType(m_momentumPerSample));
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerRMSProp::LearnerRMSProp(const unordered_set<Variable>& parameters,
|
||||||
|
double gamma, double inc, double dec, double max, double min,
|
||||||
|
bool needAveMultiplier, const DeviceDescriptor& device)
|
||||||
|
: LearnerBase(parameters, device),
|
||||||
|
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min),
|
||||||
|
m_needAveMultiplier(needAveMultiplier)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/*virtual*/ void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||||
|
{
|
||||||
|
UPDATE_FUNCTION;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||||
|
{
|
||||||
|
UNUSED(trainingSampleCount);
|
||||||
|
|
||||||
|
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||||
|
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||||
|
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||||
|
|
||||||
|
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||||
|
|
||||||
|
auto aveMultiplier = smoothedGradientMatrix->RmsProp(*gradientMatrix,
|
||||||
|
ElementType(m_gamma), ElementType(m_inc),
|
||||||
|
ElementType(m_max), ElementType(m_dec),
|
||||||
|
ElementType(m_min), m_needAveMultiplier);
|
||||||
|
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit template instantiations
|
||||||
|
template shared_ptr<Matrix<float>> LearnerBase::GetWritableMatrix<float>(const NDArrayViewPtr arrayView);
|
||||||
|
template shared_ptr<Matrix<double>> LearnerBase::GetWritableMatrix<double>(const NDArrayViewPtr arrayView);
|
||||||
|
|
||||||
|
LearnerPtr SGDLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
return MakeSharedObject<LearnerSGD>(parameters, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerPtr MomentumSGDLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
return MakeSharedObject<LearnerMomentumSGD>(parameters, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerPtr NesterovLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
return MakeSharedObject<LearnerNesterov>(parameters, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerPtr AdaGradLearner(const unordered_set<Variable>& parameters, bool needAveMultiplier, const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
return MakeSharedObject<LearnerAdaGrad>(parameters, needAveMultiplier, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerPtr FSAdaGradLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
return MakeSharedObject<LearnerFSAdaGrad>(parameters, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
LearnerPtr RMSPropLearner(const unordered_set<Variable>& parameters,
|
||||||
|
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
|
||||||
|
const DeviceDescriptor& device)
|
||||||
|
{
|
||||||
|
return MakeSharedObject<LearnerRMSProp>(parameters, gamma, inc, dec, max, min, needAveMultiplier, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,224 @@
|
||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "stdafx.h"
|
||||||
|
#include "CNTKLibrary.h"
|
||||||
|
|
||||||
|
namespace CNTK
|
||||||
|
{
|
||||||
|
// A collection of additional options that are applicable for all standard learners
|
||||||
|
// (after these options are set, they retain their value for the entire lifespan of a learner).
|
||||||
|
struct AdditionalLearningOptions
|
||||||
|
{
|
||||||
|
double l1RegularizationWeight = 0.0;
|
||||||
|
double l2RegularizationWeight = 0.0;
|
||||||
|
double gaussianNoiseInjectionStdDev = 0.0;
|
||||||
|
bool gradientClippingWithTruncation = false;
|
||||||
|
double gradientClippingThresholdPerSample = 0.0;
|
||||||
|
std::unordered_map<Variable, double> learningRateMultipliers;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An abstract base class at the root of the standard learners hierarchy
|
||||||
|
// It implements most of the learner functionality, except for the actual update function,
|
||||||
|
// and adds a few pre-/postprocessing methods (which are invoked before and after the update).
|
||||||
|
class LearnerBase : public Learner
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
CNTK_API virtual bool Update(const std::unordered_map<Variable, ValuePtr>& parameterValues,
|
||||||
|
const std::unordered_map<Variable, const ValuePtr>& gradientValues,
|
||||||
|
size_t trainingSampleCount) override final;
|
||||||
|
|
||||||
|
CNTK_API virtual Dictionary GetCheckpointState() const override;
|
||||||
|
|
||||||
|
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
|
||||||
|
|
||||||
|
CNTK_API void SetAdditionalOptions(const AdditionalLearningOptions& additionalOptions)
|
||||||
|
{
|
||||||
|
m_additionalOptions = additionalOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: should this be called ResetMomentum?
|
||||||
|
// needed for BlockMomemtumSGD to reset SGD momentum after aggregation.
|
||||||
|
CNTK_API void ResetSmoothedGradients();
|
||||||
|
|
||||||
|
// TODO: move learning rate and momentum scheduling and adjustment functionality
|
||||||
|
// inside the learner and drop these setters.
|
||||||
|
void SetLearningRate(double value) { m_learningRatePerSample = value; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
LearnerBase(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const = 0;
|
||||||
|
|
||||||
|
double ParameterDependentLearningRate(const Variable& parameter) const
|
||||||
|
{
|
||||||
|
return m_learningRatePerSample * m_additionalOptions.learningRateMultipliers.at(parameter);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string LearnerType() const;
|
||||||
|
|
||||||
|
double m_learningRatePerSample;
|
||||||
|
|
||||||
|
AdditionalLearningOptions m_additionalOptions;
|
||||||
|
|
||||||
|
std::unordered_map<Variable, ValuePtr> m_smoothedGradientValues;
|
||||||
|
|
||||||
|
// The following four static protected methods expose private methods of NDArrayView class
|
||||||
|
// (which declares LearnerBase as friend class), so that they are available to subclasses.
|
||||||
|
template <typename ElementType>
|
||||||
|
static std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrix(const NDArrayViewPtr arrayView);
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
static std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetWritableMatrix(NDArrayViewPtr arrayView);
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
static const Microsoft::MSR::CNTK::TensorView<ElementType>* GetTensorView(const NDArrayViewPtr arrayView);
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
static Microsoft::MSR::CNTK::TensorView<ElementType>* GetWritableTensorView(NDArrayViewPtr arrayView);
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void ClipGradient(Microsoft::MSR::CNTK::Matrix<ElementType>& gradient, size_t actualMBSize) const;
|
||||||
|
|
||||||
|
// Performs additional preprocessing before calling the update method
|
||||||
|
// (gradient clipping and L2 regularization depending on the additional learning parameters).
|
||||||
|
template <typename ElementType>
|
||||||
|
void PreProcess(const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t actualMBSize) const;
|
||||||
|
|
||||||
|
// Performs additional postprocessing after the update method has been executed
|
||||||
|
// (noise injection and L1 regularization specified by the additional learning parameters).
|
||||||
|
template <typename ElementType>
|
||||||
|
void PostProcess(const Variable& parameter, const ValuePtr& gradientValue,
|
||||||
|
const ValuePtr& parameterValue, size_t actualMBSize) const;
|
||||||
|
private:
|
||||||
|
// Templatized update function, it invokes preprocess and postprocess using the provided
|
||||||
|
// template parameter and also invokes virtual Update method implemented in one of the subclasses.
|
||||||
|
template <typename ElementType>
|
||||||
|
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||||
|
|
||||||
|
// TODO: make these functions friends of NDViewArray and move to Utils?
|
||||||
|
static bool HasNan(const ValuePtr& value, const char* name);
|
||||||
|
static void Print(const ValuePtr& value, const char* msg);
|
||||||
|
|
||||||
|
size_t m_sampleCount;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Vanilla gradient descent optimization algorithm.
|
||||||
|
class LearnerSGD : public LearnerBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
LearnerSGD(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
|
||||||
|
: LearnerBase(parameters, device),
|
||||||
|
m_momentumPerSample(0.0),
|
||||||
|
m_useNesterovAcceleration(false)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||||
|
|
||||||
|
double m_momentumPerSample;
|
||||||
|
bool m_useNesterovAcceleration;
|
||||||
|
};
|
||||||
|
|
||||||
|
// SGD optimization with momentum.
|
||||||
|
class LearnerMomentumSGD : public LearnerSGD
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
LearnerMomentumSGD(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
|
||||||
|
: LearnerSGD(parameters, device)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetMomentum(double value) { m_momentumPerSample = value; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Nesterov's accelerated SGDLearnerBase descent.
|
||||||
|
class LearnerNesterov : public LearnerSGD
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
LearnerNesterov(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
|
||||||
|
: LearnerSGD(parameters, device)
|
||||||
|
{
|
||||||
|
m_useNesterovAcceleration = true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LearnerAdaGrad : public LearnerBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
LearnerAdaGrad(const std::unordered_set<Variable>& parameters, bool needAveMultiplier,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool m_needAveMultiplier;
|
||||||
|
|
||||||
|
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LearnerFSAdaGrad : public LearnerMomentumSGD
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
LearnerFSAdaGrad(const std::unordered_set<Variable>& parameters,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class LearnerRMSProp : public LearnerBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
LearnerRMSProp(const std::unordered_set<Variable>& parameters,
|
||||||
|
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
|
||||||
|
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
double m_gamma;
|
||||||
|
double m_inc;
|
||||||
|
double m_dec;
|
||||||
|
double m_max;
|
||||||
|
double m_min;
|
||||||
|
bool m_needAveMultiplier;
|
||||||
|
|
||||||
|
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||||
|
|
||||||
|
template <typename ElementType>
|
||||||
|
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||||
|
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||||
|
};
|
||||||
|
}
|
|
@ -338,8 +338,10 @@ namespace CNTK
|
||||||
template std::shared_ptr<const Matrix<float>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
|
template std::shared_ptr<const Matrix<float>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
|
||||||
template std::shared_ptr<const Matrix<double>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
|
template std::shared_ptr<const Matrix<double>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
|
||||||
|
|
||||||
template std::shared_ptr<Matrix<float>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
template std::shared_ptr<Matrix<float>> NDArrayView::GetWritableMatrix<float>(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
||||||
template std::shared_ptr<Matrix<double>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
template std::shared_ptr<Matrix<double>> NDArrayView::GetWritableMatrix<double>(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
||||||
|
template TensorView<float>* NDArrayView::GetWritableTensorView<float>();
|
||||||
|
template TensorView<double>* NDArrayView::GetWritableTensorView<double>();
|
||||||
|
|
||||||
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const float* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const float* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
||||||
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const double* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const double* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
||||||
|
|
|
@ -6,11 +6,138 @@
|
||||||
#include "stdafx.h"
|
#include "stdafx.h"
|
||||||
#include "CNTKLibrary.h"
|
#include "CNTKLibrary.h"
|
||||||
#include "Utils.h"
|
#include "Utils.h"
|
||||||
|
#include "File.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
namespace CNTK
|
namespace CNTK
|
||||||
{
|
{
|
||||||
|
template <typename T>
|
||||||
|
void DictionaryValue::AllocateDataPtr(const T& value)
|
||||||
|
{
|
||||||
|
static_assert(is_same<T, NDShape>::value || is_same<T, vector<DictionaryValue>>::value, "AllocateDataPtr called with invalid type");
|
||||||
|
m_data.m_ptr = new T(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void DictionaryValue::FreePtrAsType()
|
||||||
|
{
|
||||||
|
T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
|
||||||
|
delete typedPtr;
|
||||||
|
|
||||||
|
m_data.m_ptr = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DictionaryValue::FreeDataPtr()
|
||||||
|
{
|
||||||
|
if (m_valueType == Type::NDShape)
|
||||||
|
FreePtrAsType<NDShape>();
|
||||||
|
else if (m_valueType == Type::Vector)
|
||||||
|
FreePtrAsType<vector<DictionaryValue>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us)
|
||||||
|
{
|
||||||
|
size_t version;
|
||||||
|
stream >> version;
|
||||||
|
|
||||||
|
stream >> us.m_valueType;
|
||||||
|
|
||||||
|
switch (us.ValueType())
|
||||||
|
{
|
||||||
|
case DictionaryValue::Type::Bool:
|
||||||
|
stream >> us.m_data.m_boolean;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::SizeT:
|
||||||
|
stream >> us.m_data.m_sizeT;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::Float:
|
||||||
|
stream >> us.m_data.m_float;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::Double:
|
||||||
|
stream >> us.m_data.m_double;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::NDShape:
|
||||||
|
{
|
||||||
|
size_t size;
|
||||||
|
stream >> size;
|
||||||
|
vector<size_t> dims(size);
|
||||||
|
for (auto i = 0; i < size; i++)
|
||||||
|
{
|
||||||
|
stream >> dims[i];
|
||||||
|
}
|
||||||
|
us.AllocateDataPtr(NDShape(dims));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DictionaryValue::Type::Vector:
|
||||||
|
{
|
||||||
|
size_t size;
|
||||||
|
stream >> size;
|
||||||
|
vector<DictionaryValue> values(size);
|
||||||
|
for (auto i = 0; i < size; i++)
|
||||||
|
{
|
||||||
|
stream >> values[i];
|
||||||
|
}
|
||||||
|
us.AllocateDataPtr(values);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
NOT_IMPLEMENTED;
|
||||||
|
}
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us)
|
||||||
|
{
|
||||||
|
stream << us.version;
|
||||||
|
|
||||||
|
stream << us.ValueType();
|
||||||
|
|
||||||
|
switch (us.ValueType())
|
||||||
|
{
|
||||||
|
case DictionaryValue::Type::Bool:
|
||||||
|
stream << us.m_data.m_boolean;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::SizeT:
|
||||||
|
stream << us.m_data.m_sizeT;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::Float:
|
||||||
|
stream << us.m_data.m_float;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::Double:
|
||||||
|
stream << us.m_data.m_double;
|
||||||
|
break;
|
||||||
|
case DictionaryValue::Type::NDShape:
|
||||||
|
{
|
||||||
|
NDShape* shapePtr = reinterpret_cast<NDShape*>(us.m_data.m_ptr);
|
||||||
|
auto size = shapePtr->NumAxes();
|
||||||
|
stream << size;
|
||||||
|
for (auto i = 0; i < size; i++)
|
||||||
|
{
|
||||||
|
stream << shapePtr->operator[](i);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DictionaryValue::Type::Vector:
|
||||||
|
{
|
||||||
|
vector<DictionaryValue>* vectorPtr =
|
||||||
|
reinterpret_cast<vector<DictionaryValue>*>(us.m_data.m_ptr);
|
||||||
|
auto size = vectorPtr->size();
|
||||||
|
stream << size;
|
||||||
|
for (auto i = 0; i < size; i++)
|
||||||
|
{
|
||||||
|
stream << vectorPtr->operator[](i);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
NOT_IMPLEMENTED;
|
||||||
|
}
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
Dictionary::Dictionary()
|
Dictionary::Dictionary()
|
||||||
: m_dictionaryData(new std::unordered_map < std::wstring, DictionaryValue>)
|
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +149,7 @@ namespace CNTK
|
||||||
Dictionary::Dictionary(Dictionary&& other)
|
Dictionary::Dictionary(Dictionary&& other)
|
||||||
: m_dictionaryData(nullptr)
|
: m_dictionaryData(nullptr)
|
||||||
{
|
{
|
||||||
*this = std::move(other);
|
*this = move(other);
|
||||||
}
|
}
|
||||||
|
|
||||||
Dictionary& Dictionary::operator=(Dictionary&& other)
|
Dictionary& Dictionary::operator=(Dictionary&& other)
|
||||||
|
@ -51,4 +178,130 @@ namespace CNTK
|
||||||
{
|
{
|
||||||
return (m_dictionaryData->find(key) != m_dictionaryData->end());
|
return (m_dictionaryData->find(key) != m_dictionaryData->end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us)
|
||||||
|
{
|
||||||
|
stream << us.version;
|
||||||
|
stream << us.m_dictionaryData->size();
|
||||||
|
for (auto it = us.m_dictionaryData->begin(); it != us.m_dictionaryData->end(); ++it)
|
||||||
|
{
|
||||||
|
stream << it->first;
|
||||||
|
stream << it->second;
|
||||||
|
}
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us)
|
||||||
|
{
|
||||||
|
size_t version;
|
||||||
|
stream >> version;
|
||||||
|
size_t size;
|
||||||
|
stream >> size;
|
||||||
|
us.m_dictionaryData->reserve(size);
|
||||||
|
for (auto i = 0; i < size; i++)
|
||||||
|
{
|
||||||
|
wstring key;
|
||||||
|
stream >> key;
|
||||||
|
DictionaryValue value;
|
||||||
|
stream >> value;
|
||||||
|
us.m_dictionaryData->insert(make_pair(key, value));
|
||||||
|
}
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr)
|
||||||
|
{
|
||||||
|
if (viewPtr->IsSparse())
|
||||||
|
{
|
||||||
|
LogicError("Sparse NDArrayView cannot be serialized into a vector.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto numElements = viewPtr->Shape().TotalSize();
|
||||||
|
|
||||||
|
vector<DictionaryValue> values(numElements);
|
||||||
|
|
||||||
|
NDArrayViewPtr cpuDataViewPtr = viewPtr;
|
||||||
|
if ((viewPtr->Device().Type() != DeviceKind::CPU))
|
||||||
|
{
|
||||||
|
cpuDataViewPtr = MakeSharedObject<NDArrayView>(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice());
|
||||||
|
cpuDataViewPtr->CopyFrom(*viewPtr);
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* buffer = cpuDataViewPtr->DataBuffer<T>();
|
||||||
|
for (auto i = 0; i < numElements; ++i)
|
||||||
|
{
|
||||||
|
T v = buffer[i];
|
||||||
|
values[i] = DictionaryValue(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector<DictionaryValue>& values)
|
||||||
|
{
|
||||||
|
if (viewPtr->IsSparse())
|
||||||
|
{
|
||||||
|
LogicError("Sparse NDArrayView cannot be deserialized from a vector.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto numElements = viewPtr->Shape().TotalSize();
|
||||||
|
|
||||||
|
if (values.size() != numElements)
|
||||||
|
{
|
||||||
|
LogicError("Number of elements (%lu) in the deserialized representation does not match the expected value (%lu)",
|
||||||
|
values.size(), numElements);
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArrayViewPtr cpuDataViewPtr = viewPtr;
|
||||||
|
if ((viewPtr->Device().Type() != DeviceKind::CPU))
|
||||||
|
{
|
||||||
|
cpuDataViewPtr = MakeSharedObject<NDArrayView>(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice());
|
||||||
|
}
|
||||||
|
|
||||||
|
T* buffer = cpuDataViewPtr->WritableDataBuffer<T>();
|
||||||
|
for (auto i = 0; i < numElements; ++i)
|
||||||
|
{
|
||||||
|
buffer[i] = values[i].GetValue<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((viewPtr->Device().Type() != DeviceKind::CPU))
|
||||||
|
{
|
||||||
|
viewPtr->CopyFrom(*cpuDataViewPtr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: we store the type info for every element in the vector, which is extremely redundant.
|
||||||
|
// Instead, it'd be nice to introduce some sort of DictionaryValueVector.
|
||||||
|
vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr)
|
||||||
|
{
|
||||||
|
switch (viewPtr->GetDataType())
|
||||||
|
{
|
||||||
|
case DataType::Float:
|
||||||
|
return SerializeToVector<float>(viewPtr);
|
||||||
|
case DataType::Double:
|
||||||
|
return SerializeToVector<double>(viewPtr);
|
||||||
|
default:
|
||||||
|
LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector<DictionaryValue>& values)
|
||||||
|
{
|
||||||
|
switch (viewPtr->GetDataType())
|
||||||
|
{
|
||||||
|
case DataType::Float:
|
||||||
|
DeserializeFromVector<float>(viewPtr, values);
|
||||||
|
break;
|
||||||
|
case DataType::Double:
|
||||||
|
DeserializeFromVector<double>(viewPtr, values);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
|
||||||
|
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,245 +15,6 @@ namespace CNTK
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
class Dictionary;
|
class Dictionary;
|
||||||
|
|
||||||
class DictionaryValue
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
enum class Type : unsigned int
|
|
||||||
{
|
|
||||||
None,
|
|
||||||
Bool,
|
|
||||||
SizeT,
|
|
||||||
Double,
|
|
||||||
NDShape,
|
|
||||||
Vector
|
|
||||||
};
|
|
||||||
|
|
||||||
static const char* TypeName(Type type)
|
|
||||||
{
|
|
||||||
if (type == Type::None)
|
|
||||||
return "None";
|
|
||||||
else if (type == Type::Bool)
|
|
||||||
return "Bool";
|
|
||||||
else if (type == Type::SizeT)
|
|
||||||
return "SizeT";
|
|
||||||
else if (type == Type::Double)
|
|
||||||
return "Double";
|
|
||||||
else if (type == Type::NDShape)
|
|
||||||
return "NDShape";
|
|
||||||
else if (type == Type::Vector)
|
|
||||||
return "Vector";
|
|
||||||
else
|
|
||||||
LogicError("Unknown DictionaryValue::Type");
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
DictionaryValue()
|
|
||||||
: m_valueType(Type::None)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue(bool value)
|
|
||||||
: m_valueType(GetValueType<bool>())
|
|
||||||
{
|
|
||||||
m_data.m_boolean = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue(size_t value)
|
|
||||||
: m_valueType(GetValueType<size_t>())
|
|
||||||
{
|
|
||||||
m_data.m_sizeT = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue(double value)
|
|
||||||
: m_valueType(GetValueType<double>())
|
|
||||||
{
|
|
||||||
m_data.m_double = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
DictionaryValue(const T& value)
|
|
||||||
: m_valueType(GetValueType<T>())
|
|
||||||
{
|
|
||||||
static_assert(std::is_same<T, NDShape>::value ||
|
|
||||||
std::is_same<T, std::vector<DictionaryValue>>::value,
|
|
||||||
"Unsupported ValueType");
|
|
||||||
|
|
||||||
AllocateDataPtr(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue(const DictionaryValue& other)
|
|
||||||
: m_valueType(Type::Bool)
|
|
||||||
{
|
|
||||||
// The m_valueType must hvae been set to a non-ptr type to prevent an attempt to interpret
|
|
||||||
// the underlying underlying uninitialized value as a ptr and free it.
|
|
||||||
*this = other;
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue& operator=(const DictionaryValue& other)
|
|
||||||
{
|
|
||||||
if (this != &other)
|
|
||||||
{
|
|
||||||
FreeDataPtr();
|
|
||||||
|
|
||||||
m_valueType = other.m_valueType;
|
|
||||||
m_data = other.m_data;
|
|
||||||
|
|
||||||
if (other.m_valueType == Type::NDShape)
|
|
||||||
AllocateDataPtr(other.GetValue<NDShape>());
|
|
||||||
else if (other.m_valueType == Type::Vector)
|
|
||||||
AllocateDataPtr(other.GetValue<std::vector<DictionaryValue>>());
|
|
||||||
}
|
|
||||||
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
~DictionaryValue()
|
|
||||||
{
|
|
||||||
FreeDataPtr();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename std::enable_if<std::is_same<T, bool>::value>::type* = nullptr>
|
|
||||||
const T& GetValue() const
|
|
||||||
{
|
|
||||||
VerifyType<T>();
|
|
||||||
return m_data.m_boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename std::enable_if<std::is_same<T, size_t>::value>::type* = nullptr>
|
|
||||||
const T& GetValue() const
|
|
||||||
{
|
|
||||||
VerifyType<T>();
|
|
||||||
return m_data.m_sizeT;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename std::enable_if<std::is_same<T, double>::value>::type* = nullptr>
|
|
||||||
const T& GetValue() const
|
|
||||||
{
|
|
||||||
VerifyType<T>();
|
|
||||||
return m_data.m_double;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename std::enable_if<std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value>::type* = nullptr>
|
|
||||||
const T& GetValue() const
|
|
||||||
{
|
|
||||||
VerifyType<T>();
|
|
||||||
return *(reinterpret_cast<T*>(m_data.m_ptr));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HasValue() const
|
|
||||||
{
|
|
||||||
return m_valueType != Type::None;
|
|
||||||
}
|
|
||||||
|
|
||||||
Type ValueType() const
|
|
||||||
{
|
|
||||||
return m_valueType;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename T>
|
|
||||||
static Type GetValueType()
|
|
||||||
{
|
|
||||||
static_assert(std::is_same<T, bool>::value ||
|
|
||||||
std::is_same<T, size_t>::value ||
|
|
||||||
std::is_same<T, double>::value ||
|
|
||||||
std::is_same<T, NDShape>::value ||
|
|
||||||
std::is_same<T, std::vector<DictionaryValue>>::value ||
|
|
||||||
std::is_same<T, CNTK::Dictionary>::value,
|
|
||||||
"Unsupported ValueType");
|
|
||||||
|
|
||||||
if (std::is_same<T, bool>::value)
|
|
||||||
return Type::Bool;
|
|
||||||
else if (std::is_same<T, size_t>::value)
|
|
||||||
return Type::SizeT;
|
|
||||||
else if (std::is_same<T, double>::value)
|
|
||||||
return Type::Double;
|
|
||||||
else if (std::is_same<T, NDShape>::value)
|
|
||||||
return Type::NDShape;
|
|
||||||
else if (std::is_same<T, std::vector<DictionaryValue>>::value)
|
|
||||||
return Type::Vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void VerifyType() const
|
|
||||||
{
|
|
||||||
if (GetValueType<T>() != m_valueType)
|
|
||||||
RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void AllocateDataPtr(const T& value)
|
|
||||||
{
|
|
||||||
static_assert(std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value, "AllocateDataPtr called with invalid type");
|
|
||||||
m_data.m_ptr = new T(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void FreePtrAsType()
|
|
||||||
{
|
|
||||||
T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
|
|
||||||
delete typedPtr;
|
|
||||||
|
|
||||||
m_data.m_ptr = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void FreeDataPtr()
|
|
||||||
{
|
|
||||||
if (m_valueType == Type::NDShape)
|
|
||||||
FreePtrAsType<NDShape>();
|
|
||||||
else if (m_valueType == Type::Vector)
|
|
||||||
FreePtrAsType<std::vector<DictionaryValue>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Type m_valueType;
|
|
||||||
|
|
||||||
union ValueData
|
|
||||||
{
|
|
||||||
bool m_boolean;
|
|
||||||
size_t m_sizeT;
|
|
||||||
double m_double;
|
|
||||||
void* m_ptr;
|
|
||||||
} m_data;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Dictionary
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Dictionary();
|
|
||||||
~Dictionary();
|
|
||||||
|
|
||||||
// Disallow copy contruction and assignment
|
|
||||||
Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete;
|
|
||||||
|
|
||||||
Dictionary(Dictionary&& other);
|
|
||||||
Dictionary& operator=(Dictionary&& other);
|
|
||||||
|
|
||||||
DictionaryValue& operator[](const std::wstring& key)
|
|
||||||
{
|
|
||||||
return operator[](key.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue& operator[](const wchar_t* key);
|
|
||||||
|
|
||||||
DictionaryValue operator[](const std::wstring& key) const
|
|
||||||
{
|
|
||||||
return operator[](key.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
DictionaryValue operator[](const wchar_t* key) const;
|
|
||||||
|
|
||||||
bool Contains(const std::wstring& key) const
|
|
||||||
{
|
|
||||||
return Contains(key.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Contains(const wchar_t* key) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Helper to get the size of an element of the specified DataType
|
// Helper to get the size of an element of the specified DataType
|
||||||
inline size_t ElementSize(DataType dataType)
|
inline size_t ElementSize(DataType dataType)
|
||||||
{
|
{
|
||||||
|
@ -363,4 +124,8 @@ namespace CNTK
|
||||||
{
|
{
|
||||||
return var.IsInput() && var.IsSparse();
|
return var.IsInput() && var.IsSparse();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr);
|
||||||
|
|
||||||
|
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const std::vector<DictionaryValue>& values);
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче