CNTK v2 library: Fixed handling of references of existing netowrk matrix storage handed out from Forward/Backward
This commit is contained in:
Родитель
1f23f6e161
Коммит
c7c2547f55
2
CNTK.sln
2
CNTK.sln
|
@ -1509,7 +1509,7 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamples",
|
|||
EndProject
|
||||
Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "PythonExamples", "Examples\PythonExamples.pyproj", "{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}"
|
||||
EndProject
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPP\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}"
|
||||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPPLib\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}"
|
||||
ProjectSection(ProjectDependencies) = postProject
|
||||
{E5606ECE-48CA-4464-BB12-09D81D02B9EF} = {E5606ECE-48CA-4464-BB12-09D81D02B9EF}
|
||||
EndProjectSection
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#include "UserMatrixMultiplicationOp.h"
|
||||
#include "../CPP/UserMatrixMultiplicationOp.h"
|
||||
|
||||
using namespace CNTK;
|
||||
|
|
@ -26,7 +26,7 @@
|
|||
<ClCompile Include="CPPExtensibilityExamplesLibrary.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="UserMatrixMultiplicationOp.h" />
|
||||
<ClInclude Include="../CPP/UserMatrixMultiplicationOp.h" />
|
||||
</ItemGroup>
|
||||
<PropertyGroup Label="Globals">
|
||||
<ProjectGuid>{4cf94a50-0d17-432a-8b5a-8458e91c44a6}</ProjectGuid>
|
|
@ -20,7 +20,7 @@
|
|||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="UserMatrixMultiplicationOp.h">
|
||||
<ClInclude Include="../CPP/UserMatrixMultiplicationOp.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
2
Makefile
2
Makefile
|
@ -519,7 +519,7 @@ $(CNTKLIBRARY_LIB): $(CNTKLIBRARY_OBJ) | $(CNTKMATH_LIB)
|
|||
########################################
|
||||
|
||||
CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC =\
|
||||
$(SOURCEDIR)/../Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp \
|
||||
$(SOURCEDIR)/../Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp \
|
||||
|
||||
CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_OBJ := $(patsubst %.cpp, $(OBJDIR)/%.o, $(CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC))
|
||||
|
||||
|
|
|
@ -2571,6 +2571,8 @@ namespace CNTK
|
|||
///
|
||||
virtual void CopyFrom(const Value& source);
|
||||
|
||||
virtual void Erase();
|
||||
|
||||
///
|
||||
/// Unpacks sequences in 'this' Value as a vector of NDArrayView objects, each represeting a sequence in the
|
||||
/// batch of sequences that 'this' Value object contains data for.
|
||||
|
|
|
@ -215,6 +215,10 @@ namespace CNTK
|
|||
class UserFunctionFactory;
|
||||
typedef std::shared_ptr<UserFunctionFactory> UserFunctionFactoryPtr;
|
||||
|
||||
class PackedValue;
|
||||
typedef std::shared_ptr<PackedValue> PackedValuePtr;
|
||||
typedef std::weak_ptr<PackedValue> PackedValueWeakPtr;
|
||||
|
||||
struct MinibatchSourceConfig;
|
||||
|
||||
namespace Internal
|
||||
|
|
|
@ -1426,7 +1426,16 @@ namespace CNTK
|
|||
{
|
||||
// Now copy the Forward values of output nodes from the network to outputs' Value objects
|
||||
for (auto outputVarValuePair : outputs)
|
||||
GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/);
|
||||
{
|
||||
auto& valuePtr = outputs[outputVarValuePair.first];
|
||||
auto node = m_variableToNodeMap.at(outputVarValuePair.first);
|
||||
bool noValueStrorageProvided = (valuePtr == nullptr);
|
||||
GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/);
|
||||
|
||||
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
|
||||
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
|
||||
m_existingNetworkStorageReferences.push_back(packedVarValue);
|
||||
}
|
||||
}
|
||||
|
||||
void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
|
||||
|
@ -1452,7 +1461,13 @@ namespace CNTK
|
|||
LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.",
|
||||
AsString().c_str(), gradientVarValuePair.first.AsString().c_str());
|
||||
|
||||
GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
|
||||
auto& valuePtr = gradients[gradientVarValuePair.first];
|
||||
bool noValueStrorageProvided = (valuePtr == nullptr);
|
||||
GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/);
|
||||
|
||||
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
|
||||
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
|
||||
m_existingNetworkStorageReferences.push_back(packedVarValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1592,7 +1607,8 @@ namespace CNTK
|
|||
for (auto& backpropRoot : m_currentBackpropRoots)
|
||||
m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();
|
||||
|
||||
// TODO: Verify that values were supplied for all inputs that requested outputs depend on
|
||||
// Free any previous references to the matrix storage associated with the outputsToEvaluate
|
||||
ClearExistingOutputOrGradientStorageReferences();
|
||||
|
||||
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "PrimitiveFunction.h"
|
||||
#include "ComputationNetwork.h"
|
||||
#include "BackCompat.h"
|
||||
#include "Value.h"
|
||||
|
||||
namespace CNTK
|
||||
{
|
||||
|
@ -309,6 +310,18 @@ namespace CNTK
|
|||
|
||||
std::unordered_map<Variable, uint64_t> GetCurrentBackpropRootsTimeStamps() const;
|
||||
|
||||
void ClearExistingOutputOrGradientStorageReferences()
|
||||
{
|
||||
for (auto& existingStorageWeakReference : m_existingNetworkStorageReferences)
|
||||
{
|
||||
auto existingStorageReference = existingStorageWeakReference.lock();
|
||||
if (existingStorageReference)
|
||||
existingStorageReference->Erase();
|
||||
}
|
||||
|
||||
m_existingNetworkStorageReferences.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive
|
||||
|
@ -323,6 +336,9 @@ namespace CNTK
|
|||
|
||||
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
|
||||
|
||||
// Map to keep track of any references to network output/gradient storage handed out so far
|
||||
std::vector<PackedValueWeakPtr> m_existingNetworkStorageReferences;
|
||||
|
||||
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
|
||||
// This indicates for which of its roots has 'this' Function retained required intermediate
|
||||
// states from the previos Forward call to be able to backpropagate gradients backwards from in
|
||||
|
|
|
@ -487,6 +487,11 @@ namespace CNTK
|
|||
return dynamicAxes[0].Name();
|
||||
}
|
||||
|
||||
bool IsPackedValue(const ValuePtr& value)
|
||||
{
|
||||
auto packedValue = dynamic_pointer_cast<PackedValue>(value);
|
||||
return (packedValue != nullptr) && packedValue->IsPacked();
|
||||
}
|
||||
std::pair<size_t, size_t> GetNumTimeStepsAndSequences(const NDShape& maskShape, size_t numDynamicAxes)
|
||||
{
|
||||
size_t maxNumTimeSteps = 1;
|
||||
|
@ -520,10 +525,8 @@ namespace CNTK
|
|||
if (var.GetDataType() != value->GetDataType())
|
||||
LogicError("The Variable '%S' DataType %s does not match the corresponding Value's DataType %s", var.AsString().c_str(), DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType()));
|
||||
|
||||
auto packedValue = dynamic_cast<PackedValue*>(value.get());
|
||||
bool isPackedValue = (packedValue != nullptr) && packedValue->IsPacked();
|
||||
|
||||
// TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error even for packed value objects?
|
||||
bool isPackedValue = IsPackedValue(value);
|
||||
if (!isPackedValue)
|
||||
{
|
||||
if (IsSparseInput(var) && !value->IsSparse())
|
||||
|
|
|
@ -589,6 +589,8 @@ namespace CNTK
|
|||
return ShapeRowColSplitPoint(var.Shape(), var.IsSparse());
|
||||
}
|
||||
|
||||
bool IsPackedValue(const ValuePtr& value);
|
||||
|
||||
// Helper class to manage a collection of learners.
|
||||
class Learners
|
||||
{
|
||||
|
|
|
@ -404,6 +404,12 @@ namespace CNTK
|
|||
{
|
||||
}
|
||||
|
||||
/*virtual*/ void Value::Erase()
|
||||
{
|
||||
m_data = nullptr;
|
||||
m_mask = nullptr;
|
||||
}
|
||||
|
||||
/*virtual*/ NDArrayViewPtr Value::Data() const
|
||||
{
|
||||
// TODO: Check if this is a derived type and throw an exception in that case
|
||||
|
|
|
@ -35,6 +35,17 @@ namespace CNTK
|
|||
|
||||
void Unpack() const;
|
||||
|
||||
void Erase() override
|
||||
{
|
||||
if (IsPacked())
|
||||
{
|
||||
m_packedData = nullptr;
|
||||
m_packedDataLayout = nullptr;
|
||||
}
|
||||
else
|
||||
Value::Erase();
|
||||
}
|
||||
|
||||
const NDShape& Shape() const override { return m_unpackedShape; }
|
||||
DeviceDescriptor Device() const override { return m_isPacked ? m_packedData->Device() : Value::Device(); }
|
||||
DataType GetDataType() const override { return m_isPacked ? m_packedData->GetDataType() : Value::GetDataType(); }
|
||||
|
|
|
@ -421,3 +421,34 @@ def test_transpose_0d_1d_operands():
|
|||
x2 = C.input(2)
|
||||
with pytest.raises(ValueError):
|
||||
transpose_1d = C.transpose(x2)
|
||||
|
||||
|
||||
def test_eval_again_with_prev_outputs_live(device_id):
|
||||
x = C.input(2)
|
||||
dev = cntk_device(device_id)
|
||||
w1 = C.parameter(init=np.asarray([1], dtype=np.float32), device=dev)
|
||||
w2 = C.parameter(init=np.asarray([-1], dtype=np.float32), device=dev)
|
||||
out1 = x + w1
|
||||
out2 = x + w2
|
||||
op = C.combine([out1, out2])
|
||||
|
||||
result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev)
|
||||
result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev)
|
||||
assert np.array_equal(result2[out1.output], [[0, 5], [-3, 8]])
|
||||
assert np.array_equal(result2[out2.output], [[-2, 3], [-5, 6]])
|
||||
|
||||
result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev, as_numpy=False)
|
||||
result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev, as_numpy=False)
|
||||
assert np.array_equal(result2[out1.output].asarray(), [[0, 5], [-3, 8]])
|
||||
assert np.array_equal(result2[out2.output].asarray(), [[-2, 3], [-5, 6]])
|
||||
|
||||
grad_op = out1 + out2
|
||||
grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev)
|
||||
grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev)
|
||||
assert np.array_equal(grad2[w1], [4])
|
||||
assert np.array_equal(grad2[w2], [4])
|
||||
|
||||
grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False)
|
||||
grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False)
|
||||
assert np.array_equal(grad2[w1].asarray(), [4])
|
||||
assert np.array_equal(grad2[w2].asarray(), [4])
|
||||
|
|
Загрузка…
Ссылка в новой задаче