Simplify ngram block algorithm. Bump speed from 11.0 to 14.8 samples/s.

Before change: generate all ngram pair, and pick banned tokens by look up ngram with last n-1 token.
After change: generate banned tokens directly

For example, the previous generate tokens are 1 2 3 4 2 3. token need to be banned is 4.
Before change, it generate all pair in dict {"1 2": 3, "2 3": 4, "3 4": 2, "4 2" : 3}, and do look up by "2 3", finally find 4 should be banned.
After change, it put 4 in list, and banned it.
This commit is contained in:
Yu Yan 2020-08-31 16:09:40 -07:00 коммит произвёл GitHub
Родитель d07633eef6
Коммит e645065aa4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 16 добавлений и 22 удалений

Просмотреть файл

@ -29,10 +29,10 @@ FastSeq provides efficient implementations of the popular sequence models with h
- CNN daily mail val data, NVIDIA-V100-16GB
| BatchSize | 32 | 64 | 128 |
|:----------------:|:-------------:|:--------------:|:--------------:|
| fairseq-0.9.0 | 2.3 samples/s | OOM | OOM |
| above + fastseq | 6.1 samples/s | 8.7 samples/s | 11.0 samples/s |
| BatchSize | 32 | 64 | 128 |
|:----------------:|:-------------:|:---------------:|:--------------:|
| fairseq-0.9.0 | 2.7 samples/s | OOM | OOM |
| above + fastseq | 9.0 samples/s | 12.5 samples/s | 14.5 samples/s |
with setting:

Просмотреть файл

@ -644,18 +644,18 @@ class SequenceGeneratorV2(SequenceGenerator):
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
cpu_tokens = tokens.cpu()[:, :step + 1]
banned_list = [[] for bbsz_idx in range(bsz * beam_size)]
cpu_tokens = tokens.cpu()[:, :step + 1].numpy()
check_start_pos = step + 2 - self.no_repeat_ngram_size
for bbsz_idx in range(bsz * beam_size):
gen_tokens = cpu_tokens[bbsz_idx].tolist()
for ngram in zip(*[
gen_tokens[i:]
for i in range(self.no_repeat_ngram_size)
]):
if ngram[-1] != self.pad:
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), [])\
+ [ngram[-1]]
for i in range(check_start_pos):
is_banned = True
for k in range(self.no_repeat_ngram_size - 1):
if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[bbsz_idx, check_start_pos + k]:
is_banned = False
break
if is_banned:
banned_list[bbsz_idx].append(cpu_tokens[bbsz_idx, i + self.no_repeat_ngram_size - 1])
# Record attention scores
if avg_attn_scores is not None:
@ -676,14 +676,8 @@ class SequenceGeneratorV2(SequenceGenerator):
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(
cpu_tokens[bbsz_idx,
step + 2 - self.no_repeat_ngram_size:step +
1].tolist())
banned_tokens_per_sample = gen_ngrams[bbsz_idx].get(
ngram_index, [])
banned_tokens_per_sample = [
(bbsz_idx, t) for t in banned_tokens_per_sample
(bbsz_idx, t) for t in banned_list[bbsz_idx]
]
return banned_tokens_per_sample