diff --git a/utils_nlp/dataset/snli.py b/utils_nlp/dataset/snli.py index 6c8674d..5a0513d 100644 --- a/utils_nlp/dataset/snli.py +++ b/utils_nlp/dataset/snli.py @@ -139,17 +139,18 @@ def clean_snli(source_file_path, dest_file_path): "label2", "label3", "label4", - "label5" + "label5", ], - axis=1 + axis=1, ) snli_df = snli_df.rename(index=str, columns={"gold_label": "score"}) - snli_df.to_csv(dest_file_path, sep='\t') + snli_df.to_csv(dest_file_path, sep="\t") return snli_df + def load_azureml_df( local_cache_path=None, file_split=DEFAULT_FILE_SPLIT,