Integrate zhouwang/pr1441 into master

This commit is contained in:
Project Philly 2017-03-18 04:31:50 -07:00
Родитель d2f57f7266 079518dbd9
Коммит 9b877f11e0
6 изменённых файлов: 34 добавлений и 19 удалений

Просмотреть файл

@ -1744,7 +1744,7 @@ private:
CNTK_API static Variable Deserialize(const Dictionary& dictionary, const ::CNTK::DeviceDescriptor& device = DeviceDescriptor::UseDefaultDevice());
void SetOwner(Function* ownerFunction);
void SetOwner(const std::weak_ptr<Function>& ownerFunction);
Variable CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const;
@ -5282,4 +5282,4 @@ namespace std
return std::hash<size_t>()(x.m_globalRank);
}
};
}
}

Просмотреть файл

@ -22,9 +22,9 @@ namespace CNTK
for (auto outputVar : outputs)
{
if (outputVar.IsOutput() && !outputVar.Owner())
outputVar.SetOwner(this);
outputVar.SetOwner(shared_from_this());
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.m_dataFields->m_ownerFunction == this)
if (m_rootFunction == nullptr && outputVar.IsOutput() && outputVar.Owner().get() == this)
{
// in case of a primitive function, set uid of output vars to owner function uid + "_Output_" + output index.
outputVar.m_dataFields->m_uid = m_uid + L"_" + VariableKindName(outputVar.Kind()) + L"_" + std::to_wstring(m_outputs.size());

Просмотреть файл

@ -133,9 +133,6 @@ namespace CNTK
return !(*this == other);
}
Dictionary::Dictionary()
: m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
{
@ -425,7 +422,6 @@ namespace CNTK
return fd;
}
std::string ToString(const std::wstring& wstring)
{
#ifdef _MSC_VER
@ -518,6 +514,7 @@ namespace CNTK
return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences);
}
/*static*/ void Utils::VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value)
{
if (var.GetDataType() != value->GetDataType())

Просмотреть файл

@ -30,6 +30,15 @@ namespace CNTK
NOT_IMPLEMENTED;
}
template <typename T>
inline bool IsObjectExpired(std::weak_ptr<T> ptrToObject)
{
if ((ptrToObject.owner_before(std::weak_ptr<T>{}) || std::weak_ptr<T>{}.owner_before(ptrToObject)) && ptrToObject.expired())
return true;
else
return false;
}
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
{
if (device.Type() == DeviceKind::CPU)

Просмотреть файл

@ -72,10 +72,7 @@ namespace CNTK
FunctionPtr Variable::Owner() const
{
if (m_dataFields->m_ownerFunction != nullptr)
return m_dataFields->m_ownerFunction->shared_from_this();
else
return nullptr;
return m_dataFields->Owner();
}
Variable Variable::CompositePreservingCopy(const std::shared_ptr<const Function>& composite) const
@ -94,12 +91,12 @@ namespace CNTK
return copy;
}
void Variable::SetOwner(Function* ownerFunction)
void Variable::SetOwner(const std::weak_ptr<Function>& ownerFunction)
{
if (Kind() != VariableKind::Output)
LogicError("Variable '%S' SetOwner: Owner can only be set for Output Variables", AsString().c_str());
if (m_dataFields->m_ownerFunction != nullptr)
if (Owner() != nullptr)
LogicError("Variable '%S' SetOwner: An Output Variable whose owner has previously been set, cannot be reset.", AsString().c_str());
m_dataFields->m_ownerFunction = ownerFunction;
@ -213,13 +210,24 @@ namespace CNTK
return wss.str();
}
FunctionPtr VariableFields::Owner() const
{
if (IsObjectExpired(m_ownerFunction))
LogicError("The owner function of Variable '%S' is unexpectedly expired.", AsString().c_str());
auto ownerFunctionPtr = m_ownerFunction.lock();
if (ownerFunctionPtr != nullptr)
return ownerFunctionPtr->shared_from_this();
else
return nullptr;
}
std::shared_ptr<VariableFields> VariableFields::Clone() const
{
if (m_ownerFunction != nullptr)
if (Owner() != nullptr)
InvalidArgument("Output variable '%S' cannot be cloned.", AsString().c_str());
// Note: We do not clone m_blockFunctionVariableMapping
auto clone = MakeSharedObject<VariableFields>(m_shape,
m_varKind,
m_dataType,
@ -364,7 +372,7 @@ namespace CNTK
}
Variable::Variable(const NDShape& shape, VariableKind varType, CNTK::DataType dataType, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, nullptr, value, needsGradient, dynamicAxes, isSparse, name, uid))
: m_dataFields(MakeSharedObject<VariableFields>(shape, varType, dataType, std::weak_ptr<Function>(), value, needsGradient, dynamicAxes, isSparse, name, uid))
{}
template <typename ElementType>

Просмотреть файл

@ -19,7 +19,7 @@ namespace CNTK
NDShape m_shape;
VariableKind m_varKind;
::CNTK::DataType m_dataType;
Function* m_ownerFunction; // Variable does not keep the Function alive
std::weak_ptr<Function> m_ownerFunction;
std::unique_ptr<std::once_flag> m_initValueFlag;
NDArrayViewPtr m_value;
std::unique_ptr<ParameterInitializer> m_valueInitializer;
@ -32,7 +32,7 @@ namespace CNTK
std::atomic<size_t> m_valueTimeStamp;
Variable m_blockFunctionVariableMapping;
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, Function* ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
VariableFields(const NDShape& shape, VariableKind varType, ::CNTK::DataType type, const std::weak_ptr<Function>& ownerFunction, const NDArrayViewPtr& value, bool needsGradient, const std::vector<Axis>& dynamicAxes, bool isSparse, const std::wstring& name, const std::wstring& uid)
: m_shape(shape), m_varKind(varType), m_dataType(type), m_ownerFunction(ownerFunction), m_value(value), m_needsGradient(needsGradient), m_dynamicAxes(dynamicAxes), m_isSparse(isSparse), m_name(name), m_uid(uid), m_valueTimeStamp(0)
{
if (value && (type != value->GetDataType()))
@ -59,6 +59,7 @@ namespace CNTK
std::wstring AsString() const;
std::shared_ptr<VariableFields> Clone() const;
FunctionPtr Owner() const;
CNTK_API void SetValueInitialization(const ParameterInitializer& initializationConfig, const DeviceDescriptor& device);