Integrate zhouwang/pr1441 into master
This commit is contained in:
Коммит
9b877f11e0
|
@ -1744,7 +1744,7 @@ private:
|
|||
|
||||
CNTK_API static Variable Deserialize(const Dictionary& dictionary, const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
|
||||
|
||||
void SetOwner(Function* ownerFunction);
|
||||
void SetOwner(const std::weak_ptr<Function>& ownerFunction);
|
||||
|
||||
Variable CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const;
|
||||
|
||||
|
@ -5282,4 +5282,4 @@ namespace std
|
|||
return std::hash<size_t>()(x.m_globalRank);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,9 +22,9 @@ namespace CNTK
|
|||
for (auto outputVar : outputs)
|
||||
{
|
||||
if (outputVar.IsOutput() && !outputVar.Owner())
|
||||
outputVar.SetOwner(this);
|
||||
outputVar.SetOwner(shared_from_this());
|
||||
|
||||
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.m_dataFields->m_ownerFunction == this)
|
||||
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.Owner().get() == this)
|
||||
{
|
||||
// in case of a primitive function, set uid of output vars to owner function uid + "_Output_" + output index.
|
||||
outputVar.m_dataFields->m_uid = m_uid + L"_" + VariableKindName(outputVar.Kind()) + L"_" + std::to_wstring(m_outputs.size());
|
||||
|
|
|
@ -133,9 +133,6 @@ namespace CNTK
|
|||
return !(*this == other);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
Dictionary::Dictionary()
|
||||
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
|
||||
{
|
||||
|
@ -425,7 +422,6 @@ namespace CNTK
|
|||
return fd;
|
||||
}
|
||||
|
||||
|
||||
std::string ToString(const std::wstring& wstring)
|
||||
{
|
||||
#ifdef _MSC_VER
|
||||
|
@ -518,6 +514,7 @@ namespace CNTK
|
|||
|
||||
return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences);
|
||||
}
|
||||
|
||||
/*static*/ void Utils::VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value)
|
||||
{
|
||||
if (var.GetDataType() != value->GetDataType())
|
||||
|
|
|
@ -30,6 +30,15 @@ namespace CNTK
|
|||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool IsObjectExpired(std::weak_ptr<T> ptrToObject)
|
||||
{
|
||||
if ((ptrToObject.owner_before(std::weak_ptr<T>{}) || std::weak_ptr<T>{}.owner_before(ptrToObject)) && ptrToObject.expired())
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
|
||||
{
|
||||
if (device.Type() == DeviceKind::CPU)
|
||||
|
|
|
@ -72,10 +72,7 @@ namespace CNTK
|
|||
|
||||
FunctionPtr Variable::Owner() const
|
||||
{
|
||||
if (m_dataFields->m_ownerFunction != nullptr)
|
||||
return m_dataFields->m_ownerFunction->shared_from_this();
|
||||
else
|
||||
return nullptr;
|
||||
return m_dataFields->Owner();
|
||||
}
|
||||
|
||||
Variable Variable::CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const
|
||||
|
@ -94,12 +91,12 @@ namespace CNTK
|
|||
return copy;
|
||||
}
|
||||
|
||||
void Variable::SetOwner(Function* ownerFunction)
|
||||
void Variable::SetOwner(const std::weak_ptr<Function>& ownerFunction)
|
||||
{
|
||||
if (Kind() != VariableKind::Output)
|
||||
LogicError("Variable '%S' SetOwner: Owner can only be set for Output Variables", AsString().c_str());
|
||||
|
||||
if (m_dataFields->m_ownerFunction != nullptr)
|
||||
if (Owner() != nullptr)
|
||||
LogicError("Variable '%S' SetOwner: An Output Variable whose owner has previously been set, cannot be reset.", AsString().c_str());
|
||||
|
||||
m_dataFields->m_ownerFunction = ownerFunction;
|
||||
|
@ -213,13 +210,24 @@ namespace CNTK
|
|||
return wss.str();
|
||||
}
|
||||
|
||||
FunctionPtr VariableFields::Owner() const
|
||||
{
|
||||
if (IsObjectExpired(m_ownerFunction))
|
||||
LogicError("The owner function of Variable '%S' is unexpectedly expired.", AsString().c_str());
|
||||
|
||||
auto ownerFunctionPtr = m_ownerFunction.lock();
|
||||
if (ownerFunctionPtr != nullptr)
|
||||
return ownerFunctionPtr->shared_from_this();
|
||||
else
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<VariableFields> VariableFields::Clone() const
|
||||
{
|
||||
if (m_ownerFunction != nullptr)
|
||||
if (Owner() != nullptr)
|
||||
InvalidArgument("Output variable '%S' cannot be cloned.", AsString().c_str());
|
||||
|
||||
// Note: We do not clone m_blockFunctionVariableMapping
|
||||
|
||||
auto clone = MakeSharedObject<VariableFields>(m_shape,
|
||||
m_varKind,
|
||||
m_dataType,
|
||||
|
@ -364,7 +372,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
Variable::Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
||||
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, isSparse, name, uid))
|
||||
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, std::weak_ptr<Function>(), value, needsGradient, dynamicAxes, isSparse, name, uid))
|
||||
{}
|
||||
|
||||
template <typename ElementType>
|
||||
|
|
|
@ -19,7 +19,7 @@ namespace CNTK
|
|||
NDShape m_shape;
|
||||
VariableKind m_varKind;
|
||||
::CNTK::DataType m_dataType;
|
||||
Function* m_ownerFunction; // Variable does not keep the Function alive
|
||||
std::weak_ptr<Function> m_ownerFunction;
|
||||
std::unique_ptr<std::once_flag> m_initValueFlag;
|
||||
NDArrayViewPtr m_value;
|
||||
std::unique_ptr<ParameterInitializer> m_valueInitializer;
|
||||
|
@ -32,7 +32,7 @@ namespace CNTK
|
|||
std::atomic<size_t> m_valueTimeStamp;
|
||||
Variable m_blockFunctionVariableMapping;
|
||||
|
||||
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
||||
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, const std::weak_ptr<Function>& ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
|
||||
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid), m_valueTimeStamp(0)
|
||||
{
|
||||
if (value && (type != value->GetDataType()))
|
||||
|
@ -59,6 +59,7 @@ namespace CNTK
|
|||
|
||||
std::wstring AsString() const;
|
||||
std::shared_ptr<VariableFields> Clone() const;
|
||||
FunctionPtr Owner() const;
|
||||
|
||||
CNTK_API void SetValueInitialization(const ParameterInitializer& initializationConfig, const DeviceDescriptor& device);
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче