CNTK/Source/Readers/HTKDeserializers/MLFDataDeserializer.cpp

364 строки
12 KiB
C++
Исходник Обычный вид История

//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#define __STDC_FORMAT_MACROS
#include <inttypes.h>
#include <limits>
#include "MLFDataDeserializer.h"
#include "ConfigHelper.h"
2016-09-08 16:36:34 +03:00
#include "SequenceData.h"
#include "../HTKMLFReader/htkfeatio.h"
#include "../HTKMLFReader/msra_mgram.h"
#include "latticearchive.h"
2016-10-04 14:43:35 +03:00
#include "StringUtil.h"
2016-09-08 16:36:34 +03:00
#undef max // max is defined in minwindef.h
namespace Microsoft { namespace MSR { namespace CNTK {
2016-04-15 16:06:46 +03:00
using namespace std;
static float s_oneFloat = 1.0;
static double s_oneDouble = 1.0;
// Currently we only have a single mlf chunk that contains a vector of all labels.
2016-04-28 17:59:17 +03:00
// TODO: In the future MLF should be converted to a more compact format that is amenable to chunking.
class MLFDataDeserializer::MLFChunk : public Chunk
{
MLFDataDeserializer* m_parent;
public:
MLFChunk(MLFDataDeserializer* parent) : m_parent(parent)
{}
2016-04-15 16:06:46 +03:00
virtual void GetSequence(size_t sequenceId, vector<SequenceDataPtr>& result) override
{
m_parent->GetSequenceById(sequenceId, result);
}
};
// Inner class for an utterance.
struct MLFUtterance : SequenceDescription
{
size_t m_sequenceStart;
};
2016-04-15 13:36:08 +03:00
MLFDataDeserializer::MLFDataDeserializer(CorpusDescriptorPtr corpus, const ConfigParameters& cfg, bool primary)
{
2016-04-28 17:59:17 +03:00
// TODO: This should be read in one place, potentially given by SGD.
m_frameMode = (ConfigValue)cfg("frameMode", "true");
2016-04-28 17:59:17 +03:00
// MLF cannot control chunking.
2016-04-15 13:36:08 +03:00
if (primary)
{
2016-04-28 17:59:17 +03:00
LogicError("Mlf deserializer does not support primary mode - it cannot control chunking.");
2016-04-15 13:36:08 +03:00
}
2016-05-12 15:53:39 +03:00
argvector<ConfigValue> inputs = cfg("input");
if (inputs.size() != 1)
{
2016-04-28 17:59:17 +03:00
LogicError("MLFDataDeserializer supports a single input stream only.");
}
2016-10-04 17:43:24 +03:00
std::wstring precision = cfg(L"precision", L"float");;
2016-10-04 14:43:35 +03:00
m_elementType = AreEqualIgnoreCase(precision, L"float") ? ElementType::tfloat : ElementType::tdouble;
ConfigParameters input = inputs.front();
auto inputName = input.GetMemberIds().front();
ConfigParameters streamConfig = input(inputName);
ConfigHelper config(streamConfig);
size_t dimension = config.GetLabelDimension();
2016-04-15 16:06:46 +03:00
wstring labelMappingFile = streamConfig(L"labelMappingFile", L"");
InitializeChunkDescriptions(corpus, config, labelMappingFile, dimension);
InitializeStream(inputName, dimension);
}
2016-04-15 16:06:46 +03:00
MLFDataDeserializer::MLFDataDeserializer(CorpusDescriptorPtr corpus, const ConfigParameters& labelConfig, const wstring& name)
{
// The frame mode is currently specified once per configuration,
// not in the configuration of a particular deserializer, but on a higher level in the configuration.
// Because of that we are using find method below.
m_frameMode = labelConfig.Find("frameMode", "true");
ConfigHelper config(labelConfig);
config.CheckLabelType();
size_t dimension = config.GetLabelDimension();
if (dimension > numeric_limits<IndexType>::max())
{
RuntimeError("Label dimension (%" PRIu64 ") exceeds the maximum allowed "
"value (%" PRIu64 ")\n", dimension, (size_t)numeric_limits<IndexType>::max());
}
2016-10-04 17:43:24 +03:00
std::wstring precision = labelConfig(L"precision", L"float");;
2016-10-04 14:43:35 +03:00
m_elementType = AreEqualIgnoreCase(precision, L"float") ? ElementType::tfloat : ElementType::tdouble;
2016-04-15 16:06:46 +03:00
wstring labelMappingFile = labelConfig(L"labelMappingFile", L"");
InitializeChunkDescriptions(corpus, config, labelMappingFile, dimension);
InitializeStream(name, dimension);
}
// Currently we create a single chunk only.
void MLFDataDeserializer::InitializeChunkDescriptions(CorpusDescriptorPtr corpus, const ConfigHelper& config, const wstring& stateListPath, size_t dimension)
{
// TODO: Similarly to the old reader, currently we assume all Mlfs will have same root name (key)
// restrict MLF reader to these files--will make stuff much faster without having to use shortened input files
// TODO: currently we do not use symbol and word tables.
const msra::lm::CSymbolSet* wordTable = nullptr;
2016-04-15 16:06:46 +03:00
unordered_map<const char*, int>* symbolTable = nullptr;
vector<wstring> mlfPaths = config.GetMlfPaths();
// TODO: Currently we still use the old IO module. This will be refactored later.
const double htkTimeToFrame = 100000.0; // default is 10ms
2016-04-15 16:06:46 +03:00
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence> labels(mlfPaths, set<wstring>(), stateListPath, wordTable, symbolTable, htkTimeToFrame);
// Make sure 'msra::asr::htkmlfreader' type has a move constructor
static_assert(
2016-04-15 16:06:46 +03:00
is_move_constructible<
msra::asr::htkmlfreader<msra::asr::htkmlfentry,
msra::lattices::lattice::htkmlfwordsequence >> ::value,
"Type 'msra::asr::htkmlfreader' should be move constructible!");
MLFUtterance description;
size_t numClasses = 0;
size_t totalFrames = 0;
2016-05-17 14:07:25 +03:00
const auto& stringRegistry = corpus->GetStringRegistry();
// TODO resize m_keyToSequence with number of IDs from string registry
for (const auto& l : labels)
{
2016-03-14 13:10:44 +03:00
// Currently the string registry contains only utterances described in scp.
// So here we skip all others.
size_t id = 0;
if (!stringRegistry.TryGet(msra::strfun::utf8(l.first), id))
2016-03-03 19:04:28 +03:00
continue;
2016-04-26 16:09:10 +03:00
description.m_key.m_sequence = id;
2016-03-03 19:04:28 +03:00
const auto& utterance = l.second;
description.m_sequenceStart = m_classIds.size();
2016-05-27 18:02:21 +03:00
uint32_t numberOfFrames = 0;
foreach_index(i, utterance)
{
const auto& timespan = utterance[i];
if ((i == 0 && timespan.firstframe != 0) ||
(i > 0 && utterance[i - 1].firstframe + utterance[i - 1].numframes != timespan.firstframe))
{
RuntimeError("Labels are not in the consecutive order MLF in label set: %ls", l.first.c_str());
}
if (timespan.classid >= dimension)
{
2016-04-15 16:06:46 +03:00
RuntimeError("Class id %d exceeds the model output dimension %d.", (int)timespan.classid, (int)dimension);
}
if (timespan.classid != static_cast<msra::dbn::CLASSIDTYPE>(timespan.classid))
{
RuntimeError("CLASSIDTYPE has too few bits");
}
2016-05-27 18:02:21 +03:00
if (SEQUENCELEN_MAX < timespan.firstframe + timespan.numframes)
{
RuntimeError("Maximum number of sample per sequence exceeded.");
}
numClasses = max(numClasses, (size_t)(1u + timespan.classid));
for (size_t t = timespan.firstframe; t < timespan.firstframe + timespan.numframes; t++)
{
m_classIds.push_back(timespan.classid);
numberOfFrames++;
}
}
description.m_numberOfSamples = numberOfFrames;
m_utteranceIndex.push_back(totalFrames);
totalFrames += numberOfFrames;
if (m_keyToSequence.size() <= description.m_key.m_sequence)
{
m_keyToSequence.resize(description.m_key.m_sequence + 1, SIZE_MAX);
}
assert(m_keyToSequence[description.m_key.m_sequence] == SIZE_MAX);
m_keyToSequence[description.m_key.m_sequence] = m_utteranceIndex.size() - 1;
m_numberOfSequences++;
}
m_utteranceIndex.push_back(totalFrames);
2016-03-03 14:58:06 +03:00
m_totalNumberOfFrames = totalFrames;
fprintf(stderr, "MLFDataDeserializer::MLFDataDeserializer: %" PRIu64 " utterances with %" PRIu64 " frames in %" PRIu64 " classes\n",
m_numberOfSequences,
m_totalNumberOfFrames,
numClasses);
2016-03-14 13:10:44 +03:00
// Initializing array of labels.
2016-03-04 16:34:51 +03:00
m_categories.reserve(dimension);
m_categoryIndices.reserve(dimension);
2016-03-04 16:34:51 +03:00
for (size_t i = 0; i < dimension; ++i)
{
2016-09-08 16:36:34 +03:00
auto category = make_shared<CategorySequenceData>();
m_categoryIndices.push_back(static_cast<IndexType>(i));
category->m_indices = &(m_categoryIndices[i]);
category->m_nnzCounts.resize(1);
category->m_nnzCounts[0] = 1;
category->m_totalNnzCount = 1;
category->m_numberOfSamples = 1;
2016-03-04 16:34:51 +03:00
if (m_elementType == ElementType::tfloat)
{
category->m_data = &s_oneFloat;
}
else
{
assert(m_elementType == ElementType::tdouble);
category->m_data = &s_oneDouble;
}
m_categories.push_back(category);
}
}
2016-04-15 16:06:46 +03:00
void MLFDataDeserializer::InitializeStream(const wstring& name, size_t dimension)
{
// Initializing stream description - a single stream of MLF data.
StreamDescriptionPtr stream = make_shared<StreamDescription>();
stream->m_id = 0;
stream->m_name = name;
stream->m_sampleLayout = make_shared<TensorShape>(dimension);
stream->m_storageType = StorageType::sparse_csc;
stream->m_elementType = m_elementType;
m_streams.push_back(stream);
}
void InitializeFeatureInformation();
void InitializeAugmentationWindow(ConfigHelper& config);
2016-03-14 13:10:44 +03:00
// Currently MLF has a single chunk.
// TODO: This will be changed when the deserializer properly supports chunking.
2016-03-01 13:11:49 +03:00
ChunkDescriptions MLFDataDeserializer::GetChunkDescriptions()
{
2016-04-15 16:06:46 +03:00
auto cd = make_shared<ChunkDescription>();
2016-03-21 12:50:00 +03:00
cd->m_id = 0;
cd->m_numberOfSequences = m_frameMode ? m_totalNumberOfFrames : m_numberOfSequences;
cd->m_numberOfSamples = m_totalNumberOfFrames;
2016-03-01 13:11:49 +03:00
return ChunkDescriptions{cd};
}
2016-03-14 13:10:44 +03:00
// Gets sequences for a particular chunk.
void MLFDataDeserializer::GetSequencesForChunk(ChunkIdType, vector<SequenceDescription>& result)
{
UNUSED(result);
LogicError("Mlf deserializer does not support primary mode - it cannot control chunking.");
}
ChunkPtr MLFDataDeserializer::GetChunk(ChunkIdType chunkId)
{
UNUSED(chunkId);
assert(chunkId == 0);
return make_shared<MLFChunk>(this);
};
// Sparse labels for an utterance.
template <class ElemType>
struct MLFSequenceData : SparseSequenceData
{
2016-04-15 16:06:46 +03:00
vector<ElemType> m_values;
unique_ptr<IndexType[]> m_indicesPtr;
MLFSequenceData(size_t numberOfSamples) :
m_values(numberOfSamples, 1),
m_indicesPtr(new IndexType[numberOfSamples])
{
if (numberOfSamples > numeric_limits<IndexType>::max())
{
RuntimeError("Number of samples in an MLFSequence (%" PRIu64 ") "
"exceeds the maximum allowed value (%" PRIu64 ")\n",
numberOfSamples, (size_t)numeric_limits<IndexType>::max());
}
m_nnzCounts.resize(numberOfSamples, static_cast<IndexType>(1));
2016-05-27 18:02:21 +03:00
m_numberOfSamples = (uint32_t) numberOfSamples;
m_totalNnzCount = static_cast<IndexType>(numberOfSamples);
m_indices = m_indicesPtr.get();
2016-09-08 16:36:34 +03:00
}
const void* GetDataBuffer() override
{
return m_values.data();
}
};
2016-04-15 16:06:46 +03:00
void MLFDataDeserializer::GetSequenceById(size_t sequenceId, vector<SequenceDataPtr>& result)
{
if (m_frameMode)
{
size_t label = m_classIds[sequenceId];
assert(label < m_categories.size());
result.push_back(m_categories[label]);
}
else
{
// Packing labels for the utterance into sparse sequence.
size_t startFrameIndex = m_utteranceIndex[sequenceId];
size_t numberOfSamples = m_utteranceIndex[sequenceId + 1] - startFrameIndex;
SparseSequenceDataPtr s;
if (m_elementType == ElementType::tfloat)
{
2016-04-15 16:06:46 +03:00
s = make_shared<MLFSequenceData<float>>(numberOfSamples);
}
else
{
assert(m_elementType == ElementType::tdouble);
2016-04-15 16:06:46 +03:00
s = make_shared<MLFSequenceData<double>>(numberOfSamples);
}
for (size_t i = 0; i < numberOfSamples; i++)
{
size_t frameIndex = startFrameIndex + i;
size_t label = m_classIds[frameIndex];
s->m_indices[i] = static_cast<IndexType>(label);
}
result.push_back(s);
}
}
bool MLFDataDeserializer::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& result)
{
auto sequenceId = key.m_sequence < m_keyToSequence.size() ? m_keyToSequence[key.m_sequence] : SIZE_MAX;
if (sequenceId == SIZE_MAX)
{
return false;
}
result.m_chunkId = 0;
result.m_key = key;
if (m_frameMode)
{
size_t index = m_utteranceIndex[sequenceId] + key.m_sample;
result.m_id = index;
result.m_numberOfSamples = 1;
}
else
{
assert(result.m_key.m_sample == 0);
result.m_id = sequenceId;
2016-05-27 18:02:21 +03:00
result.m_numberOfSamples = (uint32_t) (m_utteranceIndex[sequenceId + 1] - m_utteranceIndex[sequenceId]);
}
return true;
}
}}}