* Add model type check for pipelines

* Add model type check for pipelines

* rename func

* Fix the init parameters

* Fix format

* rollback unnecessary refactor
This commit is contained in:
Kevin Canwen Xu 2020-07-12 12:34:21 +08:00 коммит произвёл GitHub
Родитель dc31a72f50
Коммит 0befb51327
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 57 добавлений и 8 удалений

Просмотреть файл

@ -46,6 +46,10 @@ if is_tf_available():
TFAutoModelForQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelWithLMHead,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
)
if is_torch_available():
@ -57,6 +61,11 @@ if is_torch_available():
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoModelForSeq2SeqLM,
MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
)
if TYPE_CHECKING:
@ -396,6 +405,7 @@ class Pipeline(_ScikitCompat):
if framework is None:
framework = get_framework()
self.task = task
self.model = model
self.tokenizer = tokenizer
self.modelcard = modelcard
@ -469,6 +479,19 @@ class Pipeline(_ScikitCompat):
"""
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
def check_model_type(self, supported_models):
"""
Check if the model class is in the supported class list of the pipeline.
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models = [item[1].__name__ for item in supported_models.items()]
if self.model.__class__.__name__ not in supported_models:
raise PipelineException(
self.task,
self.model.base_model_prefix,
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
@ -615,6 +638,11 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel",
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(self.ALLOWED_MODELS)
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
@ -640,12 +668,6 @@ class TextGenerationPipeline(Pipeline):
def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
raise NotImplementedError(
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
self.model.__class__.__name__, self.ALLOWED_MODELS
)
)
text_inputs = self._args_parser(*args)
@ -771,6 +793,12 @@ class TextClassificationPipeline(Pipeline):
def __init__(self, return_all_scores: bool = False, **kwargs):
super().__init__(**kwargs)
self.check_model_type(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
)
self.return_all_scores = return_all_scores
def __call__(self, *args, **kwargs):
@ -847,6 +875,8 @@ class FillMaskPipeline(Pipeline):
task=task,
)
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
self.topk = topk
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
@ -980,6 +1010,12 @@ class TokenClassificationPipeline(Pipeline):
task=task,
)
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
@ -1220,6 +1256,10 @@ class QuestionAnsweringPipeline(Pipeline):
**kwargs,
)
self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
)
@staticmethod
def create_sample(
question: Union[str, List[str]], context: Union[str, List[str]]
@ -1483,9 +1523,13 @@ class SummarizationPipeline(Pipeline):
on the associated CUDA device id.
"""
def __init__(self, **kwargs):
def __init__(self, *args, **kwargs):
kwargs.update(task="summarization")
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
)
def __call__(
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
@ -1615,6 +1659,11 @@ class TranslationPipeline(Pipeline):
on the associated CUDA device id.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):