diff --git a/scenarios/text_classification/tc_mnli_bert.ipynb b/scenarios/text_classification/tc_mnli_bert.ipynb
new file mode 100644
index 0000000..1f7b3fa
--- /dev/null
+++ b/scenarios/text_classification/tc_mnli_bert.ipynb
@@ -0,0 +1,530 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
+ "\n",
+ "*Licensed under the MIT License.*\n",
+ "\n",
+ "# Text Classification of MultiNLI Sentences using BERT"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"../../\")\n",
+ "import os\n",
+ "import pandas as pd\n",
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
+ "from sklearn.preprocessing import LabelEncoder\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from utils_nlp.dataset.multinli import load_pandas_df\n",
+ "from utils_nlp.eval.classification import eval_classification\n",
+ "from utils_nlp.bert.sequence_classification import SequenceClassifier\n",
+ "from utils_nlp.bert.common import Language, Tokenizer\n",
+ "from utils_nlp.common.timer import Timer\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Introduction\n",
+ "In this notebook, we fine-tune and evaluate a pretrained [BERT](https://arxiv.org/abs/1810.04805) model on a subset of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset.\n",
+ "\n",
+ "We use a [sequence classifier](../../utils_nlp/bert/sequence_classification.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT) of Google's [BERT](https://github.com/google-research/bert)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DATA_FOLDER = \"./temp\"\n",
+ "BERT_CACHE_DIR = \"./temp\"\n",
+ "LANGUAGE = Language.ENGLISH\n",
+ "TO_LOWER = True\n",
+ "MAX_LEN = 150\n",
+ "BATCH_SIZE = 32\n",
+ "NUM_GPUS = 2\n",
+ "NUM_EPOCHS = 1\n",
+ "TRAIN_SIZE = 0.6\n",
+ "LABEL_COL = \"genre\"\n",
+ "TEXT_COL = \"sentence1\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Read Dataset\n",
+ "We start by loading a subset of the data. The following function also downloads and extracts the files, if they don't exist in the data folder.\n",
+ "\n",
+ "The MultiNLI dataset is mainly used for natural language inference (NLI) tasks, where the inputs are sentence pairs and the labels are entailment indicators. The sentence pairs are also classified into *genres* that allow for more coverage and better evaluation of NLI models.\n",
+ "\n",
+ "For our classification task, we use the first sentence only as the text input, and the corresponding genre as the label. We select the examples corresponding to one of the entailment labels (*neutral* in this case) to avoid duplicate rows, as the sentences are not unique, whereas the sentence pairs are."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = load_pandas_df(DATA_FOLDER, \"train\")\n",
+ "df = df[df[\"gold_label\"]==\"neutral\"] # get unique sentences"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " genre | \n",
+ " sentence1 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " government | \n",
+ " Conceptually cream skimming has two basic dime... | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " telephone | \n",
+ " yeah i tell you what though if you go price so... | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " travel | \n",
+ " But a few Christian mosaics survive above the ... | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " slate | \n",
+ " It's not that the questions they asked weren't... | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " travel | \n",
+ " Thebes held onto power until the 12th Dynasty,... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " genre sentence1\n",
+ "0 government Conceptually cream skimming has two basic dime...\n",
+ "4 telephone yeah i tell you what though if you go price so...\n",
+ "6 travel But a few Christian mosaics survive above the ...\n",
+ "12 slate It's not that the questions they asked weren't...\n",
+ "13 travel Thebes held onto power until the 12th Dynasty,..."
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df[[LABEL_COL, TEXT_COL]].head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The examples in the dataset are grouped into 5 genres:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "telephone 27783\n",
+ "government 25784\n",
+ "travel 25783\n",
+ "fiction 25782\n",
+ "slate 25768\n",
+ "Name: genre, dtype: int64"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df[LABEL_COL].value_counts()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We split the data for training and testing, and encode the class labels:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# split\n",
+ "df_train, df_test = train_test_split(df, train_size = TRAIN_SIZE, random_state=0)\n",
+ "\n",
+ "# encode labels\n",
+ "label_encoder = LabelEncoder()\n",
+ "labels_train = label_encoder.fit_transform(df_train[LABEL_COL])\n",
+ "labels_test = label_encoder.transform(df_test[LABEL_COL])\n",
+ "\n",
+ "num_labels = len(np.unique(labels_train))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of unique labels: 5\n",
+ "Number of training examples: 78540\n",
+ "Number of testing examples: 52360\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Number of unique labels: {}\".format(num_labels))\n",
+ "print(\"Number of training examples: {}\".format(df_train.shape[0]))\n",
+ "print(\"Number of testing examples: {}\".format(df_test.shape[0]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Tokenize and Preprocess"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = Tokenizer(LANGUAGE, to_lower=TO_LOWER, cache_dir=BERT_CACHE_DIR)\n",
+ "\n",
+ "tokens_train = tokenizer.tokenize(df_train[TEXT_COL])\n",
+ "tokens_test = tokenizer.tokenize(df_test[TEXT_COL])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In addition, we perform the following preprocessing steps in the cell below:\n",
+ "- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
+ "- Add sentence markers\n",
+ "- Pad or truncate the token lists to the specified max length\n",
+ "- Return mask lists that indicate paddings' positions\n",
+ "\n",
+ "*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
+ " tokens_train, MAX_LEN\n",
+ ")\n",
+ "tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
+ " tokens_test, MAX_LEN\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Create Model\n",
+ "Next, we create a sequence classifier that loads a pre-trained BERT model, given the language and number of labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "classifier = SequenceClassifier(\n",
+ " language=Language.ENGLISH, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train\n",
+ "We train the classifier using the training examples. This involves fine-tuning the BERT Transformer and learning a linear classification layer on top of that:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "t_total value of -1 results in schedule not being applied\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch:1/1; batch:1->246/2454; loss:1.631739\n",
+ "epoch:1/1; batch:247->492/2454; loss:0.427608\n",
+ "epoch:1/1; batch:493->738/2454; loss:0.255493\n",
+ "epoch:1/1; batch:739->984/2454; loss:0.286230\n",
+ "epoch:1/1; batch:985->1230/2454; loss:0.375268\n",
+ "epoch:1/1; batch:1231->1476/2454; loss:0.146290\n",
+ "epoch:1/1; batch:1477->1722/2454; loss:0.092100\n",
+ "epoch:1/1; batch:1723->1968/2454; loss:0.009405\n",
+ "epoch:1/1; batch:1969->2214/2454; loss:0.038235\n",
+ "epoch:1/1; batch:2215->2460/2454; loss:0.098216\n",
+ "[Training time: 0.981 hrs]\n"
+ ]
+ }
+ ],
+ "source": [
+ "with Timer() as t:\n",
+ " classifier.fit(\n",
+ " token_ids=tokens_train,\n",
+ " input_mask=mask_train,\n",
+ " labels=labels_train, \n",
+ " num_gpus=NUM_GPUS, \n",
+ " num_epochs=NUM_EPOCHS,\n",
+ " batch_size=BATCH_SIZE, \n",
+ " verbose=True,\n",
+ " ) \n",
+ "print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Score\n",
+ "We score the test set using the trained classifier:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "52384it [11:54, 88.50it/s] \n"
+ ]
+ }
+ ],
+ "source": [
+ "preds = classifier.predict(\n",
+ " token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Evaluate Results\n",
+ "Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " accuracy: 0.938273\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " label | \n",
+ " precision | \n",
+ " recall | \n",
+ " f1 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " fiction | \n",
+ " 0.917004 | \n",
+ " 0.925839 | \n",
+ " 0.921401 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " government | \n",
+ " 0.961477 | \n",
+ " 0.928780 | \n",
+ " 0.944845 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " slate | \n",
+ " 0.875161 | \n",
+ " 0.861535 | \n",
+ " 0.868295 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " telephone | \n",
+ " 0.989105 | \n",
+ " 0.996609 | \n",
+ " 0.992843 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " travel | \n",
+ " 0.943405 | \n",
+ " 0.973232 | \n",
+ " 0.958087 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " label precision recall f1\n",
+ "0 fiction 0.917004 0.925839 0.921401\n",
+ "1 government 0.961477 0.928780 0.944845\n",
+ "2 slate 0.875161 0.861535 0.868295\n",
+ "3 telephone 0.989105 0.996609 0.992843\n",
+ "4 travel 0.943405 0.973232 0.958087"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "accuracy = accuracy_score(labels_test, preds)\n",
+ "precision = precision_score(labels_test, preds, average=None)\n",
+ "recall = recall_score(labels_test, preds, average=None)\n",
+ "f1 = f1_score(labels_test, preds, average=None)\n",
+ "\n",
+ "print(\"\\n accuracy: {:.6f}\".format(accuracy))\n",
+ "pd.DataFrame({\"label\": label_encoder.classes_, \"precision\": precision, \"recall\": recall, \"f1\": f1})"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/utils_nlp/dataset/imdb.py b/utils_nlp/dataset/imdb.py
deleted file mode 100644
index 9e29b3f..0000000
--- a/utils_nlp/dataset/imdb.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-
-"""IMDB dataset utils"""
-
-import os
-import pandas as pd
-from utils_nlp.dataset.url_utils import maybe_download, extract_tar
-
-
-URL = "https://s3.amazonaws.com/fast-ai-nlp/imdb.tgz"
-
-
-def download(dir_path):
- """Downloads and extracts the dataset files"""
- file_name = URL.split("/")[-1]
- maybe_download(URL, file_name, dir_path)
- extract_tar(os.path.join(dir_path, file_name))
-
-
-def get_df(dir_path, label):
- """Returns a pandas dataframe given a path,
- and appends the provided label"""
- text = []
- for doc_file in os.listdir(dir_path):
- with open(os.path.join(dir_path, doc_file)) as f:
- text.append(f.read())
- labels = [label] * len(text)
- return pd.DataFrame({"text": text, "label": labels})
diff --git a/utils_nlp/dataset/multinli.py b/utils_nlp/dataset/multinli.py
new file mode 100644
index 0000000..ae5f400
--- /dev/null
+++ b/utils_nlp/dataset/multinli.py
@@ -0,0 +1,45 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""MultiNLI dataset utils
+https://www.nyu.edu/projects/bowman/multinli/
+"""
+
+import os
+import pandas as pd
+import requests
+from utils_nlp.dataset.url_utils import extract_zip, maybe_download
+
+URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
+DATA_FILES = {
+ "train": "multinli_1.0/multinli_1.0_train.jsonl",
+ "dev_matched": "multinli_1.0/multinli_1.0_dev_matched.jsonl",
+ "dev_mismatched": "multinli_1.0/multinli_1.0_dev_mismatched.jsonl",
+}
+
+
+def load_pandas_df(local_cache_path=None, file_split="train"):
+ """Downloads and extracts the dataset files
+ Args:
+ local_cache_path ([type], optional): [description]. Defaults to None.
+ file_split (str, optional): The subset to load.
+ One of: {"train", "dev_matched", "dev_mismatched"}
+ Defaults to "train".
+ Returns:
+ pd.DataFrame: pandas DataFrame containing the specified MultiNLI subset.
+ """
+
+ file_name = URL.split("/")[-1]
+ if not os.path.exists(os.path.join(local_cache_path, file_name)):
+ response = requests.get(URL)
+ with open(os.path.join(local_cache_path, file_name), "wb") as f:
+ f.write(response.content)
+ if not os.path.exists(
+ os.path.join(local_cache_path, DATA_FILES[file_split])
+ ):
+ extract_zip(
+ os.path.join(local_cache_path, file_name), local_cache_path
+ )
+ return pd.read_json(
+ os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True
+ )
diff --git a/utils_nlp/dataset/url_utils.py b/utils_nlp/dataset/url_utils.py
index 892cbad..2120023 100644
--- a/utils_nlp/dataset/url_utils.py
+++ b/utils_nlp/dataset/url_utils.py
@@ -4,6 +4,7 @@
import os
from urllib.request import urlretrieve
import tarfile
+from zipfile import ZipFile
import logging
from contextlib import contextmanager
from tempfile import TemporaryDirectory
@@ -60,19 +61,32 @@ def maybe_download(
return filepath
-def extract_tar(file_path, extract_to="."):
+def extract_tar(file_path, dest_path="."):
"""Extracts all contents of a tar archive file.
Args:
file_path (str): Path of file to extract.
- extract_to (str, optional): Destination directory. Defaults to ".".
+ dest_path (str, optional): Destination directory. Defaults to ".".
"""
if not os.path.exists(file_path):
raise IOError("File doesn't exist")
- if not os.path.exists(extract_to):
+ if not os.path.exists(dest_path):
raise IOError("Destination directory doesn't exist")
- tar = tarfile.open(file_path)
- tar.extractall(path=extract_to)
- tar.close()
+ with tarfile.open(file_path) as t:
+ t.extractall(path=dest_path)
+
+
+def extract_zip(file_path, dest_path="."):
+ """Extracts all contents of a zip archive file.
+ Args:
+ file_path (str): Path of file to extract.
+ dest_path (str, optional): Destination directory. Defaults to ".".
+ """
+ if not os.path.exists(file_path):
+ raise IOError("File doesn't exist")
+ if not os.path.exists(dest_path):
+ raise IOError("Destination directory doesn't exist")
+ with ZipFile(file_path) as z:
+ z.extractall(path=dest_path)
@contextmanager
diff --git a/utils_nlp/dataset/xnli.py b/utils_nlp/dataset/xnli.py
new file mode 100644
index 0000000..37cb550
--- /dev/null
+++ b/utils_nlp/dataset/xnli.py
@@ -0,0 +1,44 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""XNLI dataset utils
+https://www.nyu.edu/projects/bowman/xnli/
+"""
+
+import os
+import pandas as pd
+import requests
+from utils_nlp.dataset.url_utils import extract_zip, maybe_download
+
+
+URL = "https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip"
+
+DATA_FILES = {"dev": "XNLI-1.0/xnli.dev.jsonl", "test": "XNLI-1.0/xnli.test.jsonl"}
+
+
+def load_pandas_df(local_cache_path=None, file_split="train"):
+ """Downloads and extracts the dataset files
+ Args:
+ local_cache_path ([type], optional): [description]. Defaults to None.
+ file_split (str, optional): The subset to load.
+ One of: {"dev", "test"}
+ Defaults to "train".
+ Returns:
+ pd.DataFrame: pandas DataFrame containing the specified XNLI subset.
+ """
+
+ file_name = URL.split("/")[-1]
+ if not os.path.exists(os.path.join(local_cache_path, file_name)):
+ response = requests.get(URL)
+ with open(os.path.join(local_cache_path, file_name), "wb") as f:
+ f.write(response.content)
+ if not os.path.exists(
+ os.path.join(local_cache_path, DATA_FILES[file_split])
+ ):
+ extract_zip(
+ os.path.join(local_cache_path, file_name), local_cache_path
+ )
+ return pd.read_json(
+ os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True
+ )
+