BlockRandomizer::GetNextSequenceIds(): fix to return the right chunks

This commit is contained in:
Mark Hillebrand 2016-03-02 13:22:32 +01:00
Родитель 6061180a93
Коммит ea8ba640ac
2 изменённых файлов: 34 добавлений и 1 удалений

Просмотреть файл

@ -424,7 +424,7 @@ bool BlockRandomizer::GetNextSequenceIds(size_t sampleCount, std::vector<size_t>
const auto& seqDesc = m_randomTimeline[m_sequencePositionInSweep]; const auto& seqDesc = m_randomTimeline[m_sequencePositionInSweep];
originalIds.push_back(seqDesc.m_id); originalIds.push_back(seqDesc.m_id);
const auto & currentChunk = m_randomizedChunks[GetChunkIndexForSequencePosition(seqDesc.m_id)]; const auto & currentChunk = m_randomizedChunks[GetChunkIndexForSequencePosition(m_sequencePositionInSweep)];
const size_t windowBegin = currentChunk.m_windowBegin; const size_t windowBegin = currentChunk.m_windowBegin;
const size_t windowEnd = currentChunk.m_windowEnd; const size_t windowEnd = currentChunk.m_windowEnd;

Просмотреть файл

@ -173,6 +173,39 @@ BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpoch)
actual.begin(), actual.end()); actual.begin(), actual.end());
} }
BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochSmallWindow)
{
std::vector<float> data { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 };
auto mockDeserializer = std::make_shared<MockDeserializer>(5, 2, data);
auto randomizer = std::make_shared<BlockRandomizer>(0, 10, mockDeserializer);
EpochConfiguration epochConfiguration;
epochConfiguration.m_numberOfWorkers = 1;
epochConfiguration.m_workerRank = 0;
epochConfiguration.m_minibatchSizeInSamples = 0;
epochConfiguration.m_totalEpochSizeInSamples = 10;
epochConfiguration.m_epochIndex = 0;
randomizer->StartEpoch(epochConfiguration);
std::vector<float> expected { 9.0, 8.0, 3.0, 6.0, 2.0, 1.0, 4.0, 7.0, 5.0, 0.0 };
std::vector<float> actual;
for (int i = 0; i < 11; i++)
{
Sequences sequences = randomizer->GetNextSequences(1);
BOOST_CHECK_EQUAL(sequences.m_data.size(), 1 - (i / 10));
if (i < 10)
{
auto data = reinterpret_cast<DenseSequenceData&>(*sequences.m_data[0][0]);
BOOST_CHECK_EQUAL(data.m_numberOfSamples, 1);
actual.push_back(*((float*)data.m_data));
}
BOOST_CHECK_EQUAL(sequences.m_endOfEpoch, (9 <= i));
}
BOOST_CHECK_EQUAL_COLLECTIONS(expected.begin(), expected.end(),
actual.begin(), actual.end());
}
BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochLegacyRandomization) BOOST_AUTO_TEST_CASE(BlockRandomizerOneEpochLegacyRandomization)
{ {
std::vector<float> data { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 }; std::vector<float> data { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 };