[generate] do_sample default back to False (#3298)
* change do_samples back * None better default as boolean * adapt do_sample to True in test example * make style
This commit is contained in:
Родитель
2187c49f5c
Коммит
e8f44af5bf
|
@ -35,7 +35,6 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
|||
min_length=min_length + 1, # +1 from original because we start at step=1
|
||||
no_repeat_ngram_size=3,
|
||||
early_stopping=True,
|
||||
do_sample=False,
|
||||
decoder_start_token_id=model.config.eos_token_ids[0],
|
||||
)
|
||||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
||||
|
|
|
@ -460,8 +460,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
input_ids=None,
|
||||
max_length=None,
|
||||
min_length=None,
|
||||
do_sample=True,
|
||||
early_stopping=False,
|
||||
do_sample=None,
|
||||
early_stopping=None,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
top_k=None,
|
||||
|
@ -494,7 +494,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||
|
||||
do_sample: (`optional`) bool
|
||||
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`.
|
||||
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False`.
|
||||
|
||||
num_beams: (`optional`) int
|
||||
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||
|
|
|
@ -656,8 +656,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
input_ids=None,
|
||||
max_length=None,
|
||||
min_length=None,
|
||||
do_sample=True,
|
||||
early_stopping=False,
|
||||
do_sample=None,
|
||||
early_stopping=None,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
top_k=None,
|
||||
|
@ -691,7 +691,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||
|
||||
do_sample: (`optional`) bool
|
||||
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`.
|
||||
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False`.
|
||||
|
||||
num_beams: (`optional`) int
|
||||
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
||||
|
|
|
@ -285,7 +285,12 @@ class BartHeadTests(unittest.TestCase):
|
|||
|
||||
max_length = 5
|
||||
new_input_ids = lm_model.generate(
|
||||
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length
|
||||
input_ids.clone(),
|
||||
do_sample=True,
|
||||
num_return_sequences=1,
|
||||
num_beams=2,
|
||||
no_repeat_ngram_size=3,
|
||||
max_length=max_length,
|
||||
)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
|
||||
# TODO(SS): uneven length batches, empty inputs
|
||||
|
|
|
@ -638,16 +638,16 @@ class ModelTesterMixin:
|
|||
|
||||
if config.bos_token_id is None:
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(max_length=5)
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when greedy no beam generation
|
||||
|
@ -659,12 +659,14 @@ class ModelTesterMixin:
|
|||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
|
||||
)
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
|
|
|
@ -422,16 +422,16 @@ class TFModelTesterMixin:
|
|||
|
||||
if config.bos_token_id is None:
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(max_length=5)
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when greedy no beam generation
|
||||
|
@ -443,12 +443,14 @@ class TFModelTesterMixin:
|
|||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
|
||||
)
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
|
|
Загрузка…
Ссылка в новой задаче