diff --git a/CNTK.sln b/CNTK.sln
index 69057f36f..b6bd02890 100644
--- a/CNTK.sln
+++ b/CNTK.sln
@@ -1509,7 +1509,7 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamples",
EndProject
Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "PythonExamples", "Examples\PythonExamples.pyproj", "{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}"
EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPP\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}"
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPPLib\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}"
ProjectSection(ProjectDependencies) = postProject
{E5606ECE-48CA-4464-BB12-09D81D02B9EF} = {E5606ECE-48CA-4464-BB12-09D81D02B9EF}
EndProjectSection
diff --git a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp
similarity index 86%
rename from Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp
rename to Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp
index 4dd28d693..aeb2946e9 100644
--- a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp
+++ b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp
@@ -1,4 +1,4 @@
-#include "UserMatrixMultiplicationOp.h"
+#include "../CPP/UserMatrixMultiplicationOp.h"
using namespace CNTK;
diff --git a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj
similarity index 98%
rename from Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj
rename to Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj
index b48031d5d..efb948e83 100644
--- a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj
+++ b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj
@@ -26,7 +26,7 @@
-
+
{4cf94a50-0d17-432a-8b5a-8458e91c44a6}
diff --git a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj.filters b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj.filters
similarity index 94%
rename from Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj.filters
rename to Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj.filters
index e6ac6e470..58d0a7626 100644
--- a/Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.vcxproj.filters
+++ b/Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.vcxproj.filters
@@ -20,7 +20,7 @@
-
+
Header Files
diff --git a/Makefile b/Makefile
index 6a98dfa24..7ba74da52 100644
--- a/Makefile
+++ b/Makefile
@@ -519,7 +519,7 @@ $(CNTKLIBRARY_LIB): $(CNTKLIBRARY_OBJ) | $(CNTKMATH_LIB)
########################################
CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC =\
- $(SOURCEDIR)/../Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp \
+ $(SOURCEDIR)/../Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp \
CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_OBJ := $(patsubst %.cpp, $(OBJDIR)/%.o, $(CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC))
diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h
index 9636e0be1..71e9ee09a 100644
--- a/Source/CNTKv2LibraryDll/API/CNTKLibrary.h
+++ b/Source/CNTKv2LibraryDll/API/CNTKLibrary.h
@@ -2571,6 +2571,8 @@ namespace CNTK
///
virtual void CopyFrom(const Value& source);
+ virtual void Erase();
+
///
/// Unpacks sequences in 'this' Value as a vector of NDArrayView objects, each represeting a sequence in the
/// batch of sequences that 'this' Value object contains data for.
diff --git a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h
index 30938b8bc..ccc4713ae 100644
--- a/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h
+++ b/Source/CNTKv2LibraryDll/API/CNTKLibraryInternals.h
@@ -215,6 +215,10 @@ namespace CNTK
class UserFunctionFactory;
typedef std::shared_ptr UserFunctionFactoryPtr;
+ class PackedValue;
+ typedef std::shared_ptr PackedValuePtr;
+ typedef std::weak_ptr PackedValueWeakPtr;
+
struct MinibatchSourceConfig;
namespace Internal
diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.cpp b/Source/CNTKv2LibraryDll/CompositeFunction.cpp
index e94ea8315..0170105b5 100755
--- a/Source/CNTKv2LibraryDll/CompositeFunction.cpp
+++ b/Source/CNTKv2LibraryDll/CompositeFunction.cpp
@@ -1426,7 +1426,16 @@ namespace CNTK
{
// Now copy the Forward values of output nodes from the network to outputs' Value objects
for (auto outputVarValuePair : outputs)
- GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/);
+ {
+ auto& valuePtr = outputs[outputVarValuePair.first];
+ auto node = m_variableToNodeMap.at(outputVarValuePair.first);
+ bool noValueStrorageProvided = (valuePtr == nullptr);
+ GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/);
+
+ auto packedVarValue = std::dynamic_pointer_cast(valuePtr);
+ if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
+ m_existingNetworkStorageReferences.push_back(packedVarValue);
+ }
}
void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients)
@@ -1452,7 +1461,13 @@ namespace CNTK
LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.",
AsString().c_str(), gradientVarValuePair.first.AsString().c_str());
- GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
+ auto& valuePtr = gradients[gradientVarValuePair.first];
+ bool noValueStrorageProvided = (valuePtr == nullptr);
+ GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/);
+
+ auto packedVarValue = std::dynamic_pointer_cast(valuePtr);
+ if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
+ m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
@@ -1592,7 +1607,8 @@ namespace CNTK
for (auto& backpropRoot : m_currentBackpropRoots)
m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();
- // TODO: Verify that values were supplied for all inputs that requested outputs depend on
+ // Free any previous references to the matrix storage associated with the outputsToEvaluate
+ ClearExistingOutputOrGradientStorageReferences();
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.h b/Source/CNTKv2LibraryDll/CompositeFunction.h
index 433e2d41f..b170567ae 100644
--- a/Source/CNTKv2LibraryDll/CompositeFunction.h
+++ b/Source/CNTKv2LibraryDll/CompositeFunction.h
@@ -10,6 +10,7 @@
#include "PrimitiveFunction.h"
#include "ComputationNetwork.h"
#include "BackCompat.h"
+#include "Value.h"
namespace CNTK
{
@@ -309,6 +310,18 @@ namespace CNTK
std::unordered_map GetCurrentBackpropRootsTimeStamps() const;
+ void ClearExistingOutputOrGradientStorageReferences()
+ {
+ for (auto& existingStorageWeakReference : m_existingNetworkStorageReferences)
+ {
+ auto existingStorageReference = existingStorageWeakReference.lock();
+ if (existingStorageReference)
+ existingStorageReference->Erase();
+ }
+
+ m_existingNetworkStorageReferences.clear();
+ }
+
private:
// Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive
@@ -323,6 +336,9 @@ namespace CNTK
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
+ // Map to keep track of any references to network output/gradient storage handed out so far
+ std::vector m_existingNetworkStorageReferences;
+
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
// This indicates for which of its roots has 'this' Function retained required intermediate
// states from the previos Forward call to be able to backpropagate gradients backwards from in
diff --git a/Source/CNTKv2LibraryDll/Utils.cpp b/Source/CNTKv2LibraryDll/Utils.cpp
index 7a360b3ac..b0643fad0 100644
--- a/Source/CNTKv2LibraryDll/Utils.cpp
+++ b/Source/CNTKv2LibraryDll/Utils.cpp
@@ -487,6 +487,11 @@ namespace CNTK
return dynamicAxes[0].Name();
}
+ bool IsPackedValue(const ValuePtr& value)
+ {
+ auto packedValue = dynamic_pointer_cast(value);
+ return (packedValue != nullptr) && packedValue->IsPacked();
+ }
std::pair GetNumTimeStepsAndSequences(const NDShape& maskShape, size_t numDynamicAxes)
{
size_t maxNumTimeSteps = 1;
@@ -520,10 +525,8 @@ namespace CNTK
if (var.GetDataType() != value->GetDataType())
LogicError("The Variable '%S' DataType %s does not match the corresponding Value's DataType %s", var.AsString().c_str(), DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType()));
- auto packedValue = dynamic_cast(value.get());
- bool isPackedValue = (packedValue != nullptr) && packedValue->IsPacked();
-
// TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error even for packed value objects?
+ bool isPackedValue = IsPackedValue(value);
if (!isPackedValue)
{
if (IsSparseInput(var) && !value->IsSparse())
diff --git a/Source/CNTKv2LibraryDll/Utils.h b/Source/CNTKv2LibraryDll/Utils.h
index d68039baf..5c20a3d4b 100644
--- a/Source/CNTKv2LibraryDll/Utils.h
+++ b/Source/CNTKv2LibraryDll/Utils.h
@@ -589,6 +589,8 @@ namespace CNTK
return ShapeRowColSplitPoint(var.Shape(), var.IsSparse());
}
+ bool IsPackedValue(const ValuePtr& value);
+
// Helper class to manage a collection of learners.
class Learners
{
diff --git a/Source/CNTKv2LibraryDll/Value.cpp b/Source/CNTKv2LibraryDll/Value.cpp
index a8010da6b..abaacdfd1 100644
--- a/Source/CNTKv2LibraryDll/Value.cpp
+++ b/Source/CNTKv2LibraryDll/Value.cpp
@@ -404,6 +404,12 @@ namespace CNTK
{
}
+ /*virtual*/ void Value::Erase()
+ {
+ m_data = nullptr;
+ m_mask = nullptr;
+ }
+
/*virtual*/ NDArrayViewPtr Value::Data() const
{
// TODO: Check if this is a derived type and throw an exception in that case
diff --git a/Source/CNTKv2LibraryDll/Value.h b/Source/CNTKv2LibraryDll/Value.h
index 46da84929..0632ea31f 100644
--- a/Source/CNTKv2LibraryDll/Value.h
+++ b/Source/CNTKv2LibraryDll/Value.h
@@ -35,6 +35,17 @@ namespace CNTK
void Unpack() const;
+ void Erase() override
+ {
+ if (IsPacked())
+ {
+ m_packedData = nullptr;
+ m_packedDataLayout = nullptr;
+ }
+ else
+ Value::Erase();
+ }
+
const NDShape& Shape() const override { return m_unpackedShape; }
DeviceDescriptor Device() const override { return m_isPacked ? m_packedData->Device() : Value::Device(); }
DataType GetDataType() const override { return m_isPacked ? m_packedData->GetDataType() : Value::GetDataType(); }
diff --git a/bindings/python/cntk/ops/tests/function_tests.py b/bindings/python/cntk/ops/tests/function_tests.py
index aedda17f3..1e827e6dc 100644
--- a/bindings/python/cntk/ops/tests/function_tests.py
+++ b/bindings/python/cntk/ops/tests/function_tests.py
@@ -421,3 +421,34 @@ def test_transpose_0d_1d_operands():
x2 = C.input(2)
with pytest.raises(ValueError):
transpose_1d = C.transpose(x2)
+
+
+def test_eval_again_with_prev_outputs_live(device_id):
+ x = C.input(2)
+ dev = cntk_device(device_id)
+ w1 = C.parameter(init=np.asarray([1], dtype=np.float32), device=dev)
+ w2 = C.parameter(init=np.asarray([-1], dtype=np.float32), device=dev)
+ out1 = x + w1
+ out2 = x + w2
+ op = C.combine([out1, out2])
+
+ result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev)
+ result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev)
+ assert np.array_equal(result2[out1.output], [[0, 5], [-3, 8]])
+ assert np.array_equal(result2[out2.output], [[-2, 3], [-5, 6]])
+
+ result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev, as_numpy=False)
+ result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev, as_numpy=False)
+ assert np.array_equal(result2[out1.output].asarray(), [[0, 5], [-3, 8]])
+ assert np.array_equal(result2[out2.output].asarray(), [[-2, 3], [-5, 6]])
+
+ grad_op = out1 + out2
+ grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev)
+ grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev)
+ assert np.array_equal(grad2[w1], [4])
+ assert np.array_equal(grad2[w2], [4])
+
+ grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False)
+ grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False)
+ assert np.array_equal(grad2[w1].asarray(), [4])
+ assert np.array_equal(grad2[w2].asarray(), [4])