Integrate zhouwang/pr1441 into master
This commit is contained in:
Коммит
9b877f11e0
|
@ -1744,7 +1744,7 @@ private:
|
||||||
|
|
||||||
CNTK_API static Variable Deserialize(const Dictionary& dictionary, const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
|
CNTK_API static Variable Deserialize(const Dictionary& dictionary, const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
|
||||||
|
|
||||||
void SetOwner(Function* ownerFunction);
|
void SetOwner(const std::weak_ptr<Function>& ownerFunction);
|
||||||
|
|
||||||
Variable CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const;
|
Variable CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const;
|
||||||
|
|
||||||
|
|
|
@ -22,9 +22,9 @@ namespace CNTK
|
||||||
for (auto outputVar : outputs)
|
for (auto outputVar : outputs)
|
||||||
{
|
{
|
||||||
if (outputVar.IsOutput() && !outputVar.Owner())
|
if (outputVar.IsOutput() && !outputVar.Owner())
|
||||||
outputVar.SetOwner(this);
|
outputVar.SetOwner(shared_from_this());
|
||||||
|
|
||||||
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.m_dataFields->m_ownerFunction == this)
|
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.Owner().get() == this)
|
||||||
{
|
{
|
||||||
// in case of a primitive function, set uid of output vars to owner function uid + "_Output_" + output index.
|
// in case of a primitive function, set uid of output vars to owner function uid + "_Output_" + output index.
|
||||||
outputVar.m_dataFields->m_uid = m_uid + L"_" + VariableKindName(outputVar.Kind()) + L"_" + std::to_wstring(m_outputs.size());
|
outputVar.m_dataFields->m_uid = m_uid + L"_" + VariableKindName(outputVar.Kind()) + L"_" + std::to_wstring(m_outputs.size());
|
||||||
|
|
|
@ -133,9 +133,6 @@ namespace CNTK
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Dictionary::Dictionary()
|
Dictionary::Dictionary()
|
||||||
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
|
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
|
||||||
{
|
{
|
||||||
|
@ -425,7 +422,6 @@ namespace CNTK
|
||||||
return fd;
|
return fd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::string ToString(const std::wstring& wstring)
|
std::string ToString(const std::wstring& wstring)
|
||||||
{
|
{
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
|
@ -518,6 +514,7 @@ namespace CNTK
|
||||||
|
|
||||||
return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences);
|
return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ void Utils::VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value)
|
/*static*/ void Utils::VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value)
|
||||||
{
|
{
|
||||||
if (var.GetDataType() != value->GetDataType())
|
if (var.GetDataType() != value->GetDataType())
|
||||||
|
|
|
@ -30,6 +30,15 @@ namespace CNTK
|
||||||
NOT_IMPLEMENTED;
|
NOT_IMPLEMENTED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool IsObjectExpired(std::weak_ptr<T> ptrToObject)
|
||||||
|
{
|
||||||
|
if ((ptrToObject.owner_before(std::weak_ptr<T>{}) || std::weak_ptr<T>{}.owner_before(ptrToObject)) && ptrToObject.expired())
|
||||||
|
return true;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
|
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
|
||||||
{
|
{
|
||||||
if (device.Type() == DeviceKind::CPU)
|
if (device.Type() == DeviceKind::CPU)
|
||||||
|
|
|
@ -72,10 +72,7 @@ namespace CNTK
|
||||||
|
|
||||||
FunctionPtr Variable::Owner() const
|
FunctionPtr Variable::Owner() const
|
||||||
{
|
{
|
||||||
if (m_dataFields->m_ownerFunction != nullptr)
|
return m_dataFields->Owner();
|
||||||
return m_dataFields->m_ownerFunction->shared_from_this();
|
|
||||||
else
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable Variable::CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const
|
Variable Variable::CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const
|
||||||
|
@ -94,12 +91,12 @@ namespace CNTK
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Variable::SetOwner(Function* ownerFunction)
|
void Variable::SetOwner(const std::weak_ptr<Function>& ownerFunction)
|
||||||
{
|
{
|
||||||
if (Kind() != VariableKind::Output)
|
if (Kind() != VariableKind::Output)
|
||||||
LogicError("Variable '%S' SetOwner: Owner can only be set for Output Variables", AsString().c_str());
|
LogicError("Variable '%S' SetOwner: Owner can only be set for Output Variables", AsString().c_str());
|
||||||
|
|
||||||
if (m_dataFields->m_ownerFunction != nullptr)
|
if (Owner() != nullptr)
|
||||||
LogicError("Variable '%S' SetOwner: An Output Variable whose owner has previously been set, cannot be reset.", AsString().c_str());
|
LogicError("Variable '%S' SetOwner: An Output Variable whose owner has previously been set, cannot be reset.", AsString().c_str());
|
||||||
|
|
||||||
m_dataFields->m_ownerFunction = ownerFunction;
|
m_dataFields->m_ownerFunction = ownerFunction;
|
||||||
|
@ -213,13 +210,24 @@ namespace CNTK
|
||||||
return wss.str();
|
return wss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionPtr VariableFields::Owner() const
|
||||||
|
{
|
||||||
|
if (IsObjectExpired(m_ownerFunction))
|
||||||
|
LogicError("The owner function of Variable '%S' is unexpectedly expired.", AsString().c_str());
|
||||||
|
|
||||||
|
auto ownerFunctionPtr = m_ownerFunction.lock();
|
||||||
|
if (ownerFunctionPtr != nullptr)
|
||||||
|
return ownerFunctionPtr->shared_from_this();
|
||||||
|
else
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<VariableFields> VariableFields::Clone() const
|
std::shared_ptr<VariableFields> VariableFields::Clone() const
|
||||||
{
|
{
|
||||||
if (m_ownerFunction != nullptr)
|
if (Owner() != nullptr)
|
||||||
InvalidArgument("Output variable '%S' cannot be cloned.", AsString().c_str());
|
InvalidArgument("Output variable '%S' cannot be cloned.", AsString().c_str());
|
||||||
|
|
||||||
// Note: We do not clone m_blockFunctionVariableMapping
|
// Note: We do not clone m_blockFunctionVariableMapping
|
||||||
|
|
||||||
auto clone = MakeSharedObject<VariableFields>(m_shape,
|
auto clone = MakeSharedObject<VariableFields>(m_shape,
|
||||||
m_varKind,
|
m_varKind,
|
||||||
m_dataType,
|
m_dataType,
|
||||||
|
@ -364,7 +372,7 @@ namespace CNTK
|
||||||
}
|
}
|
||||||
|
|
||||||
Variable::Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
Variable::Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
||||||
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, isSparse, name, uid))
|
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, std::weak_ptr<Function>(), value, needsGradient, dynamicAxes, isSparse, name, uid))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
template <typename ElementType>
|
template <typename ElementType>
|
||||||
|
|
|
@ -19,7 +19,7 @@ namespace CNTK
|
||||||
NDShape m_shape;
|
NDShape m_shape;
|
||||||
VariableKind m_varKind;
|
VariableKind m_varKind;
|
||||||
::CNTK::DataType m_dataType;
|
::CNTK::DataType m_dataType;
|
||||||
Function* m_ownerFunction; // Variable does not keep the Function alive
|
std::weak_ptr<Function> m_ownerFunction;
|
||||||
std::unique_ptr<std::once_flag> m_initValueFlag;
|
std::unique_ptr<std::once_flag> m_initValueFlag;
|
||||||
NDArrayViewPtr m_value;
|
NDArrayViewPtr m_value;
|
||||||
std::unique_ptr<ParameterInitializer> m_valueInitializer;
|
std::unique_ptr<ParameterInitializer> m_valueInitializer;
|
||||||
|
@ -32,7 +32,7 @@ namespace CNTK
|
||||||
std::atomic<size_t> m_valueTimeStamp;
|
std::atomic<size_t> m_valueTimeStamp;
|
||||||
Variable m_blockFunctionVariableMapping;
|
Variable m_blockFunctionVariableMapping;
|
||||||
|
|
||||||
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, const std::weak_ptr<Function>& ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
||||||
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid), m_valueTimeStamp(0)
|
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid), m_valueTimeStamp(0)
|
||||||
{
|
{
|
||||||
if (value && (type != value->GetDataType()))
|
if (value && (type != value->GetDataType()))
|
||||||
|
@ -59,6 +59,7 @@ namespace CNTK
|
||||||
|
|
||||||
std::wstring AsString() const;
|
std::wstring AsString() const;
|
||||||
std::shared_ptr<VariableFields> Clone() const;
|
std::shared_ptr<VariableFields> Clone() const;
|
||||||
|
FunctionPtr Owner() const;
|
||||||
|
|
||||||
CNTK_API void SetValueInitialization(const ParameterInitializer& initializationConfig, const DeviceDescriptor& device);
|
CNTK_API void SetValueInitialization(const ParameterInitializer& initializationConfig, const DeviceDescriptor& device);
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче