This commit is contained in:
patrickvonplaten 2019-12-23 22:15:06 +01:00
Родитель 7bb4271291
Коммит eeaa402cd4
1 изменённых файлов: 2 добавлений и 2 удалений

Просмотреть файл

@ -731,7 +731,7 @@ class PreTrainedModel(nn.Module):
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past parameter to speed up decoding
# if model has past, then set the past variable to speed up decoding
if self._has_past(outputs):
past = outputs[1]
@ -818,7 +818,7 @@ class PreTrainedModel(nn.Module):
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past parameter to speed up decoding
# if model has past, then set the past variable to speed up decoding
if self._has_past(outputs):
past = outputs[1]