make if statements cleaner for prepare_inputs_for_generation
This commit is contained in:
Родитель
d039c679d2
Коммит
365ccd0af2
|
@ -491,8 +491,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
# inputs_ids should only be composed of last token if past is in kwargs and defined
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if 'past' in kwargs and kwargs['past']:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
inputs = {"input_ids": input_ids}
|
||||
inputs.update(kwargs)
|
||||
|
|
|
@ -560,8 +560,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
# inputs_ids should only be composed of last token if past is in kwargs and defined
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if 'past' in kwargs and kwargs['past']:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
inputs = {"input_ids": input_ids}
|
||||
inputs.update(kwargs)
|
||||
|
|
|
@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module):
|
|||
return {"input_ids": input_ids}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
# TODO: might be better to write a self.do_output_past method for each individual class as is done for
|
||||
# prepare_inputs_for_generation
|
||||
# TODO: might be better to write a self.do_output_past method for each
|
||||
# individual class as is done for prepare_inputs_for_generation
|
||||
has_output_past = hasattr(self.config, 'output_past') and self.config.output_past
|
||||
has_multiple_outputs = len(outputs) > 1
|
||||
has_mem_len = hasattr(self, 'mem_len')
|
||||
|
|
Загрузка…
Ссылка в новой задаче