Add UTs for device selection API
This commit is contained in:
Родитель
2e2cb6430f
Коммит
2f53509644
1
Makefile
1
Makefile
|
@ -422,6 +422,7 @@ CNTKLIBRARY_TESTS_SRC =\
|
|||
Tests/UnitTests/V2LibraryTests/FunctionTests.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/SequenceClassification.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/Seq2Seq.cpp \
|
||||
Tests/UnitTests/V2LibraryTests/DeviceSelectionTests.cpp \
|
||||
Examples/Evaluation/CPPEvalV2Client/EvalMultithreads.cpp \
|
||||
|
||||
CNTKLIBRARY_TESTS:=$(BINDIR)/v2librarytests
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include "stdafx.h"
|
||||
#include "CNTKLibrary.h"
|
||||
#include "Utils.h"
|
||||
#include "BestGpu.h"
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
|
@ -62,8 +63,7 @@ namespace CNTK
|
|||
auto selectedDevice = DefaultDevice();
|
||||
if (!alreadyFrozen)
|
||||
{
|
||||
auto id = selectedDevice.Type() == DeviceKind::CPU ? CPUDEVICE : selectedDevice.Id();
|
||||
Microsoft::MSR::CNTK::OnDeviceSelected(id);
|
||||
Microsoft::MSR::CNTK::OnDeviceSelected(AsCNTKImplDeviceId(selectedDevice));
|
||||
}
|
||||
return selectedDevice;
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace CNTK
|
|||
inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
|
||||
{
|
||||
if (device.Type() == DeviceKind::CPU)
|
||||
return -1;
|
||||
return CPUDEVICE;
|
||||
else if (device.Type() == DeviceKind::GPU)
|
||||
return device.Id();
|
||||
else
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#include "CNTKLibrary.h"
|
||||
#include "Common.h"
|
||||
|
||||
using namespace CNTK;
|
||||
|
||||
void DeviceSelectionTests()
|
||||
{
|
||||
auto cpuDevice = DeviceDescriptor::CPUDevice();
|
||||
DeviceDescriptor::SetDefaultDevice(cpuDevice);
|
||||
|
||||
assert(DeviceDescriptor::DefaultDevice() == cpuDevice);
|
||||
|
||||
auto bestDevice = DeviceDescriptor::BestDevice();
|
||||
DeviceDescriptor::SetDefaultDevice(bestDevice);
|
||||
|
||||
assert(DeviceDescriptor::DefaultDevice() == bestDevice);
|
||||
|
||||
if (bestDevice != cpuDevice)
|
||||
{
|
||||
DeviceDescriptor::SetDefaultDevice(cpuDevice);
|
||||
}
|
||||
|
||||
assert(DeviceDescriptor::UseDefaultDevice() == cpuDevice);
|
||||
|
||||
VerifyException([&cpuDevice]() {
|
||||
DeviceDescriptor::SetDefaultDevice(cpuDevice);
|
||||
}, "Was able to invoke SetDefaultDevice() after UseDefaultDevice().");
|
||||
|
||||
// Invoke BestDevice after releasing the lock in UseDefaultDevice().
|
||||
bestDevice = DeviceDescriptor::BestDevice();
|
||||
|
||||
const auto& allDevices = DeviceDescriptor::AllDevices();
|
||||
|
||||
#ifdef CPUONLY
|
||||
assert(allDevices.size() == 1);
|
||||
#endif
|
||||
auto numGpuDevices = allDevices.size() - 1;
|
||||
|
||||
VerifyException([&numGpuDevices]() {
|
||||
DeviceDescriptor::GPUDevice((unsigned int)numGpuDevices);
|
||||
}, "Was able to create GPU device descriptor with invalid id.");
|
||||
|
||||
assert(find(allDevices.begin(), allDevices.end(), bestDevice) != allDevices.end());
|
||||
assert(allDevices.back() == cpuDevice);
|
||||
|
||||
}
|
|
@ -15,6 +15,7 @@ void SerializationTests();
|
|||
void LearnerTests();
|
||||
void TrainSequenceToSequenceTranslator();
|
||||
void EvalMultiThreadsWithNewNetwork(const DeviceDescriptor&, const int);
|
||||
void DeviceSelectionTests();
|
||||
|
||||
int main()
|
||||
{
|
||||
|
@ -42,6 +43,8 @@ int main()
|
|||
EvalMultiThreadsWithNewNetwork(DeviceDescriptor::GPUDevice(0), 2);
|
||||
#endif
|
||||
|
||||
DeviceSelectionTests();
|
||||
|
||||
fprintf(stderr, "\nCNTKv2Library tests: Passed\n");
|
||||
fflush(stderr);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include "CNTKLibrary.h"
|
||||
#include "Common.h"
|
||||
#include <functional>
|
||||
#include <array>
|
||||
|
||||
|
@ -103,31 +104,18 @@ void TestNDArrayView(size_t numAxes, const DeviceDescriptor& device)
|
|||
throw std::runtime_error("The buffers underlying the alias view and the view it is an alias of are different!");
|
||||
|
||||
// Test readonliness
|
||||
auto verifyException = [](const std::function<void()>& functionToTest) {
|
||||
bool error = false;
|
||||
try
|
||||
{
|
||||
functionToTest();
|
||||
}
|
||||
catch (const std::exception&)
|
||||
{
|
||||
error = true;
|
||||
}
|
||||
|
||||
if (!error)
|
||||
throw std::runtime_error("Was incorrectly able to get a writable buffer pointer from a readonly view");
|
||||
};
|
||||
auto errorMsg = "Was incorrectly able to get a writable buffer pointer from a readonly view";
|
||||
|
||||
// Should not be able to get the WritableDataBuffer for a read-only view
|
||||
verifyException([&aliasView]() {
|
||||
VerifyException([&aliasView]() {
|
||||
ElementType* aliasViewBuffer = aliasView->WritableDataBuffer<ElementType>();
|
||||
aliasViewBuffer;
|
||||
});
|
||||
}, errorMsg);
|
||||
|
||||
// Should not be able to copy into a read-only view
|
||||
verifyException([&aliasView, &dataView]() {
|
||||
VerifyException([&aliasView, &dataView]() {
|
||||
aliasView->CopyFrom(*dataView);
|
||||
});
|
||||
}, errorMsg);
|
||||
}
|
||||
|
||||
template <typename ElementType>
|
||||
|
|
|
@ -111,6 +111,7 @@
|
|||
<ItemGroup>
|
||||
<ClCompile Include="..\..\..\Examples\Evaluation\CPPEvalV2Client\EvalMultithreads.cpp" />
|
||||
<ClCompile Include="CifarResNet.cpp" />
|
||||
<ClCompile Include="DeviceSelectionTests.cpp" />
|
||||
<ClCompile Include="LearnerTests.cpp" />
|
||||
<ClCompile Include="Seq2Seq.cpp" />
|
||||
<ClCompile Include="SerializationTests.cpp" />
|
||||
|
|
|
@ -54,6 +54,9 @@
|
|||
<ClCompile Include="..\..\..\Examples\Evaluation\CPPEvalV2Client\EvalMultithreads.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DeviceSelectionTests.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="Common.h">
|
||||
|
|
Загрузка…
Ссылка в новой задаче