Model output test (#6155)
* Use return_dict=True in all tests * Formatting
This commit is contained in:
Родитель
86caab1e0b
Коммит
d951c14ae4
|
@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
return_dict=False,
|
||||
**kwargs_encoder,
|
||||
)
|
||||
|
||||
|
@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
labels=labels,
|
||||
return_dict=False,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
|
|
|
@ -688,16 +688,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
|
||||
|
||||
lm_loss = None
|
||||
lm_loss, mc_loss = None, None
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
||||
mc_loss = None
|
||||
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
|
||||
if labels is not None:
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
mc_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||
|
|
|
@ -2386,6 +2386,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
|||
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.reformer(
|
||||
input_ids,
|
||||
|
|
|
@ -121,6 +121,7 @@ class XxxModelTester:
|
|||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -134,18 +135,13 @@ class XxxModelTester:
|
|||
model = XxxModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -153,16 +149,10 @@ class XxxModelTester:
|
|||
model = XxxForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_xxx_for_question_answering(
|
||||
|
@ -171,18 +161,13 @@ class XxxModelTester:
|
|||
model = XxxForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -194,13 +179,7 @@ class XxxModelTester:
|
|||
model = XxxForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -211,11 +190,7 @@ class XxxModelTester:
|
|||
model = XxxForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ class AlbertModelTester:
|
|||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
num_hidden_groups=self.num_hidden_groups,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -111,18 +112,13 @@ class AlbertModelTester:
|
|||
model = AlbertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_albert_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -130,22 +126,17 @@ class AlbertModelTester:
|
|||
model = AlbertForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores, sop_scores = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
sentence_order_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
"sop_scores": sop_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["sop_scores"].size()), [self.batch_size, config.num_labels])
|
||||
self.parent.assertListEqual(list(result["sop_logits"].size()), [self.batch_size, config.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_albert_for_masked_lm(
|
||||
|
@ -154,16 +145,8 @@ class AlbertModelTester:
|
|||
model = AlbertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_albert_for_question_answering(
|
||||
|
@ -172,18 +155,13 @@ class AlbertModelTester:
|
|||
model = AlbertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -195,13 +173,7 @@ class AlbertModelTester:
|
|||
model = AlbertForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -212,11 +184,7 @@ class AlbertModelTester:
|
|||
model = AlbertForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -230,16 +198,12 @@ class AlbertModelTester:
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
|
|
@ -238,6 +238,7 @@ class BartHeadTests(unittest.TestCase):
|
|||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
return_dict=True,
|
||||
)
|
||||
return config, input_ids, batch_size
|
||||
|
||||
|
@ -247,24 +248,20 @@ class BartHeadTests(unittest.TestCase):
|
|||
model = BartForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
|
||||
logits = outputs[1]
|
||||
expected_shape = torch.Size((batch_size, config.num_labels))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
loss = outputs[0]
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
self.assertEqual(outputs["logits"].shape, expected_shape)
|
||||
self.assertIsInstance(outputs["loss"].item(), float)
|
||||
|
||||
def test_question_answering_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
|
||||
model = BartForQuestionAnswering(config)
|
||||
model.to(torch_device)
|
||||
loss, start_logits, end_logits, _ = model(
|
||||
input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,
|
||||
)
|
||||
outputs = model(input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||
|
||||
self.assertEqual(start_logits.shape, input_ids.shape)
|
||||
self.assertEqual(end_logits.shape, input_ids.shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
|
||||
self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
|
||||
self.assertIsInstance(outputs["loss"].item(), float)
|
||||
|
||||
@timeout_decorator.timeout(1)
|
||||
def test_lm_forward(self):
|
||||
|
@ -272,10 +269,10 @@ class BartHeadTests(unittest.TestCase):
|
|||
lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||
lm_model = BartForConditionalGeneration(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=input_ids, labels=lm_labels)
|
||||
outputs = lm_model(input_ids=input_ids, labels=lm_labels)
|
||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
self.assertEqual(outputs["logits"].shape, expected_shape)
|
||||
self.assertIsInstance(outputs["loss"].item(), float)
|
||||
|
||||
def test_lm_uneven_forward(self):
|
||||
config = BartConfig(
|
||||
|
@ -288,13 +285,14 @@ class BartHeadTests(unittest.TestCase):
|
|||
encoder_ffn_dim=8,
|
||||
decoder_ffn_dim=8,
|
||||
max_position_embeddings=48,
|
||||
return_dict=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
outputs = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertEqual(outputs["logits"].shape, expected_shape)
|
||||
|
||||
def test_generate_beam_search(self):
|
||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device)
|
||||
|
|
|
@ -120,6 +120,7 @@ class BertModelTester:
|
|||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -160,18 +161,13 @@ class BertModelTester:
|
|||
model = BertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_bert_model_as_decoder(
|
||||
self,
|
||||
|
@ -188,29 +184,24 @@ class BertModelTester:
|
|||
model = BertModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
sequence_output, pooled_output = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_bert_for_causal_lm(
|
||||
self,
|
||||
|
@ -227,16 +218,8 @@ class BertModelTester:
|
|||
model = BertLMHeadModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_masked_lm(
|
||||
|
@ -245,16 +228,8 @@ class BertModelTester:
|
|||
model = BertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_model_for_causal_lm_as_decoder(
|
||||
|
@ -272,7 +247,7 @@ class BertModelTester:
|
|||
model = BertLMHeadModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
|
@ -280,20 +255,14 @@ class BertModelTester:
|
|||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
loss, prediction_scores = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_next_sequence_prediction(
|
||||
|
@ -302,14 +271,10 @@ class BertModelTester:
|
|||
model = BertForNextSentencePrediction(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, seq_relationship_score = model(
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_pretraining(
|
||||
|
@ -318,22 +283,17 @@ class BertModelTester:
|
|||
model = BertForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||
self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_question_answering(
|
||||
|
@ -342,18 +302,13 @@ class BertModelTester:
|
|||
model = BertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -365,13 +320,7 @@ class BertModelTester:
|
|||
model = BertForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -382,11 +331,7 @@ class BertModelTester:
|
|||
model = BertForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -400,16 +345,12 @@ class BertModelTester:
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -28,13 +28,13 @@ if is_torch_available():
|
|||
class CamembertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_output_embeds_base_model(self):
|
||||
model = CamembertModel.from_pretrained("camembert-base")
|
||||
model = CamembertModel.from_pretrained("camembert-base", return_dict=True)
|
||||
model.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], device=torch_device, dtype=torch.long,
|
||||
) # J'aime le camembert !
|
||||
output = model(input_ids)[0]
|
||||
output = model(input_ids)["last_hidden_state"]
|
||||
expected_shape = torch.Size((1, 10, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# compare the actual values for a slice.
|
||||
|
|
|
@ -74,7 +74,6 @@ class ModelTesterMixin:
|
|||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
|
|
@ -88,9 +88,10 @@ class CTRLModelTester:
|
|||
# hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
# initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
@ -117,29 +118,20 @@ class CTRLModelTester:
|
|||
|
||||
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, presents = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"presents": presents,
|
||||
}
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
||||
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = CTRLLMHeadModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
|
||||
result = {"loss": loss, "lm_logits": lm_logits}
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -110,6 +110,7 @@ if is_torch_available():
|
|||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -123,14 +124,10 @@ if is_torch_available():
|
|||
model = DistilBertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
(sequence_output,) = model(input_ids, input_mask)
|
||||
(sequence_output,) = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
result = model(input_ids, input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_distilbert_for_masked_lm(
|
||||
|
@ -139,13 +136,9 @@ if is_torch_available():
|
|||
model = DistilBertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -155,14 +148,9 @@ if is_torch_available():
|
|||
model = DistilBertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -174,11 +162,7 @@ if is_torch_available():
|
|||
model = DistilBertForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -190,11 +174,7 @@ if is_torch_available():
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
|
@ -209,13 +189,9 @@ if is_torch_available():
|
|||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -115,6 +115,7 @@ class DPRModelTester:
|
|||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
|
||||
|
||||
|
@ -126,15 +127,11 @@ class DPRModelTester:
|
|||
model = DPRContextEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_dpr_question_encoder(
|
||||
|
@ -143,15 +140,11 @@ class DPRModelTester:
|
|||
model = DPRQuestionEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_dpr_reader(
|
||||
|
@ -160,12 +153,7 @@ class DPRModelTester:
|
|||
model = DPRReader(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,)
|
||||
result = {
|
||||
"relevance_logits": relevance_logits,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask,)
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
|
||||
|
|
|
@ -97,6 +97,7 @@ class ElectraModelTester:
|
|||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -127,15 +128,11 @@ class ElectraModelTester:
|
|||
model = ElectraModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
(sequence_output,) = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
(sequence_output,) = model(input_ids, token_type_ids=token_type_ids)
|
||||
(sequence_output,) = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_electra_for_masked_lm(
|
||||
|
@ -152,16 +149,8 @@ class ElectraModelTester:
|
|||
model = ElectraForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_electra_for_token_classification(
|
||||
|
@ -179,11 +168,7 @@ class ElectraModelTester:
|
|||
model = ElectraForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -202,13 +187,7 @@ class ElectraModelTester:
|
|||
model = ElectraForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=fake_token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -227,13 +206,7 @@ class ElectraModelTester:
|
|||
model = ElectraForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -251,18 +224,13 @@ class ElectraModelTester:
|
|||
model = ElectraForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -285,16 +253,12 @@ class ElectraModelTester:
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -110,6 +110,7 @@ class FlaubertModelTester(object):
|
|||
initializer_range=self.initializer_range,
|
||||
summary_type=self.summary_type,
|
||||
use_proj=self.use_proj,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -142,15 +143,11 @@ class FlaubertModelTester(object):
|
|||
model = FlaubertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
|
||||
outputs = model(input_ids, langs=token_type_ids)
|
||||
outputs = model(input_ids)
|
||||
sequence_output = outputs[0]
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
|
||||
result = model(input_ids, langs=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_flaubert_lm_head(
|
||||
|
@ -169,13 +166,7 @@ class FlaubertModelTester(object):
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
@ -195,16 +186,9 @@ class FlaubertModelTester(object):
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
loss, start_logits, end_logits = outputs
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
result = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -225,10 +209,9 @@ class FlaubertModelTester(object):
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids)
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
|
||||
result = model(input_ids)
|
||||
|
||||
outputs = model(
|
||||
result_with_labels = model(
|
||||
input_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
|
@ -237,7 +220,7 @@ class FlaubertModelTester(object):
|
|||
p_mask=input_mask,
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
result_with_labels = model(
|
||||
input_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
|
@ -245,22 +228,13 @@ class FlaubertModelTester(object):
|
|||
is_impossible=is_impossible_labels,
|
||||
)
|
||||
|
||||
(total_loss,) = outputs
|
||||
(total_loss,) = result_with_labels.to_tuple()
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
|
||||
(total_loss,) = outputs
|
||||
(total_loss,) = result_with_labels.to_tuple()
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
|
||||
)
|
||||
|
@ -292,13 +266,8 @@ class FlaubertModelTester(object):
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
(logits,) = model(input_ids)
|
||||
loss, logits = model(input_ids, labels=sequence_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids)
|
||||
result = model(input_ids, labels=sequence_labels)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size])
|
||||
|
@ -320,11 +289,7 @@ class FlaubertModelTester(object):
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -347,16 +312,12 @@ class FlaubertModelTester(object):
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -122,9 +122,10 @@ class GPT2ModelTester:
|
|||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
# initializer_range=self.initializer_range,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
@ -149,18 +150,14 @@ class GPT2ModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, presents = model(input_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"presents": presents,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
||||
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
|
||||
|
||||
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPT2Model(config=config)
|
||||
|
@ -175,7 +172,7 @@ class GPT2ModelTester:
|
|||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past = outputs
|
||||
output, past = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -185,8 +182,8 @@ class GPT2ModelTester:
|
|||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
|
||||
|
||||
output_from_no_past, _ = model(next_input_ids, token_type_ids=next_token_type_ids)
|
||||
output_from_past, _ = model(next_tokens, token_type_ids=next_token_types, past=past)
|
||||
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
@ -209,7 +206,7 @@ class GPT2ModelTester:
|
|||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past = model(input_ids, attention_mask=attn_mask)
|
||||
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -226,8 +223,8 @@ class GPT2ModelTester:
|
|||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
||||
output_from_past, _ = model(next_tokens, past=past, attention_mask=attn_mask)
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
@ -242,13 +239,10 @@ class GPT2ModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, lm_logits, _ = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
|
||||
result = {"loss": loss, "lm_logits": lm_logits}
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def create_and_check_double_lm_head_model(
|
||||
|
@ -270,11 +264,8 @@ class GPT2ModelTester:
|
|||
"labels": multiple_choice_inputs_ids,
|
||||
}
|
||||
|
||||
loss, lm_logits, mc_logits, _ = model(**inputs)
|
||||
|
||||
result = {"loss": loss, "lm_logits": lm_logits, "mc_logits": mc_logits}
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
result = model(**inputs)
|
||||
self.parent.assertListEqual(list(result["lm_loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
|
|
@ -108,6 +108,7 @@ class LongformerModelTester:
|
|||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
attention_window=self.attention_window,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -123,8 +124,8 @@ class LongformerModelTester:
|
|||
model.eval()
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
output_with_mask = model(input_ids, attention_mask=attention_mask)[0]
|
||||
output_without_mask = model(input_ids)[0]
|
||||
output_with_mask = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
output_without_mask = model(input_ids)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
|
||||
|
||||
def create_and_check_longformer_model(
|
||||
|
@ -133,18 +134,13 @@ class LongformerModelTester:
|
|||
model = LongformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -156,25 +152,19 @@ class LongformerModelTester:
|
|||
global_attention_mask[:, input_mask.shape[-1] // 2] = 0
|
||||
global_attention_mask = global_attention_mask.to(torch_device)
|
||||
|
||||
sequence_output, pooled_output = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
sequence_output, pooled_output = model(
|
||||
input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
|
||||
)
|
||||
sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask)
|
||||
result = model(input_ids, global_attention_mask=global_attention_mask)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -182,16 +172,8 @@ class LongformerModelTester:
|
|||
model = LongformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
|
@ -200,7 +182,7 @@ class LongformerModelTester:
|
|||
model = LongformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
global_attention_mask=input_mask,
|
||||
|
@ -208,11 +190,6 @@ class LongformerModelTester:
|
|||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -224,13 +201,7 @@ class LongformerModelTester:
|
|||
model = LongformerForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -241,11 +212,7 @@ class LongformerModelTester:
|
|||
model = LongformerForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -260,17 +227,13 @@ class LongformerModelTester:
|
|||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
global_attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -114,13 +114,14 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
|||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
add_final_layer_norm=True,
|
||||
return_dict=True,
|
||||
)
|
||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
self.assertEqual(result["logits"].shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
|
|
@ -122,6 +122,7 @@ class MobileBertModelTester:
|
|||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -162,18 +163,14 @@ class MobileBertModelTester:
|
|||
model = MobileBertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_mobilebert_model_as_decoder(
|
||||
self,
|
||||
|
@ -190,29 +187,25 @@ class MobileBertModelTester:
|
|||
model = MobileBertModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
sequence_output, pooled_output = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_mobilebert_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -220,16 +213,8 @@ class MobileBertModelTester:
|
|||
model = MobileBertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_next_sequence_prediction(
|
||||
|
@ -238,14 +223,10 @@ class MobileBertModelTester:
|
|||
model = MobileBertForNextSentencePrediction(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, seq_relationship_score = model(
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_pretraining(
|
||||
|
@ -254,22 +235,17 @@ class MobileBertModelTester:
|
|||
model = MobileBertForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["prediction_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["seq_relationship_score"].size()), [self.batch_size, 2])
|
||||
self.parent.assertListEqual(list(result["seq_relationship_logits"].size()), [self.batch_size, 2])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_mobilebert_for_question_answering(
|
||||
|
@ -278,18 +254,13 @@ class MobileBertModelTester:
|
|||
model = MobileBertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -301,13 +272,7 @@ class MobileBertModelTester:
|
|||
model = MobileBertForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -318,11 +283,7 @@ class MobileBertModelTester:
|
|||
model = MobileBertForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -336,16 +297,12 @@ class MobileBertModelTester:
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -85,9 +85,10 @@ class OpenAIGPTModelTester:
|
|||
# hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
@ -110,13 +111,12 @@ class OpenAIGPTModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
model(input_ids, token_type_ids=token_type_ids)
|
||||
(sequence_output,) = model(input_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {"sequence_output": sequence_output}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
|
@ -124,13 +124,10 @@ class OpenAIGPTModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, lm_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
|
||||
result = {"loss": loss, "lm_logits": lm_logits}
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
|
@ -138,11 +135,8 @@ class OpenAIGPTModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
|
||||
result = {"loss": loss, "lm_logits": lm_logits}
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
self.parent.assertListEqual(list(result["lm_loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
|
|
@ -165,6 +165,7 @@ class ReformerModelTester:
|
|||
attn_layers=self.attn_layers,
|
||||
pad_token_id=self.pad_token_id,
|
||||
hash_seed=self.hash_seed,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -181,15 +182,12 @@ class ReformerModelTester:
|
|||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, _ = model(input_ids, attention_mask=input_mask)
|
||||
sequence_output, _ = model(input_ids)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
# 2 * hidden_size because we use reversible resnet layers
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
|
@ -198,7 +196,7 @@ class ReformerModelTester:
|
|||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
|
||||
|
@ -207,13 +205,9 @@ class ReformerModelTester:
|
|||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -222,13 +216,9 @@ class ReformerModelTester:
|
|||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -325,7 +315,7 @@ class ReformerModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
|
@ -408,7 +398,7 @@ class ReformerModelTester:
|
|||
model.to(torch_device)
|
||||
model.half()
|
||||
model.eval()
|
||||
output = model(input_ids, attention_mask=input_mask)[0]
|
||||
output = model(input_ids, attention_mask=input_mask)["last_input_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
|
||||
|
@ -444,21 +434,16 @@ class ReformerModelTester:
|
|||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
output_logits = model(input_ids, attention_mask=input_mask)["logits"]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -474,11 +459,11 @@ class ReformerModelTester:
|
|||
input_ids_second = input_ids[:, -1:]
|
||||
|
||||
# return saved cache
|
||||
_, past_buckets_states = model(input_ids_first, use_cache=True)
|
||||
past_buckets_states = model(input_ids_first, use_cache=True)["past_buckets_states"]
|
||||
|
||||
# calculate last output with and without cache
|
||||
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
|
||||
outputs_without_cache = model(input_ids)[0][:, -1]
|
||||
outputs_with_cache = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)["logits"]
|
||||
outputs_without_cache = model(input_ids)["logits"][:, -1]
|
||||
|
||||
# select random slice idx
|
||||
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
|
||||
|
@ -504,11 +489,7 @@ class ReformerModelTester:
|
|||
model = ReformerForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -96,6 +96,7 @@ class RobertaModelTester:
|
|||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -109,18 +110,14 @@ class RobertaModelTester:
|
|||
model = RobertaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooler_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_roberta_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -128,16 +125,8 @@ class RobertaModelTester:
|
|||
model = RobertaForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_roberta_for_token_classification(
|
||||
|
@ -147,11 +136,7 @@ class RobertaModelTester:
|
|||
model = RobertaForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -165,16 +150,12 @@ class RobertaModelTester:
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -184,18 +165,13 @@ class RobertaModelTester:
|
|||
model = RobertaForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
|
|
@ -83,6 +83,7 @@ class T5ModelTester:
|
|||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -136,13 +137,17 @@ class T5ModelTester:
|
|||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
decoder_output, decoder_past, encoder_output = model(
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
decoder_output, decoder_past, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
|
@ -162,10 +167,9 @@ class T5ModelTester:
|
|||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=lm_labels,
|
||||
)
|
||||
loss, prediction_scores, _, _ = outputs
|
||||
self.parent.assertEqual(len(outputs), 4)
|
||||
self.parent.assertEqual(prediction_scores.size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(loss.size(), ())
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def create_and_check_t5_decoder_model_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
|
@ -179,7 +183,7 @@ class T5ModelTester:
|
|||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_value_states = outputs
|
||||
output, past_key_value_states = outputs.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -187,8 +191,8 @@ class T5ModelTester:
|
|||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
@ -212,7 +216,7 @@ class T5ModelTester:
|
|||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -229,8 +233,10 @@ class T5ModelTester:
|
|||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[0]
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
@ -256,7 +262,7 @@ class T5ModelTester:
|
|||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).to(torch_device).half().eval()
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
|
|
@ -75,6 +75,7 @@ class TransfoXLModelTester:
|
|||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_id=self.eos_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
@ -88,13 +89,13 @@ class TransfoXLModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_1, mems_1 = model(input_ids_1)
|
||||
hidden_states_2, mems_2 = model(input_ids_2, mems_1)
|
||||
outputs1 = model(input_ids_1)
|
||||
outputs2 = model(input_ids_2, outputs1["mems"])
|
||||
outputs = {
|
||||
"hidden_states_1": hidden_states_1,
|
||||
"mems_1": mems_1,
|
||||
"hidden_states_2": hidden_states_2,
|
||||
"mems_2": mems_2,
|
||||
"hidden_states_1": outputs1["last_hidden_state"],
|
||||
"mems_1": outputs1["mems"],
|
||||
"hidden_states_2": outputs2["last_hidden_state"],
|
||||
"mems_2": outputs2["mems"],
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -119,17 +120,17 @@ class TransfoXLModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
lm_logits_1, mems_1 = model(input_ids_1)
|
||||
loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
|
||||
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
|
||||
loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
|
||||
lm_logits_1 = model(input_ids_1)["prediction_scores"]
|
||||
outputs1 = model(input_ids_1, labels=lm_labels)
|
||||
lm_logits_2 = model(input_ids_2, mems=outputs1["mems"])["prediction_scores"]
|
||||
outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
|
||||
|
||||
outputs = {
|
||||
"loss_1": loss_1,
|
||||
"mems_1": mems_1,
|
||||
"loss_1": outputs1["losses"],
|
||||
"mems_1": outputs1["mems"],
|
||||
"lm_logits_1": lm_logits_1,
|
||||
"loss_2": loss_2,
|
||||
"mems_2": mems_2,
|
||||
"loss_2": outputs2["losses"],
|
||||
"mems_2": outputs2["mems"],
|
||||
"lm_logits_2": lm_logits_2,
|
||||
}
|
||||
return outputs
|
||||
|
|
|
@ -113,6 +113,7 @@ class XLMModelTester:
|
|||
use_proj=self.use_proj,
|
||||
num_labels=self.num_labels,
|
||||
bos_token_id=self.bos_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -145,15 +146,11 @@ class XLMModelTester:
|
|||
model = XLMModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids)
|
||||
outputs = model(input_ids, langs=token_type_ids)
|
||||
outputs = model(input_ids)
|
||||
sequence_output = outputs[0]
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
}
|
||||
result = model(input_ids, lengths=input_lengths, langs=token_type_ids)
|
||||
result = model(input_ids, langs=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_xlm_lm_head(
|
||||
|
@ -172,13 +169,7 @@ class XLMModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
result = model(input_ids, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
@ -201,13 +192,7 @@ class XLMModelTester:
|
|||
outputs = model(input_ids)
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
loss, start_logits, end_logits = outputs
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
result = outputs
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
@ -228,10 +213,9 @@ class XLMModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids)
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
|
||||
result = model(input_ids)
|
||||
|
||||
outputs = model(
|
||||
result_with_labels = model(
|
||||
input_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
|
@ -240,7 +224,7 @@ class XLMModelTester:
|
|||
p_mask=input_mask,
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
result_with_labels = model(
|
||||
input_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
|
@ -248,22 +232,13 @@ class XLMModelTester:
|
|||
is_impossible=is_impossible_labels,
|
||||
)
|
||||
|
||||
(total_loss,) = outputs
|
||||
(total_loss,) = result_with_labels.to_tuple()
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
result_with_labels = model(input_ids, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
|
||||
(total_loss,) = outputs
|
||||
(total_loss,) = result_with_labels.to_tuple()
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
|
||||
)
|
||||
|
@ -295,14 +270,8 @@ class XLMModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
(logits,) = model(input_ids)
|
||||
loss, logits = model(input_ids, labels=sequence_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
|
||||
result = model(input_ids)
|
||||
result = model(input_ids, labels=sequence_labels)
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size])
|
||||
|
||||
|
@ -323,11 +292,7 @@ class XLMModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
@ -350,16 +315,12 @@ class XLMModelTester:
|
|||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ if is_torch_available():
|
|||
class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_xlm_roberta_base(self):
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-base", return_dict=True)
|
||||
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
|
||||
# The dog is cute and lives in the garden house
|
||||
|
||||
|
@ -40,14 +40,14 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
|||
# xlmr.eval()
|
||||
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
|
||||
|
||||
output = model(input_ids)[0].detach()
|
||||
output = model(input_ids)["last_hidden_state"].detach()
|
||||
self.assertEqual(output.shape, expected_output_shape)
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_xlm_roberta_large(self):
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-large")
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-large", return_dict=True)
|
||||
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
|
||||
# The dog is cute and lives in the garden house
|
||||
|
||||
|
@ -59,7 +59,7 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
|||
# xlmr.eval()
|
||||
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
|
||||
|
||||
output = model(input_ids)[0].detach()
|
||||
output = model(input_ids)["last_hidden_state"].detach()
|
||||
self.assertEqual(output.shape, expected_output_shape)
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
|
|
|
@ -137,6 +137,7 @@ class XLNetModelTester:
|
|||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -177,15 +178,10 @@ class XLNetModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
_, _ = model(input_ids_1, input_mask=input_mask)
|
||||
_, _ = model(input_ids_1, attention_mask=input_mask)
|
||||
_, _ = model(input_ids_1, token_type_ids=segment_ids)
|
||||
outputs, mems_1 = model(input_ids_1)
|
||||
|
||||
result = {
|
||||
"mems_1": mems_1,
|
||||
"outputs": outputs,
|
||||
}
|
||||
result = model(input_ids_1, input_mask=input_mask)
|
||||
result = model(input_ids_1, attention_mask=input_mask)
|
||||
result = model(input_ids_1, token_type_ids=segment_ids)
|
||||
result = model(input_ids_1)
|
||||
|
||||
config.mem_len = 0
|
||||
model = XLNetModel(config)
|
||||
|
@ -195,10 +191,10 @@ class XLNetModelTester:
|
|||
self.parent.assertEqual(len(base_model_output), 2)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
|
@ -233,7 +229,7 @@ class XLNetModelTester:
|
|||
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
|
||||
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
|
||||
|
||||
output, mems = outputs_cache
|
||||
output, mems = outputs_cache.to_tuple()
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -253,8 +249,8 @@ class XLNetModelTester:
|
|||
single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device)
|
||||
|
||||
# second forward pass
|
||||
output_from_no_past, _ = model(next_input_ids, perm_mask=causal_mask)
|
||||
output_from_past, _ = model(next_tokens, mems=mems, perm_mask=single_mask)
|
||||
output_from_no_past = model(next_input_ids, perm_mask=causal_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, mems=mems, perm_mask=single_mask)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
|
@ -283,7 +279,7 @@ class XLNetModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
_, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)
|
||||
attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)["attentions"]
|
||||
|
||||
self.parent.assertEqual(len(attentions), config.n_layer)
|
||||
self.parent.assertIsInstance(attentions[0], tuple)
|
||||
|
@ -309,36 +305,27 @@ class XLNetModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
|
||||
loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
|
||||
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1["mems"])
|
||||
|
||||
logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
|
||||
result = {
|
||||
"loss_1": loss_1,
|
||||
"mems_1": mems_1,
|
||||
"all_logits_1": all_logits_1,
|
||||
"loss_2": loss_2,
|
||||
"mems_2": mems_2,
|
||||
"all_logits_2": all_logits_2,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [])
|
||||
self.parent.assertListEqual(list(result1["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result1["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
list(list(mem.size()) for mem in result1["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [])
|
||||
self.parent.assertListEqual(list(result2["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
list(result2["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
list(list(mem.size()) for mem in result2["mems"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
|
@ -361,10 +348,9 @@ class XLNetModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids_1)
|
||||
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
|
||||
result = model(input_ids_1)
|
||||
|
||||
outputs = model(
|
||||
result_with_labels = model(
|
||||
input_ids_1,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
|
@ -373,7 +359,7 @@ class XLNetModelTester:
|
|||
p_mask=input_mask,
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
result_with_labels = model(
|
||||
input_ids_1,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
|
@ -381,23 +367,13 @@ class XLNetModelTester:
|
|||
is_impossible=is_impossible_labels,
|
||||
)
|
||||
|
||||
total_loss, mems = outputs
|
||||
total_loss, mems = result_with_labels.to_tuple()
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||
result_with_labels = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||
|
||||
total_loss, mems = outputs
|
||||
total_loss, mems = result_with_labels.to_tuple()
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
"mems": mems,
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(list(result_with_labels["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
|
||||
)
|
||||
|
@ -436,21 +412,15 @@ class XLNetModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
logits, mems_1 = model(input_ids_1)
|
||||
loss, logits, mems_1 = model(input_ids_1, labels=token_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"mems_1": mems_1,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids_1)
|
||||
result = model(input_ids_1, labels=token_labels)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
|
@ -473,21 +443,15 @@ class XLNetModelTester:
|
|||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
logits, mems_1 = model(input_ids_1)
|
||||
loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"mems_1": mems_1,
|
||||
"logits": logits,
|
||||
}
|
||||
result = model(input_ids_1)
|
||||
result = model(input_ids_1, labels=sequence_labels)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче