Fix Inconsistent NER Grouping (Pipeline) (#4987)
* Add B I handling to grouping * Add fix to include separate entity as last token * move last_idx definition outside loop * Use first entity in entity group as reference for entity type * Add test cases * Take out extra class accidentally added * Return tf ner grouped test to original * Take out redundant last entity * Get last_idx safely Co-authored-by: ColleterVi <36503688+ColleterVi@users.noreply.github.com> * Fix first entity comment * Create separate functions for group_sub_entities and group_entities (splitting call method to testable functions) * Take out unnecessary last_idx * Remove additional forward pass test * Move token classification basic tests to separate class * Move token classification basic tests back to monocolumninputtestcase * Move base ner tests to nerpipelinetests * Take out unused kwargs * Add back mandatory_keys argument * Add unitary tests for group_entities in _test_ner_pipeline * Fix last entity handling * Fix grouping fucntion used * Add typing to group_sub_entities and group_entities Co-authored-by: ColleterVi <36503688+ColleterVi@users.noreply.github.com>
This commit is contained in:
Родитель
82ce8488bb
Коммит
0cc4eae0e6
|
@ -1003,8 +1003,6 @@ class TokenClassificationPipeline(Pipeline):
|
||||||
labels_idx = score.argmax(axis=-1)
|
labels_idx = score.argmax(axis=-1)
|
||||||
|
|
||||||
entities = []
|
entities = []
|
||||||
entity_groups = []
|
|
||||||
entity_group_disagg = []
|
|
||||||
# Filter to labels not in `self.ignore_labels`
|
# Filter to labels not in `self.ignore_labels`
|
||||||
filtered_labels_idx = [
|
filtered_labels_idx = [
|
||||||
(idx, label_idx)
|
(idx, label_idx)
|
||||||
|
@ -1020,37 +1018,13 @@ class TokenClassificationPipeline(Pipeline):
|
||||||
"entity": self.model.config.id2label[label_idx],
|
"entity": self.model.config.id2label[label_idx],
|
||||||
"index": idx,
|
"index": idx,
|
||||||
}
|
}
|
||||||
last_idx, _ = filtered_labels_idx[-1]
|
|
||||||
if self.grouped_entities:
|
|
||||||
if not entity_group_disagg:
|
|
||||||
entity_group_disagg += [entity]
|
|
||||||
if idx == last_idx:
|
|
||||||
entity_groups += [self.group_entities(entity_group_disagg)]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
|
|
||||||
if (
|
|
||||||
entity["entity"] == entity_group_disagg[-1]["entity"]
|
|
||||||
and entity["index"] == entity_group_disagg[-1]["index"] + 1
|
|
||||||
):
|
|
||||||
entity_group_disagg += [entity]
|
|
||||||
# Group the entities at the last entity
|
|
||||||
if idx == last_idx:
|
|
||||||
entity_groups += [self.group_entities(entity_group_disagg)]
|
|
||||||
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
|
|
||||||
else:
|
|
||||||
entity_groups += [self.group_entities(entity_group_disagg)]
|
|
||||||
entity_group_disagg = [entity]
|
|
||||||
|
|
||||||
entities += [entity]
|
entities += [entity]
|
||||||
|
|
||||||
# Ensure if an entity is the latest one in the sequence it gets appended to the output
|
# Append grouped entities
|
||||||
if len(entity_group_disagg) > 0:
|
|
||||||
entity_groups.append(self.group_entities(entity_group_disagg))
|
|
||||||
|
|
||||||
# Append
|
|
||||||
if self.grouped_entities:
|
if self.grouped_entities:
|
||||||
answers += [entity_groups]
|
answers += [self.group_entities(entities)]
|
||||||
|
# Append ungrouped entities
|
||||||
else:
|
else:
|
||||||
answers += [entities]
|
answers += [entities]
|
||||||
|
|
||||||
|
@ -1058,12 +1032,12 @@ class TokenClassificationPipeline(Pipeline):
|
||||||
return answers[0]
|
return answers[0]
|
||||||
return answers
|
return answers
|
||||||
|
|
||||||
def group_entities(self, entities):
|
def group_sub_entities(self, entities: List[dict]) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns grouped entities
|
Returns grouped sub entities
|
||||||
"""
|
"""
|
||||||
# Get the last entity in the entity group
|
# Get the first entity in the entity group
|
||||||
entity = entities[-1]["entity"]
|
entity = entities[0]["entity"]
|
||||||
scores = np.mean([entity["score"] for entity in entities])
|
scores = np.mean([entity["score"] for entity in entities])
|
||||||
tokens = [entity["word"] for entity in entities]
|
tokens = [entity["word"] for entity in entities]
|
||||||
|
|
||||||
|
@ -1074,6 +1048,45 @@ class TokenClassificationPipeline(Pipeline):
|
||||||
}
|
}
|
||||||
return entity_group
|
return entity_group
|
||||||
|
|
||||||
|
def group_entities(self, entities: List[dict]) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Returns grouped entities
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity_groups = []
|
||||||
|
entity_group_disagg = []
|
||||||
|
|
||||||
|
if entities:
|
||||||
|
last_idx = entities[-1]["index"]
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
is_last_idx = entity["index"] == last_idx
|
||||||
|
if not entity_group_disagg:
|
||||||
|
entity_group_disagg += [entity]
|
||||||
|
if is_last_idx:
|
||||||
|
entity_groups += [self.group_sub_entities(entity_group_disagg)]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
|
||||||
|
# The split is meant to account for the "B" and "I" suffixes
|
||||||
|
if (
|
||||||
|
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
|
||||||
|
and entity["index"] == entity_group_disagg[-1]["index"] + 1
|
||||||
|
):
|
||||||
|
entity_group_disagg += [entity]
|
||||||
|
# Group the entities at the last entity
|
||||||
|
if is_last_idx:
|
||||||
|
entity_groups += [self.group_sub_entities(entity_group_disagg)]
|
||||||
|
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
|
||||||
|
else:
|
||||||
|
entity_groups += [self.group_sub_entities(entity_group_disagg)]
|
||||||
|
entity_group_disagg = [entity]
|
||||||
|
# If it's the last entity, add it to the entity groups
|
||||||
|
if is_last_idx:
|
||||||
|
entity_groups += [self.group_sub_entities(entity_group_disagg)]
|
||||||
|
|
||||||
|
return entity_groups
|
||||||
|
|
||||||
|
|
||||||
NerPipeline = TokenClassificationPipeline
|
NerPipeline = TokenClassificationPipeline
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ expected_fill_mask_result = [
|
||||||
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
|
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)
|
SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -156,34 +157,6 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertRaises(Exception, nlp, invalid_inputs)
|
self.assertRaises(Exception, nlp, invalid_inputs)
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_torch_ner(self):
|
|
||||||
mandatory_keys = {"entity", "word", "score"}
|
|
||||||
for model_name in NER_FINETUNED_MODELS:
|
|
||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
|
||||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_ner_grouped(self):
|
|
||||||
mandatory_keys = {"entity_group", "word", "score"}
|
|
||||||
for model_name in NER_FINETUNED_MODELS:
|
|
||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
|
|
||||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
def test_tf_ner(self):
|
|
||||||
mandatory_keys = {"entity", "word", "score"}
|
|
||||||
for model_name in NER_FINETUNED_MODELS:
|
|
||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
|
|
||||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
def test_tf_ner_grouped(self):
|
|
||||||
mandatory_keys = {"entity_group", "word", "score"}
|
|
||||||
for model_name in NER_FINETUNED_MODELS:
|
|
||||||
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
|
||||||
self._test_mono_column_pipeline(nlp, VALID_INPUTS, mandatory_keys)
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_sentiment_analysis(self):
|
def test_torch_sentiment_analysis(self):
|
||||||
mandatory_keys = {"label", "score"}
|
mandatory_keys = {"label", "score"}
|
||||||
|
@ -393,6 +366,100 @@ class QAPipelineTests(unittest.TestCase):
|
||||||
self._test_qa_pipeline(nlp)
|
self._test_qa_pipeline(nlp)
|
||||||
|
|
||||||
|
|
||||||
|
class NerPipelineTests(unittest.TestCase):
|
||||||
|
def _test_ner_pipeline(
|
||||||
|
self, nlp: Pipeline, output_keys: Iterable[str],
|
||||||
|
):
|
||||||
|
|
||||||
|
ungrouped_ner_inputs = [
|
||||||
|
[
|
||||||
|
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"},
|
||||||
|
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"},
|
||||||
|
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"},
|
||||||
|
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##új"},
|
||||||
|
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"},
|
||||||
|
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"},
|
||||||
|
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"},
|
||||||
|
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andrés"},
|
||||||
|
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"},
|
||||||
|
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"},
|
||||||
|
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"},
|
||||||
|
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"},
|
||||||
|
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"},
|
||||||
|
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"},
|
||||||
|
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
expected_grouped_ner_results = [
|
||||||
|
[
|
||||||
|
{"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"},
|
||||||
|
{"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
|
||||||
|
{"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"},
|
||||||
|
{"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"},
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertIsNotNone(nlp)
|
||||||
|
|
||||||
|
mono_result = nlp(VALID_INPUTS[0])
|
||||||
|
self.assertIsInstance(mono_result, list)
|
||||||
|
self.assertIsInstance(mono_result[0], (dict, list))
|
||||||
|
|
||||||
|
if isinstance(mono_result[0], list):
|
||||||
|
mono_result = mono_result[0]
|
||||||
|
|
||||||
|
for key in output_keys:
|
||||||
|
self.assertIn(key, mono_result[0])
|
||||||
|
|
||||||
|
multi_result = [nlp(input) for input in VALID_INPUTS]
|
||||||
|
self.assertIsInstance(multi_result, list)
|
||||||
|
self.assertIsInstance(multi_result[0], (dict, list))
|
||||||
|
|
||||||
|
if isinstance(multi_result[0], list):
|
||||||
|
multi_result = multi_result[0]
|
||||||
|
|
||||||
|
for result in multi_result:
|
||||||
|
for key in output_keys:
|
||||||
|
self.assertIn(key, result)
|
||||||
|
|
||||||
|
for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
|
||||||
|
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_ner(self):
|
||||||
|
mandatory_keys = {"entity", "word", "score"}
|
||||||
|
for model_name in NER_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
|
||||||
|
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_ner_grouped(self):
|
||||||
|
mandatory_keys = {"entity_group", "word", "score"}
|
||||||
|
for model_name in NER_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, grouped_entities=True)
|
||||||
|
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_ner(self):
|
||||||
|
mandatory_keys = {"entity", "word", "score"}
|
||||||
|
for model_name in NER_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf")
|
||||||
|
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_ner_grouped(self):
|
||||||
|
mandatory_keys = {"entity_group", "word", "score"}
|
||||||
|
for model_name in NER_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
|
||||||
|
self._test_ner_pipeline(nlp, mandatory_keys)
|
||||||
|
|
||||||
|
|
||||||
class PipelineCommonTests(unittest.TestCase):
|
class PipelineCommonTests(unittest.TestCase):
|
||||||
|
|
||||||
pipelines = SUPPORTED_TASKS.keys()
|
pipelines = SUPPORTED_TASKS.keys()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче