зеркало из https://github.com/microsoft/DeBERTa.git
1. Add document for DeBERTa pre-training
2. Add RTD task head 3. Merge LASER with DeBERTa
This commit is contained in:
Родитель
e59f09fbd4
Коммит
994f643ec2
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
from collections import OrderedDict, Mapping, Sequence
|
||||
|
||||
def merge_distributed(data_list, max_len=None):
|
||||
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
merged = []
|
||||
def gather(data):
|
||||
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
|
||||
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
|
||||
data_chunks[data.device.index] = data
|
||||
for i,_chunk in enumerate(data_chunks):
|
||||
torch.distributed.broadcast(_chunk, src=i)
|
||||
return data_chunks
|
||||
|
||||
for data in data_list:
|
||||
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
|
||||
if isinstance(data, Sequence):
|
||||
data_chunks = []
|
||||
for d in data:
|
||||
chunks_ = gather(d)
|
||||
data_ = torch.cat(chunks_)
|
||||
data_chunks.append(data_)
|
||||
merged.append(data_chunks)
|
||||
else:
|
||||
_chunks = gather(data)
|
||||
merged.extend(_chunks)
|
||||
else:
|
||||
merged.append(data)
|
||||
|
||||
return join_chunks(merged, max_len)
|
||||
|
||||
def join_chunks(chunks, max_len=None):
|
||||
if not isinstance(chunks[0], Sequence):
|
||||
merged = torch.cat([m.cpu() for m in chunks])
|
||||
if max_len is not None:
|
||||
return merged[:max_len]
|
||||
else:
|
||||
return merged
|
||||
else:
|
||||
data_list=[]
|
||||
for d in zip(*chunks):
|
||||
data = torch.cat([x.cpu() for x in d])
|
||||
if max_len is not None:
|
||||
data = data[:max_len]
|
||||
data_list.append(data)
|
||||
return data_list
|
||||
|
|
@ -3,3 +3,4 @@ from .multi_choice import *
|
|||
from .sequence_classification import *
|
||||
from .record_qa import *
|
||||
from .masked_language_model import *
|
||||
from .replaced_token_detection_model import *
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
#
|
||||
# Author: penhe@microsoft.com
|
||||
# Date: 04/25/2021
|
||||
#
|
||||
""" Replaced token detection model for representation learning
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import csv
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pdb
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from ...deberta import *
|
||||
|
||||
__all__ = ['LMMaskPredictionHead', 'ReplacedTokenDetectionModel']
|
||||
|
||||
class LMMaskPredictionHead(nn.Module):
|
||||
""" Replaced token prediction head
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
|
||||
self.classifer = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
def forward(self, hidden_states, input_ids, input_mask, lm_labels=None):
|
||||
# b x d
|
||||
ctx_states = hidden_states[:,0,:]
|
||||
seq_states = self.LayerNorm(ctx_states.unsqueeze(-2) + hidden_states)
|
||||
seq_states = self.dense(seq_states)
|
||||
seq_states = self.transform_act_fn(seq_states)
|
||||
|
||||
# b x max_len
|
||||
logits = self.classifer(seq_states).squeeze(-1)
|
||||
mask_loss = torch.tensor(0).to(logits).float()
|
||||
mask_labels = None
|
||||
if lm_labels is not None:
|
||||
mask_logits = logits.view(-1)
|
||||
_input_mask = input_mask.view(-1).to(mask_logits)
|
||||
input_idx = (_input_mask>0).nonzero().view(-1)
|
||||
mask_labels = ((lm_labels>0) & (lm_labels!=input_ids)).view(-1)
|
||||
mask_labels = torch.gather(mask_labels.to(mask_logits), 0, input_idx)
|
||||
mask_loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')
|
||||
mask_logits = torch.gather(mask_logits, 0, input_idx).float()
|
||||
mask_loss = mask_loss_fn(mask_logits, mask_labels)
|
||||
return logits, mask_labels, mask_loss
|
||||
|
||||
class ReplacedTokenDetectionModel(NNModule):
|
||||
""" RTD with DeBERTa
|
||||
"""
|
||||
def __init__(self, config, *wargs, **kwargs):
|
||||
super().__init__(config)
|
||||
self.deberta = DeBERTa(config)
|
||||
|
||||
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
|
||||
self.position_buckets = getattr(config, 'position_buckets', -1)
|
||||
if self.max_relative_positions <1:
|
||||
self.max_relative_positions = config.max_position_embeddings
|
||||
self.mask_predictions = LMMaskPredictionHead(self.deberta.config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None):
|
||||
device = list(self.parameters())[0].device
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
type_ids = None
|
||||
lm_labels = labels.to(device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = input_mask
|
||||
|
||||
encoder_output = self.deberta(input_ids, input_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
|
||||
encoder_layers = encoder_output['hidden_states']
|
||||
ctx_layer = encoder_layers[-1]
|
||||
(mask_logits, mask_labels, mask_loss) = self.mask_predictions(encoder_layers[-1], input_ids, input_mask, lm_labels)
|
||||
|
||||
return {
|
||||
'logits' : mask_logits,
|
||||
'labels' : mask_labels,
|
||||
'loss' : mask_loss.float(),
|
||||
}
|
|
@ -21,17 +21,20 @@ import numpy as np
|
|||
import math
|
||||
import torch
|
||||
import json
|
||||
import shutil
|
||||
from torch.utils.data import DataLoader
|
||||
from ..utils import *
|
||||
from ..utils import xtqdm as tqdm
|
||||
from ..sift import AdversarialLearner,hook_sift_layer
|
||||
from .tasks import load_tasks,get_task
|
||||
from ._utils import merge_distributed, join_chunks
|
||||
|
||||
import pdb
|
||||
|
||||
import LASER
|
||||
from LASER.training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
|
||||
from LASER.data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
|
||||
from ..training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
|
||||
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
|
||||
from ..training import get_args as get_training_args
|
||||
from ..optims import get_args as get_optims_args
|
||||
|
||||
def create_model(args, num_labels, model_class_fn):
|
||||
# Prepare model
|
||||
|
@ -46,7 +49,7 @@ def create_model(args, num_labels, model_class_fn):
|
|||
logger.info(f'Total parameters: {sum([p.numel() for p in model.parameters()])}')
|
||||
return model
|
||||
|
||||
def train_model(args, model, device, train_data, eval_data, run_eval_fn):
|
||||
def train_model(args, model, device, train_data, eval_data, run_eval_fn, train_fn=None, loss_fn=None):
|
||||
total_examples = len(train_data)
|
||||
num_train_steps = int(len(train_data)*args.num_train_epochs / args.train_batch_size)
|
||||
logger.info(" Training batch size = %d", args.train_batch_size)
|
||||
|
@ -60,76 +63,45 @@ def train_model(args, model, device, train_data, eval_data, run_eval_fn):
|
|||
eval_metric = np.mean([v[0] for k,v in results.items() if 'train' not in k])
|
||||
return eval_metric
|
||||
|
||||
def loss_fn(trainer, model, data):
|
||||
def _loss_fn(trainer, model, data):
|
||||
output = model(**data)
|
||||
loss = output['loss']
|
||||
return loss.mean(), data['input_ids'].size(0)
|
||||
|
||||
adv_modules = hook_sift_layer(model, hidden_size=model.config.hidden_size, learning_rate=args.vat_learning_rate, init_perturbation=args.vat_init_perturbation)
|
||||
adv = AdversarialLearner(model, adv_modules)
|
||||
def adv_loss_fn(trainer, model, data):
|
||||
output = model(**data)
|
||||
logits = output['logits']
|
||||
loss = output['loss']
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
v_teacher = []
|
||||
def _train_fn(args, model, device, data_fn, eval_fn, loss_fn):
|
||||
adv_modules = hook_sift_layer(model, hidden_size=model.config.hidden_size, learning_rate=args.vat_learning_rate, init_perturbation=args.vat_init_perturbation)
|
||||
adv = AdversarialLearner(model, adv_modules)
|
||||
def adv_loss_fn(trainer, model, data):
|
||||
output = model(**data)
|
||||
logits = output['logits']
|
||||
loss = output['loss']
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
v_teacher = []
|
||||
|
||||
t_logits = None
|
||||
if args.vat_lambda>0:
|
||||
def pert_logits_fn(model, **data):
|
||||
o = model(**data)
|
||||
logits = o['logits']
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
return logits
|
||||
|
||||
loss += adv.loss(logits, pert_logits_fn, loss_fn = args.vat_loss_fn, **data)*args.vat_lambda
|
||||
|
||||
return loss.mean(), data['input_ids'].size(0)
|
||||
|
||||
if loss_fn is None:
|
||||
loss_fn = adv_loss_fn
|
||||
|
||||
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
|
||||
trainer.train()
|
||||
|
||||
t_logits = None
|
||||
if args.vat_lambda>0:
|
||||
def pert_logits_fn(model, **data):
|
||||
o = model(**data)
|
||||
logits = o['logits']
|
||||
if isinstance(logits, Sequence):
|
||||
logits = logits[-1]
|
||||
return logits
|
||||
if train_fn is None:
|
||||
train_fn = _train_fn
|
||||
|
||||
loss += adv.loss(logits, pert_logits_fn, loss_fn = args.vat_loss_fn, **data)*args.vat_lambda
|
||||
|
||||
return loss.mean(), data['input_ids'].size(0)
|
||||
|
||||
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = adv_loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
|
||||
trainer.train()
|
||||
|
||||
def merge_distributed(data_list, max_len=None):
|
||||
merged = []
|
||||
def gather(data):
|
||||
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
|
||||
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
|
||||
data_chunks[data.device.index] = data
|
||||
for i,_chunk in enumerate(data_chunks):
|
||||
torch.distributed.broadcast(_chunk, src=i)
|
||||
return data_chunks
|
||||
|
||||
for data in data_list:
|
||||
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
|
||||
if isinstance(data, Sequence):
|
||||
data_chunks = []
|
||||
for d in data:
|
||||
chunks_ = gather(d)
|
||||
data_ = torch.cat(chunks_)
|
||||
data_chunks.append(data_)
|
||||
merged.append(data_chunks)
|
||||
else:
|
||||
_chunks = gather(data)
|
||||
merged.extend(_chunks)
|
||||
else:
|
||||
merged.append(data)
|
||||
if not isinstance(merged[0], Sequence):
|
||||
merged = torch.cat([m.cpu() for m in merged])
|
||||
if max_len is not None:
|
||||
return merged[:max_len]
|
||||
else:
|
||||
return merged
|
||||
else:
|
||||
data_list=[]
|
||||
for d in zip(*merged):
|
||||
data = torch.cat([x.cpu() for x in d])
|
||||
if max_len is not None:
|
||||
data = data[:max_len]
|
||||
data_list.append(data)
|
||||
return data_list
|
||||
train_fn(args, model, device, data_fn = data_fn, eval_fn = eval_fn, loss_fn = loss_fn)
|
||||
|
||||
def calc_metrics(predicts, labels, eval_loss, eval_item, eval_results, args, name, prefix, steps, tag):
|
||||
tb_metrics = OrderedDict()
|
||||
|
@ -280,12 +252,14 @@ def main(args):
|
|||
if args.do_train:
|
||||
with open(os.path.join(args.output_dir, 'model_config.json'), 'w', encoding='utf-8') as fs:
|
||||
fs.write(model.config.to_json_string() + '\n')
|
||||
shutil.copy(args.vocab_path, args.output_dir)
|
||||
logger.info("Model config {}".format(model.config))
|
||||
device = initialize_distributed(args)
|
||||
if not isinstance(device, torch.device):
|
||||
return 0
|
||||
model.to(device)
|
||||
run_eval_fn = task.run_eval_fn()
|
||||
run_eval_fn = task.get_eval_fn()
|
||||
loss_fn = task.get_loss_fn(args)
|
||||
if run_eval_fn is None:
|
||||
run_eval_fn = run_eval
|
||||
|
||||
|
@ -293,7 +267,8 @@ def main(args):
|
|||
run_eval(args, model, device, eval_data, prefix=args.tag)
|
||||
|
||||
if args.do_train:
|
||||
train_model(args, model, device, train_data, eval_data, run_eval_fn)
|
||||
train_fn = task.get_train_fn(args, model)
|
||||
train_model(args, model, device, train_data, eval_data, run_eval_fn, loss_fn=loss_fn, train_fn = train_fn)
|
||||
|
||||
if args.do_predict:
|
||||
run_predict(args, model, device, test_data, prefix=args.tag)
|
||||
|
@ -317,7 +292,7 @@ class LoadTaskAction(argparse.Action):
|
|||
type(self)._registered = True
|
||||
|
||||
def build_argument_parser():
|
||||
parser = argparse.ArgumentParser(parents=[LASER.optims.get_args(), LASER.training.get_args()], formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser = argparse.ArgumentParser(parents=[get_optims_args(), get_training_args()], formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--task_dir",
|
||||
|
|
|
@ -19,15 +19,19 @@ import random
|
|||
import torch
|
||||
import re
|
||||
import ujson as json
|
||||
from torch.utils.data import DataLoader
|
||||
from .metrics import *
|
||||
from .task import EvalData, Task
|
||||
from .task_registry import register_task
|
||||
from ...utils import xtqdm as tqdm
|
||||
from ...training import DistributedTrainer, batch_to
|
||||
from ...data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
|
||||
from ...data import ExampleInstance, ExampleSet, DynamicDataset,example_to_feature
|
||||
from ...data.example import _truncate_segments
|
||||
from ...data.example import *
|
||||
from ...utils import get_logger
|
||||
from ..models import MaskedLanguageModel
|
||||
from .._utils import merge_distributed, join_chunks
|
||||
|
||||
logger=get_logger()
|
||||
|
||||
|
@ -38,7 +42,7 @@ class NGramMaskGenerator:
|
|||
Mask ngram tokens
|
||||
https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py
|
||||
"""
|
||||
def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_seq=None, max_gram = 3, keep_prob = 0.1, mask_prob=0.8, **kwargs):
|
||||
def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_seq=None, max_gram = 1, keep_prob = 0.1, mask_prob=0.8, **kwargs):
|
||||
self.tokenizer = tokenizer
|
||||
self.mask_lm_prob = mask_lm_prob
|
||||
self.keep_prob = keep_prob
|
||||
|
@ -61,7 +65,7 @@ class NGramMaskGenerator:
|
|||
|
||||
unigrams = []
|
||||
for id in indices:
|
||||
if len(unigrams)>=1 and self.tokenizer.part_of_whole_word(tokens[id]):
|
||||
if self.max_gram>1 and len(unigrams)>=1 and self.tokenizer.part_of_whole_word(tokens[id]):
|
||||
unigrams[-1].append(id)
|
||||
else:
|
||||
unigrams.append([id])
|
||||
|
@ -165,23 +169,9 @@ dataset_size = dataset_size, shuffle=True, **kwargs)
|
|||
def get_metrics_fn(self):
|
||||
"""Calcuate metrics based on prediction results"""
|
||||
def metrics_fn(logits, labels):
|
||||
preds = np.argmax(logits, axis=-1)
|
||||
preds = logits
|
||||
acc = (preds==labels).sum()/len(labels)
|
||||
metrics = OrderedDict(accuracy= acc)
|
||||
|
||||
logits = torch.tensor(logits).cuda()
|
||||
labels = torch.tensor(labels).cuda().long()
|
||||
chk = 1024
|
||||
off = 0
|
||||
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
losses = []
|
||||
while off<labels.size(0):
|
||||
loss = loss_fn(logits[off:off+chk, :], labels[off:off+chk])
|
||||
losses.append(loss)
|
||||
off += chk
|
||||
loss = torch.cat(losses).mean()
|
||||
ppl = loss.exp().cpu().item()
|
||||
metrics['PPL'] = ppl
|
||||
return metrics
|
||||
return metrics_fn
|
||||
|
||||
|
@ -221,6 +211,64 @@ dataset_size = dataset_size, shuffle=True, **kwargs)
|
|||
features[f] = torch.tensor(features[f] + [0]*(max_seq_len - len(token_ids)), dtype=torch.int)
|
||||
return features
|
||||
|
||||
def get_eval_fn(self):
|
||||
def eval_fn(args, model, device, eval_data, prefix=None, tag=None, steps=None):
|
||||
# Run prediction for full data
|
||||
prefix = f'{tag}_{prefix}' if tag is not None else prefix
|
||||
eval_results=OrderedDict()
|
||||
eval_metric=0
|
||||
no_tqdm = (True if os.getenv('NO_TQDM', '0')!='0' else False) or args.rank>0
|
||||
for eval_item in eval_data:
|
||||
name = eval_item.name
|
||||
eval_sampler = SequentialSampler(len(eval_item.data))
|
||||
batch_sampler = BatchSampler(eval_sampler, args.eval_batch_size)
|
||||
batch_sampler = DistributedBatchSampler(batch_sampler, rank=args.rank, world_size=args.world_size)
|
||||
eval_dataloader = DataLoader(eval_item.data, batch_sampler=batch_sampler, num_workers=args.workers)
|
||||
model.eval()
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
predicts=[]
|
||||
labels=[]
|
||||
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=no_tqdm):
|
||||
batch = batch_to(batch, device)
|
||||
with torch.no_grad():
|
||||
output = model(**batch)
|
||||
logits = output['logits'].detach().argmax(dim=-1)
|
||||
tmp_eval_loss = output['loss'].detach()
|
||||
if 'labels' in output:
|
||||
label_ids = output['labels'].detach().to(device)
|
||||
else:
|
||||
label_ids = batch['labels'].to(device)
|
||||
predicts.append(logits)
|
||||
labels.append(label_ids)
|
||||
eval_loss += tmp_eval_loss.mean()
|
||||
input_ids = batch['input_ids']
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
nb_eval_steps += 1
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
predicts = merge_distributed(predicts)
|
||||
labels = merge_distributed(labels)
|
||||
|
||||
result=OrderedDict()
|
||||
metrics_fn = eval_item.metrics_fn
|
||||
metrics = metrics_fn(predicts.numpy(), labels.numpy())
|
||||
result.update(metrics)
|
||||
result['perplexity'] = torch.exp(eval_loss).item()
|
||||
critial_metrics = set(metrics.keys()) if eval_item.critial_metrics is None or len(eval_item.critial_metrics)==0 else eval_item.critial_metrics
|
||||
eval_metric = np.mean([v for k,v in metrics.items() if k in critial_metrics])
|
||||
result['eval_loss'] = eval_loss.item()
|
||||
result['eval_metric'] = eval_metric
|
||||
result['eval_samples'] = len(labels)
|
||||
if args.rank<=0:
|
||||
logger.info("***** Eval results-{}-{} *****".format(name, prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
eval_results[name]=(eval_metric, predicts, labels)
|
||||
|
||||
return eval_results
|
||||
return eval_fn
|
||||
|
||||
def get_model_class_fn(self):
|
||||
def partial_class(*wargs, **kwargs):
|
||||
return MaskedLanguageModel.load_model(*wargs, **kwargs)
|
||||
|
|
|
@ -64,10 +64,16 @@ class Task():
|
|||
label_dict = {l:i for i,l in enumerate(self.get_labels())}
|
||||
return label_dict[labelstr] if labelstr in label_dict else -1
|
||||
|
||||
def run_eval_fn(self):
|
||||
def get_train_fn(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def run_pred_fn(self):
|
||||
def get_eval_fn(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def get_pred_fn(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def get_loss_fn(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def get_metrics_fn(self):
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from .example import ExampleInstance,ExampleSet,example_to_feature
|
||||
from .dataloader import SequentialDataLoader
|
||||
from .dynamic_dataset import *
|
||||
from .data_sampler import *
|
||||
from .async_data import *
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
from queue import Queue,Empty
|
||||
from threading import Thread
|
||||
class AsyncDataLoader(object):
|
||||
def __init__(self, dataloader, buffer_size=100):
|
||||
self.buffer_size = buffer_size
|
||||
self.dataloader = dataloader
|
||||
|
||||
def __iter__(self):
|
||||
queue = Queue(self.buffer_size)
|
||||
dl=iter(self.dataloader)
|
||||
def _worker():
|
||||
while True:
|
||||
try:
|
||||
queue.put(next(dl))
|
||||
except StopIteration:
|
||||
break
|
||||
queue.put(None)
|
||||
t=Thread(target=_worker)
|
||||
t.start()
|
||||
while True:
|
||||
d = queue.get()
|
||||
if d is None:
|
||||
break
|
||||
yield d
|
||||
del t
|
||||
del queue
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataloader)
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import sys
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
__all__=['BatchSampler', 'DistributedBatchSampler', 'RandomSampler', 'SequentialSampler']
|
||||
class BatchSampler(Sampler):
|
||||
def __init__(self, sampler, batch_size):
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
for idx in self.sampler:
|
||||
batch.append(idx)
|
||||
if len(batch)==self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch)>0:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
return (len(self.sampler) + self.batch_size - 1)//self.batch_size
|
||||
|
||||
class DistributedBatchSampler(Sampler):
|
||||
def __init__(self, sampler, rank=0, world_size = 1, drop_last = False):
|
||||
self.sampler = sampler
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
def __iter__(self):
|
||||
for b in self.sampler:
|
||||
if len(b)%self.world_size != 0:
|
||||
if self.drop_last:
|
||||
break
|
||||
else:
|
||||
b.extend([b[0] for _ in range(self.world_size-len(b)%self.world_size)])
|
||||
chunk_size = len(b)//self.world_size
|
||||
yield b[self.rank*chunk_size:(self.rank+1)*chunk_size]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sampler)
|
||||
|
||||
class RandomSampler(Sampler):
|
||||
def __init__(self, total_samples:int, data_seed:int = 0):
|
||||
self.indices = np.array(np.arange(total_samples))
|
||||
self.rng = np.random.RandomState(data_seed)
|
||||
|
||||
def __iter__(self):
|
||||
self.rng.shuffle(self.indices)
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
class SequentialSampler(Sampler):
|
||||
def __init__(self, total_samples:int):
|
||||
self.indices = np.array(np.arange(total_samples))
|
||||
|
||||
def __iter__(self):
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
|
@ -0,0 +1,16 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
""" optimizers
|
||||
"""
|
||||
|
||||
from .xadam import XAdam
|
||||
from .fp16_optimizer import *
|
||||
from .lr_schedulers import SCHEDULES
|
||||
from .args import get_args
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
""" Arguments for optimizer
|
||||
"""
|
||||
import argparse
|
||||
from ..utils import boolean_string
|
||||
|
||||
__all__ = ['get_args']
|
||||
def get_args():
|
||||
parser=argparse.ArgumentParser(add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
group = parser.add_argument_group(title='Optimizer', description='Parameters for the distributed optimizer')
|
||||
group.add_argument('--fp16',
|
||||
default=False,
|
||||
type=boolean_string,
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
|
||||
group.add_argument('--loss_scale',
|
||||
type=float, default=16384,
|
||||
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
|
||||
|
||||
group.add_argument('--scale_steps',
|
||||
type=int, default=250,
|
||||
help='The steps to wait to increase the loss scale.')
|
||||
|
||||
group.add_argument('--lookahead_k',
|
||||
default=-1,
|
||||
type=int,
|
||||
help="lookahead k parameter")
|
||||
|
||||
group.add_argument('--lookahead_alpha',
|
||||
default=0.5,
|
||||
type=float,
|
||||
help="lookahead alpha parameter")
|
||||
|
||||
group.add_argument('--with_radam',
|
||||
default=False,
|
||||
type=boolean_string,
|
||||
help="whether to use RAdam")
|
||||
|
||||
group.add_argument('--opt_type',
|
||||
type=str.lower,
|
||||
default='adam',
|
||||
choices=['adam', 'admax'],
|
||||
help="The optimizer to be used.")
|
||||
|
||||
group.add_argument("--warmup_proportion",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
|
||||
group.add_argument("--lr_schedule_ends",
|
||||
default=0,
|
||||
type=float,
|
||||
help="The ended learning rate scale for learning rate scheduling")
|
||||
|
||||
group.add_argument("--lr_schedule",
|
||||
default='warmup_linear',
|
||||
type=str,
|
||||
help="The learning rate scheduler used for traning. " +
|
||||
"E.g. warmup_linear, warmup_linear_shift, warmup_cosine, warmup_constant. Default, warmup_linear")
|
||||
|
||||
group.add_argument("--max_grad_norm",
|
||||
default=1,
|
||||
type=float,
|
||||
help="The clip threshold of global gradient norm")
|
||||
|
||||
group.add_argument("--learning_rate",
|
||||
default=5e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
|
||||
group.add_argument("--epsilon",
|
||||
default=1e-6,
|
||||
type=float,
|
||||
help="epsilon setting for Adam.")
|
||||
|
||||
group.add_argument("--adam_beta1",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="The beta1 parameter for Adam.")
|
||||
|
||||
group.add_argument("--adam_beta2",
|
||||
default=0.999,
|
||||
type=float,
|
||||
help="The beta2 parameter for Adam.")
|
||||
|
||||
group.add_argument('--weight_decay',
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="The weight decay rate")
|
||||
|
||||
return parser
|
||||
|
|
@ -0,0 +1,293 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
""" FP16 optimizer wrapper
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
import ctypes
|
||||
|
||||
from ..utils import get_logger,boolean_string
|
||||
logger=get_logger()
|
||||
|
||||
__all__ = ['Fp16Optimizer', 'ExpLossScaler', 'get_world_size']
|
||||
|
||||
def get_world_size():
|
||||
try:
|
||||
wd = dist.get_world_size()
|
||||
return wd
|
||||
except:
|
||||
return 1
|
||||
|
||||
def fused_norm(input):
|
||||
return torch.norm(input, p=2, dtype=torch.float32)
|
||||
|
||||
class OptParameter(torch.Tensor):
|
||||
def __new__(cls, data, out_data=None, grad=None, name=None):
|
||||
param = torch.Tensor._make_subclass(cls, data)
|
||||
param._xgrad = grad
|
||||
param.out_data = out_data
|
||||
param._name = name
|
||||
return param
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._xgrad
|
||||
|
||||
@grad.setter
|
||||
def grad(self, grad):
|
||||
self._xgrad = grad
|
||||
|
||||
class Fp16Optimizer(object):
|
||||
def __init__(self, param_groups, optimizer_fn, loss_scaler=None, grad_clip_norm = 1.0, lookahead_k = -1, lookahead_alpha = 0.5, rank=-1, distributed=False):
|
||||
# all parameters should on the same device
|
||||
groups = []
|
||||
original_groups = []
|
||||
self.rank = rank
|
||||
self.distributed = distributed
|
||||
if self.rank<0:
|
||||
self.distributed = False
|
||||
for group in param_groups:
|
||||
if 'offset' not in group:
|
||||
group['offset'] = None
|
||||
if ('rank' not in group) or (not self.distributed):
|
||||
group['rank'] = -1
|
||||
assert group['offset'] is None, f"{group['names']}: {group['offset']}"
|
||||
group_rank = group['rank']
|
||||
params = group['params'] # parameter
|
||||
if len(params) > 1:
|
||||
flattened_params = _flatten_dense_tensors([p.data for p in params])
|
||||
unflattend_params = _unflatten_dense_tensors(flattened_params, [p.data for p in params])
|
||||
for uf,p in zip(unflattend_params, params):
|
||||
p.data = uf
|
||||
else:
|
||||
flattened_params = params[0].data.view(-1)
|
||||
if group['offset'] is not None:
|
||||
start, length = group['offset']
|
||||
flattened_params = flattened_params.narrow(0, start, length)
|
||||
|
||||
if params[0].dtype==torch.half:
|
||||
if self.rank == group_rank or (not self.distributed):
|
||||
master_params = flattened_params.clone().to(torch.float).detach_().to(flattened_params.device)
|
||||
else:
|
||||
master_params = flattened_params.clone().to(torch.float).detach_().cpu()
|
||||
group['params'] = [OptParameter(master_params, flattened_params, name='master')]
|
||||
else:
|
||||
group['params'] = [OptParameter(flattened_params, None, name='master')]
|
||||
|
||||
o_group = defaultdict(list)
|
||||
o_group['names'] = group['names']
|
||||
o_group['params'] = params
|
||||
o_group['rank'] = group_rank
|
||||
o_group['offset'] = group['offset']
|
||||
|
||||
group['names'] = ['master']
|
||||
|
||||
original_groups.append(o_group)
|
||||
groups.append(group)
|
||||
self.param_groups = groups
|
||||
self.loss_scaler = loss_scaler
|
||||
self.optimizer = optimizer_fn(self.param_groups)
|
||||
self.original_param_groups = original_groups
|
||||
self.max_grad_norm = grad_clip_norm
|
||||
self.lookahead_k = lookahead_k
|
||||
self.lookahead_alpha = lookahead_alpha
|
||||
|
||||
def backward(self, loss):
|
||||
if self.loss_scaler:
|
||||
loss_scale, loss, step_loss = self.loss_scaler.scale(loss)
|
||||
else:
|
||||
loss_scale = 1
|
||||
step_loss = loss.item()
|
||||
|
||||
loss.backward()
|
||||
return loss_scale, step_loss
|
||||
|
||||
def step(self, lr_scale, loss_scale = 1):
|
||||
grad_scale = self._grad_scale(loss_scale)
|
||||
if grad_scale is None or math.isinf(grad_scale):
|
||||
self.loss_scaler.update(False)
|
||||
return False
|
||||
|
||||
if self.lookahead_k > 0:
|
||||
for p in self.param_groups:
|
||||
if 'la_count' not in p:
|
||||
# init
|
||||
#make old copy
|
||||
p['la_count'] = 0
|
||||
p['slow_params'] = [x.data.detach().clone().requires_grad_(False) for x in p['params']]
|
||||
self.optimizer.step(grad_scale, lr_scale)
|
||||
if self.lookahead_k > 0:
|
||||
for p in self.param_groups:
|
||||
p['la_count'] += 1
|
||||
if p['la_count'] == self.lookahead_k:
|
||||
p['la_count'] = 0
|
||||
for s,f in zip(p['slow_params'], p['params']):
|
||||
s.mul_(1-self.lookahead_alpha)
|
||||
s.add_(f.data.detach()*self.lookahead_alpha)
|
||||
f.data.copy_(s, non_blocking=True)
|
||||
if hasattr(f, 'out_data') and f.out_data is not None:
|
||||
f.out_data.copy_(f.data, non_blocking=True)
|
||||
|
||||
if self.loss_scaler:
|
||||
self.loss_scaler.update(True)
|
||||
return True
|
||||
|
||||
def zero_grad(self):
|
||||
for group, o_group in zip(self.param_groups, self.original_param_groups):
|
||||
for p in group['params']:
|
||||
p.grad = None
|
||||
for p in o_group['params']:
|
||||
p.grad = None
|
||||
|
||||
def _grad_scale(self, loss_scale = 1):
|
||||
named_params = {}
|
||||
named_grads = {}
|
||||
for g in self.original_param_groups:
|
||||
for n,p in zip(g['names'], g['params']):
|
||||
named_params[n] = p
|
||||
named_grads[n] = p.grad if p.grad is not None else torch.zeros_like(p.data)
|
||||
|
||||
wd = get_world_size()
|
||||
def _reduce(group):
|
||||
grads = [named_grads[n] for n in group]
|
||||
if len(grads)>1:
|
||||
flattened_grads = _flatten_dense_tensors(grads)
|
||||
else:
|
||||
flattened_grads = grads[0],view(-1)
|
||||
|
||||
if wd > 1:
|
||||
flattened_grads /= wd
|
||||
handle = dist.all_reduce(flattened_grads, async_op=True)
|
||||
else:
|
||||
handle = None
|
||||
return flattened_grads, handle
|
||||
|
||||
def _process_grad(group, flattened_grads, max_grad, norm):
|
||||
grads = [named_grads[n] for n in group]
|
||||
norm = norm.to(flattened_grads.device)
|
||||
norm = norm + fused_norm(flattened_grads)**2
|
||||
|
||||
if len(grads) > 1:
|
||||
unflattend_grads = _unflatten_dense_tensors(flattened_grads, grads)
|
||||
else:
|
||||
unflattend_grads = [flattened_grads]
|
||||
|
||||
for n,ug in zip(group, unflattend_grads):
|
||||
named_grads[n] = ug #.to(named_params[n].data)
|
||||
|
||||
return max_grad, norm
|
||||
|
||||
group_size = 0
|
||||
group = []
|
||||
max_size = 32*1024*1024
|
||||
norm = torch.zeros(1, dtype=torch.float)
|
||||
max_grad = 0
|
||||
|
||||
all_grads = []
|
||||
for name in sorted(named_params.keys()):
|
||||
group.append(name)
|
||||
group_size += named_params[name].data.numel()
|
||||
if group_size>=max_size:
|
||||
flatten, handle = _reduce(group)
|
||||
all_grads.append([handle, flatten, group])
|
||||
group = []
|
||||
group_size = 0
|
||||
if group_size>0:
|
||||
flatten, handle = _reduce(group)
|
||||
all_grads.append([handle, flatten, group])
|
||||
group = []
|
||||
group_size = 0
|
||||
|
||||
for h,fg,group in all_grads:
|
||||
if h is not None:
|
||||
h.wait()
|
||||
max_grad, norm = _process_grad(group, fg, max_grad, norm)
|
||||
|
||||
norm = norm**0.5
|
||||
if torch.isnan(norm) or torch.isinf(norm) :#in ['-inf', 'inf', 'nan']:
|
||||
return None
|
||||
|
||||
scaled_norm = norm.detach().item()/loss_scale
|
||||
grad_scale = loss_scale
|
||||
|
||||
if self.max_grad_norm>0:
|
||||
scale = norm/(loss_scale*self.max_grad_norm)
|
||||
if scale>1:
|
||||
grad_scale *= scale
|
||||
|
||||
for group, o_g in zip(self.param_groups, self.original_param_groups):
|
||||
grads = [named_grads[n] for n in o_g['names']]
|
||||
|
||||
if len(grads) > 1:
|
||||
flattened_grads = _flatten_dense_tensors(grads)
|
||||
else:
|
||||
flattened_grads = grads[0].view(-1)
|
||||
if group['offset'] is not None:
|
||||
start, length = group['offset']
|
||||
flattened_grads = flattened_grads.narrow(0, start, length)
|
||||
if group['rank'] == self.rank or (not self.distributed):
|
||||
group['params'][0].grad = flattened_grads
|
||||
|
||||
return grad_scale
|
||||
|
||||
class ExpLossScaler:
|
||||
def __init__(self, init_scale=2**16, scale_interval=1000):
|
||||
self.cur_scale = init_scale
|
||||
self.scale_interval = scale_interval
|
||||
self.invalid_cnt = 0
|
||||
self.last_scale = 0
|
||||
self.steps = 0
|
||||
self.down_scale_smooth = 0
|
||||
|
||||
def scale(self, loss):
|
||||
assert self.cur_scale > 0, self.init_scale
|
||||
step_loss = loss.float().detach().item()
|
||||
if step_loss != 0 and math.isfinite(step_loss):
|
||||
loss_scale = self.cur_scale
|
||||
else:
|
||||
loss_scale = 1
|
||||
loss = loss.float()*loss_scale
|
||||
return (loss_scale, loss, step_loss)
|
||||
|
||||
def update(self, is_valid = True):
|
||||
if not is_valid:
|
||||
self.invalid_cnt += 1
|
||||
if self.invalid_cnt>self.down_scale_smooth:
|
||||
self.cur_scale /= 2
|
||||
self.cur_scale = max(self.cur_scale, 1)
|
||||
self.last_scale = self.steps
|
||||
else:
|
||||
self.invalid_cnt = 0
|
||||
if self.steps - self.last_scale>self.scale_interval:
|
||||
self.cur_scale *= 2
|
||||
self.last_scale = self.steps
|
||||
self.steps += 1
|
||||
|
||||
def state_dict(self):
|
||||
state = defaultdict(float)
|
||||
state['steps'] = self.steps
|
||||
state['invalid_cnt'] = self.invalid_cnt
|
||||
state['cur_scale'] = self.cur_scale
|
||||
state['last_scale'] = self.last_scale
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.steps = state['steps']
|
||||
self.invalid_cnt = state['invalid_cnt']
|
||||
self.cur_scale = state['cur_scale']
|
||||
self.last_scale = state['last_scale']
|
|
@ -0,0 +1,63 @@
|
|||
""" Learning rate schedulers
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
def warmup_cosine(step, total, warmup=0.002, ends = 0):
|
||||
x = step/total
|
||||
x = x-int(x)
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 0.5 * (1.0 + math.cos(math.pi * x))
|
||||
|
||||
def warmup_constant(step, total, warmup=0.002, ends = 0):
|
||||
x = step/total
|
||||
x = x-int(x)
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(step, total, warmup=0.002, ends = 0):
|
||||
x = step/total
|
||||
x = x-int(x)
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return (1-ends)*(1.0 - x) + ends
|
||||
|
||||
def warmup_linear_cosine(step, total, warmup=0.002, ends = 0):
|
||||
x = step/total
|
||||
x = x-int(x)
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return (1-ends)*max(0.5*(1+math.cos(math.pi*(x-warmup)/(1-warmup))), 0) + ends
|
||||
|
||||
def warmup_cyclic_linear_cosine(step, total, warmup=0.002, ends = 0):
|
||||
x = step/total
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
total = total - int(total*warmup)
|
||||
step = step - int(total*warmup)
|
||||
n_epoch = 4
|
||||
period = total//n_epoch
|
||||
k = step//period
|
||||
s = 1-k/n_epoch + 1/(2*n_epoch)*(math.pow(-1, k)*math.cos(math.pi*step/period)-1)
|
||||
return (1-ends)*max(s, 0) + ends
|
||||
|
||||
def warmup_linear_shift(step, total, warmup=0.002, ends = 0):
|
||||
x = step/total
|
||||
x = x-int(x)
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return (1-ends)*(1.0 - (x-warmup)/(1-warmup)) + ends
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine':warmup_cosine,
|
||||
'warmup_constant':warmup_constant,
|
||||
'warmup_linear':warmup_linear,
|
||||
'warmup_linear_cosine':warmup_linear_cosine,
|
||||
'warmup_cyclic_linear_cosine':warmup_cyclic_linear_cosine,
|
||||
'warmup_linear_shift':warmup_linear_shift,
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
""" Optimizer
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch import distributed as dist
|
||||
import pdb
|
||||
from .lr_schedulers import SCHEDULES
|
||||
from ..utils import get_logger
|
||||
|
||||
def adamw(data,
|
||||
out_data,
|
||||
next_m,
|
||||
next_v,
|
||||
grad,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
grad_scale, #combined_scale, g = g/scale
|
||||
step,
|
||||
eps_mode = 1, #self.eps_mode, esp inside sqrt:0, outside: 1, only update with momentum: 2
|
||||
bias_correction = 0,
|
||||
weight_decay = 0):
|
||||
if bias_correction > 0:
|
||||
lr *= bias_correction
|
||||
beta1_ = 1 - beta1
|
||||
beta2_ = 1 - beta2
|
||||
grad = grad.float()
|
||||
if grad_scale != 1:
|
||||
grad *= 1/grad_scale
|
||||
grad = grad.to(next_m)
|
||||
next_m.mul_(beta1).add_(beta1_, grad)
|
||||
# admax
|
||||
admax = eps_mode>>4
|
||||
eps_mode = eps_mode&0xF
|
||||
if admax > 0:
|
||||
torch.max(next_v.mul_(beta2), grad.abs().to(next_v), out=next_v)
|
||||
update = next_m/(next_v+eps)
|
||||
else:
|
||||
next_v.mul_(beta2).addcmul_(beta2_, grad, grad)
|
||||
if eps_mode == 0:
|
||||
update = (next_m)*(next_v+eps).rsqrt()
|
||||
elif eps_mode == 1:
|
||||
update = (next_m)/(next_v.sqrt()+eps)
|
||||
else: #=2
|
||||
update = next_m.clone()
|
||||
|
||||
if weight_decay>0:
|
||||
update.add_(weight_decay, data)
|
||||
|
||||
data.add_(-lr, update)
|
||||
if (out_data is not None) and len(out_data)>0:
|
||||
out_data.copy_(data)
|
||||
|
||||
class XAdam(Optimizer):
|
||||
"""Implements optimized version of Adam algorithm with weight decay fix.
|
||||
Params:
|
||||
lr: learning rate
|
||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||
t_total: total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default: -1
|
||||
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
||||
b1: Adams b1. Default: 0.9
|
||||
b2: Adams b2. Default: 0.999
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
weight_decay_rate: Weight decay. Default: 0.01
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
with_radam: Whether to enable radam. Default: False
|
||||
radam_th: RAdam threshold for tractable variance. Default: 4
|
||||
opt_type: The type of optimizer, [adam, admax], default: adam
|
||||
"""
|
||||
def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-8, weight_decay_rate=0.01,
|
||||
lr_ends = 0,
|
||||
max_grad_norm = 1.0,
|
||||
with_radam = False,
|
||||
radam_th = 4,
|
||||
opt_type=None,
|
||||
rank = -1):
|
||||
if not lr >= 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
self.defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
|
||||
lr_ends = lr_ends,
|
||||
max_grad_norm=max_grad_norm,
|
||||
with_radam = with_radam, radam_th = radam_th)
|
||||
self.opt_type = opt_type.lower() if opt_type is not None else ""
|
||||
self.rank = rank
|
||||
super().__init__(params, self.defaults)
|
||||
|
||||
def step(self, grad_scale = 1, lr_scale = 1):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
grad_scale: divid grad by grad_scale
|
||||
lr_scale: scale learning rate by bs_scale
|
||||
"""
|
||||
if 'global_step' not in self.state:
|
||||
self.state['global_step'] = 0
|
||||
for group in self.param_groups:
|
||||
lr_sch = self.get_group_lr_sch(group, self.state['global_step'])
|
||||
if group['rank'] == self.rank or group['rank']<0 or self.rank<0:
|
||||
for param in group['params']:
|
||||
self.update_param(group, param, grad_scale, lr_scale)
|
||||
|
||||
self.state['global_step'] += 1
|
||||
self.last_grad_scale = grad_scale
|
||||
handels = []
|
||||
for group in self.param_groups:
|
||||
if group['rank']>=0 and self.rank>=0:
|
||||
# sync
|
||||
for param in group['params']:
|
||||
out_p = param.out_data if hasattr(param, 'out_data') and (param.out_data is not None) else None
|
||||
if out_p is not None:
|
||||
h = torch.distributed.broadcast(out_p, group['rank'], async_op=True)
|
||||
else:
|
||||
h = torch.distributed.broadcast(param.data, group['rank'], async_op=True)
|
||||
handels.append(h)
|
||||
|
||||
for h in handels:
|
||||
if h is not None:
|
||||
h.wait()
|
||||
|
||||
return lr_sch
|
||||
|
||||
def get_group_lr_sch(self, group, steps):
|
||||
if group['t_total'] > 0:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = schedule_fct(steps, group['t_total'], group['warmup'], group['lr_ends'])
|
||||
else:
|
||||
lr_scheduled = 1
|
||||
return lr_scheduled
|
||||
|
||||
def update_param(self, group, param, grad_scale, lr_scale):
|
||||
grad = param.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
state = self.get_state(param)
|
||||
lr_sch = self.get_group_lr_sch(group, state['step'])
|
||||
lr = group['lr'] * lr_scale *lr_sch
|
||||
next_m, next_v = state['next_m'], state['next_v']
|
||||
beta1, beta2 = group['b1'], group['b2']
|
||||
state['step'] += 1
|
||||
|
||||
# Support for RAdam
|
||||
t = (state['step']-1) + 1
|
||||
eps_mode = 1
|
||||
if group['with_radam']:
|
||||
rou_ = 2/(1-beta2) - 1
|
||||
rou_t = rou_ - 2*t/(beta2**-t - 1)
|
||||
bias_c = 1/(1-beta1**t)
|
||||
if rou_t > group['radam_th']:
|
||||
bias_c *= math.sqrt(1 - beta2**t)
|
||||
bias_c *= math.sqrt(((rou_t - 4)*(rou_t - 2)*rou_)/((rou_ - 4)*(rou_ - 2)*rou_t))
|
||||
else:
|
||||
eps_mode = 2
|
||||
bias_c = 0
|
||||
lr *= bias_c
|
||||
|
||||
if self.opt_type == 'admax':
|
||||
eps_mode |= 0x10
|
||||
|
||||
with torch.cuda.device(param.device.index):
|
||||
out_p = param.out_data if hasattr(param, 'out_data') and (param.out_data is not None) else None
|
||||
if out_p is None or out_p.dtype != grad.dtype:
|
||||
out_p = torch.tensor([], dtype=torch.float).to(param.data)
|
||||
|
||||
weight_decay = group['weight_decay_rate']
|
||||
|
||||
adamw(param.data,
|
||||
out_p,
|
||||
next_m,
|
||||
next_v,
|
||||
grad,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
group['e'],
|
||||
grad_scale, #combined_scale, g = g/scale
|
||||
state['step'],
|
||||
eps_mode, #self.eps_mode, esp inside sqrt:0, outside: 1, only update with momentum: 2
|
||||
0, #bias_correction,
|
||||
weight_decay)
|
||||
|
||||
out_p = param.out_data if hasattr(param, 'out_data') and (param.out_data is not None) else None
|
||||
if out_p is not None and out_p.dtype != grad.dtype:
|
||||
out_p.copy_(param.data)
|
||||
|
||||
def get_state(self, param):
|
||||
state = self.state[param]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['next_m'] = torch.zeros_like(param.data)
|
||||
state['next_v'] = torch.zeros_like(param.data)
|
||||
return state
|
|
@ -0,0 +1,4 @@
|
|||
from .trainer import DistributedTrainer, set_random_seed
|
||||
from .args import get_args
|
||||
from .dist_launcher import initialize_distributed,kill_children
|
||||
from ._utils import batch_to,batch_apply
|
|
@ -0,0 +1,16 @@
|
|||
import torch
|
||||
from collections import Sequence, Mapping
|
||||
|
||||
def batch_apply(batch, fn):
|
||||
if isinstance(batch, torch.Tensor):
|
||||
return fn(batch)
|
||||
elif isinstance(batch, Sequence):
|
||||
return [batch_apply(x, fn) for x in batch]
|
||||
elif isinstance(batch, Mapping):
|
||||
return {x:batch_apply(batch[x], fn) for x in batch}
|
||||
else:
|
||||
raise NotImplementedError(f'Type of {type(batch)} are not supported in batch_apply')
|
||||
|
||||
def batch_to(batch, device):
|
||||
return batch_apply(batch, lambda x: x.to(device))
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
import argparse
|
||||
from ..utils import boolean_string
|
||||
|
||||
__all__ = ['get_args']
|
||||
|
||||
def get_args():
|
||||
parser=argparse.ArgumentParser(add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
group = parser.add_argument_group(title='Trainer', description='Parameters for the distributed trainer')
|
||||
group.add_argument('--accumulative_update',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
|
||||
group.add_argument("--dump_interval",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="Interval steps for generating checkpoint.")
|
||||
|
||||
group.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
|
||||
group.add_argument('--workers',
|
||||
type=int,
|
||||
default=2,
|
||||
help="The workers to load data.")
|
||||
|
||||
group.add_argument("--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
|
||||
group.add_argument('--seed',
|
||||
type=int,
|
||||
default=1234,
|
||||
help="random seed for initialization")
|
||||
|
||||
group.add_argument("--train_batch_size",
|
||||
default=64,
|
||||
type=int,
|
||||
help="Total batch size for training.")
|
||||
|
||||
group.add_argument("--world_size",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="[Internal] The world size of distributed training. Internal usage only!! To the world size of the program, you need to use environment. 'WORLD_SIZE'")
|
||||
|
||||
group.add_argument("--rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="[Internal] The rank id of current process. Internal usage only!! To the rank of the program, you need to use environment. 'RANK'")
|
||||
|
||||
group.add_argument("--master_ip",
|
||||
type=str,
|
||||
default=None,
|
||||
help="[Internal] The ip address of master node. Internal usage only!! To the master IP of the program, you need to use environment. 'MASTER_ADDR'")
|
||||
|
||||
group.add_argument("--master_port",
|
||||
type=str,
|
||||
default=None,
|
||||
help="[Internal] The port of master node. Internal usage only!! To the master IP of the program, you need to use environment. 'MASTER_PORT'")
|
||||
|
||||
return parser
|
|
@ -0,0 +1,163 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
import os
|
||||
import time
|
||||
import pdb
|
||||
import signal
|
||||
import torch
|
||||
from multiprocessing import Process,Pool
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import psutil
|
||||
from ..utils import set_logger, get_logger
|
||||
logger = get_logger()
|
||||
|
||||
def kill_children(proc=None, recursive = True):
|
||||
if proc is None:
|
||||
proc = psutil.Process()
|
||||
_children = proc.children(recursive=False)
|
||||
for c in _children:
|
||||
try:
|
||||
if recursive:
|
||||
kill_children(c, recursive=recursive)
|
||||
os.kill(c.pid, signal.SIGKILL)
|
||||
except:
|
||||
pass
|
||||
|
||||
for c in _children:
|
||||
try:
|
||||
c.wait(1)
|
||||
except:
|
||||
pass
|
||||
|
||||
def gc(i):
|
||||
return torch.cuda.device_count()
|
||||
|
||||
def get_ngpu():
|
||||
with Pool(1) as p:
|
||||
return p.map(gc, range(1))[0]
|
||||
|
||||
def _setup_distributed_group(args):
|
||||
"""Initialize torch.distributed."""
|
||||
|
||||
torch.backends.cudnn.enabled = False
|
||||
if args.world_size == 1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
set_logger(args.task_name, os.path.join(args.output_dir, f'training_{args.task_name}_{args.rank}.log'), rank=args.rank, verbose=1 if args.local_rank==0 else 0)
|
||||
device_id = args.rank % args.n_gpu
|
||||
if args.local_rank >= 0:
|
||||
device_id = args.local_rank
|
||||
device = torch.device("cuda", device_id)
|
||||
init_method = 'tcp://'
|
||||
init_method += args.master_ip + ':' + args.master_port
|
||||
distributed_backend = getattr(args, 'distributed_backend', 'nccl')
|
||||
torch.distributed.init_process_group(
|
||||
backend=distributed_backend,
|
||||
world_size=args.world_size, rank=args.rank,
|
||||
init_method=init_method)
|
||||
torch.cuda.set_device(device)
|
||||
n_gpu = torch.cuda.device_count()
|
||||
logger.info("device=%s, n_gpu=%d, distributed training=%r, world_size=%d", device, n_gpu, bool(args.world_size != 1), args.world_size)
|
||||
return device
|
||||
|
||||
def _get_world_size(args):
|
||||
world_size = int(os.getenv("WORLD_SIZE", '1'))
|
||||
if not hasattr(args, 'n_gpu') or args.n_gpu is None:
|
||||
n_gpu = get_ngpu()
|
||||
return n_gpu * world_size
|
||||
|
||||
def initialize_distributed(args, join=True):
|
||||
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
|
||||
args.rank = int(os.getenv('RANK', '0'))
|
||||
args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
|
||||
args.master_port = os.getenv('MASTER_PORT', '17006')
|
||||
|
||||
if args.world_size == 1:
|
||||
args.rank = 0
|
||||
args.master_ip = 'localhost'
|
||||
|
||||
if not hasattr(args, 'n_gpu') or args.n_gpu is None:
|
||||
args.n_gpu = get_ngpu()
|
||||
|
||||
args.node_rank = args.rank
|
||||
args.world_size = args.n_gpu * args.world_size
|
||||
seed = args.seed
|
||||
is_child = False
|
||||
if args.world_size>1:
|
||||
children = []
|
||||
for r in range(args.n_gpu):
|
||||
args.rank = r + args.n_gpu*args.node_rank
|
||||
args.local_rank = r
|
||||
args.seed = seed + args.rank
|
||||
child = os.fork()
|
||||
if child>0:
|
||||
children.append(child)
|
||||
else:
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
is_child = True
|
||||
break
|
||||
else:
|
||||
is_child = True
|
||||
|
||||
if is_child:
|
||||
return _setup_distributed_group(args)
|
||||
else:
|
||||
if join:
|
||||
try:
|
||||
for c in children:
|
||||
cid, ccode = os.waitpid(0,0)
|
||||
logger.debug(f'Worker {c} done with code {ccode}')
|
||||
if ccode != 0:
|
||||
logger.error(f'Worker {c} : {cid} failed with code {ccode}')
|
||||
kill_children()
|
||||
raise ValueError(f'Job failed. {cid}:{ccode}')
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.warning('Keybord interrupt by user. Terminate all processes')
|
||||
kill_children(None)
|
||||
return children
|
||||
|
||||
def test_dist_launch():
|
||||
def test_functions(args):
|
||||
global logger
|
||||
set_logger(args.task_name, os.path.join(args.output_dir, f'training_{args.task_name}_{args.node_rank}.log'), rank=args.rank)
|
||||
logger.info(args)
|
||||
|
||||
class Args:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __repr__(self):
|
||||
return str(self.__dict__)
|
||||
|
||||
args = Args()
|
||||
args.task_name = 'test'
|
||||
args.seed = 0
|
||||
args.n_gpu = None
|
||||
args.no_cuda=False
|
||||
args.output_dir = '/tmp'
|
||||
distributed_launch(args, test_functions, (args,))
|
||||
|
||||
def test_init_dist():
|
||||
class Args:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __repr__(self):
|
||||
return str(self.__dict__)
|
||||
|
||||
args = Args()
|
||||
args.task_name = 'test'
|
||||
args.seed = 0
|
||||
args.n_gpu = None
|
||||
args.no_cuda=False
|
||||
args.output_dir = '/tmp'
|
||||
device = initialize_distributed(args)
|
||||
if isinstance(device, torch.device):
|
||||
return 0
|
||||
else:
|
||||
return 1
|
|
@ -0,0 +1,170 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import pdb
|
||||
from functools import cmp_to_key
|
||||
import torch
|
||||
import re
|
||||
from ..optims import Fp16Optimizer,XAdam,ExpLossScaler,get_world_size
|
||||
from ..utils import get_logger
|
||||
logger=get_logger()
|
||||
|
||||
def xadam_factory(args, training_steps=None):
|
||||
def optimizer_fn(param_groups, max_grad_norm=None):
|
||||
with_radam = getattr(args, 'with_radam', False)
|
||||
opt_type = getattr(args, 'opt_type', None)
|
||||
optimizer = XAdam(param_groups,
|
||||
lr=args.learning_rate,
|
||||
b1=args.adam_beta1,
|
||||
b2=args.adam_beta2,
|
||||
lr_ends=args.lr_schedule_ends,
|
||||
e=args.epsilon,
|
||||
warmup=args.warmup_proportion if args.warmup_proportion<1 else args.warmup_proportion/training_steps,
|
||||
t_total=training_steps,
|
||||
schedule=args.lr_schedule,
|
||||
max_grad_norm = args.max_grad_norm if max_grad_norm is None else max_grad_norm,
|
||||
weight_decay_rate = args.weight_decay,
|
||||
with_radam = with_radam,
|
||||
opt_type = opt_type,
|
||||
rank = args.rank)
|
||||
return optimizer
|
||||
|
||||
return optimizer_fn
|
||||
|
||||
def create_xoptimizer(model, args, num_train_steps=None, no_decay=['bias', 'LayerNorm.weight']):
|
||||
if args.fp16:
|
||||
loss_scaler = ExpLossScaler(scale_interval = args.scale_steps, init_scale=args.loss_scale)
|
||||
else:
|
||||
loss_scaler = None
|
||||
|
||||
distributed_optimizer = getattr(args, 'distributed_optimizer', True)
|
||||
max_distributed_groups = getattr(args, 'max_distributed_groups', 1000000)
|
||||
world_size = get_world_size()
|
||||
if world_size<=1:
|
||||
distributed_optimizer = False
|
||||
|
||||
_no_decay = [x.strip() for x in getattr(args, 'no_decay', '').split('|') if len(x.strip())>0]
|
||||
if len(_no_decay)>0:
|
||||
no_decay = _no_decay
|
||||
|
||||
opt_fn = xadam_factory(args, num_train_steps)
|
||||
|
||||
named_params = list(model.named_parameters())
|
||||
param_size = [p.numel() for n,p in named_params]
|
||||
type_groups = defaultdict(list)
|
||||
if distributed_optimizer:
|
||||
num_groups = min(world_size, max_distributed_groups)
|
||||
max_group_size = (sum(param_size)+num_groups-1)//num_groups
|
||||
#max_group_size = max(64*1024*1024, max_group_size)
|
||||
#max_group_size = max_group_size//2
|
||||
max_group_size = (max_group_size//32)*32
|
||||
group_sizes = [0 for _ in range(num_groups)]
|
||||
group_ranks = [g*(world_size//num_groups) for g in range(num_groups)]
|
||||
else:
|
||||
# TODO: Fix inconsistent results with different group size
|
||||
max_group_size = max(64*1024*1024, max(param_size))
|
||||
num_groups = (sum(param_size)+max_group_size-1)//max_group_size
|
||||
group_sizes = [0 for _ in range(num_groups)]
|
||||
|
||||
def get_smallest_group(group_sizes):
|
||||
return np.argmin([g+i/10000 for i,g in enumerate(group_sizes)])
|
||||
|
||||
def chunk_into_pieces(param, max_size):
|
||||
num_chunks = param.numel()//max_size
|
||||
if num_chunks<2:
|
||||
return [param], [None]
|
||||
|
||||
flat = param.view(-1)
|
||||
chunks=[]
|
||||
offsets = []
|
||||
for i in range(num_chunks-1):
|
||||
chunks.append(flat.narrow(0, i*max_size, max_size))
|
||||
offsets.append([i*max_size, max_size])
|
||||
i += 1
|
||||
chunks.append(flat.narrow(0, i*max_size, flat.size(0)-i*max_size))
|
||||
offsets.append([i*max_size, flat.size(0)-i*max_size])
|
||||
assert sum([c.numel() for c in chunks])==param.numel(), f'{param.numel()}: {offsets}'
|
||||
return chunks, offsets
|
||||
|
||||
def param_cmp(x,y):
|
||||
n1,p1 = x
|
||||
n2,p2 = y
|
||||
if p1.numel() == p2.numel():
|
||||
if n1<n2:
|
||||
return -1
|
||||
elif n1>n2:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
return p1.numel() - p2.numel()
|
||||
|
||||
def add_group(param_groups, group, group_id):
|
||||
if distributed_optimizer:
|
||||
group['rank'] = group_ranks[group_id]
|
||||
param_groups.append(group.copy())
|
||||
group['params'] = []
|
||||
group['names'] = []
|
||||
group['offset'] = None
|
||||
return get_smallest_group(group_sizes),group
|
||||
|
||||
hard_reset = getattr(args, 'hard_reset', False)
|
||||
group_id = 0
|
||||
for n,p in named_params:
|
||||
key = ''
|
||||
if any(re.search(nd,n) for nd in no_decay):
|
||||
key += f'{str(p.dtype)}-nd'
|
||||
else:
|
||||
key += f'{str(p.dtype)}-d'
|
||||
type_groups[key].append((n,p))
|
||||
param_groups = []
|
||||
for key, params in type_groups.items():
|
||||
wd_theta = 0
|
||||
weight_decay = args.weight_decay
|
||||
_hard_reset = False
|
||||
if key.endswith('-nd'):
|
||||
weight_decay = 0
|
||||
else:
|
||||
_hard_reset = hard_reset
|
||||
|
||||
group = dict(params=[],
|
||||
weight_decay_rate=weight_decay,
|
||||
wd_theta = wd_theta,
|
||||
hard_reset = hard_reset,
|
||||
names=[],
|
||||
offset=None)
|
||||
params = sorted(params, key=cmp_to_key(param_cmp))
|
||||
for (n,p) in params:
|
||||
if p.numel() >= max_group_size:
|
||||
if len(group['params'])>0:
|
||||
group_id,group = add_group(param_groups, group, group_id)
|
||||
chunks, offsets = chunk_into_pieces(p, max_group_size)
|
||||
for chk, off in zip(chunks, offsets):
|
||||
group['params'].append(p)
|
||||
group['names'].append(n)
|
||||
group['offset'] = off
|
||||
group_sizes[group_id] += chk.numel()
|
||||
group_id,group = add_group(param_groups, group, group_id)
|
||||
else:
|
||||
group['params'].append(p)
|
||||
group['names'].append(n)
|
||||
group['offset'] = None
|
||||
group_sizes[group_id] += p.numel()
|
||||
if group_sizes[group_id]>=max_group_size:
|
||||
group_id,group = add_group(param_groups, group, group_id)
|
||||
if len(group['params'])>0:
|
||||
group_id,group = add_group(param_groups, group, group_id)
|
||||
|
||||
lookahead_k = getattr(args, 'lookahead_k', -1)
|
||||
lookahead_alpha = getattr(args, 'lookahead_alpha', 0.5)
|
||||
optimizer = Fp16Optimizer(param_groups, opt_fn, loss_scaler, args.max_grad_norm, lookahead_k = lookahead_k,\
|
||||
lookahead_alpha = lookahead_alpha, rank=args.rank, distributed=distributed_optimizer)
|
||||
|
||||
return optimizer
|
|
@ -0,0 +1,242 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: Pengcheng He (penhe@microsoft.com)
|
||||
# Date: 05/15/2019
|
||||
#
|
||||
|
||||
import os
|
||||
import torch
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
import pdb
|
||||
from collections import defaultdict, Mapping, Sequence, OrderedDict
|
||||
from torch.utils.data import DataLoader
|
||||
from ..data import BatchSampler, DistributedBatchSampler,RandomSampler,SequentialSampler, AsyncDataLoader
|
||||
from ..utils import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
from .dist_launcher import get_ngpu
|
||||
from .optimizer_utils import create_xoptimizer
|
||||
from ._utils import batch_to
|
||||
|
||||
__all__ = ['DistributedTrainer', 'set_random_seed']
|
||||
|
||||
def set_random_seed(seed, cpu_only=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
n_gpu = get_ngpu()
|
||||
if n_gpu > 0 and not cpu_only:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
class TrainerState:
|
||||
def __init__(self, training_steps, name=None):
|
||||
self.__dict__ = defaultdict(float)
|
||||
self.loss = 0.0
|
||||
self.examples = 0
|
||||
self.steps = 0
|
||||
self._last_report_step = 0
|
||||
self.epochs = 0
|
||||
self.next_batch = 0
|
||||
self.num_training_steps = training_steps
|
||||
self._last_report_time = time.time()
|
||||
self.best_steps = 0
|
||||
self.best_metric = -1e9
|
||||
self.name = name
|
||||
self.run_id = None
|
||||
|
||||
def update_step(self, loss, examples, loss_scale):
|
||||
self.examples += examples
|
||||
self.loss += loss
|
||||
self.steps += 1
|
||||
self.next_batch += 1
|
||||
self.loss_scale = loss_scale
|
||||
|
||||
def report_state(self):
|
||||
if self.steps <= self._last_report_step:
|
||||
return
|
||||
|
||||
end = time.time()
|
||||
start = self._last_report_time
|
||||
if self.name is not None:
|
||||
tag = f'[{self.name}]'
|
||||
else:
|
||||
tag = None
|
||||
logger.info('{}[{:0.1f}%][{:0.2f}h] Steps={}, loss={}, examples={}, loss_scale={:0.1f}, {:0.1f}s'.format(tag, 100*self.steps/self.num_training_steps, \
|
||||
(self.num_training_steps - self.steps)*(start-end)/((self.steps-self._last_report_step)*3600), self.steps, self.loss/self.steps, self.examples, self.loss_scale, end-start))
|
||||
self._last_report_time = end
|
||||
self._last_report_step = self.steps
|
||||
|
||||
class DistributedTrainer:
|
||||
def __init__(self, args, output_dir, model, device, data_fn, loss_fn=None, optimizer_fn=None, eval_fn=None, init_fn=None, update_fn=None, dump_interval = 10000, name=None, **kwargs):
|
||||
"""
|
||||
data_fn return tuples (training_dataset, training_steps, train_sampler, batch_scheduler), training_dataset is required
|
||||
loss_fn return the loss of current mini-batch and the size of the batch
|
||||
optimizer_fn return the created optimizer
|
||||
eval_fn return metrics for model selection
|
||||
"""
|
||||
self.__dict__.update(kwargs)
|
||||
self.args = args
|
||||
self.device = device
|
||||
self.eval_fn = eval_fn
|
||||
self.accumulative_update = 1
|
||||
if hasattr(args, 'accumulative_update'):
|
||||
self.accumulative_update = args.accumulative_update
|
||||
|
||||
train_data, training_steps, train_sampler = data_fn(self)
|
||||
self.train_data = train_data
|
||||
self.train_sampler = train_sampler if train_sampler is not None else RandomSampler(len(train_data))
|
||||
self.training_epochs = int(getattr(args, 'num_train_epochs', 1))
|
||||
|
||||
if training_steps is None:
|
||||
training_steps = getattr(args, 'training_steps', (len(training_data) + self.args.train_batch_size-1)//self.args.train_batch_size*self.training_epochs)
|
||||
self.training_steps = training_steps
|
||||
|
||||
self.output_dir = output_dir
|
||||
self.init_fn = init_fn
|
||||
self.trainer_state = TrainerState(self.training_steps, name = name)
|
||||
self.dump_interval = dump_interval
|
||||
|
||||
self.model = self._setup_model(args, model)
|
||||
self.post_loss_fn = None
|
||||
|
||||
def _opt_fn(trainer, model, training_steps):
|
||||
return create_xoptimizer(model, args, num_train_steps = training_steps)
|
||||
optimizer_fn = optimizer_fn if optimizer_fn is not None else _opt_fn
|
||||
|
||||
self.optimizer = optimizer_fn(self, model, training_steps)
|
||||
|
||||
def _loss_fn(trainer, model, batch):
|
||||
_,loss = model(**batch)
|
||||
batch_size = batch['input_ids'].size(0)
|
||||
return loss.mean(), batch_size
|
||||
self.loss_fn = loss_fn if loss_fn is not None else _loss_fn
|
||||
|
||||
self.initialized = False
|
||||
self.update_fn = update_fn
|
||||
|
||||
def initialize(self):
|
||||
set_random_seed(self.args.seed)
|
||||
|
||||
if self.args.world_size>1:
|
||||
torch.distributed.barrier()
|
||||
self.initialized = True
|
||||
|
||||
def train(self):
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
|
||||
rank = self.args.rank
|
||||
world_size = self.args.world_size
|
||||
for n_epoch in range(self.trainer_state.epochs, self.training_epochs):
|
||||
batch_sampler = BatchSampler(self.train_sampler, self.args.train_batch_size)
|
||||
batch_sampler = DistributedBatchSampler(batch_sampler, rank = rank, world_size = world_size)
|
||||
batch_sampler.next = self.trainer_state.next_batch
|
||||
num_workers = getattr(self.args, 'workers', 2)
|
||||
train_dataloader = DataLoader(self.train_data, batch_sampler=batch_sampler, num_workers=num_workers, worker_init_fn=self.init_fn, pin_memory=True)
|
||||
torch.cuda.empty_cache()
|
||||
for step, batch in enumerate(AsyncDataLoader(train_dataloader, 100)):
|
||||
if self.trainer_state.steps >= self.training_steps:
|
||||
break
|
||||
bs_scale = 1
|
||||
batch = batch_to(batch, self.device)
|
||||
self._train_step(batch, bs_scale)
|
||||
# Save model
|
||||
self.trainer_state.epochs += 1
|
||||
self.trainer_state.next_batch = 0
|
||||
self.trainer_state.report_state()
|
||||
self._eval_model()
|
||||
|
||||
def save_model(self, args, checkpoint_dir, chk_postfix, model, optimizer):
|
||||
save_path= os.path.join(checkpoint_dir, f'pytorch.model-{chk_postfix}.bin')
|
||||
if hasattr(model, 'module'):
|
||||
model_state = OrderedDict([(n,p) for n,p in model.module.state_dict().items()])
|
||||
else:
|
||||
model_state = OrderedDict([(n,p) for n,p in model.state_dict().items()])
|
||||
if args.rank < 1:
|
||||
torch.save(model_state, save_path)
|
||||
return save_path
|
||||
|
||||
def _eval_model(self, with_checkpoint=True):
|
||||
if with_checkpoint:
|
||||
checkpoint_dir = getattr(self.args, 'checkpoint_dir', None)
|
||||
checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else self.output_dir
|
||||
chk_postfix = f'{self.trainer_state.steps:06}'
|
||||
self.save_model(self.args, checkpoint_dir, chk_postfix, self.model, self.optimizer)
|
||||
_metric = self.trainer_state.best_metric
|
||||
_steps = self.trainer_state.best_steps
|
||||
if self.eval_fn is not None:
|
||||
metric = self.eval_fn(self, self.model, self.device, tag=f'{self.trainer_state.steps:06}-{self.training_steps}')
|
||||
if metric > _metric:
|
||||
_metric = metric
|
||||
_steps = self.trainer_state.steps
|
||||
logger.info(f'Best metric: {_metric}@{_steps}')
|
||||
self.trainer_state.best_metric, self.trainer_state.best_steps = _metric, _steps
|
||||
|
||||
def _train_step(self, data, bs_scale):
|
||||
self.model.train()
|
||||
go_next=False
|
||||
|
||||
def split(batch, parts):
|
||||
sub_batches = [{} for _ in range(parts)]
|
||||
for k in batch.keys():
|
||||
b = batch[k].size(0)
|
||||
s = (b + parts - 1)//parts
|
||||
v = batch[k].split(s)
|
||||
for i,z in enumerate(v):
|
||||
sub_batches[i][k]=z
|
||||
chunks = [b for b in sub_batches if len(b)>0]
|
||||
return chunks
|
||||
|
||||
if self.accumulative_update>1:
|
||||
data_chunks = split(data, self.accumulative_update)
|
||||
else:
|
||||
data_chunks = [data]
|
||||
|
||||
while not go_next:
|
||||
step_loss = 0
|
||||
batch_size = 0
|
||||
self.optimizer.zero_grad()
|
||||
forward_outputs = []
|
||||
for i, sub in enumerate(data_chunks):
|
||||
output = self.loss_fn(self, self.model, sub)
|
||||
if isinstance(output, dict):
|
||||
loss, sub_size = output['loss'], output['batch_size']
|
||||
else:
|
||||
loss, sub_size = output
|
||||
forward_outputs.append(output)
|
||||
loss = loss/len(data_chunks)
|
||||
if i == 0:
|
||||
loss_scale, _loss = self.optimizer.backward(loss)
|
||||
else:
|
||||
_loss = loss.float().detach().item()
|
||||
loss = loss.float() * loss_scale
|
||||
loss.backward()
|
||||
step_loss += _loss
|
||||
batch_size += sub_size
|
||||
if not self.optimizer.step(bs_scale, loss_scale):
|
||||
self.optimizer.zero_grad()
|
||||
continue
|
||||
go_next = True
|
||||
self.trainer_state.update_step(step_loss, batch_size , loss_scale)
|
||||
if self.update_fn is not None:
|
||||
self.update_fn(self, self.model, loss_scale)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.post_loss_fn is not None:
|
||||
self.post_loss_fn(forward_outputs)
|
||||
|
||||
if self.trainer_state.steps%100 == 0:
|
||||
self.trainer_state.report_state()
|
||||
if self.trainer_state.steps%self.dump_interval == 0:
|
||||
self._eval_model()
|
||||
|
||||
def _setup_model(self, args, model):
|
||||
if args.world_size > 1:
|
||||
for p in model.parameters():
|
||||
torch.distributed.broadcast(p.data, 0)
|
||||
torch.cuda.synchronize()
|
||||
return model
|
22
README.md
22
README.md
|
@ -201,7 +201,7 @@ And here are the results from the Base model
|
|||
|MNLI base| `experiments/glue/mnli.sh base`| 88.8/88.5 +/-0.2| 1.5h|
|
||||
|
||||
|
||||
#### Fine-tuning on NLU tasks
|
||||
### Fine-tuning on NLU tasks
|
||||
|
||||
We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
|
||||
|
||||
|
@ -219,7 +219,7 @@ We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
|
|||
|[DeBERTa-V3-Base](https://huggingface.co/microsoft/deberta-v3-base)|-/-|88.4/85.4|90.6/90.7|-|-|-| -| -|- |- |
|
||||
|[DeBERTa-V3-Small](https://huggingface.co/microsoft/deberta-v3-base)|-/-|82.9/80.4|88.2/87.9|-|-|-| -| -|- |- |
|
||||
|
||||
#### Fine-tuning on XNLI
|
||||
### Fine-tuning on XNLI
|
||||
|
||||
We present the dev results on XNLI with zero-shot crosslingual transfer setting, i.e. training with english data only, test on other languages.
|
||||
|
||||
|
@ -231,7 +231,11 @@ We present the dev results on XNLI with zero-shot crosslingual transfer setting,
|
|||
--------
|
||||
#### Notes.
|
||||
- <sup>1</sup> Following RoBERTa, for RTE, MRPC, STS-B, we fine-tune the tasks based on [DeBERTa-Large-MNLI](https://huggingface.co/microsoft/deberta-large-mnli), [DeBERTa-XLarge-MNLI](https://huggingface.co/microsoft/deberta-xlarge-mnli), [DeBERTa-V2-XLarge-MNLI](https://huggingface.co/microsoft/deberta-v2-xlarge-mnli), [DeBERTa-V2-XXLarge-MNLI](https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli). The results of SST-2/QQP/QNLI/SQuADv2 will also be slightly improved when start from MNLI fine-tuned models, however, we only report the numbers fine-tuned from pretrained base models for those 4 tasks.
|
||||
|
||||
|
||||
### Pre-training with MLM and RTD objectives
|
||||
|
||||
To pre-train DeBERTa with MLM and RTD objectives, please check [`experiments/language_models`](experiments/language_model)
|
||||
|
||||
|
||||
## Contacts
|
||||
|
||||
|
@ -260,17 +264,5 @@ url={https://openreview.net/forum?id=XPZIaotutsD}
|
|||
}
|
||||
```
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
|
|
|
@ -20,6 +20,19 @@ setup_glue_data $Task
|
|||
init=$1
|
||||
tag=$init
|
||||
case ${init,,} in
|
||||
bert-xsmall)
|
||||
init=/tmp/ttonly/bert-xsmall/discriminator/pytorch.model-1000000.bin
|
||||
vocab_type=spm
|
||||
vocab_path=/tmp/ttonly/bert-xsmall/discriminator/spm.model
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--fp16 True \
|
||||
--warmup 1500 \
|
||||
--learning_rate 1e-4 \
|
||||
--vocab_type $vocab_type \
|
||||
--vocab_path $vocab_path \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.1 "
|
||||
;;
|
||||
deberta-v3-small)
|
||||
parameters=" --num_train_epochs 2 \
|
||||
--fp16 False \
|
||||
|
@ -144,11 +157,13 @@ case ${init,,} in
|
|||
;;
|
||||
esac
|
||||
|
||||
export MASTER_PORT=12456
|
||||
python -m DeBERTa.apps.run --model_config config.json \
|
||||
--tag $tag \
|
||||
--do_train \
|
||||
--max_seq_len 256 \
|
||||
--eval_batch_size 256 \
|
||||
--dump_interval 1000 \
|
||||
--task_name $Task \
|
||||
--data_dir $cache_dir/glue_tasks/$Task \
|
||||
--init_model $init \
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Pre-train an efficient transformer language model for natual language understanding
|
||||
|
||||
## Data
|
||||
|
||||
We use [wiki103](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) data as example, which is publicly available. It contains three text files, `train.txt`, `valid.txt` and `text.txt`. We use `train.txt` to train the model and `valid.txt` to evalute the intermeidate checkpoints. We first need to run `prepara_data.py` to tokenize these text files. We concatenate all documents into a single text and split it into lines of tokens, while each line has at most 510 token (2 tokens are left to special tokens `[CLS]` and `[SEP]`).
|
||||
|
||||
## Pre-training with Masked Language Modeling task
|
||||
|
||||
Run `mlm.sh` to train a bert like model which uses MLM as the pre-training task. For example,
|
||||
|
||||
`mlm.sh bert-base` will train a bert base model which uses absolute position encoding
|
||||
|
||||
`mlm.sh deberta-base` will train a deberta base model which uses **Disentangled Attention**
|
||||
|
||||
## Pre-training with Replaced Token Detection task
|
||||
|
||||
Coming soon...
|
||||
|
||||
## Distributed training
|
||||
|
||||
To train with multiple node, you need to specify three environment variables,
|
||||
|
||||
`WORLD_SIZE` - Total nodes that are used for the training
|
||||
|
||||
`MASTER_ADDR` - The IP address or host name of the master node
|
||||
|
||||
`MASTER_PORT` - The port of the master node
|
||||
|
||||
`RANK` - The rank of current node
|
||||
|
||||
For example, to run train a model with 2 nodes,
|
||||
|
||||
- On **node0**,
|
||||
``` bash
|
||||
export WORLD_SIZE=2
|
||||
export MASTER_ADDR=node0
|
||||
export MASTER_PORT=7488
|
||||
export RANK=0
|
||||
./rtd.sh deberta-v3-xsmall
|
||||
```
|
||||
|
||||
- On **node1**,
|
||||
``` bash
|
||||
export WORLD_SIZE=2
|
||||
export MASTER_ADDR=node0
|
||||
export MASTER_PORT=7488
|
||||
export RANK=1
|
||||
./rtd.sh deberta-v3-xsmall
|
||||
```
|
||||
|
||||
## Model config options
|
||||
|
||||
- `relative_attention` Whether to used relative attention
|
||||
|
||||
- `pos_att_type` Relative position encoding type
|
||||
- P2C Postion to content attention
|
||||
- C2P Content to position attention
|
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 512,
|
||||
"relative_attention": true,
|
||||
"position_buckets": 256,
|
||||
"norm_rel_ebd": "layer_norm",
|
||||
"share_att_key": true,
|
||||
"pos_att_type": "p2c|c2p",
|
||||
"layer_norm_eps": 1e-7,
|
||||
"max_relative_positions": -1,
|
||||
"position_biased_input": false,
|
||||
"num_attention_heads": 12,
|
||||
"attention_head_size": 64,
|
||||
"num_hidden_layers": 12,
|
||||
"type_vocab_size": 0,
|
||||
"vocab_size": 128100
|
||||
}
|
|
@ -5,20 +5,23 @@ cd $SCRIPT_DIR
|
|||
|
||||
cache_dir=/tmp/DeBERTa/MLM/
|
||||
|
||||
max_seq_length=512
|
||||
data_dir=$cache_dir/wiki103/spm_$max_seq_length
|
||||
|
||||
function setup_wiki_data(){
|
||||
task=$1
|
||||
mkdir -p $cache_dir
|
||||
if [[ ! -e $cache_dir/spm.model ]]; then
|
||||
wget -q https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model -O $cache_dir/spm.model
|
||||
wget -q https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model -O $cache_dir/spm.model
|
||||
fi
|
||||
|
||||
if [[ ! -e $cache_dir/wiki103.zip ]]; then
|
||||
if [[ ! -e $data_dir/test.txt ]]; then
|
||||
wget -q https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -O $cache_dir/wiki103.zip
|
||||
unzip -j $cache_dir/wiki103.zip -d $cache_dir/wiki103
|
||||
mkdir -p $cache_dir/wiki103/spm
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.train.tokens -o $cache_dir/wiki103/spm/train.txt
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.valid.tokens -o $cache_dir/wiki103/spm/valid.txt
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.test.tokens -o $cache_dir/wiki103/spm/test.txt
|
||||
mkdir -p $data_dir
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.train.tokens -o $data_dir/train.txt --max_seq_length $max_seq_length
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.valid.tokens -o $data_dir/valid.txt --max_seq_length $max_seq_length
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.test.tokens -o $data_dir/test.txt --max_seq_length $max_seq_length
|
||||
fi
|
||||
}
|
||||
|
||||
|
@ -36,11 +39,20 @@ case ${init,,} in
|
|||
--learning_rate 1e-4 \
|
||||
--train_batch_size 256 \
|
||||
--max_ngram 1 \
|
||||
--fp16 True "
|
||||
;;
|
||||
deberta-base)
|
||||
parameters=" --num_train_epochs 1 \
|
||||
--model_config deberta_base.json \
|
||||
--warmup 10000 \
|
||||
--learning_rate 1e-4 \
|
||||
--train_batch_size 256 \
|
||||
--max_ngram 3 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xlarge-v2)
|
||||
parameters=" --num_train_epochs 1 \
|
||||
--model_config xlarge.json \
|
||||
--model_config deberta_xlarge.json \
|
||||
--warmup 1000 \
|
||||
--learning_rate 1e-4 \
|
||||
--train_batch_size 32 \
|
||||
|
@ -50,7 +62,7 @@ case ${init,,} in
|
|||
xxlarge-v2)
|
||||
parameters=" --num_train_epochs 1 \
|
||||
--warmup 1000 \
|
||||
--model_config xxlarge.json \
|
||||
--model_config deberta_xxlarge.json \
|
||||
--learning_rate 1e-4 \
|
||||
--train_batch_size 32 \
|
||||
--max_ngram 3 \
|
||||
|
@ -59,18 +71,19 @@ case ${init,,} in
|
|||
*)
|
||||
echo "usage $0 <Pretrained model configuration>"
|
||||
echo "Supported configurations"
|
||||
echo "bert-base - Pretrained a bert base model with DeBERTa vocabulary (12 layers, 768 hidden size, 128k vocabulary size)"
|
||||
echo "deberta-base - Pretrained a deberta base model (12 layers, 768 hidden size, 128k vocabulary size)"
|
||||
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
|
||||
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
data_dir=$cache_dir/wiki103/spm
|
||||
python -m DeBERTa.apps.run --model_config config.json \
|
||||
--tag $tag \
|
||||
--do_train \
|
||||
--num_training_steps 1000000 \
|
||||
--max_seq_len 512 \
|
||||
--max_seq_len $max_seq_length \
|
||||
--dump 10000 \
|
||||
--task_name $Task \
|
||||
--data_dir $data_dir \
|
||||
|
|
|
@ -4,8 +4,8 @@ import sys
|
|||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
def tokenize_data(input, output=None):
|
||||
p,t=deberta.load_vocab(vocab_path=None, vocab_type='spm', pretrained_id='xlarge-v2')
|
||||
def tokenize_data(input, output=None, max_seq_length=512):
|
||||
p,t=deberta.load_vocab(vocab_path=None, vocab_type='spm', pretrained_id='deberta-v3-base')
|
||||
tokenizer=deberta.tokenizers[t](p)
|
||||
if output is None:
|
||||
output=input + '.spm'
|
||||
|
@ -23,8 +23,8 @@ def tokenize_data(input, output=None):
|
|||
with open(output, 'w', encoding = 'utf-8') as wfs:
|
||||
idx = 0
|
||||
while idx < len(all_tokens):
|
||||
wfs.write(' '.join(all_tokens[idx:idx+510]) + '\n')
|
||||
idx += 510
|
||||
wfs.write(' '.join(all_tokens[idx:idx+max_seq_length-2]) + '\n')
|
||||
idx += (max_seq_length - 2)
|
||||
lines += 1
|
||||
|
||||
print(f'Saved {lines} lines to {output}')
|
||||
|
@ -32,5 +32,6 @@ def tokenize_data(input, output=None):
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input', required=True, help='The input data path')
|
||||
parser.add_argument('-o', '--output', default=None, help='The output data path')
|
||||
parser.add_argument('--max_seq_length', type=int, default=512, help='Maxium sequence length of inputs')
|
||||
args = parser.parse_args()
|
||||
tokenize_data(args.input, args.output)
|
||||
tokenize_data(args.input, args.output, args.max_seq_length)
|
||||
|
|
|
@ -10,7 +10,6 @@ ujson
|
|||
seqeval
|
||||
psutil
|
||||
sentencepiece
|
||||
laser
|
||||
#GitPython
|
||||
torch
|
||||
#torchvision
|
||||
|
|
Загрузка…
Ссылка в новой задаче