Make default_data_collator more flexible and deprecate old behavior (#5060)
* Make default_data_collator more flexible * Accept tensors for all features * Document code * Refactor * Formatting
This commit is contained in:
Родитель
5e06963394
Коммит
20fa828984
|
@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||
# have the same attributes.
|
||||
# So we will look at the first element as a proxy for what attributes exist
|
||||
# on the whole batch.
|
||||
if not isinstance(features[0], dict):
|
||||
features = [vars(f) for f in features]
|
||||
|
||||
first = features[0]
|
||||
batch = {}
|
||||
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
# (it should be automatically the case, but let's make sure of it.)
|
||||
if hasattr(first, "label") and first.label is not None:
|
||||
if type(first.label) is int:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
||||
if "label" in first:
|
||||
dtype = torch.long if type(first["label"]) is int else torch.float
|
||||
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
||||
elif "label_ids" in first:
|
||||
if isinstance(first["label_ids"], torch.Tensor):
|
||||
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
||||
else:
|
||||
labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
elif hasattr(first, "label_ids") and first.label_ids is not None:
|
||||
if type(first.label_ids[0]) is int:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
|
||||
else:
|
||||
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
|
||||
batch = {"labels": labels}
|
||||
else:
|
||||
batch = {}
|
||||
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
||||
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
||||
|
||||
# Handling of all other possible attributes.
|
||||
# Handling of all other possible keys.
|
||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||
for k, v in vars(first).items():
|
||||
for k, v in first.items():
|
||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
||||
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = torch.stack([f[k] for f in features])
|
||||
else:
|
||||
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
@ -205,6 +206,15 @@ class Trainer:
|
|||
# Set an xla_device flag on the model's config.
|
||||
# We'll find a more elegant and not need to do this in the future.
|
||||
self.model.config.xla_device = True
|
||||
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
||||
self.data_collator = self.data_collator.collate_batch
|
||||
warnings.warn(
|
||||
(
|
||||
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
|
||||
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.train_dataset is None:
|
||||
|
|
|
@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
|
|||
|
||||
@require_torch
|
||||
class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
def test_default_with_dict(self):
|
||||
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||
batch = default_data_collator(features)
|
||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||
|
||||
# With label_ids
|
||||
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
||||
batch = default_data_collator(features)
|
||||
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
||||
|
||||
# Features can already be tensors
|
||||
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
|
||||
batch = default_data_collator(features)
|
||||
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
||||
|
||||
def test_default_classification(self):
|
||||
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
|
Загрузка…
Ссылка в новой задаче