Fixing the reader lib bug when number of samples in a chunk is less than the number of workers
This commit is contained in:
Родитель
d4781c16ed
Коммит
82c8c6972b
|
@ -51,6 +51,21 @@ void LocalTimelineRandomizerBase::StartEpoch(const EpochConfiguration& config)
|
|||
Refill();
|
||||
}
|
||||
|
||||
void LocalTimelineRandomizerBase::RefillCurrentWindowNow()
|
||||
{
|
||||
m_currentState = GetInnerState();
|
||||
|
||||
// Make sure there is no outstanding prefetch.
|
||||
if (!m_prefetch.valid())
|
||||
{
|
||||
m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); });
|
||||
}
|
||||
|
||||
m_prefetch.get();
|
||||
|
||||
RefillSequenceWindow(m_window);
|
||||
}
|
||||
|
||||
void LocalTimelineRandomizerBase::Refill()
|
||||
{
|
||||
// Fill the expandable window.
|
||||
|
@ -68,16 +83,7 @@ void LocalTimelineRandomizerBase::Refill()
|
|||
// - current state of the base class
|
||||
// - state of the inherited class before the current window is asked
|
||||
// - position in current window
|
||||
|
||||
m_currentState = GetInnerState();
|
||||
|
||||
// Make sure there is no outstanding prefetch.
|
||||
if (!m_prefetch.valid())
|
||||
m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); });
|
||||
|
||||
m_prefetch.get();
|
||||
|
||||
RefillSequenceWindow(m_window);
|
||||
RefillCurrentWindowNow();
|
||||
|
||||
// Issue the next prefetch
|
||||
m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); });
|
||||
|
@ -108,10 +114,16 @@ void LocalTimelineRandomizerBase::GetNextSequenceDescriptions(size_t maxSampleCo
|
|||
if (maxSampleCount > std::numeric_limits<int>::max())
|
||||
RuntimeError("The size of a minibatch cannot exceed max int.");
|
||||
|
||||
// The underlying randomizer should always fill data,
|
||||
// in case it cannot we report the error.
|
||||
if (m_window.m_sequences.empty())
|
||||
RuntimeError("Could not read any data.");
|
||||
// This randomizer operates on the local time-line. So there could be chunks with no data
|
||||
// for all workers. In that case, we return an empty sequences.
|
||||
if (m_window.m_sequences.empty())
|
||||
{
|
||||
m_sequenceBuffer.clear();
|
||||
m_chunkBuffer.clear();
|
||||
// Set the end-of-epoch flag (true when the current batch is last in an epoch).
|
||||
result.m_endOfEpoch = IsEndReached();
|
||||
return;
|
||||
}
|
||||
|
||||
size_t samplesLoaded = 0;
|
||||
bool atLeastOneSequenceNeeded = true;
|
||||
|
@ -190,7 +202,13 @@ Sequences LocalTimelineRandomizerBase::GetNextSequences(size_t /*ignoring global
|
|||
}
|
||||
|
||||
if (m_sequenceBuffer.size() == 0) // No data
|
||||
{
|
||||
// Refill one more chunk, but does not issue the next async prefetch.
|
||||
// If the next chunk has more sequences, then the regular Refill
|
||||
// will be called inside GetNextSequenceDescriptions method.
|
||||
RefillCurrentWindowNow();
|
||||
return result;
|
||||
}
|
||||
|
||||
// Lets actually fetch data.
|
||||
result.m_data.resize(GetStreamDescriptions().size(), std::vector<SequenceDataPtr>(m_sequenceBuffer.size()));
|
||||
|
|
|
@ -113,6 +113,9 @@ private:
|
|||
// Refills the current window of sequences.
|
||||
void Refill();
|
||||
|
||||
// Refill and wait for data. Does not issue the next async refill.
|
||||
void RefillCurrentWindowNow();
|
||||
|
||||
// Gets next sequences not exceeding localSampleCount for this worker and globalSampleCount across workers.
|
||||
void GetNextSequenceDescriptions(size_t maxSampleCount, Sequences& result);
|
||||
|
||||
|
|
|
@ -3814,6 +3814,9 @@ Test module "ReaderTests" has passed with:
|
|||
Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with:
|
||||
16 assertions out of 16 passed
|
||||
|
||||
Test case "ReaderLibTests/LTNoRandomizerMultiWorker" has passed with:
|
||||
8 assertions out of 8 passed
|
||||
|
||||
Test case "ReaderLibTests/LTNoRandomizerCheckNoDuplicateSequence" has passed with:
|
||||
42 assertions out of 42 passed
|
||||
|
||||
|
@ -3927,4 +3930,4 @@ Test module "ReaderTests" has passed with:
|
|||
|
||||
Test case "TextInputIndexBuilderTests/Index_64MB_with_caching_check_perf" has passed
|
||||
|
||||
Test case "TextInputIndexBuilderTests/Index_1GB_with_caching_check_perf" has passed
|
||||
Test case "TextInputIndexBuilderTests/Index_1GB_with_caching_check_perf" has passed
|
||||
|
|
|
@ -988,6 +988,58 @@ BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers)
|
|||
test(noRandomizer, epochSize);
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(LTNoRandomizerMultiWorker)
|
||||
{
|
||||
auto num_chunks = 2;
|
||||
auto num_sequences_per_chunk = 3;
|
||||
size_t num_workers = 4;
|
||||
|
||||
vector<float> data(num_chunks * num_sequences_per_chunk);
|
||||
iota(data.begin(), data.end(), 0.0f);
|
||||
|
||||
for (int w = 0; w < num_workers; ++w)
|
||||
{
|
||||
auto mockDeserializer = make_shared<MockDeserializer>(num_chunks, num_sequences_per_chunk, data);
|
||||
auto randomizer = make_shared<LTNoRandomizer>(mockDeserializer);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = num_workers;
|
||||
epochConfiguration.m_workerRank = w;
|
||||
epochConfiguration.m_minibatchSizeInSamples = 2;
|
||||
epochConfiguration.m_totalEpochSizeInSamples = data.size();
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
randomizer->StartEpoch(epochConfiguration);
|
||||
|
||||
if (w < 2)
|
||||
{
|
||||
// Worker 0 and 1 will get two sequences.
|
||||
Sequences sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
|
||||
sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
|
||||
}
|
||||
else if (w == 2)
|
||||
{
|
||||
// Worker 2 will get only one sequence from the first
|
||||
// chunk, but not from the second chunk. There are 6 sequences
|
||||
// with indices [0,5], we take mod 4, which matches only for 2.
|
||||
Sequences sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
|
||||
sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK(sequences.m_data.empty());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Worker 3 (4th worker) will not get any sequence from the
|
||||
// first chunk, but gets one from the second chunk.
|
||||
Sequences sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK(sequences.m_data.empty());
|
||||
sequences = randomizer->GetNextSequences(1, 1);
|
||||
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check that each worker reads unique sequences. A bug was causing duplicate sequences in workers.
|
||||
BOOST_AUTO_TEST_CASE(LTNoRandomizerCheckNoDuplicateSequence)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче