[cleanup] generate_beam_search comments (#5115)
This commit is contained in:
Родитель
ca2d0f98c4
Коммит
3d3e605aff
|
@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# Check if were done so that we can save a pad step if all(done)
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
|
||||
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len
|
||||
)
|
||||
|
||||
# update next beam content
|
||||
|
@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
|
|||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
||||
def is_done(self, best_sum_logprobs, cur_len):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
|
@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
|
|||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
if cur_len is None:
|
||||
cur_len = self.max_length
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
||||
|
|
|
@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
# for each sentence
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
# if we are done with this sentence, add a pad token
|
||||
if done[batch_idx]:
|
||||
assert (
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
|
@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
# next sentence beam content
|
||||
# next sentence beam content, this will get added to next_batch_beam
|
||||
next_sent_beam = []
|
||||
|
||||
# next tokens for this sentence
|
||||
|
@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
token_id = beam_token_id % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
|
@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted token if it is not eos_token
|
||||
# add next predicted token since it is not eos_token
|
||||
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# Check if were done so that we can save a pad step if all(done)
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len=cur_len
|
||||
next_scores[batch_idx].max().item(), cur_len
|
||||
)
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
|
||||
|
||||
# stop when we are done with each sentence
|
||||
if all(done):
|
||||
|
@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
# finalize all open beam hypotheses and end to generated hypotheses
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx in range(batch_size):
|
||||
if done[batch_idx]:
|
||||
continue
|
||||
|
@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
sent_lengths[effective_batch_idx] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# shorter batches are filled with pad_token
|
||||
# shorter batches are padded
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
|
@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
|
|||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs, cur_len=None):
|
||||
def is_done(self, best_sum_logprobs, cur_len):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
|
@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
|
|||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
if cur_len is None:
|
||||
cur_len = self.max_length
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
||||
|
|
Загрузка…
Ссылка в новой задаче