Updated utils of XNLI dataset.
This commit is contained in:
Родитель
593bb4eb5d
Коммит
ed3415b320
|
@ -6,41 +6,79 @@ https://www.nyu.edu/projects/bowman/xnli/
|
|||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
from utils_nlp.dataset.preprocess import convert_to_unicode
|
||||
|
||||
URL = "https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip"
|
||||
URL_XNLI = "https://www.nyu.edu/projects/bowman/xnli/XNLI-1.0.zip"
|
||||
URL_XNLI_MT = "https://www.nyu.edu/projects/bowman/xnli/XNLI-MT-1.0.zip"
|
||||
|
||||
DATA_FILES = {
|
||||
"dev": "XNLI-1.0/xnli.dev.jsonl",
|
||||
"test": "XNLI-1.0/xnli.test.jsonl",
|
||||
}
|
||||
|
||||
|
||||
def load_pandas_df(local_cache_path=None, file_split="dev"):
|
||||
def load_pandas_df(local_cache_path="./", file_split="dev", language="zh"):
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description].
|
||||
Defaults to None.
|
||||
local_cache_path (str, optional): Path to store the data.
|
||||
Defaults to "./".
|
||||
file_split (str, optional): The subset to load.
|
||||
One of: {"dev", "test"}
|
||||
Defaults to "train".
|
||||
One of: {"train", "dev", "test"}
|
||||
Defaults to "dev".
|
||||
language (str, optional): language subset to read.
|
||||
One of: {"en", "fr", "es", "de", "el", "bg", "ru",
|
||||
"tr", "ar", "vi", "th", "zh", "hi", "sw", "ur"}
|
||||
Defaults to "zh" (Chinese).
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the specified
|
||||
XNLI subset.
|
||||
"""
|
||||
|
||||
file_name = URL.split("/")[-1]
|
||||
maybe_download(URL, file_name, local_cache_path)
|
||||
if file_split in ("dev", "test"):
|
||||
url = URL_XNLI
|
||||
sentence_1_index = 6
|
||||
sentence_2_index = 7
|
||||
label_index = 1
|
||||
|
||||
zip_file_name = url.split("/")[-1]
|
||||
folder_name = ".".join(zip_file_name.split(".")[:-1])
|
||||
file_name = folder_name + "/" + ".".join(["xnli", file_split, "tsv"])
|
||||
elif file_split == "train":
|
||||
url = URL_XNLI_MT
|
||||
sentence_1_index = 0
|
||||
sentence_2_index = 1
|
||||
label_index = 2
|
||||
|
||||
zip_file_name = url.split("/")[-1]
|
||||
folder_name = ".".join(zip_file_name.split(".")[:-1])
|
||||
file_name = folder_name + "/multinli/" + ".".join(["multinli", file_split, language, "tsv"])
|
||||
|
||||
maybe_download(url, zip_file_name, local_cache_path)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split])
|
||||
os.path.join(local_cache_path, folder_name)
|
||||
):
|
||||
extract_zip(
|
||||
os.path.join(local_cache_path, file_name), local_cache_path
|
||||
os.path.join(local_cache_path, zip_file_name), local_cache_path
|
||||
)
|
||||
return pd.read_json(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True
|
||||
)
|
||||
|
||||
with open(os.path.join(local_cache_path, file_name), "r", encoding="utf-8") as f:
|
||||
lines = f.read().splitlines()
|
||||
|
||||
line_list = [line.split("\t") for line in lines]
|
||||
# Remove the column name row
|
||||
line_list.pop(0)
|
||||
if file_split != "train":
|
||||
line_list = [line for line in line_list if line[0] == language]
|
||||
|
||||
label_list = [convert_to_unicode(line[label_index]) for line in line_list]
|
||||
old_contradict_label = convert_to_unicode("contradictory")
|
||||
new_contradict_label = convert_to_unicode("contradiction")
|
||||
label_list = [new_contradict_label if label == old_contradict_label else label for label in label_list]
|
||||
text_list = [(convert_to_unicode(line[sentence_1_index]), convert_to_unicode(line[sentence_2_index])) for line in line_list]
|
||||
|
||||
df = pd.DataFrame({"text": text_list, "label": label_list})
|
||||
|
||||
return df
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_pandas_df()
|
Загрузка…
Ссылка в новой задаче