Run on all devices even if other device types fail (#186)

* Run on all device types even if the other device types fail

* Renamed models to model

* Fixed test failure so that last failure HRESULT from create session / bind / eval can be returned

* Made changes based on PR comments
This commit is contained in:
Ryan Lai 2019-03-06 11:16:02 -08:00 коммит произвёл GitHub
Родитель c208a75895
Коммит 42ee5b45de
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 66 добавлений и 55 удалений

Просмотреть файл

@ -359,41 +359,13 @@ HRESULT EvaluateModel(const LearningModel& model, const CommandLineArgs& args, O
return S_OK; return S_OK;
} }
HRESULT EvaluateModels(const std::vector<std::wstring>& modelPaths, const std::vector<DeviceType>& deviceTypes, HRESULT EvaluateModelWithDeviceType(const LearningModel& model, const DeviceType deviceType,
const std::vector<InputBindingType>& inputBindingTypes, const std::vector<InputBindingType>& inputBindingTypes,
const std::vector<InputDataType>& inputDataTypes, const std::vector<InputDataType>& inputDataTypes,
const std::vector<DeviceCreationLocation> deviceCreationLocations, const CommandLineArgs& args, const std::vector<DeviceCreationLocation> deviceCreationLocations,
OutputHelper& output, Profiler<WINML_MODEL_TEST_PERF>& profiler) const CommandLineArgs& args, const std::wstring& modelPath, OutputHelper& output,
{ Profiler<WINML_MODEL_TEST_PERF>& profiler,
output.PrintHardwareInfo(); TensorFeatureDescriptor& tensorDescriptor)
for (const auto& path : modelPaths)
{
LearningModel model = nullptr;
try
{
model =
LoadModel(path, args.IsPerformanceCapture() || args.IsPerIterationCapture(), output, args, 0, profiler);
}
catch (hresult_error hr)
{
std::cout << hr.message().c_str() << std::endl;
return hr.code();
}
auto firstFeature = model.InputFeatures().First().Current();
auto tensorDescriptor = firstFeature.try_as<TensorFeatureDescriptor>();
// Map and Sequence bindings are not supported yet
if (!tensorDescriptor)
{
std::wcout << L"Model: " + path + L" has an input type that isn't supported by WinMLRunner yet."
<< std::endl;
continue;
}
for (const auto& deviceType : deviceTypes)
{ {
for (const auto& inputBindingType : inputBindingTypes) for (const auto& inputBindingType : inputBindingTypes)
{ {
@ -416,8 +388,8 @@ HRESULT EvaluateModels(const std::vector<std::wstring>& modelPaths, const std::v
} }
} }
HRESULT evalHResult = EvaluateModel(model, args, output, deviceType, inputBindingType, HRESULT evalHResult = EvaluateModel(model, args, output, deviceType, inputBindingType, inputDataType,
inputDataType, deviceCreationLocation, profiler); deviceCreationLocation, profiler);
if (FAILED(evalHResult)) if (FAILED(evalHResult))
{ {
@ -426,31 +398,70 @@ HRESULT EvaluateModels(const std::vector<std::wstring>& modelPaths, const std::v
if (args.IsPerformanceCapture()) if (args.IsPerformanceCapture())
{ {
output.PrintResults(profiler, args.NumIterations(), deviceType, inputBindingType, output.PrintResults(profiler, args.NumIterations(), deviceType, inputBindingType, inputDataType,
inputDataType, deviceCreationLocation, deviceCreationLocation, args.IsPerformanceConsoleOutputVerbose());
args.IsPerformanceConsoleOutputVerbose());
if (args.IsOutputPerf()) if (args.IsOutputPerf())
{ {
std::string deviceTypeStringified = TypeHelper::Stringify(deviceType); std::string deviceTypeStringified = TypeHelper::Stringify(deviceType);
std::string inputDataTypeStringified = TypeHelper::Stringify(inputDataType); std::string inputDataTypeStringified = TypeHelper::Stringify(inputDataType);
std::string inputBindingTypeStringified = TypeHelper::Stringify(inputBindingType); std::string inputBindingTypeStringified = TypeHelper::Stringify(inputBindingType);
std::string deviceCreationLocationStringified = std::string deviceCreationLocationStringified = TypeHelper::Stringify(deviceCreationLocation);
TypeHelper::Stringify(deviceCreationLocation); output.WritePerformanceDataToCSV(
output.WritePerformanceDataToCSV(profiler, args.NumIterations(), path, profiler, args.NumIterations(), modelPath, deviceTypeStringified, inputDataTypeStringified,
deviceTypeStringified, inputDataTypeStringified, inputBindingTypeStringified, deviceCreationLocationStringified);
inputBindingTypeStringified,
deviceCreationLocationStringified);
} }
} }
} }
} }
} }
return S_OK;
} }
HRESULT EvaluateModels(const std::vector<std::wstring>& modelPaths, const std::vector<DeviceType>& deviceTypes,
const std::vector<InputBindingType>& inputBindingTypes,
const std::vector<InputDataType>& inputDataTypes,
const std::vector<DeviceCreationLocation> deviceCreationLocations, const CommandLineArgs& args,
OutputHelper& output, Profiler<WINML_MODEL_TEST_PERF>& profiler)
{
output.PrintHardwareInfo();
HRESULT lastEvaluateModelResult = S_OK;
for (const auto& path : modelPaths)
{
LearningModel model = nullptr;
try
{
model =
LoadModel(path, args.IsPerformanceCapture() || args.IsPerIterationCapture(), output, args, 0, profiler);
}
catch (hresult_error hr)
{
std::cout << hr.message().c_str() << std::endl;
return hr.code();
}
auto firstFeature = model.InputFeatures().First().Current();
auto tensorDescriptor = firstFeature.try_as<TensorFeatureDescriptor>();
// Map and Sequence bindings are not supported yet
if (!tensorDescriptor)
{
std::wcout << L"Model: " + path + L" has an input type that isn't supported by WinMLRunner yet."
<< std::endl;
continue;
}
for (const auto& deviceType : deviceTypes)
{
HRESULT evaluateModelWithDeviceTypeResult =
EvaluateModelWithDeviceType(model, deviceType, inputBindingTypes, inputDataTypes,
deviceCreationLocations, args, path, output, profiler, tensorDescriptor);
if (FAILED(evaluateModelWithDeviceTypeResult))
{
lastEvaluateModelResult = evaluateModelWithDeviceTypeResult;
std::cout << "Run failed for DeviceType: " << TypeHelper::Stringify(deviceType) << std::endl;
}
}
model.Close(); model.Close();
} }
return lastEvaluateModelResult; // Return the last HRESULT that failed. Will return S_OK otherwise.
return S_OK;
} }
std::vector<InputDataType> FetchInputDataTypes(const CommandLineArgs& args) std::vector<InputDataType> FetchInputDataTypes(const CommandLineArgs& args)