Add functions to allow ILearningModelFeatureValues to be plumbed through as input through WinMLRunner static library (#327)
* Make changes to allow protobuf input through commandline args * use protobuf helper * Use feature values instead of protobuf * Commandline args changes * added clear method * Naming changes * Formatting * DOn't modify git modules
This commit is contained in:
Родитель
283c62201e
Коммит
ad214438fb
|
@ -0,0 +1 @@
|
|||
Subproject commit ccbf49e59f6bef897a94595e2213263b37d64ff3
|
|
@ -600,3 +600,7 @@ void CommandLineArgs::AddPerformanceFileMetadata(const std::string& key, const s
|
|||
cleanedValue.erase(std::remove_copy(value.begin(), value.end(), cleanedValue.begin(), ','), cleanedValue.end());
|
||||
m_perfFileMetadata.push_back(std::make_pair(key, cleanedValue));
|
||||
}
|
||||
void CommandLineArgs::AddProvidedInputFeatureValue(const ILearningModelFeatureValue& input)
|
||||
{
|
||||
m_providedInputFeatureValues.push_back(input);
|
||||
}
|
|
@ -45,6 +45,10 @@ public:
|
|||
BitmapInterpolationMode AutoScaleInterpMode() const { return m_autoScaleInterpMode; }
|
||||
|
||||
const std::vector<std::wstring>& ImagePaths() const { return m_imagePaths; }
|
||||
const std::vector<ILearningModelFeatureValue>& ProvidedInputFeatureValues() const
|
||||
{
|
||||
return m_providedInputFeatureValues;
|
||||
}
|
||||
const std::wstring& CsvPath() const { return m_csvData; }
|
||||
const std::wstring& OutputPath() const { return m_perfOutputPath; }
|
||||
const std::wstring& FolderPath() const { return m_modelFolderPath; }
|
||||
|
@ -92,11 +96,11 @@ public:
|
|||
bool IsGarbageInput() const
|
||||
{
|
||||
// When there is no image or csv input provided, then garbage input binding is used.
|
||||
return m_imagePaths.empty() && m_csvData.empty();
|
||||
return m_imagePaths.empty() && m_csvData.empty() && m_providedInputFeatureValues.empty();
|
||||
}
|
||||
bool IsCSVInput() const { return m_imagePaths.empty() && !m_csvData.empty(); }
|
||||
bool IsImageInput() const { return !m_imagePaths.empty() && m_csvData.empty(); }
|
||||
|
||||
bool InputFeatureValuesProvided() const { return !m_providedInputFeatureValues.empty(); }
|
||||
uint32_t NumIterations() const { return m_numIterations; }
|
||||
uint32_t NumLoadIterations() const { return m_numLoadIterations; }
|
||||
uint32_t NumSessionCreationIterations() const { return m_numSessionIterations; }
|
||||
|
@ -140,6 +144,8 @@ public:
|
|||
void SetSessionCreationIterations(const uint32_t iterations) { m_numSessionIterations = iterations; }
|
||||
void SetLoadIterations(const uint32_t iterations) { m_numLoadIterations = iterations; }
|
||||
void AddPerformanceFileMetadata(const std::string& key, const std::string& value);
|
||||
void AddProvidedInputFeatureValue(const ILearningModelFeatureValue& input);
|
||||
void ClearProvidedInputFeatureValues() { m_providedInputFeatureValues.clear(); };
|
||||
void SetGarbageDataMaxValue(const uint32_t value) { m_garbageDataMaxValue = value; }
|
||||
|
||||
// Stop iterating when total time of iterations after the first iteration exceeds time limit.
|
||||
|
@ -185,6 +191,7 @@ private:
|
|||
std::wstring m_modelFolderPath;
|
||||
std::wstring m_modelPath;
|
||||
std::vector<std::wstring> m_imagePaths;
|
||||
std::vector<ILearningModelFeatureValue> m_providedInputFeatureValues;
|
||||
std::wstring m_inputImageFolderPath;
|
||||
std::wstring m_csvData;
|
||||
std::wstring m_inputData;
|
||||
|
|
|
@ -221,17 +221,23 @@ HRESULT BindInputs(LearningModelBinding& context, const LearningModelSession& se
|
|||
bool captureIterationPerf = args.IsPerformanceCapture() || args.IsPerIterationCapture();
|
||||
|
||||
std::vector<ILearningModelFeatureValue> inputFeatures;
|
||||
try
|
||||
if (args.InputFeatureValuesProvided())
|
||||
{
|
||||
inputFeatures = GenerateInputFeatures(session.Model(), args, inputBindingType, inputDataType, device, iteration, imagePath);
|
||||
inputFeatures = args.ProvidedInputFeatureValues();
|
||||
}
|
||||
catch (hresult_error hr)
|
||||
else
|
||||
{
|
||||
std::wcout << "\nGenerating Input Features [FAILED]" << std::endl;
|
||||
std::wcout << hr.message().c_str() << std::endl;
|
||||
return hr.code();
|
||||
try
|
||||
{
|
||||
inputFeatures = GenerateInputFeatures(session.Model(), args, inputBindingType, inputDataType, device, iteration, imagePath);
|
||||
}
|
||||
catch (hresult_error hr)
|
||||
{
|
||||
std::wcout << "\nGenerating Input Features [FAILED]" << std::endl;
|
||||
std::wcout << hr.message().c_str() << std::endl;
|
||||
return hr.code();
|
||||
}
|
||||
}
|
||||
|
||||
HRESULT bindInputResult =
|
||||
BindInputFeatures(session.Model(), context, inputFeatures, args, output, captureIterationPerf, iteration, profiler);
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче