Added data collator for permutation (XLNet) language modeling and related calls (#5522)
* Added data collator for XLNet language modeling and related calls Added DataCollatorForXLNetLanguageModeling in data/data_collator.py to generate necessary inputs for language modeling training with XLNetLMHeadModel. Also added related arguments, logic and calls in examples/language-modeling/run_language_modeling.py. Resolves: #4739, #2008 (partially) * Changed name to `DataCollatorForPermutationLanguageModeling` Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`. Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use. CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of similar to `mems` for XLNet). Changed calls and imports appropriately. * Added detailed comments, changed variable names Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative. * Added tests for new data collator Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences. * Fixed styling issues
This commit is contained in:
Родитель
1d2332861f
Коммит
3dcb748e31
|
@ -14,9 +14,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
||||
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
||||
using a masked language modeling (MLM) loss.
|
||||
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
|
||||
GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
|
||||
using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -33,6 +33,7 @@ from transformers import (
|
|||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
HfArgumentParser,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedTokenizer,
|
||||
|
@ -101,6 +102,15 @@ class DataTrainingArguments:
|
|||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
plm_probability: float = field(
|
||||
default=1 / 6,
|
||||
metadata={
|
||||
"help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
|
||||
},
|
||||
)
|
||||
max_span_length: int = field(
|
||||
default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
|
||||
)
|
||||
|
||||
block_size: int = field(
|
||||
default=-1,
|
||||
|
@ -207,8 +217,8 @@ def main():
|
|||
|
||||
if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
|
||||
raise ValueError(
|
||||
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||
"flag (masked language modeling)."
|
||||
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
|
||||
"--mlm flag (masked language modeling)."
|
||||
)
|
||||
|
||||
if data_args.block_size <= 0:
|
||||
|
@ -221,9 +231,14 @@ def main():
|
|||
|
||||
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
if config.model_type == "xlnet":
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||
tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length,
|
||||
)
|
||||
else:
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
|
|
@ -400,7 +400,12 @@ if is_torch_available():
|
|||
|
||||
# Trainer
|
||||
from .trainer import Trainer, torch_distributed_zero_first
|
||||
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
|
||||
from .data.data_collator import (
|
||||
default_data_collator,
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
)
|
||||
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
||||
|
||||
# Benchmarks
|
||||
|
|
|
@ -21,8 +21,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||
Very simple data collator that:
|
||||
- simply collates batches of dict-like objects
|
||||
- Performs special handling for potential keys named:
|
||||
- `label`: handles a single value (int or float) per object
|
||||
- `label_ids`: handles a list of values per object
|
||||
- ``label``: handles a single value (int or float) per object
|
||||
- ``label_ids``: handles a list of values per object
|
||||
- does not do any additional preprocessing
|
||||
|
||||
i.e., Property names of the input object will be used as corresponding inputs to the model.
|
||||
|
@ -134,3 +134,126 @@ class DataCollatorForLanguageModeling:
|
|||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPermutationLanguageModeling:
|
||||
"""
|
||||
Data collator used for permutation language modeling.
|
||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
plm_probability: float = 1 / 6
|
||||
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||
|
||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
batch = self._tensorize_batch(examples)
|
||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||
|
||||
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
||||
length_of_first = examples[0].size(0)
|
||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||
if are_tensors_same_length:
|
||||
return torch.stack(examples, dim=0)
|
||||
else:
|
||||
if self.tokenizer._pad_token is None:
|
||||
raise ValueError(
|
||||
"You are attempting to pad samples but the tokenizer you are using"
|
||||
f" ({self.tokenizer.__class__.__name__}) does not have one."
|
||||
)
|
||||
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
||||
|
||||
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
||||
0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
|
||||
1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked)
|
||||
2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked
|
||||
3. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length - span_length]`` and mask tokens ``start_index:start_index + span_length``
|
||||
4. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in the sequence to be processed), repeat from Step 1.
|
||||
"""
|
||||
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
|
||||
)
|
||||
|
||||
if inputs.size(1) % 2 != 0:
|
||||
raise ValueError(
|
||||
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
|
||||
)
|
||||
|
||||
labels = inputs.clone()
|
||||
# Creating the mask and target_mapping tensors
|
||||
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
||||
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
||||
|
||||
for i in range(labels.size(0)):
|
||||
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
||||
cur_len = 0
|
||||
max_len = labels.size(1)
|
||||
|
||||
while cur_len < max_len:
|
||||
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
||||
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
||||
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
||||
context_length = int(span_length / self.plm_probability)
|
||||
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
||||
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
||||
masked_indices[i, start_index : start_index + span_length] = 1
|
||||
# Set `cur_len = cur_len + context_length`
|
||||
cur_len += context_length
|
||||
|
||||
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
||||
# the i-th predict corresponds to the i-th token.
|
||||
target_mapping[i] = torch.eye(labels.size(1))
|
||||
|
||||
special_tokens_mask = torch.tensor(
|
||||
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
||||
dtype=torch.bool,
|
||||
)
|
||||
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
||||
if self.tokenizer._pad_token is not None:
|
||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||
masked_indices.masked_fill_(padding_mask, value=0.0)
|
||||
|
||||
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
||||
non_func_mask = ~(padding_mask & special_tokens_mask)
|
||||
|
||||
inputs[masked_indices] = self.tokenizer.mask_token_id
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
||||
|
||||
for i in range(labels.size(0)):
|
||||
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
||||
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
||||
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
||||
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
||||
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
||||
# This requires that the sequence length be even.
|
||||
|
||||
# Create a linear factorisation order
|
||||
perm_index = torch.arange(labels.size(1))
|
||||
# Split this into two halves, assuming that half the sequence is reused each time
|
||||
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
||||
# Permute the two halves such that they do not cross over
|
||||
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
||||
# Flatten this out into the desired permuted factorisation order
|
||||
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
||||
# Set the permutation indices of non-masked (non-functional) tokens to the
|
||||
# smallest index (-1) so that:
|
||||
# (1) They can be seen by all other positions
|
||||
# (2) They cannot see masked positions, so there won't be information leak
|
||||
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
||||
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
||||
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
||||
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
||||
perm_mask[i] = (
|
||||
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
||||
) & masked_indices[i]
|
||||
|
||||
return inputs, perm_mask, target_mapping, labels
|
||||
|
|
|
@ -12,6 +12,7 @@ if is_torch_available():
|
|||
AutoModelForSequenceClassification,
|
||||
default_data_collator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
TextDataset,
|
||||
|
@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
||||
# ^ permutation lm
|
||||
|
||||
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
|
||||
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
|
||||
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
|
||||
|
||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
|
||||
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
|
||||
example = [torch.randint(5, [5])]
|
||||
with self.assertRaises(ValueError):
|
||||
# Expect error due to odd sequence length
|
||||
data_collator(example)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerIntegrationTest(unittest.TestCase):
|
||||
|
|
Загрузка…
Ссылка в новой задаче