зеркало из https://github.com/microsoft/DeBERTa.git
Add implementation of SiFT (Scale invariant Fine-Tuning) which is a
variant of adversarial training
This commit is contained in:
Родитель
7ec3d8620c
Коммит
b6da4de7ab
|
@ -24,6 +24,7 @@ import json
|
|||
from torch.utils.data import DataLoader
|
||||
from ..utils import *
|
||||
from ..utils import xtqdm as tqdm
|
||||
from ..sift import AdversarialLearner,hook_sift_layer
|
||||
from .tasks import load_tasks,get_task
|
||||
|
||||
import pdb
|
||||
|
@ -63,8 +64,28 @@ def train_model(args, model, device, train_data, eval_data):
|
|||
output = model(**data)
|
||||
loss = output['loss']
|
||||
return loss.mean(), data['input_ids'].size(0)
|
||||
|
||||
adv_modules = hook_sift_layer(model, hidden_size=model.config.hidden_size, learning_rate=args.vat_learning_rate, init_perturbation=args.vat_init_perturbation)
|
||||
adv = AdversarialLearner(model, adv_modules)
|
||||
def adv_loss_fn(trainer, model, data):
|
||||
logits, loss = model(**data)
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
v_teacher = []
|
||||
|
||||
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
|
||||
t_logits = None
|
||||
if args.vat_lambda>0:
|
||||
def pert_logits_fn(model, **data):
|
||||
logits,_ = model(**data)
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
return logits
|
||||
|
||||
loss += adv.loss(logits, pert_logits_fn, loss_fn = args.vat_loss_fn, **data)*args.vat_lambda
|
||||
|
||||
return loss.mean(), data['input_ids'].size(0)
|
||||
|
||||
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = adv_loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
|
||||
trainer.train()
|
||||
|
||||
def merge_distributed(data_list, max_len=None):
|
||||
|
@ -387,6 +408,28 @@ def build_argument_parser():
|
|||
default=None,
|
||||
type=str,
|
||||
help="The path of the vocabulary")
|
||||
|
||||
parser.add_argument('--vat_lambda',
|
||||
default=0,
|
||||
type=float,
|
||||
help="The weight of adversarial training loss.")
|
||||
|
||||
parser.add_argument('--vat_learning_rate',
|
||||
default=1e-4,
|
||||
type=float,
|
||||
help="The learning rate used to update pertubation")
|
||||
|
||||
parser.add_argument('--vat_init_perturbation',
|
||||
default=1e-2,
|
||||
type=float,
|
||||
help="The initialization for pertubation")
|
||||
|
||||
parser.add_argument('--vat_loss_fn',
|
||||
default="symmetric-kl",
|
||||
type=str,
|
||||
help="The loss function used to calculate adversarial loss. It can be one of symmetric-kl, kl or mse.")
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -131,13 +131,15 @@ class DeBERTa(torch.nn.Module):
|
|||
if state is None:
|
||||
state, config = load_model_state(self.pre_trained)
|
||||
self.config = config
|
||||
|
||||
prefix = ''
|
||||
for k in state:
|
||||
if 'embeddings.' in k:
|
||||
if not k.startswith('embeddings.'):
|
||||
prefix = k[:k.index('embeddings.')]
|
||||
break
|
||||
|
||||
def key_match(key, s):
|
||||
c = [k for k in s if key in k]
|
||||
assert len(c)==1, c
|
||||
return c[0]
|
||||
current = self.state_dict()
|
||||
for c in current.keys():
|
||||
current[c] = state[key_match(c, state.keys())]
|
||||
self.load_state_dict(current)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)
|
||||
|
|
|
@ -107,7 +107,7 @@ class NNModule(nn.Module):
|
|||
'num_attention_heads',
|
||||
'num_hidden_layers',
|
||||
'vocab_size',
|
||||
'max_position_embeddings']:
|
||||
'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
|
||||
model_config.__dict__[k] = config.__dict__[k]
|
||||
if model_config is not None:
|
||||
config = copy.copy(model_config)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .sift import *
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Microsoft, Inc. 2020
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: penhe@microsoft.com
|
||||
# Date: 01/25/2021
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
__all__ = ['PerturbationLayer', 'AdversarialLearner', 'hook_sift_layer']
|
||||
|
||||
class PerturbationLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size, learning_rate=1e-4, init_perturbation=1e-2):
|
||||
super().__init__()
|
||||
self.learning_rate = learning_rate
|
||||
self.init_perturbation = init_perturbation
|
||||
self.delta = None
|
||||
self.LayerNorm = torch.nn.LayerNorm(hidden_size, 1e-7, elementwise_affine=False)
|
||||
self.adversarial_mode = False
|
||||
|
||||
def adversarial_(self, adversarial = True):
|
||||
self.adversarial_mode = adversarial
|
||||
if not adversarial:
|
||||
self.delta = None
|
||||
|
||||
def forward(self, input):
|
||||
if not self.adversarial_mode:
|
||||
self.input = self.LayerNorm(input)
|
||||
return self.input
|
||||
else:
|
||||
if self.delta is None:
|
||||
self.update_delta(requires_grad=True)
|
||||
return self.perturbated_input
|
||||
|
||||
def update_delta(self, requires_grad = False):
|
||||
if not self.adversarial_mode:
|
||||
return True
|
||||
if self.delta is None:
|
||||
delta = torch.clamp(self.input.new(self.input.size()).normal_(0, self.init_perturbation).float(), -2*self.init_perturbation, 2*self.init_perturbation)
|
||||
else:
|
||||
grad = self.delta.grad
|
||||
self.delta.grad = None
|
||||
delta = self.delta
|
||||
norm = grad.norm()
|
||||
if torch.isnan(norm) or torch.isinf(norm):
|
||||
return False
|
||||
eps = self.learning_rate
|
||||
with torch.no_grad():
|
||||
delta = delta + eps*grad/(1e-6 + grad.abs().max(-1, keepdim=True)[0])
|
||||
self.delta = delta.float().detach().requires_grad_(requires_grad)
|
||||
self.perturbated_input = (self.input.to(delta).detach() + self.delta).to(self.input)
|
||||
return True
|
||||
|
||||
def hook_sift_layer(model, hidden_size, learning_rate=1e-4, init_perturbation=1e-2, target_module = 'embeddings.LayerNorm'):
|
||||
"""
|
||||
Hook the sift perturbation layer to and existing model. With this method, you can apply adversarial training
|
||||
without changing the existing model implementation.
|
||||
|
||||
Params:
|
||||
`model`: The model instance to apply adversarial training
|
||||
`hidden_size`: The dimmension size of the perturbated embedding
|
||||
`learning_rate`: The learning rate to update the perturbation
|
||||
`init_perturbation`: The initial range of perturbation
|
||||
`target_module`: The module to apply perturbation. It can be the name of the sub-module of the model or the sub-module instance.
|
||||
The perturbation layer will be inserted before the sub-module.
|
||||
|
||||
Outputs:
|
||||
The perturbation layers.
|
||||
|
||||
"""
|
||||
if isinstance(target_module, str):
|
||||
_modules = [k for n,k in model.named_modules() if target_module in n]
|
||||
else:
|
||||
assert isinstance(target_module, torch.nn.Module), f'{type(target_module)} is not an instance of torch.nn.Module'
|
||||
_modules = [target_module]
|
||||
adv_modules = []
|
||||
for m in _modules:
|
||||
adv = PerturbationLayer(hidden_size, learning_rate, init_perturbation)
|
||||
def adv_hook(module, inputs):
|
||||
return adv(inputs[0])
|
||||
for h in list(m._forward_pre_hooks.keys()):
|
||||
if m._forward_pre_hooks[h].__name__ == 'adv_hook':
|
||||
del m._forward_pre_hooks[h]
|
||||
m.register_forward_pre_hook(adv_hook)
|
||||
adv_modules.append(adv)
|
||||
return adv_modules
|
||||
|
||||
class AdversarialLearner:
|
||||
""" Adversarial Learner
|
||||
This class is the helper class for adversarial training.
|
||||
|
||||
Params:
|
||||
`model`: The model instance to apply adversarial training
|
||||
`perturbation_modules`: The sub modules in the model that will generate perturbations. If it's `None`,
|
||||
the constructor will detect sub-modules of type `PerturbationLayer` in the model.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Create DeBERTa model
|
||||
adv_modules = hook_sift_layer(model, hidden_size=768)
|
||||
adv = AdversarialLearner(model, adv_modules)
|
||||
def logits_fn(model, *wargs, **kwargs):
|
||||
logits,_ = model(*wargs, **kwargs)
|
||||
return logits
|
||||
logits,loss = model(**data)
|
||||
|
||||
loss = loss + adv.loss(logits, logits_fn, **data)
|
||||
# Other steps is the same as general training.
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
def __init__(self, model, adv_modules=None):
|
||||
if adv_modules is None:
|
||||
self.adv_modules = [m for m in model.modules() if isinstance(m, PerturbationLayer)]
|
||||
else:
|
||||
self.adv_modules = adv_modules
|
||||
self.parameters = [p for p in model.parameters()]
|
||||
self.model = model
|
||||
|
||||
def loss(self, target, logits_fn, loss_fn = 'symmetric-kl', *wargs, **kwargs):
|
||||
"""
|
||||
Calculate the adversarial loss based on the given logits fucntion and loss function.
|
||||
Inputs:
|
||||
`target`: the logits from original inputs.
|
||||
`logits_fn`: the function that produces logits based on perturbated inputs. E.g.,
|
||||
```python
|
||||
def logits_fn(model, *wargs, **kwargs):
|
||||
logits = model(*wargs, **kwargs)
|
||||
return logits
|
||||
```
|
||||
`loss_fn`: the function that caclulate the loss from perturbated logits and target logits.
|
||||
- If it's a string, it can be pre-built loss functions, i.e. kl, symmetric_kl, mse.
|
||||
- If it's a function, it will be called to calculate the loss, the signature of the function will be,
|
||||
```python
|
||||
def loss_fn(source_logits, target_logits):
|
||||
# Calculate the loss
|
||||
return loss
|
||||
```
|
||||
`*wargs`: the positional arguments that will be passed to the model
|
||||
`**kwargs`: the key-word arguments that will be passed to the model
|
||||
Outputs:
|
||||
The loss based on pertubated inputs.
|
||||
"""
|
||||
self.prepare()
|
||||
if isinstance(loss_fn, str):
|
||||
loss_fn = perturbation_loss_fns[loss_fn]
|
||||
pert_logits = logits_fn(self.model, *wargs, **kwargs)
|
||||
pert_loss = loss_fn(pert_logits, target.detach()).sum()
|
||||
pert_loss.backward()
|
||||
for m in self.adv_modules:
|
||||
ok = m.update_delta(True)
|
||||
|
||||
for r,p in zip(self.prev, self.parameters):
|
||||
p.requires_grad_(r)
|
||||
pert_logits = logits_fn(self.model, *wargs, **kwargs)
|
||||
pert_loss = symmetric_kl(pert_logits, target)
|
||||
|
||||
self.cleanup()
|
||||
return pert_loss.mean()
|
||||
|
||||
def prepare(self):
|
||||
self.prev = [p.requires_grad for p in self.parameters]
|
||||
for p in self.parameters:
|
||||
p.requires_grad_(False)
|
||||
for m in self.adv_modules:
|
||||
m.adversarial_(True)
|
||||
|
||||
def cleanup(self):
|
||||
for r,p in zip(self.prev, self.parameters):
|
||||
p.requires_grad_(r)
|
||||
|
||||
for m in self.adv_modules:
|
||||
m.adversarial_(False)
|
||||
|
||||
def symmetric_kl(logits, target):
|
||||
logit_stu = logits.view(-1, logits.size(-1)).float()
|
||||
logit_tea = target.view(-1, target.size(-1)).float()
|
||||
logprob_stu = F.log_softmax(logit_stu, -1)
|
||||
logprob_tea = F.log_softmax(logit_tea, -1)
|
||||
prob_tea = logprob_tea.exp().detach()
|
||||
prob_stu = logprob_stu.exp().detach()
|
||||
floss = ((prob_tea*(-logprob_stu)).sum(-1)) # Cross Entropy
|
||||
bloss = ((prob_stu*(-logprob_tea)).sum(-1)) # Cross Entropy
|
||||
loss = floss + bloss
|
||||
return loss
|
||||
|
||||
def kl(logits, target):
|
||||
logit_stu = logits.view(-1, logits.size(-1)).float()
|
||||
logit_tea = target.view(-1, target.size(-1)).float()
|
||||
logprob_stu = F.log_softmax(logit_stu, -1)
|
||||
logprob_tea = F.log_softmax(logit_tea.detach(), -1)
|
||||
prob_tea = logprob_tea.exp()
|
||||
loss = ((prob_tea*(-logprob_stu)).sum(-1)) # Cross Entropy
|
||||
return loss
|
||||
|
||||
def mse(logits, target):
|
||||
logit_stu = logits.view(-1, logits.size(-1)).float()
|
||||
logit_tea = target.view(-1, target.size(-1)).float()
|
||||
return F.mse_loss(logit_stu.view(-1),logit_tea.view(-1))
|
||||
|
||||
perturbation_loss_fns = {
|
||||
'symmetric-kl': symmetric_kl,
|
||||
'kl': kl,
|
||||
'mse': mse
|
||||
}
|
|
@ -21,9 +21,13 @@ init=$1
|
|||
tag=$init
|
||||
case ${init,,} in
|
||||
base)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
parameters=" --num_train_epochs 6 \
|
||||
--vat_lambda 5 \
|
||||
--vat_learning_rate 1e-4 \
|
||||
--vat_init_perturbation 1e-2 \
|
||||
--fp16 True \
|
||||
--warmup 1000 \
|
||||
--learning_rate 2e-5 \
|
||||
--learning_rate 1.5e-5 \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.1 "
|
||||
;;
|
||||
|
@ -78,4 +82,4 @@ python -m DeBERTa.apps.run --model_config config.json \
|
|||
--task_name $Task \
|
||||
--data_dir $cache_dir/glue_tasks/$Task \
|
||||
--init_model $init \
|
||||
--output_dir /tmp/ttonly/$tag/$task $parameters
|
||||
--output_dir /tmp/ttonly/$tag/${task}_v2 $parameters
|
||||
|
|
Загрузка…
Ссылка в новой задаче