Add implementation of SiFT (Scale invariant Fine-Tuning) which is a

variant of adversarial training
This commit is contained in:
Pengcheng He 2021-02-25 17:43:48 -05:00 коммит произвёл Pengcheng He
Родитель 7ec3d8620c
Коммит b6da4de7ab
6 изменённых файлов: 273 добавлений и 14 удалений

Просмотреть файл

@ -24,6 +24,7 @@ import json
from torch.utils.data import DataLoader
from ..utils import *
from ..utils import xtqdm as tqdm
from ..sift import AdversarialLearner,hook_sift_layer
from .tasks import load_tasks,get_task
import pdb
@ -63,8 +64,28 @@ def train_model(args, model, device, train_data, eval_data):
output = model(**data)
loss = output['loss']
return loss.mean(), data['input_ids'].size(0)
adv_modules = hook_sift_layer(model, hidden_size=model.config.hidden_size, learning_rate=args.vat_learning_rate, init_perturbation=args.vat_init_perturbation)
adv = AdversarialLearner(model, adv_modules)
def adv_loss_fn(trainer, model, data):
logits, loss = model(**data)
if isinstance(logits, Sequence):
logits = logits[-1]
v_teacher = []
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
t_logits = None
if args.vat_lambda>0:
def pert_logits_fn(model, **data):
logits,_ = model(**data)
if isinstance(logits, Sequence):
logits = logits[-1]
return logits
loss += adv.loss(logits, pert_logits_fn, loss_fn = args.vat_loss_fn, **data)*args.vat_lambda
return loss.mean(), data['input_ids'].size(0)
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = adv_loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
trainer.train()
def merge_distributed(data_list, max_len=None):
@ -387,6 +408,28 @@ def build_argument_parser():
default=None,
type=str,
help="The path of the vocabulary")
parser.add_argument('--vat_lambda',
default=0,
type=float,
help="The weight of adversarial training loss.")
parser.add_argument('--vat_learning_rate',
default=1e-4,
type=float,
help="The learning rate used to update pertubation")
parser.add_argument('--vat_init_perturbation',
default=1e-2,
type=float,
help="The initialization for pertubation")
parser.add_argument('--vat_loss_fn',
default="symmetric-kl",
type=str,
help="The loss function used to calculate adversarial loss. It can be one of symmetric-kl, kl or mse.")
return parser
if __name__ == "__main__":

Просмотреть файл

@ -131,13 +131,15 @@ class DeBERTa(torch.nn.Module):
if state is None:
state, config = load_model_state(self.pre_trained)
self.config = config
prefix = ''
for k in state:
if 'embeddings.' in k:
if not k.startswith('embeddings.'):
prefix = k[:k.index('embeddings.')]
break
def key_match(key, s):
c = [k for k in s if key in k]
assert len(c)==1, c
return c[0]
current = self.state_dict()
for c in current.keys():
current[c] = state[key_match(c, state.keys())]
self.load_state_dict(current)
missing_keys = []
unexpected_keys = []
error_msgs = []
self._load_from_state_dict(state, prefix = prefix, local_metadata=None, strict=True, missing_keys=missing_keys, unexpected_keys=unexpected_keys, error_msgs=error_msgs)

Просмотреть файл

@ -107,7 +107,7 @@ class NNModule(nn.Module):
'num_attention_heads',
'num_hidden_layers',
'vocab_size',
'max_position_embeddings']:
'max_position_embeddings'] or (k not in model_config.__dict__) or (model_config.__dict__[k] < 0):
model_config.__dict__[k] = config.__dict__[k]
if model_config is not None:
config = copy.copy(model_config)

1
DeBERTa/sift/__init__.py Normal file
Просмотреть файл

@ -0,0 +1 @@
from .sift import *

209
DeBERTa/sift/sift.py Normal file
Просмотреть файл

@ -0,0 +1,209 @@
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: penhe@microsoft.com
# Date: 01/25/2021
#
import torch
import torch.nn.functional as F
__all__ = ['PerturbationLayer', 'AdversarialLearner', 'hook_sift_layer']
class PerturbationLayer(torch.nn.Module):
def __init__(self, hidden_size, learning_rate=1e-4, init_perturbation=1e-2):
super().__init__()
self.learning_rate = learning_rate
self.init_perturbation = init_perturbation
self.delta = None
self.LayerNorm = torch.nn.LayerNorm(hidden_size, 1e-7, elementwise_affine=False)
self.adversarial_mode = False
def adversarial_(self, adversarial = True):
self.adversarial_mode = adversarial
if not adversarial:
self.delta = None
def forward(self, input):
if not self.adversarial_mode:
self.input = self.LayerNorm(input)
return self.input
else:
if self.delta is None:
self.update_delta(requires_grad=True)
return self.perturbated_input
def update_delta(self, requires_grad = False):
if not self.adversarial_mode:
return True
if self.delta is None:
delta = torch.clamp(self.input.new(self.input.size()).normal_(0, self.init_perturbation).float(), -2*self.init_perturbation, 2*self.init_perturbation)
else:
grad = self.delta.grad
self.delta.grad = None
delta = self.delta
norm = grad.norm()
if torch.isnan(norm) or torch.isinf(norm):
return False
eps = self.learning_rate
with torch.no_grad():
delta = delta + eps*grad/(1e-6 + grad.abs().max(-1, keepdim=True)[0])
self.delta = delta.float().detach().requires_grad_(requires_grad)
self.perturbated_input = (self.input.to(delta).detach() + self.delta).to(self.input)
return True
def hook_sift_layer(model, hidden_size, learning_rate=1e-4, init_perturbation=1e-2, target_module = 'embeddings.LayerNorm'):
"""
Hook the sift perturbation layer to and existing model. With this method, you can apply adversarial training
without changing the existing model implementation.
Params:
`model`: The model instance to apply adversarial training
`hidden_size`: The dimmension size of the perturbated embedding
`learning_rate`: The learning rate to update the perturbation
`init_perturbation`: The initial range of perturbation
`target_module`: The module to apply perturbation. It can be the name of the sub-module of the model or the sub-module instance.
The perturbation layer will be inserted before the sub-module.
Outputs:
The perturbation layers.
"""
if isinstance(target_module, str):
_modules = [k for n,k in model.named_modules() if target_module in n]
else:
assert isinstance(target_module, torch.nn.Module), f'{type(target_module)} is not an instance of torch.nn.Module'
_modules = [target_module]
adv_modules = []
for m in _modules:
adv = PerturbationLayer(hidden_size, learning_rate, init_perturbation)
def adv_hook(module, inputs):
return adv(inputs[0])
for h in list(m._forward_pre_hooks.keys()):
if m._forward_pre_hooks[h].__name__ == 'adv_hook':
del m._forward_pre_hooks[h]
m.register_forward_pre_hook(adv_hook)
adv_modules.append(adv)
return adv_modules
class AdversarialLearner:
""" Adversarial Learner
This class is the helper class for adversarial training.
Params:
`model`: The model instance to apply adversarial training
`perturbation_modules`: The sub modules in the model that will generate perturbations. If it's `None`,
the constructor will detect sub-modules of type `PerturbationLayer` in the model.
Example usage:
```python
# Create DeBERTa model
adv_modules = hook_sift_layer(model, hidden_size=768)
adv = AdversarialLearner(model, adv_modules)
def logits_fn(model, *wargs, **kwargs):
logits,_ = model(*wargs, **kwargs)
return logits
logits,loss = model(**data)
loss = loss + adv.loss(logits, logits_fn, **data)
# Other steps is the same as general training.
```
"""
def __init__(self, model, adv_modules=None):
if adv_modules is None:
self.adv_modules = [m for m in model.modules() if isinstance(m, PerturbationLayer)]
else:
self.adv_modules = adv_modules
self.parameters = [p for p in model.parameters()]
self.model = model
def loss(self, target, logits_fn, loss_fn = 'symmetric-kl', *wargs, **kwargs):
"""
Calculate the adversarial loss based on the given logits fucntion and loss function.
Inputs:
`target`: the logits from original inputs.
`logits_fn`: the function that produces logits based on perturbated inputs. E.g.,
```python
def logits_fn(model, *wargs, **kwargs):
logits = model(*wargs, **kwargs)
return logits
```
`loss_fn`: the function that caclulate the loss from perturbated logits and target logits.
- If it's a string, it can be pre-built loss functions, i.e. kl, symmetric_kl, mse.
- If it's a function, it will be called to calculate the loss, the signature of the function will be,
```python
def loss_fn(source_logits, target_logits):
# Calculate the loss
return loss
```
`*wargs`: the positional arguments that will be passed to the model
`**kwargs`: the key-word arguments that will be passed to the model
Outputs:
The loss based on pertubated inputs.
"""
self.prepare()
if isinstance(loss_fn, str):
loss_fn = perturbation_loss_fns[loss_fn]
pert_logits = logits_fn(self.model, *wargs, **kwargs)
pert_loss = loss_fn(pert_logits, target.detach()).sum()
pert_loss.backward()
for m in self.adv_modules:
ok = m.update_delta(True)
for r,p in zip(self.prev, self.parameters):
p.requires_grad_(r)
pert_logits = logits_fn(self.model, *wargs, **kwargs)
pert_loss = symmetric_kl(pert_logits, target)
self.cleanup()
return pert_loss.mean()
def prepare(self):
self.prev = [p.requires_grad for p in self.parameters]
for p in self.parameters:
p.requires_grad_(False)
for m in self.adv_modules:
m.adversarial_(True)
def cleanup(self):
for r,p in zip(self.prev, self.parameters):
p.requires_grad_(r)
for m in self.adv_modules:
m.adversarial_(False)
def symmetric_kl(logits, target):
logit_stu = logits.view(-1, logits.size(-1)).float()
logit_tea = target.view(-1, target.size(-1)).float()
logprob_stu = F.log_softmax(logit_stu, -1)
logprob_tea = F.log_softmax(logit_tea, -1)
prob_tea = logprob_tea.exp().detach()
prob_stu = logprob_stu.exp().detach()
floss = ((prob_tea*(-logprob_stu)).sum(-1)) # Cross Entropy
bloss = ((prob_stu*(-logprob_tea)).sum(-1)) # Cross Entropy
loss = floss + bloss
return loss
def kl(logits, target):
logit_stu = logits.view(-1, logits.size(-1)).float()
logit_tea = target.view(-1, target.size(-1)).float()
logprob_stu = F.log_softmax(logit_stu, -1)
logprob_tea = F.log_softmax(logit_tea.detach(), -1)
prob_tea = logprob_tea.exp()
loss = ((prob_tea*(-logprob_stu)).sum(-1)) # Cross Entropy
return loss
def mse(logits, target):
logit_stu = logits.view(-1, logits.size(-1)).float()
logit_tea = target.view(-1, target.size(-1)).float()
return F.mse_loss(logit_stu.view(-1),logit_tea.view(-1))
perturbation_loss_fns = {
'symmetric-kl': symmetric_kl,
'kl': kl,
'mse': mse
}

Просмотреть файл

@ -21,9 +21,13 @@ init=$1
tag=$init
case ${init,,} in
base)
parameters=" --num_train_epochs 3 \
parameters=" --num_train_epochs 6 \
--vat_lambda 5 \
--vat_learning_rate 1e-4 \
--vat_init_perturbation 1e-2 \
--fp16 True \
--warmup 1000 \
--learning_rate 2e-5 \
--learning_rate 1.5e-5 \
--train_batch_size 64 \
--cls_drop_out 0.1 "
;;
@ -78,4 +82,4 @@ python -m DeBERTa.apps.run --model_config config.json \
--task_name $Task \
--data_dir $cache_dir/glue_tasks/$Task \
--init_model $init \
--output_dir /tmp/ttonly/$tag/$task $parameters
--output_dir /tmp/ttonly/$tag/${task}_v2 $parameters