diff --git a/.gitignore b/.gitignore index 14b513d..3a48715 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ dist site .coverage htmlcov -examples/cs_data_test.ipynb +examples/cs_data*.ipynb examples/skills_data*.ipynb **/.ipynb_checkpoints **/.pytest_cache diff --git a/docs/overrides/main.html b/docs/overrides/main.html new file mode 100644 index 0000000..37f949a --- /dev/null +++ b/docs/overrides/main.html @@ -0,0 +1,32 @@ +{% extends "base.html" %} + + +{% block extrahead %} + + + {% if page and page.meta and page.meta.title %} + + {% elif page and page.title and not page.is_homepage %} + + {% else %} + + {% endif %} + + + + + + + + {% if page and page.meta and page.meta.title %} + + {% elif page and page.title and not page.is_homepage %} + + {% else %} + + {% endif %} + + + + +{% endblock %} \ No newline at end of file diff --git a/examples/2.0_operations.ipynb b/examples/2.0_operations.ipynb index 8ee3914..ea8f392 100644 --- a/examples/2.0_operations.ipynb +++ b/examples/2.0_operations.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -62,17 +62,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed',)).History will not be written to the database.\n" - ] - } - ], + "outputs": [], "source": [ "import copy\n", "from pprint import pprint\n", @@ -101,24 +93,26 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "OrderedDict([('rename_labels', ),\n", + "OrderedDict([('rename_labels', ),\n", " ('fix_annotations',\n", - " ),\n", + " ),\n", + " ('strip_annotations',\n", + " ),\n", " ('fix_tokenization_and_spacing',\n", - " ),\n", - " ('add_tokens', ),\n", - " ('upcase_labels', ),\n", + " ),\n", + " ('add_tokens', ),\n", + " ('upcase_labels', ),\n", " ('filter_overlaps',\n", - " )])" + " )])" ] }, - "execution_count": 2, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -129,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -143,27 +137,27 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "[]" ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "fix_tokenization_and_spacing" + "fix_tokenization_and_spacing.pre" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -172,16 +166,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[OperationState(name='fix_tokenization_and_spacing', batch=False, args=[], kwargs={}, status=, ts=datetime.datetime(2020, 4, 15, 12, 45, 53, 973976), examples_added=0, examples_removed=0, examples_changed=0, transformations=[])]" ] }, - "execution_count": 6, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -192,16 +186,49 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "recon.v1.spacy \n", + "dict_keys([1196406995714445350, 1975826118262010565, 2208035061381934376, 1473978031617727200, 106335390642092038, 500601151867807297, 440832847742549781, 706460250993890974, 1115219443209707540, 1949519895665837812, 935299783623680754, 2163307383884749416, 403532604991006089, 440153399298716595, 2165419558369486553, 709069367829232207, 1164040277870844331, 1001820030585517207, 1520504525369429473, 1143349283593732378, 772909226859630929, 2186683589276212162, 1439246817467080996, 646205256634488830, 391054967234424417, 1069073211256351973, 1631116673795041161, 1853287415515327485, 1312008281566775001, 1850119932490491100, 2248886061479488760, 1392728192912329521, 1436426349557743173, 502406100547088088, 1280147727121249306, 1761578724182773272, 1539877667996420026, 31812586317021947, 1301439328206016166, 384384728390475553, 769173937489945255, 2288129318297025329, 1242041334273589045, 228629945225414711, 1098488717506082294, 282492229720054716, 380310756566959607, 1924565152306637272, 376684774014885964, 2304838004040464281, 103529869190846028, 1988264538232162424, 1217488862723372728, 1083422037887251876, 552794652063744234, 1991149014422385820, 1804412783475653815, 133559379123924178, 630284206817307426, 1128000514323767167, 533842790538656977, 1783218316333109775, 44327064662381491, 590102417292709537, 1328189097588562120, 1471741353434857039, 1395082368088541827, 2020857152038051220, 1515538653321808267, 2909653324313924, 1142573215943909380, 1348307833267793209, 1424011768500181956, 2183248012838820930, 802489642660632463, 1860169144513739726, 435413361717193409, 606453138874978909, 1580995208213325947, 185933608830505150, 940811210756437104, 220881773549728518, 284313563161040515, 2085769956561363246, 1765438062702589944, 1090253479434315853, 1256422104425125455, 2233857738026586523, 553002533348785626, 603521638447945005, 626642907161470527, 322564103237696284, 205403610460430855, 999682930106592044, 87620785170841013, 1398471957127296361, 1869198887032894170, 1384780638327235009, 2270385796303996220, 1209709110032413943, 176860716858529675, 1241234826570026284, 200276835658424828, 2026883444237770066, 1495583689335203153, 1995249007653203494])\n", + "recon.v1.spacy \n", + "dict_keys([1700529846029343910, 1442725918004243918, 1615314956284386067, 2292918397458640594, 936174842952091019, 1416495214555748544, 1470704679076406034, 416936013533939646, 102976739755988675, 328680753745285627, 1821910156486108522, 461746148819085540, 1316906644314897545, 480190416079738296, 832098851410421791, 2077939734169057805, 499328831071479316, 2095072877850172291, 1709939900352147844, 59689883316342687, 398891871713859070, 1627140869294591678, 310028046539117771, 1862984016626204930, 515420873264934608, 198250555976187584, 1674255301498414544, 2207684247935563565, 39776830632428320, 967824718652465343, 1717026350721933397, 546998027537219826, 327268413308598687, 1070784568438262286, 1191963312554833233, 283812158810932089, 1600875993361830372, 827400666469217446, 833324646184074177, 141634872600563553, 868035574974736930, 710098414400391884, 2214163065793631862, 1960217053886089487, 413733258957413667, 574101210792060298, 128584703305360341, 1229099326012548773, 152681987489235101, 1203120265425271744, 1210731897437002224, 1708971935759553743, 1082408608430937670, 1587584412664023332, 1288634136647605965, 1559990870835400553, 727956066452843778, 411021709465579531, 712627598286018477, 745384867524222597, 754114973842762845, 1250213657925559709, 1060453348244319975, 1909832121214376215, 142120949158349466, 2127651655156229446, 79606520449659671, 1201338150913378032, 883878955589494298, 1476557811892588710, 75316633542116270, 1116956250434915335, 764974836186572417, 1468246759446807473, 2112306765897272231, 2193368486573054880, 1829112984191980085, 1250450046952360826, 1634755933488298800, 1602844456389187640, 863282461614727969, 2062363953964471355, 1809453615774236214, 235578488679793708, 1501219999274819459, 1842768724828632935, 1617507151150411878, 1694156056316982174, 1416023339598949511, 1712731616395876010, 1796257318551092056, 189257385865390542, 2118382680464518054, 234444508079345572, 1335636780743303396, 687445399934304164, 1068178454877698177, 935890424003835932, 1520292643375647582, 1492858771249319517, 57842038071894975, 1425233669197905973, 1605065109739879831, 3032625132909434, 1059081288433944966, 1384780638327235009, 2270385796303996220, 1209709110032413943, 176860716858529675, 1241234826570026284])\n", + "recon.v1.spacy \n", + "dict_keys([1367973413700956518, 2029160499876569560, 1607811024963276245, 489279466128008280, 1415495710082308731, 2255889546845585358, 530003636929667978, 1835925174907410759, 1041359394819446102, 861318538180278693, 583083590587376325, 105665357616818193, 1559441277659343605, 993295836099910212, 697271199090081718, 1305803167822276105, 1501964998077241217, 1983023552699738193, 131613684448168329, 1499203256596183286, 1913310551616924816, 805448350284417457, 1542760328774421006, 647271954588451270, 986139625200759836, 350516643100518264, 1076604938120777862, 2262171711793193596, 122887233971179058, 1061866805993556892, 1787769822543877726, 43799087513494440, 216202532025443194, 102509625659950763, 1005529631534994242, 1541086328188921886, 1425565207195997244, 455885524989955913, 389017461808763837, 435254507021985749, 1485603496641415809, 2260855597686721254, 2191677898864625980, 132765292597080193, 1432390711654223784, 650836541803377396, 1524192431628044429, 1598691181628504187, 1957197914153068094, 926372278586414398, 1242024804406792689, 600297596622501654, 1612138149287315133, 1452781579416233781, 991080122349479433, 611942000611137234, 896909537791521797, 1294785973424865062, 367331397880364536, 578218064576825402, 179908893921063561, 1800926667122710822, 276853797656703690, 1581244857368993418, 1978762199449998997, 226275956992877221, 1283816090281245606, 1581031192357922359, 1921383998015280342, 1242326926490269961, 999978380577142972, 6802486761397437, 1199172863138105586, 1972779204238002222, 1082706630997395533, 1737571819712865172, 58089873693865820, 717335074039265265, 1940859792894958469, 2247458127698710144, 1824060206597705914, 1070135875738874749, 2149059511880502851, 1323080716800541486, 1501301622567185490, 1698455696044143889, 749766778723534242, 4193966078694182, 750447672412097139, 961663597043273946, 1017453684411282930, 2157922784386036570, 1474272284242173255, 334900450619140659, 752986196974172619, 210946257216823596])\n" + ] + } + ], "source": [ "corpus.apply_(\"fix_tokenization_and_spacing\")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[OperationState(name='fix_tokenization_and_spacing', batch=False, args=[], kwargs={}, status=, ts=datetime.datetime(2020, 4, 15, 12, 45, 53, 973976), examples_added=0, examples_removed=0, examples_changed=0, transformations=[])]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpus._train.operations" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/mkdocs.yml b/mkdocs.yml index 4b5599f..03e3a7a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,6 +9,7 @@ theme: accent: 'deep-orange' logo: 'img/drone-white.svg' favicon: 'img/favicon.ico' + custom_dir: docs/overrides repo_name: 'microsoft/reconner' repo_url: 'https://github.com/microsoft/reconner' diff --git a/recon/corrections.py b/recon/corrections.py index 9645fc1..f0eb5a0 100644 --- a/recon/corrections.py +++ b/recon/corrections.py @@ -74,3 +74,34 @@ def fix_annotations( print(line) return example + + +@operation("strip_annotations") +def strip_annotations(example: Example, strip_chars: List[str] = [".", "!", "?", "-", ":", " "]) -> Example: + """Strip punctuation and spaces from start and end of annotations. + These characters are almost always a mistake and will confuse a model + + Args: + example (Example): Input Example + strip_chars (List[str], optional): Characters to strip. + + Returns: + Example: Example with stripped spans + """ + + for s in example.spans: + for ch in strip_chars: + if s.text.startswith(ch): + ch = s.text[0] + + while ch in strip_chars: + s.text = s.text[1:] + s.start += 1 + ch = s.text[0] + elif s.text.endswith(ch): + ch = s.text[-1] + while ch in strip_chars: + s.text = s.text[:-1] + ch = s.text[-1] + s.end -= 1 + return example diff --git a/recon/preprocess.py b/recon/preprocess.py new file mode 100644 index 0000000..a6225e5 --- /dev/null +++ b/recon/preprocess.py @@ -0,0 +1,37 @@ +from functools import chain +from typing import Any, Dict, Iterator, List +import spacy +from spacy.language import Language +from spacy.tokens import Doc + +from .types import Example + + +class PreProcessor(object): + name = "recon.v1.preprocess" + def __init__(self): + super().__init__() + self._cache: Dict[Any, Any] + + def __call__(self, data: List[Example]) -> Iterator[Any]: + raise NotImplementedError + + +class SpacyPreProcessor(PreProcessor): + name = "recon.v1.spacy" + + def __init__(self, nlp: Language): + super().__init__() + self._nlp = nlp + + def __call__(self, data: List[Example]) -> Iterator[Any]: + unseen_texts = (e.text for i, e in enumerate(data) if hash(e) not in self._cache) + seen_texts = ((i, e.text) for i, e in enumerate(data) if hash(e) in self._cache) + + docs = list(self._nlp.pipe(unseen_texts)) + for doc in docs: + self._cache[doc.text] = doc + for idx, st in seen_texts: + docs.insert(idx, self._cache[st]) + + return (doc for doc in docs) diff --git a/recon/tokenization.py b/recon/tokenization.py index 2b6e666..017e418 100644 --- a/recon/tokenization.py +++ b/recon/tokenization.py @@ -1,11 +1,12 @@ import copy from collections import defaultdict -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, List, Set, Tuple, Union from spacy.language import Language from .dataset import Dataset from .operations import op_iter, operation +from .preprocess import SpacyPreProcessor from .registry import tokenizers from .types import ( Example, @@ -18,14 +19,17 @@ from .types import ( ) -@operation("fix_tokenization_and_spacing", batch=True) +tokenizer = tokenizers.get("default") +nlp = tokenizer() +spacy_pre_processor = SpacyPreProcessor(nlp) + + +@operation("fix_tokenization_and_spacing", pre=[spacy_pre_processor]) def fix_tokenization_and_spacing( - examples: List[Example], + example: Example, *, - callbacks: TransformationCallbacks = TransformationCallbacks(), - tokenizer: str = "default", - verbose: bool = False, -) -> List[Example]: + preprocessed_outputs: Dict[str, Any] = {} +) -> Union[Example, None]: """Fix tokenization and spacing issues where there are annotation spans that don't fall on a token boundary. This can happen if annotations are done at the character level, not the token level. Often, when scraping web text it's easy to @@ -39,133 +43,114 @@ def fix_tokenization_and_spacing( Returns: List[Example]: List of examples with fixed tokenization """ - fixed_examples = [] - tokenization_errors: List[Tuple[Example, Span]] = [] - unfixable_examples: Set[str] = set() - nlp = tokenizers.get("default")() - texts = (e.text for e in examples) + doc = preprocessed_outputs["recon.v1.spacy"] - with nlp.disable_pipes(*nlp.pipe_names): - for (orig_example_hash, example), doc in zip(op_iter(examples), nlp.pipe(texts)): - doc = nlp.make_doc(example.text) + tokens = [] + token_starts = {} + token_ends = {} - tokens = [] - token_starts = {} - token_ends = {} + for t in doc: + start = t.idx + end = t.idx + len(t) + tokens.append(Token(text=t.text, start=start, end=end, id=t.i)) + token_starts[start] = t + token_ends[end] = t - for t in doc: - start = t.idx - end = t.idx + len(t) - tokens.append(Token(text=t.text, start=start, end=end, id=t.i)) - token_starts[start] = t - token_ends[end] = t + spans_to_increment: Dict[int, int] = defaultdict(int) + for span_i, span in enumerate(example.spans): + if span.start in token_starts and span.end in token_ends: + # Aligns to token boundaries, nothing to change here + continue - spans_to_increment: Dict[int, int] = defaultdict(int) - for span_i, span in enumerate(example.spans): - if span.start in token_starts and span.end in token_ends: - # Aligns to token boundaries, nothing to change here - continue - - if span.start in token_starts and span.end not in token_ends: - # Span start aligns to token_start but end doesn't - # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE] - tokenization_errors.append((example, span)) - # print("BAD END") - if span.end + 1 in token_ends: - # Likely off by 1 annotation - # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE] - span.end += 1 - span.text = example.text[span.start : span.end] - # print("SPAN CORRECTED OFF BY 1", example.text, span) - elif span.end - 1 in token_ends: - span.end -= 1 - span.text = example.text[span.start : span.end] - else: - # Likely bad tokenization - # e.g. [Quadling][GPE]Country should be split to [Quadling][GPE] Country - for j in range(span_i + 1, len(example.spans)): - spans_to_increment[j] += 1 - fe_text = example.text - - split_start = span.start - if ( - len(spans_to_increment) > 1 - and span_i != list(spans_to_increment.keys())[0] - ): - split_start += spans_to_increment.get(span_i, 0) - split_end = span.end - if ( - len(spans_to_increment) > 1 - and span_i != list(spans_to_increment.keys())[0] - ): - split_end += spans_to_increment.get(span_i, 0) - new_text = f"{fe_text[:split_start]}{span.text} {fe_text[split_end:]}" - - example.text = new_text - - elif span.start not in token_starts and span.end in token_ends: - # Bad tokenization - # e.g. with[Raymond][PERSON] but text should be split to with [Raymond][PERSON] - # print("BAD START", span.text) - tokenization_errors.append((example, span)) - for j in range(span_i, len(example.spans)): - spans_to_increment[j] += 1 - - fe_text = example.text - - split_start = span.start - if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]: - split_start += spans_to_increment.get(span_i, 0) - split_end = span.end - if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]: - split_end += spans_to_increment.get(span_i, 0) - - new_text = f"{fe_text[:split_start]} {span.text}{fe_text[split_end:]}" - example.text = new_text - else: - # Something is super fucked up. - # print("SPAN CORRECTED OFF BY 1 unfixable", example.text, span) - before = span.start - after = span.end - tokenization_errors.append((example, span)) - - # if (before >= 0 and after < len(span.text) and span[before] not in token_starts and span[before] != ' ' and span[after] not in token_ends and span[after] != ' '): - # fe_text = example.text - # new_text = f"{fe_text[:span.start]} {span.text}{fe_text[span.end:]}" - # spans_to_increment[span_i] += 1 - # for j in range(span_i + 1, len(example.spans)): - # spans_to_increment[j] += 2 - # else: - unfixable_examples.add(example.text) - break - - # Increment the start and end characters for each span - for span_i, count in spans_to_increment.items(): - example.spans[span_i].start += count - example.spans[span_i].end += count - - if example.text not in unfixable_examples: - callbacks.track_example(orig_example_hash, example) - fixed_examples.append(example) + if span.start in token_starts and span.end not in token_ends: + # Span start aligns to token_start but end doesn't + # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE] + # tokenization_errors.append((example, span)) + # print("BAD END") + if span.end + 1 in token_ends: + # Likely off by 1 annotation + # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE] + span.end += 1 + span.text = example.text[span.start : span.end] + # print("SPAN CORRECTED OFF BY 1", example.text, span) + elif span.end - 1 in token_ends: + span.end -= 1 + span.text = example.text[span.start : span.end] else: - callbacks.remove_example(orig_example_hash) + # Likely bad tokenization + # e.g. [Quadling][GPE]Country should be split to [Quadling][GPE] Country + for j in range(span_i + 1, len(example.spans)): + spans_to_increment[j] += 1 + fe_text = example.text - if tokenization_errors and verbose: - print(f"Found {len(tokenization_errors)} tokenization errors.") - print(f"Found {len(unfixable_examples)} unfixable tokenization errors.") + split_start = span.start + if ( + len(spans_to_increment) > 1 + and span_i != list(spans_to_increment.keys())[0] + ): + split_start += spans_to_increment.get(span_i, 0) + split_end = span.end + if ( + len(spans_to_increment) > 1 + and span_i != list(spans_to_increment.keys())[0] + ): + split_end += spans_to_increment.get(span_i, 0) + new_text = f"{fe_text[:split_start]}{span.text} {fe_text[split_end:]}" - return fixed_examples + example.text = new_text + + elif span.start not in token_starts and span.end in token_ends: + # Bad tokenization + # e.g. with[Raymond][PERSON] but text should be split to with [Raymond][PERSON] + # print("BAD START", span.text) + # tokenization_errors.append((example, span)) + for j in range(span_i, len(example.spans)): + spans_to_increment[j] += 1 + + fe_text = example.text + + split_start = span.start + if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]: + split_start += spans_to_increment.get(span_i, 0) + split_end = span.end + if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]: + split_end += spans_to_increment.get(span_i, 0) + + new_text = f"{fe_text[:split_start]} {span.text}{fe_text[split_end:]}" + example.text = new_text + else: + # Something is super fucked up. + # print("SPAN CORRECTED OFF BY 1 unfixable", example.text, span) + before = span.start + after = span.end + # tokenization_errors.append((example, span)) + + # if (before >= 0 and after < len(span.text) and span[before] not in token_starts and span[before] != ' ' and span[after] not in token_ends and span[after] != ' '): + # fe_text = example.text + # new_text = f"{fe_text[:span.start]} {span.text}{fe_text[span.end:]}" + # spans_to_increment[span_i] += 1 + # for j in range(span_i + 1, len(example.spans)): + # spans_to_increment[j] += 2 + # else: + # unfixable_examples.add(example.text) + # break + return None + + # Increment the start and end characters for each span + for span_i, count in spans_to_increment.items(): + example.spans[span_i].start += count + example.spans[span_i].end += count + + return example -@operation("add_tokens", batch=True) +@operation("add_tokens", pre=[spacy_pre_processor]) def add_tokens( - examples: List[Example], + example: Example, *, - callbacks: TransformationCallbacks = TransformationCallbacks(), - tokenizer: str = "default", - verbose: bool = False, -) -> List[Example]: + preprocessed_outputs: Dict[str, Any] +) -> Union[Example, None]: """Add tokens to each Example Args: @@ -176,44 +161,27 @@ def add_tokens( Returns: List[Example]: List of examples with tokens """ - output_examples: List[Example] = [] - tokenization_errors: List[Tuple[Example, Span]] = [] - unfixable_examples: Set[str] = set() - nlp = tokenizers.get(tokenizer)() - texts = (e.text for e in examples) + doc = preprocessed_outputs["recon.v1.spacy"] - with nlp.disable_pipes(*nlp.pipe_names): - for (orig_example_hash, example), doc in zip(op_iter(examples), nlp.pipe(texts)): - tokens = [] - token_starts = {} - token_ends = {} + tokens = [] + token_starts = {} + token_ends = {} - for t in doc: - start = t.idx - end = t.idx + len(t) - tokens.append(Token(text=t.text, start=start, end=end, id=t.i)) - token_starts[start] = t - token_ends[end] = t + for t in doc: + start = t.idx + end = t.idx + len(t) + tokens.append(Token(text=t.text, start=start, end=end, id=t.i)) + token_starts[start] = t + token_ends[end] = t - example.tokens = tokens + example.tokens = tokens - for span in example.spans: - if span.start in token_starts and span.end in token_ends: - span.token_start = token_starts[span.start].i - span.token_end = token_ends[span.end].i + for span in example.spans: + if span.start in token_starts and span.end in token_ends: + span.token_start = token_starts[span.start].i + span.token_end = token_ends[span.end].i - if span.token_start is None or span.token_end is None: - tokenization_errors.append((example, span)) - unfixable_examples.add(example.text) + if span.token_start is None or span.token_end is None: + return None - if example.text not in unfixable_examples: - callbacks.track_example(orig_example_hash, example) - output_examples.append(example) - else: - callbacks.remove_example(orig_example_hash) - - if verbose: - print(f"Found {len(tokenization_errors)} tokenization errors.") - print(f"Found {len(unfixable_examples)} unfixable examples.") - - return output_examples + return example