зеркало из https://github.com/microsoft/DeBERTa.git
Remove unused var args from masked_language_model.py.
This commit is contained in:
Родитель
31fe03f7dc
Коммит
7ec3d8620c
|
@ -33,7 +33,7 @@ class EnhancedMaskDecoder(torch.nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lm_head = BertLMPredictionHead(config, vocab_size)
|
self.lm_head = BertLMPredictionHead(config, vocab_size)
|
||||||
|
|
||||||
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None, *wargs, **kwargs):
|
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None):
|
||||||
mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos)
|
mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos)
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||||
lm_loss = torch.tensor(0).to(ctx_layers[-1])
|
lm_loss = torch.tensor(0).to(ctx_layers[-1])
|
||||||
|
@ -92,7 +92,7 @@ class MaskedLanguageModel(NNModule):
|
||||||
self.lm_predictions = EnhancedMaskDecoder(self.deberta.config, self.deberta.embeddings.word_embeddings.weight.size(0))
|
self.lm_predictions = EnhancedMaskDecoder(self.deberta.config, self.deberta.embeddings.word_embeddings.weight.size(0))
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None, **kwargs):
|
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None):
|
||||||
device = list(self.parameters())[0].device
|
device = list(self.parameters())[0].device
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
input_mask = input_mask.to(device)
|
input_mask = input_mask.to(device)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче