2020-05-27 15:21:33 +03:00
|
|
|
import pytest
|
|
|
|
|
2021-12-26 18:37:45 +03:00
|
|
|
from presidio_evaluator import InputSample
|
2021-04-26 12:40:05 +03:00
|
|
|
from presidio_evaluator.evaluation import Evaluator
|
|
|
|
|
2020-01-06 23:59:12 +03:00
|
|
|
try:
|
|
|
|
from flair.models import SequenceTagger
|
2020-05-27 15:21:33 +03:00
|
|
|
except:
|
|
|
|
ImportError("Flair is not installed by default")
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2021-12-26 18:37:45 +03:00
|
|
|
|
2021-04-26 12:40:05 +03:00
|
|
|
from presidio_evaluator.models.flair_model import FlairModel
|
2020-01-06 23:59:12 +03:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2021-04-26 12:40:05 +03:00
|
|
|
|
2020-01-06 23:59:12 +03:00
|
|
|
# no-unit because flair is not a dependency by default
|
2020-05-27 15:21:33 +03:00
|
|
|
@pytest.mark.skip(reason="Flair not installed by default")
|
|
|
|
def test_flair_simple():
|
2020-01-06 23:59:12 +03:00
|
|
|
import os
|
2020-05-27 15:21:33 +03:00
|
|
|
|
2020-01-06 23:59:12 +03:00
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
2021-12-26 18:37:45 +03:00
|
|
|
input_samples = InputSample.read_dataset_json(
|
|
|
|
os.path.join(dir_path, "data/generated_small.json")
|
2020-05-27 15:21:33 +03:00
|
|
|
)
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2020-05-27 15:21:33 +03:00
|
|
|
model = SequenceTagger.load("ner-ontonotes-fast") # .load('ner')
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2021-04-26 12:40:05 +03:00
|
|
|
flair_model = FlairModel(model=model, entities_to_keep=["PERSON"])
|
|
|
|
evaluator = Evaluator(model=flair_model)
|
|
|
|
evaluation_results = evaluator.evaluate_all(input_samples)
|
|
|
|
scores = evaluator.calculate_score(evaluation_results)
|
2020-01-06 23:59:12 +03:00
|
|
|
|
2020-05-27 15:21:33 +03:00
|
|
|
np.testing.assert_almost_equal(
|
|
|
|
scores.pii_precision, scores.entity_precision_dict["PERSON"]
|
|
|
|
)
|
|
|
|
np.testing.assert_almost_equal(
|
|
|
|
scores.pii_recall, scores.entity_recall_dict["PERSON"]
|
|
|
|
)
|
2020-01-06 23:59:12 +03:00
|
|
|
assert scores.pii_recall > 0
|
|
|
|
assert scores.pii_precision > 0
|