зеркало из https://github.com/microsoft/fastseq.git
Simplify ngram block algo (#18)
Simplify ngram block algorithm. Bump speed from 11.0 to 14.8 samples/s. Before change: generate all ngram pair, and pick banned tokens by look up ngram with last n-1 token. After change: generate banned tokens directly For example, the previous generate tokens are 1 2 3 4 2 3. token need to be banned is 4. Before change, it generate all pair in dict {"1 2": 3, "2 3": 4, "3 4": 2, "4 2" : 3}, and do look up by "2 3", finally find 4 should be banned. After change, it put 4 in list, and banned it.
This commit is contained in:
Родитель
d07633eef6
Коммит
e645065aa4
|
@ -29,10 +29,10 @@ FastSeq provides efficient implementations of the popular sequence models with h
|
|||
|
||||
- CNN daily mail val data, NVIDIA-V100-16GB
|
||||
|
||||
| BatchSize | 32 | 64 | 128 |
|
||||
|:----------------:|:-------------:|:--------------:|:--------------:|
|
||||
| fairseq-0.9.0 | 2.3 samples/s | OOM | OOM |
|
||||
| above + fastseq | 6.1 samples/s | 8.7 samples/s | 11.0 samples/s |
|
||||
| BatchSize | 32 | 64 | 128 |
|
||||
|:----------------:|:-------------:|:---------------:|:--------------:|
|
||||
| fairseq-0.9.0 | 2.7 samples/s | OOM | OOM |
|
||||
| above + fastseq | 9.0 samples/s | 12.5 samples/s | 14.5 samples/s |
|
||||
|
||||
with setting:
|
||||
|
||||
|
|
|
@ -644,18 +644,18 @@ class SequenceGeneratorV2(SequenceGenerator):
|
|||
|
||||
if self.no_repeat_ngram_size > 0:
|
||||
# for each beam and batch sentence, generate a list of previous ngrams
|
||||
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
|
||||
cpu_tokens = tokens.cpu()[:, :step + 1]
|
||||
banned_list = [[] for bbsz_idx in range(bsz * beam_size)]
|
||||
cpu_tokens = tokens.cpu()[:, :step + 1].numpy()
|
||||
check_start_pos = step + 2 - self.no_repeat_ngram_size
|
||||
for bbsz_idx in range(bsz * beam_size):
|
||||
gen_tokens = cpu_tokens[bbsz_idx].tolist()
|
||||
for ngram in zip(*[
|
||||
gen_tokens[i:]
|
||||
for i in range(self.no_repeat_ngram_size)
|
||||
]):
|
||||
if ngram[-1] != self.pad:
|
||||
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
|
||||
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), [])\
|
||||
+ [ngram[-1]]
|
||||
for i in range(check_start_pos):
|
||||
is_banned = True
|
||||
for k in range(self.no_repeat_ngram_size - 1):
|
||||
if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[bbsz_idx, check_start_pos + k]:
|
||||
is_banned = False
|
||||
break
|
||||
if is_banned:
|
||||
banned_list[bbsz_idx].append(cpu_tokens[bbsz_idx, i + self.no_repeat_ngram_size - 1])
|
||||
|
||||
# Record attention scores
|
||||
if avg_attn_scores is not None:
|
||||
|
@ -676,14 +676,8 @@ class SequenceGeneratorV2(SequenceGenerator):
|
|||
|
||||
def calculate_banned_tokens(bbsz_idx):
|
||||
# before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
ngram_index = tuple(
|
||||
cpu_tokens[bbsz_idx,
|
||||
step + 2 - self.no_repeat_ngram_size:step +
|
||||
1].tolist())
|
||||
banned_tokens_per_sample = gen_ngrams[bbsz_idx].get(
|
||||
ngram_index, [])
|
||||
banned_tokens_per_sample = [
|
||||
(bbsz_idx, t) for t in banned_tokens_per_sample
|
||||
(bbsz_idx, t) for t in banned_list[bbsz_idx]
|
||||
]
|
||||
return banned_tokens_per_sample
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче