Feat: add functionality to download MNLI preprocessed tsv data.
Leverage NYU Jiant Toolkit preprocessed tsv data source
This commit is contained in:
Родитель
150909f39e
Коммит
2091c3895e
|
@ -22,6 +22,9 @@ from utils_nlp.models.transformers.common import MAX_SEQ_LEN
|
|||
from utils_nlp.models.transformers.sequence_classification import Processor
|
||||
|
||||
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
|
||||
|
||||
# Source - https://github.com/nyu-mll/jiant/blob/master/scripts/download_glue_data.py
|
||||
URL_JIANT_MNLI_TSV = "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce"
|
||||
DATA_FILES = {
|
||||
"train": "multinli_1.0/multinli_1.0_train.jsonl",
|
||||
"dev_matched": "multinli_1.0/multinli_1.0_dev_matched.jsonl",
|
||||
|
@ -29,7 +32,9 @@ DATA_FILES = {
|
|||
}
|
||||
|
||||
|
||||
def download_file_and_extract(local_cache_path: str = ".", file_split: str = "train") -> None:
|
||||
def download_file_and_extract(
|
||||
local_cache_path: str = ".", file_split: str = "train"
|
||||
) -> None:
|
||||
"""Download and extract the dataset files
|
||||
|
||||
Args:
|
||||
|
@ -46,6 +51,31 @@ def download_file_and_extract(local_cache_path: str = ".", file_split: str = "tr
|
|||
extract_zip(os.path.join(local_cache_path, file_name), local_cache_path)
|
||||
|
||||
|
||||
def download_tsv_files_and_extract(local_cache_path: str = ".") -> None:
|
||||
"""Download and extract the dataset files in tsv format from NYU Jiant
|
||||
downloads both original and tsv formatted data.
|
||||
|
||||
Args:
|
||||
local_cache_path (str [optional]) -- Directory to cache files to. Defaults to current working directory (default: {"."})
|
||||
|
||||
Returns:
|
||||
None -- Nothing is returned
|
||||
"""
|
||||
try:
|
||||
folder_name = "MNLI"
|
||||
file_name = f"{folder_name}.zip"
|
||||
maybe_download(URL_JIANT_MNLI_TSV, file_name, local_cache_path)
|
||||
if not os.path.exists(os.path.join(local_cache_path, folder_name)):
|
||||
extract_zip(os.path.join(local_cache_path, file_name), local_cache_path)
|
||||
|
||||
# Clean up zip download
|
||||
if os.path.exists(os.path.join(local_cache_path, file_name)):
|
||||
os.remove(os.path.join(local_cache_path, file_name))
|
||||
except IOError as e:
|
||||
raise (e)
|
||||
print("Downloaded file to: ", os.path.join(local_cache_path, folder_name))
|
||||
|
||||
|
||||
def load_pandas_df(local_cache_path=".", file_split="train"):
|
||||
"""Loads extracted dataset into pandas
|
||||
Args:
|
||||
|
@ -61,10 +91,18 @@ def load_pandas_df(local_cache_path=".", file_split="train"):
|
|||
download_file_and_extract(local_cache_path, file_split)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return pd.read_json(os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True)
|
||||
return pd.read_json(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split]), lines=True
|
||||
)
|
||||
|
||||
|
||||
def get_generator(local_cache_path=".", file_split="train", block_size=10e6, batch_size=10e6, num_batches=None):
|
||||
def get_generator(
|
||||
local_cache_path=".",
|
||||
file_split="train",
|
||||
block_size=10e6,
|
||||
batch_size=10e6,
|
||||
num_batches=None,
|
||||
):
|
||||
""" Returns an extracted dataset as a random batch generator that
|
||||
yields pandas dataframes.
|
||||
Args:
|
||||
|
@ -84,9 +122,13 @@ def get_generator(local_cache_path=".", file_split="train", block_size=10e6, bat
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
loader = DaskJSONLoader(os.path.join(local_cache_path, DATA_FILES[file_split]), block_size=block_size)
|
||||
loader = DaskJSONLoader(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split]), block_size=block_size
|
||||
)
|
||||
|
||||
return loader.get_sequential_batches(batch_size=int(batch_size), num_batches=num_batches)
|
||||
return loader.get_sequential_batches(
|
||||
batch_size=int(batch_size), num_batches=num_batches
|
||||
)
|
||||
|
||||
|
||||
def load_tc_dataset(
|
||||
|
@ -161,17 +203,23 @@ def load_tc_dataset(
|
|||
label_encoder.fit(all_df[label_col])
|
||||
|
||||
if test_fraction < 0 or test_fraction >= 1.0:
|
||||
logging.warning("Invalid test fraction value: {}, changed to 0.25".format(test_fraction))
|
||||
logging.warning(
|
||||
"Invalid test fraction value: {}, changed to 0.25".format(test_fraction)
|
||||
)
|
||||
test_fraction = 0.25
|
||||
|
||||
train_df, test_df = train_test_split(all_df, train_size=(1.0 - test_fraction), random_state=random_seed)
|
||||
train_df, test_df = train_test_split(
|
||||
all_df, train_size=(1.0 - test_fraction), random_state=random_seed
|
||||
)
|
||||
|
||||
if train_sample_ratio > 1.0:
|
||||
train_sample_ratio = 1.0
|
||||
logging.warning("Setting the training sample ratio to 1.0")
|
||||
elif train_sample_ratio < 0:
|
||||
logging.error("Invalid training sample ration: {}".format(train_sample_ratio))
|
||||
raise ValueError("Invalid training sample ration: {}".format(train_sample_ratio))
|
||||
raise ValueError(
|
||||
"Invalid training sample ration: {}".format(train_sample_ratio)
|
||||
)
|
||||
|
||||
if test_sample_ratio > 1.0:
|
||||
test_sample_ratio = 1.0
|
||||
|
@ -195,12 +243,16 @@ def load_tc_dataset(
|
|||
train_dataset = processor.dataset_from_dataframe(
|
||||
df=train_df, text_col=text_col, label_col=label_col, max_len=max_len,
|
||||
)
|
||||
train_dataloader = dataloader_from_dataset(train_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=True)
|
||||
train_dataloader = dataloader_from_dataset(
|
||||
train_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=True
|
||||
)
|
||||
|
||||
test_dataset = processor.dataset_from_dataframe(
|
||||
df=test_df, text_col=text_col, label_col=label_col, max_len=max_len,
|
||||
)
|
||||
test_dataloader = dataloader_from_dataset(test_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=False)
|
||||
test_dataloader = dataloader_from_dataset(
|
||||
test_dataset, batch_size=batch_size, num_gpus=num_gpus, shuffle=False
|
||||
)
|
||||
|
||||
return (train_dataloader, test_dataloader, label_encoder, test_labels)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче