Better filtering of the model outputs in Trainer (#8633)
* Better filtering of the model outputs in Trainer * Fix examples tests * Add test for Lysandre
This commit is contained in:
Родитель
f2e07e7272
Коммит
4208f496ee
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -153,7 +153,11 @@ class Seq2SeqTrainer(Trainer):
|
|||
return loss
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
|
|
@ -43,6 +43,8 @@ class PretrainedConfig(object):
|
|||
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple sub-configs. In this case
|
||||
the config has to be initialized from two or more configs of type :class:`~transformers.PretrainedConfig`
|
||||
like: :class:`~transformers.EncoderDecoderConfig` or :class:`~RagConfig`.
|
||||
- **keys_to_ignore_at_inference** (:obj:`List[str]`): A list of keys to ignore by default when looking at
|
||||
dictionary outputs of the model during inference.
|
||||
|
||||
Args:
|
||||
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
|
||||
|
|
|
@ -110,6 +110,7 @@ class BartConfig(PretrainedConfig):
|
|||
:obj:`True` for `bart-large-cnn`.
|
||||
"""
|
||||
model_type = "bart"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -77,6 +77,7 @@ class CTRLConfig(PretrainedConfig):
|
|||
"""
|
||||
|
||||
model_type = "ctrl"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig):
|
|||
"""
|
||||
|
||||
model_type = "gpt2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -97,3 +97,4 @@ class MarianConfig(BartConfig):
|
|||
"""
|
||||
|
||||
model_type = "marian"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
|
|
@ -102,3 +102,4 @@ class MBartConfig(BartConfig):
|
|||
"""
|
||||
|
||||
model_type = "mbart"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
|
|
@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig):
|
|||
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
|
||||
"""
|
||||
model_type = "mt5"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -141,4 +141,5 @@ class PegasusConfig(BartConfig):
|
|||
"""
|
||||
|
||||
model_type = "pegasus"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# The implementation of the config object is in BartConfig
|
||||
|
|
|
@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig):
|
|||
smoothing is performed.
|
||||
"""
|
||||
model_type = "prophetnet"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig):
|
|||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "reformer"
|
||||
keys_to_ignore_at_inference = ["past_buckets_states"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -71,6 +71,7 @@ class T5Config(PretrainedConfig):
|
|||
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
|
||||
"""
|
||||
model_type = "t5"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig):
|
|||
"""
|
||||
|
||||
model_type = "transfo-xl"
|
||||
keys_to_ignore_at_inference = ["mems"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig):
|
|||
"""
|
||||
|
||||
model_type = "xlnet"
|
||||
keys_to_ignore_at_inference = ["mems"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1098,10 +1098,11 @@ class Trainer:
|
|||
"""
|
||||
outputs = model(**inputs)
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
return outputs[0]
|
||||
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
def is_local_process_zero(self) -> bool:
|
||||
"""
|
||||
|
@ -1220,7 +1221,9 @@ class Trainer:
|
|||
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
|
||||
def evaluate(
|
||||
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Run evaluation and returns metrics.
|
||||
|
||||
|
@ -1234,6 +1237,9 @@ class Trainer:
|
|||
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
||||
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
|
||||
:obj:`__len__` method.
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||
|
@ -1250,6 +1256,7 @@ class Trainer:
|
|||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if self.compute_metrics is None else None,
|
||||
ignore_keys=ignore_keys,
|
||||
)
|
||||
|
||||
self.log(output.metrics)
|
||||
|
@ -1261,7 +1268,7 @@ class Trainer:
|
|||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
||||
return output.metrics
|
||||
|
||||
def predict(self, test_dataset: Dataset) -> PredictionOutput:
|
||||
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
|
||||
"""
|
||||
Run prediction and returns predictions and potential metrics.
|
||||
|
||||
|
@ -1272,6 +1279,9 @@ class Trainer:
|
|||
test_dataset (:obj:`Dataset`):
|
||||
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -1291,10 +1301,14 @@ class Trainer:
|
|||
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
|
||||
return self.prediction_loop(test_dataloader, description="Prediction")
|
||||
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
|
||||
|
||||
def prediction_loop(
|
||||
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> PredictionOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
||||
|
@ -1346,7 +1360,7 @@ class Trainer:
|
|||
self.callback_handler.eval_dataloader = dataloader
|
||||
|
||||
for step, inputs in enumerate(dataloader):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||
if loss is not None:
|
||||
losses = loss.repeat(batch_size)
|
||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
||||
|
@ -1410,7 +1424,11 @@ class Trainer:
|
|||
return nested_numpify(tensors)
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
@ -1427,6 +1445,9 @@ class Trainer:
|
|||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||
|
@ -1434,6 +1455,11 @@ class Trainer:
|
|||
"""
|
||||
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
with torch.no_grad():
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
|
@ -1442,16 +1468,21 @@ class Trainer:
|
|||
else:
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
loss = outputs[0].mean().detach()
|
||||
logits = outputs[1:]
|
||||
if isinstance(outputs, dict):
|
||||
loss = outputs["loss"].mean().detach()
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
loss = outputs[0].mean().detach()
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
|
||||
logits = outputs[:]
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits = outputs
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||
# Remove the past from the logits.
|
||||
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
|
|
@ -44,6 +44,8 @@ if is_torch_available():
|
|||
DataCollatorForLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedModel,
|
||||
TextDataset,
|
||||
|
@ -73,6 +75,18 @@ class RegressionDataset:
|
|||
return result
|
||||
|
||||
|
||||
class RepeatDataset:
|
||||
def __init__(self, x, length=64):
|
||||
self.x = x
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
return {"input_ids": self.x, "labels": self.x}
|
||||
|
||||
|
||||
class DynamicShapesDataset:
|
||||
def __init__(self, length=64, seed=42, batch_size=8):
|
||||
self.length = length
|
||||
|
@ -136,6 +150,20 @@ if is_torch_available():
|
|||
loss = torch.nn.functional.mse_loss(y, labels)
|
||||
return (loss, y, y) if self.double_output else (loss, y)
|
||||
|
||||
class RegressionDictModel(torch.nn.Module):
|
||||
def __init__(self, a=0, b=0):
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.tensor(a).float())
|
||||
self.b = torch.nn.Parameter(torch.tensor(b).float())
|
||||
self.config = None
|
||||
|
||||
def forward(self, input_x=None, labels=None, **kwargs):
|
||||
y = input_x * self.a + self.b
|
||||
result = {"output": y}
|
||||
if labels is not None:
|
||||
result["loss"] = torch.nn.functional.mse_loss(y, labels)
|
||||
return result
|
||||
|
||||
class RegressionPreTrainedModel(PreTrainedModel):
|
||||
config_class = RegressionModelConfig
|
||||
base_model_prefix = "regression"
|
||||
|
@ -236,6 +264,33 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
metrics = trainer.evaluate()
|
||||
self.assertEqual(metrics[metric], best_value)
|
||||
|
||||
def test_trainer_works_with_dict(self):
|
||||
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
|
||||
# anything.
|
||||
train_dataset = RegressionDataset()
|
||||
eval_dataset = RegressionDataset()
|
||||
model = RegressionDictModel()
|
||||
args = TrainingArguments("./regression")
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
_ = trainer.evaluate()
|
||||
_ = trainer.predict(eval_dataset)
|
||||
|
||||
def test_evaluation_with_keys_to_drop(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
|
||||
tiny_gpt2 = GPT2LMHeadModel(config)
|
||||
x = torch.randint(0, 100, (128,))
|
||||
eval_dataset = RepeatDataset(x)
|
||||
args = TrainingArguments("./test")
|
||||
trainer = Trainer(tiny_gpt2, args, eval_dataset=eval_dataset)
|
||||
# By default the past_key_values are removed
|
||||
result = trainer.predict(eval_dataset)
|
||||
self.assertTrue(isinstance(result.predictions, np.ndarray))
|
||||
# We can still get them by setting ignore_keys to []
|
||||
result = trainer.predict(eval_dataset, ignore_keys=[])
|
||||
self.assertTrue(isinstance(result.predictions, tuple))
|
||||
self.assertEqual(len(result.predictions), 2)
|
||||
|
||||
def test_training_arguments_are_left_untouched(self):
|
||||
trainer = get_regression_trainer()
|
||||
trainer.train()
|
||||
|
|
Загрузка…
Ссылка в новой задаче