Adding GetCurrentSamplePosition to the reader interface
This commit is contained in:
Родитель
3c459d347e
Коммит
4d7d430359
|
@ -224,6 +224,13 @@ void DataReader::StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size
|
|||
}
|
||||
}
|
||||
|
||||
size_t DataReader::GetCurrentSamplePosition()
|
||||
{
|
||||
// BUGBUG: composition of old readers is not supported.
|
||||
// Returning just for the last reader.
|
||||
return m_dataReaders[m_ioNames.back()]->GetCurrentSamplePosition();
|
||||
}
|
||||
|
||||
// GetMinibatch - Get the next minibatch (features and labels)
|
||||
// matrices - [in] a map with named matrix types (i.e. 'features', 'labels') mapped to the corresponding matrix,
|
||||
// [out] each matrix resized if necessary containing data.
|
||||
|
|
|
@ -244,6 +244,9 @@ public:
|
|||
{
|
||||
return true;
|
||||
};
|
||||
|
||||
// Gets current sample position on the global timeline.
|
||||
virtual size_t GetCurrentSamplePosition() = 0;
|
||||
|
||||
virtual void StartDistributedMinibatchLoop(size_t mbSize, size_t epoch, size_t subsetNum, size_t numSubsets, size_t requestedEpochSamples = requestDataSize)
|
||||
{
|
||||
|
@ -416,6 +419,8 @@ public:
|
|||
}
|
||||
virtual ~DataReader();
|
||||
|
||||
size_t GetCurrentSamplePosition() override;
|
||||
|
||||
// StartMinibatchLoop - Startup a minibatch loop
|
||||
// mbSize - [in] size of the minibatch (number of frames, etc.)
|
||||
// epoch - [in] epoch number for this loop
|
||||
|
|
|
@ -100,6 +100,11 @@ public:
|
|||
{
|
||||
}
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
// StartMinibatchLoop - Startup a minibatch loop
|
||||
// mbSize - [in] size of the minibatch (number of frames, etc.)
|
||||
// epoch - [in] epoch number for this loop
|
||||
|
|
|
@ -609,6 +609,11 @@ public:
|
|||
{
|
||||
NOT_IMPLEMENTED;
|
||||
};
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_mbStartSample;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -182,5 +182,12 @@ public:
|
|||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
// We do not support adaptive minibatch for this reader,
|
||||
// CTF should be used instead.
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
};
|
||||
} } }
|
||||
|
|
|
@ -197,6 +197,9 @@ public:
|
|||
void SetSentenceEnd(int /*actualMbSize*/){};
|
||||
void SetRandomSeed(int){NOT_IMPLEMENTED};
|
||||
|
||||
//bool RequireSentenceSeg() const override { return !m_frameMode; };
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_mbiter->currentmbstartframe();
|
||||
}
|
||||
};
|
||||
} } }
|
||||
|
|
|
@ -212,6 +212,11 @@ public:
|
|||
{
|
||||
pMBLayout->CopyFrom(m_pMBLayout);
|
||||
}
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_mbiter->currentmbstartframe();
|
||||
}
|
||||
//bool RequireSentenceSeg() const override { return !m_framemode; };
|
||||
};
|
||||
} } }
|
||||
|
|
|
@ -286,7 +286,10 @@ public:
|
|||
|
||||
virtual bool DataEnd();
|
||||
|
||||
//int GetSentenceEndIdFromOutputLabel() { return -1; };
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_mbStartSample;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -191,8 +191,10 @@ public:
|
|||
|
||||
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);
|
||||
|
||||
//public:
|
||||
// int GetSentenceEndIdFromOutputLabel();
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_mbStartSample;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -294,6 +294,13 @@ public:
|
|||
void SetNbrSlicesEachRecurrentIter(const size_t){};
|
||||
void SetSentenceEndInBatch(std::vector<size_t>& /*sentenceEnd*/){};
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
// We do not support adaptive minibatch for this reader,
|
||||
// CTF should be used instead.
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
private:
|
||||
#if DEBUG
|
||||
marker_series* reader_series;
|
||||
|
|
|
@ -52,6 +52,11 @@ BlockRandomizer::BlockRandomizer(
|
|||
}
|
||||
}
|
||||
|
||||
size_t BlockRandomizer::GetCurrentSamplePosition()
|
||||
{
|
||||
return m_globalSamplePosition;
|
||||
}
|
||||
|
||||
// Start a new epoch.
|
||||
void BlockRandomizer::StartEpoch(const EpochConfiguration& config)
|
||||
{
|
||||
|
|
|
@ -63,6 +63,9 @@ public:
|
|||
return m_deserializer->GetStreamDescriptions();
|
||||
}
|
||||
|
||||
// Returns current position in the global timeline. The returned value is in samples.
|
||||
size_t GetCurrentSamplePosition() override;
|
||||
|
||||
~BlockRandomizer()
|
||||
{
|
||||
if (m_prefetch.valid())
|
||||
|
|
|
@ -14,7 +14,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
|
||||
NoRandomizer::NoRandomizer(IDataDeserializerPtr deserializer, bool multithreadedGetNextSequences)
|
||||
: m_deserializer(deserializer),
|
||||
m_samplePositionInEpoch(0),
|
||||
m_currentChunkPosition(CHUNKID_MAX),
|
||||
m_globalSamplePosition(0),
|
||||
m_totalNumberOfSamples(0),
|
||||
|
@ -58,7 +57,7 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
m_config.m_totalEpochSizeInSamples = m_totalNumberOfSamples;
|
||||
}
|
||||
|
||||
m_samplePositionInEpoch = 0;
|
||||
m_currentSequencePositionInChunk = 0;
|
||||
m_globalSamplePosition = m_config.m_totalEpochSizeInSamples * config.m_epochIndex;
|
||||
size_t sweepSamplePosition = m_globalSamplePosition % m_totalNumberOfSamples;
|
||||
|
||||
|
@ -77,35 +76,24 @@ void NoRandomizer::StartEpoch(const EpochConfiguration& config)
|
|||
}
|
||||
|
||||
// Moving current sequence inside the chunk to match the sample offset.
|
||||
// Currently linear, happens only at the border of epochs.
|
||||
size_t sampleOffsetInsideChunk = sweepSamplePosition - m_chunkSampleOffset[m_currentChunkPosition];
|
||||
size_t numberOfSamples = 0;
|
||||
size_t sequenceId = 0;
|
||||
|
||||
// Currently linear, happens only at the border of epochs.
|
||||
for (size_t i = 0; i < m_sequenceWindow.size(); ++i)
|
||||
while (m_currentSequencePositionInChunk < m_sequenceWindow.size() &&
|
||||
numberOfSamples < sampleOffsetInsideChunk)
|
||||
{
|
||||
size_t sequenceSize = m_sequenceWindow[i].m_numberOfSamples;
|
||||
if (sequenceSize + numberOfSamples > sampleOffsetInsideChunk)
|
||||
{
|
||||
// We have found our sequence.
|
||||
break;
|
||||
}
|
||||
|
||||
numberOfSamples += sequenceSize;
|
||||
sequenceId++;
|
||||
numberOfSamples += m_sequenceWindow[m_currentSequencePositionInChunk].m_numberOfSamples;
|
||||
m_currentSequencePositionInChunk++;
|
||||
}
|
||||
|
||||
m_currentSequencePositionInChunk = sequenceId;
|
||||
// Updating the global position
|
||||
m_globalSamplePosition = (m_globalSamplePosition - sweepSamplePosition) + m_chunkSampleOffset[m_currentChunkPosition] + numberOfSamples;
|
||||
assert(m_chunkDescriptions[m_currentChunkPosition]->m_numberOfSequences > m_currentSequencePositionInChunk);
|
||||
};
|
||||
|
||||
// Moving the cursor to the next sequence. Possibly updating the chunk information if needed.
|
||||
void NoRandomizer::MoveToNextSequence()
|
||||
{
|
||||
SequenceDescription& sequence = m_sequenceWindow[m_currentSequencePositionInChunk];
|
||||
m_samplePositionInEpoch += sequence.m_numberOfSamples;
|
||||
m_globalSamplePosition += sequence.m_numberOfSamples;
|
||||
|
||||
if (m_currentSequencePositionInChunk + 1 >= m_chunkDescriptions[m_currentChunkPosition]->m_numberOfSequences)
|
||||
{
|
||||
// Moving to the next chunk.
|
||||
|
@ -135,6 +123,7 @@ std::vector<SequenceDescription> NoRandomizer::GetNextSequenceDescriptions(size_
|
|||
const SequenceDescription& sequence = m_sequenceWindow[m_currentSequencePositionInChunk];
|
||||
result.push_back(sequence);
|
||||
samples -= (int)sequence.m_numberOfSamples;
|
||||
m_globalSamplePosition += sequence.m_numberOfSamples;
|
||||
|
||||
MoveToNextSequence();
|
||||
}
|
||||
|
@ -143,10 +132,15 @@ std::vector<SequenceDescription> NoRandomizer::GetNextSequenceDescriptions(size_
|
|||
return result;
|
||||
}
|
||||
|
||||
size_t NoRandomizer::GetCurrentSamplePosition()
|
||||
{
|
||||
return m_globalSamplePosition;
|
||||
}
|
||||
|
||||
Sequences NoRandomizer::GetNextSequences(size_t sampleCount)
|
||||
{
|
||||
Sequences result;
|
||||
if (m_config.m_totalEpochSizeInSamples <= m_samplePositionInEpoch)
|
||||
if (m_globalSamplePosition >= m_config.m_totalEpochSizeInSamples * (m_config.m_epochIndex + 1))
|
||||
{
|
||||
result.m_endOfEpoch = true;
|
||||
return result;
|
||||
|
|
|
@ -27,6 +27,8 @@ public:
|
|||
return m_deserializer->GetStreamDescriptions();
|
||||
}
|
||||
|
||||
size_t GetCurrentSamplePosition() override;
|
||||
|
||||
private:
|
||||
// Gets next sequence descriptions with total size less than sampleCount.
|
||||
std::vector<SequenceDescription> GetNextSequenceDescriptions(size_t sampleCount);
|
||||
|
@ -75,9 +77,6 @@ private:
|
|||
// TODO: possible recalculate it base on samplePositionInEpoch.
|
||||
size_t m_globalSamplePosition;
|
||||
|
||||
// Current sample position in the epoch.
|
||||
size_t m_samplePositionInEpoch;
|
||||
|
||||
// Total number of samples in the sweep.
|
||||
size_t m_totalNumberOfSamples;
|
||||
};
|
||||
|
|
|
@ -98,7 +98,6 @@ struct Minibatch
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Main Reader interface. The border interface between the CNTK and reader libraries.
|
||||
// TODO: Expect to change in a little bit: stream matrices provided by the network as input.
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
class Reader
|
||||
{
|
||||
|
@ -109,6 +108,12 @@ public:
|
|||
// Starts a new epoch with the provided configuration
|
||||
virtual void StartEpoch(const EpochConfiguration& config, const std::map<std::wstring, int>& inputDescriptions) = 0;
|
||||
|
||||
// Returns current position in the global timeline. The returned value is in samples.
|
||||
// TODO: Currently in case of sequence to sequence training,
|
||||
// TODO: the logical sequence size in samples = max(constitutuing sequences among all streams)
|
||||
// TODO: This will change in the future.
|
||||
virtual size_t GetCurrentSamplePosition() = 0;
|
||||
|
||||
// Reads a minibatch that contains data across all streams.
|
||||
virtual Minibatch ReadMinibatch() = 0;
|
||||
|
||||
|
|
|
@ -63,4 +63,9 @@ Minibatch ReaderBase::ReadMinibatch()
|
|||
return m_packer->ReadMinibatch();
|
||||
}
|
||||
|
||||
size_t ReaderBase::GetCurrentSamplePosition()
|
||||
{
|
||||
return m_sequenceEnumerator->GetCurrentSamplePosition();
|
||||
}
|
||||
|
||||
}}}
|
||||
|
|
|
@ -27,6 +27,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
// Reads a single minibatch.
|
||||
Minibatch ReadMinibatch() override;
|
||||
|
||||
// Returns current position in the global timeline. The returned value is in samples.
|
||||
size_t GetCurrentSamplePosition() override;
|
||||
|
||||
virtual ~ReaderBase() = 0;
|
||||
|
||||
protected:
|
||||
|
|
|
@ -120,6 +120,7 @@ void ReaderShim<ElemType>::StartDistributedMinibatchLoop(
|
|||
|
||||
m_endOfEpoch = false;
|
||||
m_reader->StartEpoch(config, inputDescriptions);
|
||||
m_currentSamplePosition = m_reader->GetCurrentSamplePosition();
|
||||
|
||||
auto localCurrentDataTransferIndex = m_currentDataTransferIndex;
|
||||
// Starting the prefetch task. There is always a single async read in flight.
|
||||
|
@ -184,6 +185,10 @@ bool ReaderShim<ElemType>::GetMinibatch(StreamMinibatchInputs& matrices)
|
|||
auto result = m_prefetchTask.get();
|
||||
|
||||
// Ok, prefetch is done.
|
||||
|
||||
// Let's update our sample position.
|
||||
m_currentSamplePosition = m_reader->GetCurrentSamplePosition();
|
||||
|
||||
m_endOfEpoch = result.m_isEndOfEpoch;
|
||||
if (m_endOfEpoch && !result.m_isDataAvailable)
|
||||
{
|
||||
|
@ -339,6 +344,12 @@ size_t ReaderShim<ElemType>::GetNumParallelSequencesForFixingBPTTMode()
|
|||
return m_numParallelSequences;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
size_t ReaderShim<ElemType>::GetCurrentSamplePosition()
|
||||
{
|
||||
return m_currentSamplePosition;
|
||||
}
|
||||
|
||||
template class ReaderShim<float>;
|
||||
template class ReaderShim<double>;
|
||||
} } }
|
||||
|
|
|
@ -82,6 +82,8 @@ public:
|
|||
|
||||
virtual size_t GetNumParallelSequencesForFixingBPTTMode() override;
|
||||
|
||||
virtual size_t GetCurrentSamplePosition() override;
|
||||
|
||||
private:
|
||||
struct PrefetchResult
|
||||
{
|
||||
|
@ -126,6 +128,11 @@ private:
|
|||
// Device id.
|
||||
int m_deviceId;
|
||||
|
||||
// Current sample position of the reader on the global timeline.
|
||||
// We have to remember the value locally before starting prefetch.
|
||||
// The value is updated only from the main thread (in StartEpoch/GetMinibatch)
|
||||
size_t m_currentSamplePosition;
|
||||
|
||||
static void FillMatrixFromStream(
|
||||
StorageType type,
|
||||
Matrix<ElemType>* matrix,
|
||||
|
|
|
@ -44,6 +44,9 @@ public:
|
|||
// Gets next sequences up to a maximum count of samples.
|
||||
virtual Sequences GetNextSequences(size_t sampleCount) = 0;
|
||||
|
||||
// Returns current position in the global timeline. The returned value is in samples.
|
||||
virtual size_t GetCurrentSamplePosition() = 0;
|
||||
|
||||
virtual ~SequenceEnumerator()
|
||||
{
|
||||
}
|
||||
|
|
|
@ -41,6 +41,12 @@ public:
|
|||
m_outputStreams = transformedStreams;
|
||||
}
|
||||
|
||||
// Returns current position in the global timeline. The returned value is in samples.
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_sequenceProvider->GetCurrentSamplePosition();
|
||||
}
|
||||
|
||||
// Sets configuration for the current epoch.
|
||||
// Some transformers can change their config based on the epoch.
|
||||
virtual void StartEpoch(const EpochConfiguration &config) override
|
||||
|
|
|
@ -94,5 +94,12 @@ public:
|
|||
RuntimeError("GetData not supported in SparsePCReader");
|
||||
};
|
||||
virtual bool DataEnd();
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
// We do not support adaptive minibatch for this reader,
|
||||
// CTF should be used instead.
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
};
|
||||
} } }
|
||||
|
|
|
@ -131,6 +131,7 @@ public:
|
|||
{
|
||||
InitFromConfig(config);
|
||||
}
|
||||
|
||||
virtual void Destroy();
|
||||
UCIFastReader()
|
||||
{
|
||||
|
@ -181,5 +182,10 @@ public:
|
|||
{
|
||||
NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
size_t GetCurrentSamplePosition() override
|
||||
{
|
||||
return m_mbStartSample;
|
||||
}
|
||||
};
|
||||
} } }
|
||||
|
|
|
@ -163,6 +163,47 @@ void BlockRandomizerInstantiateTest(bool prefetch)
|
|||
auto randomizer = make_shared<BlockRandomizer>(0, SIZE_MAX, mockDeserializer, prefetch, BlockRandomizer::DecimationMode::chunk, false);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(CheckCurrentCursorForRandomizers)
|
||||
{
|
||||
size_t chunkSizeInSamples = 10000;
|
||||
size_t sweepNumberOfSamples = 500000;
|
||||
uint32_t maxSequenceLength = 300;
|
||||
size_t randomizationWindow = chunkSizeInSamples * 5;
|
||||
auto deserializer = make_shared<SequentialDeserializer>(0, chunkSizeInSamples, sweepNumberOfSamples, maxSequenceLength);
|
||||
|
||||
auto blockRandomizer = make_shared<BlockRandomizer>(0, randomizationWindow, deserializer, true, BlockRandomizer::DecimationMode::chunk, false);
|
||||
auto noRandomizer = make_shared<NoRandomizer>(deserializer, false);
|
||||
|
||||
auto test = [](SequenceEnumeratorPtr r, size_t epochSize)
|
||||
{
|
||||
auto firstEpoch = ReadFullEpoch(r, epochSize, 0);
|
||||
auto firstCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(firstCursor, firstEpoch.size());
|
||||
|
||||
auto secondEpoch = ReadFullEpoch(r, epochSize, 1);
|
||||
auto secondCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(secondCursor - firstCursor, secondEpoch.size());
|
||||
|
||||
auto thirdEpoch = ReadFullEpoch(r, epochSize, 2);
|
||||
auto thirdCursor = r->GetCurrentSamplePosition();
|
||||
BOOST_CHECK_EQUAL(thirdCursor - secondCursor, thirdEpoch.size());
|
||||
|
||||
auto anotherSecondEpoch = ReadFullEpoch(r, epochSize, 1);
|
||||
auto anotherSecondCursor = r->GetCurrentSamplePosition();
|
||||
|
||||
BOOST_CHECK_EQUAL(anotherSecondCursor, secondCursor);
|
||||
};
|
||||
|
||||
// Inside sweep
|
||||
size_t epochSize = 50000;
|
||||
test(blockRandomizer, epochSize);
|
||||
test(noRandomizer, epochSize);
|
||||
|
||||
// Between sweeps
|
||||
epochSize = (size_t)(sweepNumberOfSamples / 1.5);
|
||||
test(blockRandomizer, epochSize);
|
||||
test(noRandomizer, epochSize);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(RandRollbackToEarlierEpochBetweenSweeps)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче