Adding caching to preprocessing and adding docs

This commit is contained in:
Kabir Khan 2020-04-19 14:21:29 -07:00
Родитель c44a6e23a0
Коммит 00f2a64733
7 изменённых файлов: 285 добавлений и 189 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -10,7 +10,7 @@ dist
site
.coverage
htmlcov
examples/cs_data_test.ipynb
examples/cs_data*.ipynb
examples/skills_data*.ipynb
**/.ipynb_checkpoints
**/.pytest_cache

32
docs/overrides/main.html Normal file
Просмотреть файл

@ -0,0 +1,32 @@
{% extends "base.html" %}
<!-- Extrahead -->
{% block extrahead %}
<meta property="og:url" content="{{ page.canonical_url }}">
{% if page and page.meta and page.meta.title %}
<meta property="og:title" content="{{ page.meta.title }}">
{% elif page and page.title and not page.is_homepage %}
<meta property="og:title" content="{{ page.title }} - {{ config.site_name }}">
{% else %}
<meta property="og:title" content="{{ config.site_name }}">
{% endif %}
<meta property="og:description" content="{{ config.site_description }}">
<meta property="og:image" content="https://microsoft.github.io/reconner/img/recon.svg">
<meta property="og:image:alt" content="Recon NER">
<meta name="twitter:card" content="summary_large_image">
<meta name="twitter:site" content="@kabir_khan14">
<meta name="twitter:creator" content="@kabir_khan14">
{% if page and page.meta and page.meta.title %}
<meta property="twitter:title" content="{{ page.meta.title }}">
{% elif page and page.title and not page.is_homepage %}
<meta property="twitter:title" content="{{ page.title }} - {{ config.site_name }}">
{% else %}
<meta property="twitter:title" content="{{ config.site_name }}">
{% endif %}
<meta name="twitter:description" content="{{ config.site_description }}">
<meta name="twitter:image" content="https://microsoft.github.io/reconner/img/recon.svg">
<meta name="twitter:image:alt" content="Recon NER">
{% endblock %}

Просмотреть файл

@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@ -62,17 +62,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed',)).History will not be written to the database.\n"
]
}
],
"outputs": [],
"source": [
"import copy\n",
"from pprint import pprint\n",
@ -101,24 +93,26 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('rename_labels', <recon.operations.Operation at 0x7fdfe9f2a6a0>),\n",
"OrderedDict([('rename_labels', <recon.operations.Operation at 0x7f2d7478ea58>),\n",
" ('fix_annotations',\n",
" <recon.operations.Operation at 0x7fdfe9f2a6d8>),\n",
" <recon.operations.Operation at 0x7f2d7478ea90>),\n",
" ('strip_annotations',\n",
" <recon.operations.Operation at 0x7f2d7478eac8>),\n",
" ('fix_tokenization_and_spacing',\n",
" <recon.operations.Operation at 0x7fdfe040e320>),\n",
" ('add_tokens', <recon.operations.Operation at 0x7fdfe040e358>),\n",
" ('upcase_labels', <recon.operations.Operation at 0x7fdfe040e4a8>),\n",
" <recon.operations.Operation at 0x7f2d26453320>),\n",
" ('add_tokens', <recon.operations.Operation at 0x7f2d26453198>),\n",
" ('upcase_labels', <recon.operations.Operation at 0x7f2d264533c8>),\n",
" ('filter_overlaps',\n",
" <recon.operations.Operation at 0x7fdfe040e4e0>)])"
" <recon.operations.Operation at 0x7f2d26453358>)])"
]
},
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -129,7 +123,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -143,27 +137,27 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<recon.operations.Operation at 0x7fdfe040e320>"
"[<recon.preprocess.SpacyPreProcessor at 0x7f2d264532b0>]"
]
},
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fix_tokenization_and_spacing"
"fix_tokenization_and_spacing.pre"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -172,16 +166,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
"[OperationState(name='fix_tokenization_and_spacing', batch=False, args=[], kwargs={}, status=<OperationStatus.COMPLETED: 'COMPLETED'>, ts=datetime.datetime(2020, 4, 15, 12, 45, 53, 973976), examples_added=0, examples_removed=0, examples_changed=0, transformations=[])]"
]
},
"execution_count": 6,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -192,16 +186,49 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"recon.v1.spacy <generator object Language.pipe at 0x7f2d21a9ee60>\n",
"dict_keys([1196406995714445350, 1975826118262010565, 2208035061381934376, 1473978031617727200, 106335390642092038, 500601151867807297, 440832847742549781, 706460250993890974, 1115219443209707540, 1949519895665837812, 935299783623680754, 2163307383884749416, 403532604991006089, 440153399298716595, 2165419558369486553, 709069367829232207, 1164040277870844331, 1001820030585517207, 1520504525369429473, 1143349283593732378, 772909226859630929, 2186683589276212162, 1439246817467080996, 646205256634488830, 391054967234424417, 1069073211256351973, 1631116673795041161, 1853287415515327485, 1312008281566775001, 1850119932490491100, 2248886061479488760, 1392728192912329521, 1436426349557743173, 502406100547088088, 1280147727121249306, 1761578724182773272, 1539877667996420026, 31812586317021947, 1301439328206016166, 384384728390475553, 769173937489945255, 2288129318297025329, 1242041334273589045, 228629945225414711, 1098488717506082294, 282492229720054716, 380310756566959607, 1924565152306637272, 376684774014885964, 2304838004040464281, 103529869190846028, 1988264538232162424, 1217488862723372728, 1083422037887251876, 552794652063744234, 1991149014422385820, 1804412783475653815, 133559379123924178, 630284206817307426, 1128000514323767167, 533842790538656977, 1783218316333109775, 44327064662381491, 590102417292709537, 1328189097588562120, 1471741353434857039, 1395082368088541827, 2020857152038051220, 1515538653321808267, 2909653324313924, 1142573215943909380, 1348307833267793209, 1424011768500181956, 2183248012838820930, 802489642660632463, 1860169144513739726, 435413361717193409, 606453138874978909, 1580995208213325947, 185933608830505150, 940811210756437104, 220881773549728518, 284313563161040515, 2085769956561363246, 1765438062702589944, 1090253479434315853, 1256422104425125455, 2233857738026586523, 553002533348785626, 603521638447945005, 626642907161470527, 322564103237696284, 205403610460430855, 999682930106592044, 87620785170841013, 1398471957127296361, 1869198887032894170, 1384780638327235009, 2270385796303996220, 1209709110032413943, 176860716858529675, 1241234826570026284, 200276835658424828, 2026883444237770066, 1495583689335203153, 1995249007653203494])\n",
"recon.v1.spacy <generator object Language.pipe at 0x7f2d217df5c8>\n",
"dict_keys([1700529846029343910, 1442725918004243918, 1615314956284386067, 2292918397458640594, 936174842952091019, 1416495214555748544, 1470704679076406034, 416936013533939646, 102976739755988675, 328680753745285627, 1821910156486108522, 461746148819085540, 1316906644314897545, 480190416079738296, 832098851410421791, 2077939734169057805, 499328831071479316, 2095072877850172291, 1709939900352147844, 59689883316342687, 398891871713859070, 1627140869294591678, 310028046539117771, 1862984016626204930, 515420873264934608, 198250555976187584, 1674255301498414544, 2207684247935563565, 39776830632428320, 967824718652465343, 1717026350721933397, 546998027537219826, 327268413308598687, 1070784568438262286, 1191963312554833233, 283812158810932089, 1600875993361830372, 827400666469217446, 833324646184074177, 141634872600563553, 868035574974736930, 710098414400391884, 2214163065793631862, 1960217053886089487, 413733258957413667, 574101210792060298, 128584703305360341, 1229099326012548773, 152681987489235101, 1203120265425271744, 1210731897437002224, 1708971935759553743, 1082408608430937670, 1587584412664023332, 1288634136647605965, 1559990870835400553, 727956066452843778, 411021709465579531, 712627598286018477, 745384867524222597, 754114973842762845, 1250213657925559709, 1060453348244319975, 1909832121214376215, 142120949158349466, 2127651655156229446, 79606520449659671, 1201338150913378032, 883878955589494298, 1476557811892588710, 75316633542116270, 1116956250434915335, 764974836186572417, 1468246759446807473, 2112306765897272231, 2193368486573054880, 1829112984191980085, 1250450046952360826, 1634755933488298800, 1602844456389187640, 863282461614727969, 2062363953964471355, 1809453615774236214, 235578488679793708, 1501219999274819459, 1842768724828632935, 1617507151150411878, 1694156056316982174, 1416023339598949511, 1712731616395876010, 1796257318551092056, 189257385865390542, 2118382680464518054, 234444508079345572, 1335636780743303396, 687445399934304164, 1068178454877698177, 935890424003835932, 1520292643375647582, 1492858771249319517, 57842038071894975, 1425233669197905973, 1605065109739879831, 3032625132909434, 1059081288433944966, 1384780638327235009, 2270385796303996220, 1209709110032413943, 176860716858529675, 1241234826570026284])\n",
"recon.v1.spacy <generator object Language.pipe at 0x7f2d21a9ef68>\n",
"dict_keys([1367973413700956518, 2029160499876569560, 1607811024963276245, 489279466128008280, 1415495710082308731, 2255889546845585358, 530003636929667978, 1835925174907410759, 1041359394819446102, 861318538180278693, 583083590587376325, 105665357616818193, 1559441277659343605, 993295836099910212, 697271199090081718, 1305803167822276105, 1501964998077241217, 1983023552699738193, 131613684448168329, 1499203256596183286, 1913310551616924816, 805448350284417457, 1542760328774421006, 647271954588451270, 986139625200759836, 350516643100518264, 1076604938120777862, 2262171711793193596, 122887233971179058, 1061866805993556892, 1787769822543877726, 43799087513494440, 216202532025443194, 102509625659950763, 1005529631534994242, 1541086328188921886, 1425565207195997244, 455885524989955913, 389017461808763837, 435254507021985749, 1485603496641415809, 2260855597686721254, 2191677898864625980, 132765292597080193, 1432390711654223784, 650836541803377396, 1524192431628044429, 1598691181628504187, 1957197914153068094, 926372278586414398, 1242024804406792689, 600297596622501654, 1612138149287315133, 1452781579416233781, 991080122349479433, 611942000611137234, 896909537791521797, 1294785973424865062, 367331397880364536, 578218064576825402, 179908893921063561, 1800926667122710822, 276853797656703690, 1581244857368993418, 1978762199449998997, 226275956992877221, 1283816090281245606, 1581031192357922359, 1921383998015280342, 1242326926490269961, 999978380577142972, 6802486761397437, 1199172863138105586, 1972779204238002222, 1082706630997395533, 1737571819712865172, 58089873693865820, 717335074039265265, 1940859792894958469, 2247458127698710144, 1824060206597705914, 1070135875738874749, 2149059511880502851, 1323080716800541486, 1501301622567185490, 1698455696044143889, 749766778723534242, 4193966078694182, 750447672412097139, 961663597043273946, 1017453684411282930, 2157922784386036570, 1474272284242173255, 334900450619140659, 752986196974172619, 210946257216823596])\n"
]
}
],
"source": [
"corpus.apply_(\"fix_tokenization_and_spacing\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[OperationState(name='fix_tokenization_and_spacing', batch=False, args=[], kwargs={}, status=<OperationStatus.COMPLETED: 'COMPLETED'>, ts=datetime.datetime(2020, 4, 15, 12, 45, 53, 973976), examples_added=0, examples_removed=0, examples_changed=0, transformations=[])]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"corpus._train.operations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [

Просмотреть файл

@ -9,6 +9,7 @@ theme:
accent: 'deep-orange'
logo: 'img/drone-white.svg'
favicon: 'img/favicon.ico'
custom_dir: docs/overrides
repo_name: 'microsoft/reconner'
repo_url: 'https://github.com/microsoft/reconner'

Просмотреть файл

@ -74,3 +74,34 @@ def fix_annotations(
print(line)
return example
@operation("strip_annotations")
def strip_annotations(example: Example, strip_chars: List[str] = [".", "!", "?", "-", ":", " "]) -> Example:
"""Strip punctuation and spaces from start and end of annotations.
These characters are almost always a mistake and will confuse a model
Args:
example (Example): Input Example
strip_chars (List[str], optional): Characters to strip.
Returns:
Example: Example with stripped spans
"""
for s in example.spans:
for ch in strip_chars:
if s.text.startswith(ch):
ch = s.text[0]
while ch in strip_chars:
s.text = s.text[1:]
s.start += 1
ch = s.text[0]
elif s.text.endswith(ch):
ch = s.text[-1]
while ch in strip_chars:
s.text = s.text[:-1]
ch = s.text[-1]
s.end -= 1
return example

37
recon/preprocess.py Normal file
Просмотреть файл

@ -0,0 +1,37 @@
from functools import chain
from typing import Any, Dict, Iterator, List
import spacy
from spacy.language import Language
from spacy.tokens import Doc
from .types import Example
class PreProcessor(object):
name = "recon.v1.preprocess"
def __init__(self):
super().__init__()
self._cache: Dict[Any, Any]
def __call__(self, data: List[Example]) -> Iterator[Any]:
raise NotImplementedError
class SpacyPreProcessor(PreProcessor):
name = "recon.v1.spacy"
def __init__(self, nlp: Language):
super().__init__()
self._nlp = nlp
def __call__(self, data: List[Example]) -> Iterator[Any]:
unseen_texts = (e.text for i, e in enumerate(data) if hash(e) not in self._cache)
seen_texts = ((i, e.text) for i, e in enumerate(data) if hash(e) in self._cache)
docs = list(self._nlp.pipe(unseen_texts))
for doc in docs:
self._cache[doc.text] = doc
for idx, st in seen_texts:
docs.insert(idx, self._cache[st])
return (doc for doc in docs)

Просмотреть файл

@ -1,11 +1,12 @@
import copy
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple
from typing import Any, Dict, List, Set, Tuple, Union
from spacy.language import Language
from .dataset import Dataset
from .operations import op_iter, operation
from .preprocess import SpacyPreProcessor
from .registry import tokenizers
from .types import (
Example,
@ -18,14 +19,17 @@ from .types import (
)
@operation("fix_tokenization_and_spacing", batch=True)
tokenizer = tokenizers.get("default")
nlp = tokenizer()
spacy_pre_processor = SpacyPreProcessor(nlp)
@operation("fix_tokenization_and_spacing", pre=[spacy_pre_processor])
def fix_tokenization_and_spacing(
examples: List[Example],
example: Example,
*,
callbacks: TransformationCallbacks = TransformationCallbacks(),
tokenizer: str = "default",
verbose: bool = False,
) -> List[Example]:
preprocessed_outputs: Dict[str, Any] = {}
) -> Union[Example, None]:
"""Fix tokenization and spacing issues where there are annotation spans that
don't fall on a token boundary. This can happen if annotations are done at the
character level, not the token level. Often, when scraping web text it's easy to
@ -39,133 +43,114 @@ def fix_tokenization_and_spacing(
Returns:
List[Example]: List of examples with fixed tokenization
"""
fixed_examples = []
tokenization_errors: List[Tuple[Example, Span]] = []
unfixable_examples: Set[str] = set()
nlp = tokenizers.get("default")()
texts = (e.text for e in examples)
doc = preprocessed_outputs["recon.v1.spacy"]
with nlp.disable_pipes(*nlp.pipe_names):
for (orig_example_hash, example), doc in zip(op_iter(examples), nlp.pipe(texts)):
doc = nlp.make_doc(example.text)
tokens = []
token_starts = {}
token_ends = {}
tokens = []
token_starts = {}
token_ends = {}
for t in doc:
start = t.idx
end = t.idx + len(t)
tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
token_starts[start] = t
token_ends[end] = t
for t in doc:
start = t.idx
end = t.idx + len(t)
tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
token_starts[start] = t
token_ends[end] = t
spans_to_increment: Dict[int, int] = defaultdict(int)
for span_i, span in enumerate(example.spans):
if span.start in token_starts and span.end in token_ends:
# Aligns to token boundaries, nothing to change here
continue
spans_to_increment: Dict[int, int] = defaultdict(int)
for span_i, span in enumerate(example.spans):
if span.start in token_starts and span.end in token_ends:
# Aligns to token boundaries, nothing to change here
continue
if span.start in token_starts and span.end not in token_ends:
# Span start aligns to token_start but end doesn't
# e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
tokenization_errors.append((example, span))
# print("BAD END")
if span.end + 1 in token_ends:
# Likely off by 1 annotation
# e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
span.end += 1
span.text = example.text[span.start : span.end]
# print("SPAN CORRECTED OFF BY 1", example.text, span)
elif span.end - 1 in token_ends:
span.end -= 1
span.text = example.text[span.start : span.end]
else:
# Likely bad tokenization
# e.g. [Quadling][GPE]Country should be split to [Quadling][GPE] Country
for j in range(span_i + 1, len(example.spans)):
spans_to_increment[j] += 1
fe_text = example.text
split_start = span.start
if (
len(spans_to_increment) > 1
and span_i != list(spans_to_increment.keys())[0]
):
split_start += spans_to_increment.get(span_i, 0)
split_end = span.end
if (
len(spans_to_increment) > 1
and span_i != list(spans_to_increment.keys())[0]
):
split_end += spans_to_increment.get(span_i, 0)
new_text = f"{fe_text[:split_start]}{span.text} {fe_text[split_end:]}"
example.text = new_text
elif span.start not in token_starts and span.end in token_ends:
# Bad tokenization
# e.g. with[Raymond][PERSON] but text should be split to with [Raymond][PERSON]
# print("BAD START", span.text)
tokenization_errors.append((example, span))
for j in range(span_i, len(example.spans)):
spans_to_increment[j] += 1
fe_text = example.text
split_start = span.start
if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
split_start += spans_to_increment.get(span_i, 0)
split_end = span.end
if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
split_end += spans_to_increment.get(span_i, 0)
new_text = f"{fe_text[:split_start]} {span.text}{fe_text[split_end:]}"
example.text = new_text
else:
# Something is super fucked up.
# print("SPAN CORRECTED OFF BY 1 unfixable", example.text, span)
before = span.start
after = span.end
tokenization_errors.append((example, span))
# if (before >= 0 and after < len(span.text) and span[before] not in token_starts and span[before] != ' ' and span[after] not in token_ends and span[after] != ' '):
# fe_text = example.text
# new_text = f"{fe_text[:span.start]} {span.text}{fe_text[span.end:]}"
# spans_to_increment[span_i] += 1
# for j in range(span_i + 1, len(example.spans)):
# spans_to_increment[j] += 2
# else:
unfixable_examples.add(example.text)
break
# Increment the start and end characters for each span
for span_i, count in spans_to_increment.items():
example.spans[span_i].start += count
example.spans[span_i].end += count
if example.text not in unfixable_examples:
callbacks.track_example(orig_example_hash, example)
fixed_examples.append(example)
if span.start in token_starts and span.end not in token_ends:
# Span start aligns to token_start but end doesn't
# e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
# tokenization_errors.append((example, span))
# print("BAD END")
if span.end + 1 in token_ends:
# Likely off by 1 annotation
# e.g. [customer][PERSONTYPE]s but should be annotated as [customers][PERSONTYPE]
span.end += 1
span.text = example.text[span.start : span.end]
# print("SPAN CORRECTED OFF BY 1", example.text, span)
elif span.end - 1 in token_ends:
span.end -= 1
span.text = example.text[span.start : span.end]
else:
callbacks.remove_example(orig_example_hash)
# Likely bad tokenization
# e.g. [Quadling][GPE]Country should be split to [Quadling][GPE] Country
for j in range(span_i + 1, len(example.spans)):
spans_to_increment[j] += 1
fe_text = example.text
if tokenization_errors and verbose:
print(f"Found {len(tokenization_errors)} tokenization errors.")
print(f"Found {len(unfixable_examples)} unfixable tokenization errors.")
split_start = span.start
if (
len(spans_to_increment) > 1
and span_i != list(spans_to_increment.keys())[0]
):
split_start += spans_to_increment.get(span_i, 0)
split_end = span.end
if (
len(spans_to_increment) > 1
and span_i != list(spans_to_increment.keys())[0]
):
split_end += spans_to_increment.get(span_i, 0)
new_text = f"{fe_text[:split_start]}{span.text} {fe_text[split_end:]}"
return fixed_examples
example.text = new_text
elif span.start not in token_starts and span.end in token_ends:
# Bad tokenization
# e.g. with[Raymond][PERSON] but text should be split to with [Raymond][PERSON]
# print("BAD START", span.text)
# tokenization_errors.append((example, span))
for j in range(span_i, len(example.spans)):
spans_to_increment[j] += 1
fe_text = example.text
split_start = span.start
if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
split_start += spans_to_increment.get(span_i, 0)
split_end = span.end
if len(spans_to_increment) > 1 and span_i != list(spans_to_increment.keys())[0]:
split_end += spans_to_increment.get(span_i, 0)
new_text = f"{fe_text[:split_start]} {span.text}{fe_text[split_end:]}"
example.text = new_text
else:
# Something is super fucked up.
# print("SPAN CORRECTED OFF BY 1 unfixable", example.text, span)
before = span.start
after = span.end
# tokenization_errors.append((example, span))
# if (before >= 0 and after < len(span.text) and span[before] not in token_starts and span[before] != ' ' and span[after] not in token_ends and span[after] != ' '):
# fe_text = example.text
# new_text = f"{fe_text[:span.start]} {span.text}{fe_text[span.end:]}"
# spans_to_increment[span_i] += 1
# for j in range(span_i + 1, len(example.spans)):
# spans_to_increment[j] += 2
# else:
# unfixable_examples.add(example.text)
# break
return None
# Increment the start and end characters for each span
for span_i, count in spans_to_increment.items():
example.spans[span_i].start += count
example.spans[span_i].end += count
return example
@operation("add_tokens", batch=True)
@operation("add_tokens", pre=[spacy_pre_processor])
def add_tokens(
examples: List[Example],
example: Example,
*,
callbacks: TransformationCallbacks = TransformationCallbacks(),
tokenizer: str = "default",
verbose: bool = False,
) -> List[Example]:
preprocessed_outputs: Dict[str, Any]
) -> Union[Example, None]:
"""Add tokens to each Example
Args:
@ -176,44 +161,27 @@ def add_tokens(
Returns:
List[Example]: List of examples with tokens
"""
output_examples: List[Example] = []
tokenization_errors: List[Tuple[Example, Span]] = []
unfixable_examples: Set[str] = set()
nlp = tokenizers.get(tokenizer)()
texts = (e.text for e in examples)
doc = preprocessed_outputs["recon.v1.spacy"]
with nlp.disable_pipes(*nlp.pipe_names):
for (orig_example_hash, example), doc in zip(op_iter(examples), nlp.pipe(texts)):
tokens = []
token_starts = {}
token_ends = {}
tokens = []
token_starts = {}
token_ends = {}
for t in doc:
start = t.idx
end = t.idx + len(t)
tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
token_starts[start] = t
token_ends[end] = t
for t in doc:
start = t.idx
end = t.idx + len(t)
tokens.append(Token(text=t.text, start=start, end=end, id=t.i))
token_starts[start] = t
token_ends[end] = t
example.tokens = tokens
example.tokens = tokens
for span in example.spans:
if span.start in token_starts and span.end in token_ends:
span.token_start = token_starts[span.start].i
span.token_end = token_ends[span.end].i
for span in example.spans:
if span.start in token_starts and span.end in token_ends:
span.token_start = token_starts[span.start].i
span.token_end = token_ends[span.end].i
if span.token_start is None or span.token_end is None:
tokenization_errors.append((example, span))
unfixable_examples.add(example.text)
if span.token_start is None or span.token_end is None:
return None
if example.text not in unfixable_examples:
callbacks.track_example(orig_example_hash, example)
output_examples.append(example)
else:
callbacks.remove_example(orig_example_hash)
if verbose:
print(f"Found {len(tokenization_errors)} tokenization errors.")
print(f"Found {len(unfixable_examples)} unfixable examples.")
return output_examples
return example