updates to notebooks and some evaluation logic, experiment tracking
This commit is contained in:
Родитель
fa7e1d637a
Коммит
ec7d7ac50a
|
@ -188,4 +188,5 @@ datasets/
|
|||
/model-outputs/
|
||||
/data
|
||||
|
||||
.spacy
|
||||
*.spacy
|
||||
*.pickle
|
|
@ -9,19 +9,23 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import pprint\n",
|
||||
"from faker import Faker\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.data_generator import PresidioDataGenerator\n",
|
||||
"from presidio_evaluator.data_generator.faker_extensions import (\n",
|
||||
" RecordsFaker, \n",
|
||||
" IpAddressProvider, \n",
|
||||
" NationalityProvider, \n",
|
||||
" OrganizationProvider, \n",
|
||||
" UsDriverLicenseProvider, \n",
|
||||
" AddressProviderNew\n",
|
||||
" RecordsFaker,\n",
|
||||
" IpAddressProvider,\n",
|
||||
" NationalityProvider,\n",
|
||||
" OrganizationProvider,\n",
|
||||
" UsDriverLicenseProvider,\n",
|
||||
" AgeProvider,\n",
|
||||
" AddressProviderNew,\n",
|
||||
" PhoneNumberProviderNew,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -65,12 +69,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator.data_generator import PresidioDataGenerator\n",
|
||||
"\n",
|
||||
"sentence_templates = [\n",
|
||||
" \"My name is {{name}}\",\n",
|
||||
" \"Please send it to {{address}}\",\n",
|
||||
" \"I just moved to {{city}} from {{country}}\"\n",
|
||||
" \"I just moved to {{city}} from {{country}}\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -83,8 +85,7 @@
|
|||
"\n",
|
||||
"# Print the spans of the first sample\n",
|
||||
"print(fake_records[0].fake)\n",
|
||||
"print(fake_records[0].spans)\n",
|
||||
"\n"
|
||||
"print(fake_records[0].spans)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -115,14 +116,16 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"number_of_samples = 1000\n",
|
||||
"number_of_samples = 1500\n",
|
||||
"cur_time = datetime.date.today().strftime(\"%B_%d_%Y\")\n",
|
||||
"\n",
|
||||
"output_file = f\"../data/generated_size_{number_of_samples}_date_{cur_time}.json\"\n",
|
||||
"output_conll = f\"../data/generated_size_{number_of_samples}_date_{cur_time}.tsv\"\n",
|
||||
"\n",
|
||||
"templates_file_path = \"../presidio_evaluator/data_generator/raw_data/templates.txt\"\n",
|
||||
"fake_name_generator_file = \"../presidio_evaluator/data_generator/raw_data/FakeNameGenerator.com_3000.csv\"\n",
|
||||
"fake_name_generator_file = (\n",
|
||||
" \"../presidio_evaluator/data_generator/raw_data/FakeNameGenerator.com_3000.csv\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"lower_case_ratio = 0.05"
|
||||
]
|
||||
|
@ -159,7 +162,7 @@
|
|||
"fake_name_generator_df = pd.read_csv(fake_name_generator_file)\n",
|
||||
"\n",
|
||||
"# Update to match existing templates\n",
|
||||
"PresidioDataGenerator.update_fake_name_generator_df(fake_name_generator_df)\n"
|
||||
"PresidioDataGenerator.update_fake_name_generator_df(fake_name_generator_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -178,7 +181,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# Create RecordsFaker (extension which handles records instead of independent values) and add additional specific providers\n",
|
||||
"fake = RecordsFaker(fake_name_generator_df, locale=\"en_US\")"
|
||||
"fake = RecordsFaker(records=fake_name_generator_df, locale=\"en_US\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -194,11 +197,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fake.add_provider(IpAddressProvider) # Both Ipv4 and IPv6 IP addresses\n",
|
||||
"fake.add_provider(NationalityProvider) # Read countries + nationalities from file\n",
|
||||
"fake.add_provider(OrganizationProvider) # Read organization names from file\n",
|
||||
"fake.add_provider(UsDriverLicenseProvider) # Read US driver license numbers from file\n",
|
||||
"fake.add_provider(AddressProviderNew) # Extend the default address formats Faker supports"
|
||||
"fake.add_provider(IpAddressProvider) # Both Ipv4 and IPv6 IP addresses\n",
|
||||
"fake.add_provider(NationalityProvider) # Read countries + nationalities from file\n",
|
||||
"fake.add_provider(OrganizationProvider) # Read organization names from file\n",
|
||||
"fake.add_provider(UsDriverLicenseProvider) # Read US driver license numbers from file\n",
|
||||
"fake.add_provider(AgeProvider) # Age values (unavailable on Faker)\n",
|
||||
"fake.add_provider(AddressProviderNew) # Extend the default address formats\n",
|
||||
"fake.add_provider(PhoneNumberProviderNew) # Extend the default phone number formats"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -223,13 +228,16 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# Create Presidio Data Generator\n",
|
||||
"data_generator = PresidioDataGenerator(custom_faker=fake, lower_case_ratio=lower_case_ratio)\n",
|
||||
"data_generator = PresidioDataGenerator(\n",
|
||||
" custom_faker=fake, lower_case_ratio=lower_case_ratio\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Create entity aliases (e.g. if faker supports \"name\" but templates contain \"person\").\n",
|
||||
"data_generator.add_provider_alias(provider_name=\"name\", new_name=\"person\")\n",
|
||||
"data_generator.add_provider_alias(provider_name=\"credit_card_number\", new_name=\"credit_card\")\n",
|
||||
"data_generator.add_provider_alias(provider_name=\"date_of_birth\", new_name=\"birthday\")\n",
|
||||
"\n"
|
||||
"data_generator.add_provider_alias(\n",
|
||||
" provider_name=\"credit_card_number\", new_name=\"credit_card\"\n",
|
||||
")\n",
|
||||
"data_generator.add_provider_alias(provider_name=\"date_of_birth\", new_name=\"birthday\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -278,13 +286,15 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import Counter\n",
|
||||
"count_per_template_id = Counter([sample.template_id for sample in fake_records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f\"Total: {sum(count_per_template_id.values())}\")\n",
|
||||
"print(f\"Mean numbers of records per template: {sum(count_per_template_id.values())/len(count_per_template_id)}\")\n",
|
||||
"print(f\"Median numbers of records per template: {np.median(list(count_per_template_id.values()))}\")\n",
|
||||
"print(\n",
|
||||
" f\"Avg number of records per template: {np.mean(list(count_per_template_id.values()))}\"\n",
|
||||
")\n",
|
||||
"print(\n",
|
||||
" f\"Median number of records per template: {np.median(list(count_per_template_id.values()))}\"\n",
|
||||
")\n",
|
||||
"print(f\"Std: {np.std(list(count_per_template_id.values()))}\")"
|
||||
]
|
||||
},
|
||||
|
@ -313,7 +323,7 @@
|
|||
"for record in fake_records:\n",
|
||||
" count_per_entity.update(Counter([span.type for span in record.spans]))\n",
|
||||
"\n",
|
||||
"count_per_entity\n"
|
||||
"count_per_entity"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -324,7 +334,7 @@
|
|||
}
|
||||
},
|
||||
"source": [
|
||||
"#### Translate tags to Presidio's supported entities (optional)"
|
||||
"#### Translate tags from Faker's to Presidio's (optional)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -337,44 +347,51 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"translator = {'person': \"PERSON\",\n",
|
||||
" 'ip_address': \"IP_ADDRESS\",\n",
|
||||
" 'us_driver_license': \"US_DRIVER_LICENSE\",\n",
|
||||
" 'organization': \"ORGANIZATION\",\n",
|
||||
" 'name_female': \"PERSON\",\n",
|
||||
" 'address': \"ADDRESS\",\n",
|
||||
" 'country': \"LOCATION\",\n",
|
||||
" 'credit_card_number': \"CREDIT_CARD\",\n",
|
||||
" 'city': \"LOCATION\",\n",
|
||||
" 'street_name': \"ADDRESS\",\n",
|
||||
" 'building_number': \"ADDRESS\",\n",
|
||||
" 'name': \"PERSON\",\n",
|
||||
" 'iban': \"IBAN_CODE\",\n",
|
||||
" 'last_name': \"PERSON\",\n",
|
||||
" 'last_name_male': \"PERSON\",\n",
|
||||
" 'last_name_female': \"PERSON\",\n",
|
||||
" 'first_name': \"PERSON\",\n",
|
||||
" 'first_name_male': \"PERSON\",\n",
|
||||
" 'first_name_female': \"PERSON\",\n",
|
||||
" 'phone_number': \"PHONE_NUMBER\",\n",
|
||||
" 'url': \"DOMAIN_NAME\",\n",
|
||||
" 'ssn': \"US_SSN\",\n",
|
||||
" 'email': \"EMAIL_ADDRESS\",\n",
|
||||
" 'date_time': \"DATE_TIME\",\n",
|
||||
" 'date_of_birth': \"DATE_TIME\",\n",
|
||||
" 'day_of_week': \"DATE_TIME\",\n",
|
||||
" 'name_male': \"PERSON\",\n",
|
||||
" 'prefix_male': \"TITLE\",\n",
|
||||
" 'prefix_female': \"TITLE\",\n",
|
||||
" 'prefix': \"TITLE\",\n",
|
||||
" 'nationality': \"LOCATION\",\n",
|
||||
" 'first_name_nonbinary': \"PERSON\",\n",
|
||||
" 'postcode': \"ADDRESS\",\n",
|
||||
" 'secondary_address': \"ADDRESS\",\n",
|
||||
" 'company': \"ORGANIZATION\",\n",
|
||||
" 'job': \"TITLE\",\n",
|
||||
" 'zipcode': \"ADDRESS\",\n",
|
||||
" 'state_abbr': \"ADDRESS\"}\n",
|
||||
"translator = {\n",
|
||||
" \"person\": \"PERSON\",\n",
|
||||
" \"ip_address\": \"IP_ADDRESS\",\n",
|
||||
" \"us_driver_license\": \"US_DRIVER_LICENSE\",\n",
|
||||
" \"organization\": \"ORGANIZATION\",\n",
|
||||
" \"name_female\": \"PERSON\",\n",
|
||||
" \"address\": \"STREET_ADDRESS\",\n",
|
||||
" \"country\": \"GPE\",\n",
|
||||
" \"state\": \"GPE\",\n",
|
||||
" \"credit_card_number\": \"CREDIT_CARD\",\n",
|
||||
" \"city\": \"GPE\",\n",
|
||||
" \"street_name\": \"STREET_ADDRESS\",\n",
|
||||
" \"building_number\": \"STREET_ADDRESS\",\n",
|
||||
" \"name\": \"PERSON\",\n",
|
||||
" \"iban\": \"IBAN_CODE\",\n",
|
||||
" \"last_name\": \"PERSON\",\n",
|
||||
" \"last_name_male\": \"PERSON\",\n",
|
||||
" \"last_name_female\": \"PERSON\",\n",
|
||||
" \"first_name\": \"PERSON\",\n",
|
||||
" \"first_name_male\": \"PERSON\",\n",
|
||||
" \"first_name_female\": \"PERSON\",\n",
|
||||
" \"phone_number\": \"PHONE_NUMBER\",\n",
|
||||
" \"url\": \"DOMAIN_NAME\",\n",
|
||||
" \"ssn\": \"US_SSN\",\n",
|
||||
" \"email\": \"EMAIL_ADDRESS\",\n",
|
||||
" \"date_time\": \"DATE_TIME\",\n",
|
||||
" \"date_of_birth\": \"DATE_TIME\",\n",
|
||||
" \"day_of_week\": \"DATE_TIME\",\n",
|
||||
" \"year\": \"DATE_TIME\",\n",
|
||||
" \"name_male\": \"PERSON\",\n",
|
||||
" \"prefix_male\": \"TITLE\",\n",
|
||||
" \"prefix_female\": \"TITLE\",\n",
|
||||
" \"prefix\": \"TITLE\",\n",
|
||||
" \"nationality\": \"NRP\",\n",
|
||||
" \"nation_woman\": \"NRP\",\n",
|
||||
" \"nation_man\": \"NRP\",\n",
|
||||
" \"nation_plural\": \"NRP\",\n",
|
||||
" \"first_name_nonbinary\": \"PERSON\",\n",
|
||||
" \"postcode\": \"STREET_ADDRESS\",\n",
|
||||
" \"secondary_address\": \"STREET_ADDRESS\",\n",
|
||||
" \"job\": \"TITLE\",\n",
|
||||
" \"zipcode\": \"ZIP_CODE\",\n",
|
||||
" \"state_abbr\": \"GPE\",\n",
|
||||
" \"age\": \"AGE\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"count_per_entity_new = Counter()\n",
|
||||
"for record in fake_records:\n",
|
||||
|
@ -382,8 +399,7 @@
|
|||
" span.type = translator[span.type]\n",
|
||||
" count_per_entity_new[span.type] += 1\n",
|
||||
"\n",
|
||||
"count_per_entity_new\n",
|
||||
"\n"
|
||||
"count_per_entity_new"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -404,7 +420,10 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"input_samples = [InputSample.from_faker_spans_result(faker_spans_result=fake_record) for fake_record in fake_records]"
|
||||
"input_samples = [\n",
|
||||
" InputSample.from_faker_spans_result(faker_spans_result=fake_record)\n",
|
||||
" for fake_record in fake_records\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -465,7 +484,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"conll.to_csv(output_conll,sep=\"\\t\")"
|
||||
"conll.to_csv(output_conll, sep=\"\\t\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -524,4 +543,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
|
@ -33,6 +33,7 @@
|
|||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"%matplotlib inline"
|
||||
]
|
||||
},
|
||||
|
@ -49,7 +50,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pii_df = pd.read_csv(\"../presidio_evaluator/data_generator/raw_data/FakeNameGenerator.com_3000.csv\",encoding=\"utf-8\")"
|
||||
"pii_df = pd.read_csv(\n",
|
||||
" \"../presidio_evaluator/data_generator/raw_data/FakeNameGenerator.com_3000.csv\",\n",
|
||||
" encoding=\"utf-8\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -83,12 +87,15 @@
|
|||
"source": [
|
||||
"from wordcloud import WordCloud\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def series_to_wordcloud(series):\n",
|
||||
" freqs = series.value_counts()\n",
|
||||
" wordcloud = WordCloud(background_color='white',width=800,height=400).generate_from_frequencies(freqs)\n",
|
||||
" wordcloud = WordCloud(\n",
|
||||
" background_color=\"white\", width=800, height=400\n",
|
||||
" ).generate_from_frequencies(freqs)\n",
|
||||
" fig = plt.figure(figsize=(16, 8))\n",
|
||||
" plt.suptitle(\"{} word cloud\".format(series.name))\n",
|
||||
" plt.imshow(wordcloud, interpolation='bilinear')\n",
|
||||
" plt.imshow(wordcloud, interpolation=\"bilinear\")\n",
|
||||
" plt.axis(\"off\")"
|
||||
]
|
||||
},
|
||||
|
@ -150,7 +157,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"synth = InputSample.read_dataset_json(\"../data/synth_dataset.json\")"
|
||||
"synth = InputSample.read_dataset_json(\"../data/synth_dataset_v2.json\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -159,13 +166,19 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_entity_values_from_sample(sample,entity_types):\n",
|
||||
" name_entities = [span.entity_value for span in sample.spans if span.entity_type in entity_types]\n",
|
||||
"def get_entity_values_from_sample(sample, entity_types):\n",
|
||||
" name_entities = [\n",
|
||||
" span.entity_value for span in sample.spans if span.entity_type in entity_types\n",
|
||||
" ]\n",
|
||||
" return name_entities\n",
|
||||
" \n",
|
||||
"names = [get_entity_values_from_sample(sample,['PERSON','FIRST_NAME','LAST_NAME']) for sample in synth]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"names = [\n",
|
||||
" get_entity_values_from_sample(sample, [\"PERSON\", \"FIRST_NAME\", \"LAST_NAME\"])\n",
|
||||
" for sample in synth\n",
|
||||
"]\n",
|
||||
"names = [item for sublist in names for item in sublist]\n",
|
||||
"series_to_wordcloud(pd.Series(names,name='PERSON, FIRST_NAME, LAST_NAME'))"
|
||||
"series_to_wordcloud(pd.Series(names, name=\"PERSON, FIRST_NAME, LAST_NAME\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -174,9 +187,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"countries = [get_entity_values_from_sample(sample,['LOCATION']) for sample in synth]\n",
|
||||
"countries = [get_entity_values_from_sample(sample, [\"LOCATION\"]) for sample in synth]\n",
|
||||
"countries = [item for sublist in countries for item in sublist]\n",
|
||||
"series_to_wordcloud(pd.Series(countries,name='LOCATION'))"
|
||||
"series_to_wordcloud(pd.Series(countries, name=\"LOCATION\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -185,9 +198,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"orgs = [get_entity_values_from_sample(sample,['ORGANIZATION']) for sample in synth]\n",
|
||||
"orgs = [get_entity_values_from_sample(sample, [\"ORGANIZATION\"]) for sample in synth]\n",
|
||||
"orgs = [item for sublist in orgs for item in sublist]\n",
|
||||
"series_to_wordcloud(pd.Series(orgs,name='ORGANIZATION'))"
|
||||
"series_to_wordcloud(pd.Series(orgs, name=\"ORGANIZATION\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -215,15 +228,6 @@
|
|||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
},
|
||||
"pycharm": {
|
||||
"stem_cell": {
|
||||
"cell_type": "raw",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": []
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -1,163 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Train/Test/Validation split of input samples. \n",
|
||||
"This notebook shows how train/test/split is being made on a List[InputSample]\n",
|
||||
"\n",
|
||||
"This is different from the normal split since we don't want sentences generated from the same pattern to be in more than one set. (Applicable only if the dataset was generated from templates)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.validation import split_dataset, save_to_json\n",
|
||||
"from datetime import date\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Load full dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"all_samples = InputSample.read_dataset_json(\"../data/synth_dataset.json\")\n",
|
||||
"print(len(all_samples))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Split to train/test/dev"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TRAIN_TEST_VAL_RATIOS = [0.7,0.2,0.1]\n",
|
||||
"\n",
|
||||
"train, test, validation = split_dataset(all_samples,TRAIN_TEST_VAL_RATIOS)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Train/Test only (no validation)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#TRAIN_TEST_RATIOS = [0.7,0.3]\n",
|
||||
"#train,test = split_dataset(all_sampleTRAIN_TEST_RATIOSEST_RATIOS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Save the different sets to files"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATE_DATE = date.today().strftime(\"%b-%d-%Y\")\n",
|
||||
"\n",
|
||||
"save_to_json(train,\"../data/train_{}.json\".format(DATE_DATE))\n",
|
||||
"save_to_json(test,\"../data/test_{}.json\".format(DATE_DATE))\n",
|
||||
"save_to_json(validation,\"../data/validation_{}.json\".format(DATE_DATE))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(len(train))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(len(test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(len(validation))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"assert len(train) + len(test) + len(validation) == len(all_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "presidio",
|
||||
"language": "python",
|
||||
"name": "presidio"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -36,7 +36,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"all_samples = InputSample.read_dataset_json(\"../data/synth_dataset.json\")\n",
|
||||
"all_samples = InputSample.read_dataset_json(\"../data/synth_dataset_v2.json\")\n",
|
||||
"print(len(all_samples))"
|
||||
]
|
||||
},
|
||||
|
@ -53,9 +53,9 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TRAIN_TEST_VAL_RATIOS = [0.7,0.2,0.1]\n",
|
||||
"TRAIN_TEST_VAL_RATIOS = [0.7, 0.2, 0.1]\n",
|
||||
"\n",
|
||||
"train, test, validation = split_dataset(all_samples,TRAIN_TEST_VAL_RATIOS)\n"
|
||||
"train, test, validation = split_dataset(all_samples, TRAIN_TEST_VAL_RATIOS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -71,9 +71,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"#TRAIN_TEST_RATIOS = [0.7,0.3]\n",
|
||||
"#train,test = split_dataset(all_sampleTRAIN_TEST_RATIOSEST_RATIOS)"
|
||||
"# TRAIN_TEST_RATIOS = [0.7,0.3]\n",
|
||||
"# train,test = split_dataset(all_sampleTRAIN_TEST_RATIOSEST_RATIOS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -91,9 +90,9 @@
|
|||
"source": [
|
||||
"DATE_DATE = date.today().strftime(\"%b-%d-%Y\")\n",
|
||||
"\n",
|
||||
"save_to_json(train,\"../data/train_{}.json\".format(DATE_DATE))\n",
|
||||
"save_to_json(test,\"../data/test_{}.json\".format(DATE_DATE))\n",
|
||||
"save_to_json(validation,\"../data/validation_{}.json\".format(DATE_DATE))\n"
|
||||
"save_to_json(train, \"../data/train_{}.json\".format(DATE_DATE))\n",
|
||||
"save_to_json(test, \"../data/test_{}.json\".format(DATE_DATE))\n",
|
||||
"save_to_json(validation, \"../data/validation_{}.json\".format(DATE_DATE))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -132,6 +131,13 @@
|
|||
"assert len(train) + len(test) + len(validation) == len(all_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
|
|
@ -1,701 +1,266 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.evaluation import ModelError, Evaluator\n",
|
||||
"from presidio_evaluator.models import PresidioAnalyzerWrapper\n",
|
||||
"from presidio_analyzer import AnalyzerEngine\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"pd.options.display.max_columns = None\n",
|
||||
"pd.options.display.width=None"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "847acd88",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Evaluate Presidio Analyzer\n",
|
||||
"This notebook runs the PresidioAnalyzerEvaluator class on top of synthetic data.\n",
|
||||
"\n",
|
||||
"One can perform the following changes:\n",
|
||||
"1. Replace the synthetic data creation with real data or with other type of synthetic data\n",
|
||||
"2. Adapt the Presidio `AnalyzerEngine` to a specific engine with a different set of recognizers or configured to be used on different languages\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### A. Read dataset for evaluation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\r",
|
||||
"tokenizing input: 0%| | 0/3000 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"loading model en_core_web_sm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tokenizing input: 100%|███████████████████████████████████████████████████████████| 3000/3000 [00:25<00:00, 118.48it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Read 3000 samples\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Full text: I want to update my primary and secondary address to the same: 19 Ingelbrecht Knudssøns gate 222\n",
|
||||
" Suite 598\n",
|
||||
" OSLO\n",
|
||||
" Bangladesh\n",
|
||||
"Spans: [Type: ADDRESS, value: 19 Ingelbrecht Knudssøns gate 222\n",
|
||||
" Suite 598\n",
|
||||
" OSLO\n",
|
||||
" Bangladesh, start: 63, end: 125]\n",
|
||||
"Tokens: I want to update my primary and secondary address to the same: 19 Ingelbrecht Knudssøns gate 222\n",
|
||||
" Suite 598\n",
|
||||
" OSLO\n",
|
||||
" Bangladesh\n",
|
||||
"Tags: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS', 'ADDRESS']"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"input_samples = InputSample.read_dataset_json(\"../data/synth_dataset.json\")\n",
|
||||
"print(\"Read {} samples\".format(len(input_samples)))\n",
|
||||
"input_samples[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### B. Descriptive statistics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Counter({'ADDRESS': 1512,\n",
|
||||
" 'LOCATION': 817,\n",
|
||||
" 'PHONE_NUMBER': 264,\n",
|
||||
" 'PERSON': 1800,\n",
|
||||
" 'CREDIT_CARD': 313,\n",
|
||||
" 'IBAN_CODE': 41,\n",
|
||||
" 'US_SSN': 44,\n",
|
||||
" 'ORGANIZATION': 448,\n",
|
||||
" 'DOMAIN_NAME': 40,\n",
|
||||
" 'EMAIL_ADDRESS': 71,\n",
|
||||
" 'PREFIX': 43,\n",
|
||||
" 'DATE_TIME': 112,\n",
|
||||
" 'TITLE': 23,\n",
|
||||
" 'IP_ADDRESS': 24,\n",
|
||||
" 'US_DRIVER_LICENSE': 13})"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"flatten = lambda l: [item for sublist in l for item in sublist]\n",
|
||||
"\n",
|
||||
"count_per_entity = Counter([span.entity_type for span in flatten([input_sample.spans for input_sample in input_samples])])\n",
|
||||
"count_per_entity"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### C. Remove entities not supported by Presidio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'CREDIT_CARD',\n",
|
||||
" 'DATE_TIME',\n",
|
||||
" 'DOMAIN_NAME',\n",
|
||||
" 'EMAIL_ADDRESS',\n",
|
||||
" 'IBAN_CODE',\n",
|
||||
" 'IP_ADDRESS',\n",
|
||||
" 'LOCATION',\n",
|
||||
" 'ORGANIZATION',\n",
|
||||
" 'PERSON',\n",
|
||||
" 'PHONE_NUMBER',\n",
|
||||
" 'US_DRIVER_LICENSE',\n",
|
||||
" 'US_SSN'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"entities_to_ignore = {\"ADDRESS\", \"TITLE\", \"PREFIX\"}\n",
|
||||
"entities_to_keep = set(count_per_entity.keys()) - entities_to_ignore\n",
|
||||
"entities_to_keep"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### D. Run the presidio-evaluator framework with Presidio's API as the 'model' at test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Entities supported by this Presidio Analyzer instance:\n",
|
||||
"MEDICAL_LICENSE, DOMAIN_NAME, UK_NHS, AU_ACN, CRYPTO, CREDIT_CARD, AU_ABN, US_ITIN, LOCATION, NRP, US_DRIVER_LICENSE, PHONE_NUMBER, PERSON, AU_TFN, EMAIL_ADDRESS, US_SSN, IP_ADDRESS, US_PASSPORT, US_BANK_NUMBER, SG_NRIC_FIN, AU_MEDICARE, IBAN_CODE, DATE_TIME\n",
|
||||
"Entity ORGANIZATION is not supported by this instance of Presidio Analyzer Engine\n",
|
||||
"Added ORGANIZATION as a supported entity from spaCy/Stanza\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Evaluating <class 'presidio_evaluator.evaluation.evaluator.Evaluator'>: 100%|██████| 3000/3000 [00:31<00:00, 95.34it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"presidio = PresidioAnalyzerWrapper(entities_to_keep=list(entities_to_keep))\n",
|
||||
"evaluator = Evaluator(model=presidio)\n",
|
||||
"evaluted_samples = evaluator.evaluate_all(input_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### D. Extract statistics\n",
|
||||
"- Presicion, recall and F measure are calculated based on a PII/Not PII binary classification per token.\n",
|
||||
"- Specific entity recall and precision are calculated on the specific PII entity level."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"evaluation_result = evaluator.calculate_score(evaluted_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" Entity Precision Recall Number of samples\n",
|
||||
" CREDIT_CARD 100.00% 100.00% 2728\n",
|
||||
" DATE_TIME 14.72% 89.14% 40\n",
|
||||
" DOMAIN_NAME 100.00% 82.50% 41\n",
|
||||
" EMAIL_ADDRESS 100.00% 100.00% 313\n",
|
||||
" IBAN_CODE 100.00% 90.24% 1114\n",
|
||||
" IP_ADDRESS 91.18% 83.78% 71\n",
|
||||
" LOCATION 53.84% 35.91% 220\n",
|
||||
" ORGANIZATION 24.24% 53.62% 897\n",
|
||||
" PERSON 68.44% 82.73% 37\n",
|
||||
" PHONE_NUMBER 99.40% 48.07% 14\n",
|
||||
" US_DRIVER_LICENSE 88.89% 57.14% 1034\n",
|
||||
" US_SSN 98.62% 97.73% 267\n",
|
||||
"PII F measure: 0.7567568887066222\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"evaluation_result.print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### G. Analyze wrong predictions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = evaluation_result.model_errors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Most common false positive tokens:\n",
|
||||
"[('\\n', 202), ('the', 110), ('\\n ', 96), ('last', 68), ('year', 48)]\n",
|
||||
"Example sentence with each FP token:\n",
|
||||
"how do i change my address to unit 9456 box 8731\n",
|
||||
"dpo ap 71610 for post mail?\n",
|
||||
"Muslija began writing as a teenager, publishing her first story, \"The Dimensions of a Shadow\", in 1950 while studying English and journalism at the University of El Tanque.\n",
|
||||
"As promised, here's Božica's address:\n",
|
||||
"\n",
|
||||
"99 Sahankatu 77\n",
|
||||
"Ortovero\n",
|
||||
", SV\n",
|
||||
" Nigeria 21148\n",
|
||||
"At my suggestion, one morning over breakfast, she agreed, and on the last Sunday before Labor Day we returned to Los Angeles by helicopter.\n",
|
||||
"Ewan spent a year at BBC as the assistant to Aaron Panina, and the following year at Sanders-Gill in Seguin, which later became Weather Decision Technologies in 1965.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ModelError.most_common_fp_tokens(errors,n=5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>error_type</th>\n",
|
||||
" <th>annotation</th>\n",
|
||||
" <th>prediction</th>\n",
|
||||
" <th>token</th>\n",
|
||||
" <th>full_text</th>\n",
|
||||
" <th>0</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>8731</td>\n",
|
||||
" <td>how do i change my address to unit 9456 box 87...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>\\n</td>\n",
|
||||
" <td>how do i change my address to unit 9456 box 87...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>dpo</td>\n",
|
||||
" <td>how do i change my address to unit 9456 box 87...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>ap</td>\n",
|
||||
" <td>how do i change my address to unit 9456 box 87...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>71610</td>\n",
|
||||
" <td>how do i change my address to unit 9456 box 87...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1224</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>this</td>\n",
|
||||
" <td>My card 5115922521155230 is expiring this mont...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1225</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>month</td>\n",
|
||||
" <td>My card 5115922521155230 is expiring this mont...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1226</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>33649</td>\n",
|
||||
" <td>As promised, here's Zlata's address:\\n\\n29 Rue...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1227</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>Follow up with Edward Baranova in 2 months.</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1228</th>\n",
|
||||
" <td>FP</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>months</td>\n",
|
||||
" <td>Follow up with Edward Baranova in 2 months.</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>1229 rows × 6 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" error_type annotation prediction token \\\n",
|
||||
"0 FP O DATE_TIME 8731 \n",
|
||||
"1 FP O DATE_TIME \\n \n",
|
||||
"2 FP O DATE_TIME dpo \n",
|
||||
"3 FP O DATE_TIME ap \n",
|
||||
"4 FP O DATE_TIME 71610 \n",
|
||||
"... ... ... ... ... \n",
|
||||
"1224 FP O DATE_TIME this \n",
|
||||
"1225 FP O DATE_TIME month \n",
|
||||
"1226 FP O DATE_TIME 33649 \n",
|
||||
"1227 FP O DATE_TIME 2 \n",
|
||||
"1228 FP O DATE_TIME months \n",
|
||||
"\n",
|
||||
" full_text 0 \n",
|
||||
"0 how do i change my address to unit 9456 box 87... None \n",
|
||||
"1 how do i change my address to unit 9456 box 87... None \n",
|
||||
"2 how do i change my address to unit 9456 box 87... None \n",
|
||||
"3 how do i change my address to unit 9456 box 87... None \n",
|
||||
"4 how do i change my address to unit 9456 box 87... None \n",
|
||||
"... ... ... \n",
|
||||
"1224 My card 5115922521155230 is expiring this mont... None \n",
|
||||
"1225 My card 5115922521155230 is expiring this mont... None \n",
|
||||
"1226 As promised, here's Zlata's address:\\n\\n29 Rue... None \n",
|
||||
"1227 Follow up with Edward Baranova in 2 months. None \n",
|
||||
"1228 Follow up with Edward Baranova in 2 months. None \n",
|
||||
"\n",
|
||||
"[1229 rows x 6 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fps_df = ModelError.get_fps_dataframe(errors,entity='DATE_TIME')\n",
|
||||
"if fps_df is not None:\n",
|
||||
" fps_df[['full_text','token','prediction']]\n",
|
||||
"fps_df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>error_type</th>\n",
|
||||
" <th>annotation</th>\n",
|
||||
" <th>prediction</th>\n",
|
||||
" <th>token</th>\n",
|
||||
" <th>full_text</th>\n",
|
||||
" <th>0</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>Wrong entity</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>0910</td>\n",
|
||||
" <td>Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>Wrong entity</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>-</td>\n",
|
||||
" <td>Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>Wrong entity</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>5877671</td>\n",
|
||||
" <td>Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>Wrong entity</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>-</td>\n",
|
||||
" <td>Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>Wrong entity</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>DATE_TIME</td>\n",
|
||||
" <td>4466x8827</td>\n",
|
||||
" <td>Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>532</th>\n",
|
||||
" <td>FN</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>81</td>\n",
|
||||
" <td>Kelly Björgvinsdóttir\\nAdaptive\\n63 Via Verban...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>533</th>\n",
|
||||
" <td>FN</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>21</td>\n",
|
||||
" <td>Laura Gorski\\nMinistry Of Agriculture\\n07 57 a...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>534</th>\n",
|
||||
" <td>FN</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>232</td>\n",
|
||||
" <td>Laura Gorski\\nMinistry Of Agriculture\\n07 57 a...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>535</th>\n",
|
||||
" <td>FN</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>945</td>\n",
|
||||
" <td>Laura Gorski\\nMinistry Of Agriculture\\n07 57 a...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>536</th>\n",
|
||||
" <td>FN</td>\n",
|
||||
" <td>PHONE_NUMBER</td>\n",
|
||||
" <td>O</td>\n",
|
||||
" <td>1338</td>\n",
|
||||
" <td>Laura Gorski\\nMinistry Of Agriculture\\n07 57 a...</td>\n",
|
||||
" <td>None</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>537 rows × 6 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" error_type annotation prediction token \\\n",
|
||||
"0 Wrong entity PHONE_NUMBER DATE_TIME 0910 \n",
|
||||
"1 Wrong entity PHONE_NUMBER DATE_TIME - \n",
|
||||
"2 Wrong entity PHONE_NUMBER DATE_TIME 5877671 \n",
|
||||
"3 Wrong entity PHONE_NUMBER DATE_TIME - \n",
|
||||
"4 Wrong entity PHONE_NUMBER DATE_TIME 4466x8827 \n",
|
||||
".. ... ... ... ... \n",
|
||||
"532 FN PHONE_NUMBER O 81 \n",
|
||||
"533 FN PHONE_NUMBER O 21 \n",
|
||||
"534 FN PHONE_NUMBER O 232 \n",
|
||||
"535 FN PHONE_NUMBER O 945 \n",
|
||||
"536 FN PHONE_NUMBER O 1338 \n",
|
||||
"\n",
|
||||
" full_text 0 \n",
|
||||
"0 Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5... None \n",
|
||||
"1 Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5... None \n",
|
||||
"2 Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5... None \n",
|
||||
"3 Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5... None \n",
|
||||
"4 Terry Cardoso PhD\\n\\n65 Bodbysund 61\\n Suite 5... None \n",
|
||||
".. ... ... \n",
|
||||
"532 Kelly Björgvinsdóttir\\nAdaptive\\n63 Via Verban... None \n",
|
||||
"533 Laura Gorski\\nMinistry Of Agriculture\\n07 57 a... None \n",
|
||||
"534 Laura Gorski\\nMinistry Of Agriculture\\n07 57 a... None \n",
|
||||
"535 Laura Gorski\\nMinistry Of Agriculture\\n07 57 a... None \n",
|
||||
"536 Laura Gorski\\nMinistry Of Agriculture\\n07 57 a... None \n",
|
||||
"\n",
|
||||
"[537 rows x 6 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fns_df = ModelError.get_fns_dataframe(errors,entity='PHONE_NUMBER')\n",
|
||||
"fns_df"
|
||||
"Evaluate Presidio Analyzer using the Presidio Evaluator framework"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae85cae9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from pprint import pprint\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"from presidio_evaluator.models import PresidioAnalyzerWrapper\n",
|
||||
"from presidio_evaluator.experiment_tracking import get_experiment_tracker\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"pd.set_option('display.max_columns', None) \n",
|
||||
"pd.set_option('display.max_rows', None) \n",
|
||||
"pd.set_option('display.max_colwidth', None)\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "736fdd23",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select data for evaluation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f4cbd55c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_name = \"synth_dataset_v2.json\"\n",
|
||||
"dataset = InputSample.read_dataset_json(Path(Path.cwd().parent, \"data\", dataset_name))\n",
|
||||
"print(len(dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c164ea07",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entity_counter = Counter()\n",
|
||||
"for sample in dataset:\n",
|
||||
" for tag in sample.tags:\n",
|
||||
" entity_counter[tag] += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "77aedae6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"Count per entity:\")\n",
|
||||
"pprint(entity_counter.most_common())\n",
|
||||
"\n",
|
||||
"print(\"\\nExample sentence:\")\n",
|
||||
"print(dataset[1])\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max number of tokens in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.tokens) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.tokens) for sample in dataset])}\")\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max sentence length in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.full_text) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.full_text) for sample in dataset])}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aae4c379",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run evaluation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cf65af8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"Evaluating Presidio Analyzer\")\n",
|
||||
"\n",
|
||||
"experiment = get_experiment_tracker()\n",
|
||||
"model_name = \"Presidio Analyzer\"\n",
|
||||
"model = PresidioAnalyzerWrapper()\n",
|
||||
"\n",
|
||||
"evaluator = Evaluator(model=model)\n",
|
||||
"dataset = Evaluator.align_entity_types(\n",
|
||||
" deepcopy(dataset), entities_mapping=PresidioAnalyzerWrapper.presidio_entities_map\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"evaluation_results = evaluator.evaluate_all(dataset)\n",
|
||||
"results = evaluator.calculate_score(evaluation_results)\n",
|
||||
"\n",
|
||||
"# update params tracking\n",
|
||||
"params = {\"dataset_name\": dataset_name, \"model_name\": model_name}\n",
|
||||
"params.update(model.to_log())\n",
|
||||
"experiment.log_parameters(params)\n",
|
||||
"experiment.log_dataset_hash(dataset)\n",
|
||||
"experiment.log_metrics(results.to_log())\n",
|
||||
"entities, confmatrix = results.to_confusion_matrix()\n",
|
||||
"experiment.log_confusion_matrix(matrix=confmatrix, labels=entities)\n",
|
||||
"\n",
|
||||
"print(\"Confusion matrix:\")\n",
|
||||
"print(pd.DataFrame(confmatrix, columns=entities, index=entities))\n",
|
||||
"\n",
|
||||
"print(\"Precision and recall\")\n",
|
||||
"print(results)\n",
|
||||
"\n",
|
||||
"# end experiment\n",
|
||||
"experiment.end()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "070f8287",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Results analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd04db3e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sent = 'I am taiwanese but I live in Cambodia.'\n",
|
||||
"#sent = input(\"Enter sentence: \")\n",
|
||||
"model.predict(InputSample(full_text=sent))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "08ae9bda",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = results.model_errors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "819eb905",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### False positives"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "98f4802e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Most false positive tokens:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "640037af",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ModelError.most_common_fp_tokens(errors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "03432506",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fps_df = ModelError.get_fps_dataframe(errors, entity=[\"LOCATION\"])\n",
|
||||
"fps_df[[\"full_text\", \"token\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d0852513",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. False negative examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "afae40fc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ModelError.most_common_fn_tokens(errors, n=50, entity=[\"PERSON\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "44ed6416",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"More FN analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7abfcbe9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df = ModelError.get_fns_dataframe(errors,entity=['PHONE_NUMBER'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ae73b2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df[[\"full_text\", \"token\", \"annotation\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "24c8be1d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"All errors:\\n\")\n",
|
||||
"[print(error,\"\\n\") for error in errors]"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -718,5 +283,5 @@
|
|||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -95,8 +95,8 @@
|
|||
"source": [
|
||||
"original_text = \"Hi my name is Doug Funny and this is my website: https://www.dougf.io\"\n",
|
||||
"\n",
|
||||
"presidio_response = analyzer.analyze(original_text,language='en')\n",
|
||||
"presidio_response\n"
|
||||
"presidio_response = analyzer.analyze(original_text, language=\"en\")\n",
|
||||
"presidio_response"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -126,7 +126,9 @@
|
|||
"source": [
|
||||
"# Simple pseudonymization\n",
|
||||
"\n",
|
||||
"pseudonymizer.pseudonymize(original_text=original_text, presidio_response=presidio_response,count=5)"
|
||||
"pseudonymizer.pseudonymize(\n",
|
||||
" original_text=original_text, presidio_response=presidio_response, count=5\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -159,13 +161,15 @@
|
|||
"\n",
|
||||
"text = \"Our son R2D2 used to work in Germany\"\n",
|
||||
"\n",
|
||||
"response = analyzer.analyze(text=text,language='en')\n",
|
||||
"response = analyzer.analyze(text=text, language=\"en\")\n",
|
||||
"print(f\"Presidio' response: {response}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"fake_samples = pseudonymizer.pseudonymize(original_text=text,presidio_response=response,count=5)\n",
|
||||
"fake_samples = pseudonymizer.pseudonymize(\n",
|
||||
" original_text=text, presidio_response=response, count=5\n",
|
||||
")\n",
|
||||
"print(f\"-------------\\nFake examples:\\n\")\n",
|
||||
"print(*fake_samples, sep = \"\\n\")"
|
||||
"print(*fake_samples, sep=\"\\n\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload"
|
||||
]
|
||||
},
|
||||
|
@ -41,7 +42,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_DATE = 'Dec-19-2021'"
|
||||
"DATA_DATE = \"Dec-19-2021\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -93,7 +94,7 @@
|
|||
"source": [
|
||||
"data_path = \"../../data/{}_{}.json\"\n",
|
||||
"\n",
|
||||
"train_samples = InputSample.read_dataset_json(data_path.format(\"train\",DATA_DATE))\n",
|
||||
"train_samples = InputSample.read_dataset_json(data_path.format(\"train\", DATA_DATE))\n",
|
||||
"print(\"Read {} samples\".format(len(train_samples)))"
|
||||
]
|
||||
},
|
||||
|
@ -216,7 +217,9 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"spacy_train = InputSample.create_spacy_dataset(dataset=train_tagged, output_path = \"train.spacy\")\n"
|
||||
"spacy_train = InputSample.create_spacy_dataset(\n",
|
||||
" dataset=train_tagged, output_path=\"train.spacy\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -225,7 +228,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entities_spacy = [x[1]['entities'] for x in spacy_train]\n",
|
||||
"entities_spacy = [x[1][\"entities\"] for x in spacy_train]\n",
|
||||
"entities_spacy_flat = []\n",
|
||||
"for samp in entities_spacy:\n",
|
||||
" for ent in samp:\n",
|
||||
|
@ -262,10 +265,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for fold in (\"test\",\"validation\"):\n",
|
||||
" dataset = InputSample.read_dataset_json(data_path.format(fold,DATA_DATE))\n",
|
||||
"for fold in (\"test\", \"validation\"):\n",
|
||||
" dataset = InputSample.read_dataset_json(data_path.format(fold, DATA_DATE))\n",
|
||||
" print(f\"Read {len(dataset)} samples for {fold}\")\n",
|
||||
" InputSample.create_spacy_dataset(dataset=dataset, output_path = f\"{fold}.spacy\")"
|
||||
" InputSample.create_spacy_dataset(dataset=dataset, output_path=f\"{fold}.spacy\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -2,37 +2,46 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5f6d37dc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Evaluate CRF models for person names, orgs and locations using the Presidio Evaluator framework"
|
||||
"Evaluate Conditional Random Field models using the Presidio Evaluator framework"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "51a9b95a",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from pprint import pprint\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"import spacy\n",
|
||||
"from presidio_evaluator.experiment_tracking import get_experiment_tracker\n",
|
||||
"from presidio_evaluator.models import CRFModel\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"pd.set_option('display.width', 10000)\n",
|
||||
"pd.set_option('display.max_columns', None) \n",
|
||||
"pd.set_option('display.max_rows', None) \n",
|
||||
"pd.set_option('display.max_colwidth', None)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0d2d772",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select data for evaluation:"
|
||||
|
@ -41,79 +50,65 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29c21b97",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"synth_samples = InputSample.read_dataset_json(\"../../data/synth_dataset.json\")\n",
|
||||
"print(len(synth_samples))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"DATASET = synth_samples"
|
||||
"DATA_DATE = \"Jan-15-2022\"\n",
|
||||
"dataset = InputSample.read_dataset_json(\"../../data/test_{}.json\".format(DATA_DATE))\n",
|
||||
"print(len(dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "955614fe",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import Counter\n",
|
||||
"entity_counter = Counter()\n",
|
||||
"for sample in DATASET:\n",
|
||||
" for tag in sample.tags:\n",
|
||||
" entity_counter[tag]+=1"
|
||||
"for sample in dataset:\n",
|
||||
" for t>ag in sample.tags:\n",
|
||||
" entity_counter[tag] += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f423493c",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entity_counter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATASET[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#max length sentence\n",
|
||||
"max([len(sample.tokens) for sample in DATASET])\n"
|
||||
"print(\"Count per entity:\")\n",
|
||||
"pprint(entity_counter.most_common())\n",
|
||||
"\n",
|
||||
"print(\"\\nExample sentence:\")\n",
|
||||
"print(dataset[1])\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max number of tokens in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.tokens) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.tokens) for sample in dataset])}\")\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max sentence length in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.full_text) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.full_text) for sample in dataset])}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4523af6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select models for evaluation:"
|
||||
|
@ -122,20 +117,22 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c6cbb5ba",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"crf_vanilla = \"../../models/crf.pickle\"\n",
|
||||
" \n",
|
||||
"# Assuming there exists a trained CRF model\n",
|
||||
"models = [crf_vanilla]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f683700f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run evaluation on all models:"
|
||||
|
@ -144,74 +141,88 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f987a404",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator.models import CRFModel\n",
|
||||
"\n",
|
||||
"for model in models:\n",
|
||||
"for model_path in models:\n",
|
||||
" print(\"-----------------------------------\")\n",
|
||||
" print(\"Evaluating model {}\".format(model))\n",
|
||||
" crf_model = CRFModel(model_pickle_path=model)\n",
|
||||
" evaluator = Evaluator(model=crf_model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(DATASET)\n",
|
||||
" scores = evaluator.calculate_score(evaluation_results)\n",
|
||||
" \n",
|
||||
" print(f\"Evaluating model {model_path}\")\n",
|
||||
" experiment = get_experiment_tracker()\n",
|
||||
"\n",
|
||||
" model = CRFModel(model_pickle_path=model_path)\n",
|
||||
" evaluator = Evaluator(model=model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(deepcopy(dataset))\n",
|
||||
" results = evaluator.calculate_score(evaluation_results)\n",
|
||||
"\n",
|
||||
" # update params tracking\n",
|
||||
" params = {\"dataset_name\":dataset_name, \"model_name\": model_path}\n",
|
||||
" params.update(model.to_log())\n",
|
||||
" experiment.log_parameters(params)\n",
|
||||
" experiment.log_dataset_hash(dataset)\n",
|
||||
" experiment.log_metrics(results.to_log())\n",
|
||||
" entities, confmatrix = results.to_confusion_matrix()\n",
|
||||
" experiment.log_confusion_matrix(matrix=confmatrix, labels=entities)\n",
|
||||
"\n",
|
||||
" print(\"Confusion matrix:\")\n",
|
||||
" print(scores.results)\n",
|
||||
" print(pd.DataFrame(confmatrix, columns=entities, index=entities))\n",
|
||||
"\n",
|
||||
" print(\"Precision and recall\")\n",
|
||||
" scores.print()"
|
||||
" print(results)\n",
|
||||
"\n",
|
||||
" # end experiment\n",
|
||||
" experiment.end()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c35e63e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Custom evaluation of the model"
|
||||
"### Results analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"id": "f8c53388",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Try out the model\n",
|
||||
"def sent_to_features(model_path,sent):\n",
|
||||
" \"\"\"\n",
|
||||
" Translates a sentence into a prediction using a saved CRF model\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" with open(model_path, 'rb') as f:\n",
|
||||
" model = pickle.load(f)\n",
|
||||
" \n",
|
||||
" tokenizer = spacy.blank('en')\n",
|
||||
" tokens = tokenizer(sent)\n",
|
||||
" tags = ['O' for token in tokens] # Placeholder: Not used but required. \n",
|
||||
" metadata = {'template_id':1,'Gender':'1','Country':'2'} #Placeholder: Not used but required\n",
|
||||
" input_sample = InputSample(full_text=sent,masked=\"\",spans=None,tokens=tokens,tags=tags,metadata=metadata,create_tags_from_span=False,)\n",
|
||||
"\n",
|
||||
" return CRFModel.crf_predict(input_sample, model)"
|
||||
"sent = 'I am taiwanese but I live in Cambodia.'\n",
|
||||
"#sent = input(\"Enter sentence: \")\n",
|
||||
"model.predict(InputSample(full_text=sent))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7dde59a3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Error Analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9b55fc5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SENTENCE = \"Michael is American\"\n",
|
||||
"\n",
|
||||
"sent_to_features(model_path=crf_vanilla, sent=SENTENCE)"
|
||||
"errors = results.model_errors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "179f1d88",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### False positives"
|
||||
|
@ -219,6 +230,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "478bd674",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Most false positive tokens:"
|
||||
|
@ -227,61 +239,45 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "e2bcb56d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = scores.model_errors\n",
|
||||
"\n",
|
||||
"ModelError.most_common_fp_tokens(errors)#[model_error for model_error in errors if model_error.error_type =='FP']\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. review false positives for entity 'PERSON'"
|
||||
"ModelError.most_common_fp_tokens(errors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "e7f222d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fps_df = ModelError.get_fps_dataframe(errors,entity='PERSON')\n",
|
||||
"fps_df[['full_text','token','prediction']]"
|
||||
"fps_df = ModelError.get_fps_dataframe(errors, entity=[\"GPE\"])\n",
|
||||
"fps_df[[\"full_text\", \"token\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a85823e8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### False negative examples"
|
||||
"2. False negative examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "600d5e62",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ModelError.most_common_fn_tokens(errors,n=50, entity='PERSON')"
|
||||
"ModelError.most_common_fn_tokens(errors, n=50, entity=[\"PERSON\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5b38b31d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"More FN analysis"
|
||||
|
@ -290,32 +286,38 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "7f9f2ee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df = ModelError.get_fns_dataframe(errors,entity='PERSON')"
|
||||
"fns_df = ModelError.get_fns_dataframe(errors, entity=['GPE'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "59aab6fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df[['full_text','token','annotation','prediction']]"
|
||||
"fns_df[[\"full_text\", \"token\", \"annotation\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b793a561",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"All errors:\\n\")\n",
|
||||
"[print(error,\"\\n\") for error in errors]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82aec145",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -341,5 +343,5 @@
|
|||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "097c5857",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Evaluate Flair models for person names, orgs and locations using the Presidio Evaluator framework"
|
||||
|
@ -10,21 +11,33 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "8835e626",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"from pathlib import Path\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from pprint import pprint\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"from presidio_evaluator.experiment_tracking import get_experiment_tracker\n",
|
||||
"from presidio_evaluator.models import FlairModel\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"pd.set_option('display.max_columns', None) \n",
|
||||
"pd.set_option('display.max_rows', None) \n",
|
||||
"pd.set_option('display.max_colwidth', None)\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f036de59",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select data for evaluation:"
|
||||
|
@ -33,107 +46,53 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"synth_samples = InputSample.read_dataset_json(\"../../data/synth_dataset.json\")\n",
|
||||
"print(len(synth_samples))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Map entity types"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3d8f14ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"presidio_entities_map = {\n",
|
||||
" \"PERSON\": \"PER\",\n",
|
||||
" \"EMAIL_ADDRESS\": \"O\",\n",
|
||||
" \"CREDIT_CARD\": \"O\",\n",
|
||||
" \"FIRST_NAME\": \"PER\",\n",
|
||||
" \"PHONE_NUMBER\": \"O\",\n",
|
||||
" \"BIRTHDAY\": \"O\",\n",
|
||||
" \"DATE_TIME\": \"O\",\n",
|
||||
" \"DOMAIN_NAME\": \"O\",\n",
|
||||
" \"CITY\": \"LOC\",\n",
|
||||
" \"ADDRESS\": \"LOC\",\n",
|
||||
" \"NATIONALITY\": \"LOC\",\n",
|
||||
" \"LOCATION\": \"LOC\",\n",
|
||||
" \"IBAN_CODE\": \"O\",\n",
|
||||
" \"US_DRIVER_LICENSE\": \"O\",\n",
|
||||
" \"URL\": \"O\",\n",
|
||||
" \"US_SSN\": \"O\",\n",
|
||||
" \"IP_ADDRESS\": \"O\",\n",
|
||||
" \"ORGANIZATION\": \"ORG\",\n",
|
||||
" \"TITLE\" : \"O\", # skipping evaluation of titles\n",
|
||||
" \"PREFIX\" : \"O\",\n",
|
||||
" \"O\": \"O\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"synth_samples = Evaluator.align_entity_types(synth_samples, presidio_entities_map)"
|
||||
"dataset_name = \"synth_dataset_v2.json\"\n",
|
||||
"dataset = InputSample.read_dataset_json(Path(Path.cwd().parent.parent, \"data\", dataset_name))\n",
|
||||
"print(len(dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "7605f540",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import Counter\n",
|
||||
"entity_counter = Counter()\n",
|
||||
"for sample in synth_samples:\n",
|
||||
"for sample in dataset:\n",
|
||||
" for tag in sample.tags:\n",
|
||||
" entity_counter[tag]+=1"
|
||||
" entity_counter[tag] += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "5c693b46",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entity_counter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#max length sentence\n",
|
||||
"max([len(sample.tokens) for sample in synth_samples])"
|
||||
"print(\"Count per entity:\")\n",
|
||||
"pprint(entity_counter.most_common())\n",
|
||||
"\n",
|
||||
"print(\"\\nExample sentence:\")\n",
|
||||
"print(dataset[1])\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max number of tokens in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.tokens) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.tokens) for sample in dataset])}\")\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max sentence length in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.full_text) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.full_text) for sample in dataset])}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "19920db4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select models for evaluation:"
|
||||
|
@ -142,57 +101,102 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "89cf0e43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"flair_ner = 'ner'\n",
|
||||
"flair_ner_fast = 'ner-fast'\n",
|
||||
"flair_ontonotes = 'ner-ontonotes-fast'\n",
|
||||
"models = [flair_ner, flair_ner_fast]"
|
||||
"flair_ner = \"ner-english\"\n",
|
||||
"flair_ner_fast = \"ner-english-fast\"\n",
|
||||
"flair_ontonotes_fast = \"ner-english-ontonotes-fast\"\n",
|
||||
"flair_ontonotes_large = \"ner-english-ontonotes-large\"\n",
|
||||
"models = [flair_ner, flair_ner_fast, flair_ontonotes_fast ,flair_ner_fast, flair_ontonotes_large]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "266b59ca",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run evaluation on all models:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"id": "21d1e3dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator.models import FlairModel\n",
|
||||
"\n",
|
||||
"for model in models:\n",
|
||||
"for model_name in models:\n",
|
||||
" print(\"-----------------------------------\")\n",
|
||||
" print(\"Evaluating model {}\".format(model))\n",
|
||||
" flair_model = FlairModel(model_path=model)\n",
|
||||
" evaluator = Evaluator(model=flair_model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(synth_samples)\n",
|
||||
" scores = evaluator.calculate_score(evaluation_results)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" print(f\"Evaluating model {model_name}\")\n",
|
||||
" experiment = get_experiment_tracker()\n",
|
||||
"\n",
|
||||
" model = FlairModel(model_path=model_name)\n",
|
||||
" evaluator = Evaluator(model=model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(deepcopy(dataset))\n",
|
||||
" results = evaluator.calculate_score(evaluation_results)\n",
|
||||
"\n",
|
||||
" # update params tracking\n",
|
||||
" params = {\"dataset_name\":dataset_name, \"model_name\": model_name}\n",
|
||||
" params.update(model.to_log())\n",
|
||||
" experiment.log_parameters(params)\n",
|
||||
" experiment.log_dataset_hash(dataset)\n",
|
||||
" experiment.log_metrics(results.to_log())\n",
|
||||
" entities, confmatrix = results.to_confusion_matrix()\n",
|
||||
" experiment.log_confusion_matrix(matrix=confmatrix, labels=entities)\n",
|
||||
"\n",
|
||||
" print(\"Confusion matrix:\")\n",
|
||||
" print(scores.results)\n",
|
||||
" print(pd.DataFrame(confmatrix, columns=entities, index=entities))\n",
|
||||
"\n",
|
||||
" print(\"Precision and recall\")\n",
|
||||
" scores.print()\n",
|
||||
" errors = scores.model_errors\n"
|
||||
" print(results)\n",
|
||||
"\n",
|
||||
" # end experiment\n",
|
||||
" experiment.end()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0fd5e41b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Custom evaluation"
|
||||
"### Results analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ebe0c5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sent = 'I am taiwanese but I live in Cambodia.'\n",
|
||||
"#sent = input(\"Enter sentence: \")\n",
|
||||
"model.predict(InputSample(full_text=sent))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aa0a5160",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Error Analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "77556203",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = results.model_errors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0d895d3c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### False positives"
|
||||
|
@ -200,6 +204,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fdd63274",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Most false positive tokens:"
|
||||
|
@ -208,34 +213,27 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "5616bf20",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = scores.model_errors\n",
|
||||
"\n",
|
||||
"ModelError.most_common_fp_tokens(errors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "fcd4918a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fps_df = ModelError.get_fps_dataframe(errors,entity=['PERSON'])\n",
|
||||
"fps_df[['full_text','token','prediction']]"
|
||||
"fps_df = ModelError.get_fps_dataframe(errors, entity=[\"GPE\"])\n",
|
||||
"fps_df[[\"full_text\", \"token\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f8e875a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. False negative examples"
|
||||
|
@ -244,18 +242,17 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "f2826099",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ModelError.most_common_fn_tokens(errors,n=50, entity=['PER'])"
|
||||
"errors = scores.model_errors\n",
|
||||
"ModelError.most_common_fn_tokens(errors, n=50, entity=[\"PERSON\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e37fd0cf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"More FN analysis"
|
||||
|
@ -264,33 +261,38 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"id": "266fb7d1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df = ModelError.get_fns_dataframe(errors,entity=['PERSON'])"
|
||||
"fns_df = ModelError.get_fns_dataframe(errors, entity=['GPE'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"id": "35a11568",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df[['full_text','token','annotation','prediction']]"
|
||||
"fns_df[[\"full_text\", \"token\", \"annotation\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "70a5c832",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"All errors:\\n\")\n",
|
||||
"[print(error,\"\\n\") for error in errors]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fda55c14",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -313,17 +315,8 @@
|
|||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
},
|
||||
"pycharm": {
|
||||
"stem_cell": {
|
||||
"cell_type": "raw",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"source": []
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
|
|
|
@ -10,20 +10,28 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from pprint import pprint\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"from presidio_evaluator.models import SpacyModel\n",
|
||||
"from presidio_evaluator.experiment_tracking import get_experiment_tracker\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import spacy\n",
|
||||
"\n",
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"pd.set_option(\"display.max_columns\", None)\n",
|
||||
"pd.set_option(\"display.max_rows\", None)\n",
|
||||
"pd.set_option(\"display.max_colwidth\", None)\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n"
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -41,73 +49,51 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"synth_samples = InputSample.read_dataset_json(\"../../data/synth_dataset.json\")\n",
|
||||
"print(len(synth_samples))\n",
|
||||
"DATASET = synth_samples"
|
||||
"dataset_name = \"synth_dataset_v2.json\"\n",
|
||||
"dataset = InputSample.read_dataset_json(\n",
|
||||
" Path(Path.cwd().parent.parent, \"data\", dataset_name)\n",
|
||||
")\n",
|
||||
"print(len(dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import Counter\n",
|
||||
"entity_counter = Counter()\n",
|
||||
"for sample in DATASET:\n",
|
||||
" for span in sample.spans:\n",
|
||||
" entity_counter[span.entity_type]+=1"
|
||||
"for sample in dataset:\n",
|
||||
" for tag in sample.tags:\n",
|
||||
" entity_counter[tag] += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entity_counter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATASET[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#max length sentence\n",
|
||||
"max([len(sample.tokens) for sample in DATASET])"
|
||||
"print(\"Count per entity:\")\n",
|
||||
"pprint(entity_counter.most_common())\n",
|
||||
"\n",
|
||||
"print(\"\\nExample sentence:\")\n",
|
||||
"print(dataset[1])\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max number of tokens in dataset:\")\n",
|
||||
"print(\n",
|
||||
" f\"Min: {min([len(sample.tokens) for sample in dataset])}, \"\n",
|
||||
" f\"Max: {max([len(sample.tokens) for sample in dataset])}\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max sentence length in dataset:\")\n",
|
||||
"print(\n",
|
||||
" f\"Min: {min([len(sample.full_text) for sample in dataset])}, \"\n",
|
||||
" f\"Max: {max([len(sample.full_text) for sample in dataset])}\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -120,14 +106,10 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
}
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"models = [\"en_core_web_lg\", \"en_core_web_trf\"]"
|
||||
"models = [\"en_core_web_sm\", \"en_core_web_lg\", \"en_core_web_trf\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -147,24 +129,37 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from presidio_evaluator.models import SpacyModel\n",
|
||||
"for model_name in models:\n",
|
||||
" experiment = get_experiment_tracker()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for model in models:\n",
|
||||
" print(\"-----------------------------------\")\n",
|
||||
" print(\"Evaluating model {}\".format(model))\n",
|
||||
" nlp = spacy.load(model)\n",
|
||||
" spacy_model = SpacyModel(model=nlp,entities_to_keep=['PERSON', 'GPE', 'ORG'])\n",
|
||||
" evaluator = Evaluator(model=spacy_model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(DATASET)\n",
|
||||
" scores = evaluator.calculate_score(evaluation_results)\n",
|
||||
" \n",
|
||||
" print(f\"Evaluating model {model_name}\")\n",
|
||||
"\n",
|
||||
" nlp = spacy.load(model_name)\n",
|
||||
" model = SpacyModel(\n",
|
||||
" model=nlp, entities_to_keep=[\"PERSON\", \"GPE\", \"ORG\", \"NORP\"]\n",
|
||||
" )\n",
|
||||
" evaluator = Evaluator(model=model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(deepcopy(dataset))\n",
|
||||
" results = evaluator.calculate_score(evaluation_results)\n",
|
||||
"\n",
|
||||
" # update params tracking\n",
|
||||
" params = {\"dataset_name\":dataset_name, \"model_name\": model_name}\n",
|
||||
" params.update(model.to_log())\n",
|
||||
" experiment.log_parameters(params)\n",
|
||||
" experiment.log_dataset_hash(dataset)\n",
|
||||
" experiment.log_metrics(results.to_log())\n",
|
||||
" entities, confmatrix = results.to_confusion_matrix()\n",
|
||||
" experiment.log_confusion_matrix(matrix=confmatrix, labels=entities)\n",
|
||||
"\n",
|
||||
" print(\"Confusion matrix:\")\n",
|
||||
" print(scores.results)\n",
|
||||
" print(pd.DataFrame(confmatrix, columns=entities, index=entities))\n",
|
||||
"\n",
|
||||
" print(\"Precision and recall\")\n",
|
||||
" scores.print()\n",
|
||||
" errors = scores.model_errors"
|
||||
" print(results)\n",
|
||||
"\n",
|
||||
" # end experiment\n",
|
||||
" experiment.end()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -179,13 +174,13 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#evaluate custom sentences (if exists)\n",
|
||||
"#nlp = spacy.load(spacy_ft_100)\n"
|
||||
"# evaluate custom sentences (if exists)\n",
|
||||
"# nlp = spacy.load(spacy_ft_100)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -200,16 +195,23 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sent = 'I am taiwanese but I live in Cambodia.'\n",
|
||||
"#sent = input(\"Enter sentence: \")\n",
|
||||
"sent = 'David is talking loudly'\n",
|
||||
"doc = nlp(sent)\n",
|
||||
"for ent in doc.ents:\n",
|
||||
" print(\"Entity = {} value = {}\".format(ent.label_,ent.text))"
|
||||
"model.predict(InputSample(full_text=sent))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = results.model_errors"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -231,7 +233,7 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -244,13 +246,13 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fps_df = ModelError.get_fps_dataframe(errors,entity=['GPE'])\n",
|
||||
"fps_df[['full_text','token','prediction']]"
|
||||
"fps_df = ModelError.get_fps_dataframe(errors, entity=[\"GPE\"])\n",
|
||||
"fps_df[[\"full_text\", \"token\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -265,13 +267,13 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"errors = scores.model_errors\n",
|
||||
"ModelError.most_common_fn_tokens(errors,n=50, entity=['PERSON'])"
|
||||
"\n",
|
||||
"ModelError.most_common_fn_tokens(errors, n=50, entity=[\"PERSON\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -286,12 +288,12 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df = ModelError.get_fns_dataframe(errors,entity=['GPE'])"
|
||||
"fns_df = ModelError.get_fns_dataframe(errors, entity=[\"GPE\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -299,28 +301,36 @@
|
|||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": false
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df[['full_text','token','annotation','prediction']]"
|
||||
"fns_df[[\"full_text\", \"token\", \"annotation\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"All errors:\\n\")\n",
|
||||
"[print(error,\"\\n\") for error in errors]"
|
||||
"[print(error, \"\\n\") for error in errors]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"is_executing": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
|
|
|
@ -0,0 +1,300 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2c0696fc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Evaluate Stanza models for person names, orgs and locations using the Presidio Evaluator framework"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "15ba6110",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from copy import deepcopy\n",
|
||||
"from pprint import pprint\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"from presidio_evaluator import InputSample\n",
|
||||
"from presidio_evaluator.evaluation import Evaluator, ModelError\n",
|
||||
"from presidio_evaluator.models import StanzaModel\n",
|
||||
"from presidio_evaluator.experiment_tracking import get_experiment_tracker\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"import spacy\n",
|
||||
"\n",
|
||||
"pd.set_option(\"display.max_columns\", None)\n",
|
||||
"pd.set_option(\"display.max_rows\", None)\n",
|
||||
"pd.set_option(\"display.max_colwidth\", None)\n",
|
||||
"\n",
|
||||
"%reload_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d57f0008",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select data for evaluation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea4440e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_name = \"synth_dataset_v2.json\"\n",
|
||||
"dataset = InputSample.read_dataset_json(Path(Path.cwd().parent.parent, \"data\", dataset_name))\n",
|
||||
"print(len(dataset))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0fa78fe9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"entity_counter = Counter()\n",
|
||||
"for sample in dataset:\n",
|
||||
" for tag in sample.tags:\n",
|
||||
" entity_counter[tag] += 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "91a91b4f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"Count per entity:\")\n",
|
||||
"pprint(entity_counter.most_common())\n",
|
||||
"\n",
|
||||
"print(\"\\nExample sentence:\")\n",
|
||||
"print(dataset[1])\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max number of tokens in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.tokens) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.tokens) for sample in dataset])}\")\n",
|
||||
"\n",
|
||||
"print(\"\\nMin and max sentence length in dataset:\")\n",
|
||||
"print(f\"Min: {min([len(sample.full_text) for sample in dataset])}, \" \\\n",
|
||||
" f\"Max: {max([len(sample.full_text) for sample in dataset])}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d2065a27",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Select models for evaluation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f323b611",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"models = [\"en\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ffd65bf9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run evaluation on all models:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bdeafa78",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for model_name in models:\n",
|
||||
" experiment = get_experiment_tracker()\n",
|
||||
" print(\"-----------------------------------\")\n",
|
||||
" print(f\"Evaluating model {model_name}\")\n",
|
||||
" \n",
|
||||
" model = StanzaModel(model_name=model_name, entities_to_keep=['PERSON', 'GPE', 'ORG', 'NORP'])\n",
|
||||
" evaluator = Evaluator(model=model)\n",
|
||||
" evaluation_results = evaluator.evaluate_all(deepcopy(dataset))\n",
|
||||
" results = evaluator.calculate_score(evaluation_results)\n",
|
||||
"\n",
|
||||
" # update params tracking\n",
|
||||
" params = {\"dataset_name\":dataset_name, \"model_name\": model_name}\n",
|
||||
" params.update(model.to_log())\n",
|
||||
" experiment.log_parameters(params)\n",
|
||||
" experiment.log_dataset_hash(dataset)\n",
|
||||
" experiment.log_metrics(results.to_log())\n",
|
||||
" entities, confmatrix = results.to_confusion_matrix()\n",
|
||||
" experiment.log_confusion_matrix(matrix=confmatrix, labels=entities)\n",
|
||||
"\n",
|
||||
" print(\"Confusion matrix:\")\n",
|
||||
" print(pd.DataFrame(confmatrix, columns=entities, index=entities))\n",
|
||||
"\n",
|
||||
" print(\"Precision and recall\")\n",
|
||||
" print(results)\n",
|
||||
"\n",
|
||||
" # end experiment\n",
|
||||
" experiment.end()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cb01101d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Results analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5f407b40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sent = 'I am taiwanese but I live in Cambodia.'\n",
|
||||
"#sent = input(\"Enter sentence: \")\n",
|
||||
"model.predict(InputSample(full_text=sent))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b8c1e391",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### False positives"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ce4d351",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"1. Most false positive tokens:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6cd00bec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ModelError.most_common_fp_tokens(errors=results.model_errors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "82c5aca4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fps_df = ModelError.get_fps_dataframe(errors=results.model_errors, entity=[\"NORP\"])\n",
|
||||
"fps_df[[\"full_text\", \"token\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4b5879f3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"2. False negative examples"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "340e5509",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ModelError.most_common_fn_tokens(errors=results.model_errors, n=50, entity=[\"PERSON\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1e2d693c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"More FN analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bf3c08ed",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df = ModelError.get_fns_dataframe(errors=results.model_errors,entity=['GPE'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8ad26e71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fns_df[[\"full_text\", \"token\", \"annotation\", \"prediction\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7b481676",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"All errors:\\n\")\n",
|
||||
"[print(error,\"\\n\") for error in results.model_errors]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9399e426",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "presidio",
|
||||
"language": "python",
|
||||
"name": "presidio"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -32,7 +32,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_DATE = \"Dec-22-2021\""
|
||||
"DATA_DATE = \"Jan-15-2022\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -52,8 +52,12 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_samples = InputSample.read_dataset_json(\"../../data/train_{}.json\".format(DATA_DATE))\n",
|
||||
"test_samples = InputSample.read_dataset_json(\"../../data/test_{}.json\".format(DATA_DATE))"
|
||||
"train_samples = InputSample.read_dataset_json(\n",
|
||||
" \"../../data/train_{}.json\".format(DATA_DATE)\n",
|
||||
")\n",
|
||||
"test_samples = InputSample.read_dataset_json(\n",
|
||||
" \"../../data/test_{}.json\".format(DATA_DATE)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -67,7 +71,11 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"train_tagged = [sample for sample in train_samples if len(sample.spans) > 0]\n",
|
||||
"print(\"Kept {} train samples after removal of non-tagged samples\".format(len(train_tagged)))\n",
|
||||
"print(\n",
|
||||
" \"Kept {} train samples after removal of non-tagged samples\".format(\n",
|
||||
" len(train_tagged)\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"train_data = InputSample.create_conll_dataset(train_tagged)\n",
|
||||
"\n",
|
||||
"test_data = InputSample.create_conll_dataset(test_samples)\n",
|
||||
|
@ -81,8 +89,12 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# Turn every sentence into a list of lists (list of tokens + pos + label)\n",
|
||||
"test_sents=test_data.groupby('sentence')[['text','pos','label']].apply(lambda x: x.values.tolist())\n",
|
||||
"train_sents=train_data.groupby('sentence')[['text','pos','label']].apply(lambda x: x.values.tolist())\n"
|
||||
"test_sents = test_data.groupby(\"sentence\")[[\"text\", \"pos\", \"label\"]].apply(\n",
|
||||
" lambda x: x.values.tolist()\n",
|
||||
")\n",
|
||||
"train_sents = train_data.groupby(\"sentence\")[[\"text\", \"pos\", \"label\"]].apply(\n",
|
||||
" lambda x: x.values.tolist()\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -123,11 +135,7 @@
|
|||
"source": [
|
||||
"%%time\n",
|
||||
"crf = sklearn_crfsuite.CRF(\n",
|
||||
" algorithm='lbfgs',\n",
|
||||
" c1=0.1,\n",
|
||||
" c2=0.1,\n",
|
||||
" max_iterations=100,\n",
|
||||
" all_possible_transitions=True\n",
|
||||
" algorithm=\"lbfgs\", c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True\n",
|
||||
")\n",
|
||||
"crf.fit(X_train, y_train)"
|
||||
]
|
||||
|
@ -150,9 +158,8 @@
|
|||
"\n",
|
||||
"os.makedirs(\"../../models/\", exist_ok=True)\n",
|
||||
"\n",
|
||||
"with open(\"../../models/crf.pickle\",'wb') as f:\n",
|
||||
" pickle.dump(crf, f,protocol=pickle.HIGHEST_PROTOCOL)\n",
|
||||
" "
|
||||
"with open(\"../../models/crf.pickle\", \"wb\") as f:\n",
|
||||
" pickle.dump(crf, f, protocol=pickle.HIGHEST_PROTOCOL)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -168,7 +175,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(\"../../models/crf.pickle\", 'rb') as f:\n",
|
||||
"with open(\"../../models/crf.pickle\", \"rb\") as f:\n",
|
||||
" crf = pickle.load(f)"
|
||||
]
|
||||
},
|
||||
|
@ -186,7 +193,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"labels = list(crf.classes_)\n",
|
||||
"labels.remove('O')\n",
|
||||
"labels.remove(\"O\")\n",
|
||||
"labels"
|
||||
]
|
||||
},
|
||||
|
@ -197,8 +204,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"y_pred = crf.predict(X_test)\n",
|
||||
"metrics.flat_f1_score(y_test, y_pred,\n",
|
||||
" average='weighted', labels=labels)"
|
||||
"metrics.flat_f1_score(y_test, y_pred, average=\"weighted\", labels=labels)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -219,13 +225,10 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# group B and I results\n",
|
||||
"sorted_labels = sorted(\n",
|
||||
" labels,\n",
|
||||
" key=lambda name: (name[1:], name[0])\n",
|
||||
")\n",
|
||||
"print(metrics.flat_classification_report(\n",
|
||||
" y_test, y_pred, labels=sorted_labels, digits=3\n",
|
||||
"))"
|
||||
"sorted_labels = sorted(labels, key=lambda name: (name[1:], name[0]))\n",
|
||||
"print(\n",
|
||||
" metrics.flat_classification_report(y_test, y_pred, labels=sorted_labels, digits=3)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
---
|
||||
jupyter:
|
||||
jupytext:
|
||||
text_representation:
|
||||
extension: .md
|
||||
format_name: markdown
|
||||
format_version: '1.3'
|
||||
jupytext_version: 1.13.6
|
||||
kernelspec:
|
||||
display_name: presidio
|
||||
language: python
|
||||
name: presidio
|
||||
---
|
||||
|
||||
Evaluate ...XYZ... models for person names, orgs and locations using the Presidio Evaluator framework
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from pprint import pprint
|
||||
from collections import Counter
|
||||
|
||||
from presidio_evaluator import InputSample
|
||||
from presidio_evaluator.evaluation import Evaluator, ModelError
|
||||
from presidio_evaluator.experiment_tracking import get_experiment_tracker
|
||||
from presidio_evaluator.models import ...Model...
|
||||
|
||||
import pandas as pd
|
||||
|
||||
pd.set_option('display.max_columns', None)
|
||||
pd.set_option('display.max_rows', None)
|
||||
pd.set_option('display.max_colwidth', None)
|
||||
|
||||
%reload_ext autoreload
|
||||
%autoreload 2
|
||||
```
|
||||
|
||||
Select data for evaluation:
|
||||
|
||||
```python
|
||||
dataset_name = "synth_dataset_v2.json"
|
||||
dataset = InputSample.read_dataset_json(Path(Path.cwd().parent.parent, "data", dataset_name))
|
||||
print(len(dataset))
|
||||
```
|
||||
|
||||
```python
|
||||
entity_counter = Counter()
|
||||
for sample in dataset:
|
||||
for tag in sample.tags:
|
||||
entity_counter[tag] += 1
|
||||
```
|
||||
|
||||
```python
|
||||
print("Count per entity:")
|
||||
pprint(entity_counter.most_common())
|
||||
|
||||
print("\nExample sentence:")
|
||||
print(dataset[1])
|
||||
|
||||
print("\nMin and max number of tokens in dataset:")
|
||||
print(f"Min: {min([len(sample.tokens) for sample in dataset])}, " \
|
||||
f"Max: {max([len(sample.tokens) for sample in dataset])}")
|
||||
|
||||
print("\nMin and max sentence length in dataset:")
|
||||
print(f"Min: {min([len(sample.full_text) for sample in dataset])}, " \
|
||||
f"Max: {max([len(sample.full_text) for sample in dataset])}")
|
||||
```
|
||||
|
||||
Select models for evaluation:
|
||||
|
||||
```python
|
||||
models = [...MODEL NAMES...]
|
||||
```
|
||||
|
||||
Run evaluation on all models:
|
||||
|
||||
```python
|
||||
for model_name in models:
|
||||
print("-----------------------------------")
|
||||
print(f"Evaluating model {model_name}")
|
||||
experiment = get_experiment_tracker()
|
||||
|
||||
model = Model(..., entities_to_keep=['PERSON', 'GPE', 'ORG', 'NORP'])
|
||||
evaluator = Evaluator(model=model)
|
||||
evaluation_results = evaluator.evaluate_all(deepcopy(dataset))
|
||||
results = evaluator.calculate_score(evaluation_results)
|
||||
|
||||
# update params tracking
|
||||
params = {"dataset_name":dataset_name, "model_name": model_name}
|
||||
params.update(model.to_log())
|
||||
experiment.log_parameters(params)
|
||||
experiment.log_dataset_hash(dataset)
|
||||
experiment.log_metrics(results.to_log())
|
||||
entities, confmatrix = results.to_confusion_matrix()
|
||||
experiment.log_confusion_matrix(matrix=confmatrix, labels=entities)
|
||||
|
||||
print("Confusion matrix:")
|
||||
print(pd.DataFrame(confmatrix, columns=entities, index=entities))
|
||||
|
||||
print("Precision and recall")
|
||||
print(results)
|
||||
|
||||
# end experiment
|
||||
experiment.end()
|
||||
```
|
||||
|
||||
### Results analysis
|
||||
|
||||
```python
|
||||
sent = 'I am taiwanese but I live in Cambodia.'
|
||||
#sent = input("Enter sentence: ")
|
||||
model.predict(InputSample(full_text=sent))
|
||||
```
|
||||
|
||||
### Error Analysis
|
||||
|
||||
```python
|
||||
errors = results.model_errors
|
||||
```
|
||||
|
||||
#### False positives
|
||||
|
||||
|
||||
1. Most false positive tokens:
|
||||
|
||||
```python
|
||||
ModelError.most_common_fp_tokens(errors)
|
||||
```
|
||||
|
||||
```python
|
||||
fps_df = ModelError.get_fps_dataframe(errors, entity=["GPE"])
|
||||
fps_df[["full_text", "token", "prediction"]]
|
||||
```
|
||||
|
||||
2. False negative examples
|
||||
|
||||
```python
|
||||
errors = scores.model_errors
|
||||
ModelError.most_common_fn_tokens(errors, n=50, entity=["PERSON"])
|
||||
```
|
||||
|
||||
More FN analysis
|
||||
|
||||
```python
|
||||
fns_df = ModelError.get_fns_dataframe(errors, entity=['GPE'])
|
||||
```
|
||||
|
||||
```python
|
||||
fns_df[["full_text", "token", "annotation", "prediction"]]
|
||||
```
|
||||
|
||||
```python
|
||||
print("All errors:\n")
|
||||
[print(error,"\n") for error in errors]
|
||||
```
|
||||
|
||||
```python
|
||||
|
||||
```
|
|
@ -1,4 +1,4 @@
|
|||
from .span_to_tag import span_to_tag, tokenize
|
||||
from .span_to_tag import span_to_tag, tokenize, io_to_scheme
|
||||
from .data_objects import Span, InputSample
|
||||
from .validation import (
|
||||
split_dataset,
|
||||
|
@ -7,11 +7,16 @@ from .validation import (
|
|||
group_by_template,
|
||||
save_to_json,
|
||||
)
|
||||
from .experiment_tracker import ExperimentTracker
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # take environment variables from .env.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"span_to_tag",
|
||||
"tokenize",
|
||||
"io_to_scheme",
|
||||
"Span",
|
||||
"InputSample",
|
||||
"split_dataset",
|
||||
|
@ -19,5 +24,4 @@ __all__ = [
|
|||
"get_samples_by_pattern",
|
||||
"group_by_template",
|
||||
"save_to_json",
|
||||
"ExperimentTracker",
|
||||
]
|
||||
|
|
|
@ -231,9 +231,9 @@ class InputSample(object):
|
|||
labels = span_to_tag(
|
||||
scheme=scheme,
|
||||
text=self.full_text,
|
||||
tag=tags,
|
||||
start=start_indices,
|
||||
end=end_indices,
|
||||
tags=tags,
|
||||
starts=start_indices,
|
||||
ends=end_indices,
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
@ -526,8 +526,8 @@ class InputSample(object):
|
|||
for tag in self.tags
|
||||
]
|
||||
for span in self.spans:
|
||||
if span.entity_value in PRESIDIO_SPACY_ENTITIES:
|
||||
span.entity_value = PRESIDIO_SPACY_ENTITIES[span.entity_value]
|
||||
if span.entity_value in dictionary:
|
||||
span.entity_value = dictionary[span.entity_value]
|
||||
elif ignore_unknown:
|
||||
span.entity_value = "O"
|
||||
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
import json
|
||||
from collections import Counter
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
from presidio_evaluator.evaluation import ModelError
|
||||
|
||||
|
||||
class EvaluationResult(object):
|
||||
class EvaluationResult:
|
||||
def __init__(
|
||||
self,
|
||||
results: Counter,
|
||||
model_errors: Optional[List[ModelError]] = None,
|
||||
text: str = None,
|
||||
text: Optional[str] = None,
|
||||
pii_recall: Optional[float] = None,
|
||||
pii_precision: Optional[float] = None,
|
||||
pii_f: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
entity_recall_dict: Optional[Dict[str, float]] = None,
|
||||
entity_precision_dict: Optional[Dict[str, float]] = None,
|
||||
n_dict: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Holds the output of a comparison between ground truth and predicted
|
||||
|
@ -17,45 +25,59 @@ class EvaluationResult(object):
|
|||
with structure {(actual, predicted) : count}
|
||||
:param model_errors: List of specific model errors for further inspection
|
||||
:param text: sample's full text (if used for one sample)
|
||||
:param pii_recall: Recall for all entities (PII or not)
|
||||
:param pii_precision: Precision for all entities (PII or not)
|
||||
:param pii_f: F measure for all entities (PII or not)
|
||||
:param n: Number of total entity tokens
|
||||
:param entity_recall_dict: Recall per entity
|
||||
:param entity_precision_dict: Precision per entity
|
||||
:param n_dict: Number of tokens per entity
|
||||
"""
|
||||
|
||||
self.results = results
|
||||
self.model_errors = model_errors
|
||||
self.text = text
|
||||
|
||||
self.pii_recall = None
|
||||
self.pii_precision = None
|
||||
self.pii_f = None
|
||||
self.entity_recall_dict = None
|
||||
self.entity_precision_dict = None
|
||||
self.n = None
|
||||
self.pii_recall = pii_recall
|
||||
self.pii_precision = pii_precision
|
||||
self.pii_f = pii_f
|
||||
self.n = n
|
||||
self.entity_recall_dict = entity_recall_dict
|
||||
self.entity_precision_dict = entity_precision_dict
|
||||
self.n_dict = n_dict
|
||||
|
||||
def print(self):
|
||||
def __str__(self):
|
||||
return_str = ""
|
||||
if not self.entity_precision_dict or not self.entity_recall_dict:
|
||||
return json.dumps(self.results)
|
||||
|
||||
recall_dict = dict(sorted(self.entity_recall_dict.items()))
|
||||
precision_dict = dict(sorted(self.entity_precision_dict.items()))
|
||||
entities = self.n_dict.keys()
|
||||
|
||||
recall_dict["PII"] = self.pii_recall
|
||||
precision_dict["PII"] = self.pii_precision
|
||||
|
||||
entities = recall_dict.keys()
|
||||
recall = recall_dict.values()
|
||||
precision = precision_dict.values()
|
||||
n = self.n.values()
|
||||
|
||||
row_format = "{:>30}{:>30.2%}{:>30.2%}{:>30}"
|
||||
header_format = "{:>30}" * 4
|
||||
print(
|
||||
row_format = "{:>20}{:>20.2%}{:>20.2%}{:>20}"
|
||||
header_format = "{:>20}" * 4
|
||||
return_str += str(
|
||||
header_format.format(
|
||||
*("Entity", "Precision", "Recall", "Number of samples")
|
||||
)
|
||||
)
|
||||
for entity, precision, recall, n in zip(entities, precision, recall, n):
|
||||
print(row_format.format(entity, precision, recall, n))
|
||||
for entity in entities:
|
||||
return_str += "\n" + row_format.format(
|
||||
entity,
|
||||
self.entity_precision_dict[entity],
|
||||
self.entity_recall_dict[entity],
|
||||
self.n_dict[entity],
|
||||
)
|
||||
|
||||
print("PII F measure: {}".format(self.pii_f))
|
||||
# add PII values
|
||||
return_str += "\n" + row_format.format(
|
||||
"PII",
|
||||
self.pii_precision,
|
||||
self.pii_precision,
|
||||
self.n,
|
||||
)
|
||||
|
||||
return_str += f"\nPII F measure: {self.pii_f:.2%}"
|
||||
return return_str
|
||||
|
||||
def __repr__(self):
|
||||
return f"stats={self.results}"
|
||||
|
@ -73,11 +95,17 @@ class EvaluationResult(object):
|
|||
)
|
||||
if self.entity_recall_dict:
|
||||
metrics_dict.update(
|
||||
{
|
||||
f"{ent}_recall": v
|
||||
for (ent, v) in self.entity_recall_dict.items()
|
||||
}
|
||||
{f"{ent}_recall": v for (ent, v) in self.entity_recall_dict.items()}
|
||||
)
|
||||
if self.n:
|
||||
metrics_dict.update(self.n)
|
||||
metrics_dict.update(self.n_dict)
|
||||
return metrics_dict
|
||||
|
||||
def to_confusion_matrix(self) -> Tuple[List[str], List[List[int]]]:
|
||||
entities = sorted(list(set(self.n_dict.keys()).union("O")))
|
||||
confusion_matrix = [[0] * len(entities) for _ in range(len(entities))]
|
||||
for i, actual in enumerate(entities):
|
||||
for j, predicted in enumerate(entities):
|
||||
confusion_matrix[i][j] = self.results[(actual, predicted)]
|
||||
|
||||
return entities, confusion_matrix
|
||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import tqdm
|
|||
|
||||
from presidio_evaluator import InputSample
|
||||
from presidio_evaluator.evaluation import EvaluationResult, ModelError
|
||||
from presidio_evaluator.models import BaseModel, PresidioAnalyzerWrapper
|
||||
from presidio_evaluator.models import BaseModel
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
@ -81,34 +81,34 @@ class Evaluator:
|
|||
if prediction[i] == "O":
|
||||
mistakes.append(
|
||||
ModelError(
|
||||
"FN",
|
||||
new_annotation[i],
|
||||
prediction[i],
|
||||
tokens[i],
|
||||
input_sample.full_text,
|
||||
input_sample.metadata,
|
||||
error_type="FN",
|
||||
annotation=new_annotation[i],
|
||||
prediction=prediction[i],
|
||||
token=tokens[i],
|
||||
full_text=input_sample.full_text,
|
||||
metadata=input_sample.metadata,
|
||||
)
|
||||
)
|
||||
elif new_annotation[i] == "O":
|
||||
mistakes.append(
|
||||
ModelError(
|
||||
"FP",
|
||||
new_annotation[i],
|
||||
prediction[i],
|
||||
tokens[i],
|
||||
input_sample.full_text,
|
||||
input_sample.metadata,
|
||||
error_type="FP",
|
||||
annotation=new_annotation[i],
|
||||
prediction=prediction[i],
|
||||
token=tokens[i],
|
||||
full_text=input_sample.full_text,
|
||||
metadata=input_sample.metadata,
|
||||
)
|
||||
)
|
||||
else:
|
||||
mistakes.append(
|
||||
ModelError(
|
||||
"Wrong entity",
|
||||
new_annotation[i],
|
||||
prediction[i],
|
||||
tokens[i],
|
||||
input_sample.full_text,
|
||||
input_sample.metadata,
|
||||
error_type="Wrong entity",
|
||||
annotation=new_annotation[i],
|
||||
prediction=prediction[i],
|
||||
token=tokens[i],
|
||||
full_text=input_sample.full_text,
|
||||
metadata=input_sample.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -117,6 +117,8 @@ class Evaluator:
|
|||
def _adjust_per_entities(self, tags):
|
||||
if self.entities_to_keep:
|
||||
return [tag if tag in self.entities_to_keep else "O" for tag in tags]
|
||||
else:
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def _to_io(tags):
|
||||
|
@ -140,8 +142,22 @@ class Evaluator:
|
|||
|
||||
def evaluate_all(self, dataset: List[InputSample]) -> List[EvaluationResult]:
|
||||
evaluation_results = []
|
||||
for sample in tqdm(dataset, desc="Evaluating {}".format(self.__class__)):
|
||||
if self.model.entity_mapping:
|
||||
print(f"Mapping entity values using this dictionary: {self.model.entity_mapping}")
|
||||
for sample in tqdm(dataset, desc=f"Evaluating {self.model.__class__}"):
|
||||
|
||||
# Align tag values to the ones expected by the model
|
||||
self.model.align_entity_types(sample)
|
||||
|
||||
# Predict
|
||||
prediction = self.model.predict(sample)
|
||||
|
||||
# Remove entities not requested
|
||||
prediction = self.model.filter_tags_in_supported_entities(prediction)
|
||||
|
||||
# Switch to requested labeling scheme (IO/BIO/BILUO)
|
||||
prediction = self.model.to_scheme(prediction)
|
||||
|
||||
evaluation_result = self.evaluate_sample(
|
||||
sample=sample, prediction=prediction
|
||||
)
|
||||
|
@ -287,13 +303,17 @@ class Evaluator:
|
|||
if res.model_errors:
|
||||
errors.extend(res.model_errors)
|
||||
|
||||
evaluation_result = EvaluationResult(results=all_results, model_errors=errors)
|
||||
evaluation_result.pii_precision = pii_precision
|
||||
evaluation_result.pii_recall = pii_recall
|
||||
evaluation_result.entity_recall_dict = entity_recall
|
||||
evaluation_result.entity_precision_dict = entity_precision
|
||||
evaluation_result.pii_f = pii_f_beta
|
||||
evaluation_result.n = n
|
||||
evaluation_result = EvaluationResult(
|
||||
results=all_results,
|
||||
model_errors=errors,
|
||||
pii_precision=pii_precision,
|
||||
pii_recall=pii_recall,
|
||||
entity_recall_dict=entity_recall,
|
||||
entity_precision_dict=entity_precision,
|
||||
n_dict=n,
|
||||
pii_f=pii_f_beta,
|
||||
n=sum(n.values()),
|
||||
)
|
||||
|
||||
return evaluation_result
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import os
|
||||
|
||||
from .experiment_tracker import ExperimentTracker
|
||||
|
||||
try:
|
||||
from comet_ml import Experiment
|
||||
except ImportError:
|
||||
Experiment = None
|
||||
|
||||
|
||||
def get_experiment_tracker():
|
||||
framework = os.environ.get("tracking_framework", None)
|
||||
if not framework or not Experiment:
|
||||
return ExperimentTracker()
|
||||
elif framework == "comet":
|
||||
return Experiment(
|
||||
api_key=os.environ.get("API_KEY"),
|
||||
project_name=os.environ.get("PROJECT_NAME"),
|
||||
workspace=os.environ.get("WORKSPACE"),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ExperimentTracker", "get_experiment_tracker"]
|
|
@ -1,6 +1,6 @@
|
|||
import hashlib
|
||||
import json
|
||||
from typing import Dict
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class ExperimentTracker:
|
||||
|
@ -8,7 +8,8 @@ class ExperimentTracker:
|
|||
self.parameters = None
|
||||
self.metrics = None
|
||||
self.dataset_info = None
|
||||
self.dataset_hash = None
|
||||
self.confusion_matrix = None
|
||||
self.labels = None
|
||||
|
||||
def log_parameter(self, key: str, value: object):
|
||||
self.parameters[key] = value
|
||||
|
@ -24,17 +25,29 @@ class ExperimentTracker:
|
|||
for k, v in metrics.values():
|
||||
self.log_metric(k, v)
|
||||
|
||||
def log_dataset_hash(self, dataset_hash):
|
||||
self.dataset_hash = dataset_hash
|
||||
def log_dataset_hash(self, data: str):
|
||||
pass
|
||||
|
||||
def log_dataset_info(self, dataset_info):
|
||||
self.dataset_info = dataset_info
|
||||
def log_dataset_info(self, name: str):
|
||||
self.dataset_info = name
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(self.__dict__)
|
||||
|
||||
def log_confusion_matrix(
|
||||
self,
|
||||
matrix: List[List[int]],
|
||||
labels=List[str],
|
||||
):
|
||||
self.confusion_matrix = matrix
|
||||
self.labels=labels
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
pass
|
||||
datetime_val = time.strftime("%Y%m%d-%H%M%S")
|
||||
filename = f"experiment_{datetime_val}.json"
|
||||
print(f"saving experiment data to {filename}")
|
||||
with open(filename) as json_file:
|
||||
json.dump(self.__dict__, json_file)
|
|
@ -1,14 +1,15 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from presidio_evaluator import InputSample
|
||||
from presidio_evaluator import InputSample, io_to_scheme
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
labeling_scheme: str = "BILUO",
|
||||
labeling_scheme: str = "BIO",
|
||||
entities_to_keep: List[str] = None,
|
||||
entity_mapping: Optional[Dict[str, str]] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
|
||||
|
@ -18,12 +19,16 @@ class BaseModel(ABC):
|
|||
entities are ignored. If None, none are filtered
|
||||
:param labeling_scheme: Used to translate (if needed)
|
||||
the prediction to a specific scheme (IO, BIO/IOB, BILUO)
|
||||
:param entity_mapping: Dictionary for mapping this model's input and output with the expected.
|
||||
Keys should be the input entity types (from the input dataset),
|
||||
values should be the model's supported entity types.
|
||||
:param verbose: Whether to print more debug info
|
||||
|
||||
|
||||
"""
|
||||
self.entities = entities_to_keep
|
||||
self.labeling_scheme = labeling_scheme
|
||||
self.entity_mapping = entity_mapping
|
||||
self.verbose = verbose
|
||||
|
||||
@abstractmethod
|
||||
|
@ -31,11 +36,64 @@ class BaseModel(ABC):
|
|||
"""
|
||||
Abstract. Returns the predicted tokens/spans from the evaluated model
|
||||
:param sample: Sample to be evaluated
|
||||
:return: if self.use spans: list of spans
|
||||
if not self.use_spans: tags in self.labeling_scheme format
|
||||
:return: List of tags in self.labeling_scheme format
|
||||
"""
|
||||
pass
|
||||
|
||||
def align_entity_types(self, sample: InputSample) -> None:
|
||||
"""
|
||||
Translates the sample's tags to the ones requested by the model
|
||||
:param sample: Input sample
|
||||
:return: None
|
||||
"""
|
||||
if self.entity_mapping:
|
||||
sample.translate_input_sample_tags(dictionary=self.entity_mapping)
|
||||
|
||||
def align_prediction_types(self, tags: List[str]) -> List[str]:
|
||||
"""
|
||||
Turns the model's output from the model tags to the input tags.
|
||||
:param tags: List of tags (entity names in IO or "O")
|
||||
:return: New tags
|
||||
"""
|
||||
if not self.entity_mapping:
|
||||
return tags
|
||||
|
||||
inverse_mapping = {v: k for k, v in self.entity_mapping.items()}
|
||||
new_tags = [
|
||||
InputSample.translate_tag(
|
||||
tag, dictionary=inverse_mapping, ignore_unknown=True
|
||||
)
|
||||
for tag in tags
|
||||
]
|
||||
return new_tags
|
||||
|
||||
def filter_tags_in_supported_entities(self, tags: List[str]) -> List[str]:
|
||||
"""
|
||||
Replaces tags of unwanted entities with O.
|
||||
:param tags: Lits of tags
|
||||
:return: List of tags where tags not in self.entities are considered "O"
|
||||
"""
|
||||
if not self.entities:
|
||||
return tags
|
||||
return [tag if self._tag_in_entities(tag) else "O" for tag in tags]
|
||||
|
||||
def to_scheme(self, tags: List[str]):
|
||||
"""
|
||||
Translates IO tags to BIO/BILUO based on the input labeling_scheme
|
||||
:param tags: Current tags in IO
|
||||
:return: Tags in labeling scheme
|
||||
"""
|
||||
|
||||
io_tags = [self._to_io(tag) for tag in tags]
|
||||
|
||||
return io_to_scheme(io_tags=io_tags, scheme=self.labeling_scheme)
|
||||
|
||||
@staticmethod
|
||||
def _to_io(tag):
|
||||
if "-" in tag:
|
||||
return tag[2:]
|
||||
return tag
|
||||
|
||||
def to_log(self) -> Dict:
|
||||
"""
|
||||
Returns a dictionary of parameters for logging purposes.
|
||||
|
@ -45,3 +103,15 @@ class BaseModel(ABC):
|
|||
"labeling_scheme": self.labeling_scheme,
|
||||
"entities_to_keep": self.entities,
|
||||
}
|
||||
|
||||
def _tag_in_entities(self, tag: str):
|
||||
if not self.entities:
|
||||
return True
|
||||
|
||||
if tag == "O":
|
||||
return True
|
||||
|
||||
if tag[1] == "-": # BIO/BILUO
|
||||
return tag[2:] in self.entities
|
||||
else: # IO
|
||||
return tag in self.entities
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pickle
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from presidio_evaluator import InputSample
|
||||
from presidio_evaluator.models import BaseModel
|
||||
|
@ -11,10 +11,12 @@ class CRFModel(BaseModel):
|
|||
model_pickle_path: str = "../models/crf.pickle",
|
||||
entities_to_keep: List[str] = None,
|
||||
verbose: bool = False,
|
||||
entity_mapping: Dict[str, str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
entities_to_keep=entities_to_keep,
|
||||
verbose=verbose,
|
||||
entity_mapping=entity_mapping
|
||||
)
|
||||
|
||||
if model_pickle_path is None:
|
||||
|
@ -26,12 +28,8 @@ class CRFModel(BaseModel):
|
|||
def predict(self, sample: InputSample) -> List[str]:
|
||||
tags = CRFModel.crf_predict(sample, self.model)
|
||||
|
||||
if self.entities:
|
||||
tags = [tag for tag in tags if tag in self.entities]
|
||||
|
||||
if len(tags) != len(sample.tokens):
|
||||
print("mismatch between previous tokens and new tokens")
|
||||
# translated_tags = sample.rename_from_spacy_tags(tags)
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from typing import List, Optional, Dict
|
||||
from typing import List, Dict
|
||||
|
||||
import spacy
|
||||
|
||||
from presidio_evaluator.data_objects import PRESIDIO_SPACY_ENTITIES
|
||||
|
||||
try:
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
|
@ -9,8 +11,7 @@ try:
|
|||
except ImportError:
|
||||
print("Flair is not installed by default")
|
||||
|
||||
from presidio_evaluator.data_objects import PRESIDIO_SPACY_ENTITIES
|
||||
from presidio_evaluator import InputSample
|
||||
from presidio_evaluator import InputSample, tokenize, span_to_tag
|
||||
from presidio_evaluator.models import BaseModel
|
||||
|
||||
|
||||
|
@ -21,6 +22,7 @@ class FlairModel(BaseModel):
|
|||
model_path: str = None,
|
||||
entities_to_keep: List[str] = None,
|
||||
verbose: bool = False,
|
||||
entity_mapping: Dict[str, str] = PRESIDIO_SPACY_ENTITIES,
|
||||
):
|
||||
"""
|
||||
Evaluator for Flair models
|
||||
|
@ -33,6 +35,7 @@ class FlairModel(BaseModel):
|
|||
super().__init__(
|
||||
entities_to_keep=entities_to_keep,
|
||||
verbose=verbose,
|
||||
entity_mapping=entity_mapping,
|
||||
)
|
||||
if model is None:
|
||||
if model_path is None:
|
||||
|
@ -48,23 +51,34 @@ class FlairModel(BaseModel):
|
|||
sentence = Sentence(text=sample.full_text, use_tokenizer=self.spacy_tokenizer)
|
||||
self.model.predict(sentence)
|
||||
|
||||
tags = self.get_tags_from_sentence(sentence)
|
||||
if len(tags) != len(sample.tokens):
|
||||
print("mismatch between previous tokens and new tokens")
|
||||
ents = sentence.get_spans("ner")
|
||||
if ents:
|
||||
tags, texts, start, end = zip(
|
||||
*[(ent.tag, ent.text, ent.start_pos, ent.end_pos) for ent in ents]
|
||||
)
|
||||
|
||||
if self.entities:
|
||||
tags = [tag for tag in tags if tag in self.entities]
|
||||
tags = [
|
||||
tag if tag != "PER" else "PERSON" for tag in tags
|
||||
] # Flair's tag for PERSON is PER
|
||||
|
||||
# Flair tokens might not be consistent with spaCy's tokens (even when using spacy tokenizer)
|
||||
# Use spacy tokenization and not stanza to maintain consistency with other models:
|
||||
if not sample.tokens:
|
||||
sample.tokens = tokenize(sample.full_text)
|
||||
|
||||
# Create tags (label per token) based on stanza spans and spacy tokens
|
||||
tags = span_to_tag(
|
||||
scheme="IO",
|
||||
text=sample.full_text,
|
||||
starts=start,
|
||||
ends=end,
|
||||
tags=tags,
|
||||
tokens=sample.tokens,
|
||||
)
|
||||
else:
|
||||
tags = ["O" for _ in range(len(sample.tokens))]
|
||||
|
||||
if len(tags) != len(sample.tokens):
|
||||
print("mismatch between input tokens and new tokens")
|
||||
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_from_sentence(sentence):
|
||||
tags = []
|
||||
for token in sentence:
|
||||
tags.append(token.get_tag("ner").value)
|
||||
|
||||
new_tags = []
|
||||
for tag in tags:
|
||||
new_tags.append("PERSON" if tag == "PER" else tag)
|
||||
|
||||
return new_tags
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from presidio_analyzer import AnalyzerEngine
|
||||
|
||||
|
@ -15,6 +15,7 @@ class PresidioAnalyzerWrapper(BaseModel):
|
|||
labeling_scheme: str = "BIO",
|
||||
score_threshold: float = 0.4,
|
||||
language: str = "en",
|
||||
entity_mapping:Optional[Dict[str,str]]=None
|
||||
):
|
||||
"""
|
||||
Evaluation wrapper for the Presidio Analyzer
|
||||
|
@ -24,6 +25,7 @@ class PresidioAnalyzerWrapper(BaseModel):
|
|||
entities_to_keep=entities_to_keep,
|
||||
verbose=verbose,
|
||||
labeling_scheme=labeling_scheme,
|
||||
entity_mapping=entity_mapping
|
||||
)
|
||||
self.score_threshold = score_threshold
|
||||
self.language = language
|
||||
|
@ -53,19 +55,20 @@ class PresidioAnalyzerWrapper(BaseModel):
|
|||
scores.append(res.score)
|
||||
|
||||
response_tags = span_to_tag(
|
||||
scheme=self.labeling_scheme,
|
||||
scheme="IO",
|
||||
text=sample.full_text,
|
||||
start=starts,
|
||||
end=ends,
|
||||
starts=starts,
|
||||
ends=ends,
|
||||
tokens=sample.tokens,
|
||||
scores=scores,
|
||||
tag=tags,
|
||||
tags=tags,
|
||||
)
|
||||
return response_tags
|
||||
|
||||
# Mapping between dataset entities and Presidio entities. Key: Dataset entity, Value: Presidio entity
|
||||
presidio_entities_map = {
|
||||
"PERSON": "PERSON",
|
||||
"GPE": "LOCATION",
|
||||
"EMAIL_ADDRESS": "EMAIL_ADDRESS",
|
||||
"CREDIT_CARD": "CREDIT_CARD",
|
||||
"FIRST_NAME": "PERSON",
|
||||
|
@ -84,8 +87,12 @@ class PresidioAnalyzerWrapper(BaseModel):
|
|||
"IP_ADDRESS": "IP_ADDRESS",
|
||||
"ORGANIZATION": "ORG",
|
||||
"US_DRIVER_LICENSE": "US_DRIVER_LICENSE",
|
||||
"TITLE": "O",
|
||||
"PREFIX": "O",
|
||||
"NRP": "NRP",
|
||||
"TITLE": "O", # not supported
|
||||
"PREFIX": "O", # not supported
|
||||
"STREET_ADDRESS": "O", # not supported
|
||||
"ZIP_CODE": "O", # not supported
|
||||
"AGE": "O", # not supported
|
||||
"O": "O",
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from presidio_analyzer import EntityRecognizer
|
||||
from presidio_analyzer.nlp_engine import NlpEngine
|
||||
|
@ -16,6 +16,7 @@ class PresidioRecognizerWrapper(BaseModel):
|
|||
entities_to_keep: List[str] = None,
|
||||
labeling_scheme: str = "BILUO",
|
||||
with_nlp_artifacts: bool = False,
|
||||
entity_mapping: Optional[Dict[str, str]] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -32,6 +33,7 @@ class PresidioRecognizerWrapper(BaseModel):
|
|||
entities_to_keep=entities_to_keep,
|
||||
verbose=verbose,
|
||||
labeling_scheme=labeling_scheme,
|
||||
entity_mapping=entity_mapping
|
||||
)
|
||||
self.with_nlp_artifacts = with_nlp_artifacts
|
||||
self.recognizer = recognizer
|
||||
|
@ -62,14 +64,12 @@ class PresidioRecognizerWrapper(BaseModel):
|
|||
tags.append(res.entity_type)
|
||||
scores.append(res.score)
|
||||
response_tags = span_to_tag(
|
||||
scheme=self.labeling_scheme,
|
||||
scheme="IO",
|
||||
text=sample.full_text,
|
||||
start=starts,
|
||||
end=ends,
|
||||
tag=tags,
|
||||
starts=starts,
|
||||
ends=ends,
|
||||
tags=tags,
|
||||
tokens=sample.tokens,
|
||||
scores=scores,
|
||||
)
|
||||
if len(sample.tags) == 0:
|
||||
sample.tags = ["0" for _ in response_tags]
|
||||
return response_tags
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import List
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import spacy
|
||||
|
||||
from presidio_evaluator.data_objects import PRESIDIO_SPACY_ENTITIES
|
||||
from presidio_evaluator import InputSample
|
||||
from presidio_evaluator.data_objects import PRESIDIO_SPACY_ENTITIES
|
||||
from presidio_evaluator.models import BaseModel
|
||||
|
||||
|
||||
|
@ -15,12 +15,13 @@ class SpacyModel(BaseModel):
|
|||
entities_to_keep: List[str] = None,
|
||||
verbose: bool = False,
|
||||
labeling_scheme: str = "BIO",
|
||||
translate_to_spacy_entities=True,
|
||||
entity_mapping: Optional[Dict[str, str]] = PRESIDIO_SPACY_ENTITIES,
|
||||
):
|
||||
super().__init__(
|
||||
entities_to_keep=entities_to_keep,
|
||||
verbose=verbose,
|
||||
labeling_scheme=labeling_scheme,
|
||||
entity_mapping=entity_mapping
|
||||
)
|
||||
|
||||
if model is None:
|
||||
|
@ -30,18 +31,7 @@ class SpacyModel(BaseModel):
|
|||
else:
|
||||
self.model = model
|
||||
|
||||
self.translate_to_spacy_entities = translate_to_spacy_entities
|
||||
if self.translate_to_spacy_entities:
|
||||
print(
|
||||
"Translating entites using this dictionary: {}".format(
|
||||
PRESIDIO_SPACY_ENTITIES
|
||||
)
|
||||
)
|
||||
|
||||
def predict(self, sample: InputSample) -> List[str]:
|
||||
if self.translate_to_spacy_entities:
|
||||
sample.translate_input_sample_tags()
|
||||
|
||||
doc = self.model(sample.full_text)
|
||||
tags = self.get_tags_from_doc(doc)
|
||||
if len(doc) != len(sample.tokens):
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from typing import List
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import spacy
|
||||
|
||||
from presidio_evaluator import InputSample, Span, span_to_tag, tokenize
|
||||
from presidio_evaluator import InputSample, span_to_tag, tokenize
|
||||
from presidio_evaluator.data_objects import PRESIDIO_SPACY_ENTITIES
|
||||
|
||||
try:
|
||||
import spacy_stanza
|
||||
|
@ -20,7 +21,7 @@ class StanzaModel(SpacyModel):
|
|||
entities_to_keep: List[str] = None,
|
||||
verbose: bool = False,
|
||||
labeling_scheme: str = "BIO",
|
||||
translate_to_spacy_entities=True,
|
||||
entity_mapping: Optional[Dict[str, str]] = PRESIDIO_SPACY_ENTITIES,
|
||||
):
|
||||
if not model and not model_name:
|
||||
raise ValueError("Either model_name or model object must be supplied")
|
||||
|
@ -35,12 +36,10 @@ class StanzaModel(SpacyModel):
|
|||
entities_to_keep=entities_to_keep,
|
||||
verbose=verbose,
|
||||
labeling_scheme=labeling_scheme,
|
||||
translate_to_spacy_entities=translate_to_spacy_entities,
|
||||
entity_mapping=entity_mapping,
|
||||
)
|
||||
|
||||
def predict(self, sample: InputSample) -> List[str]:
|
||||
if self.translate_to_spacy_entities:
|
||||
sample.translate_input_sample_tags()
|
||||
|
||||
doc = self.model(sample.full_text)
|
||||
if doc.ents:
|
||||
|
@ -57,10 +56,10 @@ class StanzaModel(SpacyModel):
|
|||
tags = span_to_tag(
|
||||
scheme=self.labeling_scheme,
|
||||
text=sample.full_text,
|
||||
start=start,
|
||||
end=end,
|
||||
tag=tags,
|
||||
tokens=sample.tokens
|
||||
starts=start,
|
||||
ends=end,
|
||||
tags=tags,
|
||||
tokens=sample.tokens,
|
||||
)
|
||||
else:
|
||||
tags = ["O" for _ in range(len(sample.tokens))]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import spacy
|
||||
from spacy.tokens import Token
|
||||
from spacy.tokens import Doc
|
||||
|
||||
loaded_spacy = {}
|
||||
|
||||
|
@ -13,16 +13,13 @@ def get_spacy(loaded_spacy=loaded_spacy, model_version="en_core_web_sm"):
|
|||
return loaded_spacy[model_version]
|
||||
|
||||
|
||||
def tokenize(text, model_version="en_core_web_sm"):
|
||||
def tokenize(text, model_version="en_core_web_sm") -> Doc:
|
||||
return get_spacy(model_version=model_version)(text)
|
||||
|
||||
|
||||
def _get_detailed_tags(scheme, cur_tags):
|
||||
def _get_detailed_tags_for_span(scheme: str, cur_tags: List[str]) -> List[str]:
|
||||
"""
|
||||
Replaces IO tags (e.g. PERSON PERSON) with IOB/BIO/BILOU tags
|
||||
:param cur_tags:
|
||||
:param scheme:
|
||||
:return:
|
||||
Replace IO tags (e.g. O PERSON PERSON) with IOB/BIO/BILOU tags.
|
||||
"""
|
||||
|
||||
if all([tag == "O" for tag in cur_tags]):
|
||||
|
@ -31,21 +28,21 @@ def _get_detailed_tags(scheme, cur_tags):
|
|||
return_tags = []
|
||||
if len(cur_tags) == 1:
|
||||
if scheme == "BILOU":
|
||||
return_tags.append("U-{}".format(cur_tags[0]))
|
||||
return_tags.append(f"U-{cur_tags[0]}")
|
||||
else:
|
||||
return_tags.append("I-{}".format(cur_tags[0]))
|
||||
return_tags.append(f"I-{cur_tags[0]}")
|
||||
elif len(cur_tags) > 0:
|
||||
tg = cur_tags[0]
|
||||
for j in range(0, len(cur_tags)):
|
||||
if j == 0:
|
||||
return_tags.append("B-{}".format(tg))
|
||||
return_tags.append(f"B-{tg}")
|
||||
elif j == len(cur_tags) - 1:
|
||||
if scheme == "BILOU":
|
||||
return_tags.append("L-{}".format(tg))
|
||||
return_tags.append(f"L-{tg}")
|
||||
else:
|
||||
return_tags.append("I-{}".format(tg))
|
||||
return_tags.append(f"I-{tg}")
|
||||
else:
|
||||
return_tags.append("I-{}".format(tg))
|
||||
return_tags.append(f"I-{tg}")
|
||||
return return_tags
|
||||
|
||||
|
||||
|
@ -105,11 +102,11 @@ def _handle_overlaps(start, end, tag, score):
|
|||
def span_to_tag(
|
||||
scheme: str,
|
||||
text: str,
|
||||
start: List[int],
|
||||
end: List[int],
|
||||
tag: List[str],
|
||||
starts: List[int],
|
||||
ends: List[int],
|
||||
tags: List[str],
|
||||
scores: Optional[List[float]] = None,
|
||||
tokens: Optional[List[spacy.tokens.Token]] = None,
|
||||
tokens: Optional[Doc] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Turns a list of start and end values with corresponding labels, into a NER
|
||||
|
@ -117,18 +114,18 @@ def span_to_tag(
|
|||
:param scheme: labeling scheme, either BILOU, BIO/IOB or IO
|
||||
:param text: input text
|
||||
:param tokens: text tokenized to tokens
|
||||
:param start: list of indices where entities in the text start
|
||||
:param end: list of indices where entities in the text end
|
||||
:param tag: list of entity names
|
||||
:param starts: list of indices where entities in the text start
|
||||
:param ends: list of indices where entities in the text end
|
||||
:param tags: list of entity names
|
||||
:param scores: score of tag (confidence)
|
||||
:return: list of strings, representing either BILOU or BIO for the input
|
||||
"""
|
||||
|
||||
if not scores:
|
||||
# assume all scores are of equal weight
|
||||
scores = [0.5 for start in start]
|
||||
scores = [0.5 for start in starts]
|
||||
|
||||
start, end, tag, scores = _handle_overlaps(start, end, tag, scores)
|
||||
starts, ends, tags, scores = _handle_overlaps(starts, ends, tags, scores)
|
||||
|
||||
if not tokens:
|
||||
tokens = tokenize(text)
|
||||
|
@ -136,22 +133,22 @@ def span_to_tag(
|
|||
io_tags = []
|
||||
for token in tokens:
|
||||
found = False
|
||||
for span_index in range(0, len(start)):
|
||||
for span_index in range(0, len(starts)):
|
||||
span_start_in_token = (
|
||||
token.idx <= start[span_index] <= token.idx + len(token.text)
|
||||
token.idx <= starts[span_index] <= token.idx + len(token.text)
|
||||
)
|
||||
span_end_in_token = (
|
||||
token.idx <= end[span_index] <= token.idx + len(token.text)
|
||||
token.idx <= ends[span_index] <= token.idx + len(token.text)
|
||||
)
|
||||
if (
|
||||
start[span_index] <= token.idx < end[span_index]
|
||||
starts[span_index] <= token.idx < ends[span_index]
|
||||
): # token start is between start and end
|
||||
io_tags.append(tag[span_index])
|
||||
io_tags.append(tags[span_index])
|
||||
found = True
|
||||
elif (
|
||||
span_start_in_token and span_end_in_token
|
||||
): # span is within token boundaries (special case)
|
||||
io_tags.append(tag[span_index])
|
||||
io_tags.append(tags[span_index])
|
||||
found = True
|
||||
if found:
|
||||
break
|
||||
|
@ -161,8 +158,15 @@ def span_to_tag(
|
|||
|
||||
if scheme == "IO":
|
||||
return io_tags
|
||||
else:
|
||||
return io_to_scheme(io_tags, scheme)
|
||||
|
||||
# Set tagging based on scheme (BIO/IOB or BILOU)
|
||||
|
||||
def io_to_scheme(io_tags: List[str], scheme: str) -> List[str]:
|
||||
"""Set tagging based on scheme (BIO/IOB or BILOU).
|
||||
:param io_tags: List of tags in IO (e.g. O O O PERSON PERSON O)
|
||||
:param scheme: Requested scheme (IO, BILUO or BIO)
|
||||
"""
|
||||
current_tag = ""
|
||||
span_index = 0
|
||||
changes = []
|
||||
|
@ -172,13 +176,11 @@ def span_to_tag(
|
|||
span_index += 1
|
||||
current_tag = io_tag
|
||||
changes.append(len(io_tags))
|
||||
|
||||
new_return_tags = []
|
||||
for i in range(len(changes) - 1):
|
||||
new_return_tags.extend(
|
||||
_get_detailed_tags(
|
||||
_get_detailed_tags_for_span(
|
||||
scheme=scheme, cur_tags=io_tags[changes[i] : changes[i + 1]]
|
||||
)
|
||||
)
|
||||
|
||||
return new_return_tags
|
||||
|
|
|
@ -16,4 +16,5 @@ presidio_analyzer
|
|||
presidio_anonymizer
|
||||
requests>=2.25.1
|
||||
xmltodict>=0.12.0
|
||||
python-dotenv
|
||||
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0.tar.gz#egg=en_core_web_sm
|
3
setup.py
3
setup.py
|
@ -18,7 +18,7 @@ setup(
|
|||
long_description_content_type="text/markdown",
|
||||
version=__version__,
|
||||
packages=find_packages(exclude=["tests"]),
|
||||
url="https://www.github.com/microsoft/presidio",
|
||||
url="https://www.github.com/microsoft/presidio-research",
|
||||
license="MIT",
|
||||
description="PII dataset generator, model evaluator for Presidio and PII data in general",
|
||||
data_files=[
|
||||
|
@ -46,5 +46,6 @@ setup(
|
|||
"schwifty",
|
||||
"faker",
|
||||
"sklearn_crfsuite",
|
||||
"python-dotenv"
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
from collections import Counter
|
||||
|
||||
import pytest
|
||||
|
||||
from presidio_evaluator.evaluation import EvaluationResult, Evaluator
|
||||
from tests.mocks import (
|
||||
MockTokensModel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def scores():
|
||||
results = Counter(
|
||||
{
|
||||
("O", "O"): 30,
|
||||
("ANIMAL", "ANIMAL"): 4,
|
||||
("ANIMAL", "O"): 2,
|
||||
("O", "ANIMAL"): 1,
|
||||
("PERSON", "PERSON"): 2,
|
||||
}
|
||||
)
|
||||
model = MockTokensModel(prediction=None)
|
||||
evaluator = Evaluator(model=model)
|
||||
evaluation_result = EvaluationResult(results=results)
|
||||
|
||||
return evaluator.calculate_score([evaluation_result])
|
||||
|
||||
|
||||
def test_to_confusion_matrix(scores):
|
||||
entities, confmatrix = scores.to_confusion_matrix()
|
||||
assert "O" in entities
|
||||
assert "PERSON" in entities
|
||||
assert "ANIMAL" in entities
|
||||
assert confmatrix == [[4, 2, 0], [1, 30, 0], [0, 0, 2]]
|
||||
|
||||
|
||||
def test_str(scores):
|
||||
return_str = str(scores)
|
||||
assert (
|
||||
"PERSON 100.00% 100.00% 2"
|
||||
in return_str
|
||||
)
|
||||
assert (
|
||||
"ANIMAL 80.00% 66.67% 6"
|
||||
in return_str
|
||||
)
|
||||
assert (
|
||||
"PII 85.71% 75.00% 8" in return_str
|
||||
)
|
|
@ -5,8 +5,8 @@ from presidio_evaluator.evaluation import Evaluator
|
|||
|
||||
try:
|
||||
from flair.models import SequenceTagger
|
||||
except:
|
||||
ImportError("Flair is not installed by default")
|
||||
except ImportError:
|
||||
pytest.skip("Flair not available", allow_module_level=True)
|
||||
|
||||
|
||||
from presidio_evaluator.models.flair_model import FlairModel
|
||||
|
@ -15,7 +15,6 @@ import numpy as np
|
|||
|
||||
|
||||
# no-unit because flair is not a dependency by default
|
||||
@pytest.mark.skip(reason="Flair not installed by default")
|
||||
def test_flair_simple():
|
||||
import os
|
||||
|
||||
|
@ -24,9 +23,8 @@ def test_flair_simple():
|
|||
os.path.join(dir_path, "data/generated_small.json")
|
||||
)
|
||||
|
||||
model = SequenceTagger.load("ner-ontonotes-fast") # .load('ner')
|
||||
|
||||
flair_model = FlairModel(model=model, entities_to_keep=["PERSON"])
|
||||
flair_model = FlairModel(model_path="ner", entities_to_keep=["PERSON"])
|
||||
evaluator = Evaluator(model=flair_model)
|
||||
evaluation_results = evaluator.evaluate_all(input_samples)
|
||||
scores = evaluator.calculate_score(evaluation_results)
|
||||
|
|
|
@ -115,8 +115,8 @@ def test_span_to_bio_multiple_entities():
|
|||
|
||||
tag = ["NAME", "NAME"]
|
||||
|
||||
bilou = span_to_tag(scheme=BIO_SCHEME, text=text, start=start,
|
||||
end=end, tag=tag)
|
||||
bilou = span_to_tag(scheme=BIO_SCHEME, text=text, starts=start,
|
||||
ends=end, tags=tag)
|
||||
|
||||
print(bilou)
|
||||
|
||||
|
@ -180,7 +180,7 @@ def test_overlapping_entities_second_embedded_in_first_with_lower_score():
|
|||
expected = ['O', 'O', 'O', 'O', 'O', 'PHONE_NUMBER', 'PHONE_NUMBER',
|
||||
'PHONE_NUMBER', 'PHONE_NUMBER',
|
||||
'O', 'O', 'O', 'O']
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores)
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, starts=start, ends=end, tags=tag, scores=scores)
|
||||
assert io == expected
|
||||
|
||||
|
||||
|
@ -193,7 +193,7 @@ def test_overlapping_entities_second_embedded_in_first_has_higher_score():
|
|||
expected = ['O', 'O', 'O', 'O', 'O', 'PHONE_NUMBER', 'US_PHONE_NUMBER',
|
||||
'PHONE_NUMBER', 'PHONE_NUMBER',
|
||||
'O', 'O', 'O', 'O']
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores)
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, starts=start, ends=end, tags=tag, scores=scores)
|
||||
assert io == expected
|
||||
|
||||
|
||||
|
@ -206,7 +206,7 @@ def test_overlapping_entities_second_embedded_in_first_has_lower_score():
|
|||
expected = ['O', 'O', 'O', 'O', 'O', 'PHONE_NUMBER', 'PHONE_NUMBER',
|
||||
'PHONE_NUMBER', 'PHONE_NUMBER',
|
||||
'O', 'O', 'O', 'O']
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores)
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, starts=start, ends=end, tags=tag, scores=scores)
|
||||
assert io == expected
|
||||
|
||||
|
||||
|
@ -218,7 +218,7 @@ def test_overlapping_entities_pyramid():
|
|||
tag = ["A1", "B2", "C3"]
|
||||
expected = ['O', 'O', 'O', 'O', 'O', 'A1', 'B2', 'C3', 'B2',
|
||||
'A1', 'O', 'O', 'O', 'O']
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores)
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, starts=start, ends=end, tags=tag, scores=scores)
|
||||
assert io == expected
|
||||
|
||||
|
||||
|
@ -232,6 +232,6 @@ def test_token_contains_span():
|
|||
scores = [1.0]
|
||||
tag = ["DOMAIN_NAME"]
|
||||
expected = ["O", "O", "O", "DOMAIN_NAME"]
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, start=start, end=end, tag=tag, scores=scores)
|
||||
io = span_to_tag(scheme=IO_SCHEME, text=text, starts=start, ends=end, tags=tag, scores=scores)
|
||||
assert io == expected
|
||||
# fmt: on
|
||||
|
|
Загрузка…
Ссылка в новой задаче