nlp-recipes/scenarios/entailment/entailment_xnli_multilingua...

582 строки
21 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi-lingual Inference on XNLI Dataset using BERT"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"In this notebook, we demostrate using the [Multi-lingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) to do language inference in Chinese and Hindi. We use the [XNLI](https://github.com/facebookresearch/XNLI) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral. \n",
"The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.\n",
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import random\n",
"import numpy as np\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"import torch\n",
"\n",
"nlp_path = os.path.abspath('../../')\n",
"if nlp_path not in sys.path:\n",
" sys.path.insert(0, nlp_path)\n",
"\n",
"from utils_nlp.models.bert.sequence_classification import BERTSequenceClassifier\n",
"from utils_nlp.models.bert.common import Language, Tokenizer\n",
"from utils_nlp.dataset.xnli import load_pandas_df\n",
"from utils_nlp.common.timer import Timer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configurations\n",
"Note that the running time shown in this notebook are on a Standard_NC12 Azure Deep Learning Virtual Machine with two NVIDIA Tesla K80 GPUs. If you want to run through the notebook quickly, you can change the `TRAIN_DATA_USED_PERCENT` to a small number, e.g. 0.01. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"TRAIN_DATA_USED_PERCENT = 1.0\n",
"\n",
"# set random seeds\n",
"RANDOM_SEED = 42\n",
"random.seed(RANDOM_SEED)\n",
"np.random.seed(RANDOM_SEED)\n",
"torch.manual_seed(RANDOM_SEED)\n",
"num_cuda_devices = torch.cuda.device_count()\n",
"if num_cuda_devices > 1:\n",
" torch.cuda.manual_seed_all(RANDOM_SEED)\n",
"\n",
"# model configurations\n",
"LANGUAGE_CHINESE = Language.CHINESE\n",
"LANGUAGE_MULTI = Language.MULTILINGUAL\n",
"TO_LOWER = True\n",
"MAX_SEQ_LENGTH = 128\n",
"\n",
"# training configurations\n",
"NUM_GPUS = 2\n",
"BATCH_SIZE = 32\n",
"NUM_EPOCHS = 2\n",
"\n",
"# optimizer configurations\n",
"LEARNING_RATE= 5e-5\n",
"WARMUP_PROPORTION= 0.1\n",
"\n",
"# data configurations\n",
"TEXT_COL = \"text\"\n",
"LABEL_COL = \"label\"\n",
"\n",
"CACHE_DIR = \"./temp\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Data\n",
"The XNLI dataset comes in two zip files: \n",
"* XNLI-1.0.zip: dev and test datasets in 15 languages. The original English data was translated into other languages by human translators. \n",
"* XNLI-MT-1.0.zip: training dataset in 15 languages. This dataset is machine translations of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset. It also contains English translations of the dev and test datasets, but not used in this notebook. \n",
"\n",
"The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split` and `language`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\", language=\"zh\")\n",
"dev_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev\", language=\"zh\")\n",
"test_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"test\", language=\"zh\")\n",
"\n",
"train_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\", language=\"hi\")\n",
"dev_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev\", language=\"hi\")\n",
"test_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"test\", language=\"hi\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chinese training dataset size: 392702\n",
"Chinese dev dataset size: 2490\n",
"Chinese test dataset size: 5010\n",
"\n",
"Hindi training dataset size: 392702\n",
"Hindi dev dataset size: 2490\n",
"Hindi test dataset size: 5010\n",
"\n",
" text label\n",
"0 (从 概念 上 看 , 奶油 收入 有 两 个 基本 方面 产品 和 地理 ., 产品 和 ... neutral\n",
"1 (你 知道 在 这个 季节 , 我 猜 在 你 的 水平 你 把 他们 丢到 下 一个 水平... entailment\n",
"2 (我们 的 一个 号码 会 非常 详细 地 执行 你 的 指示, 我 团队 的 一个 成员 ... entailment\n",
"3 (你 怎么 知道 的 ? 所有 这些 都 是 他们 的 信息 ., 这些 信息 属于 他们 .) entailment\n",
"4 (是 啊 , 我 告诉 你 , 如果 你 去 买 一些 网球鞋 , 我 可以 看到 为什么 ... neutral\n",
" text label\n",
"0 (Conceptually क ् रीम एंजलिस में दो मूल आयाम ह... neutral\n",
"1 (आप मौसम के दौरान जानते हैं और मैं अपने स ् तर... entailment\n",
"2 (हमारे एक नंबर में से एक आपके निर ् देशों को म... entailment\n",
"3 (आप कैसे जानते हैं ? ये सब उनकी जानकारी फिर से... entailment\n",
"4 (हाँ मैं आपको बताता हूँ कि अगर आप उन टेनिस जूत... neutral\n"
]
}
],
"source": [
"print(\"Chinese training dataset size: {}\".format(train_df_chinese.shape[0]))\n",
"print(\"Chinese dev dataset size: {}\".format(dev_df_chinese.shape[0]))\n",
"print(\"Chinese test dataset size: {}\".format(test_df_chinese.shape[0]))\n",
"print()\n",
"print(\"Hindi training dataset size: {}\".format(train_df_hindi.shape[0]))\n",
"print(\"Hindi dev dataset size: {}\".format(dev_df_hindi.shape[0]))\n",
"print(\"Hindi test dataset size: {}\".format(test_df_hindi.shape[0]))\n",
"print()\n",
"print(train_df_chinese.head())\n",
"print(train_df_hindi.head())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"train_data_used_count = round(TRAIN_DATA_USED_PERCENT * train_df_chinese.shape[0])\n",
"train_df_chinese = train_df_chinese.loc[:train_data_used_count]\n",
"train_df_hindi = train_df_hindi.loc[:train_data_used_count]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Language Inference on Chinese\n",
"For Chinese dataset, we use the `bert-base-chinese` model which was pretrained on Chinese dataset only. The `bert-base-multilingual-cased` model can also be used on Chinese, but the accuracy is 3% lower."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tokenize and Preprocess\n",
"Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 392702/392702 [02:26<00:00, 2682.67it/s]\n",
"100%|██████████| 5010/5010 [00:01<00:00, 3122.04it/s]\n"
]
}
],
"source": [
"tokenizer_chinese = Tokenizer(LANGUAGE_CHINESE, to_lower=TO_LOWER, cache_dir=CACHE_DIR)\n",
"\n",
"train_tokens_chinese = tokenizer_chinese.tokenize(train_df_chinese[TEXT_COL])\n",
"test_tokens_chinese= tokenizer_chinese.tokenize(test_df_chinese[TEXT_COL])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition, we perform the following preprocessing steps in the cell below:\n",
"\n",
"* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
"* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
"* Pad or truncate the token lists to the specified max length\n",
"* Return mask lists that indicate paddings' positions\n",
"* Return token type id lists that indicate which sentence the tokens belong to\n",
"\n",
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_token_ids_chinese, train_input_mask_chinese, train_token_type_ids_chinese = \\\n",
" tokenizer_chinese.preprocess_classification_tokens(train_tokens_chinese, max_len=MAX_SEQ_LENGTH)\n",
"test_token_ids_chinese, test_input_mask_chinese, test_token_type_ids_chinese = \\\n",
" tokenizer_chinese.preprocess_classification_tokens(test_tokens_chinese, max_len=MAX_SEQ_LENGTH)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"label_encoder_chinese = LabelEncoder()\n",
"train_labels_chinese = label_encoder_chinese.fit_transform(train_df_chinese[LABEL_COL])\n",
"num_labels_chinese = len(np.unique(train_labels_chinese))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Classifier"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"classifier_chinese = BERTSequenceClassifier(language=LANGUAGE_CHINESE,\n",
" num_labels=num_labels_chinese,\n",
" cache_dir=CACHE_DIR)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train Classifier"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/2; batch:1->1228/12271; loss:1.194384\n",
"epoch:1/2; batch:1229->2456/12271; loss:0.863067\n",
"epoch:1/2; batch:2457->3684/12271; loss:0.781256\n",
"epoch:1/2; batch:3685->4912/12271; loss:1.067413\n",
"epoch:1/2; batch:4913->6140/12271; loss:0.599279\n",
"epoch:1/2; batch:6141->7368/12271; loss:0.471488\n",
"epoch:1/2; batch:7369->8596/12271; loss:0.572327\n",
"epoch:1/2; batch:8597->9824/12271; loss:0.689093\n",
"epoch:1/2; batch:9825->11052/12271; loss:0.651702\n",
"epoch:1/2; batch:11053->12271/12271; loss:0.431085\n",
"epoch:2/2; batch:1->1228/12271; loss:0.255859\n",
"epoch:2/2; batch:1229->2456/12271; loss:0.434052\n",
"epoch:2/2; batch:2457->3684/12271; loss:0.433569\n",
"epoch:2/2; batch:3685->4912/12271; loss:0.405915\n",
"epoch:2/2; batch:4913->6140/12271; loss:0.636128\n",
"epoch:2/2; batch:6141->7368/12271; loss:0.416685\n",
"epoch:2/2; batch:7369->8596/12271; loss:0.265789\n",
"epoch:2/2; batch:8597->9824/12271; loss:0.328964\n",
"epoch:2/2; batch:9825->11052/12271; loss:0.436310\n",
"epoch:2/2; batch:11053->12271/12271; loss:0.374193\n",
"Training time : 8.050 hrs\n"
]
}
],
"source": [
"with Timer() as t:\n",
" classifier_chinese.fit(token_ids=train_token_ids_chinese,\n",
" input_mask=train_input_mask_chinese,\n",
" token_type_ids=train_token_type_ids_chinese,\n",
" labels=train_labels_chinese,\n",
" num_gpus=NUM_GPUS,\n",
" num_epochs=NUM_EPOCHS,\n",
" batch_size=BATCH_SIZE,\n",
" lr=LEARNING_RATE,\n",
" warmup_proportion=WARMUP_PROPORTION)\n",
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predict on Test Data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"5024it [00:54, 101.88it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction time : 0.015 hrs\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"with Timer() as t:\n",
" predictions_chinese = classifier_chinese.predict(token_ids=test_token_ids_chinese,\n",
" input_mask=test_input_mask_chinese,\n",
" token_type_ids=test_token_type_ids_chinese,\n",
" batch_size=BATCH_SIZE)\n",
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
"contradiction 0.81 0.84 0.82 1670\n",
" entailment 0.84 0.68 0.76 1670\n",
" neutral 0.70 0.80 0.74 1670\n",
"\n",
" accuracy 0.77 5010\n",
" macro avg 0.78 0.77 0.77 5010\n",
" weighted avg 0.78 0.77 0.77 5010\n",
"\n"
]
}
],
"source": [
"predictions_chinese = label_encoder_chinese.inverse_transform(predictions_chinese)\n",
"print(classification_report(test_df_chinese[LABEL_COL], predictions_chinese))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Language Inference on Hindi\n",
"For Hindi and all other languages except Chinese, we use the `bert-base-multilingual-cased` model. \n",
"The preprocesing, model training, and prediction steps are the same as on Chinese data, except for the underlying tokenizer and BERT model used"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tokenize and Preprocess"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 392702/392702 [03:48<00:00, 1719.84it/s]\n",
"100%|██████████| 5010/5010 [00:02<00:00, 1916.46it/s]\n"
]
}
],
"source": [
"tokenizer_multi = Tokenizer(LANGUAGE_MULTI, cache_dir=CACHE_DIR)\n",
"\n",
"train_tokens_hindi = tokenizer_multi.tokenize(train_df_hindi[TEXT_COL])\n",
"test_tokens_hindi= tokenizer_multi.tokenize(test_df_hindi[TEXT_COL])\n",
"\n",
"train_token_ids_hindi, train_input_mask_hindi, train_token_type_ids_hindi = \\\n",
" tokenizer_multi.preprocess_classification_tokens(train_tokens_hindi, max_len=MAX_SEQ_LENGTH)\n",
"test_token_ids_hindi, test_input_mask_hindi, test_token_type_ids_hindi = \\\n",
" tokenizer_multi.preprocess_classification_tokens(test_tokens_hindi, max_len=MAX_SEQ_LENGTH)\n",
"\n",
"label_encoder_hindi = LabelEncoder()\n",
"train_labels_hindi = label_encoder_hindi.fit_transform(train_df_hindi[LABEL_COL])\n",
"num_labels_hindi = len(np.unique(train_labels_hindi))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create and Train Classifier"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/2; batch:1->1228/12271; loss:1.091754\n",
"epoch:1/2; batch:1229->2456/12271; loss:0.992931\n",
"epoch:1/2; batch:2457->3684/12271; loss:1.045146\n",
"epoch:1/2; batch:3685->4912/12271; loss:0.799912\n",
"epoch:1/2; batch:4913->6140/12271; loss:0.815425\n",
"epoch:1/2; batch:6141->7368/12271; loss:0.564856\n",
"epoch:1/2; batch:7369->8596/12271; loss:0.726981\n",
"epoch:1/2; batch:8597->9824/12271; loss:0.764087\n",
"epoch:1/2; batch:9825->11052/12271; loss:0.964115\n",
"epoch:1/2; batch:11053->12271/12271; loss:0.502252\n",
"epoch:2/2; batch:1->1228/12271; loss:0.601600\n",
"epoch:2/2; batch:1229->2456/12271; loss:0.695099\n",
"epoch:2/2; batch:2457->3684/12271; loss:0.419610\n",
"epoch:2/2; batch:3685->4912/12271; loss:0.603106\n",
"epoch:2/2; batch:4913->6140/12271; loss:0.705180\n",
"epoch:2/2; batch:6141->7368/12271; loss:0.493404\n",
"epoch:2/2; batch:7369->8596/12271; loss:0.864921\n",
"epoch:2/2; batch:8597->9824/12271; loss:0.518601\n",
"epoch:2/2; batch:9825->11052/12271; loss:0.395920\n",
"epoch:2/2; batch:11053->12271/12271; loss:0.685858\n",
"Training time : 9.520 hrs\n"
]
}
],
"source": [
"classifier_multi = BERTSequenceClassifier(language=LANGUAGE_MULTI,\n",
" num_labels=num_labels_hindi,\n",
" cache_dir=CACHE_DIR)\n",
"with Timer() as t:\n",
" classifier_multi.fit(token_ids=train_token_ids_hindi,\n",
" input_mask=train_input_mask_hindi,\n",
" token_type_ids=train_token_type_ids_hindi,\n",
" labels=train_labels_hindi,\n",
" num_gpus=NUM_GPUS,\n",
" num_epochs=NUM_EPOCHS,\n",
" batch_size=BATCH_SIZE,\n",
" lr=LEARNING_RATE,\n",
" warmup_proportion=WARMUP_PROPORTION)\n",
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Predict and Evaluate"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"5024it [01:02, 87.10it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction time : 0.017 hrs\n",
" precision recall f1-score support\n",
"\n",
"contradiction 0.69 0.72 0.70 1670\n",
" entailment 0.74 0.51 0.60 1670\n",
" neutral 0.58 0.74 0.65 1670\n",
"\n",
" accuracy 0.65 5010\n",
" macro avg 0.67 0.65 0.65 5010\n",
" weighted avg 0.67 0.65 0.65 5010\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"with Timer() as t:\n",
" predictions_hindi = classifier_multi.predict(token_ids=test_token_ids_hindi,\n",
" input_mask=test_input_mask_hindi,\n",
" token_type_ids=test_token_type_ids_hindi,\n",
" batch_size=BATCH_SIZE)\n",
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))\n",
"predictions_hindi= label_encoder_hindi.inverse_transform(predictions_hindi)\n",
"print(classification_report(test_df_hindi[LABEL_COL], predictions_hindi))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}