From f51188cbe74195c14c5b3e2e8f10c2f435f9751a Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 12 Feb 2021 17:18:21 +0530 Subject: [PATCH] [examples/run_s2s] remove task_specific_params and update rouge computation (#10133) * fix rouge metrics and task specific params * fix typo * round metrics * typo * remove task_specific_params --- examples/seq2seq/run_seq2seq.py | 63 +++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/examples/seq2seq/run_seq2seq.py b/examples/seq2seq/run_seq2seq.py index 70a635cc5..6a2ec50d8 100755 --- a/examples/seq2seq/run_seq2seq.py +++ b/examples/seq2seq/run_seq2seq.py @@ -25,10 +25,12 @@ import sys from dataclasses import dataclass, field from typing import Optional +import nltk # Here to have a nice missing dependency error message early on import numpy as np from datasets import load_dataset, load_metric import transformers +from filelock import FileLock from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, @@ -44,6 +46,10 @@ from transformers import ( from transformers.trainer_utils import get_last_checkpoint, is_main_process +with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + logger = logging.getLogger(__name__) @@ -110,10 +116,22 @@ class DataTrainingArguments: default=None, metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, ) - train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + train_file: Optional[str] = field( + default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} + ) validation_file: Optional[str] = field( default=None, - metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + metadata={ + "help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on " + "(a jsonlines or csv file)." + }, + ) + test_file: Optional[str] = field( + default=None, + metadata={ + "help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on " + "(a jsonlines or csv file)." + }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} @@ -298,6 +316,9 @@ def main(): if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] datasets = load_dataset(extension, data_files=data_files) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -335,15 +356,7 @@ def main(): if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") - # Get the default prefix if None is passed. - if data_args.source_prefix is None: - task_specific_params = model.config.task_specific_params - if task_specific_params is not None: - prefix = task_specific_params.get("prefix", "") - else: - prefix = "" - else: - prefix = data_args.source_prefix + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" # Preprocessing the datasets. # We need to tokenize inputs and targets. @@ -487,6 +500,19 @@ def main(): metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu" metric = load_metric(metric_name) + def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [label.strip() for label in labels] + + # rougeLSum expects newline after each sentence + if metric_name == "rouge": + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + else: # sacrebleu + labels = [[label] for label in labels] + + return preds, labels + def compute_metrics(eval_preds): preds, labels = eval_preds if isinstance(preds, tuple): @@ -498,22 +524,19 @@ def main(): decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing - decoded_preds = [pred.strip() for pred in decoded_preds] - decoded_labels = [label.strip() for label in decoded_labels] - if metric_name == "sacrebleu": - decoded_labels = [[label] for label in decoded_labels] + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) - result = metric.compute(predictions=decoded_preds, references=decoded_labels) - - # Extract a few results from ROUGE if metric_name == "rouge": + result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) + # Extract a few results from ROUGE result = {key: value.mid.fmeasure * 100 for key, value in result.items()} else: + result = metric.compute(predictions=decoded_preds, references=decoded_labels) result = {"bleu": result["score"]} prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] result["gen_len"] = np.mean(prediction_lens) - + result = {k: round(v, 4) for k, v in result.items()} return result # Initialize our Trainer @@ -555,6 +578,7 @@ def main(): logger.info("*** Evaluate ***") results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams) + results = {k: round(v, 4) for k, v in results.items()} output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt") if trainer.is_world_process_zero(): @@ -574,6 +598,7 @@ def main(): num_beams=data_args.num_beams, ) test_metrics = test_results.metrics + test_metrics["test_loss"] = round(test_metrics["test_loss"], 4) output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt") if trainer.is_world_process_zero():