Pipeline model type check (#5679)
* Add model type check for pipelines * Add model type check for pipelines * rename func * Fix the init parameters * Fix format * rollback unnecessary refactor
This commit is contained in:
Родитель
dc31a72f50
Коммит
0befb51327
|
@ -46,6 +46,10 @@ if is_tf_available():
|
|||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelWithLMHead,
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -57,6 +61,11 @@ if is_torch_available():
|
|||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoModelForSeq2SeqLM,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -396,6 +405,7 @@ class Pipeline(_ScikitCompat):
|
|||
if framework is None:
|
||||
framework = get_framework()
|
||||
|
||||
self.task = task
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.modelcard = modelcard
|
||||
|
@ -469,6 +479,19 @@ class Pipeline(_ScikitCompat):
|
|||
"""
|
||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||
|
||||
def check_model_type(self, supported_models):
|
||||
"""
|
||||
Check if the model class is in the supported class list of the pipeline.
|
||||
"""
|
||||
if not isinstance(supported_models, list): # Create from a model mapping
|
||||
supported_models = [item[1].__name__ for item in supported_models.items()]
|
||||
if self.model.__class__.__name__ not in supported_models:
|
||||
raise PipelineException(
|
||||
self.task,
|
||||
self.model.base_model_prefix,
|
||||
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
||||
)
|
||||
|
||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
||||
"""
|
||||
Parse arguments and tokenize
|
||||
|
@ -615,6 +638,11 @@ class TextGenerationPipeline(Pipeline):
|
|||
"TFCTRLLMHeadModel",
|
||||
]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.check_model_type(self.ALLOWED_MODELS)
|
||||
|
||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||
|
||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
||||
|
@ -640,12 +668,6 @@ class TextGenerationPipeline(Pipeline):
|
|||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
|
||||
raise NotImplementedError(
|
||||
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
|
||||
self.model.__class__.__name__, self.ALLOWED_MODELS
|
||||
)
|
||||
)
|
||||
|
||||
text_inputs = self._args_parser(*args)
|
||||
|
||||
|
@ -771,6 +793,12 @@ class TextClassificationPipeline(Pipeline):
|
|||
def __init__(self, return_all_scores: bool = False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
if self.framework == "tf"
|
||||
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
)
|
||||
|
||||
self.return_all_scores = return_all_scores
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
@ -847,6 +875,8 @@ class FillMaskPipeline(Pipeline):
|
|||
task=task,
|
||||
)
|
||||
|
||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
|
||||
|
||||
self.topk = topk
|
||||
|
||||
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
|
||||
|
@ -980,6 +1010,12 @@ class TokenClassificationPipeline(Pipeline):
|
|||
task=task,
|
||||
)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
if self.framework == "tf"
|
||||
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
)
|
||||
|
||||
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
self.ignore_labels = ignore_labels
|
||||
self.grouped_entities = grouped_entities
|
||||
|
@ -1220,6 +1256,10 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_sample(
|
||||
question: Union[str, List[str]], context: Union[str, List[str]]
|
||||
|
@ -1483,9 +1523,13 @@ class SummarizationPipeline(Pipeline):
|
|||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(task="summarization")
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.check_model_type(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
|
@ -1615,6 +1659,11 @@ class TranslationPipeline(Pipeline):
|
|||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
|
||||
|
||||
def __call__(
|
||||
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
|
||||
):
|
||||
|
|
Загрузка…
Ссылка в новой задаче