Ability to choose mbsize based on stream
This commit is contained in:
Родитель
bf3a37f394
Коммит
98ac0be1c2
2
Makefile
2
Makefile
|
@ -342,6 +342,8 @@ READER_SRC =\
|
|||
$(SOURCEDIR)/Readers/ReaderLib/FramePacker.cpp \
|
||||
$(SOURCEDIR)/Readers/ReaderLib/ReaderBase.cpp \
|
||||
$(SOURCEDIR)/Readers/ReaderLib/Indexer.cpp \
|
||||
$(SOURCEDIR)/Readers/ReaderLib/MemoryBuffer.cpp \
|
||||
$(SOURCEDIR)/Readers/ReaderLib/DataDeserializerBase.cpp \
|
||||
$(SOURCEDIR)/Readers/ReaderLib/ChunkCache.cpp \
|
||||
$(SOURCEDIR)/Readers/ReaderLib/ReaderUtil.cpp \
|
||||
|
||||
|
|
|
@ -5163,20 +5163,21 @@ namespace CNTK
|
|||
|
||||
struct StreamConfiguration
|
||||
{
|
||||
StreamConfiguration(const std::wstring& streamName, size_t dim, bool isSparse = false, const std::wstring& streamAlias = L"")
|
||||
: m_streamName(streamName), m_dim(dim), m_isSparse(isSparse), m_streamAlias(streamAlias)
|
||||
StreamConfiguration(const std::wstring& streamName, size_t dim, bool isSparse = false, const std::wstring& streamAlias = L"", bool definesMbSize = false)
|
||||
: m_streamName(streamName), m_dim(dim), m_isSparse(isSparse), m_streamAlias(streamAlias), m_definesMbSize(definesMbSize)
|
||||
{}
|
||||
|
||||
std::wstring m_streamName;
|
||||
size_t m_dim;
|
||||
bool m_isSparse;
|
||||
std::wstring m_streamAlias;
|
||||
bool m_definesMbSize;
|
||||
};
|
||||
|
||||
struct HTKFeatureConfiguration
|
||||
{
|
||||
HTKFeatureConfiguration(const std::wstring& streamName, const std::wstring& scp, size_t dim, size_t left, size_t right, bool broadcast)
|
||||
: m_streamName(streamName), m_dim(dim), m_scp(scp), m_left(left), m_right(right), m_broadcast(broadcast)
|
||||
HTKFeatureConfiguration(const std::wstring& streamName, const std::wstring& scp, size_t dim, size_t left, size_t right, bool broadcast, bool definesMbSize = false)
|
||||
: m_streamName(streamName), m_dim(dim), m_scp(scp), m_left(left), m_right(right), m_broadcast(broadcast), m_definesMbSize(definesMbSize)
|
||||
{}
|
||||
|
||||
std::wstring m_streamName;
|
||||
|
@ -5185,6 +5186,7 @@ namespace CNTK
|
|||
size_t m_left;
|
||||
size_t m_right;
|
||||
bool m_broadcast;
|
||||
bool m_definesMbSize;
|
||||
};
|
||||
|
||||
typedef Dictionary ImageTransform;
|
||||
|
|
|
@ -376,6 +376,7 @@ namespace CNTK
|
|||
Dictionary stream;
|
||||
stream[L"dim"] = s.m_dim;
|
||||
stream[L"format"] = s.m_isSparse ? L"sparse" : L"dense";
|
||||
stream[L"definesMBSize"] = s.m_definesMbSize;
|
||||
if (!s.m_streamAlias.empty())
|
||||
stream[L"alias"] = s.m_streamAlias;
|
||||
input[key] = stream;
|
||||
|
@ -394,6 +395,7 @@ namespace CNTK
|
|||
Dictionary stream;
|
||||
std::vector<DictionaryValue> ctxWindow = { DictionaryValue(s.m_left), DictionaryValue(s.m_right) };
|
||||
stream.Add(L"scpFile", s.m_scp, L"dim", s.m_dim, L"contextWindow", ctxWindow, L"expandToUtterance", s.m_broadcast);
|
||||
stream[L"definesMBSize"] = s.m_definesMbSize;
|
||||
input[key] = stream;
|
||||
}
|
||||
htk.Add(L"type", L"HTKFeatureDeserializer", L"input", input);
|
||||
|
|
|
@ -68,6 +68,7 @@ TextConfigHelper::TextConfigHelper(const ConfigParameters& config)
|
|||
stream.m_id = id++;
|
||||
stream.m_name = name;
|
||||
stream.m_sampleDimension = input2(L"dim");
|
||||
stream.m_definesMbSize = input2(L"definesMBSize", false);
|
||||
string type = input2(L"format");
|
||||
|
||||
if (AreEqualIgnoreCase(type, "dense"))
|
||||
|
|
|
@ -37,7 +37,7 @@ template <class ElemType>
|
|||
class TextParser<ElemType>::TextDataChunk : public Chunk, public std::enable_shared_from_this<Chunk>
|
||||
{
|
||||
public:
|
||||
explicit TextDataChunk(const ChunkDescriptor& descriptor, TextParser* parser);
|
||||
explicit TextDataChunk(TextParser* parser);
|
||||
|
||||
// Gets sequences by id.
|
||||
void GetSequence(size_t sequenceId, std::vector<SequenceDataPtr>& result) override;
|
||||
|
@ -45,9 +45,6 @@ public:
|
|||
// A map from sequence ids to the sequence data.
|
||||
std::vector<SequenceBuffer> m_sequenceMap;
|
||||
|
||||
// chunk id (copied from the descriptor)
|
||||
ChunkIdType m_id;
|
||||
|
||||
// a non-owned pointer to the parser that created this chunk
|
||||
TextParser* m_parser;
|
||||
};
|
||||
|
@ -76,6 +73,7 @@ TextParser(corpus, helper.GetFilePath(), helper.GetStreams(), primary)
|
|||
template <class ElemType>
|
||||
TextParser<ElemType>::TextParser(CorpusDescriptorPtr corpus, const std::wstring& filename, const vector<StreamDescriptor>& streams, bool primary) :
|
||||
DataDeserializerBase(primary),
|
||||
m_streamDescriptors(streams),
|
||||
m_filename(filename),
|
||||
m_file(nullptr),
|
||||
m_streamInfos(streams.size()),
|
||||
|
@ -159,7 +157,7 @@ void TextParser<ElemType>::Initialize()
|
|||
fclose(m_file);
|
||||
m_file = fopenOrDie(m_filename, L"rbS");
|
||||
}
|
||||
|
||||
|
||||
if (funicode(m_file))
|
||||
{
|
||||
// Retrying won't help here, the file is UTF-16 encoded.
|
||||
|
@ -168,7 +166,21 @@ void TextParser<ElemType>::Initialize()
|
|||
"UTF-16 encoding is currently not supported.", m_filename.c_str());
|
||||
}
|
||||
|
||||
m_indexer = make_unique<Indexer>(m_file, m_primary, m_skipSequenceIds, NAME_PREFIX, m_chunkSizeBytes);
|
||||
std::string mainStreamAlias = "";
|
||||
auto mainStream = std::find_if(m_streamDescriptors.begin(), m_streamDescriptors.end(), [](const StreamDescriptor& s) { return s.m_definesMbSize; });
|
||||
if (mainStream != m_streamDescriptors.end())
|
||||
{
|
||||
mainStreamAlias = mainStream->m_alias;
|
||||
set<wstring> streams;
|
||||
for (auto s : m_streamDescriptors)
|
||||
if (s.m_definesMbSize)
|
||||
streams.insert(s.m_name);
|
||||
|
||||
if (streams.size() > 1)
|
||||
RuntimeError("Only a single stream is allowed to define the minibatch size, but %zu found.", streams.size());
|
||||
}
|
||||
|
||||
m_indexer = make_unique<Indexer>(m_file, m_primary, m_skipSequenceIds, NAME_PREFIX, m_chunkSizeBytes, mainStreamAlias);
|
||||
m_indexer->Build(m_corpus);
|
||||
});
|
||||
|
||||
|
@ -191,14 +203,14 @@ ChunkDescriptions TextParser<ElemType>::GetChunkDescriptions()
|
|||
const auto& index = m_indexer->GetIndex();
|
||||
|
||||
ChunkDescriptions result;
|
||||
result.reserve(index.m_chunks.size());
|
||||
for (auto const& chunk : index.m_chunks)
|
||||
result.reserve(index.Chunks().size());
|
||||
for (ChunkIdType i = 0; i < index.Chunks().size(); ++i)
|
||||
{
|
||||
result.push_back(shared_ptr<ChunkDescription>(
|
||||
new ChunkDescription {
|
||||
chunk.m_id,
|
||||
chunk.m_numberOfSamples,
|
||||
chunk.m_numberOfSequences
|
||||
new ChunkDescription{
|
||||
i,
|
||||
index.Chunks()[i].NumSamples(),
|
||||
index.Chunks()[i].Sequences().size()
|
||||
}));
|
||||
}
|
||||
|
||||
|
@ -209,12 +221,12 @@ template <class ElemType>
|
|||
void TextParser<ElemType>::GetSequencesForChunk(ChunkIdType chunkId, std::vector<SequenceDescription>& result)
|
||||
{
|
||||
const auto& index = m_indexer->GetIndex();
|
||||
const auto& chunk = index.m_chunks[chunkId];
|
||||
result.reserve(chunk.m_sequences.size());
|
||||
const auto& chunk = index.Chunks()[chunkId];
|
||||
result.reserve(chunk.Sequences().size());
|
||||
|
||||
for (size_t sequenceIndex = 0; sequenceIndex < chunk.m_sequences.size(); ++sequenceIndex)
|
||||
for (size_t sequenceIndex = 0; sequenceIndex < chunk.Sequences().size(); ++sequenceIndex)
|
||||
{
|
||||
auto const& s = chunk.m_sequences[sequenceIndex];
|
||||
auto const& s = chunk.Sequences()[sequenceIndex];
|
||||
result.push_back(
|
||||
{
|
||||
sequenceIndex,
|
||||
|
@ -226,10 +238,9 @@ void TextParser<ElemType>::GetSequencesForChunk(ChunkIdType chunkId, std::vector
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
TextParser<ElemType>::TextDataChunk::TextDataChunk(const ChunkDescriptor& descriptor, TextParser* parser) :
|
||||
TextParser<ElemType>::TextDataChunk::TextDataChunk(TextParser* parser) :
|
||||
m_parser(parser)
|
||||
{
|
||||
m_id = descriptor.m_id;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -245,8 +256,8 @@ void TextParser<ElemType>::TextDataChunk::GetSequence(size_t sequenceId, std::ve
|
|||
template <class ElemType>
|
||||
ChunkPtr TextParser<ElemType>::GetChunk(ChunkIdType chunkId)
|
||||
{
|
||||
const auto& chunkDescriptor = m_indexer->GetIndex().m_chunks[chunkId];
|
||||
auto textChunk = make_shared<TextDataChunk>(chunkDescriptor, this);
|
||||
const auto& chunkDescriptor = m_indexer->GetIndex().Chunks()[chunkId];
|
||||
auto textChunk = make_shared<TextDataChunk>(this);
|
||||
|
||||
attempt(m_numRetries, [this, &textChunk, &chunkDescriptor]()
|
||||
{
|
||||
|
@ -264,10 +275,10 @@ ChunkPtr TextParser<ElemType>::GetChunk(ChunkIdType chunkId)
|
|||
template <class ElemType>
|
||||
void TextParser<ElemType>::LoadChunk(TextChunkPtr& chunk, const ChunkDescriptor& descriptor)
|
||||
{
|
||||
chunk->m_sequenceMap.resize(descriptor.m_sequences.size());
|
||||
for (size_t sequenceIndex = 0; sequenceIndex < descriptor.m_sequences.size(); ++sequenceIndex)
|
||||
chunk->m_sequenceMap.resize(descriptor.Sequences().size());
|
||||
for (size_t sequenceIndex = 0; sequenceIndex < descriptor.Sequences().size(); ++sequenceIndex)
|
||||
{
|
||||
const auto& sequenceDescriptor = descriptor.m_sequences[sequenceIndex];
|
||||
const auto& sequenceDescriptor = descriptor.Sequences()[sequenceIndex];
|
||||
chunk->m_sequenceMap[sequenceIndex] = LoadSequence(sequenceDescriptor, descriptor.m_offset);
|
||||
}
|
||||
}
|
||||
|
@ -356,7 +367,9 @@ typename TextParser<ElemType>::SequenceBuffer TextParser<ElemType>::LoadSequence
|
|||
}
|
||||
|
||||
size_t numRowsRead = 0, expectedRowCount = sequenceDsc.m_numberOfSamples;
|
||||
for (size_t i = 0; i < expectedRowCount; i++)
|
||||
bool checkExpectedAsMax = m_indexer->MainStream().empty();
|
||||
size_t rowNumber = 1;
|
||||
while(bytesToRead)
|
||||
{
|
||||
if ((TryReadRow(sequence, bytesToRead)))
|
||||
{
|
||||
|
@ -369,26 +382,24 @@ typename TextParser<ElemType>::SequenceBuffer TextParser<ElemType>::LoadSequence
|
|||
fprintf(stderr,
|
||||
"WARNING: Could not read a row (# %" PRIu64 ")"
|
||||
" while loading sequence (id = %" PRIu64 ") %ls.\n",
|
||||
i + 1,
|
||||
rowNumber,
|
||||
sequenceDsc.m_key,
|
||||
GetFileInfo().c_str());
|
||||
}
|
||||
IncrementNumberOfErrorsOrDie();
|
||||
}
|
||||
rowNumber++;
|
||||
}
|
||||
|
||||
if (ShouldWarn() && checkExpectedAsMax && numRowsRead < expectedRowCount)
|
||||
{
|
||||
fprintf(stderr,
|
||||
"WARNING: Exhausted all input"
|
||||
" expected for the current sequence (id = %" PRIu64 ") %ls,"
|
||||
" but only read %" PRIu64 " out of %" PRIu64 " expected rows.\n",
|
||||
sequenceDsc.m_key,
|
||||
GetFileInfo().c_str(), numRowsRead, expectedRowCount);
|
||||
|
||||
if (!bytesToRead && numRowsRead < expectedRowCount)
|
||||
{
|
||||
if (ShouldWarn())
|
||||
{
|
||||
fprintf(stderr,
|
||||
"WARNING: Exhausted all input"
|
||||
" expected for the current sequence (id = %" PRIu64 ") %ls,"
|
||||
" but only read %" PRIu64 " out of %" PRIu64 " expected rows.\n",
|
||||
sequenceDsc.m_key,
|
||||
GetFileInfo().c_str(), numRowsRead, expectedRowCount);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Double check if there are empty input streams.
|
||||
|
@ -405,7 +416,7 @@ typename TextParser<ElemType>::SequenceBuffer TextParser<ElemType>::LoadSequence
|
|||
hasEmptyInputs = true;
|
||||
}
|
||||
|
||||
if (sequence[i]->m_numberOfSamples > expectedRowCount)
|
||||
if (checkExpectedAsMax && sequence[i]->m_numberOfSamples > expectedRowCount)
|
||||
{
|
||||
hasDuplicateInputs = true;
|
||||
if (ShouldWarn())
|
||||
|
@ -430,7 +441,7 @@ typename TextParser<ElemType>::SequenceBuffer TextParser<ElemType>::LoadSequence
|
|||
{
|
||||
IncrementNumberOfErrorsOrDie();
|
||||
}
|
||||
else if (maxInputLength < expectedRowCount)
|
||||
else if (checkExpectedAsMax && maxInputLength < expectedRowCount)
|
||||
{
|
||||
if (ShouldWarn())
|
||||
{
|
||||
|
@ -1289,31 +1300,9 @@ std::wstring TextParser<ElemType>::GetFileInfo()
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
bool TextParser<ElemType>::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& result)
|
||||
bool TextParser<ElemType>::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& r)
|
||||
{
|
||||
if (m_primary)
|
||||
LogicError("Matching by sequence key is not supported for primary deserilalizer.");
|
||||
|
||||
const auto& keys = m_indexer->GetIndex().m_keyToSequenceInChunk;
|
||||
auto sequenceLocation = keys.find(key.m_sequence);
|
||||
if (sequenceLocation == keys.end())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& index = m_indexer->GetIndex();
|
||||
|
||||
assert(sequenceLocation->second.first < index.m_chunks.size());
|
||||
const auto& chunk = index.m_chunks[sequenceLocation->second.first];
|
||||
|
||||
assert(sequenceLocation->second.second < chunk.m_sequences.size());
|
||||
const auto& sequence = chunk.m_sequences[sequenceLocation->second.second];
|
||||
|
||||
result.m_chunkId = sequenceLocation->second.first;
|
||||
result.m_indexInChunk = sequenceLocation->second.second;
|
||||
result.m_numberOfSamples = sequence.m_numberOfSamples;
|
||||
result.m_key = key;
|
||||
return true;
|
||||
return DataDeserializerBase::GetSequenceDescriptionByKey(m_indexer->GetIndex(), key, r);
|
||||
}
|
||||
|
||||
template class TextParser<float>;
|
||||
|
|
|
@ -98,6 +98,7 @@ private:
|
|||
// into sequence data in a proper format.
|
||||
struct StreamInfo;
|
||||
std::vector<StreamInfo> m_streamInfos;
|
||||
std::vector<StreamDescriptor> m_streamDescriptors;
|
||||
|
||||
size_t m_maxAliasLength;
|
||||
std::map<std::string, size_t> m_aliasToIdMap;
|
||||
|
|
|
@ -76,12 +76,12 @@ protected:
|
|||
m_descriptor(descriptor),
|
||||
m_deserializer(deserializer)
|
||||
{
|
||||
if (descriptor.m_sequences.empty() || !descriptor.m_byteSize)
|
||||
if (descriptor.Sequences().empty() || !descriptor.SizeInBytes())
|
||||
LogicError("Empty chunks are not supported.");
|
||||
|
||||
auto f = shared_ptr<FILE>(fopenOrDie(fileName, L"rbS"), [](FILE *f) { if (f) fclose(f); });
|
||||
size_t sizeInBytes =
|
||||
descriptor.m_sequences.back().OffsetInChunk() + descriptor.m_sequences.back().SizeInBytes();
|
||||
descriptor.Sequences().back().OffsetInChunk() + descriptor.Sequences().back().SizeInBytes();
|
||||
|
||||
// Make sure we always have 0 at the end for buffer overrun.
|
||||
m_buffer.resize(sizeInBytes + 1);
|
||||
|
@ -97,7 +97,7 @@ protected:
|
|||
freadOrDie(m_buffer.data(), 1, sizeInBytes, f.get());
|
||||
|
||||
// all sequences are valid by default.
|
||||
m_valid.resize(m_descriptor.m_numberOfSequences, true);
|
||||
m_valid.resize(m_descriptor.Sequences().size(), true);
|
||||
}
|
||||
|
||||
string KeyOf(const SequenceDescriptor& s)
|
||||
|
@ -122,11 +122,11 @@ public:
|
|||
SequenceChunk(const MLFDeserializer& parent, const ChunkDescriptor& descriptor, const wstring& fileName, StateTablePtr states)
|
||||
: ChunkBase(parent, descriptor, fileName, states)
|
||||
{
|
||||
m_sequences.resize(m_descriptor.m_numberOfSequences);
|
||||
m_sequences.resize(m_descriptor.Sequences().size());
|
||||
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < descriptor.m_sequences.size(); ++i)
|
||||
CacheSequence(descriptor.m_sequences[i], i);
|
||||
for (int i = 0; i < descriptor.Sequences().size(); ++i)
|
||||
CacheSequence(descriptor.Sequences()[i], i);
|
||||
|
||||
CleanBuffer();
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ public:
|
|||
}
|
||||
|
||||
const auto& utterance = m_sequences[sequenceIndex];
|
||||
const auto& sequence = m_descriptor.m_sequences[sequenceIndex];
|
||||
const auto& sequence = m_descriptor.Sequences()[sequenceIndex];
|
||||
|
||||
// Packing labels for the utterance into sparse sequence.
|
||||
vector<size_t> sequencePhoneBoundaries(m_deserializer.m_withPhoneBoundaries ? utterance.size() : 0);
|
||||
|
@ -213,12 +213,12 @@ public:
|
|||
: ChunkBase(parent, descriptor, fileName, states)
|
||||
{
|
||||
// Preallocate a big array for filling in class ids for the whole chunk.
|
||||
m_classIds.resize(m_descriptor.m_numberOfSamples);
|
||||
m_classIds.resize(m_descriptor.NumSamples());
|
||||
|
||||
// Parse the data on different threads to avoid locking during GetSequence calls.
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < descriptor.m_sequences.size(); ++i)
|
||||
CacheSequence(descriptor.m_sequences[i], i);
|
||||
for (int i = 0; i < descriptor.Sequences().size(); ++i)
|
||||
CacheSequence(descriptor.Sequences()[i], i);
|
||||
|
||||
CleanBuffer();
|
||||
}
|
||||
|
@ -228,11 +228,11 @@ public:
|
|||
size_t GetUtteranceForChunkFrameIndex(size_t frameIndex) const
|
||||
{
|
||||
auto result = upper_bound(
|
||||
m_descriptor.m_sequenceOffsetInChunkInSamples.begin(),
|
||||
m_descriptor.m_sequenceOffsetInChunkInSamples.end(),
|
||||
m_descriptor.SequenceOffsetInSamples().begin(),
|
||||
m_descriptor.SequenceOffsetInSamples().end(),
|
||||
frameIndex,
|
||||
[](size_t fi, const size_t& a) { return fi < a; });
|
||||
return result - 1 - m_descriptor.m_sequenceOffsetInChunkInSamples.begin();
|
||||
return result - 1 - m_descriptor.SequenceOffsetInSamples().begin();
|
||||
}
|
||||
|
||||
void GetSequence(size_t sequenceIndex, vector<SequenceDataPtr>& result) override
|
||||
|
@ -267,7 +267,7 @@ public:
|
|||
return;
|
||||
}
|
||||
|
||||
auto startRange = m_classIds.begin() + m_descriptor.m_sequenceOffsetInChunkInSamples[index];
|
||||
auto startRange = m_classIds.begin() + m_descriptor.SequenceOffsetInSamples()[index];
|
||||
for(size_t i = 0; i < utterance.size(); ++i)
|
||||
{
|
||||
const auto& range = utterance[i];
|
||||
|
@ -385,18 +385,18 @@ void MLFDeserializer::InitializeChunkDescriptions(CorpusDescriptorPtr corpus, co
|
|||
|
||||
// Build auxiliary for GetSequenceByKey.
|
||||
const auto& index = indexer->GetIndex();
|
||||
for (uint32_t chunkIndex = 0; chunkIndex < index.m_chunks.size(); ++chunkIndex)
|
||||
for (uint32_t chunkIndex = 0; chunkIndex < index.Chunks().size(); ++chunkIndex)
|
||||
{
|
||||
const auto& chunk = index.m_chunks[chunkIndex];
|
||||
const auto& chunk = index.Chunks()[chunkIndex];
|
||||
// Preparing chunk info that will be exposed to the outside.
|
||||
for (uint32_t i = 0; i < chunk.m_sequences.size(); ++i)
|
||||
for (uint32_t i = 0; i < chunk.Sequences().size(); ++i)
|
||||
{
|
||||
const auto& sequence = chunk.m_sequences[i];
|
||||
const auto& sequence = chunk.Sequences()[i];
|
||||
m_keyToChunkLocation.push_back(std::make_tuple(sequence.m_key, static_cast<ChunkIdType>(m_chunks.size()), i));
|
||||
}
|
||||
|
||||
totalNumSequences += chunk.m_numberOfSequences;
|
||||
totalNumFrames += chunk.m_numberOfSamples;
|
||||
totalNumSequences += chunk.Sequences().size();
|
||||
totalNumFrames += chunk.NumSamples();
|
||||
m_chunkToFileIndex.insert(make_pair(&chunk, m_mlfFiles.size() - 1));
|
||||
m_chunks.push_back(&chunk);
|
||||
if (m_chunks.size() >= numeric_limits<ChunkIdType>::max())
|
||||
|
@ -458,8 +458,8 @@ ChunkDescriptions MLFDeserializer::GetChunkDescriptions()
|
|||
if (cd->m_id != i)
|
||||
RuntimeError("ChunkIdType overflow during creation of a chunk description.");
|
||||
|
||||
cd->m_numberOfSequences = m_frameMode ? m_chunks[i]->m_numberOfSamples : m_chunks[i]->m_numberOfSequences;
|
||||
cd->m_numberOfSamples = m_chunks[i]->m_numberOfSamples;
|
||||
cd->m_numberOfSequences = m_frameMode ? m_chunks[i]->NumSamples() : m_chunks[i]->Sequences().size();
|
||||
cd->m_numberOfSamples = m_chunks[i]->NumSamples();
|
||||
chunks.push_back(cd);
|
||||
}
|
||||
return chunks;
|
||||
|
@ -502,15 +502,16 @@ bool MLFDeserializer::GetSequenceDescriptionByKey(const KeyType& key, SequenceDe
|
|||
auto chunkId = std::get<1>(*found);
|
||||
auto sequenceIndexInChunk = std::get<2>(*found);
|
||||
|
||||
|
||||
const auto* chunk = m_chunks[chunkId];
|
||||
const auto& sequence = chunk->m_sequences[sequenceIndexInChunk];
|
||||
const auto& sequence = chunk->Sequences()[sequenceIndexInChunk];
|
||||
|
||||
result.m_chunkId = std::get<1>(*found);
|
||||
result.m_key = key;
|
||||
|
||||
if (m_frameMode)
|
||||
{
|
||||
result.m_indexInChunk = chunk->m_sequenceOffsetInChunkInSamples[sequenceIndexInChunk] + key.m_sample;
|
||||
result.m_indexInChunk = chunk->SequenceOffsetInSamples()[sequenceIndexInChunk] + key.m_sample;
|
||||
result.m_numberOfSamples = 1;
|
||||
}
|
||||
else
|
||||
|
|
|
@ -29,13 +29,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (ferror(m_deserializer.m_dataFile.get()) != 0)
|
||||
m_deserializer.m_dataFile.reset(fopenOrDie(m_deserializer.m_fileName.c_str(), L"rbS"), [](FILE* f) { if (f) fclose(f); });
|
||||
|
||||
if (descriptor.m_sequences.empty() || !descriptor.m_byteSize)
|
||||
if (descriptor.Sequences().empty() || !descriptor.SizeInBytes())
|
||||
LogicError("Empty chunks are not supported.");
|
||||
|
||||
m_buffer.resize(descriptor.m_byteSize + 1);
|
||||
m_buffer.resize(descriptor.SizeInBytes() + 1);
|
||||
|
||||
// Make sure we always have 0 at the end for buffer overrun.
|
||||
m_buffer[descriptor.m_byteSize] = 0;
|
||||
m_buffer[descriptor.SizeInBytes()] = 0;
|
||||
m_chunkOffset = descriptor.m_offset;
|
||||
|
||||
// Read chunk into memory.
|
||||
|
@ -43,7 +43,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
if (rc)
|
||||
RuntimeError("Error seeking to position '%" PRId64 "' in the input file '%ls', error code '%d'", m_chunkOffset, m_deserializer.m_fileName.c_str(), rc);
|
||||
|
||||
freadOrDie(m_buffer.data(), descriptor.m_byteSize, 1, m_deserializer.m_dataFile.get());
|
||||
freadOrDie(m_buffer.data(), descriptor.SizeInBytes(), 1, m_deserializer.m_dataFile.get());
|
||||
}
|
||||
|
||||
std::string KeyOf(const SequenceDescriptor& s) const
|
||||
|
@ -56,7 +56,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
const size_t innerSequenceIndex = m_deserializer.m_multiViewCrop ? sequenceIndex / ImageDeserializerBase::NumMultiViewCopies : sequenceIndex;
|
||||
const size_t copyId = m_deserializer.m_multiViewCrop ? sequenceIndex % ImageDeserializerBase::NumMultiViewCopies : 0;
|
||||
|
||||
const auto& sequence = m_descriptor.m_sequences[innerSequenceIndex];
|
||||
const auto& sequence = m_descriptor.Sequences()[innerSequenceIndex];
|
||||
const size_t offset = sequence.OffsetInChunk();
|
||||
|
||||
// m_buffer always end on 0, so no overrun can happen.
|
||||
|
@ -160,13 +160,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// In case of multi crop the deserializer provides the same sequence NumMultiViewCopies times.
|
||||
size_t sequencesPerInitialSequence = m_multiViewCrop ? ImageDeserializerBase::NumMultiViewCopies : 1;
|
||||
ChunkDescriptions result;
|
||||
result.reserve(index.m_chunks.size() * sequencesPerInitialSequence);
|
||||
for (auto const& chunk : index.m_chunks)
|
||||
result.reserve(index.Chunks().size() * sequencesPerInitialSequence);
|
||||
for(uint32_t i = 0; i < index.Chunks().size(); ++i)
|
||||
{
|
||||
const auto& chunk = index.Chunks()[i];
|
||||
auto c = std::make_shared<ChunkDescription>();
|
||||
c->m_id = chunk.m_id;
|
||||
assert(chunk.m_numberOfSamples == chunk.m_numberOfSequences);
|
||||
c->m_numberOfSamples = c->m_numberOfSequences = chunk.m_numberOfSequences * sequencesPerInitialSequence;
|
||||
c->m_id = i;
|
||||
assert(chunk.NumSamples() == chunk.Sequences().size());
|
||||
c->m_numberOfSamples = c->m_numberOfSequences = chunk.Sequences().size() * sequencesPerInitialSequence;
|
||||
result.push_back(c);
|
||||
}
|
||||
return result;
|
||||
|
@ -175,13 +176,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
void Base64ImageDeserializer::GetSequencesForChunk(ChunkIdType chunkId, std::vector<SequenceDescription>& result)
|
||||
{
|
||||
const auto& index = m_indexer->GetIndex();
|
||||
const auto& chunk = index.m_chunks[chunkId];
|
||||
const auto& chunk = index.Chunks()[chunkId];
|
||||
size_t sequenceCopies = m_multiViewCrop ? NumMultiViewCopies : 1;
|
||||
result.reserve(sequenceCopies * chunk.m_sequences.size());
|
||||
result.reserve(sequenceCopies * chunk.Sequences().size());
|
||||
size_t currentId = 0;
|
||||
for (uint32_t indexInChunk = 0; indexInChunk < chunk.m_sequences.size(); ++indexInChunk)
|
||||
for (uint32_t indexInChunk = 0; indexInChunk < chunk.Sequences().size(); ++indexInChunk)
|
||||
{
|
||||
auto const& s = chunk.m_sequences[indexInChunk];
|
||||
auto const& s = chunk.Sequences()[indexInChunk];
|
||||
assert(currentId / sequenceCopies == indexInChunk);
|
||||
for (size_t i = 0; i < sequenceCopies; ++i)
|
||||
{
|
||||
|
@ -200,30 +201,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
ChunkPtr Base64ImageDeserializer::GetChunk(ChunkIdType chunkId)
|
||||
{
|
||||
const auto& chunkDescriptor = m_indexer->GetIndex().m_chunks[chunkId];
|
||||
const auto& chunkDescriptor = m_indexer->GetIndex().Chunks()[chunkId];
|
||||
return make_shared<ImageChunk>(chunkDescriptor, *this);
|
||||
}
|
||||
|
||||
bool Base64ImageDeserializer::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& result)
|
||||
bool Base64ImageDeserializer::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& r)
|
||||
{
|
||||
const auto& index = m_indexer->GetIndex();
|
||||
|
||||
const auto& keys = index.m_keyToSequenceInChunk;
|
||||
auto sequenceLocation = keys.find(key.m_sequence);
|
||||
if (sequenceLocation == keys.end())
|
||||
return false;
|
||||
|
||||
assert(sequenceLocation->second.first < index.m_chunks.size());
|
||||
const auto& chunk = index.m_chunks[sequenceLocation->second.first];
|
||||
|
||||
assert(sequenceLocation->second.second < chunk.m_sequences.size());
|
||||
const auto sequence = chunk.m_sequences[sequenceLocation->second.second];
|
||||
|
||||
result.m_chunkId = sequenceLocation->second.first;
|
||||
result.m_indexInChunk = sequenceLocation->second.second;
|
||||
result.m_key = { sequence.m_key, 0 };
|
||||
result.m_numberOfSamples = sequence.m_numberOfSamples;
|
||||
return true;
|
||||
return DataDeserializerBase::GetSequenceDescriptionByKey(m_indexer->GetIndex(), key, r);
|
||||
}
|
||||
|
||||
}}}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
|
||||
#include "DataDeserializerBase.h"
|
||||
#include "Indexer.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
bool DataDeserializerBase::GetSequenceDescriptionByKey(const Index& index, const KeyType& key, SequenceDescription& r)
|
||||
{
|
||||
if (m_primary)
|
||||
LogicError("Matching by sequence key is not supported for primary deserilalizer.");
|
||||
|
||||
auto sequenceLocation = index.GetSequenceByKey(key.m_sequence);
|
||||
if (!std::get<0>(sequenceLocation))
|
||||
return false;
|
||||
|
||||
r.m_chunkId = std::get<1>(sequenceLocation);
|
||||
r.m_indexInChunk = std::get<2>(sequenceLocation);
|
||||
r.m_key = key;
|
||||
|
||||
assert(r.m_chunkId < index.Chunks().size());
|
||||
const auto& chunk = index.Chunks()[r.m_chunkId];
|
||||
|
||||
assert(r.m_indexInChunk < chunk.Sequences().size());
|
||||
const auto& sequence = chunk.Sequences()[r.m_indexInChunk];
|
||||
|
||||
r.m_numberOfSamples = sequence.m_numberOfSamples;
|
||||
return true;
|
||||
}
|
||||
|
||||
}}}
|
|
@ -9,6 +9,8 @@
|
|||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
struct Index;
|
||||
|
||||
// Base class for data deserializers.
|
||||
// Has a default implementation for a subset of methods.
|
||||
class DataDeserializerBase : public IDataDeserializer
|
||||
|
@ -33,6 +35,8 @@ protected:
|
|||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
bool GetSequenceDescriptionByKey(const Index& index, const KeyType& key, SequenceDescription& r);
|
||||
|
||||
// Streams this data deserializer can produce.
|
||||
std::vector<StreamDescriptionPtr> m_streams;
|
||||
|
||||
|
|
|
@ -7,85 +7,50 @@
|
|||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#include <inttypes.h>
|
||||
#include "Indexer.h"
|
||||
#include <boost/utility/string_ref.hpp>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
using std::string;
|
||||
|
||||
const static char ROW_DELIMITER = '\n';
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
Indexer::Indexer(FILE* file, bool primary, bool skipSequenceIds, char streamPrefix, size_t chunkSize, size_t bufferSize) :
|
||||
Indexer::Indexer(FILE* file, bool primary, bool skipSequenceIds, char streamPrefix, size_t chunkSize, const std::string& mainStream, size_t bufferSize) :
|
||||
m_streamPrefix(streamPrefix),
|
||||
m_bufferSize(bufferSize),
|
||||
m_buffer(bufferSize, !mainStream.empty()),
|
||||
m_file(file),
|
||||
m_fileOffsetStart(0),
|
||||
m_fileOffsetEnd(0),
|
||||
m_buffer(new char[bufferSize + 1]),
|
||||
m_bufferStart(nullptr),
|
||||
m_bufferEnd(nullptr),
|
||||
m_pos(nullptr),
|
||||
m_done(false),
|
||||
m_hasSequenceIds(!skipSequenceIds),
|
||||
m_index(chunkSize, primary)
|
||||
m_index(chunkSize, primary),
|
||||
m_mainStream(mainStream)
|
||||
{
|
||||
if (m_file == nullptr)
|
||||
{
|
||||
RuntimeError("Input file not open for reading");
|
||||
}
|
||||
|
||||
fseekOrDie(m_file, 0, SEEK_SET);
|
||||
}
|
||||
|
||||
void Indexer::RefillBuffer()
|
||||
{
|
||||
if (!m_done)
|
||||
{
|
||||
size_t bytesRead = fread(m_buffer.get(), 1, m_bufferSize, m_file);
|
||||
if (bytesRead == (size_t)-1)
|
||||
RuntimeError("Could not read from the input file.");
|
||||
if (bytesRead == 0)
|
||||
{
|
||||
m_done = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_fileOffsetStart = m_fileOffsetEnd;
|
||||
m_fileOffsetEnd += bytesRead;
|
||||
m_bufferStart = m_buffer.get();
|
||||
m_pos = m_bufferStart;
|
||||
m_bufferEnd = m_bufferStart + bytesRead;
|
||||
}
|
||||
}
|
||||
m_fileSize = filesize(file);
|
||||
}
|
||||
|
||||
void Indexer::BuildFromLines()
|
||||
{
|
||||
assert(m_pos == m_bufferStart);
|
||||
m_hasSequenceIds = false;
|
||||
size_t lines = 0;
|
||||
int64_t offset = GetFileOffset();
|
||||
while (!m_done)
|
||||
int64_t offset = m_buffer.GetFileOffset();
|
||||
while (!m_buffer.Eof())
|
||||
{
|
||||
m_pos = (char*)memchr(m_pos, ROW_DELIMITER, m_bufferEnd - m_pos);
|
||||
if (m_pos)
|
||||
auto pos = m_buffer.MoveToNextLine();
|
||||
if (pos)
|
||||
{
|
||||
auto sequenceOffset = offset;
|
||||
offset = GetFileOffset() + 1;
|
||||
offset = m_buffer.GetFileOffset();
|
||||
m_index.AddSequence(SequenceDescriptor{ lines, 1 }, sequenceOffset, offset);
|
||||
++m_pos;
|
||||
++lines;
|
||||
}
|
||||
else
|
||||
{
|
||||
RefillBuffer();
|
||||
}
|
||||
m_buffer.RefillFrom(m_file);
|
||||
}
|
||||
|
||||
if (offset < m_fileOffsetEnd)
|
||||
if (offset < m_fileSize)
|
||||
{
|
||||
// There's a number of characters, not terminated by a newline,
|
||||
// add a sequence to the index, parser will have to deal with it.
|
||||
m_index.AddSequence(SequenceDescriptor{ lines, 1 }, offset, m_fileOffsetEnd);
|
||||
m_index.AddSequence(SequenceDescriptor{ lines, 1 }, offset, m_fileSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,38 +69,31 @@ void Indexer::Build(CorpusDescriptorPtr corpus)
|
|||
else
|
||||
tryGetSequenceId = [this, corpus](size_t& id) { return TryGetSymbolicSequenceId(id, corpus->KeyToId); };
|
||||
|
||||
m_index.Reserve(filesize(m_file));
|
||||
m_index.Reserve(m_fileSize);
|
||||
|
||||
RefillBuffer(); // read the first block of data
|
||||
if (m_done)
|
||||
{
|
||||
m_buffer.RefillFrom(m_file);
|
||||
if (m_buffer.Eof())
|
||||
RuntimeError("Input file is empty");
|
||||
}
|
||||
|
||||
if ((m_bufferEnd - m_bufferStart > 3) &&
|
||||
(m_bufferStart[0] == '\xEF' && m_bufferStart[1] == '\xBB' && m_bufferStart[2] == '\xBF'))
|
||||
{
|
||||
// input file contains UTF-8 BOM value, skip it.
|
||||
m_pos += 3;
|
||||
m_fileOffsetStart += 3;
|
||||
m_bufferStart += 3;
|
||||
}
|
||||
m_buffer.SkipBOMIfPresent();
|
||||
|
||||
// check the first byte and decide what to do next
|
||||
if (!m_hasSequenceIds || m_bufferStart[0] == m_streamPrefix)
|
||||
if (!m_hasSequenceIds || *m_buffer.m_current == m_streamPrefix)
|
||||
{
|
||||
// Skip sequence id parsing, treat lines as individual sequences
|
||||
// In this case the sequences do not have ids, they are assigned a line number.
|
||||
// If corpus expects to have sequence ids as symbolic names we throw.
|
||||
if (!corpus->IsNumericSequenceKeys())
|
||||
RuntimeError("Corpus expects non-numeric sequence keys but the CTF input file does not have them.");
|
||||
RuntimeError("Corpus expects non-numeric sequence keys present but the input file does not have them."
|
||||
"Please use the configuration to enable numeric keys instead.");
|
||||
|
||||
BuildFromLines();
|
||||
m_index.MapSequenceKeyToLocation();
|
||||
return;
|
||||
}
|
||||
|
||||
size_t id = 0;
|
||||
int64_t offset = GetFileOffset();
|
||||
int64_t offset = m_buffer.GetFileOffset();
|
||||
// read the very first sequence id
|
||||
if (!tryGetSequenceId(id))
|
||||
{
|
||||
|
@ -145,13 +103,21 @@ void Indexer::Build(CorpusDescriptorPtr corpus)
|
|||
auto sequenceOffset = offset;
|
||||
size_t previousId = id;
|
||||
uint32_t numberOfSamples = 0;
|
||||
while (!m_done)
|
||||
while (!m_buffer.Eof())
|
||||
{
|
||||
SkipLine(); // ignore whatever is left on this line.
|
||||
offset = GetFileOffset(); // a new line starts at this offset;
|
||||
numberOfSamples++;
|
||||
if (!m_mainStream.empty())
|
||||
{
|
||||
if(SkipLineWithCheck())
|
||||
numberOfSamples++;
|
||||
}
|
||||
else
|
||||
{
|
||||
SkipLine(); // ignore whatever is left on this line.
|
||||
numberOfSamples++;
|
||||
}
|
||||
|
||||
if (!m_done && tryGetSequenceId(id) && id != previousId)
|
||||
offset = m_buffer.GetFileOffset(); // a new line starts at this offset;
|
||||
if (!m_buffer.Eof() && tryGetSequenceId(id) && id != previousId)
|
||||
{
|
||||
// found a new sequence, which starts at the [offset] bytes into the file
|
||||
// adding the previous one to the index.
|
||||
|
@ -163,34 +129,56 @@ void Indexer::Build(CorpusDescriptorPtr corpus)
|
|||
}
|
||||
}
|
||||
|
||||
m_index.AddSequence(SequenceDescriptor{ previousId, numberOfSamples }, sequenceOffset, m_fileOffsetEnd);
|
||||
m_index.AddSequence(SequenceDescriptor{ previousId, numberOfSamples }, sequenceOffset, m_fileSize);
|
||||
m_index.MapSequenceKeyToLocation();
|
||||
}
|
||||
|
||||
void Indexer::SkipLine()
|
||||
{
|
||||
while (!m_done)
|
||||
while (!m_buffer.Eof())
|
||||
{
|
||||
m_pos = (char*)memchr(m_pos, ROW_DELIMITER, m_bufferEnd - m_pos);
|
||||
if (m_pos)
|
||||
auto pos = m_buffer.MoveToNextLine();
|
||||
if (pos)
|
||||
{
|
||||
//found a new-line character
|
||||
if (++m_pos == m_bufferEnd)
|
||||
{
|
||||
RefillBuffer();
|
||||
}
|
||||
if (pos == m_buffer.End())
|
||||
m_buffer.RefillFrom(m_file);
|
||||
return;
|
||||
}
|
||||
RefillBuffer();
|
||||
|
||||
m_buffer.RefillFrom(m_file);
|
||||
}
|
||||
}
|
||||
|
||||
bool Indexer::SkipLineWithCheck()
|
||||
{
|
||||
auto currentLine = m_buffer.m_current;
|
||||
auto pos = m_buffer.MoveToNextLine();
|
||||
if (pos)
|
||||
{
|
||||
boost::string_ref s(currentLine, pos - currentLine);
|
||||
bool found = s.find(m_mainStream) != boost::string_ref::npos;
|
||||
if (pos == m_buffer.End())
|
||||
m_buffer.RefillFrom(m_file);
|
||||
|
||||
return found;
|
||||
}
|
||||
|
||||
if (currentLine != m_buffer.End())
|
||||
RuntimeError("Unexpected end of line");
|
||||
|
||||
m_buffer.RefillFrom(m_file);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool Indexer::TryGetNumericSequenceId(size_t& id)
|
||||
{
|
||||
bool found = false;
|
||||
id = 0;
|
||||
while (!m_done)
|
||||
while (!m_buffer.Eof())
|
||||
{
|
||||
char c = *m_pos;
|
||||
char c = *m_buffer.m_current;
|
||||
if (!isdigit(c))
|
||||
{
|
||||
// Stop as soon as there's a non-digit character
|
||||
|
@ -205,10 +193,10 @@ bool Indexer::TryGetNumericSequenceId(size_t& id)
|
|||
}
|
||||
|
||||
found = true;
|
||||
++m_pos;
|
||||
++m_buffer.m_current;
|
||||
|
||||
if (m_pos == m_bufferEnd)
|
||||
RefillBuffer();
|
||||
if (m_buffer.m_current == m_buffer.End())
|
||||
m_buffer.RefillFrom(m_file);
|
||||
}
|
||||
|
||||
// reached EOF without hitting the pipe character,
|
||||
|
@ -222,9 +210,9 @@ bool Indexer::TryGetSymbolicSequenceId(size_t& id, std::function<size_t(const st
|
|||
id = 0;
|
||||
std::string key;
|
||||
key.reserve(256);
|
||||
while (!m_done)
|
||||
while (!m_buffer.Eof())
|
||||
{
|
||||
char c = *m_pos;
|
||||
char c = *m_buffer.m_current;
|
||||
if (isspace(c))
|
||||
{
|
||||
if (found)
|
||||
|
@ -234,10 +222,10 @@ bool Indexer::TryGetSymbolicSequenceId(size_t& id, std::function<size_t(const st
|
|||
|
||||
key += c;
|
||||
found = true;
|
||||
++m_pos;
|
||||
++m_buffer.m_current;
|
||||
|
||||
if(m_pos == m_bufferEnd)
|
||||
RefillBuffer();
|
||||
if(m_buffer.m_current == m_buffer.End())
|
||||
m_buffer.RefillFrom(m_file);
|
||||
}
|
||||
|
||||
// reached EOF without hitting the pipe character,
|
||||
|
@ -249,42 +237,56 @@ void Index::AddSequence(SequenceDescriptor&& sd, size_t startOffsetInFile, size_
|
|||
{
|
||||
sd.SetSize(endOffsetInFile - startOffsetInFile);
|
||||
|
||||
assert(!m_chunks.empty());
|
||||
if (m_chunks.empty() || !m_chunks.back().HasSpaceFor(sd))
|
||||
{
|
||||
m_chunks.push_back({ m_maxChunkSize, startOffsetInFile });
|
||||
if (std::numeric_limits<ChunkIdType>::max() < m_chunks.size())
|
||||
RuntimeError("Maximum number of chunks exceeded.");
|
||||
}
|
||||
|
||||
ChunkDescriptor* chunk = &m_chunks.back();
|
||||
if (chunk->m_byteSize > 0 && (chunk->m_byteSize + sd.m_byteSize) > m_maxChunkSize)
|
||||
{
|
||||
// If the size is exceeded, finalizing the current chunk
|
||||
// and creating a new one.
|
||||
chunk->m_sequences.shrink_to_fit();
|
||||
|
||||
m_chunks.push_back({});
|
||||
chunk = &m_chunks.back();
|
||||
chunk->m_id = (ChunkIdType)(m_chunks.size() - 1);
|
||||
chunk->m_offset = startOffsetInFile;
|
||||
|
||||
if (CHUNKID_MAX < m_chunks.size())
|
||||
{
|
||||
RuntimeError("Maximum number of chunks exceeded");
|
||||
}
|
||||
}
|
||||
|
||||
if (m_trackFirstSamples) // Adding number of samples where the new sequence starts.
|
||||
chunk->m_sequenceOffsetInChunkInSamples.push_back(static_cast<uint32_t>(chunk->m_numberOfSamples));
|
||||
|
||||
chunk->m_byteSize += sd.m_byteSize;
|
||||
chunk->m_numberOfSequences++;
|
||||
chunk->m_numberOfSamples += sd.m_numberOfSamples;
|
||||
if (!m_primary)
|
||||
{
|
||||
auto location = std::make_pair(chunk->m_id, static_cast<uint32_t>(chunk->m_sequences.size()));
|
||||
if (location.second != chunk->m_sequences.size())
|
||||
RuntimeError("Number of sequences overflow the chunk capacity.");
|
||||
|
||||
m_keyToSequenceInChunk.insert(std::make_pair(sd.m_key, location));
|
||||
}
|
||||
|
||||
sd.SetOffsetInChunk(startOffsetInFile - chunk->m_offset);
|
||||
chunk->m_sequences.push_back(sd);
|
||||
chunk->AddSequence(std::move(sd), m_trackFirstSamples);
|
||||
}
|
||||
|
||||
std::tuple<bool, uint32_t, uint32_t> Index::GetSequenceByKey(size_t key) const
|
||||
{
|
||||
auto found = std::lower_bound(m_keyToSequenceInChunk.begin(), m_keyToSequenceInChunk.end(), key,
|
||||
[](const std::tuple<size_t, size_t, size_t>& a, size_t b)
|
||||
{
|
||||
return std::get<0>(a) < b;
|
||||
});
|
||||
|
||||
if (found == m_keyToSequenceInChunk.end() || std::get<0>(*found) != key)
|
||||
{
|
||||
return std::make_tuple(false, 0, 0);
|
||||
}
|
||||
|
||||
return std::make_tuple(true, std::get<1>(*found), std::get<2>(*found));
|
||||
}
|
||||
|
||||
void Index::MapSequenceKeyToLocation()
|
||||
{
|
||||
if (m_primary)
|
||||
return;
|
||||
|
||||
// Precalculate size of the mapping.
|
||||
size_t numSequences = 0;
|
||||
for (const auto& c : m_chunks)
|
||||
numSequences += c.Sequences().size();
|
||||
|
||||
m_keyToSequenceInChunk.reserve(numSequences);
|
||||
|
||||
for (uint32_t i = 0; i < m_chunks.size(); i++)
|
||||
for (uint32_t j = 0; j < m_chunks[i].Sequences().size(); j++)
|
||||
m_keyToSequenceInChunk.emplace_back(m_chunks[i].Sequences()[j].m_key, i, j);
|
||||
|
||||
// Sort for fast retrieval afterwards
|
||||
std::sort(m_keyToSequenceInChunk.begin(), m_keyToSequenceInChunk.end(),
|
||||
[](const std::tuple<size_t, uint32_t, uint32_t>& a, const std::tuple<size_t, uint32_t, uint32_t>& b)
|
||||
{
|
||||
return std::get<0>(a) < std::get<0>(b);
|
||||
});
|
||||
}
|
||||
|
||||
}}}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <vector>
|
||||
#include "DataDeserializer.h"
|
||||
#include "CorpusDescriptor.h"
|
||||
#include "MemoryBuffer.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
@ -53,16 +54,59 @@ private:
|
|||
|
||||
uint32_t m_offsetInChunk; // sequence offset in the chunk (in bytes)
|
||||
uint32_t m_byteSize; // size in bytes
|
||||
|
||||
friend struct Index;
|
||||
friend class ChunkDescriptor;
|
||||
};
|
||||
|
||||
// Chunk metadata, similar to the sequence descriptor above,
|
||||
// but used to facilitate indexing and retrieval of blobs of input data of
|
||||
// some user-specified size.
|
||||
struct ChunkDescriptor : ChunkDescription
|
||||
class ChunkDescriptor
|
||||
{
|
||||
ChunkDescriptor() : ChunkDescription({}), m_byteSize(0), m_offset(0) {}
|
||||
ChunkDescriptor() : m_maxSizeInBytes(0), m_offset(0) {}
|
||||
|
||||
public:
|
||||
const size_t m_maxSizeInBytes;
|
||||
|
||||
// offset of the chunk in bytes
|
||||
const size_t m_offset;
|
||||
|
||||
ChunkDescriptor(size_t maxSizeInBytes, size_t startOffset)
|
||||
: m_maxSizeInBytes(maxSizeInBytes), m_sizeInBytes(0),
|
||||
m_offset(startOffset), m_numberOfSamples(0)
|
||||
{}
|
||||
|
||||
bool HasSpaceFor(const SequenceDescriptor& sd) const
|
||||
{
|
||||
return m_sizeInBytes == 0 || m_sizeInBytes + sd.m_byteSize <= m_maxSizeInBytes;
|
||||
}
|
||||
|
||||
void AddSequence(SequenceDescriptor&& sd, bool trackFirstSample = false)
|
||||
{
|
||||
assert(HasSpaceFor(sd));
|
||||
if (trackFirstSample) // Adding number of samples where the new sequence starts.
|
||||
m_sequenceOffsetInChunkInSamples.push_back(static_cast<uint32_t>(m_numberOfSamples));
|
||||
|
||||
m_sizeInBytes += sd.m_byteSize;
|
||||
m_numberOfSamples += sd.m_numberOfSamples;
|
||||
m_sequences.push_back(std::move(sd));
|
||||
|
||||
if (m_sizeInBytes >= m_maxSizeInBytes) // Last one, finalizing.
|
||||
m_sequences.shrink_to_fit();
|
||||
|
||||
if (m_sequences.size() > std::numeric_limits<uint32_t>::max())
|
||||
RuntimeError("Exceeded maximum number of sequences in a chunk");
|
||||
}
|
||||
|
||||
size_t SizeInBytes() const { return m_sizeInBytes; }
|
||||
size_t NumSamples() const { return m_numberOfSamples; }
|
||||
const std::vector<SequenceDescriptor>& Sequences() const { return m_sequences; }
|
||||
|
||||
// Offset of first sample of each sequence from the beginning of the chunk.
|
||||
const std::vector<uint32_t>& SequenceOffsetInSamples() const { return m_sequenceOffsetInChunkInSamples; }
|
||||
|
||||
private:
|
||||
// TODO: if we don't want to keep the whole index
|
||||
// (metadata for all sequences in memory), we should not
|
||||
// leave this empty when building a chunk index, and only
|
||||
|
@ -70,8 +114,8 @@ struct ChunkDescriptor : ChunkDescription
|
|||
// (the indexer will have to do a second pass for this chunk).
|
||||
std::vector<SequenceDescriptor> m_sequences;
|
||||
|
||||
size_t m_offset; // offset of the chunk in bytes
|
||||
size_t m_byteSize; // size in bytes
|
||||
size_t m_numberOfSamples;
|
||||
size_t m_sizeInBytes;
|
||||
|
||||
// Offset of first sample of each sequence from the beginning of the chunk.
|
||||
// Optionally filled in by the indexer.
|
||||
|
@ -86,8 +130,16 @@ typedef shared_ptr<ChunkDescriptor> ChunkDescriptorPtr;
|
|||
// It also stores a mapping of keys into sequence descriptors.
|
||||
struct Index
|
||||
{
|
||||
std::vector<ChunkDescriptor> m_chunks; // chunks
|
||||
std::map<size_t, std::pair<uint32_t, uint32_t>> m_keyToSequenceInChunk; // sequence key -> <chunk index, sequence index in chunk>
|
||||
private:
|
||||
std::vector<ChunkDescriptor> m_chunks;
|
||||
|
||||
public:
|
||||
const std::vector<ChunkDescriptor>& Chunks() const { return m_chunks; }
|
||||
|
||||
// Sorted dictionary of <sequence key, chunk index, sequence index in chunk>
|
||||
// used for fast retrieval of sequence by key for non primary deserializers.
|
||||
std::vector<std::tuple<size_t, uint32_t, uint32_t>> m_keyToSequenceInChunk;
|
||||
|
||||
const size_t m_maxChunkSize; // maximum chunk size in bytes
|
||||
bool m_primary; // index for primary deserializer
|
||||
bool m_trackFirstSamples; // flag indicating whether to build index of first samples
|
||||
|
@ -97,7 +149,8 @@ struct Index
|
|||
|
||||
Index(size_t chunkSize, bool primary, bool trackFirstSamples = false)
|
||||
: m_maxChunkSize(chunkSize), m_primary(primary), m_trackFirstSamples(trackFirstSamples)
|
||||
{}
|
||||
{
|
||||
}
|
||||
|
||||
// Adds sequence (metadata) to the index. Additionally, it
|
||||
// assigns an appropriate chunk id to the sequence descriptor,
|
||||
|
@ -109,11 +162,7 @@ struct Index
|
|||
void Reserve(size_t sizeInBytes)
|
||||
{
|
||||
if (m_maxChunkSize > 0)
|
||||
{
|
||||
m_chunks.reserve((sizeInBytes + m_maxChunkSize - 1) / m_maxChunkSize);
|
||||
}
|
||||
|
||||
m_chunks.push_back({});
|
||||
}
|
||||
|
||||
// Checks if the index is empty.
|
||||
|
@ -122,6 +171,11 @@ struct Index
|
|||
return m_chunks.empty();
|
||||
}
|
||||
|
||||
// Returns true or false with chunk and sequence index depending if the key has been found.
|
||||
std::tuple<bool, uint32_t, uint32_t> GetSequenceByKey(size_t key) const;
|
||||
|
||||
void MapSequenceKeyToLocation();
|
||||
|
||||
DISABLE_COPY_AND_MOVE(Index);
|
||||
};
|
||||
|
||||
|
@ -133,7 +187,7 @@ struct Index
|
|||
class Indexer
|
||||
{
|
||||
public:
|
||||
Indexer(FILE* file, bool isPrimary, bool skipSequenceIds = false, char streamPrefix = '|', size_t chunkSize = 32 * 1024 * 1024, size_t bufferSize = 2 * 1024 * 1024);
|
||||
Indexer(FILE* file, bool isPrimary, bool skipSequenceIds = false, char streamPrefix = '|', size_t chunkSize = 32 * 1024 * 1024, const std::string& mainStream = "", size_t bufferSize = 2 * 1024 * 1024);
|
||||
|
||||
// Reads the input file, building and index of chunks and corresponding
|
||||
// sequences.
|
||||
|
@ -147,35 +201,33 @@ public:
|
|||
// (by passing skipSequenceIds = true to the constructor).
|
||||
bool HasSequenceIds() const { return m_hasSequenceIds; }
|
||||
|
||||
const std::string& MainStream() const
|
||||
{
|
||||
return m_mainStream;
|
||||
}
|
||||
|
||||
private:
|
||||
FILE* m_file;
|
||||
|
||||
int64_t m_fileOffsetStart;
|
||||
int64_t m_fileOffsetEnd;
|
||||
|
||||
std::unique_ptr<char[]> m_buffer;
|
||||
const size_t m_bufferSize;
|
||||
const char* m_bufferStart;
|
||||
const char* m_bufferEnd;
|
||||
const char* m_pos; // buffer index
|
||||
|
||||
bool m_done; // true, when all input was processed
|
||||
|
||||
int64_t m_fileSize;
|
||||
MemoryBuffer m_buffer;
|
||||
bool m_hasSequenceIds; // true, when input contains one sequence per line
|
||||
// or when sequence id column was ignored during indexing.
|
||||
|
||||
// Stream that defines the size of the sequence.
|
||||
std::string m_mainStream;
|
||||
|
||||
// a collection of chunk descriptors and sequence keys.
|
||||
Index m_index;
|
||||
|
||||
const char m_streamPrefix;
|
||||
|
||||
// fills up the buffer with data from file, all previously buffered data
|
||||
// will be overwritten.
|
||||
void RefillBuffer();
|
||||
|
||||
// Moves the buffer position to the beginning of the next line.
|
||||
void SkipLine();
|
||||
|
||||
// Moves the buffer position to the beginning of the next line.
|
||||
// Returns true if the current line container m_mainStream.
|
||||
bool SkipLineWithCheck();
|
||||
|
||||
// Tries to get numeric sequence id.
|
||||
// Throws an exception if a non-numerical is read until the pipe character or
|
||||
// EOF is reached without hitting the pipe character.
|
||||
|
@ -187,15 +239,11 @@ private:
|
|||
// It reads a symbolic key and converts it to numeric id using provided keyToId function.
|
||||
bool TryGetSymbolicSequenceId(size_t& id, std::function<size_t(const std::string&)> keyToId);
|
||||
|
||||
|
||||
// Build a chunk/sequence index, treating each line as an individual sequence.
|
||||
// Does not do any sequence parsing, instead uses line number as
|
||||
// the corresponding sequence id.
|
||||
void BuildFromLines();
|
||||
|
||||
// Returns current offset in the input file (in bytes).
|
||||
int64_t GetFileOffset() const { return m_fileOffsetStart + (m_pos - m_bufferStart); }
|
||||
|
||||
DISABLE_COPY_AND_MOVE(Indexer);
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#include "MemoryBuffer.h"
|
||||
#include <boost/utility/string_ref.hpp>
|
||||
#include <boost/algorithm/string.hpp>
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
using namespace std;
|
||||
|
||||
MemoryBuffer::MemoryBuffer(size_t maxSize, bool useCompleteLines) : m_maxSize(maxSize), m_useCompleteLines(useCompleteLines) {}
|
||||
|
||||
void MemoryBuffer::RefillFrom(FILE* file)
|
||||
{
|
||||
if (m_done)
|
||||
return;
|
||||
|
||||
m_fileOffsetStart += m_data.size();
|
||||
m_data.resize(m_maxSize);
|
||||
|
||||
if (!m_useCompleteLines)
|
||||
{
|
||||
size_t bytesRead = fread(m_data.data(), 1, m_maxSize, file);
|
||||
if (bytesRead == (size_t)-1)
|
||||
RuntimeError("Could not read from the input file.");
|
||||
m_data.resize(bytesRead);
|
||||
if (!bytesRead)
|
||||
m_done = true;
|
||||
}
|
||||
else // Need to keep track of the last partial string.
|
||||
{
|
||||
if (m_lastPartialLineInBuffer.size() >= m_maxSize)
|
||||
RuntimeError("Length of a sequence cannot exceed '%zu' bytes.", m_maxSize);
|
||||
|
||||
// Copy last partial line if it was left during the last read.
|
||||
memcpy(&m_data[0], m_lastPartialLineInBuffer.data(), m_lastPartialLineInBuffer.size());
|
||||
|
||||
size_t bytesRead = fread(&m_data[0] + m_lastPartialLineInBuffer.size(), 1, m_data.size() - m_lastPartialLineInBuffer.size(), file);
|
||||
if (bytesRead == (size_t)-1)
|
||||
RuntimeError("Could not read from the input file.");
|
||||
|
||||
if (bytesRead == 0) // End of file reached.
|
||||
{
|
||||
boost::trim(m_lastPartialLineInBuffer);
|
||||
if (!m_lastPartialLineInBuffer.empty())
|
||||
memcpy(&m_data[0], m_lastPartialLineInBuffer.data(), m_lastPartialLineInBuffer.size());
|
||||
else
|
||||
{
|
||||
m_done = true;
|
||||
m_data.clear();
|
||||
}
|
||||
|
||||
m_lastPartialLineInBuffer.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
size_t readBufferSize = m_lastPartialLineInBuffer.size() + bytesRead;
|
||||
|
||||
// Let's find the last LF.
|
||||
int lastLF = 0;
|
||||
{
|
||||
// Let's find the latest \n if exists.
|
||||
for (lastLF = (int)readBufferSize - 1; lastLF >= 0; lastLF--)
|
||||
{
|
||||
if (m_data[lastLF] == g_Row_Delimiter)
|
||||
break;
|
||||
}
|
||||
|
||||
if (lastLF < 0)
|
||||
RuntimeError("Length of a sequence cannot exceed '%zu' bytes.", readBufferSize);
|
||||
}
|
||||
|
||||
// Let's cut the buffer at the last EOL and save partial string
|
||||
// in m_lastPartialLineInBuffer.
|
||||
auto logicalBufferSize = lastLF + 1;
|
||||
auto lastPartialLineSize = readBufferSize - logicalBufferSize;
|
||||
|
||||
// Remember the last parital line.
|
||||
m_lastPartialLineInBuffer.resize(lastPartialLineSize);
|
||||
if (lastPartialLineSize)
|
||||
memcpy(&m_lastPartialLineInBuffer[0], m_data.data() + logicalBufferSize, lastPartialLineSize);
|
||||
m_data.resize(logicalBufferSize);
|
||||
}
|
||||
|
||||
m_current = m_data.data();
|
||||
}
|
||||
|
||||
void MemoryBuffer::SkipBOMIfPresent()
|
||||
{
|
||||
assert(m_current == m_data.data());
|
||||
if ((m_data.size() > 3) &&
|
||||
(m_data[0] == '\xEF' && m_data[1] == '\xBB' && m_data[2] == '\xBF'))
|
||||
{
|
||||
m_current += 3;
|
||||
}
|
||||
}
|
||||
|
||||
}}}
|
|
@ -0,0 +1,67 @@
|
|||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <vector>
|
||||
#include "Basics.h"
|
||||
#include "ReaderConstants.h"
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
class MemoryBuffer
|
||||
{
|
||||
public:
|
||||
MemoryBuffer(size_t maxSize, bool useCompleteLines = false);
|
||||
|
||||
// Pointer to the start of the buffer.
|
||||
const char* Start() const { return m_data.data(); }
|
||||
|
||||
// Pointer to the end of the buffer.
|
||||
const char* End() const { return m_data.data() + m_data.size(); }
|
||||
|
||||
// Current position in the buffer.
|
||||
const char* m_current = 0;
|
||||
|
||||
// File offset that correspond to the current position.
|
||||
int64_t GetFileOffset() const { return m_fileOffsetStart + (m_current - Start()); }
|
||||
|
||||
// Skips UTF-8 BOM value, if it is present at current position.
|
||||
void SkipBOMIfPresent();
|
||||
|
||||
// Refills the buffer from the file.
|
||||
void RefillFrom(FILE* file);
|
||||
|
||||
// Moves the current position to the next line.
|
||||
// If no new lines is present, returns null, otherwise returns a new position.
|
||||
const char* MoveToNextLine()
|
||||
{
|
||||
m_current = (char*)memchr(m_current, g_Row_Delimiter, End() - m_current);
|
||||
if (m_current)
|
||||
{
|
||||
++m_current;
|
||||
return m_current;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_current = End();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if no more data available.
|
||||
bool Eof() const { return m_done; }
|
||||
|
||||
private:
|
||||
const size_t m_maxSize; // Max size of the buffer.
|
||||
std::vector<char> m_data; // Buffer.
|
||||
int64_t m_fileOffsetStart = 0; // Current file offset that the buffer is associated with.
|
||||
bool m_done = false; // Flag indicating whether there is more data.
|
||||
bool m_useCompleteLines; // Flag indicating whether the buffer should only contain complete lines.
|
||||
std::string m_lastPartialLineInBuffer; // Buffer for the partial string to preserve them between two sequential Refills.
|
||||
};
|
||||
|
||||
}}}
|
|
@ -78,6 +78,7 @@ struct StreamDescription
|
|||
ElementType m_elementType; // Element type of the stream
|
||||
TensorShapePtr m_sampleLayout; // Layout of the sample for the stream
|
||||
// If not specified - can be specified per sequence
|
||||
bool m_definesMbSize; // Flag indicating whether the stream is defining the minibatch size
|
||||
};
|
||||
typedef std::shared_ptr<StreamDescription> StreamDescriptionPtr;
|
||||
|
||||
|
|
|
@ -17,4 +17,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
static size_t const g_4GB = 0x100000000L;
|
||||
|
||||
const static char g_Row_Delimiter = '\n';
|
||||
|
||||
}}}
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
<ClInclude Include="ChunkRandomizer.h" />
|
||||
<ClInclude Include="ExceptionCapture.h" />
|
||||
<ClInclude Include="Indexer.h" />
|
||||
<ClInclude Include="MemoryBuffer.h" />
|
||||
<ClInclude Include="ReaderBase.h" />
|
||||
<ClInclude Include="ReaderConstants.h" />
|
||||
<ClInclude Include="SequenceData.h" />
|
||||
|
@ -83,7 +84,9 @@
|
|||
<ClCompile Include="Bundler.cpp" />
|
||||
<ClCompile Include="ChunkCache.cpp" />
|
||||
<ClCompile Include="ChunkRandomizer.cpp" />
|
||||
<ClCompile Include="DataDeserializerBase.cpp" />
|
||||
<ClCompile Include="Indexer.cpp" />
|
||||
<ClCompile Include="MemoryBuffer.cpp" />
|
||||
<ClCompile Include="NoRandomizer.cpp" />
|
||||
<ClCompile Include="BlockRandomizer.cpp" />
|
||||
<ClCompile Include="PackerBase.cpp" />
|
||||
|
|
|
@ -94,6 +94,9 @@
|
|||
<ClInclude Include="ReaderConstants.h">
|
||||
<Filter>Utils</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="MemoryBuffer.h">
|
||||
<Filter>Utils</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="NoRandomizer.cpp">
|
||||
|
@ -138,6 +141,12 @@
|
|||
<ClCompile Include="ReaderUtil.cpp">
|
||||
<Filter>Utils</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DataDeserializerBase.cpp">
|
||||
<Filter>Deserializers</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="MemoryBuffer.cpp">
|
||||
<Filter>Utils</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Filter Include="Interfaces">
|
||||
|
|
|
@ -504,11 +504,13 @@ def HTKFeatureDeserializer(streams):
|
|||
dimension = stream.dim
|
||||
scp_file = stream['scp']
|
||||
broadcast = stream['broadcast'] if 'broadcast' in stream else False
|
||||
defines_mb_size = stream['defines_mb_size'] if 'defines_mb_size' in stream else False
|
||||
left_context, right_context = stream.context if 'context' in stream\
|
||||
else (0, 0)
|
||||
htk_config = cntk_py.HTKFeatureConfiguration(stream_name, scp_file,
|
||||
dimension, left_context,
|
||||
right_context, broadcast)
|
||||
right_context, broadcast,
|
||||
defines_mb_size)
|
||||
feat.append(htk_config)
|
||||
|
||||
if len(feat) == 0:
|
||||
|
@ -626,6 +628,9 @@ def CTFDeserializer(filename, streams):
|
|||
|
||||
Args:
|
||||
filename (str): file name containing the text input
|
||||
streams: any dictionary-like object that contains a mapping from stream
|
||||
names to :class:`StreamDef` objects. Each StreamDef object configures
|
||||
an input stream.
|
||||
|
||||
See also:
|
||||
:cntkwiki:`CNTKTextReader format <BrainScript-CNTKTextFormat-Reader>`
|
||||
|
@ -635,7 +640,7 @@ def CTFDeserializer(filename, streams):
|
|||
raise ValueError("CTFDeserializer: stream name for key %s must be "
|
||||
"specified" % k)
|
||||
sc = [cntk_py.StreamConfiguration(
|
||||
k, s.dim, s.is_sparse, s.stream_alias) for k, s in streams.items()]
|
||||
k, s.dim, s.is_sparse, s.stream_alias, s['defines_mb_size']) for k, s in streams.items()]
|
||||
return cntk_py.ctf_deserializer(filename, sc)
|
||||
|
||||
# TODO: this should be a private class; use StreamDef instead
|
||||
|
@ -654,16 +659,18 @@ class StreamConfiguration(cntk_py.StreamConfiguration):
|
|||
is_sparse (bool, defaults to `False`): whether the provided data is
|
||||
sparse (`False` by default)
|
||||
stream_alias (str, defaults to ''): name of the stream in the file
|
||||
defines_mb_size (`bool`, defaults to False): whether this stream defines
|
||||
the minibatch size.
|
||||
'''
|
||||
|
||||
def __init__(self, name, dim, is_sparse=False, stream_alias=''):
|
||||
def __init__(self, name, dim, is_sparse=False, stream_alias='', defines_mb_size = False):
|
||||
return super(StreamConfiguration, self).__init__(name, dim, is_sparse,
|
||||
stream_alias)
|
||||
stream_alias, defines_mb_size)
|
||||
|
||||
# stream definition for use in StreamDefs
|
||||
# returns a record { stream_alias, is_sparse, optional shape, optional transforms, optional context, optional scp, optional mlf }
|
||||
def StreamDef(field=None, shape=None, is_sparse=False, transforms=None,
|
||||
context=None, scp=None, mlf=None, broadcast=None):
|
||||
context=None, scp=None, mlf=None, broadcast=None, defines_mb_size=False):
|
||||
'''
|
||||
Configuration of a stream for use with the builtin Deserializers.
|
||||
The meanings of some configuration keys have a mild dependency on the
|
||||
|
@ -695,6 +702,8 @@ def StreamDef(field=None, shape=None, is_sparse=False, transforms=None,
|
|||
broadcast (`bool`, defaults to `None`): whether the features in this
|
||||
stream should be broadcast to the whole sequence (useful in e.g.
|
||||
ivectors with HTK)
|
||||
defines_mb_size (`bool`, defaults to False): whether this stream defines
|
||||
the minibatch size.
|
||||
'''
|
||||
config = dict(stream_alias=field, is_sparse=is_sparse)
|
||||
if shape is not None:
|
||||
|
@ -710,6 +719,8 @@ def StreamDef(field=None, shape=None, is_sparse=False, transforms=None,
|
|||
config['is_sparse'] = True
|
||||
if broadcast is not None:
|
||||
config['broadcast'] = broadcast
|
||||
config['defines_mb_size'] = True if defines_mb_size else False
|
||||
|
||||
return Record(**config)
|
||||
# TODO: we should always use 'shape' unless it is always rank-1 or a single rank's dimension
|
||||
# TODO: dim should be inferred from the file, at least for dense
|
||||
|
|
|
@ -830,3 +830,53 @@ def test_usermbsource_training(tmpdir):
|
|||
session.train()
|
||||
|
||||
assert trainer.total_number_of_samples_seen == 20
|
||||
|
||||
|
||||
def test_minibatch_defined_by_labels(tmpdir):
|
||||
tmpfile = _write_data(tmpdir, MBDATA_SPARSE)
|
||||
|
||||
input_dim = 1000
|
||||
num_output_classes = 5
|
||||
|
||||
mb_source = MinibatchSource(CTFDeserializer(tmpfile, StreamDefs(
|
||||
features=StreamDef(field='x', shape=input_dim, is_sparse=True),
|
||||
labels=StreamDef(field='y', shape=num_output_classes, is_sparse=False, defines_mb_size=True)
|
||||
)), randomize=False)
|
||||
|
||||
assert isinstance(mb_source, MinibatchSource)
|
||||
|
||||
features_si = mb_source.stream_info('features')
|
||||
labels_si = mb_source.stream_info('labels')
|
||||
|
||||
mb = mb_source.next_minibatch(2)
|
||||
|
||||
features = mb[features_si]
|
||||
|
||||
# 2 samples, max seq len 4, 1000 dim
|
||||
assert features.shape == (2, 4, input_dim)
|
||||
assert features.end_of_sweep
|
||||
assert features.num_sequences == 2
|
||||
assert features.num_samples == 7
|
||||
assert features.is_sparse
|
||||
|
||||
labels = mb[labels_si]
|
||||
# 2 samples, max seq len 1, 5 dim
|
||||
assert labels.shape == (2, 1, num_output_classes)
|
||||
assert labels.end_of_sweep
|
||||
assert labels.num_sequences == 2
|
||||
assert labels.num_samples == 2
|
||||
assert not labels.is_sparse
|
||||
|
||||
label_data = labels.asarray()
|
||||
assert np.allclose(label_data,
|
||||
np.asarray([
|
||||
[[1., 0., 0., 0., 0.]],
|
||||
[[0., 1., 0., 0., 0.]]
|
||||
]))
|
||||
|
||||
mb = mb_source.next_minibatch(3)
|
||||
features = mb[features_si]
|
||||
labels = mb[labels_si]
|
||||
|
||||
assert features.num_samples == 10
|
||||
assert labels.num_samples == 3
|
||||
|
|
Загрузка…
Ссылка в новой задаче