diff --git a/examples/named_entity_recognition/ner_wikigold_transformer.ipynb b/examples/named_entity_recognition/ner_wikigold_transformer.ipynb index 7f10fb3..84cddf9 100644 --- a/examples/named_entity_recognition/ner_wikigold_transformer.ipynb +++ b/examples/named_entity_recognition/ner_wikigold_transformer.ipynb @@ -53,7 +53,8 @@ "from utils_nlp.dataset.ner_utils import read_conll_file\n", "from utils_nlp.dataset.url_utils import maybe_download\n", "from utils_nlp.models.transformers.named_entity_recognition import (\n", - " TokenClassificationProcessor, TokenClassifier)\n" + " TokenClassificationProcessor, TokenClassifier)\n", + "from utils_nlp.models.transformers.named_entity_recognition import supported_models as SUPPORTED_MODELS" ] }, { @@ -126,7 +127,7 @@ "\n", "# model configurations\n", "NUM_TRAIN_EPOCHS = 5\n", - "MODEL_NAME = \"bert-base-cased\"\n", + "MODEL_NAME = \"distilbert-base-cased\"\n", "DO_LOWER_CASE = False\n", "MAX_SEQ_LENGTH = 200\n", "TRAILING_PIECE_TAG = \"X\"\n", @@ -139,6 +140,118 @@ " NUM_TRAIN_EPOCHS = 1" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Models that can be used for token classification task" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
supported models
0albert-base-v1
1albert-base-v2
2albert-large-v1
3albert-large-v2
4albert-xlarge-v1
......
65xlm-roberta-large-finetuned-conll02-spanish
66xlm-roberta-large-finetuned-conll03-english
67xlm-roberta-large-finetuned-conll03-german
68xlnet-base-cased
69xlnet-large-cased
\n", + "

70 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " supported models\n", + "0 albert-base-v1\n", + "1 albert-base-v2\n", + "2 albert-large-v1\n", + "3 albert-large-v2\n", + "4 albert-xlarge-v1\n", + ".. ...\n", + "65 xlm-roberta-large-finetuned-conll02-spanish\n", + "66 xlm-roberta-large-finetuned-conll03-english\n", + "67 xlm-roberta-large-finetuned-conll03-german\n", + "68 xlnet-base-cased\n", + "69 xlnet-large-cased\n", + "\n", + "[70 rows x 1 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame({\"supported models\": SUPPORTED_MODELS})" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -152,15 +265,29 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 96.0/96.0 [00:00<00:00, 4.02kKB/s]" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum sequence length is: 144\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ @@ -194,7 +321,128 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sentencelabels
0[The, origin, of, Agotes, (, or, Cagots, ), is...[O, O, O, I-MISC, O, O, I-MISC, O, O, O, O]
1[-DOCSTART-][O]
2[It, provides, full, -, and, part-time, polyte...[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...
3[Since, she, was, the, daughter, of, the, grea...[O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI...
4[The, goals, were, two, posts, ,, with, no, cr...[O, O, O, O, O, O, O, O, O, O]
5[At, one, point, ,, so, many, orders, had, bee...[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...
6[Left, camp, in, July, 1972, ,, and, was, deal...[O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG...
7[She, fled, again, to, Abra, ,, where, she, wa...[O, O, O, O, I-LOC, O, O, O, O, O, O]
8[As, the, younger, sibling, ,, Ben, was, const...[O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,...
9[Milepost, 1, :, granite, masonry, arch, over,...[O, O, O, O, O, O, O, I-LOC, I-LOC, O]
\n", + "
" + ], + "text/plain": [ + " sentence \\\n", + "0 [The, origin, of, Agotes, (, or, Cagots, ), is... \n", + "1 [-DOCSTART-] \n", + "2 [It, provides, full, -, and, part-time, polyte... \n", + "3 [Since, she, was, the, daughter, of, the, grea... \n", + "4 [The, goals, were, two, posts, ,, with, no, cr... \n", + "5 [At, one, point, ,, so, many, orders, had, bee... \n", + "6 [Left, camp, in, July, 1972, ,, and, was, deal... \n", + "7 [She, fled, again, to, Abra, ,, where, she, wa... \n", + "8 [As, the, younger, sibling, ,, Ben, was, const... \n", + "9 [Milepost, 1, :, granite, masonry, arch, over,... \n", + "\n", + " labels \n", + "0 [O, O, O, I-MISC, O, O, I-MISC, O, O, O, O] \n", + "1 [O] \n", + "2 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ... \n", + "3 [O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI... \n", + "4 [O, O, O, O, O, O, O, O, O, O] \n", + "5 [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ... \n", + "6 [O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG... \n", + "7 [O, O, O, O, I-LOC, O, O, O, O, O, O] \n", + "8 [O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,... \n", + "9 [O, O, O, O, O, O, O, I-LOC, I-LOC, O] " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Show example sentences from input\n", + "pd.DataFrame({\"sentence\": sentence_list, \"labels\": labels_list}).head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -297,12 +545,13 @@ "10 Caloi I-ORG" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Show example tokens from input\n", "pd.DataFrame({\"token\": train_sentence_list[0], \"label\": train_labels_list[0]}).head(11)" ] }, @@ -323,14 +572,62 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea57217fe6394812af03defcdaffe4db", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Downloading', max=411, style=ProgressStyle(description_width=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "00884141779a4ddead34204d5ea01b41", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Downloading', max=213450, style=ProgressStyle(description_wid…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Token lists with length > 512 will be truncated\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:Token lists with length > 512 will be truncated\n", "WARNING:root:Token lists with length > 512 will be truncated\n" ] } @@ -376,18 +673,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d7c19dfe849b4bb3b195e792b0ccc809", + "model_id": "7cd3a9259b5c42638e8580f9fbae27db", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Downloading', max=435779157, style=ProgressStyle(description_…" + "HBox(children=(IntProgress(value=0, description='Downloading', max=263273408, style=ProgressStyle(description_…" ] }, "metadata": {}, @@ -397,22 +694,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/media/bleik2/backup/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn('Was asked to gather along dimension 0, but all '\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training time : 0.075 hrs\n" + "\n", + "Training time : 0.060 hrs\n" ] } ], @@ -453,14 +736,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Scoring: 100%|██████████| 18/18 [00:08<00:00, 2.49it/s]" + "Scoring: 100%|██████████| 35/35 [00:06<00:00, 6.14it/s]" ] }, { @@ -498,7 +781,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -514,7 +797,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -523,13 +806,13 @@ "text": [ " precision recall f1-score support\n", "\n", - " MISC 0.68 0.67 0.68 221\n", - " LOC 0.79 0.85 0.82 317\n", - " ORG 0.73 0.81 0.76 274\n", - " PER 0.92 0.93 0.92 257\n", + " ORG 0.72 0.76 0.74 274\n", + " MISC 0.67 0.73 0.70 221\n", + " LOC 0.79 0.84 0.81 317\n", + " PER 0.90 0.93 0.92 257\n", "\n", - "micro avg 0.78 0.82 0.80 1069\n", - "macro avg 0.78 0.82 0.80 1069\n", + "micro avg 0.76 0.82 0.79 1069\n", + "macro avg 0.77 0.82 0.79 1069\n", "\n" ] } @@ -559,7 +842,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -567,7 +850,7 @@ "output_type": "stream", "text": [ "WARNING:root:Token lists with length > 512 will be truncated\n", - "Scoring: 100%|██████████| 1/1 [00:00<00:00, 7.56it/s]" + "Scoring: 100%|██████████| 1/1 [00:00<00:00, 25.31it/s]" ] }, { @@ -646,13 +929,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/scrapbook.scrap.json+json": { - "data": 0.78, + "data": 0.77, "encoder": "json", "name": "precision", "version": 1 @@ -688,7 +971,7 @@ { "data": { "application/scrapbook.scrap.json+json": { - "data": 0.8, + "data": 0.79, "encoder": "json", "name": "f1", "version": 1 @@ -711,11 +994,18 @@ "sb.glue(\"recall\", float(report_splits[3]))\n", "sb.glue(\"f1\", float(report_splits[4]))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "nlp_gpu", + "display_name": "Python (nlp_gpu)", "language": "python", "name": "nlp_gpu" }, diff --git a/utils_nlp/dataset/ner_utils.py b/utils_nlp/dataset/ner_utils.py index 347d596..3e8a5e9 100644 --- a/utils_nlp/dataset/ner_utils.py +++ b/utils_nlp/dataset/ner_utils.py @@ -39,7 +39,6 @@ def preprocess_conll(text, sep="\t"): s_split_split = [t.split(sep) for t in s_split] sentence_list.append([t[0] for t in s_split_split if len(t) > 1]) labels_list.append([t[1] for t in s_split_split if len(t) > 1]) - if len(s_split_split) > max_seq_len: max_seq_len = len(s_split_split) print("Maximum sequence length is: {0}".format(max_seq_len)) diff --git a/utils_nlp/dataset/wikigold.py b/utils_nlp/dataset/wikigold.py index 48a242d..ab3fee8 100644 --- a/utils_nlp/dataset/wikigold.py +++ b/utils_nlp/dataset/wikigold.py @@ -189,8 +189,7 @@ def load_dataset( label_map = TokenClassificationProcessor.create_label_map( label_lists=train_df["labels"], trailing_piece_tag=trailing_piece_tag ) - - train_dataset = processor.preprocess_for_bert( + train_dataset = processor.preprocess( text=train_df["sentence"], max_len=max_len, labels=train_df["labels"], @@ -198,7 +197,7 @@ def load_dataset( trailing_piece_tag=trailing_piece_tag, ) - test_dataset = processor.preprocess_for_bert( + test_dataset = processor.preprocess( text=test_df["sentence"], max_len=max_len, labels=test_df["labels"],