diff --git a/DeBERTa/apps/run.py b/DeBERTa/apps/run.py index 4c32277..4781e03 100644 --- a/DeBERTa/apps/run.py +++ b/DeBERTa/apps/run.py @@ -78,7 +78,8 @@ def train_model(args, model, device, train_data, eval_data): t_logits = None if args.vat_lambda>0: def pert_logits_fn(model, **data): - logits,_ = model(**data) + o = model(**data) + logits = o['logits'] if isinstance(logits, Sequence): logits = logits[-1] return logits