Add t5 to pipeline(task='summarization') (#3413)
* solve conflicts * move warnings below * incorporate changes * add pad_to_max_length to pipelines * add bug fix for T5 beam search * add prefix patterns * make style * fix conflicts * adapt pipelines for task specific parameters * improve docstring * remove unused patterns
This commit is contained in:
Родитель
ffcffebe85
Коммит
9c683ef01e
|
@ -380,3 +380,14 @@ class PretrainedConfig(object):
|
|||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
def update(self, config_dict: Dict):
|
||||
"""
|
||||
Updates attributes of this class
|
||||
with attributes from `config_dict`.
|
||||
|
||||
Args:
|
||||
:obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
|
||||
"""
|
||||
for key, value in config_dict.items():
|
||||
setattr(self, key, value)
|
||||
|
|
|
@ -999,10 +999,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
# create eos_token_id boolean mask
|
||||
num_batch_hypotheses = batch_size * num_beams
|
||||
|
||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
|
||||
|
||||
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
|||
from .configuration_bart import BartConfig
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .configuration_t5 import T5Config
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .configuration_xlm import XLMConfig
|
||||
from .data import SquadExample, squad_convert_examples_to_features
|
||||
|
@ -60,7 +61,6 @@ if is_torch_available():
|
|||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -336,6 +336,7 @@ class Pipeline(_ScikitCompat):
|
|||
tokenizer: PreTrainedTokenizer,
|
||||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
task: str = "",
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: int = -1,
|
||||
binary_output: bool = False,
|
||||
|
@ -356,6 +357,11 @@ class Pipeline(_ScikitCompat):
|
|||
if self.framework == "pt" and self.device.type == "cuda":
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# Update config with task specific parameters
|
||||
task_specific_params = self.model.config.task_specific_params
|
||||
if task_specific_params is not None and task in task_specific_params:
|
||||
self.model.config.update(task_specific_params.get(task))
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Save the pipeline's model and tokenizer to the specified save_directory
|
||||
|
@ -420,7 +426,7 @@ class Pipeline(_ScikitCompat):
|
|||
"""
|
||||
args = ["input_ids", "attention_mask"]
|
||||
|
||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig)):
|
||||
if not isinstance(self.model.config, (DistilBertConfig, XLMConfig, RobertaConfig, BartConfig, T5Config)):
|
||||
args += ["token_type_ids"]
|
||||
|
||||
# PR #1548 (CLI) There is an issue with attention_mask
|
||||
|
@ -432,14 +438,18 @@ class Pipeline(_ScikitCompat):
|
|||
else:
|
||||
return {k: [feature[k] for feature in features] for k in args}
|
||||
|
||||
def _parse_and_tokenize(self, *texts, **kwargs):
|
||||
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
"""
|
||||
# Parse arguments
|
||||
inputs = self._args_parser(*texts, **kwargs)
|
||||
inputs = self.tokenizer.batch_encode_plus(
|
||||
inputs, add_special_tokens=True, return_tensors=self.framework, max_length=self.tokenizer.max_len
|
||||
inputs,
|
||||
add_special_tokens=True,
|
||||
return_tensors=self.framework,
|
||||
max_length=self.tokenizer.max_len,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
)
|
||||
|
||||
# Filter out features not available on specific models
|
||||
|
@ -520,6 +530,7 @@ class FeatureExtractionPipeline(Pipeline):
|
|||
framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: int = -1,
|
||||
task: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
@ -529,6 +540,7 @@ class FeatureExtractionPipeline(Pipeline):
|
|||
args_parser=args_parser,
|
||||
device=device,
|
||||
binary_output=True,
|
||||
task=task,
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
@ -625,6 +637,7 @@ class FillMaskPipeline(Pipeline):
|
|||
args_parser: ArgumentHandler = None,
|
||||
device: int = -1,
|
||||
topk=5,
|
||||
task: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
@ -634,6 +647,7 @@ class FillMaskPipeline(Pipeline):
|
|||
args_parser=args_parser,
|
||||
device=device,
|
||||
binary_output=True,
|
||||
task=task,
|
||||
)
|
||||
|
||||
self.topk = topk
|
||||
|
@ -725,6 +739,7 @@ class NerPipeline(Pipeline):
|
|||
device: int = -1,
|
||||
binary_output: bool = False,
|
||||
ignore_labels=["O"],
|
||||
task: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
@ -734,6 +749,7 @@ class NerPipeline(Pipeline):
|
|||
args_parser=args_parser,
|
||||
device=device,
|
||||
binary_output=binary_output,
|
||||
task=task,
|
||||
)
|
||||
|
||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
|
@ -896,6 +912,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||
modelcard: Optional[ModelCard] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: int = -1,
|
||||
task: str = "",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -905,6 +922,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||
framework=framework,
|
||||
args_parser=QuestionAnsweringArgumentHandler(),
|
||||
device=device,
|
||||
task=task,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -1111,12 +1129,16 @@ class SummarizationPipeline(Pipeline):
|
|||
|
||||
Usage::
|
||||
|
||||
# use bart in pytorch
|
||||
summarizer = pipeline("summarization")
|
||||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")
|
||||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
|
||||
|
||||
# use t5 in tf
|
||||
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
|
||||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
|
||||
|
||||
Supported Models:
|
||||
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
|
||||
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``
|
||||
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'.
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
|
@ -1147,17 +1169,8 @@ class SummarizationPipeline(Pipeline):
|
|||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
task = "summarization"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*documents,
|
||||
return_tensors=False,
|
||||
return_text=True,
|
||||
max_length=142,
|
||||
min_length=21,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**generate_kwargs
|
||||
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
|
@ -1165,10 +1178,6 @@ class SummarizationPipeline(Pipeline):
|
|||
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
|
||||
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result
|
||||
|
||||
max_length: (`optional`) int
|
||||
The max length of the sequence to be generated. Does not include tokens in input_ids.
|
||||
min_len: (`optional`) int
|
||||
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
|
||||
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
|
||||
**generate_kwargs: extra kwargs passed to `self.model.generate`_
|
||||
|
||||
|
@ -1180,19 +1189,60 @@ class SummarizationPipeline(Pipeline):
|
|||
|
||||
"""
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
if self.framework == "tf":
|
||||
raise NotImplementedError("Tensorflow not supported")
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*documents)
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
summaries = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=False,
|
||||
**generate_kwargs,
|
||||
assert len(documents) > 0, "Please provide a document to summarize"
|
||||
|
||||
if self.framework == "tf" and "BartForConditionalGeneration" in self.model.__class__.__name__:
|
||||
raise NotImplementedError(
|
||||
"Tensorflow is not yet supported for Bart. Please consider using T5, e.g. `t5-base`"
|
||||
)
|
||||
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
|
||||
if isinstance(documents[0], list):
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||
|
||||
documents = ([prefix + document for document in documents[0]],)
|
||||
pad_to_max_length = True
|
||||
|
||||
elif isinstance(documents[0], str):
|
||||
documents = (prefix + documents[0],)
|
||||
pad_to_max_length = False
|
||||
else:
|
||||
raise ValueError(
|
||||
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
|
||||
documents[0]
|
||||
)
|
||||
)
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*documents, pad_to_max_length=pad_to_max_length)
|
||||
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
elif self.framework == "tf":
|
||||
input_length = tf.shape(inputs["input_ids"])[-1]
|
||||
|
||||
if input_length < self.model.config.min_length // 2:
|
||||
logger.warning(
|
||||
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format(
|
||||
self.model.config.min_length, input_length
|
||||
)
|
||||
)
|
||||
|
||||
if input_length < self.model.config.max_length:
|
||||
logger.warning(
|
||||
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format(
|
||||
self.model.config.max_length, input_length
|
||||
)
|
||||
)
|
||||
|
||||
summaries = self.model.generate(
|
||||
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
|
||||
)
|
||||
|
||||
results = []
|
||||
for summary in summaries:
|
||||
record = {}
|
||||
|
@ -1266,8 +1316,8 @@ SUPPORTED_TASKS = {
|
|||
},
|
||||
"summarization": {
|
||||
"impl": SummarizationPipeline,
|
||||
"pt": BartForConditionalGeneration if is_torch_available() else None,
|
||||
"tf": None,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "bart-large-cnn", "tf": None},
|
||||
"config": None,
|
||||
|
@ -1361,7 +1411,7 @@ def pipeline(
|
|||
framework = framework or get_framework(model)
|
||||
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
task, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
task_class, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
|
@ -1422,4 +1472,4 @@ def pipeline(
|
|||
)
|
||||
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
||||
|
||||
return task(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, **kwargs)
|
||||
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
|
||||
|
|
|
@ -78,6 +78,9 @@ TF_FILL_MASK_FINETUNED_MODELS = [
|
|||
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
|
||||
]
|
||||
|
||||
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
|
||||
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
|
||||
|
||||
|
||||
class MonoColumnInputTestCase(unittest.TestCase):
|
||||
def _test_mono_column_pipeline(
|
||||
|
@ -252,10 +255,22 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["summary_text"]
|
||||
nlp = pipeline(task="summarization")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
for model, tokenizer in SUMMARIZATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_summarization(self):
|
||||
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||
invalid_inputs = [4, "<mask>"]
|
||||
mandatory_keys = ["summary_text"]
|
||||
for model, tokenizer in TF_SUMMARIZATION_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||
)
|
||||
|
||||
|
||||
class MultiColumnInputTestCase(unittest.TestCase):
|
||||
|
|
Загрузка…
Ссылка в новой задаче