diff --git a/pyproject.toml b/pyproject.toml index 84ff481..57e5577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,3 +67,6 @@ all = [ [tool.flit.scripts] recon = "recon.cli:app" + +[tool.flit.entrypoints."prodigy_recipes"] +"recon.ner_correct" = "recon:prodigy_recipes.ner_correct" diff --git a/recon/__init__.py b/recon/__init__.py index 462ff7d..7261874 100644 --- a/recon/__init__.py +++ b/recon/__init__.py @@ -8,3 +8,9 @@ from .insights import * from .loaders import read_json, read_jsonl from .stats import get_ner_stats from .validation import * + +try: + # This needs to be imported in order for the entry points to be loaded + from . import prodigy_recipes # noqa: F401 +except ImportError: + pass diff --git a/recon/constants.py b/recon/constants.py index 3d05979..445f880 100644 --- a/recon/constants.py +++ b/recon/constants.py @@ -1 +1 @@ -NONE = "__NONE__" +NONE = "NOT_LABELED" diff --git a/recon/prodigy/__init__.py b/recon/prodigy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/recon/prodigy/recipes.py b/recon/prodigy/recipes.py new file mode 100644 index 0000000..6eceee0 --- /dev/null +++ b/recon/prodigy/recipes.py @@ -0,0 +1,94 @@ +import copy +import random +from collections import Counter, defaultdict +from typing import Dict, Iterable, List, Optional, Union + +import catalogue +import prodigy +import spacy +import srsly +from prodigy.components.db import connect +from prodigy.components.loaders import get_stream +from prodigy.components.preprocess import add_tokens +from prodigy.recipes.ner import get_labels_from_ner +from prodigy.util import ( + INPUT_HASH_ATTR, + TASK_HASH_ATTR, + get_labels, + log, + set_hashes, + split_string, +) +from recon.types import HardestExample +from wasabi import msg + + +def get_stream_from_hardest_examples(hardest_examples: List[HardestExample]): + for he in hardest_examples: + task = he.example.dict() + task['prediction_errors'] = [pe.dict() for pe in he.prediction_errors] + yield task + + +@prodigy.recipe( + "recon.ner_correct", + dataset=("Dataset to save annotations to", "positional", None, str), + spacy_model=("Base model or blank:lang (e.g. blank:en) for blank model", "positional", None, str), + hardest_examples=("Data to annotate (file path or '-' to read from standard input)", "positional", None, str), + loader=("Loader (guessed from file extension if not set)", "option", "lo", str), + label=("Comma-separated label(s) to annotate or text file with one label per line", "option", "l", get_labels), + exclude=("Comma-separated list of dataset IDs whose annotations to exclude", "option", "e", split_string), +) +def ner_correct( + dataset: str, + spacy_model: str, + hardest_examples: List[HardestExample], + loader: Optional[str] = None, + label: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, +): + """ + Stream a List of `recon.types.HardestExample` instances to prodigy + for review/correction. Uses the Prodigy blocks interface to display + prediction error information along with ner view + """ + log("RECIPE: Starting recipe ner.manual", locals()) + if spacy_model.startswith("blank:"): + nlp = spacy.blank(spacy_model.replace("blank:", "")) + else: + nlp = spacy.load(spacy_model) + labels = label # comma-separated list or path to text file + if not labels: + labels = get_labels_from_ner(nlp) + if not labels: + msg.fail("No --label argument set and no labels found in model", exits=1) + msg.text(f"Using {len(labels)} labels from model: {', '.join(labels)}") + log(f"RECIPE: Annotating with {len(labels)} labels", labels) + # stream = get_stream(source, None, loader, rehash=True, dedup=True, input_key="text") + + stream = get_stream_from_hardest_examples(hardest_examples) + stream = add_tokens(nlp, stream) # add "tokens" key to the tasks + + html_template = """ + + """ + + return { + "view_id": "blocks", + "dataset": dataset, + "stream": stream, + "exclude": exclude, + "config": { + "lang": nlp.lang, + "labels": labels, + "exclude_by": "input", + "blocks": [ + {"view_id": "ner_manual"}, + {"view_id": "html", "field_rows": 3, "html_template": html_template}, + ] + } + } diff --git a/recon/prodigy/templates/prediction_error.html b/recon/prodigy/templates/prediction_error.html new file mode 100644 index 0000000..dbe0f95 --- /dev/null +++ b/recon/prodigy/templates/prediction_error.html @@ -0,0 +1,3 @@ +{% for pe in prediction_errors %} +
  • {{ user.username }}
  • +{% endfor %} \ No newline at end of file diff --git a/recon/prodigy_recipes.py b/recon/prodigy_recipes.py new file mode 100644 index 0000000..6b6549e --- /dev/null +++ b/recon/prodigy_recipes.py @@ -0,0 +1,220 @@ +import copy +import random +from collections import Counter, defaultdict +from typing import Dict, Iterable, List, Optional, Union + +import catalogue +import prodigy +import spacy +import srsly +from prodigy.components.db import connect +from prodigy.components.loaders import get_stream +from prodigy.components.preprocess import add_tokens +from prodigy.recipes.ner import get_labels_from_ner +from prodigy.util import ( + INPUT_HASH_ATTR, + TASK_HASH_ATTR, + get_labels, + log, + set_hashes, + split_string, +) +from recon.constants import NONE +from recon.types import HardestExample +from recon.validation import remove_overlapping_entities +from wasabi import msg + + +def make_span_hash(span: Dict): + return f"{span['text']}|||{span['start']}|||{span['end']}|||{span['label']}" + + +def get_stream_from_hardest_examples(nlp, hardest_examples: List[HardestExample]): + for he in hardest_examples: + task = he.example.dict() + task = list(add_tokens(nlp, [task]))[0] + gold_span_hashes = {make_span_hash(span): span for span in task['spans']} + predicted_example = None + for pe in he.prediction_errors: + for e in pe.examples: + if e.predicted.text == he.example.text: + predicted_example = e.predicted + if predicted_example: + pthtml = [] + predicted_example_task = predicted_example.dict() + predicted_example_task = list(add_tokens(nlp, [predicted_example_task]))[0] + + for token in predicted_example_task['tokens']: + pthtml.append(f'{token["text"]} ') + + pred_spans = predicted_example_task['spans'] + pred_span_hashes = [make_span_hash(span) for span in predicted_example_task['spans']] + + for gold_span_hash, gold_span in gold_span_hashes.items(): + if gold_span_hash not in pred_span_hashes: + gold_span_miss = copy.deepcopy(gold_span) + gold_span_miss['label'] = NONE + pred_spans.append(gold_span_miss) + + pred_spans = remove_overlapping_entities(sorted(pred_spans, key=lambda s: s["start"])) + # pred_spans = sorted(pred_spans, key=lambda s: s["start"]) + + i = len(pred_spans) - 1 + while i >= 0: + span = pred_spans[i] + span_hash = make_span_hash(span) + if span_hash in gold_span_hashes: + labelColorClass = "recon-pred-success-mark" + elif span['label'] == NONE: + labelColorClass = "recon-pred-missing-mark" + else: + labelColorClass = "recon-pred-error-mark" + + pthtml = pthtml[:span['token_end'] + 1] + [f'{span["label"]}x'] + pthtml[span['token_end'] + 1:] + pthtml = pthtml[:span['token_start']] + [f''] + pthtml[span['token_start']:] + i -= 1 + task['html'] = f""" +

    Recon Prediction Errors

    +
    + The following text shows the errors your model made on this example inline. + Correct the annotations above based on how well your model performed on this example. + If your labeling is correct you might need to add more training examples in this domain. +
    +
    + {''.join(pthtml)} +
    + """ + task['prediction_errors'] = [pe.dict() for pe in he.prediction_errors] + yield task + + +@prodigy.recipe( + "recon.ner_correct", + dataset=("Dataset to save annotations to", "positional", None, str), + spacy_model=("Base model or blank:lang (e.g. blank:en) for blank model", "positional", None, str), + hardest_examples=("Data to annotate (file path or '-' to read from standard input)", "positional", None, str), + loader=("Loader (guessed from file extension if not set)", "option", "lo", str), + label=("Comma-separated label(s) to annotate or text file with one label per line", "option", "l", get_labels), + exclude=("Comma-separated list of dataset IDs whose annotations to exclude", "option", "e", split_string), +) +def ner_correct( + dataset: str, + spacy_model: str, + hardest_examples: List[HardestExample], + loader: Optional[str] = None, + label: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, +): + """ + Stream a List of `recon.types.HardestExample` instances to prodigy + for review/correction. Uses the Prodigy blocks interface to display + prediction error information along with ner view + """ + log("RECIPE: Starting recipe recon.ner_correct", locals()) + if spacy_model.startswith("blank:"): + nlp = spacy.blank(spacy_model.replace("blank:", "")) + else: + nlp = spacy.load(spacy_model) + labels = label # comma-separated list or path to text file + if not labels: + labels = get_labels_from_ner(nlp) + if not labels: + msg.fail("No --label argument set and no labels found in model", exits=1) + msg.text(f"Using {len(labels)} labels from model: {', '.join(labels)}") + log(f"RECIPE: Annotating with {len(labels)} labels", labels) + # stream = get_stream(source, None, loader, rehash=True, dedup=True, input_key="text") + + stream = get_stream_from_hardest_examples(nlp, hardest_examples) + stream = add_tokens(nlp, stream) # add "tokens" key to the tasks + + table_template = """ + +
    All prediction errors for this example.
    + + + + + + + {{#prediction_errors}} + + + + + + {{/prediction_errors}} +
    TextTrue LabelPred Label
    {{text}}{{true_label}}{{pred_label}}
    + """ + + return { + "view_id": "blocks", + "dataset": dataset, + "stream": stream, + "exclude": exclude, + "config": { + "lang": nlp.lang, + "labels": labels.split(','), + "exclude_by": "input", + "blocks": [ + {"view_id": "ner_manual"}, + {"view_id": "html"}, + {"view_id": "html", "field_rows": 3, "html_template": table_template}, + ], + "global_css": """ + .recon-title { + text-align: left; + margin-top: -80px; + } + .recon-subtitle { + text-align: left; + margin-top: -80px; + white-space: normal; + } + .recon-container { + text-align: left; + line-height: 2; + margin-top: -80px; + white-space: pre-line; + } + .recon-pred { + color: inherit; + margin: 0 0.15em; + display: inline; + padding: 0.25em 0.4em; + font-weight: bold; + line-height: 1; + -webkit-box-decoration-break: clone; + } + .recon-pred-success-mark { + background: #00cc66; + } + .recon-pred-error-mark { + background: #fc7683; + } + .recon-pred-missing-mark { + background: #84b4c4; + } + .recon-pred-label { + color: #583fcf; + font-size: 0.675em; + font-weight: bold; + font-family: "Roboto Condensed", "Arial Narrow", sans-serif; + margin-left: 8px; + text-transform: uppercase; + vertical-align: middle; + } + """ + } + }