Add UTs for device selection API

This commit is contained in:
Alexey Reznichenko 2016-09-27 16:05:53 +02:00
Родитель 2e2cb6430f
Коммит 2f53509644
8 изменённых файлов: 63 добавлений и 21 удалений

Просмотреть файл

@ -422,6 +422,7 @@ CNTKLIBRARY_TESTS_SRC =\
Tests/UnitTests/V2LibraryTests/FunctionTests.cpp \
Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp \
Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp \
Tests/UnitTests/V2LibraryTests/DeviceSelectionTests.cpp \
Examples/Evaluation/CPPEvalV2Client/EvalMultithreads.cpp \
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests

Просмотреть файл

@ -5,6 +5,7 @@
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Utils.h"
#include "BestGpu.h"
#include <mutex>
#include <algorithm>
@ -62,8 +63,7 @@ namespace CNTK
auto selectedDevice = DefaultDevice();
if (!alreadyFrozen)
{
auto id = selectedDevice.Type() == DeviceKind::CPU ? CPUDEVICE : selectedDevice.Id();
Microsoft::MSR::CNTK::OnDeviceSelected(id);
Microsoft::MSR::CNTK::OnDeviceSelected(AsCNTKImplDeviceId(selectedDevice));
}
return selectedDevice;
}

Просмотреть файл

@ -32,7 +32,7 @@ namespace CNTK
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
{
if (device.Type() == DeviceKind::CPU)
return -1;
return CPUDEVICE;
else if (device.Type() == DeviceKind::GPU)
return device.Id();
else

Просмотреть файл

@ -0,0 +1,46 @@
#include "CNTKLibrary.h"
#include "Common.h"
using namespace CNTK;
void DeviceSelectionTests()
{
auto cpuDevice = DeviceDescriptor::CPUDevice();
DeviceDescriptor::SetDefaultDevice(cpuDevice);
assert(DeviceDescriptor::DefaultDevice() == cpuDevice);
auto bestDevice = DeviceDescriptor::BestDevice();
DeviceDescriptor::SetDefaultDevice(bestDevice);
assert(DeviceDescriptor::DefaultDevice() == bestDevice);
if (bestDevice != cpuDevice)
{
DeviceDescriptor::SetDefaultDevice(cpuDevice);
}
assert(DeviceDescriptor::UseDefaultDevice() == cpuDevice);
VerifyException([&cpuDevice]() {
DeviceDescriptor::SetDefaultDevice(cpuDevice);
}, "Was able to invoke SetDefaultDevice() after UseDefaultDevice().");
// Invoke BestDevice after releasing the lock in UseDefaultDevice().
bestDevice = DeviceDescriptor::BestDevice();
const auto& allDevices = DeviceDescriptor::AllDevices();
#ifdef CPUONLY
assert(allDevices.size() == 1);
#endif
auto numGpuDevices = allDevices.size() - 1;
VerifyException([&numGpuDevices]() {
DeviceDescriptor::GPUDevice((unsigned int)numGpuDevices);
}, "Was able to create GPU device descriptor with invalid id.");
assert(find(allDevices.begin(), allDevices.end(), bestDevice) != allDevices.end());
assert(allDevices.back() == cpuDevice);
}

Просмотреть файл

@ -15,6 +15,7 @@ void SerializationTests();
void LearnerTests();
void TrainSequenceToSequenceTranslator();
void EvalMultiThreadsWithNewNetwork(const DeviceDescriptor&, const int);
void DeviceSelectionTests();
int main()
{
@ -42,6 +43,8 @@ int main()
EvalMultiThreadsWithNewNetwork(DeviceDescriptor::GPUDevice(0), 2);
#endif
DeviceSelectionTests();
fprintf(stderr, "\nCNTKv2Library tests: Passed\n");
fflush(stderr);
}

Просмотреть файл

@ -1,4 +1,5 @@
#include "CNTKLibrary.h"
#include "Common.h"
#include <functional>
#include <array>
@ -103,31 +104,18 @@ void TestNDArrayView(size_t numAxes, const DeviceDescriptor& device)
throw std::runtime_error("The buffers underlying the alias view and the view it is an alias of are different!");
// Test readonliness
auto verifyException = [](const std::function<void()>& functionToTest) {
bool error = false;
try
{
functionToTest();
}
catch (const std::exception&)
{
error = true;
}
if (!error)
throw std::runtime_error("Was incorrectly able to get a writable buffer pointer from a readonly view");
};
auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";
// Should not be able to get the WritableDataBuffer for a read-only view
verifyException([&aliasView]() {
VerifyException([&aliasView]() {
ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();
aliasViewBuffer;
});
}, errorMsg);
// Should not be able to copy into a read-only view
verifyException([&aliasView, &dataView]() {
VerifyException([&aliasView, &dataView]() {
aliasView->CopyFrom(*dataView);
});
}, errorMsg);
}
template <typename ElementType>

Просмотреть файл

@ -111,6 +111,7 @@
<ItemGroup>
<ClCompile Include="..\..\..\Examples\Evaluation\CPPEvalV2Client\EvalMultithreads.cpp" />
<ClCompile Include="CifarResNet.cpp" />
<ClCompile Include="DeviceSelectionTests.cpp" />
<ClCompile Include="LearnerTests.cpp" />
<ClCompile Include="Seq2Seq.cpp" />
<ClCompile Include="SerializationTests.cpp" />

Просмотреть файл

@ -54,6 +54,9 @@
<ClCompile Include="..\..\..\Examples\Evaluation\CPPEvalV2Client\EvalMultithreads.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="DeviceSelectionTests.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="Common.h">