Add v2 Learners (standalone)
This commit is contained in:
Родитель
271476466a
Коммит
1b0548fdde
2
Makefile
2
Makefile
|
@ -375,6 +375,8 @@ CNTKLIBRARY_SRC =\
|
|||
$(SOURCEDIR)/CNTKv2LibraryDll/Utils.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Value.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Variable.cpp \
|
||||
$(SOURCEDIR)/CNTKv2LibraryDll/Learner.cpp \
|
||||
|
||||
|
||||
CNTKLIBRARY_SRC+=$(CNTK_COMMON_SRC)
|
||||
CNTKLIBRARY_SRC+=$(COMPUTATION_NETWORK_LIB_SRC)
|
||||
|
|
|
@ -285,6 +285,7 @@ namespace CNTK
|
|||
class NDArrayView final : public std::enable_shared_from_this<NDArrayView>
|
||||
{
|
||||
friend class CompositeFunction;
|
||||
friend class LearnerBase;
|
||||
|
||||
template <typename T, typename ...CtorArgTypes>
|
||||
friend inline std::shared_ptr<T> MakeSharedObject(CtorArgTypes&& ...ctorArgs);
|
||||
|
@ -1396,4 +1397,342 @@ namespace CNTK
|
|||
/// of the computation graph which can be "Combine"d to create a single Function with 2 outputs; viz. CrossEntropy loss and ClassificationError output.
|
||||
///
|
||||
CNTK_API FunctionPtr Combine(const std::initializer_list<FunctionPtr>& operands, const std::wstring& name = L"");
|
||||
|
||||
///
|
||||
/// A serializable value represents one of:
|
||||
/// a) Boolean
|
||||
/// b) Signed long integer
|
||||
/// c) Single and double precision floating point values
|
||||
/// d) NDShape
|
||||
/// e) vector<DictionaryValue>
|
||||
///
|
||||
/// TODO: we need to have native support for DictionaryValue<vector> and DictionaryValue<NDArrayView>.
|
||||
class CNTK_API DictionaryValue final
|
||||
{
|
||||
public:
|
||||
enum class Type : unsigned int
|
||||
{
|
||||
None,
|
||||
Bool,
|
||||
SizeT,
|
||||
Float,
|
||||
Double,
|
||||
NDShape,
|
||||
Vector
|
||||
};
|
||||
|
||||
static const char* TypeName(Type type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case Type::None:
|
||||
return "None";
|
||||
case Type::Bool:
|
||||
return "Bool";
|
||||
case Type::SizeT:
|
||||
return "SizeT";
|
||||
case Type::Float:
|
||||
return "Float";
|
||||
case Type::Double:
|
||||
return "Double";
|
||||
case Type::NDShape:
|
||||
return "NDShape";
|
||||
case Type::Vector:
|
||||
return "Vector";
|
||||
default:
|
||||
LogicError("Unknown DictionaryValue::Type");
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
DictionaryValue() : m_valueType(Type::None)
|
||||
{
|
||||
}
|
||||
|
||||
DictionaryValue(bool value) : m_valueType(GetValueType<bool>())
|
||||
{
|
||||
m_data.m_boolean = value;
|
||||
}
|
||||
|
||||
DictionaryValue(size_t value) : m_valueType(GetValueType<size_t>())
|
||||
{
|
||||
m_data.m_sizeT = value;
|
||||
}
|
||||
|
||||
DictionaryValue(float value) : m_valueType(GetValueType<float>())
|
||||
{
|
||||
m_data.m_float = value;
|
||||
}
|
||||
|
||||
DictionaryValue(double value) : m_valueType(GetValueType<double>())
|
||||
{
|
||||
m_data.m_double = value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DictionaryValue(const T& value) : m_valueType(GetValueType<T>())
|
||||
{
|
||||
static_assert(std::is_same<T, NDShape>::value ||
|
||||
std::is_same<T, std::vector<DictionaryValue>>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
AllocateDataPtr(value);
|
||||
}
|
||||
|
||||
DictionaryValue(const DictionaryValue& other) : m_valueType(Type::Bool)
|
||||
{
|
||||
// The m_valueType must have been set to a non-ptr type to prevent an attempt to interpret
|
||||
// the underlying underlying uninitialized value as a ptr and free it.
|
||||
*this = other;
|
||||
}
|
||||
|
||||
DictionaryValue& operator=(const DictionaryValue& other)
|
||||
{
|
||||
if (this != &other)
|
||||
{
|
||||
FreeDataPtr();
|
||||
|
||||
m_valueType = other.m_valueType;
|
||||
m_data = other.m_data;
|
||||
|
||||
if (other.m_valueType == Type::NDShape)
|
||||
AllocateDataPtr(other.GetValue<NDShape>());
|
||||
else if (other.m_valueType == Type::Vector)
|
||||
AllocateDataPtr(other.GetValue<std::vector<DictionaryValue>>());
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
~DictionaryValue()
|
||||
{
|
||||
FreeDataPtr();
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, bool>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_boolean;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, size_t>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_sizeT;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, float>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_float;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, double>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_double;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return *(reinterpret_cast<T*>(m_data.m_ptr));
|
||||
}
|
||||
|
||||
bool HasValue() const
|
||||
{
|
||||
return m_valueType != Type::None;
|
||||
}
|
||||
|
||||
Type ValueType() const
|
||||
{
|
||||
return m_valueType;
|
||||
}
|
||||
|
||||
friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us);
|
||||
friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us);
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static Type GetValueType()
|
||||
{
|
||||
static_assert(std::is_same<T, bool>::value ||
|
||||
std::is_same<T, size_t>::value ||
|
||||
std::is_same<T, float>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, NDShape>::value ||
|
||||
std::is_same<T, std::vector<DictionaryValue>>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
if (std::is_same<T, bool>::value) return Type::Bool;
|
||||
if (std::is_same<T, size_t>::value) return Type::SizeT;
|
||||
if (std::is_same<T, float>::value) return Type::Float;
|
||||
if (std::is_same<T, double>::value) return Type::Double;
|
||||
if (std::is_same<T, NDShape>::value) return Type::NDShape;
|
||||
if (std::is_same<T, std::vector<DictionaryValue>>::value) return Type::Vector;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VerifyType() const
|
||||
{
|
||||
if (GetValueType<T>() != m_valueType)
|
||||
RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllocateDataPtr(const T& value);
|
||||
|
||||
template <typename T>
|
||||
void FreePtrAsType();
|
||||
|
||||
void FreeDataPtr();
|
||||
|
||||
Type m_valueType;
|
||||
|
||||
union ValueData
|
||||
{
|
||||
bool m_boolean;
|
||||
size_t m_sizeT;
|
||||
float m_float;
|
||||
double m_double;
|
||||
void* m_ptr;
|
||||
} m_data;
|
||||
|
||||
const size_t version = 1;
|
||||
};
|
||||
|
||||
///
|
||||
/// A type denoting a dictionary (keyed by Unicode strings) of serializable values (dynamically typed).
|
||||
///
|
||||
class CNTK_API Dictionary final
|
||||
{
|
||||
public:
|
||||
Dictionary();
|
||||
~Dictionary();
|
||||
|
||||
// Disallow copy construction and assignment
|
||||
Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete;
|
||||
|
||||
Dictionary(Dictionary&& other);
|
||||
Dictionary& operator=(Dictionary&& other);
|
||||
|
||||
DictionaryValue& operator[](const std::wstring& key)
|
||||
{
|
||||
return operator[](key.c_str());
|
||||
}
|
||||
|
||||
DictionaryValue& operator[](const wchar_t* key);
|
||||
|
||||
DictionaryValue operator[](const std::wstring& key) const
|
||||
{
|
||||
return operator[](key.c_str());
|
||||
}
|
||||
|
||||
DictionaryValue operator[](const wchar_t* key) const;
|
||||
|
||||
bool Contains(const std::wstring& key) const
|
||||
{
|
||||
return Contains(key.c_str());
|
||||
}
|
||||
|
||||
bool Contains(const wchar_t* key) const;
|
||||
|
||||
friend CNTK_API Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us);
|
||||
friend CNTK_API Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
|
||||
const size_t version = 1;
|
||||
};
|
||||
|
||||
///
|
||||
/// Abstraction for learning a subset of parameters of a learnable function using first order gradient values
|
||||
/// For e.g momentum, AdaGrad, RMSProp etc. are different types of learners with their own algorithms for
|
||||
/// learning parameter values using first order gradients.
|
||||
///
|
||||
class Learner : public std::enable_shared_from_this<Learner>
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Method to update the parameters associated with this learner. By returning false, this method indicates that
|
||||
// learning has stopped for all of the parameters associated with this learner
|
||||
//
|
||||
CNTK_API virtual bool Update(const std::unordered_map<Variable, ValuePtr>& parameterValues,
|
||||
const std::unordered_map<Variable, const ValuePtr>& gradientValues,
|
||||
size_t trainingSampleCount) = 0;
|
||||
|
||||
///
|
||||
/// Returns the set of parameters associated with this learner.
|
||||
///
|
||||
const std::unordered_set<Variable>& Parameters() const { return m_parameters; }
|
||||
|
||||
// TODO: move the following two methods into ISerializable interface, make
|
||||
// Learner (and all other entities that need checkpointing capability) implement it.
|
||||
///
|
||||
/// Optionally overridable method to checkpoint the learner's state.
|
||||
///
|
||||
CNTK_API virtual Dictionary GetCheckpointState() const = 0;
|
||||
|
||||
///
|
||||
/// Optionally overridable method to restore the learner's state from a previous checkpoint.
|
||||
///
|
||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) = 0;
|
||||
|
||||
virtual ~Learner()
|
||||
{
|
||||
}
|
||||
|
||||
protected:
|
||||
Learner(const std::unordered_set<Variable>& parameters)
|
||||
: m_parameters(parameters)
|
||||
{
|
||||
}
|
||||
|
||||
std::unordered_set<Variable> m_parameters;
|
||||
|
||||
};
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in SGD learner.
|
||||
///
|
||||
/// TODO: add additional SGD parameters here (a collection of learning rate values)
|
||||
CNTK_API LearnerPtr SGDLearner(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in Momentum SGD learner.
|
||||
///
|
||||
/// TODO: add additional Momentum parameters here (a collection of momentum rate values)
|
||||
CNTK_API LearnerPtr MomentumSGDLearner(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in Nesterov's accelerated SGD learner.
|
||||
///
|
||||
CNTK_API LearnerPtr NesterovLearner(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in AdaGrad learner.
|
||||
///
|
||||
CNTK_API LearnerPtr AdaGradLearner(const std::unordered_set<Variable>& parameters, bool needAveMultiplier = true,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in FSAdaGrad (improved AdaGrad) learner.
|
||||
///
|
||||
CNTK_API LearnerPtr FSAdaGradLearner(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
///
|
||||
/// Create an instance of the CNTK built-in RMSProp learner.
|
||||
///
|
||||
CNTK_API LearnerPtr RMSPropLearner(const std::unordered_set<Variable>& parameters,
|
||||
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier = true,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
template <typename ElementType>
|
||||
class ComputationNode;
|
||||
|
||||
class File;
|
||||
}}}
|
||||
|
||||
// TODO: The following should be reconciled with the equivalent code in the CNTK implementation
|
||||
|
@ -158,4 +160,7 @@ namespace CNTK
|
|||
|
||||
class Function;
|
||||
typedef std::shared_ptr<Function> FunctionPtr;
|
||||
|
||||
class Learner;
|
||||
typedef std::shared_ptr<Learner> LearnerPtr;
|
||||
}
|
||||
|
|
|
@ -128,6 +128,7 @@
|
|||
<ClInclude Include="API\CNTKLibrary.h" />
|
||||
<ClInclude Include="API\CNTKLibraryInternals.h" />
|
||||
<ClInclude Include="Function.h" />
|
||||
<ClInclude Include="Learner.h" />
|
||||
<ClInclude Include="Utils.h" />
|
||||
<ClInclude Include="stdafx.h" />
|
||||
<ClInclude Include="targetver.h" />
|
||||
|
@ -140,6 +141,7 @@
|
|||
</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
<ClCompile Include="Function.cpp" />
|
||||
<ClCompile Include="Learner.cpp" />
|
||||
<ClCompile Include="NDArrayView.cpp" />
|
||||
<ClCompile Include="NDMask.cpp" />
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
<ClCompile Include="Variable.cpp" />
|
||||
<ClCompile Include="Utils.cpp" />
|
||||
<ClCompile Include="NDMask.cpp" />
|
||||
<ClCompile Include="Learner.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h" />
|
||||
|
@ -22,6 +23,7 @@
|
|||
<Filter>API</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="Function.h" />
|
||||
<ClInclude Include="Learner.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="API">
|
||||
|
|
|
@ -0,0 +1,464 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "Learner.h"
|
||||
#include "TensorView.h"
|
||||
#include "Utils.h"
|
||||
|
||||
#define UPDATE_FUNCTION \
|
||||
switch (smoothedGradientValue->Data()->GetDataType()) \
|
||||
{ \
|
||||
case DataType::Float: \
|
||||
Update<float>(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \
|
||||
break; \
|
||||
case DataType::Double: \
|
||||
Update<double>(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount); \
|
||||
break; \
|
||||
default: \
|
||||
NOT_IMPLEMENTED; \
|
||||
}
|
||||
|
||||
|
||||
using namespace Microsoft::MSR::CNTK;
|
||||
using namespace std;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
template <typename ElementType>
|
||||
/*static*/ shared_ptr<const Matrix<ElementType>> LearnerBase::GetMatrix(const NDArrayViewPtr arrayView)
|
||||
{
|
||||
return arrayView->GetMatrix<ElementType>();
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ shared_ptr<Matrix<ElementType>> LearnerBase::GetWritableMatrix(NDArrayViewPtr arrayView)
|
||||
{
|
||||
return arrayView->GetWritableMatrix<ElementType>();
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ const TensorView<ElementType>* LearnerBase::GetTensorView(const NDArrayViewPtr arrayView)
|
||||
{
|
||||
return arrayView->GetTensorView<ElementType>();
|
||||
}
|
||||
|
||||
/*static*/ bool LearnerBase::HasNan(const ValuePtr& value, const char* name)
|
||||
{
|
||||
const auto& data = value->Data();
|
||||
switch (data->GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
return data->GetMatrix<float>()->HasNan(name);
|
||||
case DataType::Double:
|
||||
return data->GetMatrix<double>()->HasNan(name);
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType()));
|
||||
}
|
||||
}
|
||||
|
||||
/*static*/ void LearnerBase::Print(const ValuePtr& value, const char* msg)
|
||||
{
|
||||
const auto& data = value->Data();
|
||||
switch (data->GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
data->GetMatrix<float>()->Print(msg);
|
||||
break;
|
||||
case DataType::Double:
|
||||
data->GetMatrix<double>()->Print(msg);
|
||||
break;
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(data->GetDataType()));
|
||||
}
|
||||
}
|
||||
|
||||
// Clipping gradients to prevent outliers,
|
||||
template <typename ElementType>
|
||||
void LearnerBase::ClipGradient(Matrix<ElementType>& gradient, size_t actualMBSize) const
|
||||
{
|
||||
if (m_additionalOptions.gradientClippingThresholdPerSample != numeric_limits<double>::infinity())
|
||||
{
|
||||
double maxGradientPerMB = m_additionalOptions.gradientClippingThresholdPerSample * actualMBSize;
|
||||
if (m_additionalOptions.gradientClippingWithTruncation)
|
||||
gradient.InplaceTruncate(ElementType(maxGradientPerMB));
|
||||
else
|
||||
{
|
||||
// norm2 normalized
|
||||
double gradientNorm = gradient.FrobeniusNorm();
|
||||
if (gradientNorm > maxGradientPerMB)
|
||||
{
|
||||
double normFactor = maxGradientPerMB / gradientNorm;
|
||||
gradient *= ElementType(normFactor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Performs additional preprocessing before calling the update method
|
||||
// (gradient clipping and L2 regularization depending on the additional learning parameters).
|
||||
template <typename ElementType>
|
||||
void LearnerBase::PreProcess(const ValuePtr& gradientValue,const ValuePtr& parameterValue, size_t actualMBSize) const
|
||||
{
|
||||
const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix<ElementType>();
|
||||
|
||||
// clipping gradients to prevent outliers
|
||||
ClipGradient<ElementType>(*gradientMatrix, actualMBSize);
|
||||
|
||||
// L2 regularizer
|
||||
if (m_additionalOptions.l2RegularizationWeight > 0)
|
||||
{
|
||||
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
|
||||
auto weight = ElementType(m_additionalOptions.l2RegularizationWeight * actualMBSize);
|
||||
const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix<ElementType>();
|
||||
Matrix<ElementType>::ScaleAndAdd(weight, *parameterMatrix, *gradientMatrix);
|
||||
}
|
||||
}
|
||||
|
||||
// Performs additional postprocessing after the update method has been executed
|
||||
// (noise injection and L1 regularization specified by the additional learning parameters).
|
||||
template <typename ElementType>
|
||||
void LearnerBase::PostProcess(const Variable& parameter, const ValuePtr& gradientValue,
|
||||
const ValuePtr& parameterValue, size_t actualMBSize) const
|
||||
{
|
||||
const auto& parameterMatrix = parameterValue->Data()->GetWritableMatrix<ElementType>();
|
||||
if (m_additionalOptions.gaussianNoiseInjectionStdDev > 0)
|
||||
{
|
||||
const auto& gradientMatrix = gradientValue->Data()->GetWritableMatrix<ElementType>();
|
||||
|
||||
Matrix<ElementType> sgdUpdateNoise((DEVICEID_TYPE)parameterMatrix->GetDeviceId());
|
||||
|
||||
// get the gradient structure since gradient is sparse
|
||||
sgdUpdateNoise.SetValue(*gradientMatrix);
|
||||
|
||||
auto noiseStdDev = ElementType(m_additionalOptions.gaussianNoiseInjectionStdDev);
|
||||
|
||||
// reset its value to random
|
||||
sgdUpdateNoise.SetGaussianRandomValue(ElementType(0.0), noiseStdDev);
|
||||
|
||||
Matrix<ElementType>::ScaleAndAdd(ElementType(1.0), sgdUpdateNoise, *parameterMatrix);
|
||||
}
|
||||
|
||||
// L1 regularizer with proximal gradient descent method
|
||||
if (m_additionalOptions.l1RegularizationWeight > 0)
|
||||
{
|
||||
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||
// multiply by actualMBSize so that it's invariant to minibatch size since learning rate is per sample
|
||||
auto weight = ElementType(learningRate * m_additionalOptions.l1RegularizationWeight * actualMBSize);
|
||||
parameterValue->Data()->GetWritableMatrix<ElementType>()->InplaceSoftThreshold(weight);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
/*static*/ TensorView<ElementType>* LearnerBase::GetWritableTensorView(NDArrayViewPtr arrayView)
|
||||
{
|
||||
return arrayView->GetWritableTensorView<ElementType>();
|
||||
}
|
||||
|
||||
LearnerBase::LearnerBase(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||
: Learner(parameters),
|
||||
m_learningRatePerSample(0.0),
|
||||
m_sampleCount(0)
|
||||
{
|
||||
const unordered_set<Variable>& parameterSet = parameters;
|
||||
for (const auto& parameter : parameterSet)
|
||||
{
|
||||
// TODO: using the same device to allocate data for all smoothed gradients. Is this correct?
|
||||
// Should the device be specified on the per-parameter basis?
|
||||
NDArrayViewPtr view;
|
||||
if (parameter.GetDataType() == DataType::Float)
|
||||
{
|
||||
view = MakeSharedObject<NDArrayView>(0.0f, parameter.Shape(), device);
|
||||
}
|
||||
else
|
||||
{
|
||||
view = MakeSharedObject<NDArrayView>(0.0, parameter.Shape(), device);
|
||||
}
|
||||
|
||||
m_smoothedGradientValues.insert(make_pair(parameter, MakeSharedObject<Value>(view)));
|
||||
m_additionalOptions.learningRateMultipliers.insert(make_pair(parameter, 1.0));
|
||||
}
|
||||
}
|
||||
|
||||
void LearnerBase::ResetSmoothedGradients()
|
||||
{
|
||||
for (const auto& parameter : Parameters())
|
||||
{
|
||||
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||
const auto& data = smoothedGradientValue->Data();
|
||||
switch (data->GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
data->SetValue(0.0f);
|
||||
break;
|
||||
case DataType::Double:
|
||||
data->SetValue(0.0);
|
||||
break;
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", ::CNTK::DataTypeName(data->GetDataType()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ bool LearnerBase::Update(const unordered_map<Variable, ValuePtr>& parameterValues,
|
||||
const unordered_map<Variable, const ValuePtr>& gradientValues,
|
||||
size_t trainingSampleCount) /*override*/
|
||||
{
|
||||
// make sure trainingSampleCount is a valid value
|
||||
assert(trainingSampleCount > 0);
|
||||
|
||||
for (const auto& parameter : Parameters())
|
||||
{
|
||||
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||
const auto& gradientValue = gradientValues.at(parameter);
|
||||
const auto& parameterValue = parameterValues.at(parameter);
|
||||
|
||||
// TODO: make this a runtime parameter.
|
||||
#if DUMPOUTPUT
|
||||
LOGPRINTF(stderr, "Update_%ls\n", parameter.Name().c_str());
|
||||
#endif
|
||||
|
||||
#ifdef _DEBUG
|
||||
if (HasNan(smoothedGradientValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
|
||||
LogicError("%ls has NaNs in smoothedGradient.", parameter.Name().c_str());
|
||||
#endif
|
||||
|
||||
#if DUMPOUTPUT
|
||||
LOGPRINTF(stderr, "learnRatePerSample=%0.8f, momentum=%0.8f, actualMBSize=%ld\n",
|
||||
m_learningRatePerSample, m_momentumPerSample, trainingSampleCount);
|
||||
LOGPRINTF(stderr, "GradUpdateType()=%s, GradientUpdateNoiseStd()=%0.8f\n",
|
||||
LearnerType().c_str(), m_GaussianNoiseInjectStd);
|
||||
Print(gradientValue, "Gradient Update");
|
||||
Print(smoothedGradientValue, "Smoothed Gradient Input");
|
||||
#endif
|
||||
UPDATE_FUNCTION;
|
||||
|
||||
#if DUMPOUTPUT
|
||||
Print(parameterValue, "Parameter Update");
|
||||
#endif
|
||||
|
||||
#ifdef _DEBUG
|
||||
if (HasNan(parameterValue, "TrainOneEpoch/UpdateWeights/Learner::Update(): "))
|
||||
LogicError("%ls has NaNs in parameter values after parameter update.", parameter.Name().c_str());
|
||||
#endif
|
||||
}
|
||||
m_sampleCount += trainingSampleCount;
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerBase::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||
{
|
||||
PreProcess<ElementType>(gradientValue, parameterValue, trainingSampleCount);
|
||||
Update(parameter, smoothedGradientValue, gradientValue, parameterValue, trainingSampleCount);
|
||||
PostProcess<ElementType>(parameter, gradientValue, parameterValue, trainingSampleCount);
|
||||
}
|
||||
|
||||
string LearnerBase::LearnerType() const
|
||||
{
|
||||
auto name = typeid(*this).name();
|
||||
if (strncmp(name, "class ", 6) == 0)
|
||||
{
|
||||
// On Windows, the type name contains "class" prefix.
|
||||
// Return the actual name, omitting the prefix.
|
||||
return &name[6];
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
/*virtual*/ Dictionary LearnerBase::GetCheckpointState() const /*override*/
|
||||
{
|
||||
NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this.
|
||||
Dictionary checkpoint;
|
||||
|
||||
for (const auto& parameter : Parameters())
|
||||
{
|
||||
// TODO: parameter name is not guaranteed to be unique. Instead, all serializable objects
|
||||
// need to expose "UId" property -- a persistent unique internal name.
|
||||
// Switch to UId as soon as it's available.
|
||||
if (checkpoint.Contains(parameter.Name()))
|
||||
{
|
||||
LogicError("Parameter names must be unique");
|
||||
}
|
||||
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||
|
||||
// Potentially, could store things like dimensions, element size, format, etc., but
|
||||
// that seems to be redundant, since all of that is passed in the constructor.
|
||||
checkpoint[parameter.Name()] = SerializeToVector(smoothedGradientValue->Data());
|
||||
}
|
||||
return checkpoint;
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerBase::RestoreFromCheckpoint(const Dictionary& checkpoint) /*override*/
|
||||
{
|
||||
NOT_IMPLEMENTED; // Until the new checkpointing is fully fleshed out, nobody should be calling this.
|
||||
for (const auto& parameter : Parameters())
|
||||
{
|
||||
if (!checkpoint.Contains(parameter.Name()))
|
||||
{
|
||||
LogicError("Checkpoint does not contain state for parameter %ls", parameter.Name().c_str());
|
||||
}
|
||||
const auto& smoothedGradientValue = m_smoothedGradientValues.at(parameter);
|
||||
|
||||
const DictionaryValue& state = checkpoint[parameter.Name()];
|
||||
|
||||
const auto& data = smoothedGradientValue->Data();
|
||||
|
||||
DeserializeFromVector(data, state.GetValue<vector<DictionaryValue>>());
|
||||
}
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerSGD::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||
{
|
||||
UNUSED(trainingSampleCount);
|
||||
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||
|
||||
const auto& learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||
|
||||
// TODO: break up the NormalGrad into 3 different functions, each with its own set of parameters
|
||||
// (one for vanilla SGD, the other for momentum SGD, and the third one for NAG).
|
||||
smoothedGradientMatrix->NormalGrad(*gradientMatrix, *parameterMatrix,
|
||||
learningRate, ElementType(m_momentumPerSample), m_useNesterovAcceleration);
|
||||
}
|
||||
|
||||
LearnerAdaGrad::LearnerAdaGrad(const unordered_set<Variable>& parameters, bool needAveMultiplier, const DeviceDescriptor& device)
|
||||
: LearnerBase(parameters, device),
|
||||
m_needAveMultiplier(needAveMultiplier)
|
||||
{
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||
{
|
||||
UNUSED(trainingSampleCount);
|
||||
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||
|
||||
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||
|
||||
auto aveMultiplier = smoothedGradientMatrix->Adagrad(*gradientMatrix, m_needAveMultiplier);
|
||||
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
|
||||
}
|
||||
|
||||
LearnerFSAdaGrad::LearnerFSAdaGrad(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||
: LearnerMomentumSGD(parameters, device)
|
||||
{
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerFSAdaGrad::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||
{
|
||||
UNUSED(trainingSampleCount);
|
||||
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||
|
||||
//const double momentum = MomentumPerMB(m_momentumPerSample, trainingSampleCount);
|
||||
|
||||
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||
|
||||
smoothedGradientMatrix->FSAdagrad(trainingSampleCount, *gradientMatrix, *parameterMatrix,
|
||||
learningRate, ElementType(m_momentumPerSample));
|
||||
}
|
||||
|
||||
LearnerRMSProp::LearnerRMSProp(const unordered_set<Variable>& parameters,
|
||||
double gamma, double inc, double dec, double max, double min,
|
||||
bool needAveMultiplier, const DeviceDescriptor& device)
|
||||
: LearnerBase(parameters, device),
|
||||
m_gamma(gamma), m_inc(inc), m_dec(dec), m_max(max), m_min(min),
|
||||
m_needAveMultiplier(needAveMultiplier)
|
||||
{
|
||||
}
|
||||
|
||||
/*virtual*/ void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const /*override*/
|
||||
{
|
||||
UPDATE_FUNCTION;
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
void LearnerRMSProp::Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const
|
||||
{
|
||||
UNUSED(trainingSampleCount);
|
||||
|
||||
const auto& smoothedGradientMatrix = GetWritableMatrix<ElementType>(smoothedGradientValue->Data());
|
||||
const auto& gradientMatrix = GetWritableMatrix<ElementType>(gradientValue->Data());
|
||||
const auto& parameterMatrix = GetWritableMatrix<ElementType>(parameterValue->Data());
|
||||
|
||||
auto learningRate = ElementType(ParameterDependentLearningRate(parameter));
|
||||
|
||||
auto aveMultiplier = smoothedGradientMatrix->RmsProp(*gradientMatrix,
|
||||
ElementType(m_gamma), ElementType(m_inc),
|
||||
ElementType(m_max), ElementType(m_dec),
|
||||
ElementType(m_min), m_needAveMultiplier);
|
||||
Matrix<ElementType>::ScaleAndAdd(ElementType(-learningRate / aveMultiplier), *gradientMatrix, *parameterMatrix);
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template shared_ptr<Matrix<float>> LearnerBase::GetWritableMatrix<float>(const NDArrayViewPtr arrayView);
|
||||
template shared_ptr<Matrix<double>> LearnerBase::GetWritableMatrix<double>(const NDArrayViewPtr arrayView);
|
||||
|
||||
LearnerPtr SGDLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||
{
|
||||
return MakeSharedObject<LearnerSGD>(parameters, device);
|
||||
}
|
||||
|
||||
LearnerPtr MomentumSGDLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||
{
|
||||
return MakeSharedObject<LearnerMomentumSGD>(parameters, device);
|
||||
}
|
||||
|
||||
LearnerPtr NesterovLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||
{
|
||||
return MakeSharedObject<LearnerNesterov>(parameters, device);
|
||||
}
|
||||
|
||||
LearnerPtr AdaGradLearner(const unordered_set<Variable>& parameters, bool needAveMultiplier, const DeviceDescriptor& device)
|
||||
{
|
||||
return MakeSharedObject<LearnerAdaGrad>(parameters, needAveMultiplier, device);
|
||||
}
|
||||
|
||||
LearnerPtr FSAdaGradLearner(const unordered_set<Variable>& parameters, const DeviceDescriptor& device)
|
||||
{
|
||||
return MakeSharedObject<LearnerFSAdaGrad>(parameters, device);
|
||||
}
|
||||
|
||||
LearnerPtr RMSPropLearner(const unordered_set<Variable>& parameters,
|
||||
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
|
||||
const DeviceDescriptor& device)
|
||||
{
|
||||
return MakeSharedObject<LearnerRMSProp>(parameters, gamma, inc, dec, max, min, needAveMultiplier, device);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,224 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
// A collection of additional options that are applicable for all standard learners
|
||||
// (after these options are set, they retain their value for the entire lifespan of a learner).
|
||||
struct AdditionalLearningOptions
|
||||
{
|
||||
double l1RegularizationWeight = 0.0;
|
||||
double l2RegularizationWeight = 0.0;
|
||||
double gaussianNoiseInjectionStdDev = 0.0;
|
||||
bool gradientClippingWithTruncation = false;
|
||||
double gradientClippingThresholdPerSample = 0.0;
|
||||
std::unordered_map<Variable, double> learningRateMultipliers;
|
||||
};
|
||||
|
||||
// An abstract base class at the root of the standard learners hierarchy
|
||||
// It implements most of the learner functionality, except for the actual update function,
|
||||
// and adds a few pre-/postprocessing methods (which are invoked before and after the update).
|
||||
class LearnerBase : public Learner
|
||||
{
|
||||
public:
|
||||
|
||||
CNTK_API virtual bool Update(const std::unordered_map<Variable, ValuePtr>& parameterValues,
|
||||
const std::unordered_map<Variable, const ValuePtr>& gradientValues,
|
||||
size_t trainingSampleCount) override final;
|
||||
|
||||
CNTK_API virtual Dictionary GetCheckpointState() const override;
|
||||
|
||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& checkpoint) override;
|
||||
|
||||
CNTK_API void SetAdditionalOptions(const AdditionalLearningOptions& additionalOptions)
|
||||
{
|
||||
m_additionalOptions = additionalOptions;
|
||||
}
|
||||
|
||||
// TODO: should this be called ResetMomentum?
|
||||
// needed for BlockMomemtumSGD to reset SGD momentum after aggregation.
|
||||
CNTK_API void ResetSmoothedGradients();
|
||||
|
||||
// TODO: move learning rate and momentum scheduling and adjustment functionality
|
||||
// inside the learner and drop these setters.
|
||||
void SetLearningRate(double value) { m_learningRatePerSample = value; }
|
||||
|
||||
protected:
|
||||
LearnerBase(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const = 0;
|
||||
|
||||
double ParameterDependentLearningRate(const Variable& parameter) const
|
||||
{
|
||||
return m_learningRatePerSample * m_additionalOptions.learningRateMultipliers.at(parameter);
|
||||
}
|
||||
|
||||
std::string LearnerType() const;
|
||||
|
||||
double m_learningRatePerSample;
|
||||
|
||||
AdditionalLearningOptions m_additionalOptions;
|
||||
|
||||
std::unordered_map<Variable, ValuePtr> m_smoothedGradientValues;
|
||||
|
||||
// The following four static protected methods expose private methods of NDArrayView class
|
||||
// (which declares LearnerBase as friend class), so that they are available to subclasses.
|
||||
template <typename ElementType>
|
||||
static std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>> GetMatrix(const NDArrayViewPtr arrayView);
|
||||
|
||||
template <typename ElementType>
|
||||
static std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>> GetWritableMatrix(NDArrayViewPtr arrayView);
|
||||
|
||||
template <typename ElementType>
|
||||
static const Microsoft::MSR::CNTK::TensorView<ElementType>* GetTensorView(const NDArrayViewPtr arrayView);
|
||||
|
||||
template <typename ElementType>
|
||||
static Microsoft::MSR::CNTK::TensorView<ElementType>* GetWritableTensorView(NDArrayViewPtr arrayView);
|
||||
|
||||
template <typename ElementType>
|
||||
void ClipGradient(Microsoft::MSR::CNTK::Matrix<ElementType>& gradient, size_t actualMBSize) const;
|
||||
|
||||
// Performs additional preprocessing before calling the update method
|
||||
// (gradient clipping and L2 regularization depending on the additional learning parameters).
|
||||
template <typename ElementType>
|
||||
void PreProcess(const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t actualMBSize) const;
|
||||
|
||||
// Performs additional postprocessing after the update method has been executed
|
||||
// (noise injection and L1 regularization specified by the additional learning parameters).
|
||||
template <typename ElementType>
|
||||
void PostProcess(const Variable& parameter, const ValuePtr& gradientValue,
|
||||
const ValuePtr& parameterValue, size_t actualMBSize) const;
|
||||
private:
|
||||
// Templatized update function, it invokes preprocess and postprocess using the provided
|
||||
// template parameter and also invokes virtual Update method implemented in one of the subclasses.
|
||||
template <typename ElementType>
|
||||
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||
|
||||
// TODO: make these functions friends of NDViewArray and move to Utils?
|
||||
static bool HasNan(const ValuePtr& value, const char* name);
|
||||
static void Print(const ValuePtr& value, const char* msg);
|
||||
|
||||
size_t m_sampleCount;
|
||||
};
|
||||
|
||||
// Vanilla gradient descent optimization algorithm.
|
||||
class LearnerSGD : public LearnerBase
|
||||
{
|
||||
public:
|
||||
|
||||
LearnerSGD(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
|
||||
: LearnerBase(parameters, device),
|
||||
m_momentumPerSample(0.0),
|
||||
m_useNesterovAcceleration(false)
|
||||
{
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||
|
||||
template <typename ElementType>
|
||||
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||
|
||||
double m_momentumPerSample;
|
||||
bool m_useNesterovAcceleration;
|
||||
};
|
||||
|
||||
// SGD optimization with momentum.
|
||||
class LearnerMomentumSGD : public LearnerSGD
|
||||
{
|
||||
public:
|
||||
|
||||
LearnerMomentumSGD(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
|
||||
: LearnerSGD(parameters, device)
|
||||
{
|
||||
}
|
||||
|
||||
void SetMomentum(double value) { m_momentumPerSample = value; }
|
||||
};
|
||||
|
||||
// Nesterov's accelerated SGDLearnerBase descent.
|
||||
class LearnerNesterov : public LearnerSGD
|
||||
{
|
||||
public:
|
||||
|
||||
LearnerNesterov(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice())
|
||||
: LearnerSGD(parameters, device)
|
||||
{
|
||||
m_useNesterovAcceleration = true;
|
||||
}
|
||||
};
|
||||
|
||||
class LearnerAdaGrad : public LearnerBase
|
||||
{
|
||||
public:
|
||||
|
||||
LearnerAdaGrad(const std::unordered_set<Variable>& parameters, bool needAveMultiplier,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
protected:
|
||||
bool m_needAveMultiplier;
|
||||
|
||||
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||
|
||||
template <typename ElementType>
|
||||
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||
};
|
||||
|
||||
class LearnerFSAdaGrad : public LearnerMomentumSGD
|
||||
{
|
||||
public:
|
||||
|
||||
LearnerFSAdaGrad(const std::unordered_set<Variable>& parameters,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
protected:
|
||||
|
||||
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||
|
||||
template <typename ElementType>
|
||||
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||
};
|
||||
|
||||
class LearnerRMSProp : public LearnerBase
|
||||
{
|
||||
public:
|
||||
|
||||
LearnerRMSProp(const std::unordered_set<Variable>& parameters,
|
||||
double gamma, double inc, double dec, double max, double min, bool needAveMultiplier,
|
||||
const DeviceDescriptor& device = DeviceDescriptor::DefaultDevice());
|
||||
|
||||
protected:
|
||||
|
||||
double m_gamma;
|
||||
double m_inc;
|
||||
double m_dec;
|
||||
double m_max;
|
||||
double m_min;
|
||||
bool m_needAveMultiplier;
|
||||
|
||||
virtual void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const override;
|
||||
|
||||
template <typename ElementType>
|
||||
void Update(const Variable& parameter, const ValuePtr& smoothedGradientValue,
|
||||
const ValuePtr& gradientValue, const ValuePtr& parameterValue, size_t trainingSampleCount) const;
|
||||
};
|
||||
}
|
|
@ -338,8 +338,10 @@ namespace CNTK
|
|||
template std::shared_ptr<const Matrix<float>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
|
||||
template std::shared_ptr<const Matrix<double>> NDArrayView::GetMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/) const;
|
||||
|
||||
template std::shared_ptr<Matrix<float>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
||||
template std::shared_ptr<Matrix<double>> NDArrayView::GetWritableMatrix(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
||||
template std::shared_ptr<Matrix<float>> NDArrayView::GetWritableMatrix<float>(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
||||
template std::shared_ptr<Matrix<double>> NDArrayView::GetWritableMatrix<double>(size_t rowColSplitPoint/* = AutoSelectRowColSplitPoint*/);
|
||||
template TensorView<float>* NDArrayView::GetWritableTensorView<float>();
|
||||
template TensorView<double>* NDArrayView::GetWritableTensorView<double>();
|
||||
|
||||
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const float* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
||||
template CNTK_API NDArrayView::NDArrayView(const NDShape& viewShape, const SparseIndexType* colStarts, const SparseIndexType* rowIndices, const double* nonZeroValues, size_t numNonZeroValues, const DeviceDescriptor& device, bool readOnly/* = false*/);
|
||||
|
|
|
@ -6,11 +6,138 @@
|
|||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "File.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
template <typename T>
|
||||
void DictionaryValue::AllocateDataPtr(const T& value)
|
||||
{
|
||||
static_assert(is_same<T, NDShape>::value || is_same<T, vector<DictionaryValue>>::value, "AllocateDataPtr called with invalid type");
|
||||
m_data.m_ptr = new T(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DictionaryValue::FreePtrAsType()
|
||||
{
|
||||
T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
|
||||
delete typedPtr;
|
||||
|
||||
m_data.m_ptr = nullptr;
|
||||
}
|
||||
|
||||
void DictionaryValue::FreeDataPtr()
|
||||
{
|
||||
if (m_valueType == Type::NDShape)
|
||||
FreePtrAsType<NDShape>();
|
||||
else if (m_valueType == Type::Vector)
|
||||
FreePtrAsType<vector<DictionaryValue>>();
|
||||
}
|
||||
|
||||
Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, DictionaryValue& us)
|
||||
{
|
||||
size_t version;
|
||||
stream >> version;
|
||||
|
||||
stream >> us.m_valueType;
|
||||
|
||||
switch (us.ValueType())
|
||||
{
|
||||
case DictionaryValue::Type::Bool:
|
||||
stream >> us.m_data.m_boolean;
|
||||
break;
|
||||
case DictionaryValue::Type::SizeT:
|
||||
stream >> us.m_data.m_sizeT;
|
||||
break;
|
||||
case DictionaryValue::Type::Float:
|
||||
stream >> us.m_data.m_float;
|
||||
break;
|
||||
case DictionaryValue::Type::Double:
|
||||
stream >> us.m_data.m_double;
|
||||
break;
|
||||
case DictionaryValue::Type::NDShape:
|
||||
{
|
||||
size_t size;
|
||||
stream >> size;
|
||||
vector<size_t> dims(size);
|
||||
for (auto i = 0; i < size; i++)
|
||||
{
|
||||
stream >> dims[i];
|
||||
}
|
||||
us.AllocateDataPtr(NDShape(dims));
|
||||
break;
|
||||
}
|
||||
case DictionaryValue::Type::Vector:
|
||||
{
|
||||
size_t size;
|
||||
stream >> size;
|
||||
vector<DictionaryValue> values(size);
|
||||
for (auto i = 0; i < size; i++)
|
||||
{
|
||||
stream >> values[i];
|
||||
}
|
||||
us.AllocateDataPtr(values);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const DictionaryValue& us)
|
||||
{
|
||||
stream << us.version;
|
||||
|
||||
stream << us.ValueType();
|
||||
|
||||
switch (us.ValueType())
|
||||
{
|
||||
case DictionaryValue::Type::Bool:
|
||||
stream << us.m_data.m_boolean;
|
||||
break;
|
||||
case DictionaryValue::Type::SizeT:
|
||||
stream << us.m_data.m_sizeT;
|
||||
break;
|
||||
case DictionaryValue::Type::Float:
|
||||
stream << us.m_data.m_float;
|
||||
break;
|
||||
case DictionaryValue::Type::Double:
|
||||
stream << us.m_data.m_double;
|
||||
break;
|
||||
case DictionaryValue::Type::NDShape:
|
||||
{
|
||||
NDShape* shapePtr = reinterpret_cast<NDShape*>(us.m_data.m_ptr);
|
||||
auto size = shapePtr->NumAxes();
|
||||
stream << size;
|
||||
for (auto i = 0; i < size; i++)
|
||||
{
|
||||
stream << shapePtr->operator[](i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DictionaryValue::Type::Vector:
|
||||
{
|
||||
vector<DictionaryValue>* vectorPtr =
|
||||
reinterpret_cast<vector<DictionaryValue>*>(us.m_data.m_ptr);
|
||||
auto size = vectorPtr->size();
|
||||
stream << size;
|
||||
for (auto i = 0; i < size; i++)
|
||||
{
|
||||
stream << vectorPtr->operator[](i);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
Dictionary::Dictionary()
|
||||
: m_dictionaryData(new std::unordered_map < std::wstring, DictionaryValue>)
|
||||
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -22,7 +149,7 @@ namespace CNTK
|
|||
Dictionary::Dictionary(Dictionary&& other)
|
||||
: m_dictionaryData(nullptr)
|
||||
{
|
||||
*this = std::move(other);
|
||||
*this = move(other);
|
||||
}
|
||||
|
||||
Dictionary& Dictionary::operator=(Dictionary&& other)
|
||||
|
@ -51,4 +178,130 @@ namespace CNTK
|
|||
{
|
||||
return (m_dictionaryData->find(key) != m_dictionaryData->end());
|
||||
}
|
||||
|
||||
Microsoft::MSR::CNTK::File& operator<<(Microsoft::MSR::CNTK::File& stream, const Dictionary& us)
|
||||
{
|
||||
stream << us.version;
|
||||
stream << us.m_dictionaryData->size();
|
||||
for (auto it = us.m_dictionaryData->begin(); it != us.m_dictionaryData->end(); ++it)
|
||||
{
|
||||
stream << it->first;
|
||||
stream << it->second;
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
Microsoft::MSR::CNTK::File& operator>>(Microsoft::MSR::CNTK::File& stream, Dictionary& us)
|
||||
{
|
||||
size_t version;
|
||||
stream >> version;
|
||||
size_t size;
|
||||
stream >> size;
|
||||
us.m_dictionaryData->reserve(size);
|
||||
for (auto i = 0; i < size; i++)
|
||||
{
|
||||
wstring key;
|
||||
stream >> key;
|
||||
DictionaryValue value;
|
||||
stream >> value;
|
||||
us.m_dictionaryData->insert(make_pair(key, value));
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr)
|
||||
{
|
||||
if (viewPtr->IsSparse())
|
||||
{
|
||||
LogicError("Sparse NDArrayView cannot be serialized into a vector.");
|
||||
}
|
||||
|
||||
auto numElements = viewPtr->Shape().TotalSize();
|
||||
|
||||
vector<DictionaryValue> values(numElements);
|
||||
|
||||
NDArrayViewPtr cpuDataViewPtr = viewPtr;
|
||||
if ((viewPtr->Device().Type() != DeviceKind::CPU))
|
||||
{
|
||||
cpuDataViewPtr = MakeSharedObject<NDArrayView>(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice());
|
||||
cpuDataViewPtr->CopyFrom(*viewPtr);
|
||||
}
|
||||
|
||||
const T* buffer = cpuDataViewPtr->DataBuffer<T>();
|
||||
for (auto i = 0; i < numElements; ++i)
|
||||
{
|
||||
T v = buffer[i];
|
||||
values[i] = DictionaryValue(v);
|
||||
}
|
||||
|
||||
return values;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector<DictionaryValue>& values)
|
||||
{
|
||||
if (viewPtr->IsSparse())
|
||||
{
|
||||
LogicError("Sparse NDArrayView cannot be deserialized from a vector.");
|
||||
}
|
||||
|
||||
auto numElements = viewPtr->Shape().TotalSize();
|
||||
|
||||
if (values.size() != numElements)
|
||||
{
|
||||
LogicError("Number of elements (%lu) in the deserialized representation does not match the expected value (%lu)",
|
||||
values.size(), numElements);
|
||||
}
|
||||
|
||||
NDArrayViewPtr cpuDataViewPtr = viewPtr;
|
||||
if ((viewPtr->Device().Type() != DeviceKind::CPU))
|
||||
{
|
||||
cpuDataViewPtr = MakeSharedObject<NDArrayView>(viewPtr->GetDataType(), viewPtr->Shape(), DeviceDescriptor::CPUDevice());
|
||||
}
|
||||
|
||||
T* buffer = cpuDataViewPtr->WritableDataBuffer<T>();
|
||||
for (auto i = 0; i < numElements; ++i)
|
||||
{
|
||||
buffer[i] = values[i].GetValue<T>();
|
||||
}
|
||||
|
||||
if ((viewPtr->Device().Type() != DeviceKind::CPU))
|
||||
{
|
||||
viewPtr->CopyFrom(*cpuDataViewPtr);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: we store the type info for every element in the vector, which is extremely redundant.
|
||||
// Instead, it'd be nice to introduce some sort of DictionaryValueVector.
|
||||
vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr)
|
||||
{
|
||||
switch (viewPtr->GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
return SerializeToVector<float>(viewPtr);
|
||||
case DataType::Double:
|
||||
return SerializeToVector<double>(viewPtr);
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
|
||||
}
|
||||
}
|
||||
|
||||
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const vector<DictionaryValue>& values)
|
||||
{
|
||||
switch (viewPtr->GetDataType())
|
||||
{
|
||||
case DataType::Float:
|
||||
DeserializeFromVector<float>(viewPtr, values);
|
||||
break;
|
||||
case DataType::Double:
|
||||
DeserializeFromVector<double>(viewPtr, values);
|
||||
break;
|
||||
default:
|
||||
LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
|
||||
}
|
||||
}
|
||||
|
||||
template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
|
||||
template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
|
||||
}
|
||||
|
|
|
@ -15,245 +15,6 @@ namespace CNTK
|
|||
// Forward declarations
|
||||
class Dictionary;
|
||||
|
||||
class DictionaryValue
|
||||
{
|
||||
public:
|
||||
enum class Type : unsigned int
|
||||
{
|
||||
None,
|
||||
Bool,
|
||||
SizeT,
|
||||
Double,
|
||||
NDShape,
|
||||
Vector
|
||||
};
|
||||
|
||||
static const char* TypeName(Type type)
|
||||
{
|
||||
if (type == Type::None)
|
||||
return "None";
|
||||
else if (type == Type::Bool)
|
||||
return "Bool";
|
||||
else if (type == Type::SizeT)
|
||||
return "SizeT";
|
||||
else if (type == Type::Double)
|
||||
return "Double";
|
||||
else if (type == Type::NDShape)
|
||||
return "NDShape";
|
||||
else if (type == Type::Vector)
|
||||
return "Vector";
|
||||
else
|
||||
LogicError("Unknown DictionaryValue::Type");
|
||||
}
|
||||
|
||||
public:
|
||||
DictionaryValue()
|
||||
: m_valueType(Type::None)
|
||||
{
|
||||
}
|
||||
|
||||
DictionaryValue(bool value)
|
||||
: m_valueType(GetValueType<bool>())
|
||||
{
|
||||
m_data.m_boolean = value;
|
||||
}
|
||||
|
||||
DictionaryValue(size_t value)
|
||||
: m_valueType(GetValueType<size_t>())
|
||||
{
|
||||
m_data.m_sizeT = value;
|
||||
}
|
||||
|
||||
DictionaryValue(double value)
|
||||
: m_valueType(GetValueType<double>())
|
||||
{
|
||||
m_data.m_double = value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DictionaryValue(const T& value)
|
||||
: m_valueType(GetValueType<T>())
|
||||
{
|
||||
static_assert(std::is_same<T, NDShape>::value ||
|
||||
std::is_same<T, std::vector<DictionaryValue>>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
AllocateDataPtr(value);
|
||||
}
|
||||
|
||||
DictionaryValue(const DictionaryValue& other)
|
||||
: m_valueType(Type::Bool)
|
||||
{
|
||||
// The m_valueType must hvae been set to a non-ptr type to prevent an attempt to interpret
|
||||
// the underlying underlying uninitialized value as a ptr and free it.
|
||||
*this = other;
|
||||
}
|
||||
|
||||
DictionaryValue& operator=(const DictionaryValue& other)
|
||||
{
|
||||
if (this != &other)
|
||||
{
|
||||
FreeDataPtr();
|
||||
|
||||
m_valueType = other.m_valueType;
|
||||
m_data = other.m_data;
|
||||
|
||||
if (other.m_valueType == Type::NDShape)
|
||||
AllocateDataPtr(other.GetValue<NDShape>());
|
||||
else if (other.m_valueType == Type::Vector)
|
||||
AllocateDataPtr(other.GetValue<std::vector<DictionaryValue>>());
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
~DictionaryValue()
|
||||
{
|
||||
FreeDataPtr();
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, bool>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_boolean;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, size_t>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_sizeT;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, double>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return m_data.m_double;
|
||||
}
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value>::type* = nullptr>
|
||||
const T& GetValue() const
|
||||
{
|
||||
VerifyType<T>();
|
||||
return *(reinterpret_cast<T*>(m_data.m_ptr));
|
||||
}
|
||||
|
||||
bool HasValue() const
|
||||
{
|
||||
return m_valueType != Type::None;
|
||||
}
|
||||
|
||||
Type ValueType() const
|
||||
{
|
||||
return m_valueType;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static Type GetValueType()
|
||||
{
|
||||
static_assert(std::is_same<T, bool>::value ||
|
||||
std::is_same<T, size_t>::value ||
|
||||
std::is_same<T, double>::value ||
|
||||
std::is_same<T, NDShape>::value ||
|
||||
std::is_same<T, std::vector<DictionaryValue>>::value ||
|
||||
std::is_same<T, CNTK::Dictionary>::value,
|
||||
"Unsupported ValueType");
|
||||
|
||||
if (std::is_same<T, bool>::value)
|
||||
return Type::Bool;
|
||||
else if (std::is_same<T, size_t>::value)
|
||||
return Type::SizeT;
|
||||
else if (std::is_same<T, double>::value)
|
||||
return Type::Double;
|
||||
else if (std::is_same<T, NDShape>::value)
|
||||
return Type::NDShape;
|
||||
else if (std::is_same<T, std::vector<DictionaryValue>>::value)
|
||||
return Type::Vector;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VerifyType() const
|
||||
{
|
||||
if (GetValueType<T>() != m_valueType)
|
||||
RuntimeError("Reading a DictionaryValue as the wrong type; Reading as type %s when actual type is %s", typeid(T).name(), DictionaryValue::TypeName(m_valueType));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllocateDataPtr(const T& value)
|
||||
{
|
||||
static_assert(std::is_same<T, NDShape>::value || std::is_same<T, std::vector<DictionaryValue>>::value, "AllocateDataPtr called with invalid type");
|
||||
m_data.m_ptr = new T(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FreePtrAsType()
|
||||
{
|
||||
T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
|
||||
delete typedPtr;
|
||||
|
||||
m_data.m_ptr = nullptr;
|
||||
}
|
||||
|
||||
void FreeDataPtr()
|
||||
{
|
||||
if (m_valueType == Type::NDShape)
|
||||
FreePtrAsType<NDShape>();
|
||||
else if (m_valueType == Type::Vector)
|
||||
FreePtrAsType<std::vector<DictionaryValue>>();
|
||||
}
|
||||
|
||||
private:
|
||||
Type m_valueType;
|
||||
|
||||
union ValueData
|
||||
{
|
||||
bool m_boolean;
|
||||
size_t m_sizeT;
|
||||
double m_double;
|
||||
void* m_ptr;
|
||||
} m_data;
|
||||
};
|
||||
|
||||
class Dictionary
|
||||
{
|
||||
public:
|
||||
Dictionary();
|
||||
~Dictionary();
|
||||
|
||||
// Disallow copy contruction and assignment
|
||||
Dictionary(const Dictionary&) = delete; Dictionary& operator=(const Dictionary&) = delete;
|
||||
|
||||
Dictionary(Dictionary&& other);
|
||||
Dictionary& operator=(Dictionary&& other);
|
||||
|
||||
DictionaryValue& operator[](const std::wstring& key)
|
||||
{
|
||||
return operator[](key.c_str());
|
||||
}
|
||||
|
||||
DictionaryValue& operator[](const wchar_t* key);
|
||||
|
||||
DictionaryValue operator[](const std::wstring& key) const
|
||||
{
|
||||
return operator[](key.c_str());
|
||||
}
|
||||
|
||||
DictionaryValue operator[](const wchar_t* key) const;
|
||||
|
||||
bool Contains(const std::wstring& key) const
|
||||
{
|
||||
return Contains(key.c_str());
|
||||
}
|
||||
|
||||
bool Contains(const wchar_t* key) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::wstring, DictionaryValue>* m_dictionaryData;
|
||||
};
|
||||
|
||||
// Helper to get the size of an element of the specified DataType
|
||||
inline size_t ElementSize(DataType dataType)
|
||||
{
|
||||
|
@ -363,4 +124,8 @@ namespace CNTK
|
|||
{
|
||||
return var.IsInput() && var.IsSparse();
|
||||
}
|
||||
|
||||
std::vector<DictionaryValue> SerializeToVector(const NDArrayViewPtr& viewPtr);
|
||||
|
||||
void DeserializeFromVector(const NDArrayViewPtr& viewPtr, const std::vector<DictionaryValue>& values);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче