582 строки
21 KiB
Plaintext
582 строки
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Multi-lingual Inference on XNLI Dataset using BERT"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Summary\n",
|
|
"In this notebook, we demostrate using the [Multi-lingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) to do language inference in Chinese and Hindi. We use the [XNLI](https://github.com/facebookresearch/XNLI) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral. \n",
|
|
"The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.\n",
|
|
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"import os\n",
|
|
"import random\n",
|
|
"import numpy as np\n",
|
|
"from sklearn.metrics import classification_report\n",
|
|
"from sklearn.preprocessing import LabelEncoder\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"nlp_path = os.path.abspath('../../')\n",
|
|
"if nlp_path not in sys.path:\n",
|
|
" sys.path.insert(0, nlp_path)\n",
|
|
"\n",
|
|
"from utils_nlp.models.bert.sequence_classification import BERTSequenceClassifier\n",
|
|
"from utils_nlp.models.bert.common import Language, Tokenizer\n",
|
|
"from utils_nlp.dataset.xnli import load_pandas_df\n",
|
|
"from utils_nlp.common.timer import Timer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Configurations\n",
|
|
"Note that the running time shown in this notebook are on a Standard_NC12 Azure Deep Learning Virtual Machine with two NVIDIA Tesla K80 GPUs. If you want to run through the notebook quickly, you can change the `TRAIN_DATA_USED_PERCENT` to a small number, e.g. 0.01. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"TRAIN_DATA_USED_PERCENT = 1.0\n",
|
|
"\n",
|
|
"# set random seeds\n",
|
|
"RANDOM_SEED = 42\n",
|
|
"random.seed(RANDOM_SEED)\n",
|
|
"np.random.seed(RANDOM_SEED)\n",
|
|
"torch.manual_seed(RANDOM_SEED)\n",
|
|
"num_cuda_devices = torch.cuda.device_count()\n",
|
|
"if num_cuda_devices > 1:\n",
|
|
" torch.cuda.manual_seed_all(RANDOM_SEED)\n",
|
|
"\n",
|
|
"# model configurations\n",
|
|
"LANGUAGE_CHINESE = Language.CHINESE\n",
|
|
"LANGUAGE_MULTI = Language.MULTILINGUAL\n",
|
|
"TO_LOWER = True\n",
|
|
"MAX_SEQ_LENGTH = 128\n",
|
|
"\n",
|
|
"# training configurations\n",
|
|
"NUM_GPUS = 2\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"NUM_EPOCHS = 2\n",
|
|
"\n",
|
|
"# optimizer configurations\n",
|
|
"LEARNING_RATE= 5e-5\n",
|
|
"WARMUP_PROPORTION= 0.1\n",
|
|
"\n",
|
|
"# data configurations\n",
|
|
"TEXT_COL = \"text\"\n",
|
|
"LABEL_COL = \"label\"\n",
|
|
"\n",
|
|
"CACHE_DIR = \"./temp\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load Data\n",
|
|
"The XNLI dataset comes in two zip files: \n",
|
|
"* XNLI-1.0.zip: dev and test datasets in 15 languages. The original English data was translated into other languages by human translators. \n",
|
|
"* XNLI-MT-1.0.zip: training dataset in 15 languages. This dataset is machine translations of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset. It also contains English translations of the dev and test datasets, but not used in this notebook. \n",
|
|
"\n",
|
|
"The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split` and `language`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\", language=\"zh\")\n",
|
|
"dev_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev\", language=\"zh\")\n",
|
|
"test_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"test\", language=\"zh\")\n",
|
|
"\n",
|
|
"train_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\", language=\"hi\")\n",
|
|
"dev_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev\", language=\"hi\")\n",
|
|
"test_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"test\", language=\"hi\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Chinese training dataset size: 392702\n",
|
|
"Chinese dev dataset size: 2490\n",
|
|
"Chinese test dataset size: 5010\n",
|
|
"\n",
|
|
"Hindi training dataset size: 392702\n",
|
|
"Hindi dev dataset size: 2490\n",
|
|
"Hindi test dataset size: 5010\n",
|
|
"\n",
|
|
" text label\n",
|
|
"0 (从 概念 上 看 , 奶油 收入 有 两 个 基本 方面 产品 和 地理 ., 产品 和 ... neutral\n",
|
|
"1 (你 知道 在 这个 季节 , 我 猜 在 你 的 水平 你 把 他们 丢到 下 一个 水平... entailment\n",
|
|
"2 (我们 的 一个 号码 会 非常 详细 地 执行 你 的 指示, 我 团队 的 一个 成员 ... entailment\n",
|
|
"3 (你 怎么 知道 的 ? 所有 这些 都 是 他们 的 信息 ., 这些 信息 属于 他们 .) entailment\n",
|
|
"4 (是 啊 , 我 告诉 你 , 如果 你 去 买 一些 网球鞋 , 我 可以 看到 为什么 ... neutral\n",
|
|
" text label\n",
|
|
"0 (Conceptually क ् रीम एंजलिस में दो मूल आयाम ह... neutral\n",
|
|
"1 (आप मौसम के दौरान जानते हैं और मैं अपने स ् तर... entailment\n",
|
|
"2 (हमारे एक नंबर में से एक आपके निर ् देशों को म... entailment\n",
|
|
"3 (आप कैसे जानते हैं ? ये सब उनकी जानकारी फिर से... entailment\n",
|
|
"4 (हाँ मैं आपको बताता हूँ कि अगर आप उन टेनिस जूत... neutral\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(\"Chinese training dataset size: {}\".format(train_df_chinese.shape[0]))\n",
|
|
"print(\"Chinese dev dataset size: {}\".format(dev_df_chinese.shape[0]))\n",
|
|
"print(\"Chinese test dataset size: {}\".format(test_df_chinese.shape[0]))\n",
|
|
"print()\n",
|
|
"print(\"Hindi training dataset size: {}\".format(train_df_hindi.shape[0]))\n",
|
|
"print(\"Hindi dev dataset size: {}\".format(dev_df_hindi.shape[0]))\n",
|
|
"print(\"Hindi test dataset size: {}\".format(test_df_hindi.shape[0]))\n",
|
|
"print()\n",
|
|
"print(train_df_chinese.head())\n",
|
|
"print(train_df_hindi.head())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_data_used_count = round(TRAIN_DATA_USED_PERCENT * train_df_chinese.shape[0])\n",
|
|
"train_df_chinese = train_df_chinese.loc[:train_data_used_count]\n",
|
|
"train_df_hindi = train_df_hindi.loc[:train_data_used_count]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Language Inference on Chinese\n",
|
|
"For Chinese dataset, we use the `bert-base-chinese` model which was pretrained on Chinese dataset only. The `bert-base-multilingual-cased` model can also be used on Chinese, but the accuracy is 3% lower."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Tokenize and Preprocess\n",
|
|
"Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 392702/392702 [02:26<00:00, 2682.67it/s]\n",
|
|
"100%|██████████| 5010/5010 [00:01<00:00, 3122.04it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"tokenizer_chinese = Tokenizer(LANGUAGE_CHINESE, to_lower=TO_LOWER, cache_dir=CACHE_DIR)\n",
|
|
"\n",
|
|
"train_tokens_chinese = tokenizer_chinese.tokenize(train_df_chinese[TEXT_COL])\n",
|
|
"test_tokens_chinese= tokenizer_chinese.tokenize(test_df_chinese[TEXT_COL])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
|
"\n",
|
|
"* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
|
"* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
|
"* Pad or truncate the token lists to the specified max length\n",
|
|
"* Return mask lists that indicate paddings' positions\n",
|
|
"* Return token type id lists that indicate which sentence the tokens belong to\n",
|
|
"\n",
|
|
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_token_ids_chinese, train_input_mask_chinese, train_token_type_ids_chinese = \\\n",
|
|
" tokenizer_chinese.preprocess_classification_tokens(train_tokens_chinese, max_len=MAX_SEQ_LENGTH)\n",
|
|
"test_token_ids_chinese, test_input_mask_chinese, test_token_type_ids_chinese = \\\n",
|
|
" tokenizer_chinese.preprocess_classification_tokens(test_tokens_chinese, max_len=MAX_SEQ_LENGTH)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"label_encoder_chinese = LabelEncoder()\n",
|
|
"train_labels_chinese = label_encoder_chinese.fit_transform(train_df_chinese[LABEL_COL])\n",
|
|
"num_labels_chinese = len(np.unique(train_labels_chinese))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Create Classifier"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"classifier_chinese = BERTSequenceClassifier(language=LANGUAGE_CHINESE,\n",
|
|
" num_labels=num_labels_chinese,\n",
|
|
" cache_dir=CACHE_DIR)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Train Classifier"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch:1/2; batch:1->1228/12271; loss:1.194384\n",
|
|
"epoch:1/2; batch:1229->2456/12271; loss:0.863067\n",
|
|
"epoch:1/2; batch:2457->3684/12271; loss:0.781256\n",
|
|
"epoch:1/2; batch:3685->4912/12271; loss:1.067413\n",
|
|
"epoch:1/2; batch:4913->6140/12271; loss:0.599279\n",
|
|
"epoch:1/2; batch:6141->7368/12271; loss:0.471488\n",
|
|
"epoch:1/2; batch:7369->8596/12271; loss:0.572327\n",
|
|
"epoch:1/2; batch:8597->9824/12271; loss:0.689093\n",
|
|
"epoch:1/2; batch:9825->11052/12271; loss:0.651702\n",
|
|
"epoch:1/2; batch:11053->12271/12271; loss:0.431085\n",
|
|
"epoch:2/2; batch:1->1228/12271; loss:0.255859\n",
|
|
"epoch:2/2; batch:1229->2456/12271; loss:0.434052\n",
|
|
"epoch:2/2; batch:2457->3684/12271; loss:0.433569\n",
|
|
"epoch:2/2; batch:3685->4912/12271; loss:0.405915\n",
|
|
"epoch:2/2; batch:4913->6140/12271; loss:0.636128\n",
|
|
"epoch:2/2; batch:6141->7368/12271; loss:0.416685\n",
|
|
"epoch:2/2; batch:7369->8596/12271; loss:0.265789\n",
|
|
"epoch:2/2; batch:8597->9824/12271; loss:0.328964\n",
|
|
"epoch:2/2; batch:9825->11052/12271; loss:0.436310\n",
|
|
"epoch:2/2; batch:11053->12271/12271; loss:0.374193\n",
|
|
"Training time : 8.050 hrs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with Timer() as t:\n",
|
|
" classifier_chinese.fit(token_ids=train_token_ids_chinese,\n",
|
|
" input_mask=train_input_mask_chinese,\n",
|
|
" token_type_ids=train_token_type_ids_chinese,\n",
|
|
" labels=train_labels_chinese,\n",
|
|
" num_gpus=NUM_GPUS,\n",
|
|
" num_epochs=NUM_EPOCHS,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" lr=LEARNING_RATE,\n",
|
|
" warmup_proportion=WARMUP_PROPORTION)\n",
|
|
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Predict on Test Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"5024it [00:54, 101.88it/s] "
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Prediction time : 0.015 hrs\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with Timer() as t:\n",
|
|
" predictions_chinese = classifier_chinese.predict(token_ids=test_token_ids_chinese,\n",
|
|
" input_mask=test_input_mask_chinese,\n",
|
|
" token_type_ids=test_token_type_ids_chinese,\n",
|
|
" batch_size=BATCH_SIZE)\n",
|
|
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Evaluate"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
"contradiction 0.81 0.84 0.82 1670\n",
|
|
" entailment 0.84 0.68 0.76 1670\n",
|
|
" neutral 0.70 0.80 0.74 1670\n",
|
|
"\n",
|
|
" accuracy 0.77 5010\n",
|
|
" macro avg 0.78 0.77 0.77 5010\n",
|
|
" weighted avg 0.78 0.77 0.77 5010\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"predictions_chinese = label_encoder_chinese.inverse_transform(predictions_chinese)\n",
|
|
"print(classification_report(test_df_chinese[LABEL_COL], predictions_chinese))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Language Inference on Hindi\n",
|
|
"For Hindi and all other languages except Chinese, we use the `bert-base-multilingual-cased` model. \n",
|
|
"The preprocesing, model training, and prediction steps are the same as on Chinese data, except for the underlying tokenizer and BERT model used"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Tokenize and Preprocess"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 392702/392702 [03:48<00:00, 1719.84it/s]\n",
|
|
"100%|██████████| 5010/5010 [00:02<00:00, 1916.46it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"tokenizer_multi = Tokenizer(LANGUAGE_MULTI, cache_dir=CACHE_DIR)\n",
|
|
"\n",
|
|
"train_tokens_hindi = tokenizer_multi.tokenize(train_df_hindi[TEXT_COL])\n",
|
|
"test_tokens_hindi= tokenizer_multi.tokenize(test_df_hindi[TEXT_COL])\n",
|
|
"\n",
|
|
"train_token_ids_hindi, train_input_mask_hindi, train_token_type_ids_hindi = \\\n",
|
|
" tokenizer_multi.preprocess_classification_tokens(train_tokens_hindi, max_len=MAX_SEQ_LENGTH)\n",
|
|
"test_token_ids_hindi, test_input_mask_hindi, test_token_type_ids_hindi = \\\n",
|
|
" tokenizer_multi.preprocess_classification_tokens(test_tokens_hindi, max_len=MAX_SEQ_LENGTH)\n",
|
|
"\n",
|
|
"label_encoder_hindi = LabelEncoder()\n",
|
|
"train_labels_hindi = label_encoder_hindi.fit_transform(train_df_hindi[LABEL_COL])\n",
|
|
"num_labels_hindi = len(np.unique(train_labels_hindi))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Create and Train Classifier"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch:1/2; batch:1->1228/12271; loss:1.091754\n",
|
|
"epoch:1/2; batch:1229->2456/12271; loss:0.992931\n",
|
|
"epoch:1/2; batch:2457->3684/12271; loss:1.045146\n",
|
|
"epoch:1/2; batch:3685->4912/12271; loss:0.799912\n",
|
|
"epoch:1/2; batch:4913->6140/12271; loss:0.815425\n",
|
|
"epoch:1/2; batch:6141->7368/12271; loss:0.564856\n",
|
|
"epoch:1/2; batch:7369->8596/12271; loss:0.726981\n",
|
|
"epoch:1/2; batch:8597->9824/12271; loss:0.764087\n",
|
|
"epoch:1/2; batch:9825->11052/12271; loss:0.964115\n",
|
|
"epoch:1/2; batch:11053->12271/12271; loss:0.502252\n",
|
|
"epoch:2/2; batch:1->1228/12271; loss:0.601600\n",
|
|
"epoch:2/2; batch:1229->2456/12271; loss:0.695099\n",
|
|
"epoch:2/2; batch:2457->3684/12271; loss:0.419610\n",
|
|
"epoch:2/2; batch:3685->4912/12271; loss:0.603106\n",
|
|
"epoch:2/2; batch:4913->6140/12271; loss:0.705180\n",
|
|
"epoch:2/2; batch:6141->7368/12271; loss:0.493404\n",
|
|
"epoch:2/2; batch:7369->8596/12271; loss:0.864921\n",
|
|
"epoch:2/2; batch:8597->9824/12271; loss:0.518601\n",
|
|
"epoch:2/2; batch:9825->11052/12271; loss:0.395920\n",
|
|
"epoch:2/2; batch:11053->12271/12271; loss:0.685858\n",
|
|
"Training time : 9.520 hrs\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"classifier_multi = BERTSequenceClassifier(language=LANGUAGE_MULTI,\n",
|
|
" num_labels=num_labels_hindi,\n",
|
|
" cache_dir=CACHE_DIR)\n",
|
|
"with Timer() as t:\n",
|
|
" classifier_multi.fit(token_ids=train_token_ids_hindi,\n",
|
|
" input_mask=train_input_mask_hindi,\n",
|
|
" token_type_ids=train_token_type_ids_hindi,\n",
|
|
" labels=train_labels_hindi,\n",
|
|
" num_gpus=NUM_GPUS,\n",
|
|
" num_epochs=NUM_EPOCHS,\n",
|
|
" batch_size=BATCH_SIZE,\n",
|
|
" lr=LEARNING_RATE,\n",
|
|
" warmup_proportion=WARMUP_PROPORTION)\n",
|
|
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Predict and Evaluate"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"5024it [01:02, 87.10it/s] "
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Prediction time : 0.017 hrs\n",
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
"contradiction 0.69 0.72 0.70 1670\n",
|
|
" entailment 0.74 0.51 0.60 1670\n",
|
|
" neutral 0.58 0.74 0.65 1670\n",
|
|
"\n",
|
|
" accuracy 0.65 5010\n",
|
|
" macro avg 0.67 0.65 0.65 5010\n",
|
|
" weighted avg 0.67 0.65 0.65 5010\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with Timer() as t:\n",
|
|
" predictions_hindi = classifier_multi.predict(token_ids=test_token_ids_hindi,\n",
|
|
" input_mask=test_input_mask_hindi,\n",
|
|
" token_type_ids=test_token_type_ids_hindi,\n",
|
|
" batch_size=BATCH_SIZE)\n",
|
|
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))\n",
|
|
"predictions_hindi= label_encoder_hindi.inverse_transform(predictions_hindi)\n",
|
|
"print(classification_report(test_df_hindi[LABEL_COL], predictions_hindi))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|