Merge pull request #603 from microsoft/shchikke

Changes NER notebook to use the dataset utilities
This commit is contained in:
Said Bleik 2020-06-23 11:22:34 -04:00 коммит произвёл GitHub
Родитель 82ed6f58bb 5c690efc9b
Коммит 5bd7127f54
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 334 добавлений и 46 удалений

Просмотреть файл

@ -53,7 +53,8 @@
"from utils_nlp.dataset.ner_utils import read_conll_file\n",
"from utils_nlp.dataset.url_utils import maybe_download\n",
"from utils_nlp.models.transformers.named_entity_recognition import (\n",
" TokenClassificationProcessor, TokenClassifier)\n"
" TokenClassificationProcessor, TokenClassifier)\n",
"from utils_nlp.models.transformers.named_entity_recognition import supported_models as SUPPORTED_MODELS"
]
},
{
@ -126,7 +127,7 @@
"\n",
"# model configurations\n",
"NUM_TRAIN_EPOCHS = 5\n",
"MODEL_NAME = \"bert-base-cased\"\n",
"MODEL_NAME = \"distilbert-base-cased\"\n",
"DO_LOWER_CASE = False\n",
"MAX_SEQ_LENGTH = 200\n",
"TRAILING_PIECE_TAG = \"X\"\n",
@ -139,6 +140,118 @@
" NUM_TRAIN_EPOCHS = 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Models that can be used for token classification task"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>supported models</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>albert-base-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>albert-base-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>albert-large-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>albert-large-v2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>albert-xlarge-v1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65</th>\n",
" <td>xlm-roberta-large-finetuned-conll02-spanish</td>\n",
" </tr>\n",
" <tr>\n",
" <th>66</th>\n",
" <td>xlm-roberta-large-finetuned-conll03-english</td>\n",
" </tr>\n",
" <tr>\n",
" <th>67</th>\n",
" <td>xlm-roberta-large-finetuned-conll03-german</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>xlnet-base-cased</td>\n",
" </tr>\n",
" <tr>\n",
" <th>69</th>\n",
" <td>xlnet-large-cased</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>70 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" supported models\n",
"0 albert-base-v1\n",
"1 albert-base-v2\n",
"2 albert-large-v1\n",
"3 albert-large-v2\n",
"4 albert-xlarge-v1\n",
".. ...\n",
"65 xlm-roberta-large-finetuned-conll02-spanish\n",
"66 xlm-roberta-large-finetuned-conll03-english\n",
"67 xlm-roberta-large-finetuned-conll03-german\n",
"68 xlnet-base-cased\n",
"69 xlnet-large-cased\n",
"\n",
"[70 rows x 1 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame({\"supported models\": SUPPORTED_MODELS})"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -152,15 +265,29 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 96.0/96.0 [00:00<00:00, 4.02kKB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum sequence length is: 144\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
@ -194,7 +321,128 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sentence</th>\n",
" <th>labels</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[The, origin, of, Agotes, (, or, Cagots, ), is...</td>\n",
" <td>[O, O, O, I-MISC, O, O, I-MISC, O, O, O, O]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[-DOCSTART-]</td>\n",
" <td>[O]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[It, provides, full, -, and, part-time, polyte...</td>\n",
" <td>[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[Since, she, was, the, daughter, of, the, grea...</td>\n",
" <td>[O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[The, goals, were, two, posts, ,, with, no, cr...</td>\n",
" <td>[O, O, O, O, O, O, O, O, O, O]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>[At, one, point, ,, so, many, orders, had, bee...</td>\n",
" <td>[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>[Left, camp, in, July, 1972, ,, and, was, deal...</td>\n",
" <td>[O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>[She, fled, again, to, Abra, ,, where, she, wa...</td>\n",
" <td>[O, O, O, O, I-LOC, O, O, O, O, O, O]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>[As, the, younger, sibling, ,, Ben, was, const...</td>\n",
" <td>[O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>[Milepost, 1, :, granite, masonry, arch, over,...</td>\n",
" <td>[O, O, O, O, O, O, O, I-LOC, I-LOC, O]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sentence \\\n",
"0 [The, origin, of, Agotes, (, or, Cagots, ), is... \n",
"1 [-DOCSTART-] \n",
"2 [It, provides, full, -, and, part-time, polyte... \n",
"3 [Since, she, was, the, daughter, of, the, grea... \n",
"4 [The, goals, were, two, posts, ,, with, no, cr... \n",
"5 [At, one, point, ,, so, many, orders, had, bee... \n",
"6 [Left, camp, in, July, 1972, ,, and, was, deal... \n",
"7 [She, fled, again, to, Abra, ,, where, she, wa... \n",
"8 [As, the, younger, sibling, ,, Ben, was, const... \n",
"9 [Milepost, 1, :, granite, masonry, arch, over,... \n",
"\n",
" labels \n",
"0 [O, O, O, I-MISC, O, O, I-MISC, O, O, O, O] \n",
"1 [O] \n",
"2 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ... \n",
"3 [O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI... \n",
"4 [O, O, O, O, O, O, O, O, O, O] \n",
"5 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ... \n",
"6 [O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG... \n",
"7 [O, O, O, O, I-LOC, O, O, O, O, O, O] \n",
"8 [O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,... \n",
"9 [O, O, O, O, O, O, O, I-LOC, I-LOC, O] "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show example sentences from input\n",
"pd.DataFrame({\"sentence\": sentence_list, \"labels\": labels_list}).head(10)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -297,12 +545,13 @@
"10 Caloi I-ORG"
]
},
"execution_count": 5,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show example tokens from input\n",
"pd.DataFrame({\"token\": train_sentence_list[0], \"label\": train_labels_list[0]}).head(11)"
]
},
@ -323,14 +572,62 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ea57217fe6394812af03defcdaffe4db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Downloading', max=411, style=ProgressStyle(description_width=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "00884141779a4ddead34204d5ea01b41",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Downloading', max=213450, style=ProgressStyle(description_wid…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:Token lists with length > 512 will be truncated\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:Token lists with length > 512 will be truncated\n",
"WARNING:root:Token lists with length > 512 will be truncated\n"
]
}
@ -376,18 +673,18 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d7c19dfe849b4bb3b195e792b0ccc809",
"model_id": "7cd3a9259b5c42638e8580f9fbae27db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Downloading', max=435779157, style=ProgressStyle(description_…"
"HBox(children=(IntProgress(value=0, description='Downloading', max=263273408, style=ProgressStyle(description_…"
]
},
"metadata": {},
@ -397,22 +694,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/bleik2/backup/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training time : 0.075 hrs\n"
"\n",
"Training time : 0.060 hrs\n"
]
}
],
@ -453,14 +736,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Scoring: 100%|██████████| 18/18 [00:08<00:00, 2.49it/s]"
"Scoring: 100%|██████████| 35/35 [00:06<00:00, 6.14it/s]"
]
},
{
@ -498,7 +781,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@ -514,7 +797,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -523,13 +806,13 @@
"text": [
" precision recall f1-score support\n",
"\n",
" MISC 0.68 0.67 0.68 221\n",
" LOC 0.79 0.85 0.82 317\n",
" ORG 0.73 0.81 0.76 274\n",
" PER 0.92 0.93 0.92 257\n",
" ORG 0.72 0.76 0.74 274\n",
" MISC 0.67 0.73 0.70 221\n",
" LOC 0.79 0.84 0.81 317\n",
" PER 0.90 0.93 0.92 257\n",
"\n",
"micro avg 0.78 0.82 0.80 1069\n",
"macro avg 0.78 0.82 0.80 1069\n",
"micro avg 0.76 0.82 0.79 1069\n",
"macro avg 0.77 0.82 0.79 1069\n",
"\n"
]
}
@ -559,7 +842,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"metadata": {},
"outputs": [
{
@ -567,7 +850,7 @@
"output_type": "stream",
"text": [
"WARNING:root:Token lists with length > 512 will be truncated\n",
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 7.56it/s]"
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 25.31it/s]"
]
},
{
@ -646,13 +929,13 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.78,
"data": 0.77,
"encoder": "json",
"name": "precision",
"version": 1
@ -688,7 +971,7 @@
{
"data": {
"application/scrapbook.scrap.json+json": {
"data": 0.8,
"data": 0.79,
"encoder": "json",
"name": "f1",
"version": 1
@ -711,11 +994,18 @@
"sb.glue(\"recall\", float(report_splits[3]))\n",
"sb.glue(\"f1\", float(report_splits[4]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "nlp_gpu",
"display_name": "Python (nlp_gpu)",
"language": "python",
"name": "nlp_gpu"
},

Просмотреть файл

@ -39,7 +39,6 @@ def preprocess_conll(text, sep="\t"):
s_split_split = [t.split(sep) for t in s_split]
sentence_list.append([t[0] for t in s_split_split if len(t) > 1])
labels_list.append([t[1] for t in s_split_split if len(t) > 1])
if len(s_split_split) > max_seq_len:
max_seq_len = len(s_split_split)
print("Maximum sequence length is: {0}".format(max_seq_len))

Просмотреть файл

@ -189,8 +189,7 @@ def load_dataset(
label_map = TokenClassificationProcessor.create_label_map(
label_lists=train_df["labels"], trailing_piece_tag=trailing_piece_tag
)
train_dataset = processor.preprocess_for_bert(
train_dataset = processor.preprocess(
text=train_df["sentence"],
max_len=max_len,
labels=train_df["labels"],
@ -198,7 +197,7 @@ def load_dataset(
trailing_piece_tag=trailing_piece_tag,
)
test_dataset = processor.preprocess_for_bert(
test_dataset = processor.preprocess(
text=test_df["sentence"],
max_len=max_len,
labels=test_df["labels"],