Merge pull request #586 from microsoft/bleik/add-models
update utils and examples
This commit is contained in:
Коммит
e02e3b5525
|
@ -15,37 +15,6 @@
|
||||||
"# Named Entity Recognition Using Transformer Model"
|
"# Named Entity Recognition Using Transformer Model"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Before You Start\n",
|
|
||||||
"\n",
|
|
||||||
"The running time shown in this notebook is on a Standard_NC6 Azure Deep Learning Virtual Machine with 1 NVIDIA Tesla K80 GPU. \n",
|
|
||||||
"> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
|
|
||||||
"\n",
|
|
||||||
"The table below provides some reference running time on different machine configurations. \n",
|
|
||||||
"\n",
|
|
||||||
"|QUICK_RUN|Machine Configurations|Running time|\n",
|
|
||||||
"|:---------|:----------------------|:------------|\n",
|
|
||||||
"|True|4 **CPU**s, 14GB memory| ~ 2 minutes|\n",
|
|
||||||
"|False|4 **CPU**s, 14GB memory| ~1.5 hours|\n",
|
|
||||||
"|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute|\n",
|
|
||||||
"|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 7 minutes |\n",
|
|
||||||
"\n",
|
|
||||||
"If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
|
|
||||||
"QUICK_RUN = False"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -61,29 +30,30 @@
|
||||||
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_architecture.png\">"
|
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_architecture.png\">"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Preparation"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import sys\n",
|
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"import random\n",
|
||||||
|
"import string\n",
|
||||||
|
"import sys\n",
|
||||||
|
"from tempfile import TemporaryDirectory\n",
|
||||||
|
"\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
"import scrapbook as sb\n",
|
"import scrapbook as sb\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"\n",
|
|
||||||
"from tempfile import TemporaryDirectory\n",
|
|
||||||
"from utils_nlp.dataset import wikigold\n",
|
|
||||||
"from utils_nlp.common.timer import Timer\n",
|
|
||||||
"from seqeval.metrics import classification_report\n",
|
"from seqeval.metrics import classification_report\n",
|
||||||
"from utils_nlp.models.transformers.named_entity_recognition import TokenClassifier"
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from utils_nlp.common.pytorch_utils import dataloader_from_dataset\n",
|
||||||
|
"from utils_nlp.common.timer import Timer\n",
|
||||||
|
"from utils_nlp.dataset import wikigold\n",
|
||||||
|
"from utils_nlp.dataset.ner_utils import read_conll_file\n",
|
||||||
|
"from utils_nlp.dataset.url_utils import maybe_download\n",
|
||||||
|
"from utils_nlp.models.transformers.named_entity_recognition import (\n",
|
||||||
|
" TokenClassificationProcessor, TokenClassifier)\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -93,9 +63,38 @@
|
||||||
"## Configuration"
|
"## Configuration"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The running time shown in this notebook is on a Standard_NC12 Azure Virtual Machine with 2 NVIDIA Tesla K80 GPUs. \n",
|
||||||
|
"> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
|
||||||
|
"\n",
|
||||||
|
"The table below provides some reference running time on different machine configurations. \n",
|
||||||
|
"\n",
|
||||||
|
"|QUICK_RUN|Machine Configurations|Running time|\n",
|
||||||
|
"|:---------|:----------------------|:------------|\n",
|
||||||
|
"|True|4 CPUs, 14GB memory| ~ 2 minutes|\n",
|
||||||
|
"|False|4 CPUs, 14GB memory| ~1.5 hours|\n",
|
||||||
|
"|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute|\n",
|
||||||
|
"|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 7 minutes |\n",
|
||||||
|
"\n",
|
||||||
|
"If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
|
||||||
|
"QUICK_RUN = False"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": [
|
"tags": [
|
||||||
"parameters"
|
"parameters"
|
||||||
|
@ -103,22 +102,17 @@
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Wikigold dataset\n",
|
||||||
|
"DATA_URL = (\n",
|
||||||
|
" \"https://raw.githubusercontent.com/juand-r/entity-recognition-datasets\"\n",
|
||||||
|
" \"/master/data/wikigold/CONLL-format/data/wikigold.conll.txt\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
"# fraction of the dataset used for testing\n",
|
"# fraction of the dataset used for testing\n",
|
||||||
"TEST_DATA_FRACTION = 0.3\n",
|
"TEST_DATA_FRACTION = 0.3\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# sub-sampling ratio for training\n",
|
"# sub-sampling ratio\n",
|
||||||
"TRAIN_SAMPLE_RATIO = 1\n",
|
"SAMPLE_RATIO = 1\n",
|
||||||
"\n",
|
|
||||||
"# sub-sampling ratio for testing\n",
|
|
||||||
"TEST_SAMPLE_RATIO = 1\n",
|
|
||||||
"\n",
|
|
||||||
"NUM_TRAIN_EPOCHS = 5\n",
|
|
||||||
"\n",
|
|
||||||
"# update variables for quick run option\n",
|
|
||||||
"if QUICK_RUN:\n",
|
|
||||||
" TRAIN_SAMPLE_RATIO = 0.1\n",
|
|
||||||
" TEST_SAMPLE_RATIO = 0.1\n",
|
|
||||||
" NUM_TRAIN_EPOCHS = 1\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# the data path used to save the downloaded data file\n",
|
"# the data path used to save the downloaded data file\n",
|
||||||
"DATA_PATH = TemporaryDirectory().name\n",
|
"DATA_PATH = TemporaryDirectory().name\n",
|
||||||
|
@ -131,16 +125,18 @@
|
||||||
"torch.manual_seed(RANDOM_SEED)\n",
|
"torch.manual_seed(RANDOM_SEED)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# model configurations\n",
|
"# model configurations\n",
|
||||||
|
"NUM_TRAIN_EPOCHS = 5\n",
|
||||||
"MODEL_NAME = \"bert-base-cased\"\n",
|
"MODEL_NAME = \"bert-base-cased\"\n",
|
||||||
"DO_LOWER_CASE = False\n",
|
"DO_LOWER_CASE = False\n",
|
||||||
"MAX_SEQ_LENGTH = 200\n",
|
"MAX_SEQ_LENGTH = 200\n",
|
||||||
"TRAILING_PIECE_TAG = \"X\"\n",
|
"TRAILING_PIECE_TAG = \"X\"\n",
|
||||||
"DEVICE = \"cuda\"\n",
|
"NUM_GPUS = None # uses all if available\n",
|
||||||
|
"BATCH_SIZE = 16\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if torch.cuda.is_available():\n",
|
"# update variables for quick run option\n",
|
||||||
" BATCH_SIZE = 16\n",
|
"if QUICK_RUN:\n",
|
||||||
"else:\n",
|
" SAMPLE_RATIO = 0.1\n",
|
||||||
" BATCH_SIZE = 8\n"
|
" NUM_TRAIN_EPOCHS = 1"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -151,45 +147,275 @@
|
||||||
"\n",
|
"\n",
|
||||||
"The dataset used in this notebook is the [wikigold dataset](https://www.aclweb.org/anthology/W09-3302). The wikigold dataset consists of 145 mannually labelled Wikipedia articles, including 1841 sentences and 40k tokens in total. The dataset can be directly downloaded from [here](https://github.com/juand-r/entity-recognition-datasets/tree/master/data/wikigold). \n",
|
"The dataset used in this notebook is the [wikigold dataset](https://www.aclweb.org/anthology/W09-3302). The wikigold dataset consists of 145 mannually labelled Wikipedia articles, including 1841 sentences and 40k tokens in total. The dataset can be directly downloaded from [here](https://github.com/juand-r/entity-recognition-datasets/tree/master/data/wikigold). \n",
|
||||||
"\n",
|
"\n",
|
||||||
"A helper function `load_dataset` downloads the raw wikigold data, splits it into training and testing datasets (also sub-sampling if the sampling ratio is smaller than 1.0), and then process for the transformer model. Everything is done in one function call, and you can use the processed training and testing Pytorch datasets to fine tune the model and evaluate the performance of the model."
|
"In the following cell, we download the data file, parse the tokens and labels, sample a given number of sentences, and split the dataset for training and testing."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Maximum sequence length is: 144\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"train_dataloader, test_dataloader, label_map, test_dataset = wikigold.load_dataset(\n",
|
"# download data\n",
|
||||||
" local_path=DATA_PATH,\n",
|
"file_name = DATA_URL.split(\"/\")[-1] # a name for the downloaded file\n",
|
||||||
" test_fraction=TEST_DATA_FRACTION,\n",
|
"maybe_download(DATA_URL, file_name, DATA_PATH)\n",
|
||||||
" random_seed=RANDOM_SEED,\n",
|
"data_file = os.path.join(DATA_PATH, file_name)\n",
|
||||||
" train_sample_ratio=TRAIN_SAMPLE_RATIO,\n",
|
"\n",
|
||||||
" test_sample_ratio=TEST_SAMPLE_RATIO,\n",
|
"# parse CoNll file\n",
|
||||||
" model_name=MODEL_NAME,\n",
|
"sentence_list, labels_list = read_conll_file(data_file, sep=\" \")\n",
|
||||||
" to_lower=DO_LOWER_CASE,\n",
|
"\n",
|
||||||
" cache_dir=CACHE_DIR,\n",
|
"# sub-sample (optional)\n",
|
||||||
" max_len=MAX_SEQ_LENGTH,\n",
|
"random.seed(RANDOM_SEED)\n",
|
||||||
" trailing_piece_tag=TRAILING_PIECE_TAG,\n",
|
"sample_size = int(SAMPLE_RATIO * len(sentence_list))\n",
|
||||||
" batch_size=BATCH_SIZE,\n",
|
"sentence_list, labels_list = list(\n",
|
||||||
" num_gpus=None\n",
|
" zip(*random.sample(list(zip(sentence_list, labels_list)), k=sample_size))\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# train-test split\n",
|
||||||
|
"train_sentence_list, test_sentence_list, train_labels_list, test_labels_list = train_test_split(\n",
|
||||||
|
" sentence_list, labels_list, test_size=TEST_DATA_FRACTION, random_state=RANDOM_SEED\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The following is an example input sentence of the training set."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>token</th>\n",
|
||||||
|
" <th>label</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>In</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>1999</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>,</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>the</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>Caloi</td>\n",
|
||||||
|
" <td>I-PER</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5</th>\n",
|
||||||
|
" <td>family</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>6</th>\n",
|
||||||
|
" <td>sold</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>7</th>\n",
|
||||||
|
" <td>the</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8</th>\n",
|
||||||
|
" <td>majority</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>9</th>\n",
|
||||||
|
" <td>of</td>\n",
|
||||||
|
" <td>O</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>10</th>\n",
|
||||||
|
" <td>Caloi</td>\n",
|
||||||
|
" <td>I-ORG</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" token label\n",
|
||||||
|
"0 In O\n",
|
||||||
|
"1 1999 O\n",
|
||||||
|
"2 , O\n",
|
||||||
|
"3 the O\n",
|
||||||
|
"4 Caloi I-PER\n",
|
||||||
|
"5 family O\n",
|
||||||
|
"6 sold O\n",
|
||||||
|
"7 the O\n",
|
||||||
|
"8 majority O\n",
|
||||||
|
"9 of O\n",
|
||||||
|
"10 Caloi I-ORG"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"pd.DataFrame({\"token\": train_sentence_list[0], \"label\": train_labels_list[0]}).head(11)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"> If your data is unlabeled, try using an annotation tool to simplify the process of labeling. The example [here](../annotation/Doccano.md) introduces [Doccanno](https://github.com/chakki-works/doccano) and shows how it can be used for NER annotation."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create PyTorch Datasets and Dataloaders\n",
|
||||||
|
"Given the tokenized input and corresponding labels, we use a custom processer to convert our input lists into a PyTorch dataset that can be used with our token classifier. Next, we create PyTorch dataloaders for training and testing."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"WARNING:root:Token lists with length > 512 will be truncated\n",
|
||||||
|
"WARNING:root:Token lists with length > 512 will be truncated\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"processor = TokenClassificationProcessor(model_name=MODEL_NAME, to_lower=DO_LOWER_CASE, cache_dir=CACHE_DIR)\n",
|
||||||
|
"\n",
|
||||||
|
"label_map = TokenClassificationProcessor.create_label_map(\n",
|
||||||
|
" label_lists=labels_list, trailing_piece_tag=TRAILING_PIECE_TAG\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = processor.preprocess(\n",
|
||||||
|
" text=train_sentence_list,\n",
|
||||||
|
" max_len=MAX_SEQ_LENGTH,\n",
|
||||||
|
" labels=train_labels_list,\n",
|
||||||
|
" label_map=label_map,\n",
|
||||||
|
" trailing_piece_tag=TRAILING_PIECE_TAG,\n",
|
||||||
|
")\n",
|
||||||
|
"train_dataloader = dataloader_from_dataset(\n",
|
||||||
|
" train_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True, distributed=False\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"test_dataset = processor.preprocess(\n",
|
||||||
|
" text=test_sentence_list,\n",
|
||||||
|
" max_len=MAX_SEQ_LENGTH,\n",
|
||||||
|
" labels=test_labels_list,\n",
|
||||||
|
" label_map=label_map,\n",
|
||||||
|
" trailing_piece_tag=TRAILING_PIECE_TAG,\n",
|
||||||
|
")\n",
|
||||||
|
"test_dataloader = dataloader_from_dataset(\n",
|
||||||
|
" test_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False, distributed=False\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Train Model\n",
|
"## Train Model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"There are two steps to train a NER model using pretrained transformer model: 1). instantiate a TokenClassifier class which is a wrapper of the transformer using BERT architecture, and 2), fit the model using the preprocessed training dataset. The member method `fit` of TokenClassifier class is used to fine tune the model."
|
"There are two steps to train a NER model using pretrained transformer model: 1) Instantiate a TokenClassifier class which is a wrapper of a transformer-based network, and 2) Fit the model using the preprocessed training dataloader. The member method `fit` of TokenClassifier class is used to fine-tune the model."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "d7c19dfe849b4bb3b195e792b0ccc809",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"HBox(children=(IntProgress(value=0, description='Downloading', max=435779157, style=ProgressStyle(description_…"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/media/bleik2/backup/miniconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
||||||
|
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training time : 0.075 hrs\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# Instantiate a TokenClassifier class for NER using pretrained transformer model\n",
|
"# Instantiate a TokenClassifier class for NER using pretrained transformer model\n",
|
||||||
"model = TokenClassifier(\n",
|
"model = TokenClassifier(\n",
|
||||||
|
@ -203,13 +429,13 @@
|
||||||
" model.fit(\n",
|
" model.fit(\n",
|
||||||
" train_dataloader=train_dataloader,\n",
|
" train_dataloader=train_dataloader,\n",
|
||||||
" num_epochs=NUM_TRAIN_EPOCHS,\n",
|
" num_epochs=NUM_TRAIN_EPOCHS,\n",
|
||||||
" num_gpus=None,\n",
|
" num_gpus=NUM_GPUS,\n",
|
||||||
" local_rank=-1,\n",
|
" local_rank=-1,\n",
|
||||||
" weight_decay=0.0,\n",
|
" weight_decay=0.0,\n",
|
||||||
" learning_rate=5e-5,\n",
|
" learning_rate=5e-5,\n",
|
||||||
" adam_epsilon=1e-8,\n",
|
" adam_epsilon=1e-8,\n",
|
||||||
" warmup_steps=0,\n",
|
" warmup_steps=0,\n",
|
||||||
" verbose=True,\n",
|
" verbose=False,\n",
|
||||||
" seed=RANDOM_SEED\n",
|
" seed=RANDOM_SEED\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -227,9 +453,31 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Scoring: 100%|██████████| 18/18 [00:08<00:00, 2.49it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Prediction time : 0.002 hrs\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"with Timer() as t:\n",
|
"with Timer() as t:\n",
|
||||||
" preds = model.predict(\n",
|
" preds = model.predict(\n",
|
||||||
|
@ -245,12 +493,12 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Get the true token labels of the testing dataset. "
|
"Get the true token labels of the testing dataset:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -266,9 +514,26 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" precision recall f1-score support\n",
|
||||||
|
"\n",
|
||||||
|
" MISC 0.68 0.67 0.68 221\n",
|
||||||
|
" LOC 0.79 0.85 0.82 317\n",
|
||||||
|
" ORG 0.73 0.81 0.76 274\n",
|
||||||
|
" PER 0.92 0.93 0.92 257\n",
|
||||||
|
"\n",
|
||||||
|
"micro avg 0.78 0.82 0.80 1069\n",
|
||||||
|
"macro avg 0.78 0.82 0.80 1069\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"predicted_labels = model.get_predicted_token_labels(\n",
|
"predicted_labels = model.get_predicted_token_labels(\n",
|
||||||
" predictions=preds,\n",
|
" predictions=preds,\n",
|
||||||
|
@ -284,6 +549,94 @@
|
||||||
"print(report)"
|
"print(report)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Score Example Sentences\n",
|
||||||
|
"Finally, we test the model on some random input sentences."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"WARNING:root:Token lists with length > 512 will be truncated\n",
|
||||||
|
"Scoring: 100%|██████████| 1/1 [00:00<00:00, 7.56it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
" Is it true that Jane works at Microsoft?\n",
|
||||||
|
" tokens labels\n",
|
||||||
|
"0 Is O\n",
|
||||||
|
"1 it O\n",
|
||||||
|
"2 true O\n",
|
||||||
|
"3 that O\n",
|
||||||
|
"4 Jane I-PER\n",
|
||||||
|
"5 works O\n",
|
||||||
|
"6 at O\n",
|
||||||
|
"7 Microsoft? I-ORG\n",
|
||||||
|
"\n",
|
||||||
|
" Joe now lives in Copenhagen.\n",
|
||||||
|
" tokens labels\n",
|
||||||
|
"0 Joe I-PER\n",
|
||||||
|
"1 now O\n",
|
||||||
|
"2 lives O\n",
|
||||||
|
"3 in O\n",
|
||||||
|
"4 Copenhagen. I-LOC\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# test\n",
|
||||||
|
"sample_text = [ \n",
|
||||||
|
" \"Is it true that Jane works at Microsoft?\",\n",
|
||||||
|
" \"Joe now lives in Copenhagen.\"\n",
|
||||||
|
"]\n",
|
||||||
|
"sample_tokens = [x.split() for x in sample_text]\n",
|
||||||
|
"\n",
|
||||||
|
"sample_dataset = processor.preprocess(\n",
|
||||||
|
" text=sample_tokens,\n",
|
||||||
|
" max_len=MAX_SEQ_LENGTH,\n",
|
||||||
|
" labels=None,\n",
|
||||||
|
" label_map=label_map,\n",
|
||||||
|
" trailing_piece_tag=TRAILING_PIECE_TAG,\n",
|
||||||
|
")\n",
|
||||||
|
"sample_dataloader = dataloader_from_dataset(\n",
|
||||||
|
" sample_dataset, batch_size=BATCH_SIZE, num_gpus=None, shuffle=False, distributed=False\n",
|
||||||
|
")\n",
|
||||||
|
"preds = model.predict(\n",
|
||||||
|
" test_dataloader=sample_dataloader,\n",
|
||||||
|
" num_gpus=None,\n",
|
||||||
|
" verbose=True\n",
|
||||||
|
")\n",
|
||||||
|
"predicted_labels = model.get_predicted_token_labels(\n",
|
||||||
|
" predictions=preds,\n",
|
||||||
|
" label_map=label_map,\n",
|
||||||
|
" dataset=sample_dataset\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"for i in range(len(sample_text)):\n",
|
||||||
|
" print(\"\\n\", sample_text[i])\n",
|
||||||
|
" print(pd.DataFrame({\"tokens\": sample_tokens[i] , \"labels\":predicted_labels[i]})) "
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -293,9 +646,64 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/scrapbook.scrap.json+json": {
|
||||||
|
"data": 0.78,
|
||||||
|
"encoder": "json",
|
||||||
|
"name": "precision",
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"scrapbook": {
|
||||||
|
"data": true,
|
||||||
|
"display": false,
|
||||||
|
"name": "precision"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/scrapbook.scrap.json+json": {
|
||||||
|
"data": 0.82,
|
||||||
|
"encoder": "json",
|
||||||
|
"name": "recall",
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"scrapbook": {
|
||||||
|
"data": true,
|
||||||
|
"display": false,
|
||||||
|
"name": "recall"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/scrapbook.scrap.json+json": {
|
||||||
|
"data": 0.8,
|
||||||
|
"encoder": "json",
|
||||||
|
"name": "f1",
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"scrapbook": {
|
||||||
|
"data": true,
|
||||||
|
"display": false,
|
||||||
|
"name": "f1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"report_splits = report.split('\\n')[-2].split()\n",
|
"report_splits = report.split('\\n')[-2].split()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -306,11 +714,10 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"celltoolbar": "Tags",
|
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3.6 - AzureML",
|
"display_name": "nlp_gpu",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3-azureml"
|
"name": "nlp_gpu"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
@ -326,5 +733,5 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ def test_wikigold(tmp_path):
|
||||||
|
|
||||||
|
|
||||||
def test_ner_utils(ner_utils_test_data):
|
def test_ner_utils(ner_utils_test_data):
|
||||||
output = preprocess_conll(ner_utils_test_data["input"])
|
output = preprocess_conll(ner_utils_test_data["input"], sep=" ")
|
||||||
assert output == ner_utils_test_data["expected_output"]
|
assert output == ner_utils_test_data["expected_output"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,9 @@ def qa_test_data(qa_test_df, tmp_module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# xlnet
|
# xlnet
|
||||||
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased", cache_dir=tmp_module)
|
qa_processor_xlnet = QAProcessor(
|
||||||
|
model_name="xlnet-base-cased", cache_dir=tmp_module
|
||||||
|
)
|
||||||
train_features_xlnet = qa_processor_xlnet.preprocess(
|
train_features_xlnet = qa_processor_xlnet.preprocess(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
@ -148,13 +150,19 @@ def test_QAProcessor(qa_test_data, tmp_module):
|
||||||
]:
|
]:
|
||||||
qa_processor = QAProcessor(model_name=model_name, cache_dir=tmp_module)
|
qa_processor = QAProcessor(model_name=model_name, cache_dir=tmp_module)
|
||||||
qa_processor.preprocess(
|
qa_processor.preprocess(
|
||||||
qa_test_data["train_dataset"], is_training=True, feature_cache_dir=tmp_module,
|
qa_test_data["train_dataset"],
|
||||||
|
is_training=True,
|
||||||
|
feature_cache_dir=tmp_module,
|
||||||
)
|
)
|
||||||
qa_processor.preprocess(
|
qa_processor.preprocess(
|
||||||
qa_test_data["train_dataset_list"], is_training=True, feature_cache_dir=tmp_module,
|
qa_test_data["train_dataset_list"],
|
||||||
|
is_training=True,
|
||||||
|
feature_cache_dir=tmp_module,
|
||||||
)
|
)
|
||||||
qa_processor.preprocess(
|
qa_processor.preprocess(
|
||||||
qa_test_data["test_dataset"], is_training=False, feature_cache_dir=tmp_module,
|
qa_test_data["test_dataset"],
|
||||||
|
is_training=False,
|
||||||
|
feature_cache_dir=tmp_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# test unsupported model type
|
# test unsupported model type
|
||||||
|
@ -188,7 +196,9 @@ def test_AnswerExtractor(qa_test_data, tmp_module):
|
||||||
# bert
|
# bert
|
||||||
qa_extractor_bert = AnswerExtractor(cache_dir=tmp_module)
|
qa_extractor_bert = AnswerExtractor(cache_dir=tmp_module)
|
||||||
train_loader_bert = dataloader_from_dataset(qa_test_data["train_features_bert"])
|
train_loader_bert = dataloader_from_dataset(qa_test_data["train_features_bert"])
|
||||||
test_loader_bert = dataloader_from_dataset(qa_test_data["test_features_bert"], shuffle=False)
|
test_loader_bert = dataloader_from_dataset(
|
||||||
|
qa_test_data["test_features_bert"], shuffle=False
|
||||||
|
)
|
||||||
qa_extractor_bert.fit(train_loader_bert, verbose=False, cache_model=True)
|
qa_extractor_bert.fit(train_loader_bert, verbose=False, cache_model=True)
|
||||||
|
|
||||||
# test saving fine-tuned model
|
# test saving fine-tuned model
|
||||||
|
@ -203,13 +213,19 @@ def test_AnswerExtractor(qa_test_data, tmp_module):
|
||||||
|
|
||||||
# xlnet
|
# xlnet
|
||||||
train_loader_xlnet = dataloader_from_dataset(qa_test_data["train_features_xlnet"])
|
train_loader_xlnet = dataloader_from_dataset(qa_test_data["train_features_xlnet"])
|
||||||
test_loader_xlnet = dataloader_from_dataset(qa_test_data["test_features_xlnet"], shuffle=False)
|
test_loader_xlnet = dataloader_from_dataset(
|
||||||
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp_module)
|
qa_test_data["test_features_xlnet"], shuffle=False
|
||||||
|
)
|
||||||
|
qa_extractor_xlnet = AnswerExtractor(
|
||||||
|
model_name="xlnet-base-cased", cache_dir=tmp_module
|
||||||
|
)
|
||||||
qa_extractor_xlnet.fit(train_loader_xlnet, verbose=False, cache_model=False)
|
qa_extractor_xlnet.fit(train_loader_xlnet, verbose=False, cache_model=False)
|
||||||
qa_extractor_xlnet.predict(test_loader_xlnet, verbose=False)
|
qa_extractor_xlnet.predict(test_loader_xlnet, verbose=False)
|
||||||
|
|
||||||
# distilbert
|
# distilbert
|
||||||
train_loader_xlnet = dataloader_from_dataset(qa_test_data["train_features_distilbert"])
|
train_loader_xlnet = dataloader_from_dataset(
|
||||||
|
qa_test_data["train_features_distilbert"]
|
||||||
|
)
|
||||||
test_loader_xlnet = dataloader_from_dataset(
|
test_loader_xlnet = dataloader_from_dataset(
|
||||||
qa_test_data["test_features_distilbert"], shuffle=False
|
qa_test_data["test_features_distilbert"], shuffle=False
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,7 +23,7 @@ def test_token_classifier_fit_predict(tmpdir, ner_test_data):
|
||||||
)
|
)
|
||||||
|
|
||||||
# test fit, no warmup
|
# test fit, no warmup
|
||||||
train_dataset = processor.preprocess_for_bert(
|
train_dataset = processor.preprocess(
|
||||||
text=ner_test_data["INPUT_TEXT"],
|
text=ner_test_data["INPUT_TEXT"],
|
||||||
max_len=max_seq_len,
|
max_len=max_seq_len,
|
||||||
labels=ner_test_data["INPUT_LABELS"],
|
labels=ner_test_data["INPUT_LABELS"],
|
||||||
|
|
|
@ -81,7 +81,7 @@ PIP_BASE = {
|
||||||
"https://github.com/explosion/spacy-models/releases/download/"
|
"https://github.com/explosion/spacy-models/releases/download/"
|
||||||
"en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz"
|
"en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz"
|
||||||
),
|
),
|
||||||
"transformers": "transformers==2.5.0",
|
"transformers": "transformers==2.9.0",
|
||||||
"gensim": "gensim>=3.7.0",
|
"gensim": "gensim>=3.7.0",
|
||||||
"nltk": "nltk>=3.4",
|
"nltk": "nltk>=3.4",
|
||||||
"seqeval": "seqeval>=0.0.12",
|
"seqeval": "seqeval>=0.0.12",
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
"""Common helper functions for preprocessing Named Entity Recognition (NER) datasets."""
|
"""Common helper functions for preprocessing Named Entity Recognition (NER) datasets."""
|
||||||
|
|
||||||
|
|
||||||
def preprocess_conll(text, data_type=""):
|
def preprocess_conll(text, sep="\t"):
|
||||||
"""
|
"""
|
||||||
Helper function converting data in conll format to word lists
|
Converts data in CoNLL format to word and label lists.
|
||||||
and token label lists.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): Text string in conll format, e.g.
|
text (str): Text string in conll format, e.g.
|
||||||
|
@ -20,8 +19,8 @@ def preprocess_conll(text, data_type=""):
|
||||||
of I-ORG
|
of I-ORG
|
||||||
Minnesota I-ORG
|
Minnesota I-ORG
|
||||||
. O"
|
. O"
|
||||||
data_type (str, optional): String that briefly describes the data,
|
sep (str, optional): Column separator
|
||||||
e.g. "train"
|
Defaults to \t
|
||||||
Returns:
|
Returns:
|
||||||
tuple:
|
tuple:
|
||||||
(list of word lists, list of token label lists)
|
(list of word lists, list of token label lists)
|
||||||
|
@ -37,11 +36,29 @@ def preprocess_conll(text, data_type=""):
|
||||||
# split each sentence string into "word label" pairs
|
# split each sentence string into "word label" pairs
|
||||||
s_split = s.split("\n")
|
s_split = s.split("\n")
|
||||||
# split "word label" pairs
|
# split "word label" pairs
|
||||||
s_split_split = [t.split() for t in s_split]
|
s_split_split = [t.split(sep) for t in s_split]
|
||||||
sentence_list.append([t[0] for t in s_split_split if len(t) > 1])
|
sentence_list.append([t[0] for t in s_split_split if len(t) > 1])
|
||||||
labels_list.append([t[1] for t in s_split_split if len(t) > 1])
|
labels_list.append([t[1] for t in s_split_split if len(t) > 1])
|
||||||
|
|
||||||
if len(s_split_split) > max_seq_len:
|
if len(s_split_split) > max_seq_len:
|
||||||
max_seq_len = len(s_split_split)
|
max_seq_len = len(s_split_split)
|
||||||
print("Maximum sequence length in the {0} data is: {1}".format(data_type, max_seq_len))
|
print("Maximum sequence length is: {0}".format(max_seq_len))
|
||||||
return sentence_list, labels_list
|
return sentence_list, labels_list
|
||||||
|
|
||||||
|
|
||||||
|
def read_conll_file(file_path, sep="\t", encoding=None):
|
||||||
|
"""
|
||||||
|
Reads a data file in CoNLL format and returns word and label lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Data file path.
|
||||||
|
sep (str, optional): Column separator. Defaults to "\t".
|
||||||
|
encoding (str): File encoding used when reading the file.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(list, list): A tuple of word and label lists (list of lists).
|
||||||
|
"""
|
||||||
|
with open(file_path, encoding=encoding) as f:
|
||||||
|
data = f.read()
|
||||||
|
return preprocess_conll(data, sep=sep)
|
||||||
|
|
|
@ -18,7 +18,9 @@ from utils_nlp.common.pytorch_utils import dataloader_from_dataset
|
||||||
from utils_nlp.dataset.ner_utils import preprocess_conll
|
from utils_nlp.dataset.ner_utils import preprocess_conll
|
||||||
from utils_nlp.dataset.url_utils import maybe_download
|
from utils_nlp.dataset.url_utils import maybe_download
|
||||||
from utils_nlp.models.transformers.common import MAX_SEQ_LEN
|
from utils_nlp.models.transformers.common import MAX_SEQ_LEN
|
||||||
from utils_nlp.models.transformers.named_entity_recognition import TokenClassificationProcessor
|
from utils_nlp.models.transformers.named_entity_recognition import (
|
||||||
|
TokenClassificationProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
URL = (
|
URL = (
|
||||||
"https://raw.githubusercontent.com/juand-r/entity-recognition-datasets"
|
"https://raw.githubusercontent.com/juand-r/entity-recognition-datasets"
|
||||||
|
@ -68,7 +70,9 @@ def load_train_test_dfs(local_cache_path="./", test_fraction=0.5, random_seed=No
|
||||||
train_sentence_list = sentence_list[test_sentence_count:]
|
train_sentence_list = sentence_list[test_sentence_count:]
|
||||||
train_labels_list = labels_list[test_sentence_count:]
|
train_labels_list = labels_list[test_sentence_count:]
|
||||||
|
|
||||||
train_df = pd.DataFrame({"sentence": train_sentence_list, "labels": train_labels_list})
|
train_df = pd.DataFrame(
|
||||||
|
{"sentence": train_sentence_list, "labels": train_labels_list}
|
||||||
|
)
|
||||||
|
|
||||||
test_df = pd.DataFrame({"sentence": test_sentence_list, "labels": test_labels_list})
|
test_df = pd.DataFrame({"sentence": test_sentence_list, "labels": test_labels_list})
|
||||||
|
|
||||||
|
@ -152,7 +156,9 @@ def load_dataset(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train_df, test_df = load_train_test_dfs(
|
train_df, test_df = load_train_test_dfs(
|
||||||
local_cache_path=local_path, test_fraction=test_fraction, random_seed=random_seed
|
local_cache_path=local_path,
|
||||||
|
test_fraction=test_fraction,
|
||||||
|
random_seed=random_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_sample_ratio > 1.0:
|
if train_sample_ratio > 1.0:
|
||||||
|
@ -160,7 +166,9 @@ def load_dataset(
|
||||||
logging.warning("Setting the training sample ratio to 1.0")
|
logging.warning("Setting the training sample ratio to 1.0")
|
||||||
elif train_sample_ratio < 0:
|
elif train_sample_ratio < 0:
|
||||||
logging.error("Invalid training sample ration: {}".format(train_sample_ratio))
|
logging.error("Invalid training sample ration: {}".format(train_sample_ratio))
|
||||||
raise ValueError("Invalid training sample ration: {}".format(train_sample_ratio))
|
raise ValueError(
|
||||||
|
"Invalid training sample ration: {}".format(train_sample_ratio)
|
||||||
|
)
|
||||||
|
|
||||||
if test_sample_ratio > 1.0:
|
if test_sample_ratio > 1.0:
|
||||||
test_sample_ratio = 1.0
|
test_sample_ratio = 1.0
|
||||||
|
@ -174,7 +182,9 @@ def load_dataset(
|
||||||
if test_sample_ratio < 1.0:
|
if test_sample_ratio < 1.0:
|
||||||
test_df = test_df.sample(frac=test_sample_ratio).reset_index(drop=True)
|
test_df = test_df.sample(frac=test_sample_ratio).reset_index(drop=True)
|
||||||
|
|
||||||
processor = TokenClassificationProcessor(model_name=model_name, to_lower=to_lower, cache_dir=cache_dir)
|
processor = TokenClassificationProcessor(
|
||||||
|
model_name=model_name, to_lower=to_lower, cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
label_map = TokenClassificationProcessor.create_label_map(
|
label_map = TokenClassificationProcessor.create_label_map(
|
||||||
label_lists=train_df["labels"], trailing_piece_tag=trailing_piece_tag
|
label_lists=train_df["labels"], trailing_piece_tag=trailing_piece_tag
|
||||||
|
@ -197,11 +207,19 @@ def load_dataset(
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataloader = dataloader_from_dataset(
|
train_dataloader = dataloader_from_dataset(
|
||||||
train_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=True, distributed=False
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
shuffle=True,
|
||||||
|
distributed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_dataloader = dataloader_from_dataset(
|
test_dataloader = dataloader_from_dataset(
|
||||||
test_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=False, distributed=False
|
test_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
shuffle=False,
|
||||||
|
distributed=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (train_dataloader, test_dataloader, label_map, test_dataset)
|
return (train_dataloader, test_dataloader, label_map, test_dataset)
|
||||||
|
|
|
@ -5,34 +5,29 @@
|
||||||
# This script reuses some code from https://github.com/huggingface/transformers/
|
# This script reuses some code from https://github.com/huggingface/transformers/
|
||||||
# Add to noticefile
|
# Add to noticefile
|
||||||
|
|
||||||
from collections import namedtuple
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from tqdm import tqdm
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
DataLoader,
|
|
||||||
SequentialSampler,
|
|
||||||
RandomSampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from transformers import BertModel
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer, BertModel
|
||||||
|
|
||||||
from utils_nlp.common.pytorch_utils import (
|
from utils_nlp.common.pytorch_utils import (
|
||||||
compute_training_steps,
|
compute_training_steps,
|
||||||
get_device,
|
|
||||||
get_amp,
|
get_amp,
|
||||||
|
get_device,
|
||||||
move_model_to_device,
|
move_model_to_device,
|
||||||
parallelize_model,
|
parallelize_model,
|
||||||
)
|
)
|
||||||
from utils_nlp.eval import compute_rouge_python
|
from utils_nlp.eval import compute_rouge_python
|
||||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
|
||||||
from utils_nlp.models.transformers.bertsum import model_builder
|
from utils_nlp.models.transformers.bertsum import model_builder
|
||||||
from utils_nlp.models.transformers.bertsum.model_builder import AbsSummarizer
|
from utils_nlp.models.transformers.bertsum.model_builder import AbsSummarizer
|
||||||
from utils_nlp.models.transformers.bertsum.predictor import build_predictor
|
from utils_nlp.models.transformers.bertsum.predictor import build_predictor
|
||||||
|
from utils_nlp.models.transformers.common import Transformer
|
||||||
|
|
||||||
MODEL_CLASS = {"bert-base-uncased": BertModel}
|
MODEL_CLASS = {"bert-base-uncased": BertModel}
|
||||||
|
|
||||||
|
@ -134,8 +129,11 @@ class BertSumAbsProcessor:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.tokenizer = TOKENIZER_CLASS[self.model_name].from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.model_name, do_lower_case=to_lower, cache_dir=cache_dir
|
model_name,
|
||||||
|
do_lower_case=to_lower,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
output_loading_info=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.symbols = {
|
self.symbols = {
|
||||||
|
@ -156,7 +154,7 @@ class BertSumAbsProcessor:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_supported_models():
|
def list_supported_models():
|
||||||
return list(MODEL_CLASS.keys())
|
return list(MODEL_CLASS)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self):
|
def model_name(self):
|
||||||
|
@ -184,7 +182,7 @@ class BertSumAbsProcessor:
|
||||||
also contains the target ids and the number of tokens
|
also contains the target ids and the number of tokens
|
||||||
in the target and target text.
|
in the target and target text.
|
||||||
device (torch.device): A PyTorch device.
|
device (torch.device): A PyTorch device.
|
||||||
model_name (bool, optional): Model name used to format the inputs.
|
model_name (bool): Model name used to format the inputs.
|
||||||
train_mode (bool, optional): Training mode flag.
|
train_mode (bool, optional): Training mode flag.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
|
||||||
|
@ -403,7 +401,8 @@ class BertSumAbs(Transformer):
|
||||||
check MODEL_CLASS for supported models. Defaults to "bert-base-uncased".
|
check MODEL_CLASS for supported models. Defaults to "bert-base-uncased".
|
||||||
finetune_bert (bool, option): Whether the bert model in the encoder is
|
finetune_bert (bool, option): Whether the bert model in the encoder is
|
||||||
finetune or not. Defaults to True.
|
finetune or not. Defaults to True.
|
||||||
cache_dir (str, optional): Directory to cache the tokenizer. Defaults to ".".
|
cache_dir (str, optional): Directory to cache the tokenizer.
|
||||||
|
Defaults to ".".
|
||||||
label_smoothing (float, optional): The amount of label smoothing.
|
label_smoothing (float, optional): The amount of label smoothing.
|
||||||
Value range is [0, 1]. Defaults to 0.1.
|
Value range is [0, 1]. Defaults to 0.1.
|
||||||
test (bool, optional): Whether the class is initiated for test or not.
|
test (bool, optional): Whether the class is initiated for test or not.
|
||||||
|
@ -412,13 +411,11 @@ class BertSumAbs(Transformer):
|
||||||
max_pos_length (int, optional): maximum postional embedding length for the
|
max_pos_length (int, optional): maximum postional embedding length for the
|
||||||
input. Defaults to 768.
|
input. Defaults to 768.
|
||||||
"""
|
"""
|
||||||
|
model = MODEL_CLASS[model_name].from_pretrained(
|
||||||
super().__init__(
|
model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
|
||||||
model_class=MODEL_CLASS,
|
|
||||||
model_name=model_name,
|
|
||||||
num_labels=0,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
)
|
)
|
||||||
|
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||||
|
|
||||||
if model_name not in self.list_supported_models():
|
if model_name not in self.list_supported_models():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model name {} is not supported by BertSumAbs. "
|
"Model name {} is not supported by BertSumAbs. "
|
||||||
|
@ -616,10 +613,7 @@ class BertSumAbs(Transformer):
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_fn
|
||||||
sampler=sampler,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the max number of training steps
|
# compute the max number of training steps
|
||||||
|
|
|
@ -13,7 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from transformers import RobertaConfig, BertConfig
|
from transformers import RobertaConfig, BertConfig
|
||||||
|
|
||||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
from utils_nlp.models.transformers.common import Transformer
|
||||||
from utils_nlp.common.pytorch_utils import (
|
from utils_nlp.common.pytorch_utils import (
|
||||||
get_device,
|
get_device,
|
||||||
move_model_to_device,
|
move_model_to_device,
|
||||||
|
@ -52,8 +52,7 @@ MODEL_CLASS.update(
|
||||||
{k: BertForSequenceToSequence for k in MINILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
{k: BertForSequenceToSequence for k in MINILM_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TOKENIZER_CLASS = {}
|
||||||
|
|
||||||
TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
TOKENIZER_CLASS.update({k: UnilmTokenizer for k in UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||||
TOKENIZER_CLASS.update({k: MinilmTokenizer for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
TOKENIZER_CLASS.update({k: MinilmTokenizer for k in MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP})
|
||||||
|
|
||||||
|
|
|
@ -14,32 +14,14 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
from transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
from transformers.tokenization_bert import BertTokenizer
|
|
||||||
from transformers.tokenization_distilbert import DistilBertTokenizer
|
|
||||||
from transformers.tokenization_roberta import RobertaTokenizer
|
|
||||||
from transformers.tokenization_xlnet import XLNetTokenizer
|
|
||||||
|
|
||||||
from utils_nlp.common.pytorch_utils import (
|
from utils_nlp.common.pytorch_utils import (
|
||||||
|
get_amp,
|
||||||
get_device,
|
get_device,
|
||||||
move_model_to_device,
|
move_model_to_device,
|
||||||
get_amp,
|
|
||||||
parallelize_model,
|
parallelize_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
TOKENIZER_CLASS = {}
|
|
||||||
TOKENIZER_CLASS.update({k: BertTokenizer for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP})
|
|
||||||
TOKENIZER_CLASS.update(
|
|
||||||
{k: RobertaTokenizer for k in ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP}
|
|
||||||
)
|
|
||||||
TOKENIZER_CLASS.update({k: XLNetTokenizer for k in XLNET_PRETRAINED_MODEL_ARCHIVE_MAP})
|
|
||||||
TOKENIZER_CLASS.update(
|
|
||||||
{k: DistilBertTokenizer for k in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
|
||||||
)
|
|
||||||
|
|
||||||
MAX_SEQ_LEN = 512
|
MAX_SEQ_LEN = 512
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -48,35 +30,14 @@ logger = logging.getLogger(__name__)
|
||||||
class Transformer:
|
class Transformer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_class,
|
model_name,
|
||||||
model_name="bert-base-cased",
|
model,
|
||||||
num_labels=2,
|
cache_dir,
|
||||||
cache_dir=".",
|
|
||||||
load_model_from_dir=None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if model_name not in self.list_supported_models():
|
|
||||||
raise ValueError(
|
|
||||||
"Model name {0} is not supported by {1}. "
|
|
||||||
"Call '{1}.list_supported_models()' to get all supported model "
|
|
||||||
"names.".format(model_name, self.__class__.__name__)
|
|
||||||
)
|
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._model_type = model_name.split("-")[0]
|
self._model_type = model_name.split("-")[0]
|
||||||
|
self.model = model
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.load_model_from_dir = load_model_from_dir
|
|
||||||
if load_model_from_dir is None:
|
|
||||||
self.model = model_class[model_name].from_pretrained(
|
|
||||||
model_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
num_labels=num_labels,
|
|
||||||
output_loading_info=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Loading cached model from {}".format(load_model_from_dir))
|
|
||||||
self.model = model_class[model_name].from_pretrained(
|
|
||||||
load_model_from_dir, num_labels=num_labels, output_loading_info=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self):
|
def model_name(self):
|
||||||
|
@ -241,7 +202,8 @@ class Transformer:
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
else:
|
else:
|
||||||
# Accomondate models based on older versions of Transformers, e.g. UniLM
|
# Accomondate models based on older versions of Transformers,
|
||||||
|
# e.g. UniLM
|
||||||
loss = outputs
|
loss = outputs
|
||||||
|
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
@ -317,7 +279,7 @@ class Transformer:
|
||||||
saved_model_path = os.path.join(
|
saved_model_path = os.path.join(
|
||||||
self.cache_dir, f"{self.model_name}_step_{global_step}.pt"
|
self.cache_dir, f"{self.model_name}_step_{global_step}.pt"
|
||||||
)
|
)
|
||||||
self.save_model(global_step, saved_model_path)
|
self.save_model(saved_model_path)
|
||||||
if validation_function:
|
if validation_function:
|
||||||
validation_log = validation_function(self)
|
validation_log = validation_function(self)
|
||||||
logger.info(validation_log)
|
logger.info(validation_log)
|
||||||
|
|
|
@ -12,14 +12,20 @@ from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
DataLoader,
|
|
||||||
SequentialSampler,
|
|
||||||
RandomSampler,
|
|
||||||
)
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from transformers import BertModel, DistilBertModel
|
from transformers import AutoTokenizer, BertModel, DistilBertModel
|
||||||
|
|
||||||
|
from utils_nlp.common.pytorch_utils import (
|
||||||
|
compute_training_steps,
|
||||||
|
get_device,
|
||||||
|
move_model_to_device,
|
||||||
|
parallelize_model,
|
||||||
|
)
|
||||||
|
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
|
||||||
|
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
||||||
|
fit_to_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
from utils_nlp.models.transformers.bertsum import model_builder
|
from utils_nlp.models.transformers.bertsum import model_builder
|
||||||
from utils_nlp.models.transformers.bertsum.data_loader import (
|
from utils_nlp.models.transformers.bertsum.data_loader import (
|
||||||
|
@ -32,17 +38,7 @@ from utils_nlp.models.transformers.bertsum.dataset import (
|
||||||
ExtSumProcessedIterableDataset,
|
ExtSumProcessedIterableDataset,
|
||||||
)
|
)
|
||||||
from utils_nlp.models.transformers.bertsum.model_builder import BertSumExt
|
from utils_nlp.models.transformers.bertsum.model_builder import BertSumExt
|
||||||
from utils_nlp.common.pytorch_utils import (
|
from utils_nlp.models.transformers.common import Transformer
|
||||||
compute_training_steps,
|
|
||||||
get_device,
|
|
||||||
move_model_to_device,
|
|
||||||
parallelize_model,
|
|
||||||
)
|
|
||||||
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
|
|
||||||
from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer
|
|
||||||
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
|
|
||||||
fit_to_block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASS = {
|
MODEL_CLASS = {
|
||||||
"bert-base-uncased": BertModel,
|
"bert-base-uncased": BertModel,
|
||||||
|
@ -302,7 +298,7 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1):
|
||||||
p = Pool(num_pool)
|
p = Pool(num_pool)
|
||||||
|
|
||||||
results = p.map(
|
results = p.map(
|
||||||
preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool)),
|
preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool))
|
||||||
)
|
)
|
||||||
p.close()
|
p.close()
|
||||||
p.join()
|
p.join()
|
||||||
|
@ -347,8 +343,11 @@ class ExtSumProcessor:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.tokenizer = TOKENIZER_CLASS[self.model_name].from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.model_name, do_lower_case=to_lower, cache_dir=cache_dir
|
model_name,
|
||||||
|
do_lower_case=to_lower,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
output_loading_info=False,
|
||||||
)
|
)
|
||||||
self.sep_vid = self.tokenizer.vocab["[SEP]"]
|
self.sep_vid = self.tokenizer.vocab["[SEP]"]
|
||||||
self.cls_vid = self.tokenizer.vocab["[CLS]"]
|
self.cls_vid = self.tokenizer.vocab["[CLS]"]
|
||||||
|
@ -361,7 +360,7 @@ class ExtSumProcessor:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_supported_models():
|
def list_supported_models():
|
||||||
return list(TOKENIZER_CLASS.keys())
|
return list(MODEL_CLASS)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_name(self):
|
def model_name(self):
|
||||||
|
@ -389,7 +388,7 @@ class ExtSumProcessor:
|
||||||
text. If train_model is True, it also contains the labels and target
|
text. If train_model is True, it also contains the labels and target
|
||||||
text.
|
text.
|
||||||
device (torch.device): A PyTorch device.
|
device (torch.device): A PyTorch device.
|
||||||
model_name (bool, optional): Model name used to format the inputs.
|
model_name (bool): Model name used to format the inputs.
|
||||||
train_mode (bool, optional): Training mode flag.
|
train_mode (bool, optional): Training mode flag.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
|
||||||
|
@ -500,7 +499,6 @@ class ExtSumProcessor:
|
||||||
|
|
||||||
if len(src) == 0:
|
if len(src) == 0:
|
||||||
raise ValueError("source doesn't have any sentences")
|
raise ValueError("source doesn't have any sentences")
|
||||||
return None
|
|
||||||
|
|
||||||
original_src_txt = [" ".join(s) for s in src]
|
original_src_txt = [" ".join(s) for s in src]
|
||||||
# no filtering for prediction
|
# no filtering for prediction
|
||||||
|
@ -588,12 +586,11 @@ class ExtractiveSummarizer(Transformer):
|
||||||
Defaults to ".".
|
Defaults to ".".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__(
|
model = MODEL_CLASS[model_name].from_pretrained(
|
||||||
model_class=MODEL_CLASS,
|
model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
|
||||||
model_name=model_name,
|
|
||||||
num_labels=0,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
)
|
)
|
||||||
|
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||||
|
|
||||||
if model_name not in self.list_supported_models():
|
if model_name not in self.list_supported_models():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model name {} is not supported by ExtractiveSummarizer. "
|
"Model name {} is not supported by ExtractiveSummarizer. "
|
||||||
|
@ -621,7 +618,7 @@ class ExtractiveSummarizer(Transformer):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_supported_models():
|
def list_supported_models():
|
||||||
return list(MODEL_CLASS.keys())
|
return list(MODEL_CLASS)
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -7,32 +7,21 @@ from collections import Iterable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset
|
from torch.utils.data import TensorDataset
|
||||||
from transformers.modeling_bert import (
|
from transformers import (
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertForTokenClassification,
|
AutoConfig,
|
||||||
)
|
AutoModelForTokenClassification,
|
||||||
from transformers.modeling_distilbert import (
|
AutoTokenizer,
|
||||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
DistilBertForTokenClassification,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils_nlp.common.pytorch_utils import compute_training_steps
|
from utils_nlp.common.pytorch_utils import compute_training_steps
|
||||||
from utils_nlp.models.transformers.common import (
|
from utils_nlp.models.transformers.common import MAX_SEQ_LEN, Transformer
|
||||||
MAX_SEQ_LEN,
|
|
||||||
TOKENIZER_CLASS,
|
|
||||||
Transformer,
|
|
||||||
)
|
|
||||||
|
|
||||||
TC_MODEL_CLASS = {}
|
supported_models = [
|
||||||
TC_MODEL_CLASS.update(
|
list(x.pretrained_config_archive_map)
|
||||||
{k: BertForTokenClassification for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
for x in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||||
)
|
]
|
||||||
TC_MODEL_CLASS.update(
|
supported_models = sorted([x for y in supported_models for x in y])
|
||||||
{
|
|
||||||
k: DistilBertForTokenClassification
|
|
||||||
for k in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenClassificationProcessor:
|
class TokenClassificationProcessor:
|
||||||
|
@ -52,7 +41,7 @@ class TokenClassificationProcessor:
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.to_lower = to_lower
|
self.to_lower = to_lower
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
do_lower_case=to_lower,
|
do_lower_case=to_lower,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
|
@ -68,7 +57,7 @@ class TokenClassificationProcessor:
|
||||||
batch (tuple): A tuple containing input ids, attention mask,
|
batch (tuple): A tuple containing input ids, attention mask,
|
||||||
segment ids, and labels tensors.
|
segment ids, and labels tensors.
|
||||||
device (torch.device): A PyTorch device.
|
device (torch.device): A PyTorch device.
|
||||||
model_name (bool, optional): Model name used to format the inputs.
|
model_name (bool): Model name used to format the inputs.
|
||||||
train_mode (bool, optional): Training mode flag.
|
train_mode (bool, optional): Training mode flag.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
|
||||||
|
@ -77,7 +66,7 @@ class TokenClassificationProcessor:
|
||||||
Labels are only returned when train_mode is True.
|
Labels are only returned when train_mode is True.
|
||||||
"""
|
"""
|
||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
if model_name.split("-")[0] in ["bert", "distilbert"]:
|
if model_name in supported_models:
|
||||||
if train_mode:
|
if train_mode:
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": batch[0],
|
"input_ids": batch[0],
|
||||||
|
@ -110,17 +99,15 @@ class TokenClassificationProcessor:
|
||||||
dict: A dictionary object to map a label (str) to an ID (int).
|
dict: A dictionary object to map a label (str) to an ID (int).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
label_set = set()
|
unique_labels = sorted(set([x for y in label_lists for x in y]))
|
||||||
for labels in label_lists:
|
label_map = {label: i for i, label in enumerate(unique_labels)}
|
||||||
label_set.update(labels)
|
|
||||||
|
|
||||||
label_map = {label: i for i, label in enumerate(label_set)}
|
if trailing_piece_tag not in unique_labels:
|
||||||
|
label_map[trailing_piece_tag] = len(unique_labels)
|
||||||
|
|
||||||
if trailing_piece_tag not in label_set:
|
|
||||||
label_map[trailing_piece_tag] = len(label_set)
|
|
||||||
return label_map
|
return label_map
|
||||||
|
|
||||||
def preprocess_for_bert(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
max_len=MAX_SEQ_LEN,
|
max_len=MAX_SEQ_LEN,
|
||||||
|
@ -187,6 +174,10 @@ class TokenClassificationProcessor:
|
||||||
)
|
)
|
||||||
max_len = MAX_SEQ_LEN
|
max_len = MAX_SEQ_LEN
|
||||||
|
|
||||||
|
logging.warn(
|
||||||
|
"Token lists with length > {} will be truncated".format(MAX_SEQ_LEN)
|
||||||
|
)
|
||||||
|
|
||||||
if not _is_iterable_but_not_string(text):
|
if not _is_iterable_but_not_string(text):
|
||||||
# The input text must be an non-string Iterable
|
# The input text must be an non-string Iterable
|
||||||
raise ValueError("Input text must be an iterable and not a string.")
|
raise ValueError("Input text must be an iterable and not a string.")
|
||||||
|
@ -233,11 +224,6 @@ class TokenClassificationProcessor:
|
||||||
new_tokens.append(sub_word)
|
new_tokens.append(sub_word)
|
||||||
|
|
||||||
if len(new_tokens) > max_len:
|
if len(new_tokens) > max_len:
|
||||||
logging.warn(
|
|
||||||
"Text after tokenization with length {} has been truncated".format(
|
|
||||||
len(new_tokens)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
new_tokens = new_tokens[:max_len]
|
new_tokens = new_tokens[:max_len]
|
||||||
new_labels = new_labels[:max_len]
|
new_labels = new_labels[:max_len]
|
||||||
input_ids = self.tokenizer.convert_tokens_to_ids(new_tokens)
|
input_ids = self.tokenizer.convert_tokens_to_ids(new_tokens)
|
||||||
|
@ -269,16 +255,16 @@ class TokenClassificationProcessor:
|
||||||
|
|
||||||
if label_available:
|
if label_available:
|
||||||
td = TensorDataset(
|
td = TensorDataset(
|
||||||
torch.tensor(input_ids_all, dtype=torch.long),
|
torch.LongTensor(input_ids_all),
|
||||||
torch.tensor(input_mask_all, dtype=torch.long),
|
torch.LongTensor(input_mask_all),
|
||||||
torch.tensor(trailing_token_mask_all, dtype=torch.long),
|
torch.LongTensor(trailing_token_mask_all),
|
||||||
torch.tensor(label_ids_all, dtype=torch.long),
|
torch.LongTensor(label_ids_all),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
td = TensorDataset(
|
td = TensorDataset(
|
||||||
torch.tensor(input_ids_all, dtype=torch.long),
|
torch.LongTensor(input_ids_all),
|
||||||
torch.tensor(input_mask_all, dtype=torch.long),
|
torch.LongTensor(input_mask_all),
|
||||||
torch.tensor(trailing_token_mask_all, dtype=torch.long),
|
torch.LongTensor(trailing_token_mask_all),
|
||||||
)
|
)
|
||||||
return td
|
return td
|
||||||
|
|
||||||
|
@ -297,16 +283,17 @@ class TokenClassifier(Transformer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name="bert-base-cased", num_labels=2, cache_dir="."):
|
def __init__(self, model_name="bert-base-cased", num_labels=2, cache_dir="."):
|
||||||
super().__init__(
|
config = AutoConfig.from_pretrained(
|
||||||
model_class=TC_MODEL_CLASS,
|
model_name, num_labels=num_labels, cache_dir=cache_dir
|
||||||
model_name=model_name,
|
|
||||||
num_labels=num_labels,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
)
|
)
|
||||||
|
model = AutoModelForTokenClassification.from_pretrained(
|
||||||
|
model_name, cache_dir=cache_dir, config=config, output_loading_info=False
|
||||||
|
)
|
||||||
|
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_supported_models():
|
def list_supported_models():
|
||||||
return list(TC_MODEL_CLASS)
|
return supported_models
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
|
@ -398,7 +385,9 @@ class TokenClassifier(Transformer):
|
||||||
|
|
||||||
# init scheduler
|
# init scheduler
|
||||||
scheduler = Transformer.get_default_scheduler(
|
scheduler = Transformer.get_default_scheduler(
|
||||||
optimizer=self.optimizer, warmup_steps=warmup_steps, num_training_steps=max_steps
|
optimizer=self.optimizer,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
num_training_steps=max_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fine tune
|
# fine tune
|
||||||
|
|
|
@ -27,6 +27,8 @@ import jsonlines
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset
|
from torch.utils.data import TensorDataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from transformers.modeling_albert import (
|
from transformers.modeling_albert import (
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
AlbertForQuestionAnswering,
|
AlbertForQuestionAnswering,
|
||||||
|
@ -51,19 +53,21 @@ from utils_nlp.common.pytorch_utils import (
|
||||||
move_model_to_device,
|
move_model_to_device,
|
||||||
parallelize_model,
|
parallelize_model,
|
||||||
)
|
)
|
||||||
from utils_nlp.models.transformers.common import (
|
from utils_nlp.models.transformers.common import MAX_SEQ_LEN, Transformer
|
||||||
MAX_SEQ_LEN,
|
|
||||||
TOKENIZER_CLASS,
|
|
||||||
Transformer,
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASS = {}
|
MODEL_CLASS = {}
|
||||||
MODEL_CLASS.update({k: BertForQuestionAnswering for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP})
|
MODEL_CLASS.update(
|
||||||
MODEL_CLASS.update({k: XLNetForQuestionAnswering for k in XLNET_PRETRAINED_MODEL_ARCHIVE_MAP})
|
{k: BertForQuestionAnswering for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
|
)
|
||||||
|
MODEL_CLASS.update(
|
||||||
|
{k: XLNetForQuestionAnswering for k in XLNET_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
|
)
|
||||||
MODEL_CLASS.update(
|
MODEL_CLASS.update(
|
||||||
{k: DistilBertForQuestionAnswering for k in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
{k: DistilBertForQuestionAnswering for k in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
)
|
)
|
||||||
MODEL_CLASS.update({k: AlbertForQuestionAnswering for k in ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP})
|
MODEL_CLASS.update(
|
||||||
|
{k: AlbertForQuestionAnswering for k in ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
||||||
|
)
|
||||||
|
|
||||||
# cached files during preprocessing
|
# cached files during preprocessing
|
||||||
# these are used in postprocessing to generate the final answer texts
|
# these are used in postprocessing to generate the final answer texts
|
||||||
|
@ -103,11 +107,18 @@ class QAProcessor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name="bert-base-cased", to_lower=False, custom_tokenize=None, cache_dir=".",
|
self,
|
||||||
|
model_name="bert-base-cased",
|
||||||
|
to_lower=False,
|
||||||
|
custom_tokenize=None,
|
||||||
|
cache_dir=".",
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, do_lower_case=to_lower, cache_dir=cache_dir, output_loading_info=False,
|
model_name,
|
||||||
|
do_lower_case=to_lower,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
output_loading_info=False,
|
||||||
)
|
)
|
||||||
self.do_lower_case = to_lower
|
self.do_lower_case = to_lower
|
||||||
self.custom_tokenize = custom_tokenize
|
self.custom_tokenize = custom_tokenize
|
||||||
|
@ -218,7 +229,9 @@ class QAProcessor:
|
||||||
os.makedirs(feature_cache_dir)
|
os.makedirs(feature_cache_dir)
|
||||||
|
|
||||||
if is_training and not qa_dataset.actual_answer_available:
|
if is_training and not qa_dataset.actual_answer_available:
|
||||||
raise Exception("answer_start and answer_text must be provided for training data.")
|
raise Exception(
|
||||||
|
"answer_start and answer_text must be provided for training data."
|
||||||
|
)
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
examples_file = os.path.join(feature_cache_dir, CACHED_EXAMPLES_TRAIN_FILE)
|
examples_file = os.path.join(feature_cache_dir, CACHED_EXAMPLES_TRAIN_FILE)
|
||||||
|
@ -245,7 +258,10 @@ class QAProcessor:
|
||||||
qa_examples.append(qa_example_cur)
|
qa_examples.append(qa_example_cur)
|
||||||
|
|
||||||
qa_examples_json.append(
|
qa_examples_json.append(
|
||||||
{"qa_id": qa_example_cur.qa_id, "doc_tokens": qa_example_cur.doc_tokens}
|
{
|
||||||
|
"qa_id": qa_example_cur.qa_id,
|
||||||
|
"doc_tokens": qa_example_cur.doc_tokens,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
features_cur = _create_qa_features(
|
features_cur = _create_qa_features(
|
||||||
|
@ -289,8 +305,12 @@ class QAProcessor:
|
||||||
p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.long)
|
p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.long)
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
start_positions = torch.tensor(
|
||||||
end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
[f.start_position for f in features], dtype=torch.long
|
||||||
|
)
|
||||||
|
end_positions = torch.tensor(
|
||||||
|
[f.end_position for f in features], dtype=torch.long
|
||||||
|
)
|
||||||
qa_dataset = TensorDataset(
|
qa_dataset = TensorDataset(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
|
@ -421,7 +441,9 @@ class QAProcessor:
|
||||||
return final_answers, answer_probs, nbest_answers
|
return final_answers, answer_probs, nbest_answers
|
||||||
|
|
||||||
|
|
||||||
QAResult_ = collections.namedtuple("QAResult", ["unique_id", "start_logits", "end_logits"])
|
QAResult_ = collections.namedtuple(
|
||||||
|
"QAResult", ["unique_id", "start_logits", "end_logits"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# create a wrapper class so that we can add docstrings
|
# create a wrapper class so that we can add docstrings
|
||||||
|
@ -503,15 +525,15 @@ class AnswerExtractor(Transformer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name="bert-base-cased", cache_dir=".", load_model_from_dir=None):
|
def __init__(
|
||||||
|
self, model_name="bert-base-cased", cache_dir=".", load_model_from_dir=None
|
||||||
super().__init__(
|
):
|
||||||
model_class=MODEL_CLASS,
|
model = MODEL_CLASS[model_name].from_pretrained(
|
||||||
model_name=model_name,
|
model_name if load_model_from_dir is None else load_model_from_dir,
|
||||||
num_labels=2,
|
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
load_model_from_dir=load_model_from_dir,
|
output_loading_info=False,
|
||||||
)
|
)
|
||||||
|
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_supported_models():
|
def list_supported_models():
|
||||||
|
@ -613,7 +635,9 @@ class AnswerExtractor(Transformer):
|
||||||
|
|
||||||
# inin scheduler
|
# inin scheduler
|
||||||
scheduler = Transformer.get_default_scheduler(
|
scheduler = Transformer.get_default_scheduler(
|
||||||
optimizer=self.optimizer, warmup_steps=warmup_steps, num_training_steps=max_steps
|
optimizer=self.optimizer,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
num_training_steps=max_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fine tune
|
# fine tune
|
||||||
|
@ -668,13 +692,19 @@ class AnswerExtractor(Transformer):
|
||||||
|
|
||||||
# parallelize model
|
# parallelize model
|
||||||
self.model = parallelize_model(
|
self.model = parallelize_model(
|
||||||
model=self.model, device=device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=-1,
|
model=self.model,
|
||||||
|
device=device,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
gpu_ids=gpu_ids,
|
||||||
|
local_rank=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
for batch in tqdm(test_dataloader, desc="Evaluating", disable=not verbose):
|
for batch in tqdm(test_dataloader, desc="Evaluating", disable=not verbose):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = QAProcessor.get_inputs(batch, device, self.model_name, train_mode=False)
|
inputs = QAProcessor.get_inputs(
|
||||||
|
batch, device, self.model_name, train_mode=False
|
||||||
|
)
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
unique_id_tensor = batch[5]
|
unique_id_tensor = batch[5]
|
||||||
|
|
||||||
|
@ -865,7 +895,9 @@ def postprocess_bert_answer(
|
||||||
# Sort by the sum of the start and end logits in ascending order,
|
# Sort by the sum of the start and end logits in ascending order,
|
||||||
# so that the first element is the most probable answer
|
# so that the first element is the most probable answer
|
||||||
prelim_predictions = sorted(
|
prelim_predictions = sorted(
|
||||||
prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True,
|
prelim_predictions,
|
||||||
|
key=lambda x: (x.start_logit + x.end_logit),
|
||||||
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
seen_predictions = {}
|
seen_predictions = {}
|
||||||
|
@ -890,7 +922,9 @@ def postprocess_bert_answer(
|
||||||
tok_text = " ".join(tok_text.split())
|
tok_text = " ".join(tok_text.split())
|
||||||
orig_text = " ".join(orig_tokens)
|
orig_text = " ".join(orig_tokens)
|
||||||
|
|
||||||
final_text = _get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
final_text = _get_final_text(
|
||||||
|
tok_text, orig_text, do_lower_case, verbose_logging
|
||||||
|
)
|
||||||
if final_text in seen_predictions:
|
if final_text in seen_predictions:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -901,7 +935,9 @@ def postprocess_bert_answer(
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(
|
||||||
_NbestPrediction(
|
_NbestPrediction(
|
||||||
text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit,
|
text=final_text,
|
||||||
|
start_logit=pred.start_logit,
|
||||||
|
end_logit=pred.end_logit,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# if we didn't include the empty option in the n-best, include it
|
# if we didn't include the empty option in the n-best, include it
|
||||||
|
@ -916,7 +952,9 @@ def postprocess_bert_answer(
|
||||||
# In very rare edge cases we could only have single null prediction.
|
# In very rare edge cases we could only have single null prediction.
|
||||||
# So we just create a nonce prediction in this case to avoid failure.
|
# So we just create a nonce prediction in this case to avoid failure.
|
||||||
if len(nbest) == 1:
|
if len(nbest) == 1:
|
||||||
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
nbest.insert(
|
||||||
|
0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)
|
||||||
|
)
|
||||||
|
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
|
@ -956,7 +994,9 @@ def postprocess_bert_answer(
|
||||||
else:
|
else:
|
||||||
# predict "" iff the null score - the score of best non-null > threshold
|
# predict "" iff the null score - the score of best non-null > threshold
|
||||||
score_diff = (
|
score_diff = (
|
||||||
score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
score_null
|
||||||
|
- best_non_null_entry.start_logit
|
||||||
|
- (best_non_null_entry.end_logit)
|
||||||
)
|
)
|
||||||
scores_diff_json[example["qa_id"]] = score_diff
|
scores_diff_json[example["qa_id"]] = score_diff
|
||||||
if score_diff > null_score_diff_threshold:
|
if score_diff > null_score_diff_threshold:
|
||||||
|
@ -1129,7 +1169,9 @@ def postprocess_xlnet_answer(
|
||||||
)
|
)
|
||||||
|
|
||||||
prelim_predictions = sorted(
|
prelim_predictions = sorted(
|
||||||
prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True,
|
prelim_predictions,
|
||||||
|
key=lambda x: (x.start_logit + x.end_logit),
|
||||||
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
seen_predictions = {}
|
seen_predictions = {}
|
||||||
|
@ -1172,7 +1214,9 @@ def postprocess_xlnet_answer(
|
||||||
|
|
||||||
nbest.append(
|
nbest.append(
|
||||||
_NbestPrediction(
|
_NbestPrediction(
|
||||||
text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit,
|
text=final_text,
|
||||||
|
start_logit=pred.start_logit,
|
||||||
|
end_logit=pred.end_logit,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1300,7 +1344,9 @@ def _create_qa_example(qa_input, is_training):
|
||||||
|
|
||||||
if _is_iterable_but_not_string(a_start):
|
if _is_iterable_but_not_string(a_start):
|
||||||
if not _is_iterable_but_not_string(a_text):
|
if not _is_iterable_but_not_string(a_text):
|
||||||
raise Exception("The answer text must be a list when answer start is a list.")
|
raise Exception(
|
||||||
|
"The answer text must be a list when answer start is a list."
|
||||||
|
)
|
||||||
if len(a_start) != 1 and is_training and not impossible:
|
if len(a_start) != 1 and is_training and not impossible:
|
||||||
raise Exception("For training, each question should have exactly 1 answer.")
|
raise Exception("For training, each question should have exactly 1 answer.")
|
||||||
a_start = a_start[0]
|
a_start = a_start[0]
|
||||||
|
@ -1323,7 +1369,9 @@ def _create_qa_example(qa_input, is_training):
|
||||||
cleaned_answer_text = " ".join(whitespace_tokenize(a_text))
|
cleaned_answer_text = " ".join(whitespace_tokenize(a_text))
|
||||||
if actual_text.find(cleaned_answer_text) == -1:
|
if actual_text.find(cleaned_answer_text) == -1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text,
|
"Could not find answer: '%s' vs. '%s'",
|
||||||
|
actual_text,
|
||||||
|
cleaned_answer_text,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
|
@ -1549,7 +1597,10 @@ def _create_qa_features(
|
||||||
else:
|
else:
|
||||||
tok_end_position = len(all_doc_tokens) - 1
|
tok_end_position = len(all_doc_tokens) - 1
|
||||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||||
all_doc_tokens, tok_start_position, tok_end_position, example.orig_answer_text,
|
all_doc_tokens,
|
||||||
|
tok_start_position,
|
||||||
|
tok_end_position,
|
||||||
|
example.orig_answer_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The -3 accounts for [CLS], [SEP] and [SEP]
|
# The -3 accounts for [CLS], [SEP] and [SEP]
|
||||||
|
@ -1583,7 +1634,8 @@ def _create_qa_features(
|
||||||
|
|
||||||
# p_mask: mask with 1 for token than cannot be in the answer
|
# p_mask: mask with 1 for token than cannot be in the answer
|
||||||
# (0 for token which can be in an answer)
|
# (0 for token which can be in an answer)
|
||||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
# Original TF implem also keep the classification token (set to 0)
|
||||||
|
# (not sure why...)
|
||||||
# TODO: Should we set p_mask = 1 for cls token?
|
# TODO: Should we set p_mask = 1 for cls token?
|
||||||
p_mask = []
|
p_mask = []
|
||||||
|
|
||||||
|
@ -1612,9 +1664,11 @@ def _create_qa_features(
|
||||||
split_token_index = doc_span.start + i
|
split_token_index = doc_span.start + i
|
||||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||||
|
|
||||||
## TODO: maybe this can be improved to compute
|
# TODO: maybe this can be improved to compute
|
||||||
# is_max_context for each token only once.
|
# is_max_context for each token only once.
|
||||||
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
|
is_max_context = _check_is_max_context(
|
||||||
|
doc_spans, doc_span_index, split_token_index
|
||||||
|
)
|
||||||
token_is_max_context[len(tokens)] = is_max_context
|
token_is_max_context[len(tokens)] = is_max_context
|
||||||
tokens.append(all_doc_tokens[split_token_index])
|
tokens.append(all_doc_tokens[split_token_index])
|
||||||
if model_type == "xlnet":
|
if model_type == "xlnet":
|
||||||
|
@ -1720,10 +1774,13 @@ def _create_qa_features(
|
||||||
# -------------------------------------------------------------------------------------------------
|
# -------------------------------------------------------------------------------------------------
|
||||||
# Post processing helper functions
|
# Post processing helper functions
|
||||||
_PrelimPrediction = collections.namedtuple(
|
_PrelimPrediction = collections.namedtuple(
|
||||||
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
|
"PrelimPrediction",
|
||||||
|
["feature_index", "start_index", "end_index", "start_logit", "end_logit"],
|
||||||
)
|
)
|
||||||
|
|
||||||
_NbestPrediction = collections.namedtuple("NbestPrediction", ["text", "start_logit", "end_logit"])
|
_NbestPrediction = collections.namedtuple(
|
||||||
|
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
def _get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||||
|
@ -1786,7 +1843,9 @@ def _get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||||
if len(orig_ns_text) != len(tok_ns_text):
|
if len(orig_ns_text) != len(tok_ns_text):
|
||||||
if verbose_logging:
|
if verbose_logging:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text,
|
"Length not equal after stripping spaces: '%s' vs '%s'",
|
||||||
|
orig_ns_text,
|
||||||
|
tok_ns_text,
|
||||||
)
|
)
|
||||||
return orig_text
|
return orig_text
|
||||||
|
|
||||||
|
|
|
@ -2,54 +2,22 @@
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers.modeling_albert import (
|
from transformers import (
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
AlbertForSequenceClassification,
|
AutoConfig,
|
||||||
)
|
AutoModelForSequenceClassification,
|
||||||
from transformers.modeling_bert import (
|
AutoTokenizer,
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
BertForSequenceClassification,
|
|
||||||
)
|
|
||||||
from transformers.modeling_distilbert import (
|
|
||||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
DistilBertForSequenceClassification,
|
|
||||||
)
|
|
||||||
from transformers.modeling_roberta import (
|
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
RobertaForSequenceClassification,
|
|
||||||
)
|
|
||||||
from transformers.modeling_xlnet import (
|
|
||||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
XLNetForSequenceClassification,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils_nlp.common.pytorch_utils import compute_training_steps
|
from utils_nlp.common.pytorch_utils import compute_training_steps
|
||||||
from utils_nlp.models.transformers.common import (
|
from utils_nlp.models.transformers.common import MAX_SEQ_LEN, Transformer
|
||||||
MAX_SEQ_LEN,
|
|
||||||
TOKENIZER_CLASS,
|
|
||||||
Transformer,
|
|
||||||
)
|
|
||||||
from utils_nlp.models.transformers.datasets import SCDataSet, SPCDataSet
|
from utils_nlp.models.transformers.datasets import SCDataSet, SPCDataSet
|
||||||
|
|
||||||
MODEL_CLASS = {}
|
supported_models = [
|
||||||
MODEL_CLASS.update(
|
list(x.pretrained_config_archive_map)
|
||||||
{k: BertForSequenceClassification for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
for x in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||||
)
|
]
|
||||||
MODEL_CLASS.update(
|
supported_models = sorted([x for y in supported_models for x in y])
|
||||||
{k: RobertaForSequenceClassification for k in ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP}
|
|
||||||
)
|
|
||||||
MODEL_CLASS.update(
|
|
||||||
{k: XLNetForSequenceClassification for k in XLNET_PRETRAINED_MODEL_ARCHIVE_MAP}
|
|
||||||
)
|
|
||||||
MODEL_CLASS.update(
|
|
||||||
{
|
|
||||||
k: DistilBertForSequenceClassification
|
|
||||||
for k in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
}
|
|
||||||
)
|
|
||||||
MODEL_CLASS.update(
|
|
||||||
{k: AlbertForSequenceClassification for k in ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Processor:
|
class Processor:
|
||||||
|
@ -68,7 +36,10 @@ class Processor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name="bert-base-cased", to_lower=False, cache_dir="."):
|
def __init__(self, model_name="bert-base-cased", to_lower=False, cache_dir="."):
|
||||||
self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(
|
self.model_name = model_name
|
||||||
|
self.to_lower = to_lower
|
||||||
|
self.cache_dir = cache_dir
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
do_lower_case=to_lower,
|
do_lower_case=to_lower,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
|
@ -84,7 +55,7 @@ class Processor:
|
||||||
batch (tuple): A tuple containing input ids, attention mask,
|
batch (tuple): A tuple containing input ids, attention mask,
|
||||||
segment ids, and labels tensors.
|
segment ids, and labels tensors.
|
||||||
device (torch.device): A PyTorch device.
|
device (torch.device): A PyTorch device.
|
||||||
model_name (bool, optional): Model name used to format the inputs.
|
model_name (bool): Model name used to format the inputs.
|
||||||
train_mode (bool, optional): Training mode flag.
|
train_mode (bool, optional): Training mode flag.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
|
||||||
|
@ -93,13 +64,7 @@ class Processor:
|
||||||
Labels are only returned when train_mode is True.
|
Labels are only returned when train_mode is True.
|
||||||
"""
|
"""
|
||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
if model_name.split("-")[0] in [
|
if model_name in supported_models:
|
||||||
"bert",
|
|
||||||
"xlnet",
|
|
||||||
"roberta",
|
|
||||||
"distilbert",
|
|
||||||
"albert",
|
|
||||||
]:
|
|
||||||
if train_mode:
|
if train_mode:
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": batch[0],
|
"input_ids": batch[0],
|
||||||
|
@ -109,8 +74,8 @@ class Processor:
|
||||||
else:
|
else:
|
||||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
|
||||||
|
|
||||||
# distilbert doesn't support segment ids
|
# distilbert, bart don't support segment ids
|
||||||
if model_name.split("-")[0] not in ["distilbert"]:
|
if model_name.split("-")[0] not in ["distilbert", "bart"]:
|
||||||
inputs["token_type_ids"] = batch[2]
|
inputs["token_type_ids"] = batch[2]
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -244,16 +209,17 @@ class Processor:
|
||||||
|
|
||||||
class SequenceClassifier(Transformer):
|
class SequenceClassifier(Transformer):
|
||||||
def __init__(self, model_name="bert-base-cased", num_labels=2, cache_dir="."):
|
def __init__(self, model_name="bert-base-cased", num_labels=2, cache_dir="."):
|
||||||
super().__init__(
|
config = AutoConfig.from_pretrained(
|
||||||
model_class=MODEL_CLASS,
|
model_name, num_labels=num_labels, cache_dir=cache_dir
|
||||||
model_name=model_name,
|
|
||||||
num_labels=num_labels,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
)
|
)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, cache_dir=cache_dir, config=config, output_loading_info=False
|
||||||
|
)
|
||||||
|
super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_supported_models():
|
def list_supported_models():
|
||||||
return list(MODEL_CLASS)
|
return supported_models
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
|
@ -345,7 +311,9 @@ class SequenceClassifier(Transformer):
|
||||||
|
|
||||||
# init scheduler
|
# init scheduler
|
||||||
scheduler = Transformer.get_default_scheduler(
|
scheduler = Transformer.get_default_scheduler(
|
||||||
optimizer=self.optimizer, warmup_steps=warmup_steps, num_training_steps=max_steps
|
optimizer=self.optimizer,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
num_training_steps=max_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fine tune
|
# fine tune
|
||||||
|
|
Загрузка…
Ссылка в новой задаче