diff --git a/tests/conftest.py b/tests/conftest.py index a216a94..bedcebf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from typing import Dict, List import pytest from recon.corpus import Corpus +from recon.dataset import Dataset from recon.loaders import read_jsonl from recon.preprocess import SpacyPreProcessor from recon.recognizer import SpacyEntityRecognizer @@ -31,7 +32,7 @@ def test_texts(): @pytest.fixture() def example_data() -> Dict[str, List[Example]]: """Fixture to load example train/dev/test data that has inconsistencies. - + Returns: Dict[str, List[Example]]: Dataset containing the train/dev/test split """ @@ -46,7 +47,7 @@ def example_data() -> Dict[str, List[Example]]: @pytest.fixture() def example_corpus() -> Corpus: """Fixture to load example train/dev/test data that has inconsistencies. - + Returns: Corpus: Example data """ @@ -57,7 +58,7 @@ def example_corpus() -> Corpus: @pytest.fixture() def example_corpus_processed() -> Corpus: """Fixture to load example train/dev/test data that has inconsistencies. - + Returns: Corpus: Example data """ diff --git a/tests/test_operations.py b/tests/test_operations.py index 8c14c1f..e167e2b 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -1,4 +1,23 @@ +import pytest +from recon.dataset import Dataset from recon.operations import Operation, operation, registry +from recon.types import Example, Span + + +@pytest.fixture() +def ds(): + ds = Dataset( + name="test", + data=[ + Example( + text="this is a test example with something else", + spans=[Span(text="something", start=28, end=37, label="TEST_ENTITY")], + ) + ], + ) + ds.apply_("recon.v1.add_tokens") + + return ds def test_operation_init(): @@ -8,3 +27,67 @@ def test_operation_init(): assert "test_operation" in registry.operations assert isinstance(registry.operations.get("test_operation"), Operation) + + +def test_change_operation(ds): + @operation("change_annotation") + def operation_test(example): + example.spans[0].text = "something else" + example.spans[0].end = 42 + + return example + + assert "change_annotation" in registry.operations + assert isinstance(registry.operations.get("change_annotation"), Operation) + + assert len(ds.operations) == 1 + ds.apply_("change_annotation") + assert len(ds.operations) == 2 + + assert ds.operations[1].name == "change_annotation" + assert len(ds.example_store) == 3 + + assert len(ds) == 1 + + +def test_add_operation(ds): + @operation("add_and_change_example") + def operation_test(example): + new_example = Example(text="this is a test", spans=[]) + + example.spans[0].text = "something else" + example.spans[0].end = 42 + + return [new_example, example] + + assert "add_and_change_example" in registry.operations + assert isinstance(registry.operations.get("add_and_change_example"), Operation) + + assert len(ds.operations) == 1 + ds.apply_("add_and_change_example") + assert len(ds.operations) == 2 + + assert ds.operations[1].name == "add_and_change_example" + assert len(ds.example_store) == 4 + + assert len(ds) == 2 + + +def test_remove_operation(ds): + @operation("remove_example") + def operation_test(example): + return None + + assert "remove_example" in registry.operations + assert isinstance(registry.operations.get("remove_example"), Operation) + + assert len(ds.operations) == 1 + ds.apply_("remove_example") + assert len(ds.operations) == 2 + + assert ds.operations[1].name == "remove_example" + assert len(ds.example_store) == 2 + + assert len(ds) == 0 + + print(ds.example_store._map)