defaults models for tf and pt - update tests
This commit is contained in:
Родитель
7f74084528
Коммит
db0795b5d0
|
@ -776,7 +776,10 @@ SUPPORTED_TASKS = {
|
|||
'tf': TFAutoModel if is_tf_available() else None,
|
||||
'pt': AutoModel if is_torch_available() else None,
|
||||
'default': {
|
||||
'model': 'distilbert-base-uncased',
|
||||
'model': {
|
||||
'pt': 'distilbert-base-uncased',
|
||||
'tf': 'distilbert-base-uncased',
|
||||
},
|
||||
'config': None,
|
||||
'tokenizer': 'distilbert-base-uncased'
|
||||
}
|
||||
|
@ -786,7 +789,10 @@ SUPPORTED_TASKS = {
|
|||
'tf': TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
'pt': AutoModelForSequenceClassification if is_torch_available() else None,
|
||||
'default': {
|
||||
'model': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin',
|
||||
'model': {
|
||||
'pt': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin',
|
||||
'tf': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5',
|
||||
},
|
||||
'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json',
|
||||
'tokenizer': 'distilbert-base-uncased'
|
||||
}
|
||||
|
@ -796,7 +802,10 @@ SUPPORTED_TASKS = {
|
|||
'tf': TFAutoModelForTokenClassification if is_tf_available() else None,
|
||||
'pt': AutoModelForTokenClassification if is_torch_available() else None,
|
||||
'default': {
|
||||
'model': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin',
|
||||
'model': {
|
||||
'pt':'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-pytorch_model.bin',
|
||||
'tf': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5',
|
||||
},
|
||||
'config': 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json',
|
||||
'tokenizer': 'bert-large-cased'
|
||||
}
|
||||
|
@ -806,7 +815,10 @@ SUPPORTED_TASKS = {
|
|||
'tf': TFAutoModelForQuestionAnswering if is_tf_available() else None,
|
||||
'pt': AutoModelForQuestionAnswering if is_torch_available() else None,
|
||||
'default': {
|
||||
'model': 'distilbert-base-uncased-distilled-squad',
|
||||
'model': {
|
||||
'pt': 'distilbert-base-uncased-distilled-squad',
|
||||
'tf': 'distilbert-base-uncased-distilled-squad',
|
||||
},
|
||||
'config': None,
|
||||
'tokenizer': 'distilbert-base-uncased'
|
||||
}
|
||||
|
@ -843,7 +855,8 @@ def pipeline(task: str, model: Optional = None,
|
|||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
model, config, tokenizer = tuple(targeted_task['default'].values())
|
||||
models, config, tokenizer = tuple(targeted_task['default'].values())
|
||||
model = models[framework]
|
||||
|
||||
# Try to infer tokenizer from model or config name (if provided as str)
|
||||
if tokenizer is None:
|
||||
|
|
|
@ -11,6 +11,20 @@ QA_FINETUNED_MODELS = {
|
|||
('bert-base-uncased', 'distilbert-base-uncased-distilled-squad', None)
|
||||
}
|
||||
|
||||
TF_QA_FINETUNED_MODELS = {
|
||||
('bert-base-uncased', 'bert-large-uncased-whole-word-masking-finetuned-squad', None),
|
||||
('bert-base-cased', 'bert-large-cased-whole-word-masking-finetuned-squad', None),
|
||||
('bert-base-uncased', 'distilbert-base-uncased-distilled-squad', None)
|
||||
}
|
||||
|
||||
TF_NER_FINETUNED_MODELS = {
|
||||
(
|
||||
'bert-base-cased',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-tf_model.h5',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-finetuned-conll03-english-config.json'
|
||||
)
|
||||
}
|
||||
|
||||
NER_FINETUNED_MODELS = {
|
||||
(
|
||||
'bert-base-cased',
|
||||
|
@ -25,6 +39,20 @@ FEATURE_EXTRACT_FINETUNED_MODELS = {
|
|||
('distilbert-base-uncased', 'distilbert-base-uncased', None)
|
||||
}
|
||||
|
||||
TF_FEATURE_EXTRACT_FINETUNED_MODELS = {
|
||||
('bert-base-cased', 'bert-base-cased', None),
|
||||
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
|
||||
('distilbert-base-uncased', 'distilbert-base-uncased', None)
|
||||
}
|
||||
|
||||
TF_TEXT_CLASSIF_FINETUNED_MODELS = {
|
||||
(
|
||||
'bert-base-uncased',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5',
|
||||
'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-finetuned-sst-2-english-config.json'
|
||||
)
|
||||
}
|
||||
|
||||
TEXT_CLASSIF_FINETUNED_MODELS = {
|
||||
(
|
||||
'bert-base-uncased',
|
||||
|
@ -75,7 +103,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||
mandatory_keys = {'entity', 'word', 'score'}
|
||||
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in NER_FINETUNED_MODELS:
|
||||
for tokenizer, model, config in TF_NER_FINETUNED_MODELS:
|
||||
nlp = pipeline(task='ner', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
|
@ -93,7 +121,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||
mandatory_keys = {'label'}
|
||||
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
|
||||
for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS:
|
||||
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
|
@ -109,7 +137,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||
def test_tf_features_extraction(self):
|
||||
valid_inputs = ['HuggingFace is solving NLP one commit at a time.', 'HuggingFace is based in New-York & Paris']
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
|
||||
for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS:
|
||||
nlp = pipeline(task='sentiment-analysis', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
|
||||
|
||||
|
@ -173,7 +201,7 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
|||
{'question': 'What is does with empty context ?', 'context': None},
|
||||
]
|
||||
|
||||
for tokenizer, model, config in QA_FINETUNED_MODELS:
|
||||
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
|
||||
nlp = pipeline(task='question-answering', model=model, config=config, tokenizer=tokenizer)
|
||||
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче