зеркало из https://github.com/mozilla/marian-dev.git
revisited fillBatches() and optimized it a little;
added try-catch to nail down that EAGAIN observed in Philly; temporarily changed the fillBatches() criterion to threshold of 100 (should make little difference); bug fix: Corpus::next() should check for inconsistent end of data across streams; minor fix: MPI rank in log is now padded to same #digits for all ranks, for better readable logs
This commit is contained in:
Родитель
85c1d869f6
Коммит
96646ec443
|
@ -0,0 +1,2 @@
|
|||
# git should never touch line endings
|
||||
* -text
|
|
@ -20,7 +20,7 @@ public:
|
|||
typedef typename DataSet::batch_ptr BatchPtr;
|
||||
|
||||
typedef typename DataSet::sample sample;
|
||||
typedef std::vector<sample> samples;
|
||||
typedef std::vector<sample> samples; // @TODO: type names should be capitalized
|
||||
|
||||
protected:
|
||||
Ptr<DataSet> data_;
|
||||
|
@ -43,10 +43,11 @@ private:
|
|||
mutable std::condition_variable loadCondition_;
|
||||
bool loadReady_{true};
|
||||
|
||||
// this runs on a bg thread; sequencing is handled by caller, but locking is done in here
|
||||
void fillBatches(bool shuffle = true) {
|
||||
LOG(info, "fillBatches entered");
|
||||
typedef typename sample::value_type Item;
|
||||
auto itemCmp
|
||||
= [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); };
|
||||
auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content
|
||||
|
||||
auto cmpSrc = [itemCmp](const sample& a, const sample& b) {
|
||||
return std::lexicographical_compare(
|
||||
|
@ -58,12 +59,12 @@ private:
|
|||
a.rbegin(), a.rend(), b.rbegin(), b.rend(), itemCmp);
|
||||
};
|
||||
|
||||
auto cmpNone = [](const sample& a, const sample& b) { return &a < &b; };
|
||||
auto cmpNone = [](const sample& a, const sample& b) { return &a < &b; }; // instead sort by address, so we have something to work with
|
||||
|
||||
typedef std::function<bool(const sample&, const sample&)> cmp_type;
|
||||
typedef std::priority_queue<sample, samples, cmp_type> sample_queue;
|
||||
|
||||
std::unique_ptr<sample_queue> maxiBatch;
|
||||
std::unique_ptr<sample_queue> maxiBatch; // priority queue, shortest first
|
||||
|
||||
if(options_->has("maxi-batch-sort")) {
|
||||
if(options_->get<std::string>("maxi-batch-sort") == "src")
|
||||
|
@ -89,84 +90,98 @@ private:
|
|||
++current_;
|
||||
}
|
||||
size_t sets = 0;
|
||||
while(current_ != data_->end() && maxiBatch->size() < maxSize) {
|
||||
try {
|
||||
LOG(info, "begin read lines, current size {}", maxiBatch->size());
|
||||
while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data
|
||||
maxiBatch->push(*current_);
|
||||
sets = current_->size();
|
||||
// do not consume more than required for the maxi batch as this causes
|
||||
// that line-by-line translation is delayed by one sentence
|
||||
bool last = maxiBatch->size() == maxSize;
|
||||
if(!last)
|
||||
++current_;
|
||||
++current_; // this actually reads the next line and pre-processes it
|
||||
}
|
||||
LOG(info, "end read lines, current size {}", maxiBatch->size());
|
||||
// @TODO: Consider using MPI at this point to parallelize parsing.
|
||||
}
|
||||
catch (const std::exception & e) {
|
||||
LOG("exception caught while reading: {}", e.what());
|
||||
logCallStack(0);
|
||||
throw;
|
||||
}
|
||||
|
||||
// construct the actual batches and place them in the queue
|
||||
samples batchVector;
|
||||
int currentWords = 0;
|
||||
std::vector<size_t> lengths(sets, 0);
|
||||
size_t currentWords = 0;
|
||||
std::vector<size_t> lengths(sets, 0); // records maximum length observed within current batch
|
||||
|
||||
std::vector<BatchPtr> tempBatches;
|
||||
tempBatches.reserve(10000); // (should be enough in most cases; not critical)
|
||||
|
||||
// while there are sentences in the queue
|
||||
while(!maxiBatch->empty()) {
|
||||
// process all loaded sentences in order of increasing length
|
||||
// @TODO: we could just use a vector and do a sort() here; would make the cost more explicit
|
||||
LOG(info, "begin form batches, #batches = {}", maxiBatch->size());
|
||||
const size_t mbWords = options_->get<size_t>("mini-batch-words", 0);
|
||||
const bool useDynamicBatching = options_->has("mini-batch-fit");
|
||||
while(!maxiBatch->empty()) { // while there are sentences in the queue
|
||||
// push item onto batch
|
||||
batchVector.push_back(maxiBatch->top());
|
||||
currentWords += (int)batchVector.back()[0].size();
|
||||
maxiBatch->pop();
|
||||
maxiBatch->pop(); // fetch next-shortest
|
||||
|
||||
// Batch size based on sentences
|
||||
bool makeBatch = batchVector.size() == maxBatchSize;
|
||||
|
||||
// Batch size based on words
|
||||
if(options_->has("mini-batch-words")) {
|
||||
int mbWords = options_->get<int>("mini-batch-words");
|
||||
if(mbWords > 0)
|
||||
makeBatch = currentWords > mbWords;
|
||||
}
|
||||
|
||||
if(options_->has("mini-batch-fit")) {
|
||||
// Dynamic batching
|
||||
// have we reached sufficient amount of data to form a batch?
|
||||
bool makeBatch;
|
||||
if(useDynamicBatching) { // batch size based on dynamic batching
|
||||
if(stats_) {
|
||||
for(size_t i = 0; i < sets; ++i)
|
||||
if(batchVector.back()[i].size() > lengths[i])
|
||||
lengths[i] = batchVector.back()[i].size();
|
||||
lengths[i] = batchVector.back()[i].size(); // record max lengths so far
|
||||
|
||||
maxBatchSize = stats_->getBatchSize(lengths);
|
||||
maxBatchSize = stats_->getBatchSize(lengths); // note: to speed this up, we could cache the iterator. We call it with growing sentence length.
|
||||
|
||||
makeBatch = batchVector.size() >= maxBatchSize;
|
||||
// if last added sentence caused a bump then we likely have bad padding, so rather move it into the next batch
|
||||
if(batchVector.size() > maxBatchSize) {
|
||||
maxiBatch->push(batchVector.back());
|
||||
batchVector.pop_back();
|
||||
makeBatch = true;
|
||||
} else {
|
||||
makeBatch = batchVector.size() == maxBatchSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(mbWords > 0) {
|
||||
currentWords += batchVector.back()[0].size(); // count words based on first stream =source --@TODO: shouldn't we count based on labels?
|
||||
makeBatch = currentWords > mbWords; // Batch size based on sentences
|
||||
}
|
||||
else
|
||||
makeBatch = batchVector.size() == maxBatchSize; // Batch size based on words
|
||||
|
||||
// if batch has desired size create a real batch
|
||||
// if we reached the desired batch size then create a real batch
|
||||
if(makeBatch) {
|
||||
tempBatches.push_back(data_->toBatch(batchVector));
|
||||
|
||||
// prepare for next batch
|
||||
batchVector.clear();
|
||||
currentWords = 0;
|
||||
lengths.clear();
|
||||
lengths.resize(sets, 0);
|
||||
lengths.assign(sets, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// turn rest into batch
|
||||
if(!batchVector.empty())
|
||||
tempBatches.push_back(data_->toBatch(batchVector));
|
||||
LOG(info, "end form batches, #tempBatches = {}", tempBatches.size());
|
||||
|
||||
if(shuffle) {
|
||||
// shuffle the batches
|
||||
std::shuffle(tempBatches.begin(), tempBatches.end(), eng_);
|
||||
}
|
||||
LOG(info, "end shuffling batches, #tempBatches = {}", tempBatches.size());
|
||||
|
||||
// put batches onto queue
|
||||
// exclusive lock
|
||||
std::unique_lock<std::mutex> lock(loadMutex_);
|
||||
for(const auto& batch : tempBatches)
|
||||
LOG(info, "begin pushing batches (this is after lock), #tempBatches = {}", tempBatches.size());
|
||||
for(const auto& batch : tempBatches) // @TODO: use insert()
|
||||
bufferedBatches_.push_back(batch);
|
||||
LOG(info, "fillBatches completed, bufferedBatches.size = {}", bufferedBatches_.size());
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -195,8 +210,9 @@ public:
|
|||
currentBatch_ = bufferedBatches_.front();
|
||||
|
||||
if(loadReady_
|
||||
&& (int)bufferedBatches_.size()
|
||||
<= std::max(options_->get<int>("maxi-batch") / 5, 1)) {
|
||||
&& (int)bufferedBatches_.size()
|
||||
<= 100/*std::max(options_->get<int>("maxi-batch") / 5, 1)*/ // @TODO: rather, pull Marcin's proper fix
|
||||
) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(loadMutex_);
|
||||
loadReady_ = false;
|
||||
|
@ -209,7 +225,7 @@ public:
|
|||
loadReady_ = true;
|
||||
loadCondition_.notify_all();
|
||||
})
|
||||
.detach();
|
||||
.detach();
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(loadMutex_);
|
||||
|
|
|
@ -17,7 +17,8 @@ public:
|
|||
BatchStats() { }
|
||||
|
||||
size_t getBatchSize(const std::vector<size_t>& lengths) {
|
||||
auto it = map_.lower_bound(lengths);
|
||||
// find the first item where all item.first[i] >= lengths[i], i.e. that can fit sentence tuples of lengths[]
|
||||
auto it = map_.lower_bound(lengths); // typ. 20 items, ~4..5 steps
|
||||
for(size_t i = 0; i < lengths.size(); ++i)
|
||||
while(it != map_.end() && it->first[i] < lengths[i])
|
||||
it++;
|
||||
|
|
|
@ -18,8 +18,7 @@ Corpus::Corpus(std::vector<std::string> paths,
|
|||
: CorpusBase(paths, vocabs, options) {}
|
||||
|
||||
SentenceTuple Corpus::next() {
|
||||
bool cont = true;
|
||||
while(cont) {
|
||||
for (;;) { // (this is a retry loop for skipping invalid sentences)
|
||||
// get index of the current sentence
|
||||
size_t curId = pos_;
|
||||
// if corpus has been shuffled, ids_ contains sentence indexes
|
||||
|
@ -29,11 +28,13 @@ SentenceTuple Corpus::next() {
|
|||
|
||||
// fill up the sentence tuple with sentences from all input files
|
||||
SentenceTuple tup(curId);
|
||||
size_t eofsHit = 0;
|
||||
for(size_t i = 0; i < files_.size(); ++i) {
|
||||
std::string line;
|
||||
|
||||
if(io::getline(*files_[i], line)) {
|
||||
if(i > 0 && i == alignFileIdx_) {
|
||||
bool gotLine = io::getline(*files_[i], line);
|
||||
if(gotLine) {
|
||||
if(i > 0 && i == alignFileIdx_) { // @TODO: alignFileIdx == 0 possible?
|
||||
addAlignmentToSentenceTuple(line, tup);
|
||||
} else if(i > 0 && i == weightFileIdx_) {
|
||||
addWeightsToSentenceTuple(line, tup);
|
||||
|
@ -41,23 +42,23 @@ SentenceTuple Corpus::next() {
|
|||
addWordsToSentenceTuple(line, i, tup);
|
||||
}
|
||||
}
|
||||
else
|
||||
eofsHit++;
|
||||
}
|
||||
|
||||
// continue only if each input file provides an example
|
||||
size_t expectedSize = files_.size();
|
||||
if(weightFileIdx_ > 0)
|
||||
expectedSize -= 1;
|
||||
if(alignFileIdx_ > 0)
|
||||
expectedSize -= 1;
|
||||
cont = tup.size() == expectedSize;
|
||||
if (eofsHit == files_.size())
|
||||
return SentenceTuple(0);
|
||||
ABORT_IF(eofsHit != 0, "not all input files have the same number of lines");
|
||||
|
||||
// continue if all sentences are no longer than maximum allowed length
|
||||
if(cont && std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
|
||||
// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
|
||||
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
|
||||
return words.size() > 0 && words.size() <= maxLength_;
|
||||
}))
|
||||
return tup;
|
||||
|
||||
// otherwise skip this sentence and try the next one
|
||||
// @TODO: tail recursion?
|
||||
}
|
||||
return SentenceTuple(0);
|
||||
}
|
||||
|
||||
void Corpus::shuffle() {
|
||||
|
|
|
@ -35,6 +35,7 @@ TextInput::TextInput(std::vector<std::string> inputs,
|
|||
}
|
||||
|
||||
SentenceTuple TextInput::next() {
|
||||
// @TODO: This code mixes two patterns (while and early exit). Fix that.
|
||||
bool cont = true;
|
||||
while(cont) {
|
||||
// get index of the current sentence
|
||||
|
|
|
@ -192,6 +192,7 @@ void Adam::load(const std::string& name,
|
|||
}
|
||||
ABORT_IF(vMt.size() != vVt.size(), "mt and vt have different sizes??");
|
||||
|
||||
LOG(info, "loading Adam params"); // @TODO: delete this
|
||||
scatterFn(vMt,
|
||||
[&](size_t localDeviceIndex, std::vector<float>::const_iterator begin, std::vector<float>::const_iterator end) {
|
||||
auto opt = std::dynamic_pointer_cast<Adam>(opts[localDeviceIndex]);
|
||||
|
@ -211,6 +212,7 @@ void Adam::load(const std::string& name,
|
|||
auto opt = std::dynamic_pointer_cast<Adam>(opts[id]);
|
||||
opt->vt_->set(std::vector<float>(begin, end));
|
||||
});
|
||||
LOG(info, "done loading Adam params"); // @TODO: delete this
|
||||
}
|
||||
|
||||
void Adam::save(const std::string& name,
|
||||
|
|
|
@ -28,9 +28,10 @@ void Device::reserve(size_t size) {
|
|||
std::vector<uint8_t> temp(size_);
|
||||
CUDA_CHECK(cudaMemcpy(temp.data(), data_, size_, cudaMemcpyDeviceToHost));
|
||||
CUDA_CHECK(cudaFree(data_));
|
||||
LOG(info, "[memory] Re-allocating {} bytes on device {}", size, deviceId_.no);
|
||||
LOG(info, "[memory] Re-allocating from {} to {} bytes on device {}", size_, size, deviceId_.no);
|
||||
CUDA_CHECK(cudaMalloc(&data_, size));
|
||||
CUDA_CHECK(cudaMemcpy(data_, temp.data(), size_, cudaMemcpyHostToDevice));
|
||||
logCallStack(0); // @TODO: remove this
|
||||
} else {
|
||||
// No data_ yet: Just alloc.
|
||||
LOG(info, "[memory] Allocating {} bytes in device {}", size, deviceId_.no);
|
||||
|
|
|
@ -81,8 +81,13 @@ public:
|
|||
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);
|
||||
|
||||
// patch logging pattern to include the MPI rank, so that we can associate error messages with nodes
|
||||
if (numMPIProcesses() > 1)
|
||||
switchtoMultinodeLogging(std::to_string(MPIWrapper::myMPIRank()));
|
||||
if (numMPIProcesses() > 1) {
|
||||
std::string rankStr = std::to_string(MPIWrapper::myMPIRank());
|
||||
std::string maxRankStr = std::to_string(MPIWrapper::numMPIProcesses() -1);
|
||||
while (rankStr.size() < maxRankStr.size()) // pad so that logs across MPI processes line up nicely
|
||||
rankStr.insert(rankStr.begin(), ' ');
|
||||
switchtoMultinodeLogging(rankStr);
|
||||
}
|
||||
|
||||
// log hostnames in order, and test
|
||||
for (size_t r = 0; r < numMPIProcesses(); r++) {
|
||||
|
|
|
@ -198,12 +198,12 @@ void SyncGraphGroup::update(Ptr<data::Batch> batch) /*override*/ {
|
|||
// cost across all local devices
|
||||
// @TODO: We should report cost aggregated over all MPI processes.
|
||||
float cost = 0;
|
||||
for(auto& c : localDeviceCosts)
|
||||
for(auto& c : localDeviceCosts) // localDeviceCosts is already summed up over delay steps
|
||||
cost += c;
|
||||
// extrapolate cost across MPI processes
|
||||
// @TODO: This is a crude estimate. Rather, we should aggregate cost across all GPUs correctly; cf. gradient trick described above.
|
||||
// @TODO: If this is too crude, we can also resurrect the code from f68433 to loop over the local batches,
|
||||
// and then determine a correction factor based on actual counts. They are very close though across MPI processes.
|
||||
// and then determine a correction factor based on actual counts. They are very close though across MPI processes.
|
||||
cost *= mpi_->numMPIProcesses();
|
||||
|
||||
// if cost is average-based, we need to turn the sum over devices into an average as well
|
||||
|
|
Загрузка…
Ссылка в новой задаче