Make default_data_collator more flexible and deprecate old behavior (#5060)

* Make default_data_collator more flexible

* Accept tensors for all features

* Document code

* Refactor

* Formatting
This commit is contained in:
Sylvain Gugger 2020-06-17 15:24:51 -04:00 коммит произвёл GitHub
Родитель 5e06963394
Коммит 20fa828984
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 50 добавлений и 16 удалений

Просмотреть файл

@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
if not isinstance(features[0], dict):
features = [vars(f) for f in features]
first = features[0]
batch = {}
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
if "label" in first:
dtype = torch.long if type(first["label"]) is int else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first:
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else:
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
# Handling of all other possible attributes.
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
else:
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
return batch

Просмотреть файл

@ -4,6 +4,7 @@ import os
import random
import re
import shutil
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
@ -205,6 +206,15 @@ class Trainer:
# Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
self.data_collator = self.data_collator.collate_batch
warnings.warn(
(
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
),
FutureWarning,
)
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:

Просмотреть файл

@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@require_torch
class DataCollatorIntegrationTest(unittest.TestCase):
def test_default_with_dict(self):
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# Features can already be tensors
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)