add Function::FindByName(), Function::FindAllWithName; Add tests;
This commit is contained in:
Родитель
5b8d122a4b
Коммит
8176e0fccf
|
@ -2799,7 +2799,7 @@ namespace CNTK
|
|||
}
|
||||
|
||||
///
|
||||
/// Returns the set of all Constant variables of 'this' Function.
|
||||
/// Returns the set of all Placeholder variables of 'this' Function.
|
||||
///
|
||||
std::vector<Variable> Placeholders() const
|
||||
{
|
||||
|
@ -2809,6 +2809,41 @@ namespace CNTK
|
|||
}
|
||||
|
||||
///
|
||||
/// Find a function with the given name in the Function graph underlying 'this' Function.
|
||||
/// If more than one function with the same name, an exception is thrown.
|
||||
/// If nestedSearchInsideBlockFunction is true, all functions inside block functions are also searched for the given name.
|
||||
///
|
||||
FunctionPtr FindByName(const std::wstring& name, bool nestedSearchInsideBlockFunction = false)
|
||||
{
|
||||
FunctionPtr foundFunction = nullptr;
|
||||
PreorderTraverseFunctions(RootFunction(), [&foundFunction, &name](const FunctionPtr& function) {
|
||||
if (name.compare(function->Name()) == 0)
|
||||
{
|
||||
if (foundFunction != nullptr)
|
||||
RuntimeError("Multiple functions with the same name are found in the Function graph underlying 'this' Function.");
|
||||
else
|
||||
foundFunction = function;
|
||||
}
|
||||
}, nestedSearchInsideBlockFunction);
|
||||
|
||||
return foundFunction;
|
||||
}
|
||||
|
||||
///
|
||||
/// Find a list of functions with the given name in the Function graph underlying 'this' Function.
|
||||
/// If nestedSearchInsideBlockFunction is true, all functions inside block functions are also searched for the given name.
|
||||
///
|
||||
std::vector<FunctionPtr> FindAllWithName(const std::wstring& name, bool nestedSearchInsideBlockFunction = false)
|
||||
{
|
||||
std::vector<FunctionPtr> foundFunctions;
|
||||
PreorderTraverseFunctions(RootFunction(), [&foundFunctions, &name](const FunctionPtr& function) {
|
||||
if (name.compare(function->Name()) == 0)
|
||||
foundFunctions.push_back(function);
|
||||
}, nestedSearchInsideBlockFunction);
|
||||
|
||||
return foundFunctions;
|
||||
}
|
||||
|
||||
/// Returns the dictionary of attributes of 'this' Function
|
||||
///
|
||||
const Dictionary& Attributes() const { return m_attributes; }
|
||||
|
@ -2855,6 +2890,38 @@ namespace CNTK
|
|||
///
|
||||
CNTK_API Function(const std::vector<Variable>& inputs, Dictionary&& functionConfig, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"UserDefinedFunction"));
|
||||
|
||||
template <typename FunctionType>
|
||||
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor, bool traverseInsideBlockFunction = false)
|
||||
{
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
PreorderTraverseFunctions(rootFunction, visitedFunctions, functor, traverseInsideBlockFunction);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
|
||||
template <typename FunctionType>
|
||||
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions, const FunctionType& functor, bool traverseInsideBlockFunction = false)
|
||||
{
|
||||
visitedFunctions.insert(rootFunction);
|
||||
functor(rootFunction);
|
||||
|
||||
if (traverseInsideBlockFunction && rootFunction->IsBlock())
|
||||
{
|
||||
PreorderTraverseFunctions(rootFunction->BlockRoot(), visitedFunctions, functor, traverseInsideBlockFunction);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<Variable> rootFunctionInputs = rootFunction->Inputs();
|
||||
for (const auto& rootInput : rootFunctionInputs)
|
||||
{
|
||||
if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end())
|
||||
{
|
||||
const auto& function = rootInput.Owner();
|
||||
PreorderTraverseFunctions(function, visitedFunctions, functor, traverseInsideBlockFunction);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Restores the state of the 'this' Function in place using the provided dictionary.
|
||||
/// Structurally, 'this' Function graph has to be identical to the state captured in the dictionary.
|
||||
CNTK_API virtual void RestoreFromCheckpoint(const Dictionary& dictionary);
|
||||
|
|
|
@ -121,31 +121,6 @@ namespace CNTK
|
|||
return CompositeFunctionOpName;
|
||||
}
|
||||
|
||||
template <typename FunctionType>
|
||||
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, const FunctionType& functor)
|
||||
{
|
||||
std::unordered_set<FunctionPtr> visitedFunctions;
|
||||
PreorderTraverseFunctions(rootFunction, visitedFunctions, functor);
|
||||
}
|
||||
|
||||
// Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph.
|
||||
template <typename FunctionType>
|
||||
static void PreorderTraverseFunctions(const FunctionPtr& rootFunction, std::unordered_set<FunctionPtr>& visitedFunctions, const FunctionType& functor)
|
||||
{
|
||||
visitedFunctions.insert(rootFunction);
|
||||
functor(rootFunction);
|
||||
|
||||
std::vector<Variable> rootFunctionInputs = rootFunction->Inputs();
|
||||
for (const auto& rootInput : rootFunctionInputs)
|
||||
{
|
||||
if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end())
|
||||
{
|
||||
const auto& function = rootInput.Owner();
|
||||
PreorderTraverseFunctions(function, visitedFunctions, functor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctionType>
|
||||
static void PreorderTraverseVariables(const FunctionPtr& rootFunction, const FunctionType& functor, bool pythonOperandOrder = false)
|
||||
{
|
||||
|
|
|
@ -774,12 +774,179 @@ void TestOuputVariableName(const DeviceDescriptor& device)
|
|||
static_cast<unsigned long>(output->Output().Shape().TotalSize()));
|
||||
}
|
||||
|
||||
void CheckFindByNameResult(FunctionPtr actual, FunctionPtr expected)
|
||||
{
|
||||
if (actual == nullptr)
|
||||
{
|
||||
if (expected != nullptr)
|
||||
ReportFailure("The expected function '%S' has not been found.", expected->Name().c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
if (expected == nullptr)
|
||||
ReportFailure("Found a function '%S', but null is expected.", actual->Name().c_str());
|
||||
else if (expected->Name().compare(actual->Name()) != 0)
|
||||
ReportFailure("The found function '%S' does have the same name as the exepected one '%S'", actual->Name().c_str(), expected->Name().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void CheckFindAllWithNameResult(std::vector<FunctionPtr> actual, std::wstring expectedName, size_t expectedSize)
|
||||
{
|
||||
if (actual.size() != expectedSize)
|
||||
ReportFailure("The number of found functions does not match the expected number.");
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < actual.size(); i++)
|
||||
{
|
||||
if (actual[i]->Name().compare(expectedName) != 0)
|
||||
ReportFailure("The found function '%S' does have the same name as the exepected one '%S'", actual[i]->Name().c_str(), expectedName.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestFindName()
|
||||
{
|
||||
size_t inputDim = 10;
|
||||
size_t outputDim = 20;
|
||||
const std::wstring timesFuncName = L"TimesFunc";
|
||||
const std::wstring plusFuncName = L"PlusFunc";
|
||||
const std::wstring anotherPlusFuncName = L"AnotherPlusFunc";
|
||||
const std::wstring minusFuncName = L"MinusFunc";
|
||||
const std::wstring anotherMinusFuncName = L"AnotherMinusFunc";
|
||||
const std::wstring blockFuncName = L"BlockFunc";
|
||||
const std::wstring nonExistingFuncName = L"NonExistingFunc";
|
||||
const std::wstring nestedBlockFuncName = L"NestedBlockFunc";
|
||||
const std::wstring emptyFuncName = L"";
|
||||
const std::wstring placeholderName = L"inputPlaceholder";
|
||||
const std::wstring variableName = L"features";
|
||||
const std::wstring aliasFuncName = L"aliasFunc";
|
||||
|
||||
auto inputVar1 = InputVariable({ inputDim }, DataType::Float, L"features");
|
||||
|
||||
auto inputPlaceholder1 = PlaceholderVariable(L"inputPlaceholder");
|
||||
auto timesParam = CNTK::Parameter(CNTK::NDArrayView::RandomUniform<float>({ outputDim, inputDim }, -0.05, 0.05, 1, DeviceDescriptor::DefaultDevice()));
|
||||
auto timesFunc1 = CNTK::Times(timesParam, inputPlaceholder1, timesFuncName);
|
||||
auto plusFunc1 = CNTK::Plus(Constant::Scalar(2.0f), timesFunc1, plusFuncName);
|
||||
auto plusFunc2 = CNTK::Plus(Constant::Scalar(2.0f), plusFunc1, plusFuncName);
|
||||
auto emptyNameFunc1 = CNTK::Plus(plusFunc1, plusFunc2);
|
||||
auto minusFunc1 = CNTK::Minus(plusFunc2, emptyNameFunc1, minusFuncName);
|
||||
|
||||
// Test FindByName for the case without any block function
|
||||
CheckFindByNameResult(minusFunc1->FindByName(timesFuncName), timesFunc1);
|
||||
CheckFindByNameResult(minusFunc1->FindByName(minusFuncName), minusFunc1);
|
||||
CheckFindByNameResult(minusFunc1->FindByName(emptyFuncName), emptyNameFunc1);
|
||||
CheckFindByNameResult(minusFunc1->FindByName(nonExistingFuncName), nullptr);
|
||||
CheckFindByNameResult(minusFunc1->FindByName(placeholderName), nullptr);
|
||||
VerifyException([&minusFunc1, &plusFuncName]() {
|
||||
minusFunc1->FindByName(plusFuncName);
|
||||
}, "The expected exception has not been caugth: multiple functions with the same name.");
|
||||
|
||||
// Test FindAllWithName for the case without any block function
|
||||
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(timesFuncName), timesFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(minusFuncName), minusFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(emptyFuncName), emptyFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(nonExistingFuncName), nonExistingFuncName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(placeholderName), placeholderName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc1->FindAllWithName(plusFuncName), plusFuncName, 2);
|
||||
|
||||
// Build a block function
|
||||
auto blockFunc = CNTK::AsBlock(std::move(minusFunc1), { { inputPlaceholder1, inputVar1 } }, L"TimesPlusMinus", blockFuncName);
|
||||
|
||||
// Build a nested block function
|
||||
auto inputPlaceholder2 = PlaceholderVariable(L"inputPlaceholder");
|
||||
auto inputPlaceholder3 = PlaceholderVariable(L"inputPlaceholder");
|
||||
auto inputVar2 = InputVariable({ outputDim }, DataType::Float, L"features");
|
||||
auto anotherMinusFunc1 = CNTK::Minus(inputPlaceholder2, Constant::Scalar(3.0f), anotherMinusFuncName);
|
||||
auto plusFunc3 = CNTK::Plus(Constant::Scalar(3.0f), anotherMinusFunc1, plusFuncName);
|
||||
auto cloneBlockFunc = blockFunc->Clone(ParameterCloningMethod::Clone, { { inputVar1, inputPlaceholder3 } });
|
||||
auto minusFunc2 = CNTK::Minus(cloneBlockFunc, plusFunc3, minusFuncName);
|
||||
auto plusFunc4 = CNTK::Plus(minusFunc2, Constant::Scalar(3.0f), plusFuncName);
|
||||
auto nestedBlockFunc = CNTK::AsBlock(std::move(plusFunc4), { { inputPlaceholder2, inputVar2 },{ inputPlaceholder3, inputVar1 } }, L"NestedBlock", nestedBlockFuncName);
|
||||
|
||||
// Build a function having both block and nested block functions
|
||||
auto inputVar3 = InputVariable({ outputDim }, DataType::Float, variableName);
|
||||
auto plusFunc5 = CNTK::Plus(inputVar3, blockFunc, plusFuncName);
|
||||
auto anotherPlusFunc1 = CNTK::Plus(plusFunc5, nestedBlockFunc, anotherPlusFuncName);
|
||||
auto minusFunc3 = CNTK::Minus(anotherPlusFunc1, Constant::Scalar(3.0f), minusFuncName);
|
||||
|
||||
// Test FindByName with block functions, nestedSearchInsideBlockFunction is false.
|
||||
CheckFindByNameResult(minusFunc3->FindByName(anotherPlusFuncName), anotherPlusFunc1);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(anotherMinusFuncName), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(nonExistingFuncName), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(variableName), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(nestedBlockFuncName), nestedBlockFunc);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(plusFuncName), plusFunc5);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(timesFuncName), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(emptyFuncName), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(minusFuncName), minusFunc3);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(blockFuncName), blockFunc);
|
||||
|
||||
// Test FindByName with block funcitons, nestedSearchInsideBlockFunction is true
|
||||
CheckFindByNameResult(minusFunc3->FindByName(anotherPlusFuncName, true), anotherPlusFunc1);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(anotherMinusFuncName, true), anotherMinusFunc1);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(nonExistingFuncName, true), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(variableName, true), nullptr);
|
||||
CheckFindByNameResult(minusFunc3->FindByName(nestedBlockFuncName, true), nestedBlockFunc);
|
||||
VerifyException([&minusFunc3, &plusFuncName]() {
|
||||
minusFunc3->FindByName(plusFuncName, true);
|
||||
}, "The expected exception has not been caugth: multiple functions with the same name.");
|
||||
VerifyException([&minusFunc3, ×FuncName]() {
|
||||
minusFunc3->FindByName(timesFuncName, true);
|
||||
}, "The expected exception has not been caugth: multiple functions with the same name.");
|
||||
VerifyException([&minusFunc3, &emptyFuncName]() {
|
||||
minusFunc3->FindByName(emptyFuncName, true);
|
||||
}, "The expected exception has not been caugth: multiple functions with the same name.");
|
||||
VerifyException([&minusFunc3, &minusFuncName]() {
|
||||
minusFunc3->FindByName(minusFuncName, true);
|
||||
}, "The expected exception has not been caugth: multiple functions with the same name.");
|
||||
VerifyException([&minusFunc3, &blockFuncName]() {
|
||||
minusFunc3->FindByName(blockFuncName, true);
|
||||
}, "The expected exception has not been caugth: multiple functions with the same name.");
|
||||
|
||||
// Test FindAllWithName with block functions, nestedSearchInsideBlockFunction is false.
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherPlusFuncName), anotherPlusFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherMinusFuncName), anotherMinusFuncName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nonExistingFuncName), nonExistingFuncName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(variableName), variableName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nestedBlockFuncName), nestedBlockFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(plusFuncName), plusFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(timesFuncName), timesFuncName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(emptyFuncName), emptyFuncName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(minusFuncName), minusFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(blockFuncName), blockFuncName, 1);
|
||||
|
||||
// Test FindAllWithName with block funcitons, nestedSearchInsideBlockFunction is true
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherPlusFuncName, true), anotherPlusFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(anotherMinusFuncName, true), anotherMinusFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nonExistingFuncName, true), nonExistingFuncName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(variableName, true), variableName, 0);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(nestedBlockFuncName, true), nestedBlockFuncName, 1);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(plusFuncName, true), plusFuncName, 7);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(timesFuncName, true), timesFuncName, 2);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(emptyFuncName,true), emptyFuncName, 2);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(minusFuncName, true), minusFuncName, 4);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(blockFuncName, true), blockFuncName, 2);
|
||||
|
||||
// Test alias
|
||||
auto aliasFunc1 = Alias(anotherPlusFunc1, aliasFuncName);
|
||||
// The Alias does not really create an alias for the function, but indeed create a new function having alias as name.
|
||||
// The new function is not a part of existing graph, except it is explicitly referenced in the graph.
|
||||
// TODO: change the tests when Alias is a real alias of a function.
|
||||
CheckFindByNameResult(minusFunc3->FindByName(aliasFuncName), nullptr);
|
||||
CheckFindAllWithNameResult(minusFunc3->FindAllWithName(aliasFuncName, true), aliasFuncName, 0);
|
||||
auto minusFunc4 = CNTK::Minus(aliasFunc1, minusFunc3, minusFuncName);
|
||||
CheckFindByNameResult(minusFunc4->FindByName(aliasFuncName), aliasFunc1);
|
||||
CheckFindAllWithNameResult(minusFunc4->FindAllWithName(aliasFuncName, true), aliasFuncName, 1);
|
||||
}
|
||||
|
||||
void FunctionTests()
|
||||
{
|
||||
fprintf(stderr, "\nFunctionTests..\n");
|
||||
|
||||
TestSplice();
|
||||
|
||||
TestFindName();
|
||||
|
||||
TestChangingParameterValues<float>(2, DeviceDescriptor::CPUDevice());
|
||||
if (IsGPUAvailable())
|
||||
TestChangingParameterValues<double>(3, DeviceDescriptor::GPUDevice(0));
|
||||
|
|
|
@ -643,35 +643,35 @@ void ValueCopyToExceptionsTest(const DeviceDescriptor& device)
|
|||
auto sampleVariable = CreateVariable<float>(NDShape::Unknown, 0);
|
||||
VerifyException([&val, &sampleVariable, &output]() {
|
||||
val->CopyVariableValueTo(sampleVariable, output);
|
||||
}, "The expected exception has not been caugth: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
|
||||
}, "The expected exception has not been caught: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
|
||||
|
||||
// Test variable having shape with InferredDimentsion.
|
||||
sampleVariable = CreateVariable<float>(NDShape(2), 0);
|
||||
VerifyException([&val, &sampleVariable, &output]() {
|
||||
val->CopyVariableValueTo(sampleVariable, output);
|
||||
}, "The expected exception has not been caugth: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
|
||||
}, "The expected exception has not been caught: It is not supported that the outputVariable has a unknown shape or inferred dimension.");
|
||||
|
||||
// Test variable having incorrect data type.
|
||||
sampleVariable = CreateVariable<double>(sampleShape, 0);
|
||||
VerifyException([&val, &sampleVariable, &output]() {
|
||||
val->CopyVariableValueTo(sampleVariable, output);
|
||||
}, "The expected exception has not been caugth: The outputVariable has a different data type than the Value object.");
|
||||
}, "The expected exception has not been caught: The outputVariable has a different data type than the Value object.");
|
||||
|
||||
sampleVariable = CreateVariable<double>(sampleOneHotShape, 0);
|
||||
VerifyException([&val, &sampleVariable, &outputInOneHot]() {
|
||||
val->CopyVariableValueTo(sampleVariable, outputInOneHot);
|
||||
}, "The expected exception has not been caugth: The outputVariable has a different data type than the Value object.");
|
||||
}, "The expected exception has not been caught: The outputVariable has a different data type than the Value object.");
|
||||
|
||||
// Test output buffer having incorrect data type.
|
||||
sampleVariable = CreateVariable<float>(sampleShape, 0);
|
||||
VerifyException([&val, &sampleVariable, &outputInDouble]() {
|
||||
val->CopyVariableValueTo(sampleVariable, outputInDouble);
|
||||
}, "The expected exception has not been caugth: The specified ElementType Double does not match the DataType Float");
|
||||
}, "The expected exception has not been caught: The specified ElementType Double does not match the DataType Float");
|
||||
|
||||
// Test the first axis when using one-hot format.
|
||||
VerifyException([&val, &sampleVariable, &outputInOneHot]() {
|
||||
val->CopyVariableValueTo(sampleVariable, outputInOneHot);
|
||||
}, "The expected exception has not been caugth: The outputVariable's leading axis dimensionality must equal the total size of the variable for sparse data.");
|
||||
}, "The expected exception has not been caught: The outputVariable's leading axis dimensionality must equal the total size of the variable for sparse data.");
|
||||
}
|
||||
|
||||
|
||||
|
@ -760,12 +760,12 @@ void CreateBatchTestDense(const DeviceDescriptor device, bool readOnly)
|
|||
vector<ElementType> wrongBatch(sampleSize * 2 - 1, 0);
|
||||
VerifyException([&sampleShape, &wrongBatch, &device, &readOnly]() {
|
||||
Value::CreateBatch(sampleShape, wrongBatch, device, readOnly);
|
||||
}, "The expected exception has not been caugth: The number of data is not a multiple of the sample size.");
|
||||
}, "The expected exception has not been caught: The number of data is not a multiple of the sample size.");
|
||||
|
||||
auto emptyBatch = vector<ElementType>(0);
|
||||
VerifyException([&sampleShape, &emptyBatch, &device, &readOnly]() {
|
||||
Value::CreateBatch(sampleShape, emptyBatch, device, readOnly);
|
||||
}, "The expected exception has not been caugth: The number of sequences is 0");
|
||||
}, "The expected exception has not been caught: The number of sequences is 0");
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
|
@ -803,12 +803,12 @@ void CreateSequenceTestDense(const DeviceDescriptor device, bool readOnly)
|
|||
vector<ElementType> wrongSeq(sampleSize * 2 - 1, 0);
|
||||
VerifyException([&sampleShape, &wrongSeq, &device, &readOnly]() {
|
||||
Value::CreateSequence(sampleShape, wrongSeq, device, readOnly);
|
||||
}, "The expected exception has not been caugth: The number of data is not a multiple of the sample size.");
|
||||
}, "The expected exception has not been caught: The number of data is not a multiple of the sample size.");
|
||||
|
||||
auto emptySeq = vector<ElementType>(0);
|
||||
VerifyException([&sampleShape, &emptySeq, &device, &readOnly]() {
|
||||
Value::CreateSequence(sampleShape, emptySeq, device, readOnly);
|
||||
}, "The expected exception has not been caugth: The sequence length is 0");
|
||||
}, "The expected exception has not been caught: The sequence length is 0");
|
||||
}
|
||||
|
||||
|
||||
|
@ -889,7 +889,7 @@ void CreateBatchTestOneHot(const DeviceDescriptor device, bool readOnly)
|
|||
auto emptyBatch = vector<size_t>(0);
|
||||
VerifyException([&dimSize, &emptyBatch, &device, &readOnly]() {
|
||||
Value::CreateBatch<ElementType>(dimSize, emptyBatch, device, readOnly);
|
||||
}, "The expected exception has not been caugth: The number of sequences is 0");
|
||||
}, "The expected exception has not been caught: The number of sequences is 0");
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
|
@ -928,7 +928,7 @@ void CreateSequenceTestOneHot(const DeviceDescriptor device, bool readOnly)
|
|||
auto emptySeq = vector<size_t>(0);
|
||||
VerifyException([&dimSize, &emptySeq, &device, &readOnly]() {
|
||||
Value::CreateSequence<ElementType>(dimSize, emptySeq, device, readOnly);
|
||||
}, "The expected exception has not been caugth: The sequences length is 0");
|
||||
}, "The expected exception has not been caught: The sequences length is 0");
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -96,6 +96,7 @@
|
|||
<Compile Include="SwigProxyClasses\Variable.cs" />
|
||||
<Compile Include="SwigProxyClasses\VariableKind.cs" />
|
||||
<Compile Include="SwigProxyClasses\VariableVector.cs" />
|
||||
<Compile Include="SwigProxyClasses\FunctionPtrVector.cs" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
|
||||
<Target Name="Build" Condition="Exists('$(SWIG_PATH)\swig.exe')" DependsOnTargets="$(BuildDependsOn)" />
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
%template(DeviceDescriptorVector) std::vector<CNTK::DeviceDescriptor>;
|
||||
%template(UnorderedMapVariableValuePtr) std::unordered_map<CNTK::Variable, std::shared_ptr<CNTK::Value>>;
|
||||
%template(UnorderedMapVariableVariable) std::unordered_map<CNTK::Variable, CNTK::Variable>;
|
||||
%template(FunctionPtrVector) std::vector<std::shared_ptr<CNTK::Function>>;
|
||||
|
||||
%template() std::vector<bool>;
|
||||
%template() std::pair<size_t, double>;
|
||||
|
|
Загрузка…
Ссылка в новой задаче