Merge pull request #48 from Robbie-Palmer/tokenize_with_any_model
Tokenize with any model
This commit is contained in:
Коммит
6146d2ef25
|
@ -1,12 +1,10 @@
|
|||
from typing import List, Set, Dict
|
||||
from typing import List
|
||||
|
||||
from presidio_analyzer import RecognizerResult
|
||||
from presidio_anonymizer import AnonymizerEngine
|
||||
|
||||
from presidio_evaluator.data_generator import PresidioDataGenerator
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class PresidioPseudonymization(PresidioDataGenerator):
|
||||
def __init__(self, map_to_presidio_entities: bool = True, **kwargs):
|
||||
|
|
|
@ -39,7 +39,7 @@ PRESIDIO_SPACY_ENTITIES = {
|
|||
|
||||
class Span:
|
||||
"""
|
||||
Holds information about the start, end, type nad value
|
||||
Holds information about the start, end, type and value
|
||||
of an entity in a text
|
||||
"""
|
||||
|
||||
|
@ -126,6 +126,7 @@ class InputSample(object):
|
|||
tokens: Optional[Doc] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
create_tags_from_span=False,
|
||||
token_model_version="en_core_web_sm",
|
||||
scheme="IO",
|
||||
metadata=None,
|
||||
template_id=None,
|
||||
|
@ -142,6 +143,7 @@ class InputSample(object):
|
|||
:param tokens: spaCy Doc object
|
||||
:param tags: list of strings representing the label for each token,
|
||||
given the scheme
|
||||
:param token_model_version: The name of the model to use for tokenization if no tokens provided
|
||||
:param metadata: A dictionary of additional metadata on the sample,
|
||||
in the English (or other language) vocabulary
|
||||
:param template_id: Original template (utterance) of sample, in case it was generated # noqa
|
||||
|
@ -162,7 +164,7 @@ class InputSample(object):
|
|||
self.template_id = template_id
|
||||
|
||||
if create_tags_from_span:
|
||||
tokens, tags = self.get_tags(scheme)
|
||||
tokens, tags = self.get_tags(scheme, token_model_version)
|
||||
self.tokens = tokens
|
||||
self.tags = tags
|
||||
else:
|
||||
|
@ -217,16 +219,16 @@ class InputSample(object):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, data):
|
||||
def from_json(cls, data, **kwargs):
|
||||
if "spans" in data:
|
||||
data["spans"] = [Span.from_json(span) for span in data["spans"]]
|
||||
return cls(**data, create_tags_from_span=True)
|
||||
return cls(**data, create_tags_from_span=True, **kwargs)
|
||||
|
||||
def get_tags(self, scheme="IOB"):
|
||||
def get_tags(self, scheme="IOB", model_version="en_core_web_sm"):
|
||||
start_indices = [span.start_position for span in self.spans]
|
||||
end_indices = [span.end_position for span in self.spans]
|
||||
tags = [span.entity_type for span in self.spans]
|
||||
tokens = tokenize(self.full_text)
|
||||
tokens = tokenize(self.full_text, model_version)
|
||||
|
||||
labels = span_to_tag(
|
||||
scheme=scheme,
|
||||
|
@ -276,6 +278,7 @@ class InputSample(object):
|
|||
dataset: Union[List["InputSample"], List[FakerSpansResult]],
|
||||
translate_tags=False,
|
||||
to_bio=True,
|
||||
token_model_version="en_core_web_sm"
|
||||
) -> pd.DataFrame:
|
||||
|
||||
if len(dataset) <= 1:
|
||||
|
@ -284,7 +287,7 @@ class InputSample(object):
|
|||
if isinstance(dataset[0], FakerSpansResult):
|
||||
dataset = [
|
||||
InputSample.from_faker_spans_result(
|
||||
record, create_tags_from_span=True, scheme="BILUO"
|
||||
record, create_tags_from_span=True, scheme="BILUO", token_model_version=token_model_version
|
||||
)
|
||||
for record in tqdm(dataset, desc="Translating spans into tokens")
|
||||
]
|
||||
|
@ -548,7 +551,7 @@ class InputSample(object):
|
|||
|
||||
@staticmethod
|
||||
def read_dataset_json(
|
||||
filepath: Union[Path, str] = None, length: Optional[int] = None
|
||||
filepath: Union[Path, str] = None, length: Optional[int] = None, **kwargs
|
||||
) -> List["InputSample"]:
|
||||
"""
|
||||
Reads an existing dataset, stored in json into a list of InputSample objects
|
||||
|
@ -563,7 +566,7 @@ class InputSample(object):
|
|||
dataset = dataset[:length]
|
||||
|
||||
input_samples = [
|
||||
InputSample.from_json(row) for row in tqdm(dataset, desc="tokenizing input")
|
||||
InputSample.from_json(row, **kwargs) for row in tqdm(dataset, desc="tokenizing input")
|
||||
]
|
||||
|
||||
return input_samples
|
||||
|
|
|
@ -5,8 +5,8 @@ from typing import Dict, List
|
|||
|
||||
class ExperimentTracker:
|
||||
def __init__(self):
|
||||
self.parameters = None
|
||||
self.metrics = None
|
||||
self.parameters = dict()
|
||||
self.metrics = dict()
|
||||
self.dataset_info = None
|
||||
self.confusion_matrix = None
|
||||
self.labels = None
|
||||
|
@ -15,14 +15,14 @@ class ExperimentTracker:
|
|||
self.parameters[key] = value
|
||||
|
||||
def log_parameters(self, parameters: Dict):
|
||||
for k, v in parameters.values():
|
||||
for k, v in parameters.items():
|
||||
self.log_parameter(k, v)
|
||||
|
||||
def log_metric(self, key: str, value: object):
|
||||
self.metrics[key] = value
|
||||
|
||||
def log_metrics(self, metrics: Dict):
|
||||
for k, v in metrics.values():
|
||||
for k, v in metrics.items():
|
||||
self.log_metric(k, v)
|
||||
|
||||
def log_dataset_hash(self, data: str):
|
||||
|
@ -49,5 +49,5 @@ class ExperimentTracker:
|
|||
datetime_val = time.strftime("%Y%m%d-%H%M%S")
|
||||
filename = f"experiment_{datetime_val}.json"
|
||||
print(f"saving experiment data to {filename}")
|
||||
with open(filename) as json_file:
|
||||
with open(filename, 'w') as json_file:
|
||||
json.dump(self.__dict__, json_file)
|
||||
|
|
|
@ -107,6 +107,7 @@ def span_to_tag(
|
|||
tags: List[str],
|
||||
scores: Optional[List[float]] = None,
|
||||
tokens: Optional[Doc] = None,
|
||||
token_model_version: str = "en_core_web_sm"
|
||||
) -> List[str]:
|
||||
"""
|
||||
Turns a list of start and end values with corresponding labels, into a NER
|
||||
|
@ -118,6 +119,7 @@ def span_to_tag(
|
|||
:param ends: list of indices where entities in the text end
|
||||
:param tags: list of entity names
|
||||
:param scores: score of tag (confidence)
|
||||
:param token_model_version: version of the model used for tokenization if no tokens provided
|
||||
:return: list of strings, representing either BILUO or BIO for the input
|
||||
"""
|
||||
|
||||
|
@ -128,7 +130,7 @@ def span_to_tag(
|
|||
starts, ends, tags, scores = _handle_overlaps(starts, ends, tags, scores)
|
||||
|
||||
if not tokens:
|
||||
tokens = tokenize(text)
|
||||
tokens = tokenize(text, token_model_version)
|
||||
|
||||
io_tags = []
|
||||
for token in tokens:
|
||||
|
|
Загрузка…
Ссылка в новой задаче