From 96646ec443d2e611389ebfc8005b4d3249848354 Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Mon, 8 Oct 2018 13:29:16 -0700 Subject: [PATCH] revisited fillBatches() and optimized it a little; added try-catch to nail down that EAGAIN observed in Philly; temporarily changed the fillBatches() criterion to threshold of 100 (should make little difference); bug fix: Corpus::next() should check for inconsistent end of data across streams; minor fix: MPI rank in log is now padded to same #digits for all ranks, for better readable logs --- .gitattributes | 2 + src/data/batch_generator.h | 90 ++++++++++++++++++------------- src/data/batch_stats.h | 3 +- src/data/corpus.cpp | 29 +++++----- src/data/text_input.cpp | 1 + src/optimizers/optimizers.cpp | 2 + src/tensors/gpu/device.cu | 3 +- src/training/communicator.cpp | 9 +++- src/training/graph_group_sync.cpp | 4 +- 9 files changed, 86 insertions(+), 57 deletions(-) create mode 100644 .gitattributes mode change 100644 => 100755 src/data/text_input.cpp diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..e5b09931 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# git should never touch line endings +* -text diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index d4e95290..9b99015d 100755 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -20,7 +20,7 @@ public: typedef typename DataSet::batch_ptr BatchPtr; typedef typename DataSet::sample sample; - typedef std::vector samples; + typedef std::vector samples; // @TODO: type names should be capitalized protected: Ptr data_; @@ -43,10 +43,11 @@ private: mutable std::condition_variable loadCondition_; bool loadReady_{true}; + // this runs on a bg thread; sequencing is handled by caller, but locking is done in here void fillBatches(bool shuffle = true) { + LOG(info, "fillBatches entered"); typedef typename sample::value_type Item; - auto itemCmp - = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; + auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content auto cmpSrc = [itemCmp](const sample& a, const sample& b) { return std::lexicographical_compare( @@ -58,12 +59,12 @@ private: a.rbegin(), a.rend(), b.rbegin(), b.rend(), itemCmp); }; - auto cmpNone = [](const sample& a, const sample& b) { return &a < &b; }; + auto cmpNone = [](const sample& a, const sample& b) { return &a < &b; }; // instead sort by address, so we have something to work with typedef std::function cmp_type; typedef std::priority_queue sample_queue; - std::unique_ptr maxiBatch; + std::unique_ptr maxiBatch; // priority queue, shortest first if(options_->has("maxi-batch-sort")) { if(options_->get("maxi-batch-sort") == "src") @@ -89,84 +90,98 @@ private: ++current_; } size_t sets = 0; - while(current_ != data_->end() && maxiBatch->size() < maxSize) { + try { + LOG(info, "begin read lines, current size {}", maxiBatch->size()); + while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data maxiBatch->push(*current_); sets = current_->size(); // do not consume more than required for the maxi batch as this causes // that line-by-line translation is delayed by one sentence bool last = maxiBatch->size() == maxSize; if(!last) - ++current_; + ++current_; // this actually reads the next line and pre-processes it + } + LOG(info, "end read lines, current size {}", maxiBatch->size()); + // @TODO: Consider using MPI at this point to parallelize parsing. + } + catch (const std::exception & e) { + LOG("exception caught while reading: {}", e.what()); + logCallStack(0); + throw; } + // construct the actual batches and place them in the queue samples batchVector; - int currentWords = 0; - std::vector lengths(sets, 0); + size_t currentWords = 0; + std::vector lengths(sets, 0); // records maximum length observed within current batch std::vector tempBatches; + tempBatches.reserve(10000); // (should be enough in most cases; not critical) - // while there are sentences in the queue - while(!maxiBatch->empty()) { + // process all loaded sentences in order of increasing length + // @TODO: we could just use a vector and do a sort() here; would make the cost more explicit + LOG(info, "begin form batches, #batches = {}", maxiBatch->size()); + const size_t mbWords = options_->get("mini-batch-words", 0); + const bool useDynamicBatching = options_->has("mini-batch-fit"); + while(!maxiBatch->empty()) { // while there are sentences in the queue // push item onto batch batchVector.push_back(maxiBatch->top()); - currentWords += (int)batchVector.back()[0].size(); - maxiBatch->pop(); + maxiBatch->pop(); // fetch next-shortest - // Batch size based on sentences - bool makeBatch = batchVector.size() == maxBatchSize; - - // Batch size based on words - if(options_->has("mini-batch-words")) { - int mbWords = options_->get("mini-batch-words"); - if(mbWords > 0) - makeBatch = currentWords > mbWords; - } - - if(options_->has("mini-batch-fit")) { - // Dynamic batching + // have we reached sufficient amount of data to form a batch? + bool makeBatch; + if(useDynamicBatching) { // batch size based on dynamic batching if(stats_) { for(size_t i = 0; i < sets; ++i) if(batchVector.back()[i].size() > lengths[i]) - lengths[i] = batchVector.back()[i].size(); + lengths[i] = batchVector.back()[i].size(); // record max lengths so far - maxBatchSize = stats_->getBatchSize(lengths); + maxBatchSize = stats_->getBatchSize(lengths); // note: to speed this up, we could cache the iterator. We call it with growing sentence length. + makeBatch = batchVector.size() >= maxBatchSize; + // if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch if(batchVector.size() > maxBatchSize) { maxiBatch->push(batchVector.back()); batchVector.pop_back(); - makeBatch = true; - } else { - makeBatch = batchVector.size() == maxBatchSize; } } } + else if(mbWords > 0) { + currentWords += batchVector.back()[0].size(); // count words based on first stream =source --@TODO: shouldn't we count based on labels? + makeBatch = currentWords > mbWords; // Batch size based on sentences + } + else + makeBatch = batchVector.size() == maxBatchSize; // Batch size based on words - // if batch has desired size create a real batch + // if we reached the desired batch size then create a real batch if(makeBatch) { tempBatches.push_back(data_->toBatch(batchVector)); // prepare for next batch batchVector.clear(); currentWords = 0; - lengths.clear(); - lengths.resize(sets, 0); + lengths.assign(sets, 0); } } // turn rest into batch if(!batchVector.empty()) tempBatches.push_back(data_->toBatch(batchVector)); + LOG(info, "end form batches, #tempBatches = {}", tempBatches.size()); if(shuffle) { // shuffle the batches std::shuffle(tempBatches.begin(), tempBatches.end(), eng_); } + LOG(info, "end shuffling batches, #tempBatches = {}", tempBatches.size()); // put batches onto queue // exclusive lock std::unique_lock lock(loadMutex_); - for(const auto& batch : tempBatches) + LOG(info, "begin pushing batches (this is after lock), #tempBatches = {}", tempBatches.size()); + for(const auto& batch : tempBatches) // @TODO: use insert() bufferedBatches_.push_back(batch); + LOG(info, "fillBatches completed, bufferedBatches.size = {}", bufferedBatches_.size()); } public: @@ -195,8 +210,9 @@ public: currentBatch_ = bufferedBatches_.front(); if(loadReady_ - && (int)bufferedBatches_.size() - <= std::max(options_->get("maxi-batch") / 5, 1)) { + && (int)bufferedBatches_.size() + <= 100/*std::max(options_->get("maxi-batch") / 5, 1)*/ // @TODO: rather, pull Marcin's proper fix + ) { { std::unique_lock lock(loadMutex_); loadReady_ = false; @@ -209,7 +225,7 @@ public: loadReady_ = true; loadCondition_.notify_all(); }) - .detach(); + .detach(); } std::unique_lock lock(loadMutex_); diff --git a/src/data/batch_stats.h b/src/data/batch_stats.h index 2590b5b0..846d6109 100755 --- a/src/data/batch_stats.h +++ b/src/data/batch_stats.h @@ -17,7 +17,8 @@ public: BatchStats() { } size_t getBatchSize(const std::vector& lengths) { - auto it = map_.lower_bound(lengths); + // find the first item where all item.first[i] >= lengths[i], i.e. that can fit sentence tuples of lengths[] + auto it = map_.lower_bound(lengths); // typ. 20 items, ~4..5 steps for(size_t i = 0; i < lengths.size(); ++i) while(it != map_.end() && it->first[i] < lengths[i]) it++; diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index 86dd9000..7f563039 100755 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp @@ -18,8 +18,7 @@ Corpus::Corpus(std::vector paths, : CorpusBase(paths, vocabs, options) {} SentenceTuple Corpus::next() { - bool cont = true; - while(cont) { + for (;;) { // (this is a retry loop for skipping invalid sentences) // get index of the current sentence size_t curId = pos_; // if corpus has been shuffled, ids_ contains sentence indexes @@ -29,11 +28,13 @@ SentenceTuple Corpus::next() { // fill up the sentence tuple with sentences from all input files SentenceTuple tup(curId); + size_t eofsHit = 0; for(size_t i = 0; i < files_.size(); ++i) { std::string line; - if(io::getline(*files_[i], line)) { - if(i > 0 && i == alignFileIdx_) { + bool gotLine = io::getline(*files_[i], line); + if(gotLine) { + if(i > 0 && i == alignFileIdx_) { // @TODO: alignFileIdx == 0 possible? addAlignmentToSentenceTuple(line, tup); } else if(i > 0 && i == weightFileIdx_) { addWeightsToSentenceTuple(line, tup); @@ -41,23 +42,23 @@ SentenceTuple Corpus::next() { addWordsToSentenceTuple(line, i, tup); } } + else + eofsHit++; } - // continue only if each input file provides an example - size_t expectedSize = files_.size(); - if(weightFileIdx_ > 0) - expectedSize -= 1; - if(alignFileIdx_ > 0) - expectedSize -= 1; - cont = tup.size() == expectedSize; + if (eofsHit == files_.size()) + return SentenceTuple(0); + ABORT_IF(eofsHit != 0, "not all input files have the same number of lines"); - // continue if all sentences are no longer than maximum allowed length - if(cont && std::all_of(tup.begin(), tup.end(), [=](const Words& words) { + // check if all streams are valid, that is, non-empty and no longer than maximum allowed length + if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) { return words.size() > 0 && words.size() <= maxLength_; })) return tup; + + // otherwise skip this sentence and try the next one + // @TODO: tail recursion? } - return SentenceTuple(0); } void Corpus::shuffle() { diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp old mode 100644 new mode 100755 index b2d00556..0d484766 --- a/src/data/text_input.cpp +++ b/src/data/text_input.cpp @@ -35,6 +35,7 @@ TextInput::TextInput(std::vector inputs, } SentenceTuple TextInput::next() { + // @TODO: This code mixes two patterns (while and early exit). Fix that. bool cont = true; while(cont) { // get index of the current sentence diff --git a/src/optimizers/optimizers.cpp b/src/optimizers/optimizers.cpp index e694e869..4e859fc5 100755 --- a/src/optimizers/optimizers.cpp +++ b/src/optimizers/optimizers.cpp @@ -192,6 +192,7 @@ void Adam::load(const std::string& name, } ABORT_IF(vMt.size() != vVt.size(), "mt and vt have different sizes??"); + LOG(info, "loading Adam params"); // @TODO: delete this scatterFn(vMt, [&](size_t localDeviceIndex, std::vector::const_iterator begin, std::vector::const_iterator end) { auto opt = std::dynamic_pointer_cast(opts[localDeviceIndex]); @@ -211,6 +212,7 @@ void Adam::load(const std::string& name, auto opt = std::dynamic_pointer_cast(opts[id]); opt->vt_->set(std::vector(begin, end)); }); + LOG(info, "done loading Adam params"); // @TODO: delete this } void Adam::save(const std::string& name, diff --git a/src/tensors/gpu/device.cu b/src/tensors/gpu/device.cu index 9638bfe9..0ec0b1f9 100755 --- a/src/tensors/gpu/device.cu +++ b/src/tensors/gpu/device.cu @@ -28,9 +28,10 @@ void Device::reserve(size_t size) { std::vector temp(size_); CUDA_CHECK(cudaMemcpy(temp.data(), data_, size_, cudaMemcpyDeviceToHost)); CUDA_CHECK(cudaFree(data_)); - LOG(info, "[memory] Re-allocating {} bytes on device {}", size, deviceId_.no); + LOG(info, "[memory] Re-allocating from {} to {} bytes on device {}", size_, size, deviceId_.no); CUDA_CHECK(cudaMalloc(&data_, size)); CUDA_CHECK(cudaMemcpy(data_, temp.data(), size_, cudaMemcpyHostToDevice)); + logCallStack(0); // @TODO: remove this } else { // No data_ yet: Just alloc. LOG(info, "[memory] Allocating {} bytes in device {}", size, deviceId_.no); diff --git a/src/training/communicator.cpp b/src/training/communicator.cpp index 0ee6ebd6..dfabcf05 100755 --- a/src/training/communicator.cpp +++ b/src/training/communicator.cpp @@ -81,8 +81,13 @@ public: MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_); // patch logging pattern to include the MPI rank, so that we can associate error messages with nodes - if (numMPIProcesses() > 1) - switchtoMultinodeLogging(std::to_string(MPIWrapper::myMPIRank())); + if (numMPIProcesses() > 1) { + std::string rankStr = std::to_string(MPIWrapper::myMPIRank()); + std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1); + while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely + rankStr.insert(rankStr.begin(), ' '); + switchtoMultinodeLogging(rankStr); + } // log hostnames in order, and test for (size_t r = 0; r < numMPIProcesses(); r++) { diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp index 484179d9..6229284f 100755 --- a/src/training/graph_group_sync.cpp +++ b/src/training/graph_group_sync.cpp @@ -198,12 +198,12 @@ void SyncGraphGroup::update(Ptr batch) /*override*/ { // cost across all local devices // @TODO: We should report cost aggregated over all MPI processes. float cost = 0; - for(auto& c : localDeviceCosts) + for(auto& c : localDeviceCosts) // localDeviceCosts is already summed up over delay steps cost += c; // extrapolate cost across MPI processes // @TODO: This is a crude estimate. Rather, we should aggregate cost across all GPUs correctly; cf. gradient trick described above. // @TODO: If this is too crude, we can also resurrect the code from f68433 to loop over the local batches, - // and then determine a correction factor based on actual counts. They are very close though across MPI processes. + // and then determine a correction factor based on actual counts. They are very close though across MPI processes. cost *= mpi_->numMPIProcesses(); // if cost is average-based, we need to turn the sum over devices into an average as well