Add t5 to pipeline(task='summarization') (#3413)

* solve conflicts

* move warnings below

* incorporate changes

* add pad_to_max_length to pipelines

* add bug fix for T5 beam search

* add prefix patterns

* make style

* fix conflicts

* adapt pipelines for task specific parameters

* improve docstring

* remove unused patterns
This commit is contained in:
Patrick von Platen 2020-03-26 11:03:13 +01:00 коммит произвёл GitHub
Родитель ffcffebe85
Коммит 9c683ef01e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 120 добавлений и 42 удалений

Просмотреть файл

@ -380,3 +380,14 @@ class PretrainedConfig(object):
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
def update(self, config_dict: Dict):
"""
Updates attributes of this class
with attributes from `config_dict`.
Args:
:obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)

Просмотреть файл

@ -999,10 +999,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask
num_batch_hypotheses = batch_size * num_beams
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))

Просмотреть файл

@ -31,6 +31,7 @@ from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
from .configuration_bart import BartConfig
from .configuration_distilbert import DistilBertConfig
from .configuration_roberta import RobertaConfig
from .configuration_t5 import T5Config
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLMConfig
from .data import SquadExample, squad_convert_examples_to_features
@ -60,7 +61,6 @@ if is_torch_available():
AutoModelForTokenClassification,
AutoModelWithLMHead,
)
from .modeling_bart import BartForConditionalGeneration
logger = logging.getLogger(__name__)
@ -336,6 +336,7 @@ class Pipeline(_ScikitCompat):
tokenizer: PreTrainedTokenizer,
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: int = -1,
binary_output: bool = False,
@ -356,6 +357,11 @@ class Pipeline(_ScikitCompat):
if self.framework == "pt" and self.device.type == "cuda":
self.model = self.model.to(self.device)
# Update config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
def save_pretrained(self, save_directory):
"""
Save the pipeline's model and tokenizer to the specified save_directory
@ -420,7 +426,7 @@ class Pipeline(_ScikitCompat):
"""
args = ["input_ids", "attention_mask"]
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
args += ["token_type_ids"]
# PR #1548 (CLI) There is an issue with attention_mask
@ -432,14 +438,18 @@ class Pipeline(_ScikitCompat):
else:
return {k: [feature[k] for feature in features] for k in args}
def _parse_and_tokenize(self, *texts, **kwargs):
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
inputs = self._args_parser(*texts, **kwargs)
inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
inputs,
add_special_tokens=True,
return_tensors=self.framework,
max_length=self.tokenizer.max_len,
pad_to_max_length=pad_to_max_length,
)
# Filter out features not available on specific models
@ -520,6 +530,7 @@ class FeatureExtractionPipeline(Pipeline):
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
task: str = "",
):
super().__init__(
model=model,
@ -529,6 +540,7 @@ class FeatureExtractionPipeline(Pipeline):
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)
def __call__(self, *args, **kwargs):
@ -625,6 +637,7 @@ class FillMaskPipeline(Pipeline):
args_parser: ArgumentHandler = None,
device: int = -1,
topk=5,
task: str = "",
):
super().__init__(
model=model,
@ -634,6 +647,7 @@ class FillMaskPipeline(Pipeline):
args_parser=args_parser,
device=device,
binary_output=True,
task=task,
)
self.topk = topk
@ -725,6 +739,7 @@ class NerPipeline(Pipeline):
device: int = -1,
binary_output: bool = False,
ignore_labels=["O"],
task: str = "",
):
super().__init__(
model=model,
@ -734,6 +749,7 @@ class NerPipeline(Pipeline):
args_parser=args_parser,
device=device,
binary_output=binary_output,
task=task,
)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
@ -896,6 +912,7 @@ class QuestionAnsweringPipeline(Pipeline):
modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None,
device: int = -1,
task: str = "",
**kwargs
):
super().__init__(
@ -905,6 +922,7 @@ class QuestionAnsweringPipeline(Pipeline):
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device,
task=task,
**kwargs,
)
@ -1111,12 +1129,16 @@ class SummarizationPipeline(Pipeline):
Usage::
# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
# use t5 in tf
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
Supported Models:
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
@ -1147,17 +1169,8 @@ class SummarizationPipeline(Pipeline):
on the associated CUDA device id.
"""
task = "summarization"
def __call__(
self,
*documents,
return_tensors=False,
return_text=True,
max_length=142,
min_length=21,
clean_up_tokenization_spaces=False,
**generate_kwargs
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
r"""
Args:
@ -1165,10 +1178,6 @@ class SummarizationPipeline(Pipeline):
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
min_len: (`optional`) int
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
**generate_kwargs: extra kwargs passed to `self.model.generate`_
@ -1180,19 +1189,60 @@ class SummarizationPipeline(Pipeline):
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
if self.framework == "tf":
raise NotImplementedError("Tensorflow not supported")
with self.device_placement():
inputs = self._parse_and_tokenize(*documents)
inputs = self.ensure_tensor_on_device(**inputs)
summaries = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
min_length=min_length,
do_sample=False,
**generate_kwargs,
assert len(documents) > 0, "Please provide a document to summarize"
if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
raise NotImplementedError(
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
)
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(documents[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
documents = ([prefix + document for document in documents[0]],)
pad_to_max_length = True
elif isinstance(documents[0], str):
documents = (prefix + documents[0],)
pad_to_max_length = False
else:
raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
documents[0]
)
)
with self.device_placement():
inputs = self._parse_and_tokenize(*documents, pad_to_max_length=pad_to_max_length)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1]
if input_length < self.model.config.min_length // 2:
logger.warning(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format(
self.model.config.min_length, input_length
)
)
if input_length < self.model.config.max_length:
logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format(
self.model.config.max_length, input_length
)
)
summaries = self.model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
)
results = []
for summary in summaries:
record = {}
@ -1266,8 +1316,8 @@ SUPPORTED_TASKS = {
},
"summarization": {
"impl": SummarizationPipeline,
"pt": BartForConditionalGeneration if is_torch_available() else None,
"tf": None,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "bart-large-cnn", "tf": None},
"config": None,
@ -1361,7 +1411,7 @@ def pipeline(
framework = framework or get_framework(model)
targeted_task = SUPPORTED_TASKS[task]
task, model_class = targeted_task["impl"], targeted_task[framework]
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
@ -1422,4 +1472,4 @@ def pipeline(
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)
return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)

Просмотреть файл

@ -78,6 +78,9 @@ TF_FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(
@ -252,10 +255,22 @@ class MonoColumnInputTestCase(unittest.TestCase):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
nlp = pipeline(task="summarization")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
for model, tokenizer in SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
@require_tf
def test_tf_summarization(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
for model, tokenizer in TF_SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
class MultiColumnInputTestCase(unittest.TestCase):