From e8f44af5bf44a79f102678f5d7bb737cd6da3b52 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Mar 2020 15:52:37 +0100 Subject: [PATCH] [generate] do_sample default back to False (#3298) * change do_samples back * None better default as boolean * adapt do_sample to True in test example * make style --- examples/summarization/bart/evaluate_cnn.py | 1 - src/transformers/modeling_tf_utils.py | 6 +++--- src/transformers/modeling_utils.py | 6 +++--- tests/test_modeling_bart.py | 7 ++++++- tests/test_modeling_common.py | 16 +++++++++------- tests/test_modeling_tf_common.py | 16 +++++++++------- 6 files changed, 30 insertions(+), 22 deletions(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index fded7e51f..b6a2eb7bd 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -35,7 +35,6 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): min_length=min_length + 1, # +1 from original because we start at step=1 no_repeat_ngram_size=3, early_stopping=True, - do_sample=False, decoder_start_token_id=model.config.eos_token_ids[0], ) dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a0247015b..088119ee3 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -460,8 +460,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): input_ids=None, max_length=None, min_length=None, - do_sample=True, - early_stopping=False, + do_sample=None, + early_stopping=None, num_beams=None, temperature=None, top_k=None, @@ -494,7 +494,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): The max length of the sequence to be generated. Between 1 and infinity. Default to 20. do_sample: (`optional`) bool - If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`. + If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False`. num_beams: (`optional`) int Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e2c2ef1bf..8e2b9c499 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -656,8 +656,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): input_ids=None, max_length=None, min_length=None, - do_sample=True, - early_stopping=False, + do_sample=None, + early_stopping=None, num_beams=None, temperature=None, top_k=None, @@ -691,7 +691,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): The max length of the sequence to be generated. Between 1 and infinity. Default to 20. do_sample: (`optional`) bool - If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `True`. + If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False`. num_beams: (`optional`) int Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1. diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 8f06152c6..77aed9eb6 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -285,7 +285,12 @@ class BartHeadTests(unittest.TestCase): max_length = 5 new_input_ids = lm_model.generate( - input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length + input_ids.clone(), + do_sample=True, + num_return_sequences=1, + num_beams=2, + no_repeat_ngram_size=3, + max_length=max_length, ) self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1)) # TODO(SS): uneven length batches, empty inputs diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc7bc967e..23dee7947 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -638,16 +638,16 @@ class ModelTesterMixin: if config.bos_token_id is None: with self.assertRaises(AssertionError): - model.generate(max_length=5) + model.generate(do_sample=True, max_length=5) # batch_size = 1 - self._check_generated_tokens(model.generate(input_ids)) + self._check_generated_tokens(model.generate(input_ids, do_sample=True)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(input_ids, num_beams=3)) + self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3)) else: # batch_size = 1 - self._check_generated_tokens(model.generate(max_length=5)) + self._check_generated_tokens(model.generate(do_sample=True, max_length=5)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) + self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3)) with self.assertRaises(AssertionError): # generating multiple sequences when greedy no beam generation @@ -659,12 +659,14 @@ class ModelTesterMixin: model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) # batch_size > 1, sample - self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) + self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3)) # batch_size > 1, greedy self._check_generated_tokens(model.generate(input_ids, do_sample=False)) # batch_size > 1, num_beams > 1, sample - self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) + self._check_generated_tokens( + model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,) + ) # batch_size > 1, num_beams > 1, greedy self._check_generated_tokens( model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 6887388d8..e2b7a0fa9 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -422,16 +422,16 @@ class TFModelTesterMixin: if config.bos_token_id is None: with self.assertRaises(AssertionError): - model.generate(max_length=5) + model.generate(do_sample=True, max_length=5) # batch_size = 1 - self._check_generated_tokens(model.generate(input_ids)) + self._check_generated_tokens(model.generate(input_ids, do_sample=True)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(input_ids, num_beams=3)) + self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3)) else: # batch_size = 1 - self._check_generated_tokens(model.generate(max_length=5)) + self._check_generated_tokens(model.generate(do_sample=True, max_length=5)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(max_length=5, num_beams=3)) + self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3)) with self.assertRaises(AssertionError): # generating multiple sequences when greedy no beam generation @@ -443,12 +443,14 @@ class TFModelTesterMixin: model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) # batch_size > 1, sample - self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3)) + self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3)) # batch_size > 1, greedy self._check_generated_tokens(model.generate(input_ids, do_sample=False)) # batch_size > 1, num_beams > 1, sample - self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,)) + self._check_generated_tokens( + model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,) + ) # batch_size > 1, num_beams > 1, greedy self._check_generated_tokens( model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)