diff --git a/.gitignore b/.gitignore
index 14b513d..3a48715 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,7 +10,7 @@ dist
site
.coverage
htmlcov
-examples/cs_data_test.ipynb
+examples/cs_data*.ipynb
examples/skills_data*.ipynb
**/.ipynb_checkpoints
**/.pytest_cache
diff --git a/docs/overrides/main.html b/docs/overrides/main.html
new file mode 100644
index 0000000..37f949a
--- /dev/null
+++ b/docs/overrides/main.html
@@ -0,0 +1,32 @@
+{% extends "base.html" %}
+
+
+{% block extrahead %}
+
+
+ {% if page and page.meta and page.meta.title %}
+
+ {% elif page and page.title and not page.is_homepage %}
+
+ {% else %}
+
+ {% endif %}
+
+
+
+
+
+
+
+ {% if page and page.meta and page.meta.title %}
+
+ {% elif page and page.title and not page.is_homepage %}
+
+ {% else %}
+
+ {% endif %}
+
+
+
+
+{% endblock %}
\ No newline at end of file
diff --git a/examples/2.0_operations.ipynb b/examples/2.0_operations.ipynb
index 8ee3914..ea8f392 100644
--- a/examples/2.0_operations.ipynb
+++ b/examples/2.0_operations.ipynb
@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -62,17 +62,9 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 3,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed',)).History will not be written to the database.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import copy\n",
"from pprint import pprint\n",
@@ -101,24 +93,26 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "OrderedDict([('rename_labels', ),\n",
+ "OrderedDict([('rename_labels', ),\n",
" ('fix_annotations',\n",
- " ),\n",
+ " ),\n",
+ " ('strip_annotations',\n",
+ " ),\n",
" ('fix_tokenization_and_spacing',\n",
- " ),\n",
- " ('add_tokens', ),\n",
- " ('upcase_labels', ),\n",
+ " ),\n",
+ " ('add_tokens', ),\n",
+ " ('upcase_labels', ),\n",
" ('filter_overlaps',\n",
- " )])"
+ " )])"
]
},
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -129,7 +123,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@@ -143,27 +137,27 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- ""
+ "[]"
]
},
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "fix_tokenization_and_spacing"
+ "fix_tokenization_and_spacing.pre"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -172,16 +166,16 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "[]"
+ "[OperationState(name='fix_tokenization_and_spacing', batch=False, args=[], kwargs={}, status=, ts=datetime.datetime(2020, 4, 15, 12, 45, 53, 973976), examples_added=0, examples_removed=0, examples_changed=0, transformations=[])]"
]
},
- "execution_count": 6,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -192,16 +186,49 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "recon.v1.spacy \n",
+ "dict_keys([1196406995714445350, 1975826118262010565, 2208035061381934376, 1473978031617727200, 106335390642092038, 500601151867807297, 440832847742549781, 706460250993890974, 1115219443209707540, 1949519895665837812, 935299783623680754, 2163307383884749416, 403532604991006089, 440153399298716595, 2165419558369486553, 709069367829232207, 1164040277870844331, 1001820030585517207, 1520504525369429473, 1143349283593732378, 772909226859630929, 2186683589276212162, 1439246817467080996, 646205256634488830, 391054967234424417, 1069073211256351973, 1631116673795041161, 1853287415515327485, 1312008281566775001, 1850119932490491100, 2248886061479488760, 1392728192912329521, 1436426349557743173, 502406100547088088, 1280147727121249306, 1761578724182773272, 1539877667996420026, 31812586317021947, 1301439328206016166, 384384728390475553, 769173937489945255, 2288129318297025329, 1242041334273589045, 228629945225414711, 1098488717506082294, 282492229720054716, 380310756566959607, 1924565152306637272, 376684774014885964, 2304838004040464281, 103529869190846028, 1988264538232162424, 1217488862723372728, 1083422037887251876, 552794652063744234, 1991149014422385820, 1804412783475653815, 133559379123924178, 630284206817307426, 1128000514323767167, 533842790538656977, 1783218316333109775, 44327064662381491, 590102417292709537, 1328189097588562120, 1471741353434857039, 1395082368088541827, 2020857152038051220, 1515538653321808267, 2909653324313924, 1142573215943909380, 1348307833267793209, 1424011768500181956, 2183248012838820930, 802489642660632463, 1860169144513739726, 435413361717193409, 606453138874978909, 1580995208213325947, 185933608830505150, 940811210756437104, 220881773549728518, 284313563161040515, 2085769956561363246, 1765438062702589944, 1090253479434315853, 1256422104425125455, 2233857738026586523, 553002533348785626, 603521638447945005, 626642907161470527, 322564103237696284, 205403610460430855, 999682930106592044, 87620785170841013, 1398471957127296361, 1869198887032894170, 1384780638327235009, 2270385796303996220, 1209709110032413943, 176860716858529675, 1241234826570026284, 200276835658424828, 2026883444237770066, 1495583689335203153, 1995249007653203494])\n",
+ "recon.v1.spacy \n",
+ "dict_keys([1700529846029343910, 1442725918004243918, 1615314956284386067, 2292918397458640594, 936174842952091019, 1416495214555748544, 1470704679076406034, 416936013533939646, 102976739755988675, 328680753745285627, 1821910156486108522, 461746148819085540, 1316906644314897545, 480190416079738296, 832098851410421791, 2077939734169057805, 499328831071479316, 2095072877850172291, 1709939900352147844, 59689883316342687, 398891871713859070, 1627140869294591678, 310028046539117771, 1862984016626204930, 515420873264934608, 198250555976187584, 1674255301498414544, 2207684247935563565, 39776830632428320, 967824718652465343, 1717026350721933397, 546998027537219826, 327268413308598687, 1070784568438262286, 1191963312554833233, 283812158810932089, 1600875993361830372, 827400666469217446, 833324646184074177, 141634872600563553, 868035574974736930, 710098414400391884, 2214163065793631862, 1960217053886089487, 413733258957413667, 574101210792060298, 128584703305360341, 1229099326012548773, 152681987489235101, 1203120265425271744, 1210731897437002224, 1708971935759553743, 1082408608430937670, 1587584412664023332, 1288634136647605965, 1559990870835400553, 727956066452843778, 411021709465579531, 712627598286018477, 745384867524222597, 754114973842762845, 1250213657925559709, 1060453348244319975, 1909832121214376215, 142120949158349466, 2127651655156229446, 79606520449659671, 1201338150913378032, 883878955589494298, 1476557811892588710, 75316633542116270, 1116956250434915335, 764974836186572417, 1468246759446807473, 2112306765897272231, 2193368486573054880, 1829112984191980085, 1250450046952360826, 1634755933488298800, 1602844456389187640, 863282461614727969, 2062363953964471355, 1809453615774236214, 235578488679793708, 1501219999274819459, 1842768724828632935, 1617507151150411878, 1694156056316982174, 1416023339598949511, 1712731616395876010, 1796257318551092056, 189257385865390542, 2118382680464518054, 234444508079345572, 1335636780743303396, 687445399934304164, 1068178454877698177, 935890424003835932, 1520292643375647582, 1492858771249319517, 57842038071894975, 1425233669197905973, 1605065109739879831, 3032625132909434, 1059081288433944966, 1384780638327235009, 2270385796303996220, 1209709110032413943, 176860716858529675, 1241234826570026284])\n",
+ "recon.v1.spacy \n",
+ "dict_keys([1367973413700956518, 2029160499876569560, 1607811024963276245, 489279466128008280, 1415495710082308731, 2255889546845585358, 530003636929667978, 1835925174907410759, 1041359394819446102, 861318538180278693, 583083590587376325, 105665357616818193, 1559441277659343605, 993295836099910212, 697271199090081718, 1305803167822276105, 1501964998077241217, 1983023552699738193, 131613684448168329, 1499203256596183286, 1913310551616924816, 805448350284417457, 1542760328774421006, 647271954588451270, 986139625200759836, 350516643100518264, 1076604938120777862, 2262171711793193596, 122887233971179058, 1061866805993556892, 1787769822543877726, 43799087513494440, 216202532025443194, 102509625659950763, 1005529631534994242, 1541086328188921886, 1425565207195997244, 455885524989955913, 389017461808763837, 435254507021985749, 1485603496641415809, 2260855597686721254, 2191677898864625980, 132765292597080193, 1432390711654223784, 650836541803377396, 1524192431628044429, 1598691181628504187, 1957197914153068094, 926372278586414398, 1242024804406792689, 600297596622501654, 1612138149287315133, 1452781579416233781, 991080122349479433, 611942000611137234, 896909537791521797, 1294785973424865062, 367331397880364536, 578218064576825402, 179908893921063561, 1800926667122710822, 276853797656703690, 1581244857368993418, 1978762199449998997, 226275956992877221, 1283816090281245606, 1581031192357922359, 1921383998015280342, 1242326926490269961, 999978380577142972, 6802486761397437, 1199172863138105586, 1972779204238002222, 1082706630997395533, 1737571819712865172, 58089873693865820, 717335074039265265, 1940859792894958469, 2247458127698710144, 1824060206597705914, 1070135875738874749, 2149059511880502851, 1323080716800541486, 1501301622567185490, 1698455696044143889, 749766778723534242, 4193966078694182, 750447672412097139, 961663597043273946, 1017453684411282930, 2157922784386036570, 1474272284242173255, 334900450619140659, 752986196974172619, 210946257216823596])\n"
+ ]
+ }
+ ],
"source": [
"corpus.apply_(\"fix_tokenization_and_spacing\")"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[OperationState(name='fix_tokenization_and_spacing', batch=False, args=[], kwargs={}, status=, ts=datetime.datetime(2020, 4, 15, 12, 45, 53, 973976), examples_added=0, examples_removed=0, examples_changed=0, transformations=[])]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "corpus._train.operations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
diff --git a/mkdocs.yml b/mkdocs.yml
index 4b5599f..03e3a7a 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -9,6 +9,7 @@ theme:
accent: 'deep-orange'
logo: 'img/drone-white.svg'
favicon: 'img/favicon.ico'
+ custom_dir: docs/overrides
repo_name: 'microsoft/reconner'
repo_url: 'https://github.com/microsoft/reconner'
diff --git a/recon/corrections.py b/recon/corrections.py
index 9645fc1..f0eb5a0 100644
--- a/recon/corrections.py
+++ b/recon/corrections.py
@@ -74,3 +74,34 @@ def fix_annotations(
print(line)
return example
+
+
+@operation("strip_annotations")
+def strip_annotations(example: Example, strip_chars: List[str] = [".", "!", "?", "-", ":", " "]) -> Example:
+ """Strip punctuation and spaces from start and end of annotations.
+ These characters are almost always a mistake and will confuse a model
+
+ Args:
+ example (Example): Input Example
+ strip_chars (List[str], optional): Characters to strip.
+
+ Returns:
+ Example: Example with stripped spans
+ """
+
+ for s in example.spans:
+ for ch in strip_chars:
+ if s.text.startswith(ch):
+ ch = s.text[0]
+
+ while ch in strip_chars:
+ s.text = s.text[1:]
+ s.start += 1
+ ch = s.text[0]
+ elif s.text.endswith(ch):
+ ch = s.text[-1]
+ while ch in strip_chars:
+ s.text = s.text[:-1]
+ ch = s.text[-1]
+ s.end -= 1
+ return example
diff --git a/recon/preprocess.py b/recon/preprocess.py
new file mode 100644
index 0000000..a6225e5
--- /dev/null
+++ b/recon/preprocess.py
@@ -0,0 +1,37 @@
+from functools import chain
+from typing import Any, Dict, Iterator, List
+import spacy
+from spacy.language import Language
+from spacy.tokens import Doc
+
+from .types import Example
+
+
+class PreProcessor(object):
+ name = "recon.v1.preprocess"
+ def __init__(self):
+ super().__init__()
+ self._cache: Dict[Any, Any]
+
+ def __call__(self, data: List[Example]) -> Iterator[Any]:
+ raise NotImplementedError
+
+
+class SpacyPreProcessor(PreProcessor):
+ name = "recon.v1.spacy"
+
+ def __init__(self, nlp: Language):
+ super().__init__()
+ self._nlp = nlp
+
+ def __call__(self, data: List[Example]) -> Iterator[Any]:
+ unseen_texts = (e.text for i, e in enumerate(data) if hash(e) not in self._cache)
+ seen_texts = ((i, e.text) for i, e in enumerate(data) if hash(e) in self._cache)
+
+ docs = list(self._nlp.pipe(unseen_texts))
+ for doc in docs:
+ self._cache[doc.text] = doc
+ for idx, st in seen_texts:
+ docs.insert(idx, self._cache[st])
+
+ return (doc for doc in docs)
diff --git a/recon/tokenization.py b/recon/tokenization.py
index 2b6e666..017e418 100644
--- a/recon/tokenization.py
+++ b/recon/tokenization.py
@@ -1,11 +1,12 @@
import copy
from collections import defaultdict
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any, Dict, List, Set, Tuple, Union
from spacy.language import Language
from .dataset import Dataset
from .operations import op_iter, operation
+from .preprocess import SpacyPreProcessor
from .registry import tokenizers
from .types import (
Example,
@@ -18,14 +19,17 @@ from .types import (
)
-@operation("fix_tokenization_and_spacing", batch=True)
+tokenizer = tokenizers.get("default")
+nlp = tokenizer()
+spacy_pre_processor = SpacyPreProcessor(nlp)
+
+
+@operation("fix_tokenization_and_spacing", pre=[spacy_pre_processor])
def fix_tokenization_and_spacing(
- examples: List[Example],
+ example: Example,
*,
- callbacks: TransformationCallbacks = TransformationCallbacks(),
- tokenizer: str = "default",
- verbose: bool = False,
-) -> List[Example]:
+ preprocessed_outputs: Dict[str, Any] = {}
+) -> Union[Example, None]:
"""Fix tokenization and spacing issues where there are annotation spans that
don't fall on a token boundary. This can happen if annotations are done at the
character level, not the token level. Often, when scraping web text it's easy to
@@ -39,133 +43,114 @@ def fix_tokenization_and_spacing(
Returns:
List[Example]: List of examples with fixed tokenization
"""
- fixed_examples = []
- tokenization_errors: List[Tuple[Example, Span]] = []
- unfixable_examples: Set[str] = set()
- nlp = tokenizers.get("default")()
- texts = (e.text for e in examples)
+ doc = preprocessed_outputs["recon.v1.spacy"]
- with nlp.disable_pipes(*nlp.pipe_names):
- for (orig_example_hash, example), doc in zip(op_iter(examples), nlp.pipe(texts)):
- doc = nlp.make_doc(example.text)
+ tokens = []
+ token_starts = {}
+ token_ends = {}
- tokens = []
- token_starts = {}
- token_ends = {}
+ for t in doc:
+ start = t.idx
+ end = t.idx + len(t)
+ tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
+ token_starts[start] = t
+ token_ends[end] = t
- for t in doc:
- start = t.idx
- end = t.idx + len(t)
- tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
- token_starts[start] = t
- token_ends[end] = t
+ spans_to_increment: Dict[int, int] = defaultdict(int)
+ for span_i, span in enumerate(example.spans):
+ if span.start in token_starts and span.end in token_ends:
+ # Aligns to token boundaries, nothing to change here
+ continue
- spans_to_increment: Dict[int, int] = defaultdict(int)
- for span_i, span in enumerate(example.spans):
- if span.start in token_starts and span.end in token_ends:
- # Aligns to token boundaries, nothing to change here
- continue
-
- if span.start in token_starts and span.end not in token_ends:
- # Span start aligns to token_start but end doesn't
- # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
- tokenization_errors.append((example, span))
- # print("BAD END")
- if span.end + 1 in token_ends:
- # Likely off by 1 annotation
- # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
- span.end += 1
- span.text = example.text[span.start : span.end]
- # print("SPAN CORRECTED OFF BY 1", example.text, span)
- elif span.end - 1 in token_ends:
- span.end -= 1
- span.text = example.text[span.start : span.end]
- else:
- # Likely bad tokenization
- # e.g. [Quadling][GPE]Country should be split to [Quadling][GPE] Country
- for j in range(span_i + 1, len(example.spans)):
- spans_to_increment[j] += 1
- fe_text = example.text
-
- split_start = span.start
- if (
- len(spans_to_increment) > 1
- and span_i != list(spans_to_increment.keys())[0]
- ):
- split_start += spans_to_increment.get(span_i, 0)
- split_end = span.end
- if (
- len(spans_to_increment) > 1
- and span_i != list(spans_to_increment.keys())[0]
- ):
- split_end += spans_to_increment.get(span_i, 0)
- new_text = f"{fe_text[:split_start]}{span.text} {fe_text[split_end:]}"
-
- example.text = new_text
-
- elif span.start not in token_starts and span.end in token_ends:
- # Bad tokenization
- # e.g. with[Raymond][PERSON] but text should be split to with [Raymond][PERSON]
- # print("BAD START", span.text)
- tokenization_errors.append((example, span))
- for j in range(span_i, len(example.spans)):
- spans_to_increment[j] += 1
-
- fe_text = example.text
-
- split_start = span.start
- if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
- split_start += spans_to_increment.get(span_i, 0)
- split_end = span.end
- if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
- split_end += spans_to_increment.get(span_i, 0)
-
- new_text = f"{fe_text[:split_start]} {span.text}{fe_text[split_end:]}"
- example.text = new_text
- else:
- # Something is super fucked up.
- # print("SPAN CORRECTED OFF BY 1 unfixable", example.text, span)
- before = span.start
- after = span.end
- tokenization_errors.append((example, span))
-
- # if (before >= 0 and after < len(span.text) and span[before] not in token_starts and span[before] != ' ' and span[after] not in token_ends and span[after] != ' '):
- # fe_text = example.text
- # new_text = f"{fe_text[:span.start]} {span.text}{fe_text[span.end:]}"
- # spans_to_increment[span_i] += 1
- # for j in range(span_i + 1, len(example.spans)):
- # spans_to_increment[j] += 2
- # else:
- unfixable_examples.add(example.text)
- break
-
- # Increment the start and end characters for each span
- for span_i, count in spans_to_increment.items():
- example.spans[span_i].start += count
- example.spans[span_i].end += count
-
- if example.text not in unfixable_examples:
- callbacks.track_example(orig_example_hash, example)
- fixed_examples.append(example)
+ if span.start in token_starts and span.end not in token_ends:
+ # Span start aligns to token_start but end doesn't
+ # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
+ # tokenization_errors.append((example, span))
+ # print("BAD END")
+ if span.end + 1 in token_ends:
+ # Likely off by 1 annotation
+ # e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
+ span.end += 1
+ span.text = example.text[span.start : span.end]
+ # print("SPAN CORRECTED OFF BY 1", example.text, span)
+ elif span.end - 1 in token_ends:
+ span.end -= 1
+ span.text = example.text[span.start : span.end]
else:
- callbacks.remove_example(orig_example_hash)
+ # Likely bad tokenization
+ # e.g. [Quadling][GPE]Country should be split to [Quadling][GPE] Country
+ for j in range(span_i + 1, len(example.spans)):
+ spans_to_increment[j] += 1
+ fe_text = example.text
- if tokenization_errors and verbose:
- print(f"Found {len(tokenization_errors)} tokenization errors.")
- print(f"Found {len(unfixable_examples)} unfixable tokenization errors.")
+ split_start = span.start
+ if (
+ len(spans_to_increment) > 1
+ and span_i != list(spans_to_increment.keys())[0]
+ ):
+ split_start += spans_to_increment.get(span_i, 0)
+ split_end = span.end
+ if (
+ len(spans_to_increment) > 1
+ and span_i != list(spans_to_increment.keys())[0]
+ ):
+ split_end += spans_to_increment.get(span_i, 0)
+ new_text = f"{fe_text[:split_start]}{span.text} {fe_text[split_end:]}"
- return fixed_examples
+ example.text = new_text
+
+ elif span.start not in token_starts and span.end in token_ends:
+ # Bad tokenization
+ # e.g. with[Raymond][PERSON] but text should be split to with [Raymond][PERSON]
+ # print("BAD START", span.text)
+ # tokenization_errors.append((example, span))
+ for j in range(span_i, len(example.spans)):
+ spans_to_increment[j] += 1
+
+ fe_text = example.text
+
+ split_start = span.start
+ if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
+ split_start += spans_to_increment.get(span_i, 0)
+ split_end = span.end
+ if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
+ split_end += spans_to_increment.get(span_i, 0)
+
+ new_text = f"{fe_text[:split_start]} {span.text}{fe_text[split_end:]}"
+ example.text = new_text
+ else:
+ # Something is super fucked up.
+ # print("SPAN CORRECTED OFF BY 1 unfixable", example.text, span)
+ before = span.start
+ after = span.end
+ # tokenization_errors.append((example, span))
+
+ # if (before >= 0 and after < len(span.text) and span[before] not in token_starts and span[before] != ' ' and span[after] not in token_ends and span[after] != ' '):
+ # fe_text = example.text
+ # new_text = f"{fe_text[:span.start]} {span.text}{fe_text[span.end:]}"
+ # spans_to_increment[span_i] += 1
+ # for j in range(span_i + 1, len(example.spans)):
+ # spans_to_increment[j] += 2
+ # else:
+ # unfixable_examples.add(example.text)
+ # break
+ return None
+
+ # Increment the start and end characters for each span
+ for span_i, count in spans_to_increment.items():
+ example.spans[span_i].start += count
+ example.spans[span_i].end += count
+
+ return example
-@operation("add_tokens", batch=True)
+@operation("add_tokens", pre=[spacy_pre_processor])
def add_tokens(
- examples: List[Example],
+ example: Example,
*,
- callbacks: TransformationCallbacks = TransformationCallbacks(),
- tokenizer: str = "default",
- verbose: bool = False,
-) -> List[Example]:
+ preprocessed_outputs: Dict[str, Any]
+) -> Union[Example, None]:
"""Add tokens to each Example
Args:
@@ -176,44 +161,27 @@ def add_tokens(
Returns:
List[Example]: List of examples with tokens
"""
- output_examples: List[Example] = []
- tokenization_errors: List[Tuple[Example, Span]] = []
- unfixable_examples: Set[str] = set()
- nlp = tokenizers.get(tokenizer)()
- texts = (e.text for e in examples)
+ doc = preprocessed_outputs["recon.v1.spacy"]
- with nlp.disable_pipes(*nlp.pipe_names):
- for (orig_example_hash, example), doc in zip(op_iter(examples), nlp.pipe(texts)):
- tokens = []
- token_starts = {}
- token_ends = {}
+ tokens = []
+ token_starts = {}
+ token_ends = {}
- for t in doc:
- start = t.idx
- end = t.idx + len(t)
- tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
- token_starts[start] = t
- token_ends[end] = t
+ for t in doc:
+ start = t.idx
+ end = t.idx + len(t)
+ tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
+ token_starts[start] = t
+ token_ends[end] = t
- example.tokens = tokens
+ example.tokens = tokens
- for span in example.spans:
- if span.start in token_starts and span.end in token_ends:
- span.token_start = token_starts[span.start].i
- span.token_end = token_ends[span.end].i
+ for span in example.spans:
+ if span.start in token_starts and span.end in token_ends:
+ span.token_start = token_starts[span.start].i
+ span.token_end = token_ends[span.end].i
- if span.token_start is None or span.token_end is None:
- tokenization_errors.append((example, span))
- unfixable_examples.add(example.text)
+ if span.token_start is None or span.token_end is None:
+ return None
- if example.text not in unfixable_examples:
- callbacks.track_example(orig_example_hash, example)
- output_examples.append(example)
- else:
- callbacks.remove_example(orig_example_hash)
-
- if verbose:
- print(f"Found {len(tokenization_errors)} tokenization errors.")
- print(f"Found {len(unfixable_examples)} unfixable examples.")
-
- return output_examples
+ return example