"write" command can now apply a labelMappingFile to print classification results as text labels

This commit is contained in:
Frank Seide 2016-02-15 11:45:54 -08:00
Родитель 6ea2892967
Коммит 343893e882
5 изменённых файлов: 78 добавлений и 46 удалений

Просмотреть файл

@ -256,10 +256,15 @@ void DoWriteOutput(const ConfigParameters& config)
if (config.Exists("format"))
{
ConfigParameters formatConfig(config(L"format"));
string type = formatConfig(L"type", "");
if (type == "real") formattingOptions.isCategoryLabel = false;
else if (type == "category") formattingOptions.isCategoryLabel = true;
else if (type != "") InvalidArgument("write: type must be 'real' or 'category'");
if (formatConfig.ExistsCurrent("type")) // do not inherit 'type' from outer block
{
string type = formatConfig(L"type");
if (type == "real") formattingOptions.isCategoryLabel = false;
else if (type == "category") formattingOptions.isCategoryLabel = true;
else InvalidArgument("write: type must be 'real' or 'category'");
if (formattingOptions.isCategoryLabel)
formattingOptions.labelMappingFile = formatConfig(L"labelMappingFile", L"");
}
formattingOptions.transpose = formatConfig(L"transpose", formattingOptions.transpose);
formattingOptions.prologue = formatConfig(L"prologue", formattingOptions.prologue);
formattingOptions.epilogue = formatConfig(L"epilogue", formattingOptions.epilogue);

Просмотреть файл

@ -249,6 +249,31 @@ public:
// This function does not quite fit here, but it fits elsewhere even worse. TODO: change to use File class!
template <class ElemType>
static vector<ElemType> LoadMatrixFromTextFile(const std::wstring& filePath, size_t& /*out*/ numRows, size_t& /*out*/ numCols);
// Read a label file.
// A label file is a sequence of text lines with one token per line, where each line maps a string to an index, starting with 0.
// This function allows spaces inside the word name, but trims surrounding spaces.
// TODO: Move this to class File, as this is similar in nature to LoadMatrixFromTextFile().
template <class LabelType>
static void LoadLabelFile(const std::wstring& filePath, std::vector<LabelType>& retLabels)
{
File file(filePath, fileOptionsRead | fileOptionsText);
LabelType str;
retLabels.clear();
while (!file.IsEOF())
{
file.GetLine(str);
if (str.empty())
if (file.IsEOF())
break;
else
RuntimeError("LoadLabelFile: Invalid empty line in label file.");
retLabels.push_back(trim(str));
}
}
};
}}}

Просмотреть файл

@ -310,6 +310,7 @@ void SequenceReader<ElemType>::UpdateDataVariables()
}
}
// TODO: move this to class File as well
template <class ElemType>
void SequenceReader<ElemType>::WriteLabelFile()
{
@ -339,29 +340,6 @@ void SequenceReader<ElemType>::WriteLabelFile()
}
}
// a label file is a sequence of text lines with one token per line
// This function allows spaces inside the word name, but trims surrounding spaces.
// TODO: Move this to class File, as this is similar in nature to LoadMatrixFromTextFile().
template <class ElemType>
void SequenceReader<ElemType>::LoadLabelFile(const std::wstring& filePath, std::vector<LabelType>& retLabels)
{
File file(filePath, fileOptionsRead | fileOptionsText);
string str;
retLabels.clear();
while (!file.IsEOF())
{
file.GetLine(str);
if (str.empty())
if (file.IsEOF())
break;
else
RuntimeError("LoadLabelFile: Invalid empty line in label file.");
retLabels.push_back(trim(str));
}
}
// Destroy - cleanup and remove this class
// NOTE: this destroys the object, and it can't be used past this point
template <class ElemType>
@ -482,7 +460,7 @@ void SequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& readerConf
std::wstring labelPath = labelConfig(L"labelMappingFile");
if (fexists(labelPath))
{
LoadLabelFile(labelPath, arrayLabels);
File::LoadLabelFile(labelPath, arrayLabels);
for (int i = 0; i < arrayLabels.size(); ++i)
{
LabelType label = arrayLabels[i];
@ -1480,7 +1458,7 @@ void BatchSequenceReader<ElemType>::InitFromConfig(const ConfigRecordType& reade
std::wstring labelPath = labelConfig(L"labelMappingFile");
if (File::Exists(labelPath))
{
LoadLabelFile(labelPath, arrayLabels);
File::LoadLabelFile(labelPath, arrayLabels);
// build the two-way mapping tables
for (int i = 0; i < arrayLabels.size(); ++i)
{

Просмотреть файл

@ -221,7 +221,6 @@ protected:
size_t RecordsToRead(size_t mbStartSample, bool tail = false);
void ReleaseMemory();
void WriteLabelFile();
void LoadLabelFile(const std::wstring& filePath, std::vector<LabelType>& retLabels);
LabelIdType GetIdFromLabel(const std::string& label, LabelInfo& labelInfo);
bool CheckIdFromLabel(const std::string& labelValue, const LabelInfo& labelInfo, unsigned& labelId);
@ -316,7 +315,6 @@ public:
using Base::labelInfoIn;
using Base::nwords;
using Base::ReadClassInfo;
using Base::LoadLabelFile;
using Base::word4idx;
using Base::idx4word;
using Base::idx4cnt;

Просмотреть файл

@ -109,11 +109,13 @@ public:
// clean up
}
// pass this to WriteOutput() (to file-path, below) to specify how the output should be formatted
struct WriteFormattingOptions
{
// How to interpret the data:
bool isCategoryLabel; // true: find max value in column and output the index instead of the entire vector
bool transpose; // true: one line per sample, each sample (column vector) forms one line; false: one column per sample
bool isCategoryLabel; // true: find max value in column and output the index instead of the entire vector
std::wstring labelMappingFile; // optional dictionary for pretty-printing category labels
bool transpose; // true: one line per sample, each sample (column vector) forms one line; false: one column per sample
// The following strings are interspersed with the data:
// overall
std::string prologue; // print this at the start (e.g. a global header or opening bracket)
@ -133,11 +135,13 @@ public:
{ }
};
// TODO: Remove code dup with above function
// E.g. create a shared function that takes the actual writing operation as a lambda.
// TODO: Remove code dup with above function by creating a fake Writer object and then calling the other function.
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
{
File::MakeIntermediateDirs(outputPath);
// load a label mapping if requested
std::vector<std::string> labelMapping;
if (formattingOptions.isCategoryLabel && !formattingOptions.labelMappingFile.empty())
File::LoadLabelFile(formattingOptions.labelMappingFile, labelMapping);
// specify output nodes and files
std::vector<ComputationNodeBasePtr> outputNodes;
@ -155,6 +159,8 @@ public:
outputNodes.push_back(m_net->GetNodeFromName(outputNodeNames[i]));
}
// open output files
File::MakeIntermediateDirs(outputPath);
std::map<ComputationNodeBasePtr, shared_ptr<File>> outputStreams; // TODO: why does unique_ptr not work here? Complains about non-existent default_delete()
for (auto & onode : outputNodes)
{
@ -191,7 +197,8 @@ public:
fprintfOrDie(f, "%s", formattingOptions.prologue.c_str());
}
std::string valueFormatString = "%" + formattingOptions.precisionFormat + "f"; // format string used in fprintf() for formatting the values
char formatChar = !formattingOptions.isCategoryLabel ? 'f' : !formattingOptions.labelMappingFile.empty() ? 's' : 'u';
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
@ -201,13 +208,6 @@ public:
for (auto & onode : outputNodes)
{
FILE * f = *outputStreams[onode];
// sequence separator
if (numMBsRun > 0 && !formattingOptions.sequenceSeparator.empty())
fprintfOrDie(f, "%s", formattingOptions.sequenceSeparator.c_str());
fprintfOrDie(f, "%s", formattingOptions.sequencePrologue.c_str());
// compute the node value
// Note: Intermediate values are memoized, so in case of multiple output nodes, we only compute what has not been computed already.
m_net->ForwardProp(onode);
@ -217,11 +217,22 @@ public:
outputValues.CopyToArray(tempArray, tempArraySize);
ElemType* pCurValue = tempArray;
// sequence separator
FILE * f = *outputStreams[onode];
if (numMBsRun > 0 && !formattingOptions.sequenceSeparator.empty())
fprintfOrDie(f, "%s", formattingOptions.sequenceSeparator.c_str());
fprintfOrDie(f, "%s", formattingOptions.sequencePrologue.c_str());
// output it according to our format specification
size_t T = outputValues.GetNumCols();
size_t dim = outputValues.GetNumRows();
if (formattingOptions.isCategoryLabel)
{
if (formatChar == 's') // verify label dimension
{
if (dim != labelMapping.size())
InvalidArgument("write: Row dimension %d does not match number of entries %d in labelMappingFile '%ls'", (int)dim, (int)labelMapping.size(), formattingOptions.labelMappingFile.c_str());
}
// update the matrix in-place from one-hot (or max) to index
// find the max in each column
foreach_column(j, outputValues)
@ -253,8 +264,23 @@ public:
{
if (i > 0)
fprintfOrDie(f, "%s", formattingOptions.elementSeparator.c_str());
double val = pCurValue[i * istride + j * jstride];
fprintfOrDie(f, valueFormatString.c_str(), (double)val);
if (formatChar == 'f') // print as real number
{
double dval = pCurValue[i * istride + j * jstride];
fprintfOrDie(f, valueFormatString.c_str(), dval);
}
else if (formatChar == 'u') // print category as integer index
{
unsigned int uval = (unsigned int) pCurValue[i * istride + j * jstride];
fprintfOrDie(f, valueFormatString.c_str(), uval);
}
else if (formatChar == 's') // print category as a label string
{
size_t uval = (size_t) pCurValue[i * istride + j * jstride];
assert(uval < labelMapping.size());
const char * sval = labelMapping[uval].c_str();
fprintfOrDie(f, valueFormatString.c_str(), sval);
}
}
}
fprintfOrDie(f, "%s", formattingOptions.sequenceEpilogue.c_str());