From 82c8c6972b2a5f124d2a217189ffa536a833751c Mon Sep 17 00:00:00 2001 From: Jaliya Ekanayake Date: Sat, 7 Apr 2018 23:18:46 -0700 Subject: [PATCH] Fixing the reader lib bug when number of samples in a chunk is less than the number of workers --- .../ReaderLib/LocalTimelineRandomizerBase.cpp | 46 +++++++++++----- .../ReaderLib/LocalTimelineRandomizerBase.h | 3 ++ .../UnitTests/ReaderTests/baseline.txt | 5 +- .../UnitTests/ReaderTests/ReaderLibTests.cpp | 52 +++++++++++++++++++ 4 files changed, 91 insertions(+), 15 deletions(-) diff --git a/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.cpp b/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.cpp index bdff65bea..92b8ea321 100644 --- a/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.cpp +++ b/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.cpp @@ -51,6 +51,21 @@ void LocalTimelineRandomizerBase::StartEpoch(const EpochConfiguration& config) Refill(); } +void LocalTimelineRandomizerBase::RefillCurrentWindowNow() +{ + m_currentState = GetInnerState(); + + // Make sure there is no outstanding prefetch. + if (!m_prefetch.valid()) + { + m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); }); + } + + m_prefetch.get(); + + RefillSequenceWindow(m_window); +} + void LocalTimelineRandomizerBase::Refill() { // Fill the expandable window. @@ -68,16 +83,7 @@ void LocalTimelineRandomizerBase::Refill() // - current state of the base class // - state of the inherited class before the current window is asked // - position in current window - - m_currentState = GetInnerState(); - - // Make sure there is no outstanding prefetch. - if (!m_prefetch.valid()) - m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); }); - - m_prefetch.get(); - - RefillSequenceWindow(m_window); + RefillCurrentWindowNow(); // Issue the next prefetch m_prefetch = std::async(std::launch::async, [this]() { Prefetch(); }); @@ -108,10 +114,16 @@ void LocalTimelineRandomizerBase::GetNextSequenceDescriptions(size_t maxSampleCo if (maxSampleCount > std::numeric_limits::max()) RuntimeError("The size of a minibatch cannot exceed max int."); - // The underlying randomizer should always fill data, - // in case it cannot we report the error. - if (m_window.m_sequences.empty()) - RuntimeError("Could not read any data."); + // This randomizer operates on the local time-line. So there could be chunks with no data + // for all workers. In that case, we return an empty sequences. + if (m_window.m_sequences.empty()) + { + m_sequenceBuffer.clear(); + m_chunkBuffer.clear(); + // Set the end-of-epoch flag (true when the current batch is last in an epoch). + result.m_endOfEpoch = IsEndReached(); + return; + } size_t samplesLoaded = 0; bool atLeastOneSequenceNeeded = true; @@ -190,7 +202,13 @@ Sequences LocalTimelineRandomizerBase::GetNextSequences(size_t /*ignoring global } if (m_sequenceBuffer.size() == 0) // No data + { + // Refill one more chunk, but does not issue the next async prefetch. + // If the next chunk has more sequences, then the regular Refill + // will be called inside GetNextSequenceDescriptions method. + RefillCurrentWindowNow(); return result; + } // Lets actually fetch data. result.m_data.resize(GetStreamDescriptions().size(), std::vector(m_sequenceBuffer.size())); diff --git a/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.h b/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.h index 35429e580..d8d8cd58f 100644 --- a/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.h +++ b/Source/Readers/ReaderLib/LocalTimelineRandomizerBase.h @@ -113,6 +113,9 @@ private: // Refills the current window of sequences. void Refill(); + // Refill and wait for data. Does not issue the next async refill. + void RefillCurrentWindowNow(); + // Gets next sequences not exceeding localSampleCount for this worker and globalSampleCount across workers. void GetNextSequenceDescriptions(size_t maxSampleCount, Sequences& result); diff --git a/Tests/EndToEndTests/UnitTests/ReaderTests/baseline.txt b/Tests/EndToEndTests/UnitTests/ReaderTests/baseline.txt index 12f833d93..d036be81b 100644 --- a/Tests/EndToEndTests/UnitTests/ReaderTests/baseline.txt +++ b/Tests/EndToEndTests/UnitTests/ReaderTests/baseline.txt @@ -3814,6 +3814,9 @@ Test module "ReaderTests" has passed with: Test case "ReaderLibTests/CheckGetCurrentCursorForRandomizers" has passed with: 16 assertions out of 16 passed + Test case "ReaderLibTests/LTNoRandomizerMultiWorker" has passed with: + 8 assertions out of 8 passed + Test case "ReaderLibTests/LTNoRandomizerCheckNoDuplicateSequence" has passed with: 42 assertions out of 42 passed @@ -3927,4 +3930,4 @@ Test module "ReaderTests" has passed with: Test case "TextInputIndexBuilderTests/Index_64MB_with_caching_check_perf" has passed - Test case "TextInputIndexBuilderTests/Index_1GB_with_caching_check_perf" has passed \ No newline at end of file + Test case "TextInputIndexBuilderTests/Index_1GB_with_caching_check_perf" has passed diff --git a/Tests/UnitTests/ReaderTests/ReaderLibTests.cpp b/Tests/UnitTests/ReaderTests/ReaderLibTests.cpp index f6ac5bc84..13f75aeb1 100644 --- a/Tests/UnitTests/ReaderTests/ReaderLibTests.cpp +++ b/Tests/UnitTests/ReaderTests/ReaderLibTests.cpp @@ -988,6 +988,58 @@ BOOST_AUTO_TEST_CASE(CheckGetCurrentCursorForRandomizers) test(noRandomizer, epochSize); } +BOOST_AUTO_TEST_CASE(LTNoRandomizerMultiWorker) +{ + auto num_chunks = 2; + auto num_sequences_per_chunk = 3; + size_t num_workers = 4; + + vector data(num_chunks * num_sequences_per_chunk); + iota(data.begin(), data.end(), 0.0f); + + for (int w = 0; w < num_workers; ++w) + { + auto mockDeserializer = make_shared(num_chunks, num_sequences_per_chunk, data); + auto randomizer = make_shared(mockDeserializer); + + EpochConfiguration epochConfiguration; + epochConfiguration.m_numberOfWorkers = num_workers; + epochConfiguration.m_workerRank = w; + epochConfiguration.m_minibatchSizeInSamples = 2; + epochConfiguration.m_totalEpochSizeInSamples = data.size(); + epochConfiguration.m_epochIndex = 0; + randomizer->StartEpoch(epochConfiguration); + + if (w < 2) + { + // Worker 0 and 1 will get two sequences. + Sequences sequences = randomizer->GetNextSequences(1, 1); + BOOST_CHECK_EQUAL(sequences.m_data.size(), 1); + sequences = randomizer->GetNextSequences(1, 1); + BOOST_CHECK_EQUAL(sequences.m_data.size(), 1); + } + else if (w == 2) + { + // Worker 2 will get only one sequence from the first + // chunk, but not from the second chunk. There are 6 sequences + // with indices [0,5], we take mod 4, which matches only for 2. + Sequences sequences = randomizer->GetNextSequences(1, 1); + BOOST_CHECK_EQUAL(sequences.m_data.size(), 1); + sequences = randomizer->GetNextSequences(1, 1); + BOOST_CHECK(sequences.m_data.empty()); + } + else + { + // Worker 3 (4th worker) will not get any sequence from the + // first chunk, but gets one from the second chunk. + Sequences sequences = randomizer->GetNextSequences(1, 1); + BOOST_CHECK(sequences.m_data.empty()); + sequences = randomizer->GetNextSequences(1, 1); + BOOST_CHECK_EQUAL(sequences.m_data.size(), 1); + } + } +} + // Check that each worker reads unique sequences. A bug was causing duplicate sequences in workers. BOOST_AUTO_TEST_CASE(LTNoRandomizerCheckNoDuplicateSequence) {