Merge branch 'hlu/update_entailment_notebook_to_use_transformers' of https://github.com/Microsoft/NLP into hlu/update_entailment_notebook_to_use_transformers
This commit is contained in:
Коммит
6301686f52
|
@ -22,24 +22,16 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# Before You Start\n",
|
"# Before You Start\n",
|
||||||
"\n",
|
"\n",
|
||||||
"The running time shown in this notebook is running bert-large-cased on a Standard_NC24rs_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n",
|
"It takes about 4 hours to fine-tune the `bert-large-cased` model on a Standard_NC24rs_v3 Azure Data Science Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n",
|
||||||
"> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
|
"> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
|
||||||
"\n",
|
"\n",
|
||||||
"The table below provides some reference running time on different machine configurations. \n",
|
|
||||||
"\n",
|
|
||||||
"|QUICK_RUN|Machine Configurations|Running time|\n",
|
|
||||||
"|:---------|:----------------------|:------------|\n",
|
|
||||||
"|True|4 **CPU**s, 14GB memory| ~ 15 minutes|\n",
|
|
||||||
"|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5 minutes|\n",
|
|
||||||
"|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 10.5 hours|\n",
|
|
||||||
"|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ 2.5 hours|\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
|
"If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -56,31 +48,24 @@
|
||||||
"To classify a sentence pair, we concatenate the tokens in both sentences and separate the sentences by the special [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.The NLI task essentially becomes a sequence classification task. For example, the figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. \n",
|
"To classify a sentence pair, we concatenate the tokens in both sentences and separate the sentences by the special [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.The NLI task essentially becomes a sequence classification task. For example, the figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. \n",
|
||||||
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">\n",
|
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">\n",
|
||||||
"\n",
|
"\n",
|
||||||
"We compare the training time and performance of three models: bert-base-cased, bert-large-cased, and xlnet-large-cased. The model used can be set in the **Configurations** section. "
|
"We compare the training time and performance of bert-large-cased and xlnet-large-cased. The model used can be set in the **Configurations** section. "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"scrolled": false
|
"scrolled": false
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"I1110 19:13:59.935610 140117887072000 file_utils.py:39] PyTorch version 1.2.0 available.\n",
|
|
||||||
"I1110 19:13:59.978967 140117887072000 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import sys, os\n",
|
"import sys, os\n",
|
||||||
"nlp_path = os.path.abspath('../../')\n",
|
"nlp_path = os.path.abspath('../../')\n",
|
||||||
"if nlp_path not in sys.path:\n",
|
"if nlp_path not in sys.path:\n",
|
||||||
" sys.path.insert(0, nlp_path)\n",
|
" sys.path.insert(0, nlp_path)\n",
|
||||||
" \n",
|
"\n",
|
||||||
|
"import scrapbook as sb\n",
|
||||||
|
"\n",
|
||||||
"from tempfile import TemporaryDirectory\n",
|
"from tempfile import TemporaryDirectory\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
|
@ -104,39 +89,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"['bert-base-uncased',\n",
|
|
||||||
" 'bert-large-uncased',\n",
|
|
||||||
" 'bert-base-cased',\n",
|
|
||||||
" 'bert-large-cased',\n",
|
|
||||||
" 'bert-base-multilingual-uncased',\n",
|
|
||||||
" 'bert-base-multilingual-cased',\n",
|
|
||||||
" 'bert-base-chinese',\n",
|
|
||||||
" 'bert-base-german-cased',\n",
|
|
||||||
" 'bert-large-uncased-whole-word-masking',\n",
|
|
||||||
" 'bert-large-cased-whole-word-masking',\n",
|
|
||||||
" 'bert-large-uncased-whole-word-masking-finetuned-squad',\n",
|
|
||||||
" 'bert-large-cased-whole-word-masking-finetuned-squad',\n",
|
|
||||||
" 'bert-base-cased-finetuned-mrpc',\n",
|
|
||||||
" 'roberta-base',\n",
|
|
||||||
" 'roberta-large',\n",
|
|
||||||
" 'roberta-large-mnli',\n",
|
|
||||||
" 'xlnet-base-cased',\n",
|
|
||||||
" 'xlnet-large-cased',\n",
|
|
||||||
" 'distilbert-base-uncased',\n",
|
|
||||||
" 'distilbert-base-uncased-distilled-squad']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"SequenceClassifier.list_supported_models()"
|
"SequenceClassifier.list_supported_models()"
|
||||||
]
|
]
|
||||||
|
@ -150,7 +105,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": [
|
"tags": [
|
||||||
"parameters"
|
"parameters"
|
||||||
|
@ -194,8 +149,7 @@
|
||||||
"LABEL_COL = \"gold_label\"\n",
|
"LABEL_COL = \"gold_label\"\n",
|
||||||
"LABEL_COL_NUM = \"gold_label_num\"\n",
|
"LABEL_COL_NUM = \"gold_label_num\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"CACHE_DIR = TemporaryDirectory().name\n",
|
"CACHE_DIR = TemporaryDirectory().name"
|
||||||
"CACHE_DIR = \"./temp\""
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -209,7 +163,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -220,7 +174,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -230,33 +184,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Training dataset size: 392702\n",
|
|
||||||
"Development (matched) dataset size: 9815\n",
|
|
||||||
"Development (mismatched) dataset size: 9832\n",
|
|
||||||
"\n",
|
|
||||||
" gold_label sentence1 \\\n",
|
|
||||||
"0 neutral Conceptually cream skimming has two basic dime... \n",
|
|
||||||
"1 entailment you know during the season and i guess at at y... \n",
|
|
||||||
"2 entailment One of our number will carry out your instruct... \n",
|
|
||||||
"3 entailment How do you know? All this is their information... \n",
|
|
||||||
"4 neutral yeah i tell you what though if you go price so... \n",
|
|
||||||
"\n",
|
|
||||||
" sentence2 \n",
|
|
||||||
"0 Product and geography are what make cream skim... \n",
|
|
||||||
"1 You lose the things to the following level if ... \n",
|
|
||||||
"2 A member of my team will execute your orders w... \n",
|
|
||||||
"3 This information belongs to them. \n",
|
|
||||||
"4 The tennis shoes have a range of prices. \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"print(\"Training dataset size: {}\".format(train_df.shape[0]))\n",
|
"print(\"Training dataset size: {}\".format(train_df.shape[0]))\n",
|
||||||
"print(\"Development (matched) dataset size: {}\".format(dev_df_matched.shape[0]))\n",
|
"print(\"Development (matched) dataset size: {}\".format(dev_df_matched.shape[0]))\n",
|
||||||
|
@ -267,7 +197,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -278,7 +208,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -293,25 +223,18 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Tokenize and Preprocess\n",
|
"## Tokenize and Preprocess\n",
|
||||||
"Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
|
"Before training, we tokenize and preprocess the sentence texts to convert them into the format required by transformer model classes. \n",
|
||||||
|
"The `create_dataloader_from_df` method of the `Processor` class performs the following preprocessing steps and returns a Pytorch `DataLoader`\n",
|
||||||
|
"* Tokenize input texts using the tokenizer of the pre-trained model specified by `model_name`. \n",
|
||||||
|
"* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.\n",
|
||||||
|
"* Pad or truncate the token lists to the specified max length."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"I1110 19:14:11.376676 140117887072000 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at ./temp/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n",
|
|
||||||
"100%|██████████| 392702/392702 [03:48<00:00, 1715.17it/s]\n",
|
|
||||||
"100%|██████████| 9815/9815 [00:05<00:00, 1797.48it/s]\n",
|
|
||||||
"100%|██████████| 9832/9832 [00:05<00:00, 1709.69it/s]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"processor = Processor(model_name=MODEL_NAME, cache_dir=CACHE_DIR, to_lower=TO_LOWER)\n",
|
"processor = Processor(model_name=MODEL_NAME, cache_dir=CACHE_DIR, to_lower=TO_LOWER)\n",
|
||||||
"train_dataloader = processor.create_dataloader_from_df(\n",
|
"train_dataloader = processor.create_dataloader_from_df(\n",
|
||||||
|
@ -341,21 +264,6 @@
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
|
||||||
"\n",
|
|
||||||
"* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
|
||||||
"* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
|
||||||
"* Pad or truncate the token lists to the specified max length\n",
|
|
||||||
"* Return mask lists that indicate paddings' positions\n",
|
|
||||||
"* Return token type id lists that indicate which sentence the tokens belong to\n",
|
|
||||||
"\n",
|
|
||||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -416,31 +324,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating: 100%|██████████| 614/614 [04:53<00:00, 2.12it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Prediction time : 0.082 hrs\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"with Timer() as t:\n",
|
"with Timer() as t:\n",
|
||||||
" predictions_matched = classifier.predict(dev_dataloader_matched)\n",
|
" predictions_matched = classifier.predict(dev_dataloader_matched)\n",
|
||||||
|
@ -449,31 +335,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Evaluating: 100%|██████████| 615/615 [04:53<00:00, 2.12it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Prediction time : 0.082 hrs\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"with Timer() as t:\n",
|
"with Timer() as t:\n",
|
||||||
" predictions_mismatched = classifier.predict(dev_dataloader_mismatched)\n",
|
" predictions_mismatched = classifier.predict(dev_dataloader_mismatched)\n",
|
||||||
|
@ -489,26 +353,9 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" precision recall f1-score support\n",
|
|
||||||
"\n",
|
|
||||||
"contradiction 0.872 0.894 0.883 3213\n",
|
|
||||||
" entailment 0.913 0.862 0.887 3479\n",
|
|
||||||
" neutral 0.813 0.842 0.828 3123\n",
|
|
||||||
"\n",
|
|
||||||
" micro avg 0.866 0.866 0.866 9815\n",
|
|
||||||
" macro avg 0.866 0.866 0.866 9815\n",
|
|
||||||
" weighted avg 0.868 0.866 0.867 9815\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"predictions_matched = label_encoder.inverse_transform(predictions_matched)\n",
|
"predictions_matched = label_encoder.inverse_transform(predictions_matched)\n",
|
||||||
"print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))"
|
"print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))"
|
||||||
|
@ -516,28 +363,11 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" precision recall f1-score support\n",
|
|
||||||
"\n",
|
|
||||||
"contradiction 0.891 0.888 0.889 3240\n",
|
|
||||||
" entailment 0.899 0.862 0.880 3463\n",
|
|
||||||
" neutral 0.810 0.850 0.830 3129\n",
|
|
||||||
"\n",
|
|
||||||
" micro avg 0.867 0.867 0.867 9832\n",
|
|
||||||
" macro avg 0.867 0.867 0.866 9832\n",
|
|
||||||
" weighted avg 0.868 0.867 0.867 9832\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)\n",
|
"predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)\n",
|
||||||
"print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))"
|
"print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))"
|
||||||
|
@ -559,6 +389,22 @@
|
||||||
"|xlnet-large-cased|5.15 hrs|0.11 hrs|0.887|0.890|\n",
|
"|xlnet-large-cased|5.15 hrs|0.11 hrs|0.887|0.890|\n",
|
||||||
"|bert-large-cased|4.01 hrs|0.08 hrs|0.867|0.867|"
|
"|bert-large-cased|4.01 hrs|0.08 hrs|0.867|0.867|"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result_matched_dict = classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3, output_dict=True)\n",
|
||||||
|
"result_mismatched_dict = classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3, output_dict=True)\n",
|
||||||
|
"sb.glue(\"matched_precision\", result_matched_dict[\"weighted avg\"][\"precision\"])\n",
|
||||||
|
"sb.glue(\"matched_recall\", result_matched_dict[\"weighted avg\"][\"recall\"])\n",
|
||||||
|
"sb.glue(\"matched_f1\", result_matched_dict[\"weighted avg\"][\"f1-score\"])\n",
|
||||||
|
"sb.glue(\"mismatched_precision\", result_mismatched_dict[\"weighted avg\"][\"precision\"])\n",
|
||||||
|
"sb.glue(\"mismatched_recall\", result_mismatched_dict[\"weighted avg\"][\"recall\"])\n",
|
||||||
|
"sb.glue(\"mismatched_f1\", result_mismatched_dict[\"weighted avg\"][\"f1-score\"])"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
@ -63,7 +63,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -78,7 +78,7 @@
|
||||||
" sys.path.insert(0, nlp_path)\n",
|
" sys.path.insert(0, nlp_path)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from utils_nlp.dataset.squad import load_pandas_df\n",
|
"from utils_nlp.dataset.squad import load_pandas_df\n",
|
||||||
"from utils_nlp.dataset.pytorch import QADataset\n",
|
"from utils_nlp.models.transformers.datasets import QADataset\n",
|
||||||
"from utils_nlp.models.transformers.question_answering import (\n",
|
"from utils_nlp.models.transformers.question_answering import (\n",
|
||||||
" QAProcessor,\n",
|
" QAProcessor,\n",
|
||||||
" AnswerExtractor\n",
|
" AnswerExtractor\n",
|
||||||
|
@ -175,6 +175,7 @@
|
||||||
"DOC_STRIDE = 128\n",
|
"DOC_STRIDE = 128\n",
|
||||||
"PER_GPU_BATCH_SIZE = 4\n",
|
"PER_GPU_BATCH_SIZE = 4\n",
|
||||||
"GRADIENT_ACCUMULATION_STEPS = 1\n",
|
"GRADIENT_ACCUMULATION_STEPS = 1\n",
|
||||||
|
"NUM_GPUS = torch.cuda.device_count()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if QUICK_RUN:\n",
|
"if QUICK_RUN:\n",
|
||||||
" TRAIN_DATA_USED_PERCENT = 0.001\n",
|
" TRAIN_DATA_USED_PERCENT = 0.001\n",
|
||||||
|
@ -558,7 +559,7 @@
|
||||||
"* Pad the concatenated token sequence to `max_seq_length` if it's shorter.\n",
|
"* Pad the concatenated token sequence to `max_seq_length` if it's shorter.\n",
|
||||||
"* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.\n",
|
"* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"`QAProcessor.preprocess` returns a Pytorch TensorDataset. By default, it saves `cached_examples_train/test.jsonl` and `cached_features_train/test.jsonl` to `./cached_qa_features`. These files are required by postprocessing the predicted answer start and end indices to get the final answer text. You can change the default file directory by specifying `feature_cache_dir`. "
|
"`QAProcessor.preprocess` returns a Pytorch Dataloader. By default, it saves `cached_examples_train/test.jsonl` and `cached_features_train/test.jsonl` to `./cached_qa_features`. These files are required by postprocessing the predicted answer start and end indices to get the final answer text. You can change the default file directory by specifying `feature_cache_dir`. "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -576,16 +577,20 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"qa_processor = QAProcessor(model_name=MODEL_NAME, to_lower=DO_LOWER_CASE)\n",
|
"qa_processor = QAProcessor(model_name=MODEL_NAME, to_lower=DO_LOWER_CASE)\n",
|
||||||
"train_features = qa_processor.preprocess(\n",
|
"train_dataloader = qa_processor.preprocess(\n",
|
||||||
" train_dataset, \n",
|
" train_dataset, \n",
|
||||||
|
" batch_size=PER_GPU_BATCH_SIZE,\n",
|
||||||
|
" num_gpus=NUM_GPUS,\n",
|
||||||
" is_training=True,\n",
|
" is_training=True,\n",
|
||||||
" max_question_length=MAX_QUESTION_LENGTH,\n",
|
" max_question_length=MAX_QUESTION_LENGTH,\n",
|
||||||
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
||||||
" doc_stride=DOC_STRIDE\n",
|
" doc_stride=DOC_STRIDE\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"dev_features = qa_processor.preprocess(\n",
|
"dev_dataloader = qa_processor.preprocess(\n",
|
||||||
" dev_dataset, \n",
|
" dev_dataset, \n",
|
||||||
|
" batch_size=PER_GPU_BATCH_SIZE,\n",
|
||||||
|
" num_gpus=NUM_GPUS,\n",
|
||||||
" is_training=False,\n",
|
" is_training=False,\n",
|
||||||
" max_question_length=MAX_QUESTION_LENGTH,\n",
|
" max_question_length=MAX_QUESTION_LENGTH,\n",
|
||||||
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
" max_seq_length=MAX_SEQ_LENGTH,\n",
|
||||||
|
@ -616,10 +621,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"with Timer() as t:\n",
|
"with Timer() as t:\n",
|
||||||
" qa_extractor.fit(train_dataset=train_features,\n",
|
" qa_extractor.fit(train_dataloader,\n",
|
||||||
" num_epochs=NUM_EPOCHS,\n",
|
" num_epochs=NUM_EPOCHS,\n",
|
||||||
" learning_rate=LEARNING_RATE,\n",
|
" learning_rate=LEARNING_RATE,\n",
|
||||||
" per_gpu_batch_size=PER_GPU_BATCH_SIZE,\n",
|
|
||||||
" gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
" gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n",
|
||||||
" seed=RANDOM_SEED,\n",
|
" seed=RANDOM_SEED,\n",
|
||||||
" cache_model=True)\n",
|
" cache_model=True)\n",
|
||||||
|
@ -648,7 +652,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"qa_results = qa_extractor.predict(dev_features, per_gpu_batch_size=PER_GPU_BATCH_SIZE)"
|
"qa_results = qa_extractor.predict(dev_dataloader)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -824,9 +828,9 @@
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"celltoolbar": "Tags",
|
"celltoolbar": "Tags",
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python [default]",
|
"display_name": "nlp_gpu",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "nlp_gpu"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
@ -838,7 +842,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.5.5"
|
"version": "3.6.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
from utils_nlp.dataset.pytorch import QADataset
|
from utils_nlp.models.transformers.datasets import QADataset
|
||||||
from utils_nlp.models.transformers.question_answering import (
|
from utils_nlp.models.transformers.question_answering import (
|
||||||
QAProcessor,
|
QAProcessor,
|
||||||
AnswerExtractor,
|
AnswerExtractor,
|
||||||
|
@ -11,6 +11,11 @@ from utils_nlp.models.transformers.question_answering import (
|
||||||
CACHED_FEATURES_TEST_FILE,
|
CACHED_FEATURES_TEST_FILE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
NUM_GPUS = max(1, torch.cuda.device_count())
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def qa_test_data(qa_test_df, tmp):
|
def qa_test_data(qa_test_df, tmp):
|
||||||
|
@ -61,6 +66,8 @@ def qa_test_data(qa_test_df, tmp):
|
||||||
qa_processor_bert = QAProcessor()
|
qa_processor_bert = QAProcessor()
|
||||||
train_features_bert = qa_processor_bert.preprocess(
|
train_features_bert = qa_processor_bert.preprocess(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_gpus=NUM_GPUS,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
|
@ -70,6 +77,8 @@ def qa_test_data(qa_test_df, tmp):
|
||||||
|
|
||||||
test_features_bert = qa_processor_bert.preprocess(
|
test_features_bert = qa_processor_bert.preprocess(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_gpus=NUM_GPUS,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
|
@ -80,6 +89,8 @@ def qa_test_data(qa_test_df, tmp):
|
||||||
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased")
|
qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased")
|
||||||
train_features_xlnet = qa_processor_xlnet.preprocess(
|
train_features_xlnet = qa_processor_xlnet.preprocess(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_gpus=NUM_GPUS,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
|
@ -89,6 +100,8 @@ def qa_test_data(qa_test_df, tmp):
|
||||||
|
|
||||||
test_features_xlnet = qa_processor_xlnet.preprocess(
|
test_features_xlnet = qa_processor_xlnet.preprocess(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_gpus=NUM_GPUS,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
|
@ -99,6 +112,8 @@ def qa_test_data(qa_test_df, tmp):
|
||||||
qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased")
|
qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased")
|
||||||
train_features_distilbert = qa_processor_distilbert.preprocess(
|
train_features_distilbert = qa_processor_distilbert.preprocess(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_gpus=NUM_GPUS,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
|
@ -108,6 +123,8 @@ def qa_test_data(qa_test_df, tmp):
|
||||||
|
|
||||||
test_features_distilbert = qa_processor_distilbert.preprocess(
|
test_features_distilbert = qa_processor_distilbert.preprocess(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
num_gpus=NUM_GPUS,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
max_question_length=16,
|
max_question_length=16,
|
||||||
max_seq_length=64,
|
max_seq_length=64,
|
||||||
|
@ -157,9 +174,7 @@ def test_QAProcessor(qa_test_data, tmp):
|
||||||
def test_AnswerExtractor(qa_test_data, tmp):
|
def test_AnswerExtractor(qa_test_data, tmp):
|
||||||
# test bert
|
# test bert
|
||||||
qa_extractor_bert = AnswerExtractor(cache_dir=tmp)
|
qa_extractor_bert = AnswerExtractor(cache_dir=tmp)
|
||||||
qa_extractor_bert.fit(
|
qa_extractor_bert.fit(qa_test_data["train_features_bert"], cache_model=True)
|
||||||
qa_test_data["train_features_bert"], cache_model=True, per_gpu_batch_size=8
|
|
||||||
)
|
|
||||||
|
|
||||||
# test saving fine-tuned model
|
# test saving fine-tuned model
|
||||||
model_output_dir = os.path.join(tmp, "fine_tuned")
|
model_output_dir = os.path.join(tmp, "fine_tuned")
|
||||||
|
@ -170,15 +185,11 @@ def test_AnswerExtractor(qa_test_data, tmp):
|
||||||
qa_extractor_from_cache.predict(qa_test_data["test_features_bert"])
|
qa_extractor_from_cache.predict(qa_test_data["test_features_bert"])
|
||||||
|
|
||||||
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
|
qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp)
|
||||||
qa_extractor_xlnet.fit(
|
qa_extractor_xlnet.fit(qa_test_data["train_features_xlnet"], cache_model=False)
|
||||||
qa_test_data["train_features_xlnet"], cache_model=False, per_gpu_batch_size=8
|
|
||||||
)
|
|
||||||
qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"])
|
qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"])
|
||||||
|
|
||||||
qa_extractor_distilbert = AnswerExtractor(model_name="distilbert-base-uncased", cache_dir=tmp)
|
qa_extractor_distilbert = AnswerExtractor(model_name="distilbert-base-uncased", cache_dir=tmp)
|
||||||
qa_extractor_distilbert.fit(
|
qa_extractor_distilbert.fit(qa_test_data["train_features_distilbert"], cache_model=False)
|
||||||
qa_test_data["train_features_distilbert"], cache_model=False, per_gpu_batch_size=8
|
|
||||||
)
|
|
||||||
qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"])
|
qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,8 @@ import math
|
||||||
import jsonlines
|
import jsonlines
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader
|
from torch.utils.data import TensorDataset, SequentialSampler, DataLoader, RandomSampler
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering
|
||||||
|
@ -40,11 +41,7 @@ from transformers.modeling_distilbert import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils_nlp.common.pytorch_utils import get_device
|
from utils_nlp.common.pytorch_utils import get_device
|
||||||
from utils_nlp.models.transformers.common import (
|
from utils_nlp.models.transformers.common import MAX_SEQ_LEN, TOKENIZER_CLASS, Transformer
|
||||||
MAX_SEQ_LEN,
|
|
||||||
TOKENIZER_CLASS,
|
|
||||||
Transformer,
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_CLASS = {}
|
MODEL_CLASS = {}
|
||||||
MODEL_CLASS.update({k: BertForQuestionAnswering for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP})
|
MODEL_CLASS.update({k: BertForQuestionAnswering for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP})
|
||||||
|
@ -146,6 +143,9 @@ class QAProcessor:
|
||||||
self,
|
self,
|
||||||
qa_dataset,
|
qa_dataset,
|
||||||
is_training,
|
is_training,
|
||||||
|
batch_size=32,
|
||||||
|
num_gpus=None,
|
||||||
|
distributed=False,
|
||||||
max_question_length=64,
|
max_question_length=64,
|
||||||
max_seq_length=MAX_SEQ_LEN,
|
max_seq_length=MAX_SEQ_LEN,
|
||||||
doc_stride=128,
|
doc_stride=128,
|
||||||
|
@ -243,37 +243,42 @@ class QAProcessor:
|
||||||
examples_writer.write_all(qa_examples_json)
|
examples_writer.write_all(qa_examples_json)
|
||||||
features_writer.write_all(features_json)
|
features_writer.write_all(features_json)
|
||||||
|
|
||||||
# TODO: maybe generalize the following code
|
|
||||||
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
|
||||||
input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
|
||||||
segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
|
||||||
cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
|
||||||
p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
|
||||||
|
|
||||||
if is_training:
|
|
||||||
start_positions = torch.tensor(
|
|
||||||
[f.start_position for f in features], dtype=torch.long
|
|
||||||
)
|
|
||||||
end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
|
||||||
qa_dataset = TensorDataset(
|
|
||||||
input_ids,
|
|
||||||
input_mask,
|
|
||||||
segment_ids,
|
|
||||||
start_positions,
|
|
||||||
end_positions,
|
|
||||||
cls_index,
|
|
||||||
p_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
unique_id_all = torch.tensor(unique_id_all, dtype=torch.long)
|
|
||||||
qa_dataset = TensorDataset(
|
|
||||||
input_ids, input_mask, segment_ids, cls_index, p_mask, unique_id_all
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("QA examples are saved to {}".format(examples_file))
|
logger.info("QA examples are saved to {}".format(examples_file))
|
||||||
logger.info("QA features are saved to {}".format(features_file))
|
logger.info("QA features are saved to {}".format(features_file))
|
||||||
|
|
||||||
return qa_dataset
|
# TODO: maybe generalize the following code
|
||||||
|
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
|
input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||||
|
segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||||
|
cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||||
|
p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||||
|
end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||||
|
qa_dataset = TensorDataset(
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
segment_ids,
|
||||||
|
start_positions,
|
||||||
|
end_positions,
|
||||||
|
cls_index,
|
||||||
|
p_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unique_id_all = torch.tensor(unique_id_all, dtype=torch.long)
|
||||||
|
qa_dataset = TensorDataset(
|
||||||
|
input_ids, input_mask, segment_ids, cls_index, p_mask, unique_id_all
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_gpus is not None:
|
||||||
|
batch_size = batch_size * max(1, num_gpus)
|
||||||
|
if distributed:
|
||||||
|
sampler = DistributedSampler(qa_dataset)
|
||||||
|
else:
|
||||||
|
sampler = RandomSampler(qa_dataset) if is_training else SequentialSampler(qa_dataset)
|
||||||
|
|
||||||
|
return DataLoader(qa_dataset, sampler=sampler, batch_size=batch_size)
|
||||||
|
|
||||||
def postprocess(
|
def postprocess(
|
||||||
self,
|
self,
|
||||||
|
@ -469,9 +474,8 @@ class AnswerExtractor(Transformer):
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
train_dataset,
|
train_dataloader,
|
||||||
num_gpus=None,
|
num_gpus=None,
|
||||||
per_gpu_batch_size=8,
|
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
learning_rate=5e-5,
|
learning_rate=5e-5,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm=1.0,
|
||||||
|
@ -491,12 +495,10 @@ class AnswerExtractor(Transformer):
|
||||||
Fine-tune pre-trained transofmer models for question answering.
|
Fine-tune pre-trained transofmer models for question answering.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_dataset (QADataset): Training dataset of type
|
train_dataloader (Dataloader): Dataloader for the training data.
|
||||||
:class:`utils_nlp.dataset.pytorch.QADataset`.
|
|
||||||
num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will
|
num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will
|
||||||
be used. If set to 0 or GPUs are not available, CPU device will
|
be used. If set to 0 or GPUs are not available, CPU device will
|
||||||
be used. Defaults to None.
|
be used. Defaults to None.
|
||||||
per_gpu_batch_size (int, optional): Training batch size on each GPU. Defaults to 8.
|
|
||||||
num_epochs (int, optional): Number of training epochs. Defaults to 1.
|
num_epochs (int, optional): Number of training epochs. Defaults to 1.
|
||||||
learning_rate (float, optional): Learning rate of the AdamW optimizer. Defaults to
|
learning_rate (float, optional): Learning rate of the AdamW optimizer. Defaults to
|
||||||
5e-5.
|
5e-5.
|
||||||
|
@ -530,14 +532,13 @@ class AnswerExtractor(Transformer):
|
||||||
|
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
super().fine_tune(
|
super().fine_tune(
|
||||||
train_dataset=train_dataset,
|
train_dataloader=train_dataloader,
|
||||||
get_inputs=QAProcessor.get_inputs,
|
get_inputs=QAProcessor.get_inputs,
|
||||||
device=device,
|
device=device,
|
||||||
max_steps=max_steps,
|
max_steps=max_steps,
|
||||||
num_train_epochs=num_epochs,
|
num_train_epochs=num_epochs,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
per_gpu_train_batch_size=per_gpu_batch_size,
|
|
||||||
n_gpu=num_gpus,
|
n_gpu=num_gpus,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
|
@ -552,22 +553,13 @@ class AnswerExtractor(Transformer):
|
||||||
if cache_model:
|
if cache_model:
|
||||||
self.save_model()
|
self.save_model()
|
||||||
|
|
||||||
def predict(
|
def predict(self, test_dataloader, num_gpus=None, local_rank=-1, verbose=True):
|
||||||
self,
|
|
||||||
test_dataset,
|
|
||||||
per_gpu_batch_size=16,
|
|
||||||
num_gpus=None,
|
|
||||||
local_rank=-1,
|
|
||||||
verbose=True,
|
|
||||||
):
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Predicts answer start and end logits.
|
Predicts answer start and end logits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
test_dataset (QADataset): Testing dataset of type
|
test_dataloader (QADataset): Dataloader for the testing data.
|
||||||
:class:`utils_nlp.dataset.pytorch.QADataset`.
|
|
||||||
per_gpu_batch_size (int, optional): Testing batch size on each GPU. Defaults to 16.
|
|
||||||
num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will
|
num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will
|
||||||
be used. If set to 0 or GPUs are not available, CPU device will
|
be used. If set to 0 or GPUs are not available, CPU device will
|
||||||
be used. Defaults to None.
|
be used. Defaults to None.
|
||||||
|
@ -583,16 +575,12 @@ class AnswerExtractor(Transformer):
|
||||||
return tensor.detach().cpu().tolist()
|
return tensor.detach().cpu().tolist()
|
||||||
|
|
||||||
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank)
|
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank)
|
||||||
batch_size = per_gpu_batch_size * max(1, num_gpus)
|
|
||||||
|
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
# score
|
# score
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
sampler = SequentialSampler(test_dataset)
|
|
||||||
test_dataloader = DataLoader(test_dataset, sampler=sampler, batch_size=batch_size)
|
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
for batch in tqdm(test_dataloader, desc="Evaluating", disable=not verbose):
|
for batch in tqdm(test_dataloader, desc="Evaluating", disable=not verbose):
|
||||||
batch = tuple(t.to(device) for t in batch)
|
batch = tuple(t.to(device) for t in batch)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче