2019-12-09 20:35:26 +03:00
|
|
|
import unittest
|
2019-12-20 13:47:56 +03:00
|
|
|
from typing import Iterable
|
|
|
|
|
|
|
|
from transformers import pipeline
|
2019-12-22 15:44:13 +03:00
|
|
|
|
|
|
|
from .utils import require_tf, require_torch
|
2019-12-09 20:35:26 +03:00
|
|
|
|
2019-12-21 17:57:32 +03:00
|
|
|
|
2019-12-09 20:35:26 +03:00
|
|
|
QA_FINETUNED_MODELS = {
|
2019-12-21 17:46:46 +03:00
|
|
|
("bert-base-uncased", "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
|
|
|
("bert-base-cased", "bert-large-cased-whole-word-masking-finetuned-squad", None),
|
|
|
|
("bert-base-uncased", "distilbert-base-uncased-distilled-squad", None),
|
2019-12-20 13:47:56 +03:00
|
|
|
}
|
|
|
|
|
2019-12-20 17:07:00 +03:00
|
|
|
TF_QA_FINETUNED_MODELS = {
|
2019-12-21 17:46:46 +03:00
|
|
|
("bert-base-uncased", "bert-large-uncased-whole-word-masking-finetuned-squad", None),
|
|
|
|
("bert-base-cased", "bert-large-cased-whole-word-masking-finetuned-squad", None),
|
|
|
|
("bert-base-uncased", "distilbert-base-uncased-distilled-squad", None),
|
2019-12-20 17:07:00 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
TF_NER_FINETUNED_MODELS = {
|
|
|
|
(
|
2019-12-21 17:46:46 +03:00
|
|
|
"bert-base-cased",
|
2020-01-15 18:43:44 +03:00
|
|
|
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
|
|
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
2019-12-20 17:07:00 +03:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
2019-12-20 13:47:56 +03:00
|
|
|
NER_FINETUNED_MODELS = {
|
|
|
|
(
|
2019-12-21 17:46:46 +03:00
|
|
|
"bert-base-cased",
|
2020-01-15 18:43:44 +03:00
|
|
|
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
|
|
"dbmdz/bert-large-cased-finetuned-conll03-english",
|
2019-12-20 13:47:56 +03:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
FEATURE_EXTRACT_FINETUNED_MODELS = {
|
2019-12-21 17:46:46 +03:00
|
|
|
("bert-base-cased", "bert-base-cased", None),
|
|
|
|
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
|
|
|
|
("distilbert-base-uncased", "distilbert-base-uncased", None),
|
2019-12-20 13:47:56 +03:00
|
|
|
}
|
2019-12-09 20:35:26 +03:00
|
|
|
|
2019-12-20 17:07:00 +03:00
|
|
|
TF_FEATURE_EXTRACT_FINETUNED_MODELS = {
|
2019-12-21 17:46:46 +03:00
|
|
|
("bert-base-cased", "bert-base-cased", None),
|
|
|
|
# ('xlnet-base-cased', 'xlnet-base-cased', None), # Disabled for now as it crash for TF2
|
|
|
|
("distilbert-base-uncased", "distilbert-base-uncased", None),
|
2019-12-20 17:07:00 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
TF_TEXT_CLASSIF_FINETUNED_MODELS = {
|
|
|
|
(
|
2019-12-21 17:46:46 +03:00
|
|
|
"bert-base-uncased",
|
2020-01-15 19:28:50 +03:00
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english",
|
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english",
|
2019-12-20 17:07:00 +03:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
2019-12-20 13:47:56 +03:00
|
|
|
TEXT_CLASSIF_FINETUNED_MODELS = {
|
|
|
|
(
|
2019-12-21 17:46:46 +03:00
|
|
|
"bert-base-uncased",
|
2020-01-15 19:28:50 +03:00
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english",
|
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english",
|
2019-12-20 13:47:56 +03:00
|
|
|
)
|
2019-12-09 20:35:26 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-12-20 13:47:56 +03:00
|
|
|
class MonoColumnInputTestCase(unittest.TestCase):
|
|
|
|
def _test_mono_column_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
|
|
|
self.assertIsNotNone(nlp)
|
|
|
|
|
|
|
|
mono_result = nlp(valid_inputs[0])
|
|
|
|
self.assertIsInstance(mono_result, list)
|
|
|
|
self.assertIsInstance(mono_result[0], (dict, list))
|
|
|
|
|
|
|
|
if isinstance(mono_result[0], list):
|
|
|
|
mono_result = mono_result[0]
|
|
|
|
|
|
|
|
for key in output_keys:
|
|
|
|
self.assertIn(key, mono_result[0])
|
|
|
|
|
|
|
|
multi_result = nlp(valid_inputs)
|
|
|
|
self.assertIsInstance(multi_result, list)
|
|
|
|
self.assertIsInstance(multi_result[0], (dict, list))
|
|
|
|
|
|
|
|
if isinstance(multi_result[0], list):
|
|
|
|
multi_result = multi_result[0]
|
|
|
|
|
|
|
|
for result in multi_result:
|
|
|
|
for key in output_keys:
|
|
|
|
self.assertIn(key, result)
|
|
|
|
|
|
|
|
self.assertRaises(Exception, nlp, invalid_inputs)
|
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_torch
|
2019-12-20 13:47:56 +03:00
|
|
|
def test_ner(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
mandatory_keys = {"entity", "word", "score"}
|
|
|
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
2019-12-20 13:47:56 +03:00
|
|
|
invalid_inputs = [None]
|
|
|
|
for tokenizer, model, config in NER_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_tf
|
|
|
|
def test_tf_ner(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
mandatory_keys = {"entity", "word", "score"}
|
|
|
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
2019-12-20 15:16:23 +03:00
|
|
|
invalid_inputs = [None]
|
2019-12-20 17:07:00 +03:00
|
|
|
for tokenizer, model, config in TF_NER_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_torch
|
2019-12-20 13:47:56 +03:00
|
|
|
def test_sentiment_analysis(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
mandatory_keys = {"label"}
|
|
|
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
2019-12-20 13:47:56 +03:00
|
|
|
invalid_inputs = [None]
|
|
|
|
for tokenizer, model, config in TEXT_CLASSIF_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_tf
|
|
|
|
def test_tf_sentiment_analysis(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
mandatory_keys = {"label"}
|
|
|
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
2019-12-20 15:16:23 +03:00
|
|
|
invalid_inputs = [None]
|
2019-12-20 17:07:00 +03:00
|
|
|
for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_torch
|
2019-12-20 13:47:56 +03:00
|
|
|
def test_features_extraction(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
2019-12-20 13:47:56 +03:00
|
|
|
invalid_inputs = [None]
|
|
|
|
for tokenizer, model, config in FEATURE_EXTRACT_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_tf
|
|
|
|
def test_tf_features_extraction(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
2019-12-20 15:16:23 +03:00
|
|
|
invalid_inputs = [None]
|
2019-12-20 17:07:00 +03:00
|
|
|
for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
|
2019-12-20 13:47:56 +03:00
|
|
|
|
|
|
|
|
|
|
|
class MultiColumnInputTestCase(unittest.TestCase):
|
|
|
|
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
|
|
|
self.assertIsNotNone(nlp)
|
|
|
|
|
|
|
|
mono_result = nlp(valid_inputs[0])
|
|
|
|
self.assertIsInstance(mono_result, dict)
|
|
|
|
|
|
|
|
for key in output_keys:
|
|
|
|
self.assertIn(key, mono_result)
|
|
|
|
|
|
|
|
multi_result = nlp(valid_inputs)
|
|
|
|
self.assertIsInstance(multi_result, list)
|
|
|
|
self.assertIsInstance(multi_result[0], dict)
|
|
|
|
|
|
|
|
for result in multi_result:
|
|
|
|
for key in output_keys:
|
|
|
|
self.assertIn(key, result)
|
|
|
|
|
|
|
|
self.assertRaises(Exception, nlp, invalid_inputs[0])
|
|
|
|
self.assertRaises(Exception, nlp, invalid_inputs)
|
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_torch
|
2019-12-20 13:47:56 +03:00
|
|
|
def test_question_answering(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
mandatory_output_keys = {"score", "answer", "start", "end"}
|
2019-12-20 13:47:56 +03:00
|
|
|
valid_samples = [
|
2019-12-21 17:46:46 +03:00
|
|
|
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
|
2019-12-20 13:47:56 +03:00
|
|
|
{
|
2019-12-21 17:46:46 +03:00
|
|
|
"question": "In what field is HuggingFace working ?",
|
|
|
|
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
|
|
|
},
|
2019-12-20 13:47:56 +03:00
|
|
|
]
|
|
|
|
invalid_samples = [
|
2019-12-21 17:46:46 +03:00
|
|
|
{"question": "", "context": "This is a test to try empty question edge case"},
|
|
|
|
{"question": None, "context": "This is a test to try empty question edge case"},
|
|
|
|
{"question": "What is does with empty context ?", "context": ""},
|
|
|
|
{"question": "What is does with empty context ?", "context": None},
|
2019-12-20 13:47:56 +03:00
|
|
|
]
|
|
|
|
|
|
|
|
for tokenizer, model, config in QA_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 15:16:23 +03:00
|
|
|
@require_tf
|
|
|
|
def test_tf_question_answering(self):
|
2019-12-21 17:46:46 +03:00
|
|
|
mandatory_output_keys = {"score", "answer", "start", "end"}
|
2019-12-20 15:16:23 +03:00
|
|
|
valid_samples = [
|
2019-12-21 17:46:46 +03:00
|
|
|
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
|
2019-12-20 15:16:23 +03:00
|
|
|
{
|
2019-12-21 17:46:46 +03:00
|
|
|
"question": "In what field is HuggingFace working ?",
|
|
|
|
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
|
|
|
},
|
2019-12-20 15:16:23 +03:00
|
|
|
]
|
|
|
|
invalid_samples = [
|
2019-12-21 17:46:46 +03:00
|
|
|
{"question": "", "context": "This is a test to try empty question edge case"},
|
|
|
|
{"question": None, "context": "This is a test to try empty question edge case"},
|
|
|
|
{"question": "What is does with empty context ?", "context": ""},
|
|
|
|
{"question": "What is does with empty context ?", "context": None},
|
2019-12-20 15:16:23 +03:00
|
|
|
]
|
2019-12-20 13:47:56 +03:00
|
|
|
|
2019-12-20 17:07:00 +03:00
|
|
|
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
|
2019-12-21 17:46:46 +03:00
|
|
|
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
|
2019-12-20 15:16:23 +03:00
|
|
|
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|