Merge pull request #603 from microsoft/shchikke
Changes NER notebook to use the dataset utilities
This commit is contained in:
Коммит
5bd7127f54
|
@ -53,7 +53,8 @@
|
|||
"from utils_nlp.dataset.ner_utils import read_conll_file\n",
|
||||
"from utils_nlp.dataset.url_utils import maybe_download\n",
|
||||
"from utils_nlp.models.transformers.named_entity_recognition import (\n",
|
||||
" TokenClassificationProcessor, TokenClassifier)\n"
|
||||
" TokenClassificationProcessor, TokenClassifier)\n",
|
||||
"from utils_nlp.models.transformers.named_entity_recognition import supported_models as SUPPORTED_MODELS"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -126,7 +127,7 @@
|
|||
"\n",
|
||||
"# model configurations\n",
|
||||
"NUM_TRAIN_EPOCHS = 5\n",
|
||||
"MODEL_NAME = \"bert-base-cased\"\n",
|
||||
"MODEL_NAME = \"distilbert-base-cased\"\n",
|
||||
"DO_LOWER_CASE = False\n",
|
||||
"MAX_SEQ_LENGTH = 200\n",
|
||||
"TRAILING_PIECE_TAG = \"X\"\n",
|
||||
|
@ -139,6 +140,118 @@
|
|||
" NUM_TRAIN_EPOCHS = 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Models that can be used for token classification task"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>supported models</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>albert-base-v1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>albert-base-v2</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>albert-large-v1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>albert-large-v2</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>albert-xlarge-v1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>65</th>\n",
|
||||
" <td>xlm-roberta-large-finetuned-conll02-spanish</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>66</th>\n",
|
||||
" <td>xlm-roberta-large-finetuned-conll03-english</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>67</th>\n",
|
||||
" <td>xlm-roberta-large-finetuned-conll03-german</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>68</th>\n",
|
||||
" <td>xlnet-base-cased</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>69</th>\n",
|
||||
" <td>xlnet-large-cased</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>70 rows × 1 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" supported models\n",
|
||||
"0 albert-base-v1\n",
|
||||
"1 albert-base-v2\n",
|
||||
"2 albert-large-v1\n",
|
||||
"3 albert-large-v2\n",
|
||||
"4 albert-xlarge-v1\n",
|
||||
".. ...\n",
|
||||
"65 xlm-roberta-large-finetuned-conll02-spanish\n",
|
||||
"66 xlm-roberta-large-finetuned-conll03-english\n",
|
||||
"67 xlm-roberta-large-finetuned-conll03-german\n",
|
||||
"68 xlnet-base-cased\n",
|
||||
"69 xlnet-large-cased\n",
|
||||
"\n",
|
||||
"[70 rows x 1 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd.DataFrame({\"supported models\": SUPPORTED_MODELS})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -152,15 +265,29 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 96.0/96.0 [00:00<00:00, 4.02kKB/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Maximum sequence length is: 144\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
|
@ -194,7 +321,128 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>sentence</th>\n",
|
||||
" <th>labels</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>[The, origin, of, Agotes, (, or, Cagots, ), is...</td>\n",
|
||||
" <td>[O, O, O, I-MISC, O, O, I-MISC, O, O, O, O]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>[-DOCSTART-]</td>\n",
|
||||
" <td>[O]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>[It, provides, full, -, and, part-time, polyte...</td>\n",
|
||||
" <td>[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>[Since, she, was, the, daughter, of, the, grea...</td>\n",
|
||||
" <td>[O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>[The, goals, were, two, posts, ,, with, no, cr...</td>\n",
|
||||
" <td>[O, O, O, O, O, O, O, O, O, O]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>[At, one, point, ,, so, many, orders, had, bee...</td>\n",
|
||||
" <td>[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>[Left, camp, in, July, 1972, ,, and, was, deal...</td>\n",
|
||||
" <td>[O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>[She, fled, again, to, Abra, ,, where, she, wa...</td>\n",
|
||||
" <td>[O, O, O, O, I-LOC, O, O, O, O, O, O]</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>[As, the, younger, sibling, ,, Ben, was, const...</td>\n",
|
||||
" <td>[O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>[Milepost, 1, :, granite, masonry, arch, over,...</td>\n",
|
||||
" <td>[O, O, O, O, O, O, O, I-LOC, I-LOC, O]</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" sentence \\\n",
|
||||
"0 [The, origin, of, Agotes, (, or, Cagots, ), is... \n",
|
||||
"1 [-DOCSTART-] \n",
|
||||
"2 [It, provides, full, -, and, part-time, polyte... \n",
|
||||
"3 [Since, she, was, the, daughter, of, the, grea... \n",
|
||||
"4 [The, goals, were, two, posts, ,, with, no, cr... \n",
|
||||
"5 [At, one, point, ,, so, many, orders, had, bee... \n",
|
||||
"6 [Left, camp, in, July, 1972, ,, and, was, deal... \n",
|
||||
"7 [She, fled, again, to, Abra, ,, where, she, wa... \n",
|
||||
"8 [As, the, younger, sibling, ,, Ben, was, const... \n",
|
||||
"9 [Milepost, 1, :, granite, masonry, arch, over,... \n",
|
||||
"\n",
|
||||
" labels \n",
|
||||
"0 [O, O, O, I-MISC, O, O, I-MISC, O, O, O, O] \n",
|
||||
"1 [O] \n",
|
||||
"2 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ... \n",
|
||||
"3 [O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI... \n",
|
||||
"4 [O, O, O, O, O, O, O, O, O, O] \n",
|
||||
"5 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ... \n",
|
||||
"6 [O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG... \n",
|
||||
"7 [O, O, O, O, I-LOC, O, O, O, O, O, O] \n",
|
||||
"8 [O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,... \n",
|
||||
"9 [O, O, O, O, O, O, O, I-LOC, I-LOC, O] "
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Show example sentences from input\n",
|
||||
"pd.DataFrame({\"sentence\": sentence_list, \"labels\": labels_list}).head(10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -297,12 +545,13 @@
|
|||
"10 Caloi I-ORG"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Show example tokens from input\n",
|
||||
"pd.DataFrame({\"token\": train_sentence_list[0], \"label\": train_labels_list[0]}).head(11)"
|
||||
]
|
||||
},
|
||||
|
@ -323,14 +572,62 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ea57217fe6394812af03defcdaffe4db",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, description='Downloading', max=411, style=ProgressStyle(description_width=…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "00884141779a4ddead34204d5ea01b41",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, description='Downloading', max=213450, style=ProgressStyle(description_wid…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:root:Token lists with length > 512 will be truncated\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:root:Token lists with length > 512 will be truncated\n",
|
||||
"WARNING:root:Token lists with length > 512 will be truncated\n"
|
||||
]
|
||||
}
|
||||
|
@ -376,18 +673,18 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d7c19dfe849b4bb3b195e792b0ccc809",
|
||||
"model_id": "7cd3a9259b5c42638e8580f9fbae27db",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, description='Downloading', max=435779157, style=ProgressStyle(description_…"
|
||||
"HBox(children=(IntProgress(value=0, description='Downloading', max=263273408, style=ProgressStyle(description_…"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
@ -397,22 +694,8 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/media/bleik2/backup/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training time : 0.075 hrs\n"
|
||||
"\n",
|
||||
"Training time : 0.060 hrs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -453,14 +736,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Scoring: 100%|██████████| 18/18 [00:08<00:00, 2.49it/s]"
|
||||
"Scoring: 100%|██████████| 35/35 [00:06<00:00, 6.14it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -498,7 +781,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -514,7 +797,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -523,13 +806,13 @@
|
|||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" MISC 0.68 0.67 0.68 221\n",
|
||||
" LOC 0.79 0.85 0.82 317\n",
|
||||
" ORG 0.73 0.81 0.76 274\n",
|
||||
" PER 0.92 0.93 0.92 257\n",
|
||||
" ORG 0.72 0.76 0.74 274\n",
|
||||
" MISC 0.67 0.73 0.70 221\n",
|
||||
" LOC 0.79 0.84 0.81 317\n",
|
||||
" PER 0.90 0.93 0.92 257\n",
|
||||
"\n",
|
||||
"micro avg 0.78 0.82 0.80 1069\n",
|
||||
"macro avg 0.78 0.82 0.80 1069\n",
|
||||
"micro avg 0.76 0.82 0.79 1069\n",
|
||||
"macro avg 0.77 0.82 0.79 1069\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
@ -559,7 +842,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -567,7 +850,7 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING:root:Token lists with length > 512 will be truncated\n",
|
||||
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 7.56it/s]"
|
||||
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 25.31it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -646,13 +929,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/scrapbook.scrap.json+json": {
|
||||
"data": 0.78,
|
||||
"data": 0.77,
|
||||
"encoder": "json",
|
||||
"name": "precision",
|
||||
"version": 1
|
||||
|
@ -688,7 +971,7 @@
|
|||
{
|
||||
"data": {
|
||||
"application/scrapbook.scrap.json+json": {
|
||||
"data": 0.8,
|
||||
"data": 0.79,
|
||||
"encoder": "json",
|
||||
"name": "f1",
|
||||
"version": 1
|
||||
|
@ -711,11 +994,18 @@
|
|||
"sb.glue(\"recall\", float(report_splits[3]))\n",
|
||||
"sb.glue(\"f1\", float(report_splits[4]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "nlp_gpu",
|
||||
"display_name": "Python (nlp_gpu)",
|
||||
"language": "python",
|
||||
"name": "nlp_gpu"
|
||||
},
|
||||
|
|
|
@ -39,7 +39,6 @@ def preprocess_conll(text, sep="\t"):
|
|||
s_split_split = [t.split(sep) for t in s_split]
|
||||
sentence_list.append([t[0] for t in s_split_split if len(t) > 1])
|
||||
labels_list.append([t[1] for t in s_split_split if len(t) > 1])
|
||||
|
||||
if len(s_split_split) > max_seq_len:
|
||||
max_seq_len = len(s_split_split)
|
||||
print("Maximum sequence length is: {0}".format(max_seq_len))
|
||||
|
|
|
@ -189,8 +189,7 @@ def load_dataset(
|
|||
label_map = TokenClassificationProcessor.create_label_map(
|
||||
label_lists=train_df["labels"], trailing_piece_tag=trailing_piece_tag
|
||||
)
|
||||
|
||||
train_dataset = processor.preprocess_for_bert(
|
||||
train_dataset = processor.preprocess(
|
||||
text=train_df["sentence"],
|
||||
max_len=max_len,
|
||||
labels=train_df["labels"],
|
||||
|
@ -198,7 +197,7 @@ def load_dataset(
|
|||
trailing_piece_tag=trailing_piece_tag,
|
||||
)
|
||||
|
||||
test_dataset = processor.preprocess_for_bert(
|
||||
test_dataset = processor.preprocess(
|
||||
text=test_df["sentence"],
|
||||
max_len=max_len,
|
||||
labels=test_df["labels"],
|
||||
|
|
Загрузка…
Ссылка в новой задаче