BlockRandomizer::GetNextSequenceIds(): fix to return the right chunks
This commit is contained in:
Родитель
6061180a93
Коммит
ea8ba640ac
|
@ -424,7 +424,7 @@ bool BlockRandomizer::GetNextSequenceIds(size_t sampleCount, std::vector<size_t>
|
|||
const auto& seqDesc = m_randomTimeline[m_sequencePositionInSweep];
|
||||
originalIds.push_back(seqDesc.m_id);
|
||||
|
||||
const auto & currentChunk = m_randomizedChunks[GetChunkIndexForSequencePosition(seqDesc.m_id)];
|
||||
const auto & currentChunk = m_randomizedChunks[GetChunkIndexForSequencePosition(m_sequencePositionInSweep)];
|
||||
const size_t windowBegin = currentChunk.m_windowBegin;
|
||||
const size_t windowEnd = currentChunk.m_windowEnd;
|
||||
|
||||
|
|
|
@ -173,6 +173,39 @@ BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpoch)
|
|||
actual.begin(), actual.end());
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochSmallWindow)
|
||||
{
|
||||
std::vector<float> data { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 };
|
||||
auto mockDeserializer = std::make_shared<MockDeserializer>(5, 2, data);
|
||||
|
||||
auto randomizer = std::make_shared<BlockRandomizer>(0, 10, mockDeserializer);
|
||||
|
||||
EpochConfiguration epochConfiguration;
|
||||
epochConfiguration.m_numberOfWorkers = 1;
|
||||
epochConfiguration.m_workerRank = 0;
|
||||
epochConfiguration.m_minibatchSizeInSamples = 0;
|
||||
epochConfiguration.m_totalEpochSizeInSamples = 10;
|
||||
epochConfiguration.m_epochIndex = 0;
|
||||
randomizer->StartEpoch(epochConfiguration);
|
||||
|
||||
std::vector<float> expected { 9.0, 8.0, 3.0, 6.0, 2.0, 1.0, 4.0, 7.0, 5.0, 0.0 };
|
||||
std::vector<float> actual;
|
||||
for (int i = 0; i < 11; i++)
|
||||
{
|
||||
Sequences sequences = randomizer->GetNextSequences(1);
|
||||
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1 - (i / 10));
|
||||
if (i < 10)
|
||||
{
|
||||
auto data = reinterpret_cast<DenseSequenceData&>(*sequences.m_data[0][0]);
|
||||
BOOST_CHECK_EQUAL(data.m_numberOfSamples, 1);
|
||||
actual.push_back(*((float*)data.m_data));
|
||||
}
|
||||
BOOST_CHECK_EQUAL(sequences.m_endOfEpoch, (9 <= i));
|
||||
}
|
||||
BOOST_CHECK_EQUAL_COLLECTIONS(expected.begin(), expected.end(),
|
||||
actual.begin(), actual.end());
|
||||
}
|
||||
|
||||
BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochLegacyRandomization)
|
||||
{
|
||||
std::vector<float> data { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 };
|
||||
|
|
Загрузка…
Ссылка в новой задаче