2022-01-24 13:00:12 +03:00
|
|
|
import sys
|
|
|
|
|
2020-05-27 15:21:33 +03:00
|
|
|
import pytest
|
|
|
|
|
2021-04-26 12:40:05 +03:00
|
|
|
from presidio_evaluator.evaluation import Evaluator
|
2022-01-24 13:00:12 +03:00
|
|
|
from tests.conftest import assert_model_results_gt
|
2021-04-26 12:40:05 +03:00
|
|
|
from presidio_evaluator.models.flair_model import FlairModel
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2020-05-27 15:21:33 +03:00
|
|
|
|
2022-01-24 13:00:12 +03:00
|
|
|
@pytest.mark.slow
|
|
|
|
@pytest.mark.skipif("flair" not in sys.modules, reason="requires the Flair library")
|
|
|
|
def test_flair_simple(small_dataset):
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2022-01-15 01:42:14 +03:00
|
|
|
flair_model = FlairModel(model_path="ner", entities_to_keep=["PERSON"])
|
2021-04-26 12:40:05 +03:00
|
|
|
evaluator = Evaluator(model=flair_model)
|
2022-01-24 13:00:12 +03:00
|
|
|
evaluation_results = evaluator.evaluate_all(small_dataset)
|
2021-04-26 12:40:05 +03:00
|
|
|
scores = evaluator.calculate_score(evaluation_results)
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2022-01-24 13:00:12 +03:00
|
|
|
assert_model_results_gt(scores, "PERSON", 0)
|