diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 246a8afe4..e49cc3311 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1391,6 +1391,10 @@ class SummarizationPipeline(Pipeline): on the associated CUDA device id. """ + def __init__(self, **kwargs): + kwargs.update(task="summarization") + super().__init__(**kwargs) + def __call__( self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs ):