From 7ec3d8620c48b5d4a31a4985382aa2885eae7816 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 8 Apr 2021 13:25:25 -0700 Subject: [PATCH] Remove unused var args from masked_language_model.py. --- DeBERTa/apps/models/masked_language_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DeBERTa/apps/models/masked_language_model.py b/DeBERTa/apps/models/masked_language_model.py index 890df36..65ab8fc 100644 --- a/DeBERTa/apps/models/masked_language_model.py +++ b/DeBERTa/apps/models/masked_language_model.py @@ -33,7 +33,7 @@ class EnhancedMaskDecoder(torch.nn.Module): self.config = config self.lm_head = BertLMPredictionHead(config, vocab_size) - def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None, *wargs, **kwargs): + def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None): mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos) loss_fct = torch.nn.CrossEntropyLoss(reduction='none') lm_loss = torch.tensor(0).to(ctx_layers[-1]) @@ -92,7 +92,7 @@ class MaskedLanguageModel(NNModule): self.lm_predictions = EnhancedMaskDecoder(self.deberta.config, self.deberta.embeddings.word_embeddings.weight.size(0)) self.apply(self.init_weights) - def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None, **kwargs): + def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None): device = list(self.parameters())[0].device input_ids = input_ids.to(device) input_mask = input_mask.to(device)