Fixing the reader lib bug when number of samples in a chunk is less than the number of workers

This commit is contained in:
Jaliya Ekanayake 2018-04-07 23:18:46 -07:00
Родитель d4781c16ed
Коммит 82c8c6972b
4 изменённых файлов: 91 добавлений и 15 удалений

Просмотреть файл

@ -51,6 +51,21 @@ void LocalTimelineRandomizerBase::StartEpoch(const EpochConfiguration& config)
Refill();
}
void LocalTimelineRandomizerBase::RefillCurrentWindowNow()
{
m_currentState = GetInnerState();
// Make sure there is no outstanding prefetch.
if (!m_prefetch.valid())
{
m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); });
}
m_prefetch.get();
RefillSequenceWindow(m_window);
}
void LocalTimelineRandomizerBase::Refill()
{
// Fill the expandable window.
@ -68,16 +83,7 @@ void LocalTimelineRandomizerBase::Refill()
// - current state of the base class
// - state of the inherited class before the current window is asked
// - position in current window
m_currentState = GetInnerState();
// Make sure there is no outstanding prefetch.
if (!m_prefetch.valid())
m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); });
m_prefetch.get();
RefillSequenceWindow(m_window);
RefillCurrentWindowNow();
// Issue the next prefetch
m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); });
@ -108,10 +114,16 @@ void LocalTimelineRandomizerBase::GetNextSequenceDescriptions(size_t maxSampleCo
if (maxSampleCount > std::numeric_limits<int>::max())
RuntimeError("The size of a minibatch cannot exceed max int.");
// The underlying randomizer should always fill data,
// in case it cannot we report the error.
if (m_window.m_sequences.empty())
RuntimeError("Could not read any data.");
// This randomizer operates on the local time-line. So there could be chunks with no data
// for all workers. In that case, we return an empty sequences.
if (m_window.m_sequences.empty())
{
m_sequenceBuffer.clear();
m_chunkBuffer.clear();
// Set the end-of-epoch flag (true when the current batch is last in an epoch).
result.m_endOfEpoch = IsEndReached();
return;
}
size_t samplesLoaded = 0;
bool atLeastOneSequenceNeeded = true;
@ -190,7 +202,13 @@ Sequences LocalTimelineRandomizerBase::GetNextSequences(size_t /*ignoring global
}
if (m_sequenceBuffer.size() == 0) // No data
{
// Refill one more chunk, but does not issue the next async prefetch.
// If the next chunk has more sequences, then the regular Refill
// will be called inside GetNextSequenceDescriptions method.
RefillCurrentWindowNow();
return result;
}
// Lets actually fetch data.
result.m_data.resize(GetStreamDescriptions().size(), std::vector<SequenceDataPtr>(m_sequenceBuffer.size()));

Просмотреть файл

@ -113,6 +113,9 @@ private:
// Refills the current window of sequences.
void Refill();
// Refill and wait for data. Does not issue the next async refill.
void RefillCurrentWindowNow();
// Gets next sequences not exceeding localSampleCount for this worker and globalSampleCount across workers.
void GetNextSequenceDescriptions(size_t maxSampleCount, Sequences& result);

Просмотреть файл

@ -3814,6 +3814,9 @@ Test module "ReaderTests" has passed with:
Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with:
16 assertions out of 16 passed
Test case "ReaderLibTests/LTNoRandomizerMultiWorker" has passed with:
8 assertions out of 8 passed
Test case "ReaderLibTests/LTNoRandomizerCheckNoDuplicateSequence" has passed with:
42 assertions out of 42 passed
@ -3927,4 +3930,4 @@ Test module "ReaderTests" has passed with:
Test case "TextInputIndexBuilderTests/Index_64MB_with_caching_check_perf" has passed
Test case "TextInputIndexBuilderTests/Index_1GB_with_caching_check_perf" has passed
Test case "TextInputIndexBuilderTests/Index_1GB_with_caching_check_perf" has passed

Просмотреть файл

@ -988,6 +988,58 @@ BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers)
test(noRandomizer, epochSize);
}
BOOST_AUTO_TEST_CASE(LTNoRandomizerMultiWorker)
{
auto num_chunks = 2;
auto num_sequences_per_chunk = 3;
size_t num_workers = 4;
vector<float> data(num_chunks * num_sequences_per_chunk);
iota(data.begin(), data.end(), 0.0f);
for (int w = 0; w < num_workers; ++w)
{
auto mockDeserializer = make_shared<MockDeserializer>(num_chunks, num_sequences_per_chunk, data);
auto randomizer = make_shared<LTNoRandomizer>(mockDeserializer);
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = num_workers;
epochConfiguration.m_workerRank = w;
epochConfiguration.m_minibatchSizeInSamples = 2;
epochConfiguration.m_totalEpochSizeInSamples = data.size();
epochConfiguration.m_epochIndex = 0;
randomizer->StartEpoch(epochConfiguration);
if (w < 2)
{
// Worker 0 and 1 will get two sequences.
Sequences sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
}
else if (w == 2)
{
// Worker 2 will get only one sequence from the first
// chunk, but not from the second chunk. There are 6 sequences
// with indices [0,5], we take mod 4, which matches only for 2.
Sequences sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK(sequences.m_data.empty());
}
else
{
// Worker 3 (4th worker) will not get any sequence from the
// first chunk, but gets one from the second chunk.
Sequences sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK(sequences.m_data.empty());
sequences = randomizer->GetNextSequences(1, 1);
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1);
}
}
}
// Check that each worker reads unique sequences. A bug was causing duplicate sequences in workers.
BOOST_AUTO_TEST_CASE(LTNoRandomizerCheckNoDuplicateSequence)
{