Add functions to allow ILearningModelFeatureValues to be plumbed through as input through WinMLRunner static library (#327)

* Make changes to allow protobuf input through commandline args

* use protobuf helper

* Use feature values instead of protobuf

* Commandline args changes

* added clear method

* Naming changes

* Formatting

* DOn't modify git modules
This commit is contained in:
Ryan Lai 2020-07-16 17:06:57 -07:00 коммит произвёл GitHub
Родитель 283c62201e
Коммит ad214438fb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 27 добавлений и 9 удалений

@ -0,0 +1 @@
Subproject commit ccbf49e59f6bef897a94595e2213263b37d64ff3

Просмотреть файл

@ -600,3 +600,7 @@ void CommandLineArgs::AddPerformanceFileMetadata(const std::string& key, const s
cleanedValue.erase(std::remove_copy(value.begin(), value.end(), cleanedValue.begin(), ','), cleanedValue.end());
m_perfFileMetadata.push_back(std::make_pair(key, cleanedValue));
}
void CommandLineArgs::AddProvidedInputFeatureValue(const ILearningModelFeatureValue& input)
{
m_providedInputFeatureValues.push_back(input);
}

Просмотреть файл

@ -45,6 +45,10 @@ public:
BitmapInterpolationMode AutoScaleInterpMode() const { return m_autoScaleInterpMode; }
const std::vector<std::wstring>& ImagePaths() const { return m_imagePaths; }
const std::vector<ILearningModelFeatureValue>& ProvidedInputFeatureValues() const
{
return m_providedInputFeatureValues;
}
const std::wstring& CsvPath() const { return m_csvData; }
const std::wstring& OutputPath() const { return m_perfOutputPath; }
const std::wstring& FolderPath() const { return m_modelFolderPath; }
@ -92,11 +96,11 @@ public:
bool IsGarbageInput() const
{
// When there is no image or csv input provided, then garbage input binding is used.
return m_imagePaths.empty() && m_csvData.empty();
return m_imagePaths.empty() && m_csvData.empty() && m_providedInputFeatureValues.empty();
}
bool IsCSVInput() const { return m_imagePaths.empty() && !m_csvData.empty(); }
bool IsImageInput() const { return !m_imagePaths.empty() && m_csvData.empty(); }
bool InputFeatureValuesProvided() const { return !m_providedInputFeatureValues.empty(); }
uint32_t NumIterations() const { return m_numIterations; }
uint32_t NumLoadIterations() const { return m_numLoadIterations; }
uint32_t NumSessionCreationIterations() const { return m_numSessionIterations; }
@ -140,6 +144,8 @@ public:
void SetSessionCreationIterations(const uint32_t iterations) { m_numSessionIterations = iterations; }
void SetLoadIterations(const uint32_t iterations) { m_numLoadIterations = iterations; }
void AddPerformanceFileMetadata(const std::string& key, const std::string& value);
void AddProvidedInputFeatureValue(const ILearningModelFeatureValue& input);
void ClearProvidedInputFeatureValues() { m_providedInputFeatureValues.clear(); };
void SetGarbageDataMaxValue(const uint32_t value) { m_garbageDataMaxValue = value; }
// Stop iterating when total time of iterations after the first iteration exceeds time limit.
@ -185,6 +191,7 @@ private:
std::wstring m_modelFolderPath;
std::wstring m_modelPath;
std::vector<std::wstring> m_imagePaths;
std::vector<ILearningModelFeatureValue> m_providedInputFeatureValues;
std::wstring m_inputImageFolderPath;
std::wstring m_csvData;
std::wstring m_inputData;

Просмотреть файл

@ -221,17 +221,23 @@ HRESULT BindInputs(LearningModelBinding& context, const LearningModelSession& se
bool captureIterationPerf = args.IsPerformanceCapture() || args.IsPerIterationCapture();
std::vector<ILearningModelFeatureValue> inputFeatures;
try
if (args.InputFeatureValuesProvided())
{
inputFeatures = GenerateInputFeatures(session.Model(), args, inputBindingType, inputDataType, device, iteration, imagePath);
inputFeatures = args.ProvidedInputFeatureValues();
}
catch (hresult_error hr)
else
{
std::wcout << "\nGenerating Input Features [FAILED]" << std::endl;
std::wcout << hr.message().c_str() << std::endl;
return hr.code();
try
{
inputFeatures = GenerateInputFeatures(session.Model(), args, inputBindingType, inputDataType, device, iteration, imagePath);
}
catch (hresult_error hr)
{
std::wcout << "\nGenerating Input Features [FAILED]" << std::endl;
std::wcout << hr.message().c_str() << std::endl;
return hr.code();
}
}
HRESULT bindInputResult =
BindInputFeatures(session.Model(), context, inputFeatures, args, output, captureIterationPerf, iteration, profiler);