Merge pull request #113 from microsoft/hlu/two_sequence_utils_and_XNLI_notebook
Hlu/two sequence utils and xnli notebook
This commit is contained in:
Коммит
35fc04c383
|
@ -0,0 +1,581 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Multi-lingual Inference on XNLI Dataset using BERT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Summary\n",
|
||||
"In this notebook, we demostrate using the [Multi-lingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) to do language inference in Chinese and Hindi. We use the [XNLI](https://github.com/facebookresearch/XNLI) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral. \n",
|
||||
"The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.\n",
|
||||
"<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"nlp_path = os.path.abspath('../../')\n",
|
||||
"if nlp_path not in sys.path:\n",
|
||||
" sys.path.insert(0, nlp_path)\n",
|
||||
"\n",
|
||||
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
|
||||
"from utils_nlp.bert.common import Language, Tokenizer\n",
|
||||
"from utils_nlp.dataset.xnli import load_pandas_df\n",
|
||||
"from utils_nlp.common.timer import Timer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configurations\n",
|
||||
"Note that the running time shown in this notebook are on a Standard_NC12 Azure Deep Learning Virtual Machine with two NVIDIA Tesla K80 GPUs. If you want to run through the notebook quickly, you can change the `TRAIN_DATA_USED_PERCENT` to a small number, e.g. 0.01. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TRAIN_DATA_USED_PERCENT = 1.0\n",
|
||||
"\n",
|
||||
"# set random seeds\n",
|
||||
"RANDOM_SEED = 42\n",
|
||||
"random.seed(RANDOM_SEED)\n",
|
||||
"np.random.seed(RANDOM_SEED)\n",
|
||||
"torch.manual_seed(RANDOM_SEED)\n",
|
||||
"num_cuda_devices = torch.cuda.device_count()\n",
|
||||
"if num_cuda_devices > 1:\n",
|
||||
" torch.cuda.manual_seed_all(RANDOM_SEED)\n",
|
||||
"\n",
|
||||
"# model configurations\n",
|
||||
"LANGUAGE_CHINESE = Language.CHINESE\n",
|
||||
"LANGUAGE_MULTI = Language.MULTILINGUAL\n",
|
||||
"TO_LOWER = True\n",
|
||||
"MAX_SEQ_LENGTH = 128\n",
|
||||
"\n",
|
||||
"# training configurations\n",
|
||||
"NUM_GPUS = 2\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"NUM_EPOCHS = 2\n",
|
||||
"\n",
|
||||
"# optimizer configurations\n",
|
||||
"LEARNING_RATE= 5e-5\n",
|
||||
"WARMUP_PROPORTION= 0.1\n",
|
||||
"\n",
|
||||
"# data configurations\n",
|
||||
"TEXT_COL = \"text\"\n",
|
||||
"LABEL_COL = \"label\"\n",
|
||||
"\n",
|
||||
"CACHE_DIR = \"./temp\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Data\n",
|
||||
"The XNLI dataset comes in two zip files: \n",
|
||||
"* XNLI-1.0.zip: dev and test datasets in 15 languages. The original English data was translated into other languages by human translators. \n",
|
||||
"* XNLI-MT-1.0.zip: training dataset in 15 languages. This dataset is machine translations of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset. It also contains English translations of the dev and test datasets, but not used in this notebook. \n",
|
||||
"\n",
|
||||
"The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split` and `language`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\", language=\"zh\")\n",
|
||||
"dev_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev\", language=\"zh\")\n",
|
||||
"test_df_chinese = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"test\", language=\"zh\")\n",
|
||||
"\n",
|
||||
"train_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\", language=\"hi\")\n",
|
||||
"dev_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev\", language=\"hi\")\n",
|
||||
"test_df_hindi = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"test\", language=\"hi\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Chinese training dataset size: 392702\n",
|
||||
"Chinese dev dataset size: 2490\n",
|
||||
"Chinese test dataset size: 5010\n",
|
||||
"\n",
|
||||
"Hindi training dataset size: 392702\n",
|
||||
"Hindi dev dataset size: 2490\n",
|
||||
"Hindi test dataset size: 5010\n",
|
||||
"\n",
|
||||
" text label\n",
|
||||
"0 (从 概念 上 看 , 奶油 收入 有 两 个 基本 方面 产品 和 地理 ., 产品 和 ... neutral\n",
|
||||
"1 (你 知道 在 这个 季节 , 我 猜 在 你 的 水平 你 把 他们 丢到 下 一个 水平... entailment\n",
|
||||
"2 (我们 的 一个 号码 会 非常 详细 地 执行 你 的 指示, 我 团队 的 一个 成员 ... entailment\n",
|
||||
"3 (你 怎么 知道 的 ? 所有 这些 都 是 他们 的 信息 ., 这些 信息 属于 他们 .) entailment\n",
|
||||
"4 (是 啊 , 我 告诉 你 , 如果 你 去 买 一些 网球鞋 , 我 可以 看到 为什么 ... neutral\n",
|
||||
" text label\n",
|
||||
"0 (Conceptually क ् रीम एंजलिस में दो मूल आयाम ह... neutral\n",
|
||||
"1 (आप मौसम के दौरान जानते हैं और मैं अपने स ् तर... entailment\n",
|
||||
"2 (हमारे एक नंबर में से एक आपके निर ् देशों को म... entailment\n",
|
||||
"3 (आप कैसे जानते हैं ? ये सब उनकी जानकारी फिर से... entailment\n",
|
||||
"4 (हाँ मैं आपको बताता हूँ कि अगर आप उन टेनिस जूत... neutral\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Chinese training dataset size: {}\".format(train_df_chinese.shape[0]))\n",
|
||||
"print(\"Chinese dev dataset size: {}\".format(dev_df_chinese.shape[0]))\n",
|
||||
"print(\"Chinese test dataset size: {}\".format(test_df_chinese.shape[0]))\n",
|
||||
"print()\n",
|
||||
"print(\"Hindi training dataset size: {}\".format(train_df_hindi.shape[0]))\n",
|
||||
"print(\"Hindi dev dataset size: {}\".format(dev_df_hindi.shape[0]))\n",
|
||||
"print(\"Hindi test dataset size: {}\".format(test_df_hindi.shape[0]))\n",
|
||||
"print()\n",
|
||||
"print(train_df_chinese.head())\n",
|
||||
"print(train_df_hindi.head())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data_used_count = round(TRAIN_DATA_USED_PERCENT * train_df_chinese.shape[0])\n",
|
||||
"train_df_chinese = train_df_chinese.loc[:train_data_used_count]\n",
|
||||
"train_df_hindi = train_df_hindi.loc[:train_data_used_count]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Language Inference on Chinese\n",
|
||||
"For Chinese dataset, we use the `bert-base-chinese` model which was pretrained on Chinese dataset only. The `bert-base-multilingual-cased` model can also be used on Chinese, but the accuracy is 3% lower."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Tokenize and Preprocess\n",
|
||||
"Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 392702/392702 [02:26<00:00, 2682.67it/s]\n",
|
||||
"100%|██████████| 5010/5010 [00:01<00:00, 3122.04it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer_chinese = Tokenizer(LANGUAGE_CHINESE, to_lower=TO_LOWER, cache_dir=CACHE_DIR)\n",
|
||||
"\n",
|
||||
"train_tokens_chinese = tokenizer_chinese.tokenize(train_df_chinese[TEXT_COL])\n",
|
||||
"test_tokens_chinese= tokenizer_chinese.tokenize(test_df_chinese[TEXT_COL])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"\n",
|
||||
"* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"* Pad or truncate the token lists to the specified max length\n",
|
||||
"* Return mask lists that indicate paddings' positions\n",
|
||||
"* Return token type id lists that indicate which sentence the tokens belong to\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_token_ids_chinese, train_input_mask_chinese, train_token_type_ids_chinese = \\\n",
|
||||
" tokenizer_chinese.preprocess_classification_tokens(train_tokens_chinese, max_len=MAX_SEQ_LENGTH)\n",
|
||||
"test_token_ids_chinese, test_input_mask_chinese, test_token_type_ids_chinese = \\\n",
|
||||
" tokenizer_chinese.preprocess_classification_tokens(test_tokens_chinese, max_len=MAX_SEQ_LENGTH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_encoder_chinese = LabelEncoder()\n",
|
||||
"train_labels_chinese = label_encoder_chinese.fit_transform(train_df_chinese[LABEL_COL])\n",
|
||||
"num_labels_chinese = len(np.unique(train_labels_chinese))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create Classifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier_chinese = BERTSequenceClassifier(language=LANGUAGE_CHINESE,\n",
|
||||
" num_labels=num_labels_chinese,\n",
|
||||
" cache_dir=CACHE_DIR)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Train Classifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/2; batch:1->1228/12271; loss:1.194384\n",
|
||||
"epoch:1/2; batch:1229->2456/12271; loss:0.863067\n",
|
||||
"epoch:1/2; batch:2457->3684/12271; loss:0.781256\n",
|
||||
"epoch:1/2; batch:3685->4912/12271; loss:1.067413\n",
|
||||
"epoch:1/2; batch:4913->6140/12271; loss:0.599279\n",
|
||||
"epoch:1/2; batch:6141->7368/12271; loss:0.471488\n",
|
||||
"epoch:1/2; batch:7369->8596/12271; loss:0.572327\n",
|
||||
"epoch:1/2; batch:8597->9824/12271; loss:0.689093\n",
|
||||
"epoch:1/2; batch:9825->11052/12271; loss:0.651702\n",
|
||||
"epoch:1/2; batch:11053->12271/12271; loss:0.431085\n",
|
||||
"epoch:2/2; batch:1->1228/12271; loss:0.255859\n",
|
||||
"epoch:2/2; batch:1229->2456/12271; loss:0.434052\n",
|
||||
"epoch:2/2; batch:2457->3684/12271; loss:0.433569\n",
|
||||
"epoch:2/2; batch:3685->4912/12271; loss:0.405915\n",
|
||||
"epoch:2/2; batch:4913->6140/12271; loss:0.636128\n",
|
||||
"epoch:2/2; batch:6141->7368/12271; loss:0.416685\n",
|
||||
"epoch:2/2; batch:7369->8596/12271; loss:0.265789\n",
|
||||
"epoch:2/2; batch:8597->9824/12271; loss:0.328964\n",
|
||||
"epoch:2/2; batch:9825->11052/12271; loss:0.436310\n",
|
||||
"epoch:2/2; batch:11053->12271/12271; loss:0.374193\n",
|
||||
"Training time : 8.050 hrs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with Timer() as t:\n",
|
||||
" classifier_chinese.fit(token_ids=train_token_ids_chinese,\n",
|
||||
" input_mask=train_input_mask_chinese,\n",
|
||||
" token_type_ids=train_token_type_ids_chinese,\n",
|
||||
" labels=train_labels_chinese,\n",
|
||||
" num_gpus=NUM_GPUS,\n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" lr=LEARNING_RATE,\n",
|
||||
" warmup_proportion=WARMUP_PROPORTION)\n",
|
||||
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Predict on Test Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"5024it [00:54, 101.88it/s] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Prediction time : 0.015 hrs\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with Timer() as t:\n",
|
||||
" predictions_chinese = classifier_chinese.predict(token_ids=test_token_ids_chinese,\n",
|
||||
" input_mask=test_input_mask_chinese,\n",
|
||||
" token_type_ids=test_token_type_ids_chinese,\n",
|
||||
" batch_size=BATCH_SIZE)\n",
|
||||
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
"contradiction 0.81 0.84 0.82 1670\n",
|
||||
" entailment 0.84 0.68 0.76 1670\n",
|
||||
" neutral 0.70 0.80 0.74 1670\n",
|
||||
"\n",
|
||||
" accuracy 0.77 5010\n",
|
||||
" macro avg 0.78 0.77 0.77 5010\n",
|
||||
" weighted avg 0.78 0.77 0.77 5010\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predictions_chinese = label_encoder_chinese.inverse_transform(predictions_chinese)\n",
|
||||
"print(classification_report(test_df_chinese[LABEL_COL], predictions_chinese))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Language Inference on Hindi\n",
|
||||
"For Hindi and all other languages except Chinese, we use the `bert-base-multilingual-cased` model. \n",
|
||||
"The preprocesing, model training, and prediction steps are the same as on Chinese data, except for the underlying tokenizer and BERT model used"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Tokenize and Preprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 392702/392702 [03:48<00:00, 1719.84it/s]\n",
|
||||
"100%|██████████| 5010/5010 [00:02<00:00, 1916.46it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer_multi = Tokenizer(LANGUAGE_MULTI, cache_dir=CACHE_DIR)\n",
|
||||
"\n",
|
||||
"train_tokens_hindi = tokenizer_multi.tokenize(train_df_hindi[TEXT_COL])\n",
|
||||
"test_tokens_hindi= tokenizer_multi.tokenize(test_df_hindi[TEXT_COL])\n",
|
||||
"\n",
|
||||
"train_token_ids_hindi, train_input_mask_hindi, train_token_type_ids_hindi = \\\n",
|
||||
" tokenizer_multi.preprocess_classification_tokens(train_tokens_hindi, max_len=MAX_SEQ_LENGTH)\n",
|
||||
"test_token_ids_hindi, test_input_mask_hindi, test_token_type_ids_hindi = \\\n",
|
||||
" tokenizer_multi.preprocess_classification_tokens(test_tokens_hindi, max_len=MAX_SEQ_LENGTH)\n",
|
||||
"\n",
|
||||
"label_encoder_hindi = LabelEncoder()\n",
|
||||
"train_labels_hindi = label_encoder_hindi.fit_transform(train_df_hindi[LABEL_COL])\n",
|
||||
"num_labels_hindi = len(np.unique(train_labels_hindi))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create and Train Classifier"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/2; batch:1->1228/12271; loss:1.091754\n",
|
||||
"epoch:1/2; batch:1229->2456/12271; loss:0.992931\n",
|
||||
"epoch:1/2; batch:2457->3684/12271; loss:1.045146\n",
|
||||
"epoch:1/2; batch:3685->4912/12271; loss:0.799912\n",
|
||||
"epoch:1/2; batch:4913->6140/12271; loss:0.815425\n",
|
||||
"epoch:1/2; batch:6141->7368/12271; loss:0.564856\n",
|
||||
"epoch:1/2; batch:7369->8596/12271; loss:0.726981\n",
|
||||
"epoch:1/2; batch:8597->9824/12271; loss:0.764087\n",
|
||||
"epoch:1/2; batch:9825->11052/12271; loss:0.964115\n",
|
||||
"epoch:1/2; batch:11053->12271/12271; loss:0.502252\n",
|
||||
"epoch:2/2; batch:1->1228/12271; loss:0.601600\n",
|
||||
"epoch:2/2; batch:1229->2456/12271; loss:0.695099\n",
|
||||
"epoch:2/2; batch:2457->3684/12271; loss:0.419610\n",
|
||||
"epoch:2/2; batch:3685->4912/12271; loss:0.603106\n",
|
||||
"epoch:2/2; batch:4913->6140/12271; loss:0.705180\n",
|
||||
"epoch:2/2; batch:6141->7368/12271; loss:0.493404\n",
|
||||
"epoch:2/2; batch:7369->8596/12271; loss:0.864921\n",
|
||||
"epoch:2/2; batch:8597->9824/12271; loss:0.518601\n",
|
||||
"epoch:2/2; batch:9825->11052/12271; loss:0.395920\n",
|
||||
"epoch:2/2; batch:11053->12271/12271; loss:0.685858\n",
|
||||
"Training time : 9.520 hrs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"classifier_multi = BERTSequenceClassifier(language=LANGUAGE_MULTI,\n",
|
||||
" num_labels=num_labels_hindi,\n",
|
||||
" cache_dir=CACHE_DIR)\n",
|
||||
"with Timer() as t:\n",
|
||||
" classifier_multi.fit(token_ids=train_token_ids_hindi,\n",
|
||||
" input_mask=train_input_mask_hindi,\n",
|
||||
" token_type_ids=train_token_type_ids_hindi,\n",
|
||||
" labels=train_labels_hindi,\n",
|
||||
" num_gpus=NUM_GPUS,\n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" lr=LEARNING_RATE,\n",
|
||||
" warmup_proportion=WARMUP_PROPORTION)\n",
|
||||
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Predict and Evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"5024it [01:02, 87.10it/s] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Prediction time : 0.017 hrs\n",
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
"contradiction 0.69 0.72 0.70 1670\n",
|
||||
" entailment 0.74 0.51 0.60 1670\n",
|
||||
" neutral 0.58 0.74 0.65 1670\n",
|
||||
"\n",
|
||||
" accuracy 0.65 5010\n",
|
||||
" macro avg 0.67 0.65 0.65 5010\n",
|
||||
" weighted avg 0.67 0.65 0.65 5010\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with Timer() as t:\n",
|
||||
" predictions_hindi = classifier_multi.predict(token_ids=test_token_ids_hindi,\n",
|
||||
" input_mask=test_input_mask_hindi,\n",
|
||||
" token_type_ids=test_token_type_ids_hindi,\n",
|
||||
" batch_size=BATCH_SIZE)\n",
|
||||
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))\n",
|
||||
"predictions_hindi= label_encoder_hindi.inverse_transform(predictions_hindi)\n",
|
||||
"print(classification_report(test_df_hindi[LABEL_COL], predictions_hindi))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "pytorch",
|
||||
"language": "python",
|
||||
"name": "pytorch"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -340,6 +340,7 @@
|
|||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"- Return token type id lists that indicate which sentence the tokens belong to (not needed for one-sequence classification)\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
|
@ -350,10 +351,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
"tokens_train, mask_train, _ = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
|
||||
"tokens_test, mask_test, _ = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_test, MAX_LEN\n",
|
||||
")"
|
||||
]
|
||||
|
@ -511,7 +512,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
"version": "3.6.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -275,6 +275,7 @@
|
|||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"- Return token type id lists that indicate which sentence the tokens belong to (not needed for one-sequence classification)\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
|
@ -285,10 +286,10 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
"tokens_train, mask_train, _ = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
|
||||
"tokens_test, mask_test, _ = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_test, MAX_LEN\n",
|
||||
")"
|
||||
]
|
||||
|
@ -446,7 +447,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
"version": "3.6.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# This script reuses some code from
|
||||
# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples
|
||||
# /run_classifier.py
|
||||
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from enum import Enum
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
|
@ -47,27 +53,38 @@ class Tokenizer:
|
|||
def tokenize(self, text):
|
||||
"""Tokenizes a list of documents using a BERT tokenizer
|
||||
Args:
|
||||
text (list(str)): list of text documents.
|
||||
text (list): List of strings (one sequence) or
|
||||
tuples (two sequences).
|
||||
|
||||
Returns:
|
||||
[list(str)]: list of token lists.
|
||||
[list]: List of lists. Each sublist contains WordPiece tokens
|
||||
of the input sequence(s).
|
||||
"""
|
||||
tokens = [self.tokenizer.tokenize(x) for x in text]
|
||||
return tokens
|
||||
if isinstance(text[0], str):
|
||||
return [self.tokenizer.tokenize(x) for x in tqdm(text)]
|
||||
else:
|
||||
return [
|
||||
[self.tokenizer.tokenize(x) for x in sentences]
|
||||
for sentences in tqdm(text)
|
||||
]
|
||||
|
||||
def preprocess_classification_tokens(self, tokens, max_len=BERT_MAX_LEN):
|
||||
"""Preprocessing of input tokens:
|
||||
- add BERT sentence markers ([CLS] and [SEP])
|
||||
- map tokens to indices
|
||||
- map tokens to token indices in the BERT vocabulary
|
||||
- pad and truncate sequences
|
||||
- create an input_mask
|
||||
- create token type ids, aka. segment ids
|
||||
Args:
|
||||
tokens (list): List of tokens to preprocess.
|
||||
tokens (list): List of token lists to preprocess.
|
||||
max_len (int, optional): Maximum number of tokens
|
||||
(documents will be truncated or padded).
|
||||
Defaults to 512.
|
||||
Returns:
|
||||
list of preprocesssed token lists
|
||||
list of input mask lists
|
||||
tuple: A tuple containing the following three lists
|
||||
list of preprocesssed token lists
|
||||
list of input mask lists
|
||||
list of token type id lists
|
||||
"""
|
||||
if max_len > BERT_MAX_LEN:
|
||||
print(
|
||||
|
@ -77,15 +94,67 @@ class Tokenizer:
|
|||
)
|
||||
max_len = BERT_MAX_LEN
|
||||
|
||||
# truncate and add BERT sentence markers
|
||||
tokens = [["[CLS]"] + x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
|
||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
# This is a simple heuristic which will always truncate the longer
|
||||
# sequence one token at a time. This makes more sense than
|
||||
# truncating an equal percent of tokens from each, since if one
|
||||
# sequence is very short then each token that's truncated likely
|
||||
# contains more information than a longer sequence.
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
tokens_a.append("[SEP]")
|
||||
tokens_b.append("[SEP]")
|
||||
|
||||
return [tokens_a, tokens_b]
|
||||
|
||||
if isinstance(tokens[0], str):
|
||||
tokens = [x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
|
||||
token_type_ids = None
|
||||
else:
|
||||
# print(tokens[:2])
|
||||
# get tokens for each sentence [[t00, t01, ...] [t10, t11,... ]]
|
||||
tokens = [
|
||||
_truncate_seq_pair(sentence[0], sentence[1], max_len - 3)
|
||||
for sentence in tokens
|
||||
]
|
||||
|
||||
# construct token_type_ids
|
||||
# [[0, 0, 0, 0, ... 0, 1, 1, 1, ... 1], [0, 0, 0, ..., 1, 1, ]
|
||||
token_type_ids = [
|
||||
[[i] * len(sentence) for i, sentence in enumerate(example)]
|
||||
for example in tokens
|
||||
]
|
||||
# merge sentences
|
||||
tokens = [
|
||||
[token for sentence in example for token in sentence]
|
||||
for example in tokens
|
||||
]
|
||||
# prefix with [0] for [CLS]
|
||||
token_type_ids = [
|
||||
[0] + [i for sentence in example for i in sentence]
|
||||
for example in token_type_ids
|
||||
]
|
||||
# pad sequence
|
||||
token_type_ids = [
|
||||
x + [0] * (max_len - len(x)) for x in token_type_ids
|
||||
]
|
||||
|
||||
tokens = [["[CLS]"] + x for x in tokens]
|
||||
# convert tokens to indices
|
||||
tokens = [self.tokenizer.convert_tokens_to_ids(x) for x in tokens]
|
||||
# pad sequence
|
||||
tokens = [x + [0] * (max_len - len(x)) for x in tokens]
|
||||
# create input mask
|
||||
input_mask = [[min(1, x) for x in y] for y in tokens]
|
||||
return tokens, input_mask
|
||||
return tokens, input_mask, token_type_ids
|
||||
|
||||
def preprocess_ner_tokens(
|
||||
self,
|
||||
|
|
|
@ -47,10 +47,12 @@ class BERTSequenceClassifier:
|
|||
token_ids,
|
||||
input_mask,
|
||||
labels,
|
||||
token_type_ids=None,
|
||||
num_gpus=None,
|
||||
num_epochs=1,
|
||||
batch_size=32,
|
||||
lr=2e-5,
|
||||
warmup_proportion=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""Fine-tunes the BERT classifier using the given training data.
|
||||
|
@ -58,8 +60,10 @@ class BERTSequenceClassifier:
|
|||
token_ids (list): List of training token id lists.
|
||||
input_mask (list): List of input mask lists.
|
||||
labels (list): List of training labels.
|
||||
device (str, optional): Device used for training ("cpu" or "gpu").
|
||||
Defaults to "gpu".
|
||||
token_type_ids (list, optional): List of lists. Each sublist
|
||||
contains segment ids indicating if the token belongs to
|
||||
the first sentence(0) or second sentence(1). Only needed
|
||||
for two-sentence tasks.
|
||||
num_gpus (int, optional): The number of gpus to use.
|
||||
If None is specified, all available GPUs
|
||||
will be used. Defaults to None.
|
||||
|
@ -67,6 +71,9 @@ class BERTSequenceClassifier:
|
|||
Defaults to 1.
|
||||
batch_size (int, optional): Training batch size. Defaults to 32.
|
||||
lr (float): Learning rate of the Adam optimizer. Defaults to 2e-5.
|
||||
warmup_proportion (float, optional): Proportion of training to
|
||||
perform linear learning rate warmup for. E.g., 0.1 = 10% of
|
||||
training. Defaults to None.
|
||||
verbose (bool, optional): If True, shows the training progress and
|
||||
loss values. Defaults to True.
|
||||
"""
|
||||
|
@ -95,16 +102,27 @@ class BERTSequenceClassifier:
|
|||
},
|
||||
]
|
||||
|
||||
opt = BertAdam(optimizer_grouped_parameters, lr=lr)
|
||||
num_examples = len(token_ids)
|
||||
num_batches = int(num_examples / batch_size)
|
||||
num_train_optimization_steps = num_batches * num_epochs
|
||||
|
||||
if warmup_proportion is None:
|
||||
opt = BertAdam(optimizer_grouped_parameters, lr=lr)
|
||||
else:
|
||||
opt = BertAdam(
|
||||
optimizer_grouped_parameters,
|
||||
lr=lr,
|
||||
t_total=num_train_optimization_steps,
|
||||
warmup=warmup_proportion,
|
||||
)
|
||||
|
||||
# define loss function
|
||||
loss_func = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
# train
|
||||
self.model.train() # training mode
|
||||
num_examples = len(token_ids)
|
||||
num_batches = int(num_examples / batch_size)
|
||||
|
||||
token_type_ids_batch = None
|
||||
for epoch in range(num_epochs):
|
||||
for i in range(num_batches):
|
||||
|
||||
|
@ -121,11 +139,18 @@ class BERTSequenceClassifier:
|
|||
input_mask[start:end], dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids_batch = torch.tensor(
|
||||
token_type_ids[start:end],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
y_h = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=None,
|
||||
token_type_ids=token_type_ids_batch,
|
||||
attention_mask=mask_batch,
|
||||
labels=None,
|
||||
)
|
||||
|
@ -146,14 +171,25 @@ class BERTSequenceClassifier:
|
|||
)
|
||||
)
|
||||
# empty cache
|
||||
del [x_batch, y_batch, mask_batch]
|
||||
del [x_batch, y_batch, mask_batch, token_type_ids_batch]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
|
||||
def predict(
|
||||
self,
|
||||
token_ids,
|
||||
input_mask,
|
||||
token_type_ids=None,
|
||||
num_gpus=None,
|
||||
batch_size=32,
|
||||
):
|
||||
"""Scores the given dataset and returns the predicted classes.
|
||||
Args:
|
||||
token_ids (list): List of training token lists.
|
||||
input_mask (list): List of input mask lists.
|
||||
token_type_ids (list, optional): List of lists. Each sublist
|
||||
contains segment ids indicating if the token belongs to
|
||||
the first sentence(0) or second sentence(1). Only needed
|
||||
for two-sentence tasks.
|
||||
num_gpus (int, optional): The number of gpus to use.
|
||||
If None is specified, all available GPUs
|
||||
will be used. Defaults to None.
|
||||
|
@ -178,10 +214,17 @@ class BERTSequenceClassifier:
|
|||
mask_batch = torch.tensor(
|
||||
mask_batch, dtype=torch.long, device=device
|
||||
)
|
||||
token_type_ids_batch = None
|
||||
if token_type_ids is not None:
|
||||
token_type_ids_batch = torch.tensor(
|
||||
token_type_ids[i : i + batch_size],
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
p_batch = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=None,
|
||||
token_type_ids=token_type_ids_batch,
|
||||
attention_mask=mask_batch,
|
||||
labels=None,
|
||||
)
|
||||
|
|
|
@ -22,7 +22,8 @@ def to_lowercase_all(df):
|
|||
|
||||
def to_lowercase(df, column_names=[]):
|
||||
"""
|
||||
This function transforms strings of the column names in the dataframe passed to lowercase
|
||||
This function transforms strings of the column names in the dataframe
|
||||
passed to lowercase
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Raw dataframe with some text columns.
|
||||
|
@ -46,18 +47,18 @@ def to_spacy_tokens(
|
|||
token_cols=["sentence1_tokens", "sentence2_tokens"],
|
||||
):
|
||||
"""
|
||||
This function tokenizes the sentence pairs using spaCy, defaulting to the
|
||||
spaCy en_core_web_sm model
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe with columns sentence_cols to tokenize.
|
||||
sentence_cols (list, optional): Column names of the raw sentence pairs.
|
||||
token_cols (list, optional): Column names for the tokenized sentences.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Dataframe with new columns token_cols, each containing
|
||||
a list of tokens for their respective sentences.
|
||||
"""
|
||||
This function tokenizes the sentence pairs using spaCy, defaulting to the
|
||||
spaCy en_core_web_sm model
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe with columns sentence_cols to tokenize.
|
||||
sentence_cols (list, optional): Column names of the raw sentence pairs.
|
||||
token_cols (list, optional): Column names for the tokenized sentences.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Dataframe with new columns token_cols, each containing
|
||||
a list of tokens for their respective sentences.
|
||||
"""
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
text_df = df[sentence_cols]
|
||||
nlp_df = text_df.applymap(lambda x: nlp(x))
|
||||
|
@ -77,21 +78,22 @@ def rm_spacy_stopwords(
|
|||
custom_stopwords=[],
|
||||
):
|
||||
"""
|
||||
This function tokenizes the sentence pairs using spaCy and remove stopwords,
|
||||
defaulting to the spaCy en_core_web_sm model
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe with columns sentence_cols to tokenize.
|
||||
sentence_cols (list, optional): Column names for the raw sentence pairs.
|
||||
stop_cols (list, optional): Column names for the tokenized sentences
|
||||
without stop words.
|
||||
custom_stopwords (list of str, optional): List of custom stopwords to
|
||||
register with the spaCy model.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Dataframe with new columns stop_cols, each containing a
|
||||
list of tokens for their respective sentences.
|
||||
"""
|
||||
This function tokenizes the sentence pairs using spaCy and remove
|
||||
stopwords, defaulting to the spaCy en_core_web_sm model
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe with columns sentence_cols to tokenize.
|
||||
sentence_cols (list, optional): Column names for the raw sentence
|
||||
pairs.
|
||||
stop_cols (list, optional): Column names for the tokenized sentences
|
||||
without stop words.
|
||||
custom_stopwords (list of str, optional): List of custom stopwords to
|
||||
register with the spaCy model.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Dataframe with new columns stop_cols, each containing a
|
||||
list of tokens for their respective sentences.
|
||||
"""
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
if len(custom_stopwords) > 0:
|
||||
for csw in custom_stopwords:
|
||||
|
@ -160,3 +162,13 @@ def rm_nltk_stopwords(
|
|||
|
||||
stop_df.columns = stop_cols
|
||||
return pd.concat([df, stop_df], axis=1)
|
||||
|
||||
|
||||
def convert_to_unicode(input_text):
|
||||
"""Converts intput_text to Unicode. Input must be utf-8."""
|
||||
if isinstance(input_text, str):
|
||||
return input_text
|
||||
elif isinstance(input_text, bytes):
|
||||
return input_text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise TypeError("Unsupported string type: %s" % (type(input_text)))
|
||||
|
|
|
@ -10,37 +10,86 @@ import os
|
|||
import pandas as pd
|
||||
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
from utils_nlp.dataset.preprocess import convert_to_unicode
|
||||
|
||||
URL = "https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip"
|
||||
|
||||
DATA_FILES = {
|
||||
"dev": "XNLI-1.0/xnli.dev.jsonl",
|
||||
"test": "XNLI-1.0/xnli.test.jsonl",
|
||||
}
|
||||
URL_XNLI = "https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip"
|
||||
URL_XNLI_MT = "https://www.nyu.edu/projects/bowman/xnli/XNLI-MT-1.0.zip"
|
||||
|
||||
|
||||
def load_pandas_df(local_cache_path=None, file_split="dev"):
|
||||
"""Downloads and extracts the dataset files
|
||||
def load_pandas_df(local_cache_path="./", file_split="dev", language="zh"):
|
||||
"""Downloads and extracts the dataset files.
|
||||
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description].
|
||||
Defaults to None.
|
||||
local_cache_path (str, optional): Path to store the data.
|
||||
Defaults to "./".
|
||||
file_split (str, optional): The subset to load.
|
||||
One of: {"dev", "test"}
|
||||
Defaults to "train".
|
||||
One of: {"train", "dev", "test"}
|
||||
Defaults to "dev".
|
||||
language (str, optional): language subset to read.
|
||||
One of: {"en", "fr", "es", "de", "el", "bg", "ru",
|
||||
"tr", "ar", "vi", "th", "zh", "hi", "sw", "ur"}
|
||||
Defaults to "zh" (Chinese).
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the specified
|
||||
XNLI subset.
|
||||
"""
|
||||
|
||||
file_name = URL.split("/")[-1]
|
||||
maybe_download(URL, file_name, local_cache_path)
|
||||
if file_split in ("dev", "test"):
|
||||
url = URL_XNLI
|
||||
sentence_1_index = 6
|
||||
sentence_2_index = 7
|
||||
label_index = 1
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split])
|
||||
):
|
||||
extract_zip(
|
||||
os.path.join(local_cache_path, file_name), local_cache_path
|
||||
zip_file_name = url.split("/")[-1]
|
||||
folder_name = ".".join(zip_file_name.split(".")[:-1])
|
||||
file_name = folder_name + "/" + ".".join(["xnli", file_split, "tsv"])
|
||||
elif file_split == "train":
|
||||
url = URL_XNLI_MT
|
||||
sentence_1_index = 0
|
||||
sentence_2_index = 1
|
||||
label_index = 2
|
||||
|
||||
zip_file_name = url.split("/")[-1]
|
||||
folder_name = ".".join(zip_file_name.split(".")[:-1])
|
||||
file_name = (
|
||||
folder_name
|
||||
+ "/multinli/"
|
||||
+ ".".join(["multinli", file_split, language, "tsv"])
|
||||
)
|
||||
return pd.read_json(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True
|
||||
)
|
||||
|
||||
maybe_download(url, zip_file_name, local_cache_path)
|
||||
|
||||
if not os.path.exists(os.path.join(local_cache_path, folder_name)):
|
||||
extract_zip(
|
||||
os.path.join(local_cache_path, zip_file_name), local_cache_path
|
||||
)
|
||||
|
||||
with open(
|
||||
os.path.join(local_cache_path, file_name), "r", encoding="utf-8"
|
||||
) as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
line_list = [line.split("\t") for line in lines]
|
||||
# Remove the column name row
|
||||
line_list.pop(0)
|
||||
if file_split != "train":
|
||||
line_list = [line for line in line_list if line[0] == language]
|
||||
|
||||
label_list = [convert_to_unicode(line[label_index]) for line in line_list]
|
||||
old_contradict_label = convert_to_unicode("contradictory")
|
||||
new_contradict_label = convert_to_unicode("contradiction")
|
||||
label_list = [
|
||||
new_contradict_label if label == old_contradict_label else label
|
||||
for label in label_list
|
||||
]
|
||||
text_list = [
|
||||
(
|
||||
convert_to_unicode(line[sentence_1_index]),
|
||||
convert_to_unicode(line[sentence_2_index]),
|
||||
)
|
||||
for line in line_list
|
||||
]
|
||||
|
||||
df = pd.DataFrame({"text": text_list, "label": label_list})
|
||||
|
||||
return df
|
||||
|
|
Загрузка…
Ссылка в новой задаче