[examples/run_s2s] remove task_specific_params and update rouge computation (#10133)
* fix rouge metrics and task specific params * fix typo * round metrics * typo * remove task_specific_params
This commit is contained in:
Родитель
31245775e5
Коммит
f51188cbe7
|
@ -25,10 +25,12 @@ import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import nltk # Here to have a nice missing dependency error message early on
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_metric
|
from datasets import load_dataset, load_metric
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
from filelock import FileLock
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
|
@ -44,6 +46,10 @@ from transformers import (
|
||||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||||
|
|
||||||
|
|
||||||
|
with FileLock(".lock") as lock:
|
||||||
|
nltk.download("punkt", quiet=True)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,10 +116,22 @@ class DataTrainingArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
|
||||||
)
|
)
|
||||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
train_file: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
|
||||||
|
)
|
||||||
validation_file: Optional[str] = field(
|
validation_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
metadata={
|
||||||
|
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
|
||||||
|
"(a jsonlines or csv file)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
test_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
|
||||||
|
"(a jsonlines or csv file)."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
|
@ -298,6 +316,9 @@ def main():
|
||||||
if data_args.validation_file is not None:
|
if data_args.validation_file is not None:
|
||||||
data_files["validation"] = data_args.validation_file
|
data_files["validation"] = data_args.validation_file
|
||||||
extension = data_args.validation_file.split(".")[-1]
|
extension = data_args.validation_file.split(".")[-1]
|
||||||
|
if data_args.test_file is not None:
|
||||||
|
data_files["test"] = data_args.test_file
|
||||||
|
extension = data_args.test_file.split(".")[-1]
|
||||||
datasets = load_dataset(extension, data_files=data_files)
|
datasets = load_dataset(extension, data_files=data_files)
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||||
|
@ -335,15 +356,7 @@ def main():
|
||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
|
||||||
# Get the default prefix if None is passed.
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||||
if data_args.source_prefix is None:
|
|
||||||
task_specific_params = model.config.task_specific_params
|
|
||||||
if task_specific_params is not None:
|
|
||||||
prefix = task_specific_params.get("prefix", "")
|
|
||||||
else:
|
|
||||||
prefix = ""
|
|
||||||
else:
|
|
||||||
prefix = data_args.source_prefix
|
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
# We need to tokenize inputs and targets.
|
# We need to tokenize inputs and targets.
|
||||||
|
@ -487,6 +500,19 @@ def main():
|
||||||
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
||||||
metric = load_metric(metric_name)
|
metric = load_metric(metric_name)
|
||||||
|
|
||||||
|
def postprocess_text(preds, labels):
|
||||||
|
preds = [pred.strip() for pred in preds]
|
||||||
|
labels = [label.strip() for label in labels]
|
||||||
|
|
||||||
|
# rougeLSum expects newline after each sentence
|
||||||
|
if metric_name == "rouge":
|
||||||
|
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||||
|
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
||||||
|
else: # sacrebleu
|
||||||
|
labels = [[label] for label in labels]
|
||||||
|
|
||||||
|
return preds, labels
|
||||||
|
|
||||||
def compute_metrics(eval_preds):
|
def compute_metrics(eval_preds):
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
if isinstance(preds, tuple):
|
if isinstance(preds, tuple):
|
||||||
|
@ -498,22 +524,19 @@ def main():
|
||||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
# Some simple post-processing
|
# Some simple post-processing
|
||||||
decoded_preds = [pred.strip() for pred in decoded_preds]
|
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||||
decoded_labels = [label.strip() for label in decoded_labels]
|
|
||||||
if metric_name == "sacrebleu":
|
|
||||||
decoded_labels = [[label] for label in decoded_labels]
|
|
||||||
|
|
||||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
|
||||||
|
|
||||||
# Extract a few results from ROUGE
|
|
||||||
if metric_name == "rouge":
|
if metric_name == "rouge":
|
||||||
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
||||||
|
# Extract a few results from ROUGE
|
||||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||||
else:
|
else:
|
||||||
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
||||||
result = {"bleu": result["score"]}
|
result = {"bleu": result["score"]}
|
||||||
|
|
||||||
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||||
result["gen_len"] = np.mean(prediction_lens)
|
result["gen_len"] = np.mean(prediction_lens)
|
||||||
|
result = {k: round(v, 4) for k, v in result.items()}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -555,6 +578,7 @@ def main():
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
|
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
|
||||||
|
results = {k: round(v, 4) for k, v in results.items()}
|
||||||
|
|
||||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
|
@ -574,6 +598,7 @@ def main():
|
||||||
num_beams=data_args.num_beams,
|
num_beams=data_args.num_beams,
|
||||||
)
|
)
|
||||||
test_metrics = test_results.metrics
|
test_metrics = test_results.metrics
|
||||||
|
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
|
||||||
|
|
||||||
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
|
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
|
|
Загрузка…
Ссылка в новой задаче