[cleanup] generate_beam_search comments (#5115)

This commit is contained in:
Sam Shleifer 2020-06-18 16:30:24 -04:00 коммит произвёл GitHub
Родитель ca2d0f98c4
Коммит 3d3e605aff
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 14 добавлений и 18 удалений

Просмотреть файл

@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if len(next_sent_beam) == num_beams:
break
# Check if were done so that we can save a pad step if all(done)
# Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len
)
# update next beam content
@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None):
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
elif self.early_stopping:
return True
else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret

Просмотреть файл

@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence
# if we are done with this sentence, add a pad token
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content
# next sentence beam content, this will get added to next_batch_beam
next_sent_beam = []
# next tokens for this sentence
@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids[effective_beam_id].clone(), beam_token_score.item(),
)
else:
# add next predicted token if it is not eos_token
# add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full
# once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams:
break
# Check if were done so that we can save a pad step if all(done)
# Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=cur_len
next_scores[batch_idx].max().item(), cur_len
)
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
# stop when we are done with each sentence
if all(done):
@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# finalize all open beam hypotheses and end to generated hypotheses
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx in range(batch_size):
if done[batch_idx]:
continue
@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token
# shorter batches are padded
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None):
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
elif self.early_stopping:
return True
else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret