Integrate zhouwang/pr1441 into master

This commit is contained in:
Project Philly 2017-03-18 04:31:50 -07:00
Родитель d2f57f7266 079518dbd9
Коммит 9b877f11e0
6 изменённых файлов: 34 добавлений и 19 удалений

Просмотреть файл

@ -1744,7 +1744,7 @@ private:
CNTK_API static Variable Deserialize(const Dictionary& dictionary, const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice()); CNTK_API static Variable Deserialize(const Dictionary& dictionary, const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
void SetOwner(Function* ownerFunction); void SetOwner(const std::weak_ptr<Function>& ownerFunction);
Variable CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const; Variable CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const;

Просмотреть файл

@ -22,9 +22,9 @@ namespace CNTK
for (auto outputVar : outputs) for (auto outputVar : outputs)
{ {
if (outputVar.IsOutput() && !outputVar.Owner()) if (outputVar.IsOutput() && !outputVar.Owner())
outputVar.SetOwner(this); outputVar.SetOwner(shared_from_this());
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.m_dataFields->m_ownerFunction == this) if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.Owner().get() == this)
{ {
// in case of a primitive function, set uid of output vars to owner function uid + "_Output_" + output index. // in case of a primitive function, set uid of output vars to owner function uid + "_Output_" + output index.
outputVar.m_dataFields->m_uid = m_uid + L"_" + VariableKindName(outputVar.Kind()) + L"_" + std::to_wstring(m_outputs.size()); outputVar.m_dataFields->m_uid = m_uid + L"_" + VariableKindName(outputVar.Kind()) + L"_" + std::to_wstring(m_outputs.size());

Просмотреть файл

@ -133,9 +133,6 @@ namespace CNTK
return !(*this == other); return !(*this == other);
} }
Dictionary::Dictionary() Dictionary::Dictionary()
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>) : m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
{ {
@ -425,7 +422,6 @@ namespace CNTK
return fd; return fd;
} }
std::string ToString(const std::wstring& wstring) std::string ToString(const std::wstring& wstring)
{ {
#ifdef _MSC_VER #ifdef _MSC_VER
@ -518,6 +514,7 @@ namespace CNTK
return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences); return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences);
} }
/*static*/ void Utils::VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value) /*static*/ void Utils::VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value)
{ {
if (var.GetDataType() != value->GetDataType()) if (var.GetDataType() != value->GetDataType())

Просмотреть файл

@ -30,6 +30,15 @@ namespace CNTK
NOT_IMPLEMENTED; NOT_IMPLEMENTED;
} }
template <typename T>
inline bool IsObjectExpired(std::weak_ptr<T> ptrToObject)
{
if ((ptrToObject.owner_before(std::weak_ptr<T>{}) || std::weak_ptr<T>{}.owner_before(ptrToObject)) && ptrToObject.expired())
return true;
else
return false;
}
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device) inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
{ {
if (device.Type() == DeviceKind::CPU) if (device.Type() == DeviceKind::CPU)

Просмотреть файл

@ -72,10 +72,7 @@ namespace CNTK
FunctionPtr Variable::Owner() const FunctionPtr Variable::Owner() const
{ {
if (m_dataFields->m_ownerFunction != nullptr) return m_dataFields->Owner();
return m_dataFields->m_ownerFunction->shared_from_this();
else
return nullptr;
} }
Variable Variable::CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const Variable Variable::CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const
@ -94,12 +91,12 @@ namespace CNTK
return copy; return copy;
} }
void Variable::SetOwner(Function* ownerFunction) void Variable::SetOwner(const std::weak_ptr<Function>& ownerFunction)
{ {
if (Kind() != VariableKind::Output) if (Kind() != VariableKind::Output)
LogicError("Variable '%S' SetOwner: Owner can only be set for Output Variables", AsString().c_str()); LogicError("Variable '%S' SetOwner: Owner can only be set for Output Variables", AsString().c_str());
if (m_dataFields->m_ownerFunction != nullptr) if (Owner() != nullptr)
LogicError("Variable '%S' SetOwner: An Output Variable whose owner has previously been set, cannot be reset.", AsString().c_str()); LogicError("Variable '%S' SetOwner: An Output Variable whose owner has previously been set, cannot be reset.", AsString().c_str());
m_dataFields->m_ownerFunction = ownerFunction; m_dataFields->m_ownerFunction = ownerFunction;
@ -213,13 +210,24 @@ namespace CNTK
return wss.str(); return wss.str();
} }
FunctionPtr VariableFields::Owner() const
{
if (IsObjectExpired(m_ownerFunction))
LogicError("The owner function of Variable '%S' is unexpectedly expired.", AsString().c_str());
auto ownerFunctionPtr = m_ownerFunction.lock();
if (ownerFunctionPtr != nullptr)
return ownerFunctionPtr->shared_from_this();
else
return nullptr;
}
std::shared_ptr<VariableFields> VariableFields::Clone() const std::shared_ptr<VariableFields> VariableFields::Clone() const
{ {
if (m_ownerFunction != nullptr) if (Owner() != nullptr)
InvalidArgument("Output variable '%S' cannot be cloned.", AsString().c_str()); InvalidArgument("Output variable '%S' cannot be cloned.", AsString().c_str());
// Note: We do not clone m_blockFunctionVariableMapping // Note: We do not clone m_blockFunctionVariableMapping
auto clone = MakeSharedObject<VariableFields>(m_shape, auto clone = MakeSharedObject<VariableFields>(m_shape,
m_varKind, m_varKind,
m_dataType, m_dataType,
@ -364,7 +372,7 @@ namespace CNTK
} }
Variable::Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid) Variable::Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, isSparse, name, uid)) : m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, std::weak_ptr<Function>(), value, needsGradient, dynamicAxes, isSparse, name, uid))
{} {}
template <typename ElementType> template <typename ElementType>

Просмотреть файл

@ -19,7 +19,7 @@ namespace CNTK
NDShape m_shape; NDShape m_shape;
VariableKind m_varKind; VariableKind m_varKind;
::CNTK::DataType m_dataType; ::CNTK::DataType m_dataType;
Function* m_ownerFunction; // Variable does not keep the Function alive std::weak_ptr<Function> m_ownerFunction;
std::unique_ptr<std::once_flag> m_initValueFlag; std::unique_ptr<std::once_flag> m_initValueFlag;
NDArrayViewPtr m_value; NDArrayViewPtr m_value;
std::unique_ptr<ParameterInitializer> m_valueInitializer; std::unique_ptr<ParameterInitializer> m_valueInitializer;
@ -32,7 +32,7 @@ namespace CNTK
std::atomic<size_t> m_valueTimeStamp; std::atomic<size_t> m_valueTimeStamp;
Variable m_blockFunctionVariableMapping; Variable m_blockFunctionVariableMapping;
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid) VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, const std::weak_ptr<Function>& ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid), m_valueTimeStamp(0) : m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid), m_valueTimeStamp(0)
{ {
if (value && (type != value->GetDataType())) if (value && (type != value->GetDataType()))
@ -59,6 +59,7 @@ namespace CNTK
std::wstring AsString() const; std::wstring AsString() const;
std::shared_ptr<VariableFields> Clone() const; std::shared_ptr<VariableFields> Clone() const;
FunctionPtr Owner() const;
CNTK_API void SetValueInitialization(const ParameterInitializer& initializationConfig, const DeviceDescriptor& device); CNTK_API void SetValueInitialization(const ParameterInitializer& initializationConfig, const DeviceDescriptor& device);