Adding GetCurrentSamplePosition to the reader interface

This commit is contained in:
Eldar Akchurin 2016-09-28 13:31:02 +02:00
Родитель 3c459d347e
Коммит 4d7d430359
24 изменённых файлов: 173 добавлений и 29 удалений

Просмотреть файл

@ -224,6 +224,13 @@ void DataReader::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size
}
}
size_t DataReader::GetCurrentSamplePosition()
{
// BUGBUG: composition of old readers is not supported.
// Returning just for the last reader.
return m_dataReaders[m_ioNames.back()]->GetCurrentSamplePosition();
}
// GetMinibatch - Get the next minibatch (features and labels)
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponding matrix,
// [out] each matrix resized if necessary containing data.

Просмотреть файл

@ -244,6 +244,9 @@ public:
{
return true;
};
// Gets current sample position on the global timeline.
virtual size_t GetCurrentSamplePosition() = 0;
virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize)
{
@ -416,6 +419,8 @@ public:
}
virtual ~DataReader();
size_t GetCurrentSamplePosition() override;
// StartMinibatchLoop - Startup a minibatch loop
// mbSize - [in] size of the minibatch (number of frames, etc.)
// epoch - [in] epoch number for this loop

Просмотреть файл

@ -100,6 +100,11 @@ public:
{
}
size_t GetCurrentSamplePosition() override
{
NOT_IMPLEMENTED;
}
// StartMinibatchLoop - Startup a minibatch loop
// mbSize - [in] size of the minibatch (number of frames, etc.)
// epoch - [in] epoch number for this loop

Просмотреть файл

@ -609,6 +609,11 @@ public:
{
NOT_IMPLEMENTED;
};
size_t GetCurrentSamplePosition() override
{
return m_mbStartSample;
}
};
template <class ElemType>

Просмотреть файл

@ -182,5 +182,12 @@ public:
{
NOT_IMPLEMENTED;
}
size_t GetCurrentSamplePosition() override
{
// We do not support adaptive minibatch for this reader,
// CTF should be used instead.
NOT_IMPLEMENTED;
}
};
} } }

Просмотреть файл

@ -197,6 +197,9 @@ public:
void SetSentenceEnd(int /*actualMbSize*/){};
void SetRandomSeed(int){NOT_IMPLEMENTED};
//bool RequireSentenceSeg() const override { return !m_frameMode; };
size_t GetCurrentSamplePosition() override
{
return m_mbiter->currentmbstartframe();
}
};
} } }

Просмотреть файл

@ -212,6 +212,11 @@ public:
{
pMBLayout->CopyFrom(m_pMBLayout);
}
size_t GetCurrentSamplePosition() override
{
return m_mbiter->currentmbstartframe();
}
//bool RequireSentenceSeg() const override { return !m_framemode; };
};
} } }

Просмотреть файл

@ -286,7 +286,10 @@ public:
virtual bool DataEnd();
//int GetSentenceEndIdFromOutputLabel() { return -1; };
size_t GetCurrentSamplePosition() override
{
return m_mbStartSample;
}
};
template <class ElemType>

Просмотреть файл

@ -191,8 +191,10 @@ public:
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
//public:
// int GetSentenceEndIdFromOutputLabel();
size_t GetCurrentSamplePosition() override
{
return m_mbStartSample;
}
};
template <class ElemType>

Просмотреть файл

@ -294,6 +294,13 @@ public:
void SetNbrSlicesEachRecurrentIter(const size_t){};
void SetSentenceEndInBatch(std::vector<size_t>& /*sentenceEnd*/){};
size_t GetCurrentSamplePosition() override
{
// We do not support adaptive minibatch for this reader,
// CTF should be used instead.
NOT_IMPLEMENTED;
}
private:
#if DEBUG
marker_series* reader_series;

Просмотреть файл

@ -52,6 +52,11 @@ BlockRandomizer::BlockRandomizer(
}
}
size_t BlockRandomizer::GetCurrentSamplePosition()
{
return m_globalSamplePosition;
}
// Start a new epoch.
void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
{

Просмотреть файл

@ -63,6 +63,9 @@ public:
return m_deserializer->GetStreamDescriptions();
}
// Returns current position in the global timeline. The returned value is in samples.
size_t GetCurrentSamplePosition() override;
~BlockRandomizer()
{
if (m_prefetch.valid())

Просмотреть файл

@ -14,7 +14,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
NoRandomizer::NoRandomizer(IDataDeserializerPtr deserializer, bool multithreadedGetNextSequences)
: m_deserializer(deserializer),
m_samplePositionInEpoch(0),
m_currentChunkPosition(CHUNKID_MAX),
m_globalSamplePosition(0),
m_totalNumberOfSamples(0),
@ -58,7 +57,7 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
m_config.m_totalEpochSizeInSamples = m_totalNumberOfSamples;
}
m_samplePositionInEpoch = 0;
m_currentSequencePositionInChunk = 0;
m_globalSamplePosition = m_config.m_totalEpochSizeInSamples * config.m_epochIndex;
size_t sweepSamplePosition = m_globalSamplePosition % m_totalNumberOfSamples;
@ -77,35 +76,24 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
}
// Moving current sequence inside the chunk to match the sample offset.
// Currently linear, happens only at the border of epochs.
size_t sampleOffsetInsideChunk = sweepSamplePosition - m_chunkSampleOffset[m_currentChunkPosition];
size_t numberOfSamples = 0;
size_t sequenceId = 0;
// Currently linear, happens only at the border of epochs.
for (size_t i = 0; i < m_sequenceWindow.size(); ++i)
while (m_currentSequencePositionInChunk < m_sequenceWindow.size() &&
numberOfSamples < sampleOffsetInsideChunk)
{
size_t sequenceSize = m_sequenceWindow[i].m_numberOfSamples;
if (sequenceSize + numberOfSamples > sampleOffsetInsideChunk)
{
// We have found our sequence.
break;
}
numberOfSamples += sequenceSize;
sequenceId++;
numberOfSamples += m_sequenceWindow[m_currentSequencePositionInChunk].m_numberOfSamples;
m_currentSequencePositionInChunk++;
}
m_currentSequencePositionInChunk = sequenceId;
// Updating the global position
m_globalSamplePosition = (m_globalSamplePosition - sweepSamplePosition) + m_chunkSampleOffset[m_currentChunkPosition] + numberOfSamples;
assert(m_chunkDescriptions[m_currentChunkPosition]->m_numberOfSequences > m_currentSequencePositionInChunk);
};
// Moving the cursor to the next sequence. Possibly updating the chunk information if needed.
void NoRandomizer::MoveToNextSequence()
{
SequenceDescription& sequence = m_sequenceWindow[m_currentSequencePositionInChunk];
m_samplePositionInEpoch += sequence.m_numberOfSamples;
m_globalSamplePosition += sequence.m_numberOfSamples;
if (m_currentSequencePositionInChunk + 1 >= m_chunkDescriptions[m_currentChunkPosition]->m_numberOfSequences)
{
// Moving to the next chunk.
@ -135,6 +123,7 @@ std::vector<SequenceDescription> NoRandomizer::GetNextSequenceDescriptions(size_
const SequenceDescription& sequence = m_sequenceWindow[m_currentSequencePositionInChunk];
result.push_back(sequence);
samples -= (int)sequence.m_numberOfSamples;
m_globalSamplePosition += sequence.m_numberOfSamples;
MoveToNextSequence();
}
@ -143,10 +132,15 @@ std::vector<SequenceDescription> NoRandomizer::GetNextSequenceDescriptions(size_
return result;
}
size_t NoRandomizer::GetCurrentSamplePosition()
{
return m_globalSamplePosition;
}
Sequences NoRandomizer::GetNextSequences(size_t sampleCount)
{
Sequences result;
if (m_config.m_totalEpochSizeInSamples <= m_samplePositionInEpoch)
if (m_globalSamplePosition >= m_config.m_totalEpochSizeInSamples * (m_config.m_epochIndex + 1))
{
result.m_endOfEpoch = true;
return result;

Просмотреть файл

@ -27,6 +27,8 @@ public:
return m_deserializer->GetStreamDescriptions();
}
size_t GetCurrentSamplePosition() override;
private:
// Gets next sequence descriptions with total size less than sampleCount.
std::vector<SequenceDescription> GetNextSequenceDescriptions(size_t sampleCount);
@ -75,9 +77,6 @@ private:
// TODO: possible recalculate it base on samplePositionInEpoch.
size_t m_globalSamplePosition;
// Current sample position in the epoch.
size_t m_samplePositionInEpoch;
// Total number of samples in the sweep.
size_t m_totalNumberOfSamples;
};

Просмотреть файл

@ -98,7 +98,6 @@ struct Minibatch
//////////////////////////////////////////////////////////////////////////////////////////////////
// Main Reader interface. The border interface between the CNTK and reader libraries.
// TODO: Expect to change in a little bit: stream matrices provided by the network as input.
//////////////////////////////////////////////////////////////////////////////////////////////////
class Reader
{
@ -109,6 +108,12 @@ public:
// Starts a new epoch with the provided configuration
virtual void StartEpoch(const EpochConfiguration& config, const std::map<std::wstring, int>& inputDescriptions) = 0;
// Returns current position in the global timeline. The returned value is in samples.
// TODO: Currently in case of sequence to sequence training,
// TODO: the logical sequence size in samples = max(constitutuing sequences among all streams)
// TODO: This will change in the future.
virtual size_t GetCurrentSamplePosition() = 0;
// Reads a minibatch that contains data across all streams.
virtual Minibatch ReadMinibatch() = 0;

Просмотреть файл

@ -63,4 +63,9 @@ Minibatch ReaderBase::ReadMinibatch()
return m_packer->ReadMinibatch();
}
size_t ReaderBase::GetCurrentSamplePosition()
{
return m_sequenceEnumerator->GetCurrentSamplePosition();
}
}}}

Просмотреть файл

@ -27,6 +27,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Reads a single minibatch.
Minibatch ReadMinibatch() override;
// Returns current position in the global timeline. The returned value is in samples.
size_t GetCurrentSamplePosition() override;
virtual ~ReaderBase() = 0;
protected:

Просмотреть файл

@ -120,6 +120,7 @@ void ReaderShim<ElemType>::StartDistributedMinibatchLoop(
m_endOfEpoch = false;
m_reader->StartEpoch(config, inputDescriptions);
m_currentSamplePosition = m_reader->GetCurrentSamplePosition();
auto localCurrentDataTransferIndex = m_currentDataTransferIndex;
// Starting the prefetch task. There is always a single async read in flight.
@ -184,6 +185,10 @@ bool ReaderShim<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
auto result = m_prefetchTask.get();
// Ok, prefetch is done.
// Let's update our sample position.
m_currentSamplePosition = m_reader->GetCurrentSamplePosition();
m_endOfEpoch = result.m_isEndOfEpoch;
if (m_endOfEpoch && !result.m_isDataAvailable)
{
@ -339,6 +344,12 @@ size_t ReaderShim<ElemType>::GetNumParallelSequencesForFixingBPTTMode()
return m_numParallelSequences;
}
template <class ElemType>
size_t ReaderShim<ElemType>::GetCurrentSamplePosition()
{
return m_currentSamplePosition;
}
template class ReaderShim<float>;
template class ReaderShim<double>;
} } }

Просмотреть файл

@ -82,6 +82,8 @@ public:
virtual size_t GetNumParallelSequencesForFixingBPTTMode() override;
virtual size_t GetCurrentSamplePosition() override;
private:
struct PrefetchResult
{
@ -126,6 +128,11 @@ private:
// Device id.
int m_deviceId;
// Current sample position of the reader on the global timeline.
// We have to remember the value locally before starting prefetch.
// The value is updated only from the main thread (in StartEpoch/GetMinibatch)
size_t m_currentSamplePosition;
static void FillMatrixFromStream(
StorageType type,
Matrix<ElemType>* matrix,

Просмотреть файл

@ -44,6 +44,9 @@ public:
// Gets next sequences up to a maximum count of samples.
virtual Sequences GetNextSequences(size_t sampleCount) = 0;
// Returns current position in the global timeline. The returned value is in samples.
virtual size_t GetCurrentSamplePosition() = 0;
virtual ~SequenceEnumerator()
{
}

Просмотреть файл

@ -41,6 +41,12 @@ public:
m_outputStreams = transformedStreams;
}
// Returns current position in the global timeline. The returned value is in samples.
size_t GetCurrentSamplePosition() override
{
return m_sequenceProvider->GetCurrentSamplePosition();
}
// Sets configuration for the current epoch.
// Some transformers can change their config based on the epoch.
virtual void StartEpoch(const EpochConfiguration &config) override

Просмотреть файл

@ -94,5 +94,12 @@ public:
RuntimeError("GetData not supported in SparsePCReader");
};
virtual bool DataEnd();
size_t GetCurrentSamplePosition() override
{
// We do not support adaptive minibatch for this reader,
// CTF should be used instead.
NOT_IMPLEMENTED;
}
};
} } }

Просмотреть файл

@ -131,6 +131,7 @@ public:
{
InitFromConfig(config);
}
virtual void Destroy();
UCIFastReader()
{
@ -181,5 +182,10 @@ public:
{
NOT_IMPLEMENTED;
}
size_t GetCurrentSamplePosition() override
{
return m_mbStartSample;
}
};
} } }

Просмотреть файл

@ -163,6 +163,47 @@ void BlockRandomizerInstantiateTest(bool prefetch)
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch, BlockRandomizer::DecimationMode::chunk, false);
}
BOOST_AUTO_TEST_CASE(CheckCurrentCursorForRandomizers)
{
size_t chunkSizeInSamples = 10000;
size_t sweepNumberOfSamples = 500000;
uint32_t maxSequenceLength = 300;
size_t randomizationWindow = chunkSizeInSamples * 5;
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
auto blockRandomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, BlockRandomizer::DecimationMode::chunk, false);
auto noRandomizer = make_shared<NoRandomizer>(deserializer, false);
auto test = [](SequenceEnumeratorPtr r, size_t epochSize)
{
auto firstEpoch = ReadFullEpoch(r, epochSize, 0);
auto firstCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(firstCursor, firstEpoch.size());
auto secondEpoch = ReadFullEpoch(r, epochSize, 1);
auto secondCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(secondCursor - firstCursor, secondEpoch.size());
auto thirdEpoch = ReadFullEpoch(r, epochSize, 2);
auto thirdCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(thirdCursor - secondCursor, thirdEpoch.size());
auto anotherSecondEpoch = ReadFullEpoch(r, epochSize, 1);
auto anotherSecondCursor = r->GetCurrentSamplePosition();
BOOST_CHECK_EQUAL(anotherSecondCursor, secondCursor);
};
// Inside sweep
size_t epochSize = 50000;
test(blockRandomizer, epochSize);
test(noRandomizer, epochSize);
// Between sweeps
epochSize = (size_t)(sweepNumberOfSamples / 1.5);
test(blockRandomizer, epochSize);
test(noRandomizer, epochSize);
}
BOOST_AUTO_TEST_CASE(RandRollbackToEarlierEpochBetweenSweeps)
{