diff --git a/examples/entailment/entailment_multinli_transformers.ipynb b/examples/entailment/entailment_multinli_transformers.ipynb index 47bfcec..62a5fbb 100644 --- a/examples/entailment/entailment_multinli_transformers.ipynb +++ b/examples/entailment/entailment_multinli_transformers.ipynb @@ -22,24 +22,16 @@ "source": [ "# Before You Start\n", "\n", - "The running time shown in this notebook is running bert-large-cased on a Standard_NC24rs_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n", + "It takes about 4 hours to fine-tune the `bert-large-cased` model on a Standard_NC24rs_v3 Azure Data Science Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n", "> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n", "\n", - "The table below provides some reference running time on different machine configurations. \n", - "\n", - "|QUICK_RUN|Machine Configurations|Running time|\n", - "|:---------|:----------------------|:------------|\n", - "|True|4 **CPU**s, 14GB memory| ~ 15 minutes|\n", - "|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5 minutes|\n", - "|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 10.5 hours|\n", - "|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ 2.5 hours|\n", "\n", "If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. " ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -56,31 +48,24 @@ "To classify a sentence pair, we concatenate the tokens in both sentences and separate the sentences by the special [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.The NLI task essentially becomes a sequence classification task. For example, the figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. \n", "\n", "\n", - "We compare the training time and performance of three models: bert-base-cased, bert-large-cased, and xlnet-large-cased. The model used can be set in the **Configurations** section. " + "We compare the training time and performance of bert-large-cased and xlnet-large-cased. The model used can be set in the **Configurations** section. " ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "scrolled": false }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "I1110 19:13:59.935610 140117887072000 file_utils.py:39] PyTorch version 1.2.0 available.\n", - "I1110 19:13:59.978967 140117887072000 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n" - ] - } - ], + "outputs": [], "source": [ "import sys, os\n", "nlp_path = os.path.abspath('../../')\n", "if nlp_path not in sys.path:\n", " sys.path.insert(0, nlp_path)\n", - " \n", + "\n", + "import scrapbook as sb\n", + "\n", "from tempfile import TemporaryDirectory\n", "\n", "import numpy as np\n", @@ -104,39 +89,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['bert-base-uncased',\n", - " 'bert-large-uncased',\n", - " 'bert-base-cased',\n", - " 'bert-large-cased',\n", - " 'bert-base-multilingual-uncased',\n", - " 'bert-base-multilingual-cased',\n", - " 'bert-base-chinese',\n", - " 'bert-base-german-cased',\n", - " 'bert-large-uncased-whole-word-masking',\n", - " 'bert-large-cased-whole-word-masking',\n", - " 'bert-large-uncased-whole-word-masking-finetuned-squad',\n", - " 'bert-large-cased-whole-word-masking-finetuned-squad',\n", - " 'bert-base-cased-finetuned-mrpc',\n", - " 'roberta-base',\n", - " 'roberta-large',\n", - " 'roberta-large-mnli',\n", - " 'xlnet-base-cased',\n", - " 'xlnet-large-cased',\n", - " 'distilbert-base-uncased',\n", - " 'distilbert-base-uncased-distilled-squad']" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "SequenceClassifier.list_supported_models()" ] @@ -150,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "tags": [ "parameters" @@ -194,8 +149,7 @@ "LABEL_COL = \"gold_label\"\n", "LABEL_COL_NUM = \"gold_label_num\"\n", "\n", - "CACHE_DIR = TemporaryDirectory().name\n", - "CACHE_DIR = \"./temp\"" + "CACHE_DIR = TemporaryDirectory().name" ] }, { @@ -209,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -220,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -230,33 +184,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training dataset size: 392702\n", - "Development (matched) dataset size: 9815\n", - "Development (mismatched) dataset size: 9832\n", - "\n", - " gold_label sentence1 \\\n", - "0 neutral Conceptually cream skimming has two basic dime... \n", - "1 entailment you know during the season and i guess at at y... \n", - "2 entailment One of our number will carry out your instruct... \n", - "3 entailment How do you know? All this is their information... \n", - "4 neutral yeah i tell you what though if you go price so... \n", - "\n", - " sentence2 \n", - "0 Product and geography are what make cream skim... \n", - "1 You lose the things to the following level if ... \n", - "2 A member of my team will execute your orders w... \n", - "3 This information belongs to them. \n", - "4 The tennis shoes have a range of prices. \n" - ] - } - ], + "outputs": [], "source": [ "print(\"Training dataset size: {}\".format(train_df.shape[0]))\n", "print(\"Development (matched) dataset size: {}\".format(dev_df_matched.shape[0]))\n", @@ -267,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -278,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -293,25 +223,18 @@ "metadata": {}, "source": [ "## Tokenize and Preprocess\n", - "Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets." + "Before training, we tokenize and preprocess the sentence texts to convert them into the format required by transformer model classes. \n", + "The `create_dataloader_from_df` method of the `Processor` class performs the following preprocessing steps and returns a Pytorch `DataLoader`\n", + "* Tokenize input texts using the tokenizer of the pre-trained model specified by `model_name`. \n", + "* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.\n", + "* Pad or truncate the token lists to the specified max length." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "I1110 19:14:11.376676 140117887072000 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt from cache at ./temp/cee054f6aafe5e2cf816d2228704e326446785f940f5451a5b26033516a4ac3d.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n", - "100%|██████████| 392702/392702 [03:48<00:00, 1715.17it/s]\n", - "100%|██████████| 9815/9815 [00:05<00:00, 1797.48it/s]\n", - "100%|██████████| 9832/9832 [00:05<00:00, 1709.69it/s]\n" - ] - } - ], + "outputs": [], "source": [ "processor = Processor(model_name=MODEL_NAME, cache_dir=CACHE_DIR, to_lower=TO_LOWER)\n", "train_dataloader = processor.create_dataloader_from_df(\n", @@ -341,21 +264,6 @@ ")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In addition, we perform the following preprocessing steps in the cell below:\n", - "\n", - "* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n", - "* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n", - "* Pad or truncate the token lists to the specified max length\n", - "* Return mask lists that indicate paddings' positions\n", - "* Return token type id lists that indicate which sentence the tokens belong to\n", - "\n", - "*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -416,31 +324,9 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Evaluating: 100%|██████████| 614/614 [04:53<00:00, 2.12it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Prediction time : 0.082 hrs\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ "with Timer() as t:\n", " predictions_matched = classifier.predict(dev_dataloader_matched)\n", @@ -449,31 +335,9 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Evaluating: 100%|██████████| 615/615 [04:53<00:00, 2.12it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Prediction time : 0.082 hrs\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ "with Timer() as t:\n", " predictions_mismatched = classifier.predict(dev_dataloader_mismatched)\n", @@ -489,26 +353,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - "contradiction 0.872 0.894 0.883 3213\n", - " entailment 0.913 0.862 0.887 3479\n", - " neutral 0.813 0.842 0.828 3123\n", - "\n", - " micro avg 0.866 0.866 0.866 9815\n", - " macro avg 0.866 0.866 0.866 9815\n", - " weighted avg 0.868 0.866 0.867 9815\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "predictions_matched = label_encoder.inverse_transform(predictions_matched)\n", "print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))" @@ -516,28 +363,11 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - "contradiction 0.891 0.888 0.889 3240\n", - " entailment 0.899 0.862 0.880 3463\n", - " neutral 0.810 0.850 0.830 3129\n", - "\n", - " micro avg 0.867 0.867 0.867 9832\n", - " macro avg 0.867 0.867 0.866 9832\n", - " weighted avg 0.868 0.867 0.867 9832\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)\n", "print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))" @@ -559,6 +389,22 @@ "|xlnet-large-cased|5.15 hrs|0.11 hrs|0.887|0.890|\n", "|bert-large-cased|4.01 hrs|0.08 hrs|0.867|0.867|" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result_matched_dict = classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3, output_dict=True)\n", + "result_mismatched_dict = classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3, output_dict=True)\n", + "sb.glue(\"matched_precision\", result_matched_dict[\"weighted avg\"][\"precision\"])\n", + "sb.glue(\"matched_recall\", result_matched_dict[\"weighted avg\"][\"recall\"])\n", + "sb.glue(\"matched_f1\", result_matched_dict[\"weighted avg\"][\"f1-score\"])\n", + "sb.glue(\"mismatched_precision\", result_mismatched_dict[\"weighted avg\"][\"precision\"])\n", + "sb.glue(\"mismatched_recall\", result_mismatched_dict[\"weighted avg\"][\"recall\"])\n", + "sb.glue(\"mismatched_f1\", result_mismatched_dict[\"weighted avg\"][\"f1-score\"])" + ] } ], "metadata": { diff --git a/examples/question_answering/question_answering_squad_transformers.ipynb b/examples/question_answering/question_answering_squad_transformers.ipynb index 97bca7d..c89cf91 100644 --- a/examples/question_answering/question_answering_squad_transformers.ipynb +++ b/examples/question_answering/question_answering_squad_transformers.ipynb @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -78,7 +78,7 @@ " sys.path.insert(0, nlp_path)\n", "\n", "from utils_nlp.dataset.squad import load_pandas_df\n", - "from utils_nlp.dataset.pytorch import QADataset\n", + "from utils_nlp.models.transformers.datasets import QADataset\n", "from utils_nlp.models.transformers.question_answering import (\n", " QAProcessor,\n", " AnswerExtractor\n", @@ -175,6 +175,7 @@ "DOC_STRIDE = 128\n", "PER_GPU_BATCH_SIZE = 4\n", "GRADIENT_ACCUMULATION_STEPS = 1\n", + "NUM_GPUS = torch.cuda.device_count()\n", "\n", "if QUICK_RUN:\n", " TRAIN_DATA_USED_PERCENT = 0.001\n", @@ -558,7 +559,7 @@ "* Pad the concatenated token sequence to `max_seq_length` if it's shorter.\n", "* Convert the tokens into token indices corresponding to the tokenizer's vocabulary.\n", "\n", - "`QAProcessor.preprocess` returns a Pytorch TensorDataset. By default, it saves `cached_examples_train/test.jsonl` and `cached_features_train/test.jsonl` to `./cached_qa_features`. These files are required by postprocessing the predicted answer start and end indices to get the final answer text. You can change the default file directory by specifying `feature_cache_dir`. " + "`QAProcessor.preprocess` returns a Pytorch Dataloader. By default, it saves `cached_examples_train/test.jsonl` and `cached_features_train/test.jsonl` to `./cached_qa_features`. These files are required by postprocessing the predicted answer start and end indices to get the final answer text. You can change the default file directory by specifying `feature_cache_dir`. " ] }, { @@ -576,16 +577,20 @@ ], "source": [ "qa_processor = QAProcessor(model_name=MODEL_NAME, to_lower=DO_LOWER_CASE)\n", - "train_features = qa_processor.preprocess(\n", + "train_dataloader = qa_processor.preprocess(\n", " train_dataset, \n", + " batch_size=PER_GPU_BATCH_SIZE,\n", + " num_gpus=NUM_GPUS,\n", " is_training=True,\n", " max_question_length=MAX_QUESTION_LENGTH,\n", " max_seq_length=MAX_SEQ_LENGTH,\n", " doc_stride=DOC_STRIDE\n", ")\n", "\n", - "dev_features = qa_processor.preprocess(\n", + "dev_dataloader = qa_processor.preprocess(\n", " dev_dataset, \n", + " batch_size=PER_GPU_BATCH_SIZE,\n", + " num_gpus=NUM_GPUS,\n", " is_training=False,\n", " max_question_length=MAX_QUESTION_LENGTH,\n", " max_seq_length=MAX_SEQ_LENGTH,\n", @@ -616,10 +621,9 @@ "outputs": [], "source": [ "with Timer() as t:\n", - " qa_extractor.fit(train_dataset=train_features,\n", + " qa_extractor.fit(train_dataloader,\n", " num_epochs=NUM_EPOCHS,\n", " learning_rate=LEARNING_RATE,\n", - " per_gpu_batch_size=PER_GPU_BATCH_SIZE,\n", " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", " seed=RANDOM_SEED,\n", " cache_model=True)\n", @@ -648,7 +652,7 @@ } ], "source": [ - "qa_results = qa_extractor.predict(dev_features, per_gpu_batch_size=PER_GPU_BATCH_SIZE)" + "qa_results = qa_extractor.predict(dev_dataloader)" ] }, { @@ -824,9 +828,9 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python [default]", + "display_name": "nlp_gpu", "language": "python", - "name": "python3" + "name": "nlp_gpu" }, "language_info": { "codemirror_mode": { @@ -838,7 +842,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.5" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/tests/unit/test_models_transformers_question_answering.py b/tests/unit/test_models_transformers_question_answering.py index bbd3d71..daa0c76 100644 --- a/tests/unit/test_models_transformers_question_answering.py +++ b/tests/unit/test_models_transformers_question_answering.py @@ -3,7 +3,7 @@ import pytest import os -from utils_nlp.dataset.pytorch import QADataset +from utils_nlp.models.transformers.datasets import QADataset from utils_nlp.models.transformers.question_answering import ( QAProcessor, AnswerExtractor, @@ -11,6 +11,11 @@ from utils_nlp.models.transformers.question_answering import ( CACHED_FEATURES_TEST_FILE, ) +import torch + +NUM_GPUS = max(1, torch.cuda.device_count()) +BATCH_SIZE = 8 + @pytest.fixture() def qa_test_data(qa_test_df, tmp): @@ -61,6 +66,8 @@ def qa_test_data(qa_test_df, tmp): qa_processor_bert = QAProcessor() train_features_bert = qa_processor_bert.preprocess( train_dataset, + batch_size=BATCH_SIZE, + num_gpus=NUM_GPUS, is_training=True, max_question_length=16, max_seq_length=64, @@ -70,6 +77,8 @@ def qa_test_data(qa_test_df, tmp): test_features_bert = qa_processor_bert.preprocess( test_dataset, + batch_size=BATCH_SIZE, + num_gpus=NUM_GPUS, is_training=False, max_question_length=16, max_seq_length=64, @@ -80,6 +89,8 @@ def qa_test_data(qa_test_df, tmp): qa_processor_xlnet = QAProcessor(model_name="xlnet-base-cased") train_features_xlnet = qa_processor_xlnet.preprocess( train_dataset, + batch_size=BATCH_SIZE, + num_gpus=NUM_GPUS, is_training=True, max_question_length=16, max_seq_length=64, @@ -89,6 +100,8 @@ def qa_test_data(qa_test_df, tmp): test_features_xlnet = qa_processor_xlnet.preprocess( test_dataset, + batch_size=BATCH_SIZE, + num_gpus=NUM_GPUS, is_training=False, max_question_length=16, max_seq_length=64, @@ -99,6 +112,8 @@ def qa_test_data(qa_test_df, tmp): qa_processor_distilbert = QAProcessor(model_name="distilbert-base-uncased") train_features_distilbert = qa_processor_distilbert.preprocess( train_dataset, + batch_size=BATCH_SIZE, + num_gpus=NUM_GPUS, is_training=True, max_question_length=16, max_seq_length=64, @@ -108,6 +123,8 @@ def qa_test_data(qa_test_df, tmp): test_features_distilbert = qa_processor_distilbert.preprocess( test_dataset, + batch_size=BATCH_SIZE, + num_gpus=NUM_GPUS, is_training=False, max_question_length=16, max_seq_length=64, @@ -157,9 +174,7 @@ def test_QAProcessor(qa_test_data, tmp): def test_AnswerExtractor(qa_test_data, tmp): # test bert qa_extractor_bert = AnswerExtractor(cache_dir=tmp) - qa_extractor_bert.fit( - qa_test_data["train_features_bert"], cache_model=True, per_gpu_batch_size=8 - ) + qa_extractor_bert.fit(qa_test_data["train_features_bert"], cache_model=True) # test saving fine-tuned model model_output_dir = os.path.join(tmp, "fine_tuned") @@ -170,15 +185,11 @@ def test_AnswerExtractor(qa_test_data, tmp): qa_extractor_from_cache.predict(qa_test_data["test_features_bert"]) qa_extractor_xlnet = AnswerExtractor(model_name="xlnet-base-cased", cache_dir=tmp) - qa_extractor_xlnet.fit( - qa_test_data["train_features_xlnet"], cache_model=False, per_gpu_batch_size=8 - ) + qa_extractor_xlnet.fit(qa_test_data["train_features_xlnet"], cache_model=False) qa_extractor_xlnet.predict(qa_test_data["test_features_xlnet"]) qa_extractor_distilbert = AnswerExtractor(model_name="distilbert-base-uncased", cache_dir=tmp) - qa_extractor_distilbert.fit( - qa_test_data["train_features_distilbert"], cache_model=False, per_gpu_batch_size=8 - ) + qa_extractor_distilbert.fit(qa_test_data["train_features_distilbert"], cache_model=False) qa_extractor_distilbert.predict(qa_test_data["test_features_distilbert"]) diff --git a/utils_nlp/models/transformers/question_answering.py b/utils_nlp/models/transformers/question_answering.py index 41ace6b..c340d36 100644 --- a/utils_nlp/models/transformers/question_answering.py +++ b/utils_nlp/models/transformers/question_answering.py @@ -26,7 +26,8 @@ import math import jsonlines import torch -from torch.utils.data import TensorDataset, SequentialSampler, DataLoader +from torch.utils.data import TensorDataset, SequentialSampler, DataLoader, RandomSampler +from torch.utils.data.distributed import DistributedSampler from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering @@ -40,11 +41,7 @@ from transformers.modeling_distilbert import ( ) from utils_nlp.common.pytorch_utils import get_device -from utils_nlp.models.transformers.common import ( - MAX_SEQ_LEN, - TOKENIZER_CLASS, - Transformer, -) +from utils_nlp.models.transformers.common import MAX_SEQ_LEN, TOKENIZER_CLASS, Transformer MODEL_CLASS = {} MODEL_CLASS.update({k: BertForQuestionAnswering for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP}) @@ -146,6 +143,9 @@ class QAProcessor: self, qa_dataset, is_training, + batch_size=32, + num_gpus=None, + distributed=False, max_question_length=64, max_seq_length=MAX_SEQ_LEN, doc_stride=128, @@ -243,37 +243,42 @@ class QAProcessor: examples_writer.write_all(qa_examples_json) features_writer.write_all(features_json) - # TODO: maybe generalize the following code - input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) - input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) - segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) - cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) - p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) - - if is_training: - start_positions = torch.tensor( - [f.start_position for f in features], dtype=torch.long - ) - end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) - qa_dataset = TensorDataset( - input_ids, - input_mask, - segment_ids, - start_positions, - end_positions, - cls_index, - p_mask, - ) - else: - unique_id_all = torch.tensor(unique_id_all, dtype=torch.long) - qa_dataset = TensorDataset( - input_ids, input_mask, segment_ids, cls_index, p_mask, unique_id_all - ) - logger.info("QA examples are saved to {}".format(examples_file)) logger.info("QA features are saved to {}".format(features_file)) - return qa_dataset + # TODO: maybe generalize the following code + input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) + segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) + cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) + p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) + + if is_training: + start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) + end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) + qa_dataset = TensorDataset( + input_ids, + input_mask, + segment_ids, + start_positions, + end_positions, + cls_index, + p_mask, + ) + else: + unique_id_all = torch.tensor(unique_id_all, dtype=torch.long) + qa_dataset = TensorDataset( + input_ids, input_mask, segment_ids, cls_index, p_mask, unique_id_all + ) + + if num_gpus is not None: + batch_size = batch_size * max(1, num_gpus) + if distributed: + sampler = DistributedSampler(qa_dataset) + else: + sampler = RandomSampler(qa_dataset) if is_training else SequentialSampler(qa_dataset) + + return DataLoader(qa_dataset, sampler=sampler, batch_size=batch_size) def postprocess( self, @@ -469,9 +474,8 @@ class AnswerExtractor(Transformer): def fit( self, - train_dataset, + train_dataloader, num_gpus=None, - per_gpu_batch_size=8, num_epochs=1, learning_rate=5e-5, max_grad_norm=1.0, @@ -491,12 +495,10 @@ class AnswerExtractor(Transformer): Fine-tune pre-trained transofmer models for question answering. Args: - train_dataset (QADataset): Training dataset of type - :class:`utils_nlp.dataset.pytorch.QADataset`. + train_dataloader (Dataloader): Dataloader for the training data. num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will be used. If set to 0 or GPUs are not available, CPU device will be used. Defaults to None. - per_gpu_batch_size (int, optional): Training batch size on each GPU. Defaults to 8. num_epochs (int, optional): Number of training epochs. Defaults to 1. learning_rate (float, optional): Learning rate of the AdamW optimizer. Defaults to 5e-5. @@ -530,14 +532,13 @@ class AnswerExtractor(Transformer): self.model.to(device) super().fine_tune( - train_dataset=train_dataset, + train_dataloader=train_dataloader, get_inputs=QAProcessor.get_inputs, device=device, max_steps=max_steps, num_train_epochs=num_epochs, max_grad_norm=max_grad_norm, gradient_accumulation_steps=gradient_accumulation_steps, - per_gpu_train_batch_size=per_gpu_batch_size, n_gpu=num_gpus, weight_decay=weight_decay, learning_rate=learning_rate, @@ -552,22 +553,13 @@ class AnswerExtractor(Transformer): if cache_model: self.save_model() - def predict( - self, - test_dataset, - per_gpu_batch_size=16, - num_gpus=None, - local_rank=-1, - verbose=True, - ): + def predict(self, test_dataloader, num_gpus=None, local_rank=-1, verbose=True): """ Predicts answer start and end logits. Args: - test_dataset (QADataset): Testing dataset of type - :class:`utils_nlp.dataset.pytorch.QADataset`. - per_gpu_batch_size (int, optional): Testing batch size on each GPU. Defaults to 16. + test_dataloader (QADataset): Dataloader for the testing data. num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will be used. If set to 0 or GPUs are not available, CPU device will be used. Defaults to None. @@ -583,16 +575,12 @@ class AnswerExtractor(Transformer): return tensor.detach().cpu().tolist() device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank) - batch_size = per_gpu_batch_size * max(1, num_gpus) self.model.to(device) # score self.model.eval() - sampler = SequentialSampler(test_dataset) - test_dataloader = DataLoader(test_dataset, sampler=sampler, batch_size=batch_size) - all_results = [] for batch in tqdm(test_dataloader, desc="Evaluating", disable=not verbose): batch = tuple(t.to(device) for t in batch)