CNTK v2 library: Fixed handling of references of existing netowrk matrix storage handed out from Forward/Backward

This commit is contained in:
Amit Agarwal 2017-04-13 16:51:13 -07:00
Родитель 1f23f6e161
Коммит c7c2547f55
14 изменённых файлов: 102 добавлений и 11 удалений

Просмотреть файл

@ -1509,7 +1509,7 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamples",
EndProject
Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "PythonExamples", "Examples\PythonExamples.pyproj", "{292FF4EE-D9DD-4BA7-85F7-6A22148D1E01}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPP\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}"
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "CPPExtensibilityExamplesLibrary", "Examples\Extensibility\CPPLib\CPPExtensibilityExamplesLibrary.vcxproj", "{4CF94A50-0D17-432A-8B5A-8458E91C44A6}"
ProjectSection(ProjectDependencies) = postProject
{E5606ECE-48CA-4464-BB12-09D81D02B9EF} = {E5606ECE-48CA-4464-BB12-09D81D02B9EF}
EndProjectSection

Просмотреть файл

@ -1,4 +1,4 @@
#include "UserMatrixMultiplicationOp.h"
#include "../CPP/UserMatrixMultiplicationOp.h"
using namespace CNTK;

Просмотреть файл

@ -26,7 +26,7 @@
<ClCompile Include="CPPExtensibilityExamplesLibrary.cpp" />
</ItemGroup>
<ItemGroup>
<ClInclude Include="UserMatrixMultiplicationOp.h" />
<ClInclude Include="../CPP/UserMatrixMultiplicationOp.h" />
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{4cf94a50-0d17-432a-8b5a-8458e91c44a6}</ProjectGuid>

Просмотреть файл

@ -20,7 +20,7 @@
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="UserMatrixMultiplicationOp.h">
<ClInclude Include="../CPP/UserMatrixMultiplicationOp.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>

Просмотреть файл

@ -519,7 +519,7 @@ $(CNTKLIBRARY_LIB): $(CNTKLIBRARY_OBJ) | $(CNTKMATH_LIB)
########################################
CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC =\
$(SOURCEDIR)/../Examples/Extensibility/CPP/CPPExtensibilityExamplesLibrary.cpp \
$(SOURCEDIR)/../Examples/Extensibility/CPPLib/CPPExtensibilityExamplesLibrary.cpp \
CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_OBJ := $(patsubst %.cpp, $(OBJDIR)/%.o, $(CPP_EXTENSIBILITY_EXAMPLES_LIBRARY_SRC))

Просмотреть файл

@ -2571,6 +2571,8 @@ namespace CNTK
///
virtual void CopyFrom(const Value& source);
virtual void Erase();
///
/// Unpacks sequences in 'this' Value as a vector of NDArrayView objects, each represeting a sequence in the
/// batch of sequences that 'this' Value object contains data for.

Просмотреть файл

@ -215,6 +215,10 @@ namespace CNTK
class UserFunctionFactory;
typedef std::shared_ptr<UserFunctionFactory> UserFunctionFactoryPtr;
class PackedValue;
typedef std::shared_ptr<PackedValue> PackedValuePtr;
typedef std::weak_ptr<PackedValue> PackedValueWeakPtr;
struct MinibatchSourceConfig;
namespace Internal

Просмотреть файл

@ -1426,7 +1426,16 @@ namespace CNTK
{
// Now copy the Forward values of output nodes from the network to outputs' Value objects
for (auto outputVarValuePair : outputs)
GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/);
{
auto& valuePtr = outputs[outputVarValuePair.first];
auto node = m_variableToNodeMap.at(outputVarValuePair.first);
bool noValueStrorageProvided = (valuePtr == nullptr);
GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/);
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
@ -1452,7 +1461,13 @@ namespace CNTK
LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.",
AsString().c_str(), gradientVarValuePair.first.AsString().c_str());
GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
auto& valuePtr = gradients[gradientVarValuePair.first];
bool noValueStrorageProvided = (valuePtr == nullptr);
GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/);
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
@ -1592,7 +1607,8 @@ namespace CNTK
for (auto& backpropRoot : m_currentBackpropRoots)
m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();
// TODO: Verify that values were supplied for all inputs that requested outputs depend on
// Free any previous references to the matrix storage associated with the outputsToEvaluate
ClearExistingOutputOrGradientStorageReferences();
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);

Просмотреть файл

@ -10,6 +10,7 @@
#include "PrimitiveFunction.h"
#include "ComputationNetwork.h"
#include "BackCompat.h"
#include "Value.h"
namespace CNTK
{
@ -309,6 +310,18 @@ namespace CNTK
std::unordered_map<Variable, uint64_t> GetCurrentBackpropRootsTimeStamps() const;
void ClearExistingOutputOrGradientStorageReferences()
{
for (auto& existingStorageWeakReference : m_existingNetworkStorageReferences)
{
auto existingStorageReference = existingStorageWeakReference.lock();
if (existingStorageReference)
existingStorageReference->Erase();
}
m_existingNetworkStorageReferences.clear();
}
private:
// Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive
@ -323,6 +336,9 @@ namespace CNTK
Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork;
// Map to keep track of any references to network output/gradient storage handed out so far
std::vector<PackedValueWeakPtr> m_existingNetworkStorageReferences;
// The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function.
// This indicates for which of its roots has 'this' Function retained required intermediate
// states from the previos Forward call to be able to backpropagate gradients backwards from in

Просмотреть файл

@ -487,6 +487,11 @@ namespace CNTK
return dynamicAxes[0].Name();
}
bool IsPackedValue(const ValuePtr& value)
{
auto packedValue = dynamic_pointer_cast<PackedValue>(value);
return (packedValue != nullptr) && packedValue->IsPacked();
}
std::pair<size_t, size_t> GetNumTimeStepsAndSequences(const NDShape& maskShape, size_t numDynamicAxes)
{
size_t maxNumTimeSteps = 1;
@ -520,10 +525,8 @@ namespace CNTK
if (var.GetDataType() != value->GetDataType())
LogicError("The Variable '%S' DataType %s does not match the corresponding Value's DataType %s", var.AsString().c_str(), DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType()));
auto packedValue = dynamic_cast<PackedValue*>(value.get());
bool isPackedValue = (packedValue != nullptr) && packedValue->IsPacked();
// TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error even for packed value objects?
bool isPackedValue = IsPackedValue(value);
if (!isPackedValue)
{
if (IsSparseInput(var) && !value->IsSparse())

Просмотреть файл

@ -589,6 +589,8 @@ namespace CNTK
return ShapeRowColSplitPoint(var.Shape(), var.IsSparse());
}
bool IsPackedValue(const ValuePtr& value);
// Helper class to manage a collection of learners.
class Learners
{

Просмотреть файл

@ -404,6 +404,12 @@ namespace CNTK
{
}
/*virtual*/ void Value::Erase()
{
m_data = nullptr;
m_mask = nullptr;
}
/*virtual*/ NDArrayViewPtr Value::Data() const
{
// TODO: Check if this is a derived type and throw an exception in that case

Просмотреть файл

@ -35,6 +35,17 @@ namespace CNTK
void Unpack() const;
void Erase() override
{
if (IsPacked())
{
m_packedData = nullptr;
m_packedDataLayout = nullptr;
}
else
Value::Erase();
}
const NDShape& Shape() const override { return m_unpackedShape; }
DeviceDescriptor Device() const override { return m_isPacked ? m_packedData->Device() : Value::Device(); }
DataType GetDataType() const override { return m_isPacked ? m_packedData->GetDataType() : Value::GetDataType(); }

Просмотреть файл

@ -421,3 +421,34 @@ def test_transpose_0d_1d_operands():
x2 = C.input(2)
with pytest.raises(ValueError):
transpose_1d = C.transpose(x2)
def test_eval_again_with_prev_outputs_live(device_id):
x = C.input(2)
dev = cntk_device(device_id)
w1 = C.parameter(init=np.asarray([1], dtype=np.float32), device=dev)
w2 = C.parameter(init=np.asarray([-1], dtype=np.float32), device=dev)
out1 = x + w1
out2 = x + w2
op = C.combine([out1, out2])
result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev)
result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev)
assert np.array_equal(result2[out1.output], [[0, 5], [-3, 8]])
assert np.array_equal(result2[out2.output], [[-2, 3], [-5, 6]])
result1 = op.eval({x : np.asarray([2, 5], dtype=np.float32)}, device=dev, as_numpy=False)
result2 = op.eval({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, device=dev, as_numpy=False)
assert np.array_equal(result2[out1.output].asarray(), [[0, 5], [-3, 8]])
assert np.array_equal(result2[out2.output].asarray(), [[-2, 3], [-5, 6]])
grad_op = out1 + out2
grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev)
grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev)
assert np.array_equal(grad2[w1], [4])
assert np.array_equal(grad2[w2], [4])
grad1 = grad_op.grad({x : np.asarray([2, 5], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False)
grad2 = grad_op.grad({x : np.asarray([[-1, 4], [-4, 7]], dtype=np.float32)}, wrt=[w1, w2], device=dev, as_numpy=False)
assert np.array_equal(grad2[w1].asarray(), [4])
assert np.array_equal(grad2[w2].asarray(), [4])