Seq2SeqDataset: avoid passing src_lang everywhere (#7470)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Родитель
08939cfdf7
Коммит
c031d01023
|
@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
|
|||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||
assert ids1.intersection(ids2) == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tok_name",
|
||||
[
|
||||
MBART_TINY,
|
||||
MARIAN_TINY,
|
||||
T5_TINY,
|
||||
BART_TINY,
|
||||
PEGASUS_XSUM,
|
||||
],
|
||||
)
|
||||
def test_dataset_kwargs(tok_name):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
if tok_name == MBART_TINY:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer,
|
||||
data_dir=make_test_data_dir(),
|
||||
type_path="train",
|
||||
max_source_length=4,
|
||||
max_target_length=8,
|
||||
src_lang="EN",
|
||||
tgt_lang="FR",
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||
else:
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
|
||||
)
|
||||
kwargs = train_dataset.dataset_kwargs
|
||||
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
||||
|
|
|
@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
|||
return loss, nll_loss
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
"""Only used by LegacyDataset"""
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||
max_target_length,
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
**dataset_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
|
@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
|
|||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
|
||||
self.dataset_kwargs = dataset_kwargs
|
||||
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
|||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
|
||||
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
||||
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
|
||||
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
|
@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
|||
"labels": target_ids,
|
||||
}
|
||||
|
||||
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
"""Only used by LegacyDataset"""
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
|
@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
|||
"""Call prepare_seq2seq_batch."""
|
||||
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.tgt_lang,
|
||||
max_length=self.max_source_length,
|
||||
max_target_length=self.max_target_length,
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
**self.dataset_kwargs,
|
||||
).data
|
||||
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
||||
return batch_encoding
|
||||
|
|
Загрузка…
Ссылка в новой задаче