зеркало из https://github.com/microsoft/reconner.git
114 строки
4.1 KiB
Python
114 строки
4.1 KiB
Python
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, List, Set, Tuple
|
|
|
|
from spacy.language import Language
|
|
|
|
from .operations import operation
|
|
from .types import Example, Span, TransformationCallbacks
|
|
|
|
|
|
@operation("recon.v1.upcase_labels")
|
|
def upcase_labels(example: Example) -> Example:
|
|
"""Convert all span labels to uppercase to normalize
|
|
|
|
Args:
|
|
example (Example): Input Example
|
|
|
|
Returns:
|
|
Example: Example with fixed labels
|
|
"""
|
|
for s in example.spans:
|
|
s.label = s.label.upper()
|
|
return example
|
|
|
|
|
|
@operation("recon.v1.filter_overlaps")
|
|
def filter_overlaps(example: Example) -> Example:
|
|
"""Filter overlapping entity spans by picking the longest one.
|
|
|
|
Args:
|
|
example (Example): Input Example
|
|
|
|
Returns:
|
|
List[Example]: Example with fixed overlaps
|
|
"""
|
|
annotations: List[Span] = sorted(example.spans, key=lambda s: s.start)
|
|
filtered_annotations = remove_overlapping_entities(annotations)
|
|
example.spans = filtered_annotations
|
|
|
|
return example
|
|
|
|
|
|
def select_subset_of_overlapping_chain(chain: List[Span]) -> List[Span]:
|
|
"""
|
|
Select the subset of entities in an overlapping chain to return by greedily choosing the
|
|
longest entity in the chain until there are no entities remaining
|
|
"""
|
|
sorted_chain = sorted(chain, key=lambda s: s.end - s.start, reverse=True)
|
|
selections_from_chain: List[Span] = []
|
|
chain_index = 0
|
|
# dump the current chain by greedily keeping the longest entity that doesn't overlap
|
|
while chain_index < len(sorted_chain):
|
|
entity = sorted_chain[chain_index]
|
|
match_found = False
|
|
for already_selected_entity in selections_from_chain:
|
|
max_start = max(entity.start, already_selected_entity.start)
|
|
min_end = min(entity.end, already_selected_entity.end)
|
|
if len(range(max_start, min_end)) > 0:
|
|
match_found = True
|
|
break
|
|
|
|
if not match_found:
|
|
selections_from_chain.append(entity)
|
|
|
|
chain_index += 1
|
|
|
|
return selections_from_chain
|
|
|
|
|
|
def remove_overlapping_entities(sorted_spans: List[Span]) -> List[Span]:
|
|
"""
|
|
Removes overlapping entities from the entity set, by greedilytaking the longest
|
|
entity from each overlapping chain. The input list of entities should be sorted
|
|
and follow the spacy format.
|
|
"""
|
|
spans_without_overlap: List[Span] = []
|
|
current_overlapping_chain: List[Span] = []
|
|
current_overlapping_chain_start = 0
|
|
current_overlapping_chain_end = 0
|
|
for i, current_entity in enumerate(sorted_spans):
|
|
current_entity = sorted_spans[i]
|
|
current_entity_start = current_entity.start
|
|
current_entity_end = current_entity.end
|
|
|
|
if len(current_overlapping_chain) == 0:
|
|
current_overlapping_chain.append(current_entity)
|
|
current_overlapping_chain_start = current_entity_start
|
|
current_overlapping_chain_end = current_entity_end
|
|
else:
|
|
min_end = min(current_entity_end, current_overlapping_chain_end)
|
|
max_start = max(current_entity_start, current_overlapping_chain_start)
|
|
if min_end - max_start > 0:
|
|
current_overlapping_chain.append(current_entity)
|
|
current_overlapping_chain_start = min(
|
|
current_entity_start, current_overlapping_chain_start
|
|
)
|
|
current_overlapping_chain_end = max(
|
|
current_entity_end, current_overlapping_chain_end
|
|
)
|
|
else:
|
|
selections_from_chain: List[Span] = select_subset_of_overlapping_chain(
|
|
current_overlapping_chain
|
|
)
|
|
|
|
current_overlapping_chain = []
|
|
spans_without_overlap.extend(selections_from_chain)
|
|
current_overlapping_chain.append(current_entity)
|
|
current_overlapping_chain_start = current_entity_start
|
|
current_overlapping_chain_end = current_entity_end
|
|
|
|
spans_without_overlap.extend(select_subset_of_overlapping_chain(current_overlapping_chain))
|
|
|
|
return sorted(spans_without_overlap, key=lambda x: x.start)
|