Remove unused var args from masked_language_model.py.

This commit is contained in:
Sergii Dymchenko 2021-04-08 13:25:25 -07:00 коммит произвёл Pengcheng He
Родитель 31fe03f7dc
Коммит 7ec3d8620c
1 изменённых файлов: 2 добавлений и 2 удалений

Просмотреть файл

@ -33,7 +33,7 @@ class EnhancedMaskDecoder(torch.nn.Module):
self.config = config
self.lm_head = BertLMPredictionHead(config, vocab_size)
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None, *wargs, **kwargs):
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None):
mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos)
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
lm_loss = torch.tensor(0).to(ctx_layers[-1])
@ -92,7 +92,7 @@ class MaskedLanguageModel(NNModule):
self.lm_predictions = EnhancedMaskDecoder(self.deberta.config, self.deberta.embeddings.word_embeddings.weight.size(0))
self.apply(self.init_weights)
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None, **kwargs):
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None):
device = list(self.parameters())[0].device
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)