Enable fastseq generate CLI to handle the empty shard correctly (#82)

* Handle the empty shard

* Remove unused variable
This commit is contained in:
Fei Hu 2021-01-15 20:46:46 -08:00 коммит произвёл GitHub
Родитель 4fbb2211b9
Коммит c5549069cb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 7 добавлений и 4 удалений

Просмотреть файл

@ -60,7 +60,7 @@ class IOProcess(Process):
self.args = args
self.message_queue = message_queue
self.has_target = True
self.has_target = False
def run(self):
while True:
@ -71,6 +71,7 @@ class IOProcess(Process):
self.scorer.add_string(t, h)
else:
self.scorer.add(t, h)
self.has_target = True
elif msg == GENERATE_FINISHED:
if self.has_target:
print('| Generate {} with beam={}: {}'.format(
@ -124,7 +125,6 @@ class PostProcess(Process):
self.task = task
self.data_queue = data_queue
self.message_queue = message_queue
self.has_target = True
if args.decode_hypothesis:
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args)
@ -402,10 +402,13 @@ def main_v1(args):
for p in p_list:
p.join()
sent_throught = num_sentences / gen_timer.sum if num_sentences > 0 else 0
tokens_throught = 1. / gen_timer.avg if num_sentences > 0 else 0
message_queue.put(
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'. # pylint: disable=line-too-long
format(num_sentences, gen_timer.n, gen_timer.sum,
num_sentences / gen_timer.sum, 1. / gen_timer.avg))
format(num_sentences, gen_timer.n, gen_timer.sum, sent_throught,
tokens_throught))
message_queue.put(GENERATE_FINISHED)
io_process.join()