presidio-research/tests/test_flair_model.py

42 строки
1.2 KiB
Python
Исходник Обычный вид История

2020-05-27 15:21:33 +03:00
import pytest
from presidio_evaluator import InputSample
2021-04-26 12:40:05 +03:00
from presidio_evaluator.evaluation import Evaluator
2020-01-06 23:59:12 +03:00
try:
from flair.models import SequenceTagger
2020-05-27 15:21:33 +03:00
except:
ImportError("Flair is not installed by default")
2020-01-06 23:59:12 +03:00
2021-04-26 12:40:05 +03:00
from presidio_evaluator.models.flair_model import FlairModel
2020-01-06 23:59:12 +03:00
import numpy as np
2021-04-26 12:40:05 +03:00
2020-01-06 23:59:12 +03:00
# no-unit because flair is not a dependency by default
2020-05-27 15:21:33 +03:00
@pytest.mark.skip(reason="Flair not installed by default")
def test_flair_simple():
2020-01-06 23:59:12 +03:00
import os
2020-05-27 15:21:33 +03:00
2020-01-06 23:59:12 +03:00
dir_path = os.path.dirname(os.path.realpath(__file__))
input_samples = InputSample.read_dataset_json(
os.path.join(dir_path, "data/generated_small.json")
2020-05-27 15:21:33 +03:00
)
2020-01-06 23:59:12 +03:00
2020-05-27 15:21:33 +03:00
model = SequenceTagger.load("ner-ontonotes-fast") # .load('ner')
2020-01-06 23:59:12 +03:00
2021-04-26 12:40:05 +03:00
flair_model = FlairModel(model=model, entities_to_keep=["PERSON"])
evaluator = Evaluator(model=flair_model)
evaluation_results = evaluator.evaluate_all(input_samples)
scores = evaluator.calculate_score(evaluation_results)
2020-01-06 23:59:12 +03:00
2020-05-27 15:21:33 +03:00
np.testing.assert_almost_equal(
scores.pii_precision, scores.entity_precision_dict["PERSON"]
)
np.testing.assert_almost_equal(
scores.pii_recall, scores.entity_recall_dict["PERSON"]
)
2020-01-06 23:59:12 +03:00
assert scores.pii_recall > 0
assert scores.pii_precision > 0