From c7c2547f559bc5b7268af20ad89cea3d4431203a Mon Sep 17 00:00:00 2001 From: Amit Agarwal Date: Thu, 13 Apr 2017 16:51:13 -0700 Subject: [PATCH] CNTK v2 library: Fixed handling of references of existing netowrk matrix storage handed out from Forward/Backward --- CNTK.sln | 2 +- .../CPPExtensibilityExamplesLibrary.cpp | 2 +- .../CPPExtensibilityExamplesLibrary.vcxproj | 2 +- ...tensibilityExamplesLibrary.vcxproj.filters | 2 +- Makefile | 2 +- Source/CNTKv2LibraryDll/API/CNTKLibrary.h | 2 ++ .../API/CNTKLibraryInternals.h | 4 +++ Source/CNTKv2LibraryDll/CompositeFunction.cpp | 22 +++++++++++-- Source/CNTKv2LibraryDll/CompositeFunction.h | 16 ++++++++++ Source/CNTKv2LibraryDll/Utils.cpp | 9 ++++-- Source/CNTKv2LibraryDll/Utils.h | 2 ++ Source/CNTKv2LibraryDll/Value.cpp | 6 ++++ Source/CNTKv2LibraryDll/Value.h | 11 +++++++ .../python/cntk/ops/tests/function_tests.py | 31 +++++++++++++++++++ 14 files changed, 102 insertions(+), 11 deletions(-) rename Examples/Extensibility/{CPP => CPPLib}/CPPExtensibilityExamplesLibrary.cpp (86%) rename Examples/Extensibility/{CPP => CPPLib}/CPPExtensibilityExamplesLibrary.vcxproj (98%) rename Examples/Extensibility/{CPP => CPPLib}/CPPExtensibilityExamplesLibrary.vcxproj.filters (94%) diff --git a/CNTK.sln b/CNTK.sln index 69057f36f..b6bd02890 100644 --- a/CNTK.sln +++ b/CNTK.sln @@ -1509,7 +1509,7 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamples", EndProject Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "PythonExamples", "Examples\PythonExamples.pyproj", "{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}" EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPP\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}" +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPPLib\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}" ProjectSection(ProjectDependencies) = postProject {E5606ECE-48CA-4464-BB12-09D81D02B9EF} = {E5606ECE-48CA-4464-BB12-09D81D02B9EF} EndProjectSection diff --git a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp similarity index 86% rename from Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp rename to Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp index 4dd28d693..aeb2946e9 100644 --- a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp +++ b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp @@ -1,4 +1,4 @@ -#include "UserMatrixMultiplicationOp.h" +#include "../CPP/UserMatrixMultiplicationOp.h" using namespace CNTK; diff --git a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj similarity index 98% rename from Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj rename to Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj index b48031d5d..efb948e83 100644 --- a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj +++ b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj @@ -26,7 +26,7 @@ - + {4cf94a50-0d17-432a-8b5a-8458e91c44a6} diff --git a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj.filters b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj.filters similarity index 94% rename from Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj.filters rename to Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj.filters index e6ac6e470..58d0a7626 100644 --- a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj.filters +++ b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj.filters @@ -20,7 +20,7 @@ - + Header Files diff --git a/Makefile b/Makefile index 6a98dfa24..7ba74da52 100644 --- a/Makefile +++ b/Makefile @@ -519,7 +519,7 @@ $(CNTKLIBRARY_LIB): $(CNTKLIBRARY_OBJ) | $(CNTKMATH_LIB) ######################################## CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC =\ - $(SOURCEDIR)/../Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp \ + $(SOURCEDIR)/../Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp \ CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_OBJ := $(patsubst %.cpp, $(OBJDIR)/%.o, $(CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC)) diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h index 9636e0be1..71e9ee09a 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h @@ -2571,6 +2571,8 @@ namespace CNTK /// virtual void CopyFrom(const Value& source); + virtual void Erase(); + /// /// Unpacks sequences in 'this' Value as a vector of NDArrayView objects, each represeting a sequence in the /// batch of sequences that 'this' Value object contains data for. diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h index 30938b8bc..ccc4713ae 100644 --- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h +++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h @@ -215,6 +215,10 @@ namespace CNTK class UserFunctionFactory; typedef std::shared_ptr UserFunctionFactoryPtr; + class PackedValue; + typedef std::shared_ptr PackedValuePtr; + typedef std::weak_ptr PackedValueWeakPtr; + struct MinibatchSourceConfig; namespace Internal diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.cpp b/Source/CNTKv2LibraryDll/CompositeFunction.cpp index e94ea8315..0170105b5 100755 --- a/Source/CNTKv2LibraryDll/CompositeFunction.cpp +++ b/Source/CNTKv2LibraryDll/CompositeFunction.cpp @@ -1426,7 +1426,16 @@ namespace CNTK { // Now copy the Forward values of output nodes from the network to outputs' Value objects for (auto outputVarValuePair : outputs) - GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/); + { + auto& valuePtr = outputs[outputVarValuePair.first]; + auto node = m_variableToNodeMap.at(outputVarValuePair.first); + bool noValueStrorageProvided = (valuePtr == nullptr); + GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/); + + auto packedVarValue = std::dynamic_pointer_cast(valuePtr); + if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked()) + m_existingNetworkStorageReferences.push_back(packedVarValue); + } } void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) @@ -1452,7 +1461,13 @@ namespace CNTK LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.", AsString().c_str(), gradientVarValuePair.first.AsString().c_str()); - GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/); + auto& valuePtr = gradients[gradientVarValuePair.first]; + bool noValueStrorageProvided = (valuePtr == nullptr); + GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/); + + auto packedVarValue = std::dynamic_pointer_cast(valuePtr); + if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked()) + m_existingNetworkStorageReferences.push_back(packedVarValue); } } @@ -1592,7 +1607,8 @@ namespace CNTK for (auto& backpropRoot : m_currentBackpropRoots) m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll(); - // TODO: Verify that values were supplied for all inputs that requested outputs depend on + // Free any previous references to the matrix storage associated with the outputsToEvaluate + ClearExistingOutputOrGradientStorageReferences(); ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.h b/Source/CNTKv2LibraryDll/CompositeFunction.h index 433e2d41f..b170567ae 100644 --- a/Source/CNTKv2LibraryDll/CompositeFunction.h +++ b/Source/CNTKv2LibraryDll/CompositeFunction.h @@ -10,6 +10,7 @@ #include "PrimitiveFunction.h" #include "ComputationNetwork.h" #include "BackCompat.h" +#include "Value.h" namespace CNTK { @@ -309,6 +310,18 @@ namespace CNTK std::unordered_map GetCurrentBackpropRootsTimeStamps() const; + void ClearExistingOutputOrGradientStorageReferences() + { + for (auto& existingStorageWeakReference : m_existingNetworkStorageReferences) + { + auto existingStorageReference = existingStorageWeakReference.lock(); + if (existingStorageReference) + existingStorageReference->Erase(); + } + + m_existingNetworkStorageReferences.clear(); + } + private: // Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive @@ -323,6 +336,9 @@ namespace CNTK Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork; + // Map to keep track of any references to network output/gradient storage handed out so far + std::vector m_existingNetworkStorageReferences; + // The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function. // This indicates for which of its roots has 'this' Function retained required intermediate // states from the previos Forward call to be able to backpropagate gradients backwards from in diff --git a/Source/CNTKv2LibraryDll/Utils.cpp b/Source/CNTKv2LibraryDll/Utils.cpp index 7a360b3ac..b0643fad0 100644 --- a/Source/CNTKv2LibraryDll/Utils.cpp +++ b/Source/CNTKv2LibraryDll/Utils.cpp @@ -487,6 +487,11 @@ namespace CNTK return dynamicAxes[0].Name(); } + bool IsPackedValue(const ValuePtr& value) + { + auto packedValue = dynamic_pointer_cast(value); + return (packedValue != nullptr) && packedValue->IsPacked(); + } std::pair GetNumTimeStepsAndSequences(const NDShape& maskShape, size_t numDynamicAxes) { size_t maxNumTimeSteps = 1; @@ -520,10 +525,8 @@ namespace CNTK if (var.GetDataType() != value->GetDataType()) LogicError("The Variable '%S' DataType %s does not match the corresponding Value's DataType %s", var.AsString().c_str(), DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType())); - auto packedValue = dynamic_cast(value.get()); - bool isPackedValue = (packedValue != nullptr) && packedValue->IsPacked(); - // TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error even for packed value objects? + bool isPackedValue = IsPackedValue(value); if (!isPackedValue) { if (IsSparseInput(var) && !value->IsSparse()) diff --git a/Source/CNTKv2LibraryDll/Utils.h b/Source/CNTKv2LibraryDll/Utils.h index d68039baf..5c20a3d4b 100644 --- a/Source/CNTKv2LibraryDll/Utils.h +++ b/Source/CNTKv2LibraryDll/Utils.h @@ -589,6 +589,8 @@ namespace CNTK return ShapeRowColSplitPoint(var.Shape(), var.IsSparse()); } + bool IsPackedValue(const ValuePtr& value); + // Helper class to manage a collection of learners. class Learners { diff --git a/Source/CNTKv2LibraryDll/Value.cpp b/Source/CNTKv2LibraryDll/Value.cpp index a8010da6b..abaacdfd1 100644 --- a/Source/CNTKv2LibraryDll/Value.cpp +++ b/Source/CNTKv2LibraryDll/Value.cpp @@ -404,6 +404,12 @@ namespace CNTK { } + /*virtual*/ void Value::Erase() + { + m_data = nullptr; + m_mask = nullptr; + } + /*virtual*/ NDArrayViewPtr Value::Data() const { // TODO: Check if this is a derived type and throw an exception in that case diff --git a/Source/CNTKv2LibraryDll/Value.h b/Source/CNTKv2LibraryDll/Value.h index 46da84929..0632ea31f 100644 --- a/Source/CNTKv2LibraryDll/Value.h +++ b/Source/CNTKv2LibraryDll/Value.h @@ -35,6 +35,17 @@ namespace CNTK void Unpack() const; + void Erase() override + { + if (IsPacked()) + { + m_packedData = nullptr; + m_packedDataLayout = nullptr; + } + else + Value::Erase(); + } + const NDShape& Shape() const override { return m_unpackedShape; } DeviceDescriptor Device() const override { return m_isPacked ? m_packedData->Device() : Value::Device(); } DataType GetDataType() const override { return m_isPacked ? m_packedData->GetDataType() : Value::GetDataType(); } diff --git a/bindings/python/cntk/ops/tests/function_tests.py b/bindings/python/cntk/ops/tests/function_tests.py index aedda17f3..1e827e6dc 100644 --- a/bindings/python/cntk/ops/tests/function_tests.py +++ b/bindings/python/cntk/ops/tests/function_tests.py @@ -421,3 +421,34 @@ def test_transpose_0d_1d_operands(): x2 = C.input(2) with pytest.raises(ValueError): transpose_1d = C.transpose(x2) + + +def test_eval_again_with_prev_outputs_live(device_id): + x = C.input(2) + dev = cntk_device(device_id) + w1 = C.parameter(init=np.asarray([1], dtype=np.float32), device=dev) + w2 = C.parameter(init=np.asarray([-1], dtype=np.float32), device=dev) + out1 = x + w1 + out2 = x + w2 + op = C.combine([out1, out2]) + + result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev) + result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev) + assert np.array_equal(result2[out1.output], [[0, 5], [-3, 8]]) + assert np.array_equal(result2[out2.output], [[-2, 3], [-5, 6]]) + + result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev, as_numpy=False) + result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev, as_numpy=False) + assert np.array_equal(result2[out1.output].asarray(), [[0, 5], [-3, 8]]) + assert np.array_equal(result2[out2.output].asarray(), [[-2, 3], [-5, 6]]) + + grad_op = out1 + out2 + grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev) + grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev) + assert np.array_equal(grad2[w1], [4]) + assert np.array_equal(grad2[w2], [4]) + + grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False) + grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False) + assert np.array_equal(grad2[w1].asarray(), [4]) + assert np.array_equal(grad2[w2].asarray(), [4])