"write" command can now apply a labelMappingFile to print classification results as text labels
This commit is contained in:
Родитель
6ea2892967
Коммит
343893e882
|
@ -256,10 +256,15 @@ void DoWriteOutput(const ConfigParameters& config)
|
|||
if (config.Exists("format"))
|
||||
{
|
||||
ConfigParameters formatConfig(config(L"format"));
|
||||
string type = formatConfig(L"type", "");
|
||||
if (type == "real") formattingOptions.isCategoryLabel = false;
|
||||
else if (type == "category") formattingOptions.isCategoryLabel = true;
|
||||
else if (type != "") InvalidArgument("write: type must be 'real' or 'category'");
|
||||
if (formatConfig.ExistsCurrent("type")) // do not inherit 'type' from outer block
|
||||
{
|
||||
string type = formatConfig(L"type");
|
||||
if (type == "real") formattingOptions.isCategoryLabel = false;
|
||||
else if (type == "category") formattingOptions.isCategoryLabel = true;
|
||||
else InvalidArgument("write: type must be 'real' or 'category'");
|
||||
if (formattingOptions.isCategoryLabel)
|
||||
formattingOptions.labelMappingFile = formatConfig(L"labelMappingFile", L"");
|
||||
}
|
||||
formattingOptions.transpose = formatConfig(L"transpose", formattingOptions.transpose);
|
||||
formattingOptions.prologue = formatConfig(L"prologue", formattingOptions.prologue);
|
||||
formattingOptions.epilogue = formatConfig(L"epilogue", formattingOptions.epilogue);
|
||||
|
|
|
@ -249,6 +249,31 @@ public:
|
|||
// This function does not quite fit here, but it fits elsewhere even worse. TODO: change to use File class!
|
||||
template <class ElemType>
|
||||
static vector<ElemType> LoadMatrixFromTextFile(const std::wstring& filePath, size_t& /*out*/ numRows, size_t& /*out*/ numCols);
|
||||
|
||||
// Read a label file.
|
||||
// A label file is a sequence of text lines with one token per line, where each line maps a string to an index, starting with 0.
|
||||
// This function allows spaces inside the word name, but trims surrounding spaces.
|
||||
// TODO: Move this to class File, as this is similar in nature to LoadMatrixFromTextFile().
|
||||
template <class LabelType>
|
||||
static void LoadLabelFile(const std::wstring& filePath, std::vector<LabelType>& retLabels)
|
||||
{
|
||||
File file(filePath, fileOptionsRead | fileOptionsText);
|
||||
|
||||
LabelType str;
|
||||
retLabels.clear();
|
||||
while (!file.IsEOF())
|
||||
{
|
||||
file.GetLine(str);
|
||||
if (str.empty())
|
||||
if (file.IsEOF())
|
||||
break;
|
||||
else
|
||||
RuntimeError("LoadLabelFile: Invalid empty line in label file.");
|
||||
|
||||
retLabels.push_back(trim(str));
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}}}
|
||||
|
|
|
@ -310,6 +310,7 @@ void SequenceReader<ElemType>::UpdateDataVariables()
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: move this to class File as well
|
||||
template <class ElemType>
|
||||
void SequenceReader<ElemType>::WriteLabelFile()
|
||||
{
|
||||
|
@ -339,29 +340,6 @@ void SequenceReader<ElemType>::WriteLabelFile()
|
|||
}
|
||||
}
|
||||
|
||||
// a label file is a sequence of text lines with one token per line
|
||||
// This function allows spaces inside the word name, but trims surrounding spaces.
|
||||
// TODO: Move this to class File, as this is similar in nature to LoadMatrixFromTextFile().
|
||||
template <class ElemType>
|
||||
void SequenceReader<ElemType>::LoadLabelFile(const std::wstring& filePath, std::vector<LabelType>& retLabels)
|
||||
{
|
||||
File file(filePath, fileOptionsRead | fileOptionsText);
|
||||
|
||||
string str;
|
||||
retLabels.clear();
|
||||
while (!file.IsEOF())
|
||||
{
|
||||
file.GetLine(str);
|
||||
if (str.empty())
|
||||
if (file.IsEOF())
|
||||
break;
|
||||
else
|
||||
RuntimeError("LoadLabelFile: Invalid empty line in label file.");
|
||||
|
||||
retLabels.push_back(trim(str));
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy - cleanup and remove this class
|
||||
// NOTE: this destroys the object, and it can't be used past this point
|
||||
template <class ElemType>
|
||||
|
@ -482,7 +460,7 @@ void SequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConf
|
|||
std::wstring labelPath = labelConfig(L"labelMappingFile");
|
||||
if (fexists(labelPath))
|
||||
{
|
||||
LoadLabelFile(labelPath, arrayLabels);
|
||||
File::LoadLabelFile(labelPath, arrayLabels);
|
||||
for (int i = 0; i < arrayLabels.size(); ++i)
|
||||
{
|
||||
LabelType label = arrayLabels[i];
|
||||
|
@ -1480,7 +1458,7 @@ void BatchSequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& reade
|
|||
std::wstring labelPath = labelConfig(L"labelMappingFile");
|
||||
if (File::Exists(labelPath))
|
||||
{
|
||||
LoadLabelFile(labelPath, arrayLabels);
|
||||
File::LoadLabelFile(labelPath, arrayLabels);
|
||||
// build the two-way mapping tables
|
||||
for (int i = 0; i < arrayLabels.size(); ++i)
|
||||
{
|
||||
|
|
|
@ -221,7 +221,6 @@ protected:
|
|||
size_t RecordsToRead(size_t mbStartSample, bool tail = false);
|
||||
void ReleaseMemory();
|
||||
void WriteLabelFile();
|
||||
void LoadLabelFile(const std::wstring& filePath, std::vector<LabelType>& retLabels);
|
||||
|
||||
LabelIdType GetIdFromLabel(const std::string& label, LabelInfo& labelInfo);
|
||||
bool CheckIdFromLabel(const std::string& labelValue, const LabelInfo& labelInfo, unsigned& labelId);
|
||||
|
@ -316,7 +315,6 @@ public:
|
|||
using Base::labelInfoIn;
|
||||
using Base::nwords;
|
||||
using Base::ReadClassInfo;
|
||||
using Base::LoadLabelFile;
|
||||
using Base::word4idx;
|
||||
using Base::idx4word;
|
||||
using Base::idx4cnt;
|
||||
|
|
|
@ -109,11 +109,13 @@ public:
|
|||
// clean up
|
||||
}
|
||||
|
||||
// pass this to WriteOutput() (to file-path, below) to specify how the output should be formatted
|
||||
struct WriteFormattingOptions
|
||||
{
|
||||
// How to interpret the data:
|
||||
bool isCategoryLabel; // true: find max value in column and output the index instead of the entire vector
|
||||
bool transpose; // true: one line per sample, each sample (column vector) forms one line; false: one column per sample
|
||||
bool isCategoryLabel; // true: find max value in column and output the index instead of the entire vector
|
||||
std::wstring labelMappingFile; // optional dictionary for pretty-printing category labels
|
||||
bool transpose; // true: one line per sample, each sample (column vector) forms one line; false: one column per sample
|
||||
// The following strings are interspersed with the data:
|
||||
// overall
|
||||
std::string prologue; // print this at the start (e.g. a global header or opening bracket)
|
||||
|
@ -133,11 +135,13 @@ public:
|
|||
{ }
|
||||
};
|
||||
|
||||
// TODO: Remove code dup with above function
|
||||
// E.g. create a shared function that takes the actual writing operation as a lambda.
|
||||
// TODO: Remove code dup with above function by creating a fake Writer object and then calling the other function.
|
||||
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
|
||||
{
|
||||
File::MakeIntermediateDirs(outputPath);
|
||||
// load a label mapping if requested
|
||||
std::vector<std::string> labelMapping;
|
||||
if (formattingOptions.isCategoryLabel && !formattingOptions.labelMappingFile.empty())
|
||||
File::LoadLabelFile(formattingOptions.labelMappingFile, labelMapping);
|
||||
|
||||
// specify output nodes and files
|
||||
std::vector<ComputationNodeBasePtr> outputNodes;
|
||||
|
@ -155,6 +159,8 @@ public:
|
|||
outputNodes.push_back(m_net->GetNodeFromName(outputNodeNames[i]));
|
||||
}
|
||||
|
||||
// open output files
|
||||
File::MakeIntermediateDirs(outputPath);
|
||||
std::map<ComputationNodeBasePtr, shared_ptr<File>> outputStreams; // TODO: why does unique_ptr not work here? Complains about non-existent default_delete()
|
||||
for (auto & onode : outputNodes)
|
||||
{
|
||||
|
@ -191,7 +197,8 @@ public:
|
|||
fprintfOrDie(f, "%s", formattingOptions.prologue.c_str());
|
||||
}
|
||||
|
||||
std::string valueFormatString = "%" + formattingOptions.precisionFormat + "f"; // format string used in fprintf() for formatting the values
|
||||
char formatChar = !formattingOptions.isCategoryLabel ? 'f' : !formattingOptions.labelMappingFile.empty() ? 's' : 'u';
|
||||
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
|
||||
|
||||
size_t actualMBSize;
|
||||
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
|
||||
|
@ -201,13 +208,6 @@ public:
|
|||
|
||||
for (auto & onode : outputNodes)
|
||||
{
|
||||
FILE * f = *outputStreams[onode];
|
||||
|
||||
// sequence separator
|
||||
if (numMBsRun > 0 && !formattingOptions.sequenceSeparator.empty())
|
||||
fprintfOrDie(f, "%s", formattingOptions.sequenceSeparator.c_str());
|
||||
fprintfOrDie(f, "%s", formattingOptions.sequencePrologue.c_str());
|
||||
|
||||
// compute the node value
|
||||
// Note: Intermediate values are memoized, so in case of multiple output nodes, we only compute what has not been computed already.
|
||||
m_net->ForwardProp(onode);
|
||||
|
@ -217,11 +217,22 @@ public:
|
|||
outputValues.CopyToArray(tempArray, tempArraySize);
|
||||
ElemType* pCurValue = tempArray;
|
||||
|
||||
// sequence separator
|
||||
FILE * f = *outputStreams[onode];
|
||||
if (numMBsRun > 0 && !formattingOptions.sequenceSeparator.empty())
|
||||
fprintfOrDie(f, "%s", formattingOptions.sequenceSeparator.c_str());
|
||||
fprintfOrDie(f, "%s", formattingOptions.sequencePrologue.c_str());
|
||||
|
||||
// output it according to our format specification
|
||||
size_t T = outputValues.GetNumCols();
|
||||
size_t dim = outputValues.GetNumRows();
|
||||
if (formattingOptions.isCategoryLabel)
|
||||
{
|
||||
if (formatChar == 's') // verify label dimension
|
||||
{
|
||||
if (dim != labelMapping.size())
|
||||
InvalidArgument("write: Row dimension %d does not match number of entries %d in labelMappingFile '%ls'", (int)dim, (int)labelMapping.size(), formattingOptions.labelMappingFile.c_str());
|
||||
}
|
||||
// update the matrix in-place from one-hot (or max) to index
|
||||
// find the max in each column
|
||||
foreach_column(j, outputValues)
|
||||
|
@ -253,8 +264,23 @@ public:
|
|||
{
|
||||
if (i > 0)
|
||||
fprintfOrDie(f, "%s", formattingOptions.elementSeparator.c_str());
|
||||
double val = pCurValue[i * istride + j * jstride];
|
||||
fprintfOrDie(f, valueFormatString.c_str(), (double)val);
|
||||
if (formatChar == 'f') // print as real number
|
||||
{
|
||||
double dval = pCurValue[i * istride + j * jstride];
|
||||
fprintfOrDie(f, valueFormatString.c_str(), dval);
|
||||
}
|
||||
else if (formatChar == 'u') // print category as integer index
|
||||
{
|
||||
unsigned int uval = (unsigned int) pCurValue[i * istride + j * jstride];
|
||||
fprintfOrDie(f, valueFormatString.c_str(), uval);
|
||||
}
|
||||
else if (formatChar == 's') // print category as a label string
|
||||
{
|
||||
size_t uval = (size_t) pCurValue[i * istride + j * jstride];
|
||||
assert(uval < labelMapping.size());
|
||||
const char * sval = labelMapping[uval].c_str();
|
||||
fprintfOrDie(f, valueFormatString.c_str(), sval);
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintfOrDie(f, "%s", formattingOptions.sequenceEpilogue.c_str());
|
||||
|
|
Загрузка…
Ссылка в новой задаче