зеркало из https://github.com/microsoft/DeBERTa.git
Add logits
This commit is contained in:
Родитель
0e7e72fc58
Коммит
14bb78d123
|
@ -78,7 +78,8 @@ def train_model(args, model, device, train_data, eval_data):
|
|||
t_logits = None
|
||||
if args.vat_lambda>0:
|
||||
def pert_logits_fn(model, **data):
|
||||
logits,_ = model(**data)
|
||||
o = model(**data)
|
||||
logits = o['logits']
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
return logits
|
||||
|
|
Загрузка…
Ссылка в новой задаче