add Function::FindByName(), Function::FindAllWithName; Add tests;

This commit is contained in:
Zhou Wang 2017-02-03 16:26:12 +01:00
Родитель 5b8d122a4b
Коммит 8176e0fccf
6 изменённых файлов: 249 добавлений и 38 удалений

Просмотреть файл

@ -2799,7 +2799,7 @@ namespace CNTK
}
///
/// Returns the set of all Constant variables of 'this' Function.
/// Returns the set of all Placeholder variables of 'this' Function.
///
std::vector<Variable> Placeholders() const
{
@ -2809,6 +2809,41 @@ namespace CNTK
}
///
/// Find a function with the given name in the Function graph underlying 'this' Function.
/// If more than one function with the same name, an exception is thrown.
/// If nestedSearchInsideBlockFunction is true, all functions inside block functions are also searched for the given name.
///
FunctionPtr FindByName(const std::wstring& name, bool nestedSearchInsideBlockFunction = false)
{
FunctionPtr foundFunction = nullptr;
PreorderTraverseFunctions(RootFunction(), [&foundFunction, &name](const FunctionPtr& function) {
if (name.compare(function->Name()) == 0)
{
if (foundFunction != nullptr)
RuntimeError("Multiple functions with the same name are found in the Function graph underlying 'this' Function.");
else
foundFunction = function;
}
}, nestedSearchInsideBlockFunction);
return foundFunction;
}
///
/// Find a list of functions with the given name in the Function graph underlying 'this' Function.
/// If nestedSearchInsideBlockFunction is true, all functions inside block functions are also searched for the given name.
///
std::vector<FunctionPtr> FindAllWithName(const std::wstring& name, bool nestedSearchInsideBlockFunction = false)
{
std::vector<FunctionPtr> foundFunctions;
PreorderTraverseFunctions(RootFunction(), [&foundFunctions, &name](const FunctionPtr& function) {
if (name.compare(function->Name()) == 0)
foundFunctions.push_back(function);
}, nestedSearchInsideBlockFunction);
return foundFunctions;
}
/// Returns the dictionary of attributes of 'this' Function
///
const Dictionary& Attributes() const { return m_attributes; }
@ -2855,6 +2890,38 @@ namespace CNTK
///
CNTK_API Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction"));
template <typename FunctionType>
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor, bool traverseInsideBlockFunction = false)
{
std::unordered_set<FunctionPtr> visitedFunctions;
PreorderTraverseFunctions(rootFunction, visitedFunctions, functor, traverseInsideBlockFunction);
}
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
template <typename FunctionType>
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions, const FunctionType& functor, bool traverseInsideBlockFunction = false)
{
visitedFunctions.insert(rootFunction);
functor(rootFunction);
if (traverseInsideBlockFunction && rootFunction->IsBlock())
{
PreorderTraverseFunctions(rootFunction->BlockRoot(), visitedFunctions, functor, traverseInsideBlockFunction);
}
else
{
std::vector<Variable> rootFunctionInputs = rootFunction->Inputs();
for (const auto& rootInput : rootFunctionInputs)
{
if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end())
{
const auto& function = rootInput.Owner();
PreorderTraverseFunctions(function, visitedFunctions, functor, traverseInsideBlockFunction);
}
}
}
}
/// Restores the state of the 'this' Function in place using the provided dictionary.
/// Structurally, 'this' Function graph has to be identical to the state captured in the dictionary.
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& dictionary);

Просмотреть файл

@ -121,31 +121,6 @@ namespace CNTK
return CompositeFunctionOpName;
}
template <typename FunctionType>
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor)
{
std::unordered_set<FunctionPtr> visitedFunctions;
PreorderTraverseFunctions(rootFunction, visitedFunctions, functor);
}
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
template <typename FunctionType>
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions, const FunctionType& functor)
{
visitedFunctions.insert(rootFunction);
functor(rootFunction);
std::vector<Variable> rootFunctionInputs = rootFunction->Inputs();
for (const auto& rootInput : rootFunctionInputs)
{
if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end())
{
const auto& function = rootInput.Owner();
PreorderTraverseFunctions(function, visitedFunctions, functor);
}
}
}
template <typename FunctionType>
static void PreorderTraverseVariables(const FunctionPtr& rootFunction, const FunctionType& functor, bool pythonOperandOrder = false)
{

Просмотреть файл

@ -774,12 +774,179 @@ void TestOuputVariableName(const DeviceDescriptor& device)
static_cast<unsigned long>(output->Output().Shape().TotalSize()));
}
void CheckFindByNameResult(FunctionPtr actual, FunctionPtr expected)
{
if (actual == nullptr)
{
if (expected != nullptr)
ReportFailure("The expected function '%S' has not been found.", expected->Name().c_str());
}
else
{
if (expected == nullptr)
ReportFailure("Found a function '%S', but null is expected.", actual->Name().c_str());
else if (expected->Name().compare(actual->Name()) != 0)
ReportFailure("The found function '%S' does have the same name as the exepected one '%S'", actual->Name().c_str(), expected->Name().c_str());
}
}
void CheckFindAllWithNameResult(std::vector<FunctionPtr> actual, std::wstring expectedName, size_t expectedSize)
{
if (actual.size() != expectedSize)
ReportFailure("The number of found functions does not match the expected number.");
else
{
for (size_t i = 0; i < actual.size(); i++)
{
if (actual[i]->Name().compare(expectedName) != 0)
ReportFailure("The found function '%S' does have the same name as the exepected one '%S'", actual[i]->Name().c_str(), expectedName.c_str());
}
}
}
void TestFindName()
{
size_t inputDim = 10;
size_t outputDim = 20;
const std::wstring timesFuncName = L"TimesFunc";
const std::wstring plusFuncName = L"PlusFunc";
const std::wstring anotherPlusFuncName = L"AnotherPlusFunc";
const std::wstring minusFuncName = L"MinusFunc";
const std::wstring anotherMinusFuncName = L"AnotherMinusFunc";
const std::wstring blockFuncName = L"BlockFunc";
const std::wstring nonExistingFuncName = L"NonExistingFunc";
const std::wstring nestedBlockFuncName = L"NestedBlockFunc";
const std::wstring emptyFuncName = L"";
const std::wstring placeholderName = L"inputPlaceholder";
const std::wstring variableName = L"features";
const std::wstring aliasFuncName = L"aliasFunc";
auto inputVar1 = InputVariable({ inputDim }, DataType::Float, L"features");
auto inputPlaceholder1 = PlaceholderVariable(L"inputPlaceholder");
auto timesParam = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<float>({ outputDim, inputDim }, -0.05, 0.05, 1, DeviceDescriptor::DefaultDevice()));
auto timesFunc1 = CNTK::Times(timesParam, inputPlaceholder1, timesFuncName);
auto plusFunc1 = CNTK::Plus(Constant::Scalar(2.0f), timesFunc1, plusFuncName);
auto plusFunc2 = CNTK::Plus(Constant::Scalar(2.0f), plusFunc1, plusFuncName);
auto emptyNameFunc1 = CNTK::Plus(plusFunc1, plusFunc2);
auto minusFunc1 = CNTK::Minus(plusFunc2, emptyNameFunc1, minusFuncName);
// Test FindByName for the case without any block function
CheckFindByNameResult(minusFunc1->FindByName(timesFuncName), timesFunc1);
CheckFindByNameResult(minusFunc1->FindByName(minusFuncName), minusFunc1);
CheckFindByNameResult(minusFunc1->FindByName(emptyFuncName), emptyNameFunc1);
CheckFindByNameResult(minusFunc1->FindByName(nonExistingFuncName), nullptr);
CheckFindByNameResult(minusFunc1->FindByName(placeholderName), nullptr);
VerifyException([&minusFunc1, &plusFuncName]() {
minusFunc1->FindByName(plusFuncName);
}, "The expected exception has not been caugth: multiple functions with the same name.");
// Test FindAllWithName for the case without any block function
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(timesFuncName), timesFuncName, 1);
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(minusFuncName), minusFuncName, 1);
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(emptyFuncName), emptyFuncName, 1);
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(nonExistingFuncName), nonExistingFuncName, 0);
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(placeholderName), placeholderName, 0);
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(plusFuncName), plusFuncName, 2);
// Build a block function
auto blockFunc = CNTK::AsBlock(std::move(minusFunc1), { { inputPlaceholder1, inputVar1 } }, L"TimesPlusMinus", blockFuncName);
// Build a nested block function
auto inputPlaceholder2 = PlaceholderVariable(L"inputPlaceholder");
auto inputPlaceholder3 = PlaceholderVariable(L"inputPlaceholder");
auto inputVar2 = InputVariable({ outputDim }, DataType::Float, L"features");
auto anotherMinusFunc1 = CNTK::Minus(inputPlaceholder2, Constant::Scalar(3.0f), anotherMinusFuncName);
auto plusFunc3 = CNTK::Plus(Constant::Scalar(3.0f), anotherMinusFunc1, plusFuncName);
auto cloneBlockFunc = blockFunc->Clone(ParameterCloningMethod::Clone, { { inputVar1, inputPlaceholder3 } });
auto minusFunc2 = CNTK::Minus(cloneBlockFunc, plusFunc3, minusFuncName);
auto plusFunc4 = CNTK::Plus(minusFunc2, Constant::Scalar(3.0f), plusFuncName);
auto nestedBlockFunc = CNTK::AsBlock(std::move(plusFunc4), { { inputPlaceholder2, inputVar2 },{ inputPlaceholder3, inputVar1 } }, L"NestedBlock", nestedBlockFuncName);
// Build a function having both block and nested block functions
auto inputVar3 = InputVariable({ outputDim }, DataType::Float, variableName);
auto plusFunc5 = CNTK::Plus(inputVar3, blockFunc, plusFuncName);
auto anotherPlusFunc1 = CNTK::Plus(plusFunc5, nestedBlockFunc, anotherPlusFuncName);
auto minusFunc3 = CNTK::Minus(anotherPlusFunc1, Constant::Scalar(3.0f), minusFuncName);
// Test FindByName with block functions, nestedSearchInsideBlockFunction is false.
CheckFindByNameResult(minusFunc3->FindByName(anotherPlusFuncName), anotherPlusFunc1);
CheckFindByNameResult(minusFunc3->FindByName(anotherMinusFuncName), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(nonExistingFuncName), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(variableName), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(nestedBlockFuncName), nestedBlockFunc);
CheckFindByNameResult(minusFunc3->FindByName(plusFuncName), plusFunc5);
CheckFindByNameResult(minusFunc3->FindByName(timesFuncName), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(emptyFuncName), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(minusFuncName), minusFunc3);
CheckFindByNameResult(minusFunc3->FindByName(blockFuncName), blockFunc);
// Test FindByName with block funcitons, nestedSearchInsideBlockFunction is true
CheckFindByNameResult(minusFunc3->FindByName(anotherPlusFuncName, true), anotherPlusFunc1);
CheckFindByNameResult(minusFunc3->FindByName(anotherMinusFuncName, true), anotherMinusFunc1);
CheckFindByNameResult(minusFunc3->FindByName(nonExistingFuncName, true), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(variableName, true), nullptr);
CheckFindByNameResult(minusFunc3->FindByName(nestedBlockFuncName, true), nestedBlockFunc);
VerifyException([&minusFunc3, &plusFuncName]() {
minusFunc3->FindByName(plusFuncName, true);
}, "The expected exception has not been caugth: multiple functions with the same name.");
VerifyException([&minusFunc3, &timesFuncName]() {
minusFunc3->FindByName(timesFuncName, true);
}, "The expected exception has not been caugth: multiple functions with the same name.");
VerifyException([&minusFunc3, &emptyFuncName]() {
minusFunc3->FindByName(emptyFuncName, true);
}, "The expected exception has not been caugth: multiple functions with the same name.");
VerifyException([&minusFunc3, &minusFuncName]() {
minusFunc3->FindByName(minusFuncName, true);
}, "The expected exception has not been caugth: multiple functions with the same name.");
VerifyException([&minusFunc3, &blockFuncName]() {
minusFunc3->FindByName(blockFuncName, true);
}, "The expected exception has not been caugth: multiple functions with the same name.");
// Test FindAllWithName with block functions, nestedSearchInsideBlockFunction is false.
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherPlusFuncName), anotherPlusFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherMinusFuncName), anotherMinusFuncName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nonExistingFuncName), nonExistingFuncName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(variableName), variableName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nestedBlockFuncName), nestedBlockFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(plusFuncName), plusFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(timesFuncName), timesFuncName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(emptyFuncName), emptyFuncName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(minusFuncName), minusFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(blockFuncName), blockFuncName, 1);
// Test FindAllWithName with block funcitons, nestedSearchInsideBlockFunction is true
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherPlusFuncName, true), anotherPlusFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherMinusFuncName, true), anotherMinusFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nonExistingFuncName, true), nonExistingFuncName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(variableName, true), variableName, 0);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nestedBlockFuncName, true), nestedBlockFuncName, 1);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(plusFuncName, true), plusFuncName, 7);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(timesFuncName, true), timesFuncName, 2);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(emptyFuncName,true), emptyFuncName, 2);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(minusFuncName, true), minusFuncName, 4);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(blockFuncName, true), blockFuncName, 2);
// Test alias
auto aliasFunc1 = Alias(anotherPlusFunc1, aliasFuncName);
// The Alias does not really create an alias for the function, but indeed create a new function having alias as name.
// The new function is not a part of existing graph, except it is explicitly referenced in the graph.
// TODO: change the tests when Alias is a real alias of a function.
CheckFindByNameResult(minusFunc3->FindByName(aliasFuncName), nullptr);
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(aliasFuncName, true), aliasFuncName, 0);
auto minusFunc4 = CNTK::Minus(aliasFunc1, minusFunc3, minusFuncName);
CheckFindByNameResult(minusFunc4->FindByName(aliasFuncName), aliasFunc1);
CheckFindAllWithNameResult(minusFunc4->FindAllWithName(aliasFuncName, true), aliasFuncName, 1);
}
void FunctionTests()
{
fprintf(stderr, "\nFunctionTests..\n");
TestSplice();
TestFindName();
TestChangingParameterValues<float>(2, DeviceDescriptor::CPUDevice());
if (IsGPUAvailable())
TestChangingParameterValues<double>(3, DeviceDescriptor::GPUDevice(0));

Просмотреть файл

@ -643,35 +643,35 @@ void ValueCopyToExceptionsTest(const DeviceDescriptor& device)
auto sampleVariable = CreateVariable<float>(NDShape::Unknown, 0);
VerifyException([&val, &sampleVariable, &output]() {
val->CopyVariableValueTo(sampleVariable, output);
}, "The expected exception has not been caugth: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
}, "The expected exception has not been caught: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
// Test variable having shape with InferredDimentsion.
sampleVariable = CreateVariable<float>(NDShape(2), 0);
VerifyException([&val, &sampleVariable, &output]() {
val->CopyVariableValueTo(sampleVariable, output);
}, "The expected exception has not been caugth: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
}, "The expected exception has not been caught: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
// Test variable having incorrect data type.
sampleVariable = CreateVariable<double>(sampleShape, 0);
VerifyException([&val, &sampleVariable, &output]() {
val->CopyVariableValueTo(sampleVariable, output);
}, "The expected exception has not been caugth: The outputVariable has a different data type than the Value object.");
}, "The expected exception has not been caught: The outputVariable has a different data type than the Value object.");
sampleVariable = CreateVariable<double>(sampleOneHotShape, 0);
VerifyException([&val, &sampleVariable, &outputInOneHot]() {
val->CopyVariableValueTo(sampleVariable, outputInOneHot);
}, "The expected exception has not been caugth: The outputVariable has a different data type than the Value object.");
}, "The expected exception has not been caught: The outputVariable has a different data type than the Value object.");
// Test output buffer having incorrect data type.
sampleVariable = CreateVariable<float>(sampleShape, 0);
VerifyException([&val, &sampleVariable, &outputInDouble]() {
val->CopyVariableValueTo(sampleVariable, outputInDouble);
}, "The expected exception has not been caugth: The specified ElementType Double does not match the DataType Float");
}, "The expected exception has not been caught: The specified ElementType Double does not match the DataType Float");
// Test the first axis when using one-hot format.
VerifyException([&val, &sampleVariable, &outputInOneHot]() {
val->CopyVariableValueTo(sampleVariable, outputInOneHot);
}, "The expected exception has not been caugth: The outputVariable's leading axis dimensionality must equal the total size of the variable for sparse data.");
}, "The expected exception has not been caught: The outputVariable's leading axis dimensionality must equal the total size of the variable for sparse data.");
}
@ -760,12 +760,12 @@ void CreateBatchTestDense(const DeviceDescriptor device, bool readOnly)
vector<ElementType> wrongBatch(sampleSize * 2 - 1, 0);
VerifyException([&sampleShape, &wrongBatch, &device, &readOnly]() {
Value::CreateBatch(sampleShape, wrongBatch, device, readOnly);
}, "The expected exception has not been caugth: The number of data is not a multiple of the sample size.");
}, "The expected exception has not been caught: The number of data is not a multiple of the sample size.");
auto emptyBatch = vector<ElementType>(0);
VerifyException([&sampleShape, &emptyBatch, &device, &readOnly]() {
Value::CreateBatch(sampleShape, emptyBatch, device, readOnly);
}, "The expected exception has not been caugth: The number of sequences is 0");
}, "The expected exception has not been caught: The number of sequences is 0");
}
template <typename ElementType>
@ -803,12 +803,12 @@ void CreateSequenceTestDense(const DeviceDescriptor device, bool readOnly)
vector<ElementType> wrongSeq(sampleSize * 2 - 1, 0);
VerifyException([&sampleShape, &wrongSeq, &device, &readOnly]() {
Value::CreateSequence(sampleShape, wrongSeq, device, readOnly);
}, "The expected exception has not been caugth: The number of data is not a multiple of the sample size.");
}, "The expected exception has not been caught: The number of data is not a multiple of the sample size.");
auto emptySeq = vector<ElementType>(0);
VerifyException([&sampleShape, &emptySeq, &device, &readOnly]() {
Value::CreateSequence(sampleShape, emptySeq, device, readOnly);
}, "The expected exception has not been caugth: The sequence length is 0");
}, "The expected exception has not been caught: The sequence length is 0");
}
@ -889,7 +889,7 @@ void CreateBatchTestOneHot(const DeviceDescriptor device, bool readOnly)
auto emptyBatch = vector<size_t>(0);
VerifyException([&dimSize, &emptyBatch, &device, &readOnly]() {
Value::CreateBatch<ElementType>(dimSize, emptyBatch, device, readOnly);
}, "The expected exception has not been caugth: The number of sequences is 0");
}, "The expected exception has not been caught: The number of sequences is 0");
}
template <typename ElementType>
@ -928,7 +928,7 @@ void CreateSequenceTestOneHot(const DeviceDescriptor device, bool readOnly)
auto emptySeq = vector<size_t>(0);
VerifyException([&dimSize, &emptySeq, &device, &readOnly]() {
Value::CreateSequence<ElementType>(dimSize, emptySeq, device, readOnly);
}, "The expected exception has not been caugth: The sequences length is 0");
}, "The expected exception has not been caught: The sequences length is 0");
}

Просмотреть файл

@ -96,6 +96,7 @@
<Compile Include="SwigProxyClasses\Variable.cs" />
<Compile Include="SwigProxyClasses\VariableKind.cs" />
<Compile Include="SwigProxyClasses\VariableVector.cs" />
<Compile Include="SwigProxyClasses\FunctionPtrVector.cs" />
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
<Target Name="Build" Condition="Exists('$(SWIG_PATH)\swig.exe')" DependsOnTargets="$(BuildDependsOn)" />

Просмотреть файл

@ -42,6 +42,7 @@
%template(DeviceDescriptorVector) std::vector<CNTK::DeviceDescriptor>;
%template(UnorderedMapVariableValuePtr) std::unordered_map<CNTK::Variable, std::shared_ptr<CNTK::Value>>;
%template(UnorderedMapVariableVariable) std::unordered_map<CNTK::Variable, CNTK::Variable>;
%template(FunctionPtrVector) std::vector<std::shared_ptr<CNTK::Function>>;
%template() std::vector<bool>;
%template() std::pair<size_t, double>;