diff --git a/scenarios/entailment/entailment_xnli_bert_azureml.ipynb b/scenarios/entailment/entailment_xnli_bert_azureml.ipynb new file mode 100644 index 0000000..c3d8345 --- /dev/null +++ b/scenarios/entailment/entailment_xnli_bert_azureml.ipynb @@ -0,0 +1,531 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Natural Language Inference on XNLI Dataset using BERT with Azure Machine Learning\n", + "\n", + "## 1. Summary\n", + "In this notebook, we demostrate using the BERT model to do language inference in English. We use the [XNLI](https://github.com/facebookresearch/XNLI) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral. \n", + "The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.\n", + "\n", + "\n", + "Azure Machine Learning features higlighted in the notebook : \n", + "\n", + "- Distributed training with Horovod" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "\n", + "import sys\n", + "\n", + "sys.path.append(\"../..\")\n", + "\n", + "import os\n", + "import shutil\n", + "import torch\n", + "import json\n", + "import pandas as pd\n", + "\n", + "import azureml.core\n", + "from azureml.train.dnn import PyTorch\n", + "from azureml.core.runconfig import MpiConfiguration\n", + "from azureml.core import Experiment\n", + "from azureml.widgets import RunDetails\n", + "from azureml.core.compute import ComputeTarget\n", + "from utils_nlp.azureml.azureml_utils import get_or_create_workspace, get_output_files" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "# Parameters\n", + "\n", + "DEBUG = True\n", + "NODE_COUNT = 4\n", + "NUM_PROCESS = 1\n", + "DATA_PERCENT_USED = 1.0\n", + "\n", + "config_path = (\n", + " \"./.azureml\"\n", + ") # Path to the directory containing config.json with azureml credentials\n", + "\n", + "# Azure resources\n", + "subscription_id = \"YOUR_SUBSCRIPTION_ID\"\n", + "resource_group = \"YOUR_RESOURCE_GROUP_NAME\" \n", + "workspace_name = \"YOUR_WORKSPACE_NAME\" \n", + "workspace_region = \"YOUR_WORKSPACE_REGION\" # eg: eastus, eastus2.\n", + "cluster_name = \"gpu-entail\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. AzureML Setup\n", + "\n", + "### 2.1 Link to or create a Workspace\n", + "\n", + "First, go through the [Configuration](https://github.com/Azure/MachineLearningNotebooks/blob/master/configuration.ipynb) notebook to install the Azure Machine Learning Python SDK and create an Azure ML `Workspace`. This will create a config.json file containing the values needed below to create a workspace.\n", + "\n", + "**Note**: you do not need to fill in these values if you have a config.json in the same folder as this notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "ws = get_or_create_workspace(\n", + " config_path=config_path,\n", + " subscription_id=subscription_id,\n", + " resource_group=resource_group,\n", + " workspace_name=workspace_name,\n", + " workspace_region=workspace_region,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " \"Workspace name: \" + ws.name,\n", + " \"Azure region: \" + ws.location,\n", + " \"Subscription id: \" + ws.subscription_id,\n", + " \"Resource group: \" + ws.resource_group,\n", + " sep=\"\\n\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3 Link AmlCompute Compute Target\n", + "\n", + "We need to link a [compute target](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#compute-target) for training our model (see [compute options](https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-set-up-training-targets#supported-compute-targets) for explanation of the different options). We will use an [AmlCompute](https://docs.microsoft.com/azure/machine-learning/service/how-to-set-up-training-targets#amlcompute) target and link to an existing target (if the cluster_name exists) or create a STANDARD_NC6 GPU cluster (autoscales from 0 to 4 nodes) in this example. Creating a new AmlComputes takes approximately 5 minutes. \n", + "\n", + "As with other Azure services, there are limits on certain resources (e.g. AmlCompute) associated with the Azure Machine Learning service. Please read [this article](https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-manage-quotas) on the default limits and how to request more quota." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found compute target: gpu-entail\n", + "{'currentNodeCount': 0, 'targetNodeCount': 0, 'nodeStateCounts': {'preparingNodeCount': 0, 'runningNodeCount': 0, 'idleNodeCount': 0, 'unusableNodeCount': 0, 'leavingNodeCount': 0, 'preemptedNodeCount': 0}, 'allocationState': 'Steady', 'allocationStateTransitionTime': '2019-08-03T13:43:20.068000+00:00', 'errors': None, 'creationTime': '2019-07-27T02:14:46.127092+00:00', 'modifiedTime': '2019-07-27T02:15:07.181277+00:00', 'provisioningState': 'Succeeded', 'provisioningStateTransitionTime': None, 'scaleSettings': {'minNodeCount': 0, 'maxNodeCount': 4, 'nodeIdleTimeBeforeScaleDown': 'PT120S'}, 'vmPriority': 'Dedicated', 'vmSize': 'STANDARD_NC6S_V2'}\n" + ] + } + ], + "source": [ + "try:\n", + " compute_target = ComputeTarget(workspace=ws, name=cluster_name)\n", + " print(\"Found compute target: {}\".format(cluster_name))\n", + "except ComputeTargetException:\n", + " print(\"Creating new compute target: {}\".format(cluster_name))\n", + " compute_config = AmlCompute.provisioning_configuration(\n", + " vm_size=\"STANDARD_NC6\", max_nodes=1\n", + " )\n", + " compute_target = ComputeTarget.create(ws, cluster_name, compute_config)\n", + " compute_target.wait_for_completion(show_output=True)\n", + "\n", + "\n", + "print(compute_target.get_status().serialize())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'./entail_utils\\\\utils_nlp'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "project_dir = \"./entail_utils\"\n", + "if DEBUG and os.path.exists(project_dir):\n", + " shutil.rmtree(project_dir)\n", + "shutil.copytree(\"../../utils_nlp\", os.path.join(project_dir, \"utils_nlp\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Prepare Training Script" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing ./entail_utils/train.py\n" + ] + } + ], + "source": [ + "%%writefile $project_dir/train.py\n", + "import horovod.torch as hvd\n", + "import torch\n", + "import numpy as np\n", + "import time\n", + "import argparse\n", + "from utils_nlp.common.timer import Timer\n", + "from utils_nlp.dataset.xnli_torch_dataset import XnliDataset\n", + "from utils_nlp.models.bert.common import Language\n", + "from utils_nlp.models.bert.sequence_classification_distributed import (\n", + " BERTSequenceClassifier,\n", + ")\n", + "from sklearn.metrics import classification_report\n", + "\n", + "print(\"Torch version:\", torch.__version__)\n", + "\n", + "hvd.init()\n", + "\n", + "LANGUAGE_ENGLISH = \"en\"\n", + "TRAIN_FILE_SPLIT = \"train\"\n", + "TEST_FILE_SPLIT = \"test\"\n", + "TO_LOWERCASE = True\n", + "PRETRAINED_BERT_LNG = Language.ENGLISH\n", + "LEARNING_RATE = 5e-5\n", + "WARMUP_PROPORTION = 0.1\n", + "BATCH_SIZE = 32\n", + "NUM_GPUS = 1\n", + "OUTPUT_DIR = \"./outputs/\"\n", + "LABELS = [\"contradiction\", \"entailment\", \"neutral\"]\n", + "\n", + "## each machine gets it's own copy of data\n", + "CACHE_DIR = \"./xnli-%d\" % hvd.rank()\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "# Training settings\n", + "parser.add_argument(\n", + " \"--seed\", type=int, default=42, metavar=\"S\", help=\"random seed (default: 42)\"\n", + ")\n", + "parser.add_argument(\n", + " \"--epochs\", type=int, default=2, metavar=\"S\", help=\"random seed (default: 2)\"\n", + ")\n", + "parser.add_argument(\n", + " \"--no-cuda\", action=\"store_true\", default=False, help=\"disables CUDA training\"\n", + ")\n", + "parser.add_argument(\n", + " \"--data_percent_used\",\n", + " type=float,\n", + " default=1.0,\n", + " metavar=\"S\",\n", + " help=\"data percent used (default: 1.0)\",\n", + ")\n", + "\n", + "args = parser.parse_args()\n", + "args.cuda = not args.no_cuda and torch.cuda.is_available()\n", + "\n", + "\"\"\"\n", + "Note: For example, you have 4 nodes and 4 GPUs each node, so you spawn 16 workers. \n", + "Every worker will have a rank [0, 15], and every worker will have a local_rank [0, 3]\n", + "\"\"\"\n", + "if args.cuda:\n", + " torch.cuda.set_device(hvd.local_rank())\n", + " torch.cuda.manual_seed(args.seed)\n", + "\n", + "# num_workers - this is equal to number of gpus per machine\n", + "kwargs = {\"num_workers\": NUM_GPUS, \"pin_memory\": True} if args.cuda else {}\n", + "\n", + "train_dataset = XnliDataset(\n", + " file_split=TRAIN_FILE_SPLIT,\n", + " cache_dir=CACHE_DIR,\n", + " language=LANGUAGE_ENGLISH,\n", + " to_lowercase=TO_LOWERCASE,\n", + " tok_language=PRETRAINED_BERT_LNG,\n", + " data_percent_used=args.data_percent_used,\n", + ")\n", + "\n", + "\n", + "# set the label_encoder for evaluation\n", + "label_encoder = train_dataset.label_encoder\n", + "num_labels = len(np.unique(train_dataset.labels))\n", + "\n", + "# Train\n", + "classifier = BERTSequenceClassifier(\n", + " language=Language.ENGLISH,\n", + " num_labels=num_labels,\n", + " cache_dir=CACHE_DIR,\n", + " use_distributed=True,\n", + ")\n", + "\n", + "\n", + "train_loader = classifier.create_data_loader(\n", + " train_dataset, BATCH_SIZE, mode=\"train\", **kwargs\n", + ")\n", + "\n", + "\n", + "num_samples = len(train_loader.dataset)\n", + "num_batches = int(num_samples / BATCH_SIZE)\n", + "num_train_optimization_steps = num_batches * args.epochs\n", + "optimizer = classifier.create_optimizer(\n", + " num_train_optimization_steps, lr=LEARNING_RATE, warmup_proportion=WARMUP_PROPORTION\n", + ")\n", + "\n", + "with Timer() as t:\n", + " for epoch in range(1, args.epochs + 1):\n", + "\n", + " # to allow data shuffling for DistributedSampler\n", + " train_loader.sampler.set_epoch(epoch)\n", + "\n", + " # epoch and num_epochs is passed in the fit function to print loss at regular batch intervals\n", + " classifier.fit(\n", + " train_loader,\n", + " epoch=epoch,\n", + " num_epochs=args.epochs,\n", + " bert_optimizer=optimizer,\n", + " num_gpus=NUM_GPUS,\n", + " )\n", + "\n", + "#if machine has multiple gpus then run predictions on only on 1 gpu since test_dataset is small.\n", + "if hvd.rank() == 0:\n", + " NUM_GPUS = 1\n", + " \n", + " test_dataset = XnliDataset(\n", + " file_split=TEST_FILE_SPLIT,\n", + " cache_dir=CACHE_DIR,\n", + " language=LANGUAGE_ENGLISH,\n", + " to_lowercase=TO_LOWERCASE,\n", + " tok_language=PRETRAINED_BERT_LNG,\n", + " )\n", + "\n", + " test_loader = classifier.create_data_loader(test_dataset, mode=\"test\")\n", + "\n", + " # predict\n", + " predictions, pred_labels = classifier.predict(test_loader, NUM_GPUS)\n", + "\n", + " predictions = label_encoder.inverse_transform(predictions)\n", + "\n", + " # Evaluate\n", + " results = classification_report(\n", + " pred_labels, predictions, target_names=LABELS, output_dict=True\n", + " )\n", + "\n", + " result_file = os.path.join(OUTPUT_DIR, \"results.json\")\n", + " with open(result_file, \"w+\") as fp:\n", + " json.dump(results, fp)\n", + "\n", + " # save model\n", + " classifier.save_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Create a PyTorch Estimator\n", + "\n", + "BERT is built on PyTorch, so we will use the AzureML SDK's PyTorch estimator to easily submit PyTorch training jobs for both single-node and distributed runs. For more information on the PyTorch estimator, see [How to Train Pytorch Models on AzureML](https://docs.microsoft.com/azure/machine-learning/service/how-to-train-pytorch). " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "mpiConfig = MpiConfiguration()\n", + "mpiConfig.process_count_per_node = NUM_PROCESS\n", + "\n", + "script_params = {\n", + " '--data_percent_used': DATA_PERCENT_USED\n", + "}\n", + "\n", + "est = PyTorch(\n", + " source_directory=project_dir,\n", + " compute_target=compute_target,\n", + " entry_script=\"train.py\",\n", + " script_params = script_params,\n", + " node_count=NODE_COUNT,\n", + " distributed_training=mpiConfig,\n", + " use_gpu=True,\n", + " framework_version=\"1.0\",\n", + " conda_packages=[\"scikit-learn=0.20.3\", \"numpy\", \"spacy\", \"nltk\"],\n", + " pip_packages=[\"pandas\", \"seqeval[gpu]\", \"pytorch-pretrained-bert\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Create Experiment and Submit a Job\n", + "Submit the estimator object to run your experiment. Results can be monitored using a Jupyter widget. The widget and run are asynchronous and update every 10-15 seconds until job completion.\n", + "\n", + "**Note**: The experiment takes ~4 hours with 2 NC24 nodes and ~7hours with 4 NC6 nodes. The overhead is due to the communication time between nodes. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "experiment = Experiment(ws, name=\"NLP-Entailment-BERT\")\n", + "run = experiment.submit(est)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c8e7a44fa8804e95b21eea74d7694b1e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "_UserRunWidget(widget_settings={'childWidgetDisplay': 'popup', 'send_telemetry': False, 'log_level': 'INFO', '…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "RunDetails(run).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the above cell is an async call, the below cell is a blocking call to stop the cells below it to execute." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "run.wait_for_completion()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Analyze Results\n", + "\n", + "Download result.json from portal and open to view results. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading file outputs/results.json to ./outputs\\results.json...\n" + ] + } + ], + "source": [ + "file_names = [\"outputs/results.json\"]\n", + "get_output_files(run, \"./outputs\", file_names=file_names)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " f1-score precision recall support\n", + "contradiction 0.838749 0.859296 0.819162 1670.0\n", + "entailment 0.817280 0.877663 0.764671 1670.0\n", + "neutral 0.777870 0.719817 0.846108 1670.0\n", + "micro avg 0.809980 0.809980 0.809980 5010.0\n", + "macro avg 0.811300 0.818925 0.809980 5010.0\n", + "weighted avg 0.811300 0.818925 0.809980 5010.0\n" + ] + } + ], + "source": [ + "with open(\"outputs/results.json\", \"r\") as handle:\n", + " parsed = json.load(handle)\n", + " print(pd.DataFrame.from_dict(parsed).transpose())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/scenarios/text_classification/tc_bert_azureml.ipynb b/scenarios/text_classification/tc_bert_azureml.ipynb index 009983b..988ced8 100644 --- a/scenarios/text_classification/tc_bert_azureml.ipynb +++ b/scenarios/text_classification/tc_bert_azureml.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 30, "metadata": { "tags": [ "parameters" @@ -135,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 31, "metadata": { "scrolled": false }, @@ -164,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -172,7 +172,7 @@ "output_type": "stream", "text": [ "Found existing compute target.\n", - "{'currentNodeCount': 2, 'targetNodeCount': 2, 'nodeStateCounts': {'preparingNodeCount': 0, 'runningNodeCount': 0, 'idleNodeCount': 2, 'unusableNodeCount': 0, 'leavingNodeCount': 0, 'preemptedNodeCount': 0}, 'allocationState': 'Steady', 'allocationStateTransitionTime': '2019-07-31T22:29:42.732000+00:00', 'errors': None, 'creationTime': '2019-07-25T04:16:20.598768+00:00', 'modifiedTime': '2019-07-25T04:16:36.486727+00:00', 'provisioningState': 'Succeeded', 'provisioningStateTransitionTime': None, 'scaleSettings': {'minNodeCount': 2, 'maxNodeCount': 10, 'nodeIdleTimeBeforeScaleDown': 'PT120S'}, 'vmPriority': 'Dedicated', 'vmSize': 'STANDARD_NC12'}\n" + "{'currentNodeCount': 0, 'targetNodeCount': 0, 'nodeStateCounts': {'preparingNodeCount': 0, 'runningNodeCount': 0, 'idleNodeCount': 0, 'unusableNodeCount': 0, 'leavingNodeCount': 0, 'preemptedNodeCount': 0}, 'allocationState': 'Steady', 'allocationStateTransitionTime': '2019-08-11T08:53:18.284000+00:00', 'errors': None, 'creationTime': '2019-07-25T04:16:20.598768+00:00', 'modifiedTime': '2019-08-05T06:40:12.292030+00:00', 'provisioningState': 'Succeeded', 'provisioningStateTransitionTime': None, 'scaleSettings': {'minNodeCount': 0, 'maxNodeCount': 10, 'nodeIdleTimeBeforeScaleDown': 'PT120S'}, 'vmPriority': 'Dedicated', 'vmSize': 'STANDARD_NC12'}\n" ] } ], @@ -213,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -234,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 34, "metadata": { "scrolled": true }, @@ -271,9 +271,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 35, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "$AZUREML_DATAREFERENCE_9609849b541244d396d06017b5729edb" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "ds = ws.get_default_datastore()\n", "ds.upload(src_dir=TRAIN_FOLDER, target_path=\"mnli_data/train\", overwrite=True, show_progress=False)\n", @@ -282,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -299,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 37, "metadata": { "scrolled": true }, @@ -404,7 +415,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -413,7 +424,7 @@ "'../../utils_nlp/models/bert/preprocess.py'" ] }, - "execution_count": 28, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -432,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -461,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -543,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -564,13 +575,14 @@ "import logging\n", "import os\n", "import torch\n", + "\n", "from sklearn.metrics import classification_report\n", "\n", - "from utils_nlp.models.bert.common import Language\n", - "from utils_nlp.models.bert.sequence_classification_distributed import (\n", - " BERTSequenceDistClassifier,\n", - ")\n", "from utils_nlp.common.timer import Timer\n", + "from utils_nlp.models.bert.common import Language, get_dataset_multiple_files\n", + "from utils_nlp.models.bert.sequence_classification_distributed import (\n", + " BERTSequenceClassifier,\n", + ")\n", "\n", "BATCH_SIZE = 32\n", "NUM_GPUS = 2\n", @@ -602,46 +614,64 @@ "# Handle square brackets from train list\n", "train_files[0] = train_files[0][1:]\n", "train_files[len(train_files) - 1] = train_files[len(train_files) - 1][:-1]\n", + "train_dataset = get_dataset_multiple_files(train_files)\n", "\n", "# Handle square brackets from test list\n", "test_files[0] = test_files[0][1:]\n", "test_files[len(test_files) - 1] = test_files[len(test_files) - 1][:-1]\n", + "test_dataset = get_dataset_multiple_files(test_files)\n", "\n", "# Train\n", - "classifier = BERTSequenceDistClassifier(\n", - " language=Language.ENGLISH, num_labels=len(LABELS)\n", + "classifier = BERTSequenceClassifier(\n", + " language=Language.ENGLISH, num_labels=len(LABELS), use_distributed=True\n", ")\n", + "\n", + "# Create data loaders.\n", + "kwargs = (\n", + " {\"num_workers\": NUM_GPUS, \"pin_memory\": True} if torch.cuda.is_available() else {}\n", + ")\n", + "train_data_loader = classifier.create_data_loader(\n", + " train_dataset, batch_size=BATCH_SIZE, **kwargs\n", + ")\n", + "test_data_loader = classifier.create_data_loader(\n", + " test_dataset, batch_size=BATCH_SIZE, mode=\"test\", **kwargs\n", + ")\n", + "\n", + "# Create optimizer\n", + "num_examples = len(train_dataset)\n", + "num_batches = int(num_examples / BATCH_SIZE)\n", + "num_train_optimization_steps = num_batches * NUM_EPOCHS\n", + "optimizer = classifier.create_optimizer(num_train_optimization_steps)\n", + "\n", "with Timer() as t:\n", - " classifier.fit(\n", - " train_files,\n", - " num_gpus=NUM_GPUS,\n", - " num_epochs=NUM_EPOCHS,\n", - " batch_size=BATCH_SIZE,\n", - " verbose=True,\n", - " )\n", + " for epoch in range(1, NUM_EPOCHS + 1):\n", + " train_data_loader.sampler.set_epoch(epoch)\n", + " classifier.fit(\n", + " train_data_loader,\n", + " epoch=epoch,\n", + " bert_optimizer=optimizer,\n", + " num_gpus=NUM_GPUS,\n", + " num_epochs=NUM_EPOCHS,\n", + " )\n", "\n", "# Predict\n", - "preds, labels_test = classifier.predict(\n", - " test_files, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n", - ")\n", + "preds, labels_test = classifier.predict(test_data_loader, num_gpus=NUM_GPUS)\n", "\n", + "# Evaluate\n", "results = classification_report(\n", " labels_test, preds, target_names=LABELS, output_dict=True\n", ")\n", "\n", "# Write out results.\n", + "classifier.save_model()\n", "result_file = os.path.join(OUTPUT_DIR, \"results.json\")\n", "with open(result_file, \"w+\") as fp:\n", - " json.dump(results, fp)\n", - "\n", - "# Save model\n", - "model_file = os.path.join(OUTPUT_DIR, \"model.pt\")\n", - "torch.save(classifier.model.state_dict(), model_file)" + " json.dump(results, fp)" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -650,7 +680,7 @@ "'../../utils_nlp/models/bert/train.py'" ] }, - "execution_count": 32, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -675,15 +705,14 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING - framework_version is not specified, defaulting to version 1.1.\n", - "WARNING - 'process_count_per_node' parameter will be deprecated. Please use it as part of 'distributed_training' parameter.\n" + "WARNING - framework_version is not specified, defaulting to version 1.1.\n" ] } ], @@ -692,8 +721,7 @@ " compute_target=compute_target,\n", " entry_script='utils_nlp/models/bert/train.py',\n", " node_count= NODE_COUNT,\n", - " distributed_training=MpiConfiguration(),\n", - " process_count_per_node=2,\n", + " distributed_training= MpiConfiguration(),\n", " use_gpu=True,\n", " conda_packages=['scikit-learn=0.20.3', 'numpy>=1.16.0', 'pandas'],\n", " pip_packages=[\"tqdm==4.31.1\",\"pytorch-pretrained-bert>=0.6\"]\n", @@ -702,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -742,7 +770,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 46, "metadata": { "scrolled": false }, @@ -750,7 +778,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "48df85f533834264a8a8b65a57d60d59", + "model_id": "060659321062486694c0acbb0184eeed", "version_major": 2, "version_minor": 0 }, @@ -768,7 +796,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -797,7 +825,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 49, "metadata": { "scrolled": false }, @@ -807,19 +835,20 @@ "output_type": "stream", "text": [ "Downloading file outputs/results.json to ./outputs\\results.json...\n", - "Downloading file outputs/model.pt to ./outputs\\model.pt...\n" + "Downloading file outputs/bert-large-uncased to ./outputs\\bert-large-uncased...\n", + "Downloading file outputs/bert_config.json to ./outputs\\bert_config.json...\n" ] } ], "source": [ "step_run = pipeline_run.find_step_run(\"Estimator-Train\")[0]\n", - "file_names = ['outputs/results.json', 'outputs/model.pt']\n", + "file_names = ['outputs/results.json', 'outputs/bert-large-uncased', 'outputs/bert_config.json' ]\n", "azureml_utils.get_output_files(step_run, './outputs', file_names=file_names)" ] }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -827,14 +856,14 @@ "output_type": "stream", "text": [ " f1-score precision recall support\n", - "telephone 0.920217 0.897281 0.944356 629.0\n", - "government 0.967905 0.979487 0.956594 599.0\n", - "travel 0.856683 0.900169 0.817204 651.0\n", - "slate 0.991093 0.991896 0.990291 618.0\n", - "fiction 0.936434 0.906907 0.967949 624.0\n", - "micro avg 0.933996 0.933996 0.933996 3121.0\n", - "macro avg 0.934466 0.935148 0.935279 3121.0\n", - "weighted avg 0.933394 0.934321 0.933996 3121.0\n" + "telephone 0.904130 0.843191 0.974563 629.0\n", + "government 0.955857 0.972366 0.939900 599.0\n", + "travel 0.839966 0.935849 0.761905 651.0\n", + "slate 0.986411 0.974724 0.998382 618.0\n", + "fiction 0.938871 0.918712 0.959936 624.0\n", + "micro avg 0.925344 0.925344 0.925344 3121.0\n", + "macro avg 0.925047 0.928968 0.926937 3121.0\n", + "weighted avg 0.923913 0.928455 0.925344 3121.0\n" ] } ], @@ -869,7 +898,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ diff --git a/tests/conftest.py b/tests/conftest.py index e0eec91..7b55d35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,6 +69,9 @@ def notebooks(): "entailment_multinli_bert": os.path.join( folder_notebooks, "entailment", "entailment_multinli_bert.ipynb" ), + "entailment_bert_azureml": os.path.join( + folder_notebooks, "entailment", "entailment_xnli_bert_azureml.ipynb" + ), "tc_bert_azureml": os.path.join( folder_notebooks, "text_classification", "tc_bert_azureml.ipynb" ), diff --git a/tests/integration/test_notebooks_entailment.py b/tests/integration/test_notebooks_entailment.py index c8e11b6..9e62279 100644 --- a/tests/integration/test_notebooks_entailment.py +++ b/tests/integration/test_notebooks_entailment.py @@ -3,8 +3,12 @@ import pytest import papermill as pm +import os +import json +import shutil from tests.notebooks_common import OUTPUT_NOTEBOOK, KERNEL_NAME +ABS_TOL = 0.1 @pytest.mark.gpu @pytest.mark.integration @@ -20,3 +24,30 @@ def test_entailment_multinli_bert(notebooks): }, kernel_name=KERNEL_NAME, ) + + +@pytest.mark.integration +@pytest.mark.azureml +def test_entailment_bert_azureml(notebooks, + subscription_id, + resource_group, + workspace_name, + workspace_region, + cluster_name): + notebook_path = notebooks["entailment_bert_azureml"] + pm.execute_notebook(notebook_path, + OUTPUT_NOTEBOOK, + parameters={'DATA_PERCENT_USED': 0.0025, + "subscription_id": subscription_id, + "resource_group": resource_group, + "workspace_name": workspace_name, + "workspace_region": workspace_region, + "cluster_name": cluster_name}, + kernel_name=KERNEL_NAME,) + + with open("outputs/results.json", "r") as handle: + result_dict = json.load(handle) + assert result_dict["weighted avg"]["f1-score"] == pytest.approx(0.2, abs=ABS_TOL) + + if os.path.exists("outputs"): + shutil.rmtree("outputs") diff --git a/utils_nlp/dataset/xnli_torch_dataset.py b/utils_nlp/dataset/xnli_torch_dataset.py new file mode 100644 index 0000000..62b3c38 --- /dev/null +++ b/utils_nlp/dataset/xnli_torch_dataset.py @@ -0,0 +1,119 @@ +import numpy as np +import torch +from utils_nlp.models.bert.common import Language, Tokenizer +from torch.utils import data +from utils_nlp.dataset.xnli import load_pandas_df +from sklearn.preprocessing import LabelEncoder + +MAX_SEQ_LENGTH = 128 +TEXT_COL = "text" +LABEL_COL = "label" +DATA_PERCENT_USED = 1.0 +TRAIN_FILE_SPLIT = "train" +TEST_FILE_SPLIT = "test" +VALIDATION_FILE_SPLIT = "dev" +CACHE_DIR = "./" +LANGUAGE_ENGLISH = "en" +TO_LOWER_CASE = False +TOK_ENGLISH = Language.ENGLISH +VALID_FILE_SPLIT = [TRAIN_FILE_SPLIT, VALIDATION_FILE_SPLIT, TEST_FILE_SPLIT] + + +def _load_pandas_df(cache_dir, file_split, language, data_percent_used): + df = load_pandas_df(local_cache_path=cache_dir, file_split=file_split, language=language) + data_used_count = round(data_percent_used * df.shape[0]) + df = df.loc[:data_used_count] + return df + + +def _tokenize(tok_language, to_lowercase, cache_dir, df): + print("Create a tokenizer...") + tokenizer = Tokenizer(language=tok_language, to_lower=to_lowercase, cache_dir=cache_dir) + tokens = tokenizer.tokenize(df[TEXT_COL]) + + print("Tokenize and preprocess text...") + # tokenize + token_ids, input_mask, token_type_ids = tokenizer.preprocess_classification_tokens( + tokens, max_len=MAX_SEQ_LENGTH + ) + return token_ids, input_mask, token_type_ids + + +def _fit_train_labels(df): + label_encoder = LabelEncoder() + train_labels = label_encoder.fit_transform(df[LABEL_COL]) + train_labels = np.array(train_labels) + return label_encoder, train_labels + + +class XnliDataset(data.Dataset): + def __init__( + self, + file_split=TRAIN_FILE_SPLIT, + cache_dir=CACHE_DIR, + language=LANGUAGE_ENGLISH, + to_lowercase=TO_LOWER_CASE, + tok_language=TOK_ENGLISH, + data_percent_used=DATA_PERCENT_USED, + ): + """ + Load the dataset here + Args: + file_split (str, optional):The subset to load. + One of: {"train", "dev", "test"} + Defaults to "train". + cache_dir (str, optional):Path to store the data. + Defaults to "./". + language(str):Language required to load which xnli file (eg - "en", "zh") + to_lowercase(bool):flag to convert samples in dataset to lowercase + tok_language(Language, optional): language (Language, optional): The pretrained model's language. + Defaults to Language.ENGLISH. + data_percent_used(float, optional): Data used to create Torch Dataset.Defaults to "1.0" which is 100% data + """ + if file_split not in VALID_FILE_SPLIT: + raise ValueError("The file split is not part of ", VALID_FILE_SPLIT) + + self.file_split = file_split + self.cache_dir = cache_dir + self.language = language + self.to_lowercase = to_lowercase + self.tok_language = tok_language + self.data_percent_used = data_percent_used + + df = _load_pandas_df(self.cache_dir, self.file_split, self.language, self.data_percent_used) + + self.df = df + + token_ids, input_mask, token_type_ids = _tokenize( + tok_language, to_lowercase, cache_dir, self.df + ) + + self.token_ids = token_ids + self.input_mask = input_mask + self.token_type_ids = token_type_ids + + if file_split == TRAIN_FILE_SPLIT: + label_encoder, train_labels = _fit_train_labels(self.df) + self.label_encoder = label_encoder + self.labels = train_labels + else: + # use the label_encoder passed when you create the test/validate dataset + self.labels = self.df[LABEL_COL] + + def __len__(self): + """ Denotes the total number of samples """ + return len(self.df) + + def __getitem__(self, index): + """ Generates one sample of data """ + token_ids = self.token_ids[index] + input_mask = self.input_mask[index] + token_type_ids = self.token_type_ids[index] + labels = self.labels[index] + + return { + "token_ids": torch.tensor(token_ids, dtype=torch.long), + "input_mask": torch.tensor(input_mask, dtype=torch.long), + "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), + "labels": labels, + } diff --git a/utils_nlp/models/bert/common.py b/utils_nlp/models/bert/common.py index f9db1e4..232357a 100644 --- a/utils_nlp/models/bert/common.py +++ b/utils_nlp/models/bert/common.py @@ -418,7 +418,8 @@ def create_data_loader( class TextDataset(Dataset): """ Characterizes a dataset for PyTorch which can be used to load a file containing multiple rows - where each row is a training example. + where each row is a training example. The format of each line in the file is assumed to be + tokens, mask and label. """ def __init__(self, filename): @@ -457,11 +458,13 @@ class TextDataset(Dataset): tokens = self._cast(row[0][1:-1].split(",")) mask = self._cast(row[1][1:-1].split(",")) - return ( - torch.tensor(tokens, dtype=torch.long), - torch.tensor(mask, dtype=torch.long), - torch.tensor(int(row[2]), dtype=torch.long), - ) + data = { + "token_ids": torch.tensor(tokens, dtype=torch.long), + "input_mask": torch.tensor(mask, dtype=torch.long), + "labels": torch.tensor(int(row[2]), dtype=torch.long), + } + + return data def get_dataset_multiple_files(files): diff --git a/utils_nlp/models/bert/sequence_classification_distributed.py b/utils_nlp/models/bert/sequence_classification_distributed.py index 8494b93..5c3faa0 100644 --- a/utils_nlp/models/bert/sequence_classification_distributed.py +++ b/utils_nlp/models/bert/sequence_classification_distributed.py @@ -1,44 +1,41 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import logging -import horovod.torch as hvd +# This script reuses some code from +# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py + +import os +import warnings + import numpy as np import torch.nn as nn -from torch.utils.data import TensorDataset -import torch.utils.data.distributed +import torch.utils.data from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.optimization import BertAdam from tqdm import tqdm +from utils_nlp.common.pytorch_utils import get_device, move_to_device from utils_nlp.models.bert.common import Language -from utils_nlp.models.bert.common import get_dataset_multiple_files -from utils_nlp.common.pytorch_utils import get_device, move_to_device - -logger = logging.getLogger(__name__) -hvd.init() -torch.manual_seed(42) - -if torch.cuda.is_available(): - # Horovod: pin GPU to local rank. - torch.cuda.set_device(hvd.local_rank()) - torch.cuda.manual_seed(42) +try: + import horovod.torch as hvd +except ImportError: + raise warnings.warn("No Horovod found! Can't do distributed training..") -class BERTSequenceDistClassifier: - """Distributed BERT-based sequence classifier""" +class BERTSequenceClassifier: + """BERT-based sequence classifier""" - def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."): - """Initializes the classifier and the underlying pretrained model. + def __init__( + self, language=Language.ENGLISH, num_labels=2, cache_dir=".", use_distributed=False + ): + + """ Args: - language (Language, optional): The pretrained model's language. - Defaults to Language.ENGLISH. - num_labels (int, optional): The number of unique labels in the - training data. Defaults to 2. - cache_dir (str, optional): Location of BERT's cache directory. - Defaults to ".". + language: Language passed to pre-trained BERT model to pick the appropriate model + num_labels: number of unique labels in train dataset + cache_dir: cache_dir to load pre-trained BERT model. Defaults to "." """ if num_labels < 2: raise ValueError("Number of labels should be at least 2.") @@ -46,280 +43,280 @@ class BERTSequenceDistClassifier: self.language = language self.num_labels = num_labels self.cache_dir = cache_dir - self.kwargs = ( - {"num_workers": 1, "pin_memory": True} - if torch.cuda.is_available() - else {} - ) + self.use_distributed = use_distributed # create classifier self.model = BertForSequenceClassification.from_pretrained( - language.value, num_labels=num_labels + language.value, cache_dir=cache_dir, num_labels=num_labels ) - def fit( - self, - token_ids, - input_mask, - labels, - token_type_ids=None, - input_files, - num_gpus=1, - num_epochs=1, - batch_size=32, - lr=2e-5, - warmup_proportion=None, - verbose=True, - fp16_allreduce=False, - ): - """fine-tunes the bert classifier using the given training data. - - args: - input_files(list, required): list of paths to the training data files. - token_ids (list): List of training token id lists. - input_mask (list): List of input mask lists. - labels (list): List of training labels. - token_type_ids (list, optional): List of lists. Each sublist - contains segment ids indicating if the token belongs to - the first sentence(0) or second sentence(1). Only needed - for two-sentence tasks. - num_gpus (int, optional): the number of gpus to use. - if none is specified, all available gpus - will be used. defaults to none. - num_epochs (int, optional): number of training epochs. - defaults to 1. - batch_size (int, optional): training batch size. defaults to 32. - lr (float): learning rate of the adam optimizer. defaults to 2e-5. - warmup_proportion (float, optional): proportion of training to - perform linear learning rate warmup for. e.g., 0.1 = 10% of - training. defaults to none. - verbose (bool, optional): if true, shows the training progress and - loss values. defaults to true. - fp16_allreduce(bool, optional)L if true, use fp16 compression during allreduce - """ - - if input_files is not None: - train_dataset = get_dataset_multiple_files(input_files) - else: - token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) - input_mask_tensor = torch.tensor(input_mask, dtype=torch.long) - labels_tensor = torch.tensor(labels, dtype=torch.long) - - if token_type_ids: - token_type_ids_tensor = torch.tensor( - token_type_ids, dtype=torch.long - ) - train_dataset = TensorDataset( - token_ids_tensor, - input_mask_tensor, - token_type_ids_tensor, - labels_tensor, - ) - else: - train_dataset = TensorDataset( - token_ids_tensor, input_mask_tensor, labels_tensor - ) - - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, num_replicas=hvd.size(), rank=hvd.rank() - ) - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=batch_size, - sampler=train_sampler, - **self.kwargs - ) - - device = get_device() - self.model.cuda() - - hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) - # hvd.broadcast_optimizer_state(optimizer, root_rank=0) - - # define loss function - loss_func = nn.CrossEntropyLoss().to(device) - # define optimizer and model parameters param_optimizer = list(self.model.named_parameters()) no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { - "params": [ - p - for n, p in param_optimizer - if not any(nd in n for nd in no_decay) - ], + "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01, }, - { - "params": [ - p - for n, p in param_optimizer - if any(nd in n for nd in no_decay) - ] - }, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)]}, ] + self.optimizer_params = optimizer_grouped_parameters + self.name_parameters = self.model.named_parameters() + self.state_dict = self.model.state_dict() - num_examples = len(train_dataset) - num_batches = int(num_examples / batch_size) - num_train_optimization_steps = num_batches * num_epochs + if use_distributed: + hvd.init() + if torch.cuda.is_available(): + torch.cuda.set_device(hvd.local_rank()) + else: + warnings.warn("No GPU available! Using CPU.") + + def create_optimizer( + self, num_train_optimization_steps, lr=2e-5, fp16_allreduce=False, warmup_proportion=None + ): + + """ + Method to create an BERT Optimizer based on the inputs from the user. + + Args: + num_train_optimization_steps(int): Number of optimization steps. + lr (float): learning rate of the adam optimizer. defaults to 2e-5. + warmup_proportion (float, optional): proportion of training to + perform linear learning rate warmup for. e.g., 0.1 = 10% of + training. defaults to none. + fp16_allreduce(bool, optional)L if true, use fp16 compression during allreduce + + Returns: + pytorch_pretrained_bert.optimization.BertAdam : A BertAdam optimizer with user + specified config. + + """ + if self.use_distributed: + lr = lr * hvd.size() if warmup_proportion is None: - optimizer = BertAdam( - optimizer_grouped_parameters, lr=lr * hvd.size() - ) + optimizer = BertAdam(self.optimizer_params, lr=lr) else: optimizer = BertAdam( - optimizer_grouped_parameters, - lr=lr * hvd.size(), + self.optimizer_params, + lr=lr, t_total=num_train_optimization_steps, warmup=warmup_proportion, ) - # Horovod: (optional) compression algorithm. - compression = ( - hvd.Compression.fp16 if fp16_allreduce else hvd.Compression.none - ) + if self.use_distributed: + compression = hvd.Compression.fp16 if fp16_allreduce else hvd.Compression.none + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=self.model.named_parameters(), compression=compression + ) - # Horovod: wrap optimizer with DistributedOptimizer. - optimizer = hvd.DistributedOptimizer( - optimizer, - named_parameters=self.model.named_parameters(), - compression=compression, - ) + return optimizer - # Horovod: set epoch to sampler for shuffling. - for epoch in range(num_epochs): - self.model.train() - train_sampler.set_epoch(epoch) - for batch_idx, batch in enumerate(train_loader): - - if token_type_ids: - x_batch, mask_batch, token_type_ids_batch, y_batch = tuple( - t.to(device) for t in batch - ) - else: - token_type_ids_batch = None - x_batch, mask_batch, y_batch = tuple( - t.to(device) for t in batch - ) - - optimizer.zero_grad() - - output = self.model( - input_ids=x_batch, attention_mask=mask_batch, labels=None - ) - - loss = loss_func(output, y_batch).mean() - loss.backward() - optimizer.step() - if verbose and (batch_idx % ((num_batches // 10) + 1)) == 0: - # Horovod: use train_sampler to determine the number of examples in - # this worker's partition. - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(x_batch), - len(train_sampler), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - # empty cache - torch.cuda.empty_cache() - - def predict( - self, - input_files = None, - token_ids, - input_mask, - token_type_ids=None, - input_files, num_gpus=1, batch_size=32, probabilities=False - ): - """Scores the given set of train files and returns the predicted classes. + def create_data_loader(self, dataset, batch_size=32, mode="train", **kwargs): + """ + Method to create a data loader for a given Tensor dataset. Args: - input_files(list, required): list of paths to the test data files. - token_ids (list): List of training token lists. - input_mask (list): List of input mask lists. - token_type_ids (list, optional): List of lists. Each sublist - contains segment ids indicating if the token belongs to - the first sentence(0) or second sentence(1). Only needed - for two-sentence tasks. + mode(str): Mode for creating data loader. Could be train or test. + dataset(torch.utils.data.Dataset): A Tensor dataset. + batch_size(int): Batch size. + + Returns: + torch.utils.data.DataLoader: A torch data loader to the given dataset. + + """ + + if mode == "test": + sampler = torch.utils.data.sampler.SequentialSampler(dataset) + elif self.use_distributed: + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=hvd.size(), rank=hvd.rank() + ) + else: + sampler = torch.utils.data.RandomSampler(dataset) + + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, sampler=sampler, **kwargs + ) + + return data_loader + + def save_model(self): + """ + Method to save the trained model. + #ToDo: Works for English Language now. Multiple language support needs to be added. + + """ + # Save the model to the outputs directory for capture + output_dir = "outputs" + os.makedirs(output_dir, exist_ok=True) + + # Save a trained model, configuration and tokenizer + model_to_save = self.model.module if hasattr(self.model, "module") else self.model + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = "outputs/bert-large-uncased" + output_config_file = "outputs/bert_config.json" + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + + def fit( + self, + train_loader, + epoch, + bert_optimizer=None, + num_epochs=1, + num_gpus=0, + lr=2e-5, + warmup_proportion=None, + fp16_allreduce=False, + num_train_optimization_steps=10, + ): + """ + Method to fine-tune the bert classifier using the given training data + + Args: + train_loader(torch.DataLoader): Torch Dataloader created from Torch Dataset + epoch(int): Current epoch number of training. + bert_optimizer(optimizer): optimizer can be BERTAdam for local and Dsitributed if Horovod + num_epochs(int): the number of epochs to run + num_gpus(int): the number of gpus + lr (float): learning rate of the adam optimizer. defaults to 2e-5. + warmup_proportion (float, optional): proportion of training to + perform linear learning rate warmup for. e.g., 0.1 = 10% of + training. defaults to none. + fp16_allreduce(bool): if true, use fp16 compression during allreduce + num_train_optimization_steps: number of steps the optimizer should take. + """ + + device = get_device("cpu" if num_gpus == 0 else "gpu") + + if device: + self.model.cuda() + + if bert_optimizer is None: + bert_optimizer = self.create_optimizer( + num_train_optimization_steps=num_train_optimization_steps, + lr=lr, + warmup_proportion=warmup_proportion, + fp16_allreduce=fp16_allreduce, + ) + + if self.use_distributed: + hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) + + loss_func = nn.CrossEntropyLoss().to(device) + + # train + self.model.train() # training mode + + token_type_ids_batch = None + + num_print = 1000 + for batch_idx, data in enumerate(train_loader): + + x_batch = data["token_ids"] + x_batch = x_batch.cuda() + + y_batch = data["labels"] + y_batch = y_batch.cuda() + + mask_batch = data["input_mask"] + mask_batch = mask_batch.cuda() + + if "token_type_ids" in data and data["token_type_ids"] is not None: + token_type_ids_batch = data["token_type_ids"] + token_type_ids_batch = token_type_ids_batch.cuda() + + bert_optimizer.zero_grad() + + y_h = self.model( + input_ids=x_batch, + token_type_ids=token_type_ids_batch, + attention_mask=mask_batch, + labels=None, + ) + + loss = loss_func(y_h, y_batch).mean() + loss.backward() + + bert_optimizer.synchronize() + bert_optimizer.step() + + if batch_idx % num_print == 0: + print( + "Train Epoch: {}/{} ({:.0f}%) \t Batch:{} \tLoss: {:.6f}".format( + epoch, + num_epochs, + 100.0 * batch_idx / len(train_loader), + batch_idx + 1, + loss.item(), + ) + ) + + del [x_batch, y_batch, mask_batch, token_type_ids_batch] + torch.cuda.empty_cache() + + def predict(self, test_loader, num_gpus=None, probabilities=False): + """ + + Method to predict the results on the test loader. Only evaluates for non distributed + workload on the head node in a distributed setup. + + Args: + test_loader(torch Dataloader): Torch Dataloader created from Torch Dataset num_gpus (int, optional): The number of gpus to use. If None is specified, all available GPUs will be used. Defaults to None. - batch_size (int, optional): Scoring batch size. Defaults to 32. probabilities (bool, optional): If True, the predicted probability distribution is also returned. Defaults to False. + Returns: 1darray, dict(1darray, 1darray, ndarray): Predicted classes and target labels or a dictionary with classes, target labels, probabilities) if probabilities is True. """ - - if input_files is not None: - test_dataset = get_dataset_multiple_files(input_files) - - else: - token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) - input_mask_tensor = torch.tensor(input_mask, dtype=torch.long) - - if token_type_ids: - token_type_ids_tensor = torch.tensor( - token_type_ids, dtype=torch.long - ) - test_dataset = TensorDataset( - token_ids_tensor, input_mask_tensor, token_type_ids_tensor - ) - else: - test_dataset = TensorDataset(token_ids_tensor, input_mask_tensor) - - # Horovod: use DistributedSampler to partition the test data. - test_sampler = torch.utils.data.sampler.SequentialSampler(test_dataset) - - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=batch_size, - sampler=test_sampler, - **self.kwargs - ) - - device = get_device() + device = get_device("cpu" if num_gpus == 0 else "gpu") self.model = move_to_device(self.model, device, num_gpus) + + # score self.model.eval() + preds = [] - labels_test = [] + test_labels = [] + for i, data in enumerate(tqdm(test_loader, desc="Iteration")): + x_batch = data["token_ids"] + x_batch = x_batch.cuda() - with tqdm(total=len(test_loader)) as pbar: - for i, (tokens, mask, target) in enumerate(test_loader): - if torch.cuda.is_available(): - tokens, mask, target = ( - tokens.cuda(), - mask.cuda(), - target.cuda(), - ) + mask_batch = data["input_mask"] + mask_batch = mask_batch.cuda() - with torch.no_grad(): - p_batch = self.model( - input_ids=tokens, attention_mask=mask, labels=None - ) - preds.append(p_batch.cpu()) - labels_test.append(target.cpu()) - if i % batch_size == 0: - pbar.update(batch_size) + y_batch = data["labels"] + + token_type_ids_batch = None + if "token_type_ids" in data and data["token_type_ids"] is not None: + token_type_ids_batch = data["token_type_ids"] + token_type_ids_batch = token_type_ids_batch.cuda() + + with torch.no_grad(): + p_batch = self.model( + input_ids=x_batch, + token_type_ids=token_type_ids_batch, + attention_mask=mask_batch, + labels=None, + ) + preds.append(p_batch.cpu()) + test_labels.append(y_batch) preds = np.concatenate(preds) - labels_test = np.concatenate(labels_test) + test_labels = np.concatenate(test_labels) if probabilities: return { "Predictions": preds.argmax(axis=1), - "Target": labels_test, - "classes probabilities": nn.Softmax(dim=1)( - torch.Tensor(preds) - ).numpy(), + "Target": test_labels, + "classes probabilities": nn.Softmax(dim=1)(torch.Tensor(preds)).numpy(), } else: - return preds.argmax(axis=1), labels_test + return preds.argmax(axis=1), test_labels