1. Add document for DeBERTa pre-training

2. Add RTD task head
3. Merge LASER with DeBERTa
This commit is contained in:
Pengcheng He 2021-11-19 20:21:06 -05:00 коммит произвёл Pengcheng He
Родитель e59f09fbd4
Коммит 994f643ec2
29 изменённых файлов: 1870 добавлений и 121 удалений

51
DeBERTa/apps/_utils.py Normal file
Просмотреть файл

@ -0,0 +1,51 @@
import torch
from collections import OrderedDict, Mapping, Sequence
def merge_distributed(data_list, max_len=None):
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
world_size = torch.distributed.get_world_size()
else:
world_size = 1
merged = []
def gather(data):
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(world_size)]
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
data_chunks[data.device.index] = data
for i,_chunk in enumerate(data_chunks):
torch.distributed.broadcast(_chunk, src=i)
return data_chunks
for data in data_list:
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
if isinstance(data, Sequence):
data_chunks = []
for d in data:
chunks_ = gather(d)
data_ = torch.cat(chunks_)
data_chunks.append(data_)
merged.append(data_chunks)
else:
_chunks = gather(data)
merged.extend(_chunks)
else:
merged.append(data)
return join_chunks(merged, max_len)
def join_chunks(chunks, max_len=None):
if not isinstance(chunks[0], Sequence):
merged = torch.cat([m.cpu() for m in chunks])
if max_len is not None:
return merged[:max_len]
else:
return merged
else:
data_list=[]
for d in zip(*chunks):
data = torch.cat([x.cpu() for x in d])
if max_len is not None:
data = data[:max_len]
data_list.append(data)
return data_list

Просмотреть файл

@ -3,3 +3,4 @@ from .multi_choice import *
from .sequence_classification import *
from .record_qa import *
from .masked_language_model import *
from .replaced_token_detection_model import *

Просмотреть файл

@ -0,0 +1,97 @@
#
# Author: penhe@microsoft.com
# Date: 04/25/2021
#
""" Replaced token detection model for representation learning
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from concurrent.futures import ThreadPoolExecutor
import csv
import os
import json
import random
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import torch.nn as nn
import pdb
from collections.abc import Mapping
from copy import copy
from ...deberta import *
__all__ = ['LMMaskPredictionHead', 'ReplacedTokenDetectionModel']
class LMMaskPredictionHead(nn.Module):
""" Replaced token prediction head
"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
self.classifer = nn.Linear(config.hidden_size, 1)
def forward(self, hidden_states, input_ids, input_mask, lm_labels=None):
# b x d
ctx_states = hidden_states[:,0,:]
seq_states = self.LayerNorm(ctx_states.unsqueeze(-2) + hidden_states)
seq_states = self.dense(seq_states)
seq_states = self.transform_act_fn(seq_states)
# b x max_len
logits = self.classifer(seq_states).squeeze(-1)
mask_loss = torch.tensor(0).to(logits).float()
mask_labels = None
if lm_labels is not None:
mask_logits = logits.view(-1)
_input_mask = input_mask.view(-1).to(mask_logits)
input_idx = (_input_mask>0).nonzero().view(-1)
mask_labels = ((lm_labels>0) & (lm_labels!=input_ids)).view(-1)
mask_labels = torch.gather(mask_labels.to(mask_logits), 0, input_idx)
mask_loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none')
mask_logits = torch.gather(mask_logits, 0, input_idx).float()
mask_loss = mask_loss_fn(mask_logits, mask_labels)
return logits, mask_labels, mask_loss
class ReplacedTokenDetectionModel(NNModule):
""" RTD with DeBERTa
"""
def __init__(self, config, *wargs, **kwargs):
super().__init__(config)
self.deberta = DeBERTa(config)
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
self.position_buckets = getattr(config, 'position_buckets', -1)
if self.max_relative_positions <1:
self.max_relative_positions = config.max_position_embeddings
self.mask_predictions = LMMaskPredictionHead(self.deberta.config)
self.apply(self.init_weights)
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None):
device = list(self.parameters())[0].device
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
type_ids = None
lm_labels = labels.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
else:
attention_mask = input_mask
encoder_output = self.deberta(input_ids, input_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
encoder_layers = encoder_output['hidden_states']
ctx_layer = encoder_layers[-1]
(mask_logits, mask_labels, mask_loss) = self.mask_predictions(encoder_layers[-1], input_ids, input_mask, lm_labels)
return {
'logits' : mask_logits,
'labels' : mask_labels,
'loss' : mask_loss.float(),
}

Просмотреть файл

@ -21,17 +21,20 @@ import numpy as np
import math
import torch
import json
import shutil
from torch.utils.data import DataLoader
from ..utils import *
from ..utils import xtqdm as tqdm
from ..sift import AdversarialLearner,hook_sift_layer
from .tasks import load_tasks,get_task
from ._utils import merge_distributed, join_chunks
import pdb
import LASER
from LASER.training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
from LASER.data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
from ..training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
from ..training import get_args as get_training_args
from ..optims import get_args as get_optims_args
def create_model(args, num_labels, model_class_fn):
# Prepare model
@ -46,7 +49,7 @@ def create_model(args, num_labels, model_class_fn):
logger.info(f'Total parameters: {sum([p.numel() for p in model.parameters()])}')
return model
def train_model(args, model, device, train_data, eval_data, run_eval_fn):
def train_model(args, model, device, train_data, eval_data, run_eval_fn, train_fn=None, loss_fn=None):
total_examples = len(train_data)
num_train_steps = int(len(train_data)*args.num_train_epochs / args.train_batch_size)
logger.info(" Training batch size = %d", args.train_batch_size)
@ -60,11 +63,12 @@ def train_model(args, model, device, train_data, eval_data, run_eval_fn):
eval_metric = np.mean([v[0] for k,v in results.items() if 'train' not in k])
return eval_metric
def loss_fn(trainer, model, data):
def _loss_fn(trainer, model, data):
output = model(**data)
loss = output['loss']
return loss.mean(), data['input_ids'].size(0)
def _train_fn(args, model, device, data_fn, eval_fn, loss_fn):
adv_modules = hook_sift_layer(model, hidden_size=model.config.hidden_size, learning_rate=args.vat_learning_rate, init_perturbation=args.vat_init_perturbation)
adv = AdversarialLearner(model, adv_modules)
def adv_loss_fn(trainer, model, data):
@ -88,48 +92,16 @@ def train_model(args, model, device, train_data, eval_data, run_eval_fn):
return loss.mean(), data['input_ids'].size(0)
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = adv_loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
if loss_fn is None:
loss_fn = adv_loss_fn
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
trainer.train()
def merge_distributed(data_list, max_len=None):
merged = []
def gather(data):
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(args.world_size)]
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
data_chunks[data.device.index] = data
for i,_chunk in enumerate(data_chunks):
torch.distributed.broadcast(_chunk, src=i)
return data_chunks
if train_fn is None:
train_fn = _train_fn
for data in data_list:
if torch.distributed.is_initialized() and torch.distributed.get_world_size()>1:
if isinstance(data, Sequence):
data_chunks = []
for d in data:
chunks_ = gather(d)
data_ = torch.cat(chunks_)
data_chunks.append(data_)
merged.append(data_chunks)
else:
_chunks = gather(data)
merged.extend(_chunks)
else:
merged.append(data)
if not isinstance(merged[0], Sequence):
merged = torch.cat([m.cpu() for m in merged])
if max_len is not None:
return merged[:max_len]
else:
return merged
else:
data_list=[]
for d in zip(*merged):
data = torch.cat([x.cpu() for x in d])
if max_len is not None:
data = data[:max_len]
data_list.append(data)
return data_list
train_fn(args, model, device, data_fn = data_fn, eval_fn = eval_fn, loss_fn = loss_fn)
def calc_metrics(predicts, labels, eval_loss, eval_item, eval_results, args, name, prefix, steps, tag):
tb_metrics = OrderedDict()
@ -280,12 +252,14 @@ def main(args):
if args.do_train:
with open(os.path.join(args.output_dir, 'model_config.json'), 'w', encoding='utf-8') as fs:
fs.write(model.config.to_json_string() + '\n')
shutil.copy(args.vocab_path, args.output_dir)
logger.info("Model config {}".format(model.config))
device = initialize_distributed(args)
if not isinstance(device, torch.device):
return 0
model.to(device)
run_eval_fn = task.run_eval_fn()
run_eval_fn = task.get_eval_fn()
loss_fn = task.get_loss_fn(args)
if run_eval_fn is None:
run_eval_fn = run_eval
@ -293,7 +267,8 @@ def main(args):
run_eval(args, model, device, eval_data, prefix=args.tag)
if args.do_train:
train_model(args, model, device, train_data, eval_data, run_eval_fn)
train_fn = task.get_train_fn(args, model)
train_model(args, model, device, train_data, eval_data, run_eval_fn, loss_fn=loss_fn, train_fn = train_fn)
if args.do_predict:
run_predict(args, model, device, test_data, prefix=args.tag)
@ -317,7 +292,7 @@ class LoadTaskAction(argparse.Action):
type(self)._registered = True
def build_argument_parser():
parser = argparse.ArgumentParser(parents=[LASER.optims.get_args(), LASER.training.get_args()], formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = argparse.ArgumentParser(parents=[get_optims_args(), get_training_args()], formatter_class=argparse.ArgumentDefaultsHelpFormatter)
## Required parameters
parser.add_argument("--task_dir",

Просмотреть файл

@ -19,15 +19,19 @@ import random
import torch
import re
import ujson as json
from torch.utils.data import DataLoader
from .metrics import *
from .task import EvalData, Task
from .task_registry import register_task
from ...utils import xtqdm as tqdm
from ...training import DistributedTrainer, batch_to
from ...data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
from ...data import ExampleInstance, ExampleSet, DynamicDataset,example_to_feature
from ...data.example import _truncate_segments
from ...data.example import *
from ...utils import get_logger
from ..models import MaskedLanguageModel
from .._utils import merge_distributed, join_chunks
logger=get_logger()
@ -38,7 +42,7 @@ class NGramMaskGenerator:
Mask ngram tokens
https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py
"""
def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_seq=None, max_gram = 3, keep_prob = 0.1, mask_prob=0.8, **kwargs):
def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_seq=None, max_gram = 1, keep_prob = 0.1, mask_prob=0.8, **kwargs):
self.tokenizer = tokenizer
self.mask_lm_prob = mask_lm_prob
self.keep_prob = keep_prob
@ -61,7 +65,7 @@ class NGramMaskGenerator:
unigrams = []
for id in indices:
if len(unigrams)>=1 and self.tokenizer.part_of_whole_word(tokens[id]):
if self.max_gram>1 and len(unigrams)>=1 and self.tokenizer.part_of_whole_word(tokens[id]):
unigrams[-1].append(id)
else:
unigrams.append([id])
@ -165,23 +169,9 @@ dataset_size = dataset_size, shuffle=True, **kwargs)
def get_metrics_fn(self):
"""Calcuate metrics based on prediction results"""
def metrics_fn(logits, labels):
preds = np.argmax(logits, axis=-1)
preds = logits
acc = (preds==labels).sum()/len(labels)
metrics = OrderedDict(accuracy= acc)
logits = torch.tensor(logits).cuda()
labels = torch.tensor(labels).cuda().long()
chk = 1024
off = 0
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
losses = []
while off<labels.size(0):
loss = loss_fn(logits[off:off+chk, :], labels[off:off+chk])
losses.append(loss)
off += chk
loss = torch.cat(losses).mean()
ppl = loss.exp().cpu().item()
metrics['PPL'] = ppl
return metrics
return metrics_fn
@ -221,6 +211,64 @@ dataset_size = dataset_size, shuffle=True, **kwargs)
features[f] = torch.tensor(features[f] + [0]*(max_seq_len - len(token_ids)), dtype=torch.int)
return features
def get_eval_fn(self):
def eval_fn(args, model, device, eval_data, prefix=None, tag=None, steps=None):
# Run prediction for full data
prefix = f'{tag}_{prefix}' if tag is not None else prefix
eval_results=OrderedDict()
eval_metric=0
no_tqdm = (True if os.getenv('NO_TQDM', '0')!='0' else False) or args.rank>0
for eval_item in eval_data:
name = eval_item.name
eval_sampler = SequentialSampler(len(eval_item.data))
batch_sampler = BatchSampler(eval_sampler, args.eval_batch_size)
batch_sampler = DistributedBatchSampler(batch_sampler, rank=args.rank, world_size=args.world_size)
eval_dataloader = DataLoader(eval_item.data, batch_sampler=batch_sampler, num_workers=args.workers)
model.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
predicts=[]
labels=[]
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=no_tqdm):
batch = batch_to(batch, device)
with torch.no_grad():
output = model(**batch)
logits = output['logits'].detach().argmax(dim=-1)
tmp_eval_loss = output['loss'].detach()
if 'labels' in output:
label_ids = output['labels'].detach().to(device)
else:
label_ids = batch['labels'].to(device)
predicts.append(logits)
labels.append(label_ids)
eval_loss += tmp_eval_loss.mean()
input_ids = batch['input_ids']
nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
predicts = merge_distributed(predicts)
labels = merge_distributed(labels)
result=OrderedDict()
metrics_fn = eval_item.metrics_fn
metrics = metrics_fn(predicts.numpy(), labels.numpy())
result.update(metrics)
result['perplexity'] = torch.exp(eval_loss).item()
critial_metrics = set(metrics.keys()) if eval_item.critial_metrics is None or len(eval_item.critial_metrics)==0 else eval_item.critial_metrics
eval_metric = np.mean([v for k,v in metrics.items() if k in critial_metrics])
result['eval_loss'] = eval_loss.item()
result['eval_metric'] = eval_metric
result['eval_samples'] = len(labels)
if args.rank<=0:
logger.info("***** Eval results-{}-{} *****".format(name, prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
eval_results[name]=(eval_metric, predicts, labels)
return eval_results
return eval_fn
def get_model_class_fn(self):
def partial_class(*wargs, **kwargs):
return MaskedLanguageModel.load_model(*wargs, **kwargs)

Просмотреть файл

@ -64,10 +64,16 @@ class Task():
label_dict = {l:i for i,l in enumerate(self.get_labels())}
return label_dict[labelstr] if labelstr in label_dict else -1
def run_eval_fn(self):
def get_train_fn(self, *args, **kwargs):
return None
def run_pred_fn(self):
def get_eval_fn(self, *args, **kwargs):
return None
def get_pred_fn(self, *args, **kwargs):
return None
def get_loss_fn(self, *args, **kwargs):
return None
def get_metrics_fn(self):

Просмотреть файл

@ -1,3 +1,5 @@
from .example import ExampleInstance,ExampleSet,example_to_feature
from .dataloader import SequentialDataLoader
from .dynamic_dataset import *
from .data_sampler import *
from .async_data import *

Просмотреть файл

@ -0,0 +1,38 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
from queue import Queue,Empty
from threading import Thread
class AsyncDataLoader(object):
def __init__(self, dataloader, buffer_size=100):
self.buffer_size = buffer_size
self.dataloader = dataloader
def __iter__(self):
queue = Queue(self.buffer_size)
dl=iter(self.dataloader)
def _worker():
while True:
try:
queue.put(next(dl))
except StopIteration:
break
queue.put(None)
t=Thread(target=_worker)
t.start()
while True:
d = queue.get()
if d is None:
break
yield d
del t
del queue
def __len__(self):
return len(self.dataloader)

Просмотреть файл

@ -0,0 +1,76 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
import os
import numpy as np
import math
import sys
from torch.utils.data import Sampler
__all__=['BatchSampler', 'DistributedBatchSampler', 'RandomSampler', 'SequentialSampler']
class BatchSampler(Sampler):
def __init__(self, sampler, batch_size):
self.sampler = sampler
self.batch_size = batch_size
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch)==self.batch_size:
yield batch
batch = []
if len(batch)>0:
yield batch
def __len__(self):
return (len(self.sampler) + self.batch_size - 1)//self.batch_size
class DistributedBatchSampler(Sampler):
def __init__(self, sampler, rank=0, world_size = 1, drop_last = False):
self.sampler = sampler
self.rank = rank
self.world_size = world_size
self.drop_last = drop_last
def __iter__(self):
for b in self.sampler:
if len(b)%self.world_size != 0:
if self.drop_last:
break
else:
b.extend([b[0] for _ in range(self.world_size-len(b)%self.world_size)])
chunk_size = len(b)//self.world_size
yield b[self.rank*chunk_size:(self.rank+1)*chunk_size]
def __len__(self):
return len(self.sampler)
class RandomSampler(Sampler):
def __init__(self, total_samples:int, data_seed:int = 0):
self.indices = np.array(np.arange(total_samples))
self.rng = np.random.RandomState(data_seed)
def __iter__(self):
self.rng.shuffle(self.indices)
for i in self.indices:
yield i
def __len__(self):
return len(self.indices)
class SequentialSampler(Sampler):
def __init__(self, total_samples:int):
self.indices = np.array(np.arange(total_samples))
def __iter__(self):
for i in self.indices:
yield i
def __len__(self):
return len(self.indices)

16
DeBERTa/optims/__init__.py Executable file
Просмотреть файл

@ -0,0 +1,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
""" optimizers
"""
from .xadam import XAdam
from .fp16_optimizer import *
from .lr_schedulers import SCHEDULES
from .args import get_args

100
DeBERTa/optims/args.py Normal file
Просмотреть файл

@ -0,0 +1,100 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
""" Arguments for optimizer
"""
import argparse
from ..utils import boolean_string
__all__ = ['get_args']
def get_args():
parser=argparse.ArgumentParser(add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
group = parser.add_argument_group(title='Optimizer', description='Parameters for the distributed optimizer')
group.add_argument('--fp16',
default=False,
type=boolean_string,
help="Whether to use 16-bit float precision instead of 32-bit")
group.add_argument('--loss_scale',
type=float, default=16384,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
group.add_argument('--scale_steps',
type=int, default=250,
help='The steps to wait to increase the loss scale.')
group.add_argument('--lookahead_k',
default=-1,
type=int,
help="lookahead k parameter")
group.add_argument('--lookahead_alpha',
default=0.5,
type=float,
help="lookahead alpha parameter")
group.add_argument('--with_radam',
default=False,
type=boolean_string,
help="whether to use RAdam")
group.add_argument('--opt_type',
type=str.lower,
default='adam',
choices=['adam', 'admax'],
help="The optimizer to be used.")
group.add_argument("--warmup_proportion",
default=0.1,
type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
group.add_argument("--lr_schedule_ends",
default=0,
type=float,
help="The ended learning rate scale for learning rate scheduling")
group.add_argument("--lr_schedule",
default='warmup_linear',
type=str,
help="The learning rate scheduler used for traning. " +
"E.g. warmup_linear, warmup_linear_shift, warmup_cosine, warmup_constant. Default, warmup_linear")
group.add_argument("--max_grad_norm",
default=1,
type=float,
help="The clip threshold of global gradient norm")
group.add_argument("--learning_rate",
default=5e-5,
type=float,
help="The initial learning rate for Adam.")
group.add_argument("--epsilon",
default=1e-6,
type=float,
help="epsilon setting for Adam.")
group.add_argument("--adam_beta1",
default=0.9,
type=float,
help="The beta1 parameter for Adam.")
group.add_argument("--adam_beta2",
default=0.999,
type=float,
help="The beta2 parameter for Adam.")
group.add_argument('--weight_decay',
type=float,
default=0.01,
help="The weight decay rate")
return parser

293
DeBERTa/optims/fp16_optimizer.py Executable file
Просмотреть файл

@ -0,0 +1,293 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
""" FP16 optimizer wrapper
"""
from collections import defaultdict
import numpy as np
import math
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import ctypes
from ..utils import get_logger,boolean_string
logger=get_logger()
__all__ = ['Fp16Optimizer', 'ExpLossScaler', 'get_world_size']
def get_world_size():
try:
wd = dist.get_world_size()
return wd
except:
return 1
def fused_norm(input):
return torch.norm(input, p=2, dtype=torch.float32)
class OptParameter(torch.Tensor):
def __new__(cls, data, out_data=None, grad=None, name=None):
param = torch.Tensor._make_subclass(cls, data)
param._xgrad = grad
param.out_data = out_data
param._name = name
return param
@property
def name(self):
return self._name
@property
def grad(self):
return self._xgrad
@grad.setter
def grad(self, grad):
self._xgrad = grad
class Fp16Optimizer(object):
def __init__(self, param_groups, optimizer_fn, loss_scaler=None, grad_clip_norm = 1.0, lookahead_k = -1, lookahead_alpha = 0.5, rank=-1, distributed=False):
# all parameters should on the same device
groups = []
original_groups = []
self.rank = rank
self.distributed = distributed
if self.rank<0:
self.distributed = False
for group in param_groups:
if 'offset' not in group:
group['offset'] = None
if ('rank' not in group) or (not self.distributed):
group['rank'] = -1
assert group['offset'] is None, f"{group['names']}: {group['offset']}"
group_rank = group['rank']
params = group['params'] # parameter
if len(params) > 1:
flattened_params = _flatten_dense_tensors([p.data for p in params])
unflattend_params = _unflatten_dense_tensors(flattened_params, [p.data for p in params])
for uf,p in zip(unflattend_params, params):
p.data = uf
else:
flattened_params = params[0].data.view(-1)
if group['offset'] is not None:
start, length = group['offset']
flattened_params = flattened_params.narrow(0, start, length)
if params[0].dtype==torch.half:
if self.rank == group_rank or (not self.distributed):
master_params = flattened_params.clone().to(torch.float).detach_().to(flattened_params.device)
else:
master_params = flattened_params.clone().to(torch.float).detach_().cpu()
group['params'] = [OptParameter(master_params, flattened_params, name='master')]
else:
group['params'] = [OptParameter(flattened_params, None, name='master')]
o_group = defaultdict(list)
o_group['names'] = group['names']
o_group['params'] = params
o_group['rank'] = group_rank
o_group['offset'] = group['offset']
group['names'] = ['master']
original_groups.append(o_group)
groups.append(group)
self.param_groups = groups
self.loss_scaler = loss_scaler
self.optimizer = optimizer_fn(self.param_groups)
self.original_param_groups = original_groups
self.max_grad_norm = grad_clip_norm
self.lookahead_k = lookahead_k
self.lookahead_alpha = lookahead_alpha
def backward(self, loss):
if self.loss_scaler:
loss_scale, loss, step_loss = self.loss_scaler.scale(loss)
else:
loss_scale = 1
step_loss = loss.item()
loss.backward()
return loss_scale, step_loss
def step(self, lr_scale, loss_scale = 1):
grad_scale = self._grad_scale(loss_scale)
if grad_scale is None or math.isinf(grad_scale):
self.loss_scaler.update(False)
return False
if self.lookahead_k > 0:
for p in self.param_groups:
if 'la_count' not in p:
# init
#make old copy
p['la_count'] = 0
p['slow_params'] = [x.data.detach().clone().requires_grad_(False) for x in p['params']]
self.optimizer.step(grad_scale, lr_scale)
if self.lookahead_k > 0:
for p in self.param_groups:
p['la_count'] += 1
if p['la_count'] == self.lookahead_k:
p['la_count'] = 0
for s,f in zip(p['slow_params'], p['params']):
s.mul_(1-self.lookahead_alpha)
s.add_(f.data.detach()*self.lookahead_alpha)
f.data.copy_(s, non_blocking=True)
if hasattr(f, 'out_data') and f.out_data is not None:
f.out_data.copy_(f.data, non_blocking=True)
if self.loss_scaler:
self.loss_scaler.update(True)
return True
def zero_grad(self):
for group, o_group in zip(self.param_groups, self.original_param_groups):
for p in group['params']:
p.grad = None
for p in o_group['params']:
p.grad = None
def _grad_scale(self, loss_scale = 1):
named_params = {}
named_grads = {}
for g in self.original_param_groups:
for n,p in zip(g['names'], g['params']):
named_params[n] = p
named_grads[n] = p.grad if p.grad is not None else torch.zeros_like(p.data)
wd = get_world_size()
def _reduce(group):
grads = [named_grads[n] for n in group]
if len(grads)>1:
flattened_grads = _flatten_dense_tensors(grads)
else:
flattened_grads = grads[0],view(-1)
if wd > 1:
flattened_grads /= wd
handle = dist.all_reduce(flattened_grads, async_op=True)
else:
handle = None
return flattened_grads, handle
def _process_grad(group, flattened_grads, max_grad, norm):
grads = [named_grads[n] for n in group]
norm = norm.to(flattened_grads.device)
norm = norm + fused_norm(flattened_grads)**2
if len(grads) > 1:
unflattend_grads = _unflatten_dense_tensors(flattened_grads, grads)
else:
unflattend_grads = [flattened_grads]
for n,ug in zip(group, unflattend_grads):
named_grads[n] = ug #.to(named_params[n].data)
return max_grad, norm
group_size = 0
group = []
max_size = 32*1024*1024
norm = torch.zeros(1, dtype=torch.float)
max_grad = 0
all_grads = []
for name in sorted(named_params.keys()):
group.append(name)
group_size += named_params[name].data.numel()
if group_size>=max_size:
flatten, handle = _reduce(group)
all_grads.append([handle, flatten, group])
group = []
group_size = 0
if group_size>0:
flatten, handle = _reduce(group)
all_grads.append([handle, flatten, group])
group = []
group_size = 0
for h,fg,group in all_grads:
if h is not None:
h.wait()
max_grad, norm = _process_grad(group, fg, max_grad, norm)
norm = norm**0.5
if torch.isnan(norm) or torch.isinf(norm) :#in ['-inf', 'inf', 'nan']:
return None
scaled_norm = norm.detach().item()/loss_scale
grad_scale = loss_scale
if self.max_grad_norm>0:
scale = norm/(loss_scale*self.max_grad_norm)
if scale>1:
grad_scale *= scale
for group, o_g in zip(self.param_groups, self.original_param_groups):
grads = [named_grads[n] for n in o_g['names']]
if len(grads) > 1:
flattened_grads = _flatten_dense_tensors(grads)
else:
flattened_grads = grads[0].view(-1)
if group['offset'] is not None:
start, length = group['offset']
flattened_grads = flattened_grads.narrow(0, start, length)
if group['rank'] == self.rank or (not self.distributed):
group['params'][0].grad = flattened_grads
return grad_scale
class ExpLossScaler:
def __init__(self, init_scale=2**16, scale_interval=1000):
self.cur_scale = init_scale
self.scale_interval = scale_interval
self.invalid_cnt = 0
self.last_scale = 0
self.steps = 0
self.down_scale_smooth = 0
def scale(self, loss):
assert self.cur_scale > 0, self.init_scale
step_loss = loss.float().detach().item()
if step_loss != 0 and math.isfinite(step_loss):
loss_scale = self.cur_scale
else:
loss_scale = 1
loss = loss.float()*loss_scale
return (loss_scale, loss, step_loss)
def update(self, is_valid = True):
if not is_valid:
self.invalid_cnt += 1
if self.invalid_cnt>self.down_scale_smooth:
self.cur_scale /= 2
self.cur_scale = max(self.cur_scale, 1)
self.last_scale = self.steps
else:
self.invalid_cnt = 0
if self.steps - self.last_scale>self.scale_interval:
self.cur_scale *= 2
self.last_scale = self.steps
self.steps += 1
def state_dict(self):
state = defaultdict(float)
state['steps'] = self.steps
state['invalid_cnt'] = self.invalid_cnt
state['cur_scale'] = self.cur_scale
state['last_scale'] = self.last_scale
return state
def load_state_dict(self, state):
self.steps = state['steps']
self.invalid_cnt = state['invalid_cnt']
self.cur_scale = state['cur_scale']
self.last_scale = state['last_scale']

63
DeBERTa/optims/lr_schedulers.py Executable file
Просмотреть файл

@ -0,0 +1,63 @@
""" Learning rate schedulers
"""
import math
import torch
from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_
def warmup_cosine(step, total, warmup=0.002, ends = 0):
x = step/total
x = x-int(x)
if x < warmup:
return x/warmup
return 0.5 * (1.0 + math.cos(math.pi * x))
def warmup_constant(step, total, warmup=0.002, ends = 0):
x = step/total
x = x-int(x)
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(step, total, warmup=0.002, ends = 0):
x = step/total
x = x-int(x)
if x < warmup:
return x/warmup
return (1-ends)*(1.0 - x) + ends
def warmup_linear_cosine(step, total, warmup=0.002, ends = 0):
x = step/total
x = x-int(x)
if x < warmup:
return x/warmup
return (1-ends)*max(0.5*(1+math.cos(math.pi*(x-warmup)/(1-warmup))), 0) + ends
def warmup_cyclic_linear_cosine(step, total, warmup=0.002, ends = 0):
x = step/total
if x < warmup:
return x/warmup
total = total - int(total*warmup)
step = step - int(total*warmup)
n_epoch = 4
period = total//n_epoch
k = step//period
s = 1-k/n_epoch + 1/(2*n_epoch)*(math.pow(-1, k)*math.cos(math.pi*step/period)-1)
return (1-ends)*max(s, 0) + ends
def warmup_linear_shift(step, total, warmup=0.002, ends = 0):
x = step/total
x = x-int(x)
if x < warmup:
return x/warmup
return (1-ends)*(1.0 - (x-warmup)/(1-warmup)) + ends
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_linear_cosine':warmup_linear_cosine,
'warmup_cyclic_linear_cosine':warmup_cyclic_linear_cosine,
'warmup_linear_shift':warmup_linear_shift,
}

217
DeBERTa/optims/xadam.py Executable file
Просмотреть файл

@ -0,0 +1,217 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
""" Optimizer
"""
import math
import torch
from torch.optim import Optimizer
from torch.nn.utils import clip_grad_norm_
from torch import distributed as dist
import pdb
from .lr_schedulers import SCHEDULES
from ..utils import get_logger
def adamw(data,
out_data,
next_m,
next_v,
grad,
lr,
beta1,
beta2,
eps,
grad_scale, #combined_scale, g = g/scale
step,
eps_mode = 1, #self.eps_mode, esp inside sqrt:0, outside: 1, only update with momentum: 2
bias_correction = 0,
weight_decay = 0):
if bias_correction > 0:
lr *= bias_correction
beta1_ = 1 - beta1
beta2_ = 1 - beta2
grad = grad.float()
if grad_scale != 1:
grad *= 1/grad_scale
grad = grad.to(next_m)
next_m.mul_(beta1).add_(beta1_, grad)
# admax
admax = eps_mode>>4
eps_mode = eps_mode&0xF
if admax > 0:
torch.max(next_v.mul_(beta2), grad.abs().to(next_v), out=next_v)
update = next_m/(next_v+eps)
else:
next_v.mul_(beta2).addcmul_(beta2_, grad, grad)
if eps_mode == 0:
update = (next_m)*(next_v+eps).rsqrt()
elif eps_mode == 1:
update = (next_m)/(next_v.sqrt()+eps)
else: #=2
update = next_m.clone()
if weight_decay>0:
update.add_(weight_decay, data)
data.add_(-lr, update)
if (out_data is not None) and len(out_data)>0:
out_data.copy_(data)
class XAdam(Optimizer):
"""Implements optimized version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay_rate: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
with_radam: Whether to enable radam. Default: False
radam_th: RAdam threshold for tractable variance. Default: 4
opt_type: The type of optimizer, [adam, admax], default: adam
"""
def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-8, weight_decay_rate=0.01,
lr_ends = 0,
max_grad_norm = 1.0,
with_radam = False,
radam_th = 4,
opt_type=None,
rank = -1):
if not lr >= 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
self.defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
lr_ends = lr_ends,
max_grad_norm=max_grad_norm,
with_radam = with_radam, radam_th = radam_th)
self.opt_type = opt_type.lower() if opt_type is not None else ""
self.rank = rank
super().__init__(params, self.defaults)
def step(self, grad_scale = 1, lr_scale = 1):
"""Performs a single optimization step.
Arguments:
grad_scale: divid grad by grad_scale
lr_scale: scale learning rate by bs_scale
"""
if 'global_step' not in self.state:
self.state['global_step'] = 0
for group in self.param_groups:
lr_sch = self.get_group_lr_sch(group, self.state['global_step'])
if group['rank'] == self.rank or group['rank']<0 or self.rank<0:
for param in group['params']:
self.update_param(group, param, grad_scale, lr_scale)
self.state['global_step'] += 1
self.last_grad_scale = grad_scale
handels = []
for group in self.param_groups:
if group['rank']>=0 and self.rank>=0:
# sync
for param in group['params']:
out_p = param.out_data if hasattr(param, 'out_data') and (param.out_data is not None) else None
if out_p is not None:
h = torch.distributed.broadcast(out_p, group['rank'], async_op=True)
else:
h = torch.distributed.broadcast(param.data, group['rank'], async_op=True)
handels.append(h)
for h in handels:
if h is not None:
h.wait()
return lr_sch
def get_group_lr_sch(self, group, steps):
if group['t_total'] > 0:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = schedule_fct(steps, group['t_total'], group['warmup'], group['lr_ends'])
else:
lr_scheduled = 1
return lr_scheduled
def update_param(self, group, param, grad_scale, lr_scale):
grad = param.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.get_state(param)
lr_sch = self.get_group_lr_sch(group, state['step'])
lr = group['lr'] * lr_scale *lr_sch
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
state['step'] += 1
# Support for RAdam
t = (state['step']-1) + 1
eps_mode = 1
if group['with_radam']:
rou_ = 2/(1-beta2) - 1
rou_t = rou_ - 2*t/(beta2**-t - 1)
bias_c = 1/(1-beta1**t)
if rou_t > group['radam_th']:
bias_c *= math.sqrt(1 - beta2**t)
bias_c *= math.sqrt(((rou_t - 4)*(rou_t - 2)*rou_)/((rou_ - 4)*(rou_ - 2)*rou_t))
else:
eps_mode = 2
bias_c = 0
lr *= bias_c
if self.opt_type == 'admax':
eps_mode |= 0x10
with torch.cuda.device(param.device.index):
out_p = param.out_data if hasattr(param, 'out_data') and (param.out_data is not None) else None
if out_p is None or out_p.dtype != grad.dtype:
out_p = torch.tensor([], dtype=torch.float).to(param.data)
weight_decay = group['weight_decay_rate']
adamw(param.data,
out_p,
next_m,
next_v,
grad,
lr,
beta1,
beta2,
group['e'],
grad_scale, #combined_scale, g = g/scale
state['step'],
eps_mode, #self.eps_mode, esp inside sqrt:0, outside: 1, only update with momentum: 2
0, #bias_correction,
weight_decay)
out_p = param.out_data if hasattr(param, 'out_data') and (param.out_data is not None) else None
if out_p is not None and out_p.dtype != grad.dtype:
out_p.copy_(param.data)
def get_state(self, param):
state = self.state[param]
# State initialization
if len(state) == 0:
state['step'] = 0
state['next_m'] = torch.zeros_like(param.data)
state['next_v'] = torch.zeros_like(param.data)
return state

4
DeBERTa/training/__init__.py Executable file
Просмотреть файл

@ -0,0 +1,4 @@
from .trainer import DistributedTrainer, set_random_seed
from .args import get_args
from .dist_launcher import initialize_distributed,kill_children
from ._utils import batch_to,batch_apply

16
DeBERTa/training/_utils.py Executable file
Просмотреть файл

@ -0,0 +1,16 @@
import torch
from collections import Sequence, Mapping
def batch_apply(batch, fn):
if isinstance(batch, torch.Tensor):
return fn(batch)
elif isinstance(batch, Sequence):
return [batch_apply(x, fn) for x in batch]
elif isinstance(batch, Mapping):
return {x:batch_apply(batch[x], fn) for x in batch}
else:
raise NotImplementedError(f'Type of {type(batch)} are not supported in batch_apply')
def batch_to(batch, device):
return batch_apply(batch, lambda x: x.to(device))

72
DeBERTa/training/args.py Normal file
Просмотреть файл

@ -0,0 +1,72 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
import argparse
from ..utils import boolean_string
__all__ = ['get_args']
def get_args():
parser=argparse.ArgumentParser(add_help=False, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
group = parser.add_argument_group(title='Trainer', description='Parameters for the distributed trainer')
group.add_argument('--accumulative_update',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
group.add_argument("--dump_interval",
default=1000,
type=int,
help="Interval steps for generating checkpoint.")
group.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
group.add_argument('--workers',
type=int,
default=2,
help="The workers to load data.")
group.add_argument("--num_train_epochs",
default=3.0,
type=float,
help="Total number of training epochs to perform.")
group.add_argument('--seed',
type=int,
default=1234,
help="random seed for initialization")
group.add_argument("--train_batch_size",
default=64,
type=int,
help="Total batch size for training.")
group.add_argument("--world_size",
type=int,
default=-1,
help="[Internal] The world size of distributed training. Internal usage only!! To the world size of the program, you need to use environment. 'WORLD_SIZE'")
group.add_argument("--rank",
type=int,
default=-1,
help="[Internal] The rank id of current process. Internal usage only!! To the rank of the program, you need to use environment. 'RANK'")
group.add_argument("--master_ip",
type=str,
default=None,
help="[Internal] The ip address of master node. Internal usage only!! To the master IP of the program, you need to use environment. 'MASTER_ADDR'")
group.add_argument("--master_port",
type=str,
default=None,
help="[Internal] The port of master node. Internal usage only!! To the master IP of the program, you need to use environment. 'MASTER_PORT'")
return parser

163
DeBERTa/training/dist_launcher.py Executable file
Просмотреть файл

@ -0,0 +1,163 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
import os
import time
import pdb
import signal
import torch
from multiprocessing import Process,Pool
from collections import defaultdict
import sys
import psutil
from ..utils import set_logger, get_logger
logger = get_logger()
def kill_children(proc=None, recursive = True):
if proc is None:
proc = psutil.Process()
_children = proc.children(recursive=False)
for c in _children:
try:
if recursive:
kill_children(c, recursive=recursive)
os.kill(c.pid, signal.SIGKILL)
except:
pass
for c in _children:
try:
c.wait(1)
except:
pass
def gc(i):
return torch.cuda.device_count()
def get_ngpu():
with Pool(1) as p:
return p.map(gc, range(1))[0]
def _setup_distributed_group(args):
"""Initialize torch.distributed."""
torch.backends.cudnn.enabled = False
if args.world_size == 1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
set_logger(args.task_name, os.path.join(args.output_dir, f'training_{args.task_name}_{args.rank}.log'), rank=args.rank, verbose=1 if args.local_rank==0 else 0)
device_id = args.rank % args.n_gpu
if args.local_rank >= 0:
device_id = args.local_rank
device = torch.device("cuda", device_id)
init_method = 'tcp://'
init_method += args.master_ip + ':' + args.master_port
distributed_backend = getattr(args, 'distributed_backend', 'nccl')
torch.distributed.init_process_group(
backend=distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
torch.cuda.set_device(device)
n_gpu = torch.cuda.device_count()
logger.info("device=%s, n_gpu=%d, distributed training=%r, world_size=%d", device, n_gpu, bool(args.world_size != 1), args.world_size)
return device
def _get_world_size(args):
world_size = int(os.getenv("WORLD_SIZE", '1'))
if not hasattr(args, 'n_gpu') or args.n_gpu is None:
n_gpu = get_ngpu()
return n_gpu * world_size
def initialize_distributed(args, join=True):
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
args.rank = int(os.getenv('RANK', '0'))
args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
args.master_port = os.getenv('MASTER_PORT', '17006')
if args.world_size == 1:
args.rank = 0
args.master_ip = 'localhost'
if not hasattr(args, 'n_gpu') or args.n_gpu is None:
args.n_gpu = get_ngpu()
args.node_rank = args.rank
args.world_size = args.n_gpu * args.world_size
seed = args.seed
is_child = False
if args.world_size>1:
children = []
for r in range(args.n_gpu):
args.rank = r + args.n_gpu*args.node_rank
args.local_rank = r
args.seed = seed + args.rank
child = os.fork()
if child>0:
children.append(child)
else:
signal.signal(signal.SIGINT, signal.SIG_IGN)
is_child = True
break
else:
is_child = True
if is_child:
return _setup_distributed_group(args)
else:
if join:
try:
for c in children:
cid, ccode = os.waitpid(0,0)
logger.debug(f'Worker {c} done with code {ccode}')
if ccode != 0:
logger.error(f'Worker {c} : {cid} failed with code {ccode}')
kill_children()
raise ValueError(f'Job failed. {cid}:{ccode}')
except (KeyboardInterrupt, SystemExit):
logger.warning('Keybord interrupt by user. Terminate all processes')
kill_children(None)
return children
def test_dist_launch():
def test_functions(args):
global logger
set_logger(args.task_name, os.path.join(args.output_dir, f'training_{args.task_name}_{args.node_rank}.log'), rank=args.rank)
logger.info(args)
class Args:
def __init__(self):
pass
def __repr__(self):
return str(self.__dict__)
args = Args()
args.task_name = 'test'
args.seed = 0
args.n_gpu = None
args.no_cuda=False
args.output_dir = '/tmp'
distributed_launch(args, test_functions, (args,))
def test_init_dist():
class Args:
def __init__(self):
pass
def __repr__(self):
return str(self.__dict__)
args = Args()
args.task_name = 'test'
args.seed = 0
args.n_gpu = None
args.no_cuda=False
args.output_dir = '/tmp'
device = initialize_distributed(args)
if isinstance(device, torch.device):
return 0
else:
return 1

Просмотреть файл

@ -0,0 +1,170 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
from collections import defaultdict
import numpy as np
import pdb
from functools import cmp_to_key
import torch
import re
from ..optims import Fp16Optimizer,XAdam,ExpLossScaler,get_world_size
from ..utils import get_logger
logger=get_logger()
def xadam_factory(args, training_steps=None):
def optimizer_fn(param_groups, max_grad_norm=None):
with_radam = getattr(args, 'with_radam', False)
opt_type = getattr(args, 'opt_type', None)
optimizer = XAdam(param_groups,
lr=args.learning_rate,
b1=args.adam_beta1,
b2=args.adam_beta2,
lr_ends=args.lr_schedule_ends,
e=args.epsilon,
warmup=args.warmup_proportion if args.warmup_proportion<1 else args.warmup_proportion/training_steps,
t_total=training_steps,
schedule=args.lr_schedule,
max_grad_norm = args.max_grad_norm if max_grad_norm is None else max_grad_norm,
weight_decay_rate = args.weight_decay,
with_radam = with_radam,
opt_type = opt_type,
rank = args.rank)
return optimizer
return optimizer_fn
def create_xoptimizer(model, args, num_train_steps=None, no_decay=['bias', 'LayerNorm.weight']):
if args.fp16:
loss_scaler = ExpLossScaler(scale_interval = args.scale_steps, init_scale=args.loss_scale)
else:
loss_scaler = None
distributed_optimizer = getattr(args, 'distributed_optimizer', True)
max_distributed_groups = getattr(args, 'max_distributed_groups', 1000000)
world_size = get_world_size()
if world_size<=1:
distributed_optimizer = False
_no_decay = [x.strip() for x in getattr(args, 'no_decay', '').split('|') if len(x.strip())>0]
if len(_no_decay)>0:
no_decay = _no_decay
opt_fn = xadam_factory(args, num_train_steps)
named_params = list(model.named_parameters())
param_size = [p.numel() for n,p in named_params]
type_groups = defaultdict(list)
if distributed_optimizer:
num_groups = min(world_size, max_distributed_groups)
max_group_size = (sum(param_size)+num_groups-1)//num_groups
#max_group_size = max(64*1024*1024, max_group_size)
#max_group_size = max_group_size//2
max_group_size = (max_group_size//32)*32
group_sizes = [0 for _ in range(num_groups)]
group_ranks = [g*(world_size//num_groups) for g in range(num_groups)]
else:
# TODO: Fix inconsistent results with different group size
max_group_size = max(64*1024*1024, max(param_size))
num_groups = (sum(param_size)+max_group_size-1)//max_group_size
group_sizes = [0 for _ in range(num_groups)]
def get_smallest_group(group_sizes):
return np.argmin([g+i/10000 for i,g in enumerate(group_sizes)])
def chunk_into_pieces(param, max_size):
num_chunks = param.numel()//max_size
if num_chunks<2:
return [param], [None]
flat = param.view(-1)
chunks=[]
offsets = []
for i in range(num_chunks-1):
chunks.append(flat.narrow(0, i*max_size, max_size))
offsets.append([i*max_size, max_size])
i += 1
chunks.append(flat.narrow(0, i*max_size, flat.size(0)-i*max_size))
offsets.append([i*max_size, flat.size(0)-i*max_size])
assert sum([c.numel() for c in chunks])==param.numel(), f'{param.numel()}: {offsets}'
return chunks, offsets
def param_cmp(x,y):
n1,p1 = x
n2,p2 = y
if p1.numel() == p2.numel():
if n1<n2:
return -1
elif n1>n2:
return 1
else:
return 0
else:
return p1.numel() - p2.numel()
def add_group(param_groups, group, group_id):
if distributed_optimizer:
group['rank'] = group_ranks[group_id]
param_groups.append(group.copy())
group['params'] = []
group['names'] = []
group['offset'] = None
return get_smallest_group(group_sizes),group
hard_reset = getattr(args, 'hard_reset', False)
group_id = 0
for n,p in named_params:
key = ''
if any(re.search(nd,n) for nd in no_decay):
key += f'{str(p.dtype)}-nd'
else:
key += f'{str(p.dtype)}-d'
type_groups[key].append((n,p))
param_groups = []
for key, params in type_groups.items():
wd_theta = 0
weight_decay = args.weight_decay
_hard_reset = False
if key.endswith('-nd'):
weight_decay = 0
else:
_hard_reset = hard_reset
group = dict(params=[],
weight_decay_rate=weight_decay,
wd_theta = wd_theta,
hard_reset = hard_reset,
names=[],
offset=None)
params = sorted(params, key=cmp_to_key(param_cmp))
for (n,p) in params:
if p.numel() >= max_group_size:
if len(group['params'])>0:
group_id,group = add_group(param_groups, group, group_id)
chunks, offsets = chunk_into_pieces(p, max_group_size)
for chk, off in zip(chunks, offsets):
group['params'].append(p)
group['names'].append(n)
group['offset'] = off
group_sizes[group_id] += chk.numel()
group_id,group = add_group(param_groups, group, group_id)
else:
group['params'].append(p)
group['names'].append(n)
group['offset'] = None
group_sizes[group_id] += p.numel()
if group_sizes[group_id]>=max_group_size:
group_id,group = add_group(param_groups, group, group_id)
if len(group['params'])>0:
group_id,group = add_group(param_groups, group, group_id)
lookahead_k = getattr(args, 'lookahead_k', -1)
lookahead_alpha = getattr(args, 'lookahead_alpha', 0.5)
optimizer = Fp16Optimizer(param_groups, opt_fn, loss_scaler, args.max_grad_norm, lookahead_k = lookahead_k,\
lookahead_alpha = lookahead_alpha, rank=args.rank, distributed=distributed_optimizer)
return optimizer

242
DeBERTa/training/trainer.py Executable file
Просмотреть файл

@ -0,0 +1,242 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: Pengcheng He (penhe@microsoft.com)
# Date: 05/15/2019
#
import os
import torch
import random
import time
import numpy as np
import pdb
from collections import defaultdict, Mapping, Sequence, OrderedDict
from torch.utils.data import DataLoader
from ..data import BatchSampler, DistributedBatchSampler,RandomSampler,SequentialSampler, AsyncDataLoader
from ..utils import get_logger
logger = get_logger()
from .dist_launcher import get_ngpu
from .optimizer_utils import create_xoptimizer
from ._utils import batch_to
__all__ = ['DistributedTrainer', 'set_random_seed']
def set_random_seed(seed, cpu_only=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
n_gpu = get_ngpu()
if n_gpu > 0 and not cpu_only:
torch.cuda.manual_seed_all(seed)
class TrainerState:
def __init__(self, training_steps, name=None):
self.__dict__ = defaultdict(float)
self.loss = 0.0
self.examples = 0
self.steps = 0
self._last_report_step = 0
self.epochs = 0
self.next_batch = 0
self.num_training_steps = training_steps
self._last_report_time = time.time()
self.best_steps = 0
self.best_metric = -1e9
self.name = name
self.run_id = None
def update_step(self, loss, examples, loss_scale):
self.examples += examples
self.loss += loss
self.steps += 1
self.next_batch += 1
self.loss_scale = loss_scale
def report_state(self):
if self.steps <= self._last_report_step:
return
end = time.time()
start = self._last_report_time
if self.name is not None:
tag = f'[{self.name}]'
else:
tag = None
logger.info('{}[{:0.1f}%][{:0.2f}h] Steps={}, loss={}, examples={}, loss_scale={:0.1f}, {:0.1f}s'.format(tag, 100*self.steps/self.num_training_steps, \
(self.num_training_steps - self.steps)*(start-end)/((self.steps-self._last_report_step)*3600), self.steps, self.loss/self.steps, self.examples, self.loss_scale, end-start))
self._last_report_time = end
self._last_report_step = self.steps
class DistributedTrainer:
def __init__(self, args, output_dir, model, device, data_fn, loss_fn=None, optimizer_fn=None, eval_fn=None, init_fn=None, update_fn=None, dump_interval = 10000, name=None, **kwargs):
"""
data_fn return tuples (training_dataset, training_steps, train_sampler, batch_scheduler), training_dataset is required
loss_fn return the loss of current mini-batch and the size of the batch
optimizer_fn return the created optimizer
eval_fn return metrics for model selection
"""
self.__dict__.update(kwargs)
self.args = args
self.device = device
self.eval_fn = eval_fn
self.accumulative_update = 1
if hasattr(args, 'accumulative_update'):
self.accumulative_update = args.accumulative_update
train_data, training_steps, train_sampler = data_fn(self)
self.train_data = train_data
self.train_sampler = train_sampler if train_sampler is not None else RandomSampler(len(train_data))
self.training_epochs = int(getattr(args, 'num_train_epochs', 1))
if training_steps is None:
training_steps = getattr(args, 'training_steps', (len(training_data) + self.args.train_batch_size-1)//self.args.train_batch_size*self.training_epochs)
self.training_steps = training_steps
self.output_dir = output_dir
self.init_fn = init_fn
self.trainer_state = TrainerState(self.training_steps, name = name)
self.dump_interval = dump_interval
self.model = self._setup_model(args, model)
self.post_loss_fn = None
def _opt_fn(trainer, model, training_steps):
return create_xoptimizer(model, args, num_train_steps = training_steps)
optimizer_fn = optimizer_fn if optimizer_fn is not None else _opt_fn
self.optimizer = optimizer_fn(self, model, training_steps)
def _loss_fn(trainer, model, batch):
_,loss = model(**batch)
batch_size = batch['input_ids'].size(0)
return loss.mean(), batch_size
self.loss_fn = loss_fn if loss_fn is not None else _loss_fn
self.initialized = False
self.update_fn = update_fn
def initialize(self):
set_random_seed(self.args.seed)
if self.args.world_size>1:
torch.distributed.barrier()
self.initialized = True
def train(self):
if not self.initialized:
self.initialize()
rank = self.args.rank
world_size = self.args.world_size
for n_epoch in range(self.trainer_state.epochs, self.training_epochs):
batch_sampler = BatchSampler(self.train_sampler, self.args.train_batch_size)
batch_sampler = DistributedBatchSampler(batch_sampler, rank = rank, world_size = world_size)
batch_sampler.next = self.trainer_state.next_batch
num_workers = getattr(self.args, 'workers', 2)
train_dataloader = DataLoader(self.train_data, batch_sampler=batch_sampler, num_workers=num_workers, worker_init_fn=self.init_fn, pin_memory=True)
torch.cuda.empty_cache()
for step, batch in enumerate(AsyncDataLoader(train_dataloader, 100)):
if self.trainer_state.steps >= self.training_steps:
break
bs_scale = 1
batch = batch_to(batch, self.device)
self._train_step(batch, bs_scale)
# Save model
self.trainer_state.epochs += 1
self.trainer_state.next_batch = 0
self.trainer_state.report_state()
self._eval_model()
def save_model(self, args, checkpoint_dir, chk_postfix, model, optimizer):
save_path= os.path.join(checkpoint_dir, f'pytorch.model-{chk_postfix}.bin')
if hasattr(model, 'module'):
model_state = OrderedDict([(n,p) for n,p in model.module.state_dict().items()])
else:
model_state = OrderedDict([(n,p) for n,p in model.state_dict().items()])
if args.rank < 1:
torch.save(model_state, save_path)
return save_path
def _eval_model(self, with_checkpoint=True):
if with_checkpoint:
checkpoint_dir = getattr(self.args, 'checkpoint_dir', None)
checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else self.output_dir
chk_postfix = f'{self.trainer_state.steps:06}'
self.save_model(self.args, checkpoint_dir, chk_postfix, self.model, self.optimizer)
_metric = self.trainer_state.best_metric
_steps = self.trainer_state.best_steps
if self.eval_fn is not None:
metric = self.eval_fn(self, self.model, self.device, tag=f'{self.trainer_state.steps:06}-{self.training_steps}')
if metric > _metric:
_metric = metric
_steps = self.trainer_state.steps
logger.info(f'Best metric: {_metric}@{_steps}')
self.trainer_state.best_metric, self.trainer_state.best_steps = _metric, _steps
def _train_step(self, data, bs_scale):
self.model.train()
go_next=False
def split(batch, parts):
sub_batches = [{} for _ in range(parts)]
for k in batch.keys():
b = batch[k].size(0)
s = (b + parts - 1)//parts
v = batch[k].split(s)
for i,z in enumerate(v):
sub_batches[i][k]=z
chunks = [b for b in sub_batches if len(b)>0]
return chunks
if self.accumulative_update>1:
data_chunks = split(data, self.accumulative_update)
else:
data_chunks = [data]
while not go_next:
step_loss = 0
batch_size = 0
self.optimizer.zero_grad()
forward_outputs = []
for i, sub in enumerate(data_chunks):
output = self.loss_fn(self, self.model, sub)
if isinstance(output, dict):
loss, sub_size = output['loss'], output['batch_size']
else:
loss, sub_size = output
forward_outputs.append(output)
loss = loss/len(data_chunks)
if i == 0:
loss_scale, _loss = self.optimizer.backward(loss)
else:
_loss = loss.float().detach().item()
loss = loss.float() * loss_scale
loss.backward()
step_loss += _loss
batch_size += sub_size
if not self.optimizer.step(bs_scale, loss_scale):
self.optimizer.zero_grad()
continue
go_next = True
self.trainer_state.update_step(step_loss, batch_size , loss_scale)
if self.update_fn is not None:
self.update_fn(self, self.model, loss_scale)
self.optimizer.zero_grad()
if self.post_loss_fn is not None:
self.post_loss_fn(forward_outputs)
if self.trainer_state.steps%100 == 0:
self.trainer_state.report_state()
if self.trainer_state.steps%self.dump_interval == 0:
self._eval_model()
def _setup_model(self, args, model):
if args.world_size > 1:
for p in model.parameters():
torch.distributed.broadcast(p.data, 0)
torch.cuda.synchronize()
return model

Просмотреть файл

@ -201,7 +201,7 @@ And here are the results from the Base model
|MNLI base| `experiments/glue/mnli.sh base`| 88.8/88.5 +/-0.2| 1.5h|
#### Fine-tuning on NLU tasks
### Fine-tuning on NLU tasks
We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
@ -219,7 +219,7 @@ We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
|[DeBERTa-V3-Base](https://huggingface.co/microsoft/deberta-v3-base)|-/-|88.4/85.4|90.6/90.7|-|-|-| -| -|- |- |
|[DeBERTa-V3-Small](https://huggingface.co/microsoft/deberta-v3-base)|-/-|82.9/80.4|88.2/87.9|-|-|-| -| -|- |- |
#### Fine-tuning on XNLI
### Fine-tuning on XNLI
We present the dev results on XNLI with zero-shot crosslingual transfer setting, i.e. training with english data only, test on other languages.
@ -232,6 +232,10 @@ We present the dev results on XNLI with zero-shot crosslingual transfer setting,
#### Notes.
- <sup>1</sup> Following RoBERTa, for RTE, MRPC, STS-B, we fine-tune the tasks based on [DeBERTa-Large-MNLI](https://huggingface.co/microsoft/deberta-large-mnli), [DeBERTa-XLarge-MNLI](https://huggingface.co/microsoft/deberta-xlarge-mnli), [DeBERTa-V2-XLarge-MNLI](https://huggingface.co/microsoft/deberta-v2-xlarge-mnli), [DeBERTa-V2-XXLarge-MNLI](https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli). The results of SST-2/QQP/QNLI/SQuADv2 will also be slightly improved when start from MNLI fine-tuned models, however, we only report the numbers fine-tuned from pretrained base models for those 4 tasks.
### Pre-training with MLM and RTD objectives
To pre-train DeBERTa with MLM and RTD objectives, please check [`experiments/language_models`](experiments/language_model)
## Contacts
@ -260,17 +264,5 @@ url={https://openreview.net/forum?id=XPZIaotutsD}
}
```
# Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

Просмотреть файл

@ -20,6 +20,19 @@ setup_glue_data $Task
init=$1
tag=$init
case ${init,,} in
bert-xsmall)
init=/tmp/ttonly/bert-xsmall/discriminator/pytorch.model-1000000.bin
vocab_type=spm
vocab_path=/tmp/ttonly/bert-xsmall/discriminator/spm.model
parameters=" --num_train_epochs 3 \
--fp16 True \
--warmup 1500 \
--learning_rate 1e-4 \
--vocab_type $vocab_type \
--vocab_path $vocab_path \
--train_batch_size 64 \
--cls_drop_out 0.1 "
;;
deberta-v3-small)
parameters=" --num_train_epochs 2 \
--fp16 False \
@ -144,11 +157,13 @@ case ${init,,} in
;;
esac
export MASTER_PORT=12456
python -m DeBERTa.apps.run --model_config config.json \
--tag $tag \
--do_train \
--max_seq_len 256 \
--eval_batch_size 256 \
--dump_interval 1000 \
--task_name $Task \
--data_dir $cache_dir/glue_tasks/$Task \
--init_model $init \

Просмотреть файл

@ -0,0 +1,57 @@
# Pre-train an efficient transformer language model for natual language understanding
## Data
We use [wiki103](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip) data as example, which is publicly available. It contains three text files, `train.txt`, `valid.txt` and `text.txt`. We use `train.txt` to train the model and `valid.txt` to evalute the intermeidate checkpoints. We first need to run `prepara_data.py` to tokenize these text files. We concatenate all documents into a single text and split it into lines of tokens, while each line has at most 510 token (2 tokens are left to special tokens `[CLS]` and `[SEP]`).
## Pre-training with Masked Language Modeling task
Run `mlm.sh` to train a bert like model which uses MLM as the pre-training task. For example,
`mlm.sh bert-base` will train a bert base model which uses absolute position encoding
`mlm.sh deberta-base` will train a deberta base model which uses **Disentangled Attention**
## Pre-training with Replaced Token Detection task
Coming soon...
## Distributed training
To train with multiple node, you need to specify three environment variables,
`WORLD_SIZE` - Total nodes that are used for the training
`MASTER_ADDR` - The IP address or host name of the master node
`MASTER_PORT` - The port of the master node
`RANK` - The rank of current node
For example, to run train a model with 2 nodes,
- On **node0**,
``` bash
export WORLD_SIZE=2
export MASTER_ADDR=node0
export MASTER_PORT=7488
export RANK=0
./rtd.sh deberta-v3-xsmall
```
- On **node1**,
``` bash
export WORLD_SIZE=2
export MASTER_ADDR=node0
export MASTER_PORT=7488
export RANK=1
./rtd.sh deberta-v3-xsmall
```
## Model config options
- `relative_attention` Whether to used relative attention
- `pos_att_type` Relative position encoding type
- P2C Postion to content attention
- C2P Content to position attention

Просмотреть файл

@ -0,0 +1,22 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"relative_attention": true,
"position_buckets": 256,
"norm_rel_ebd": "layer_norm",
"share_att_key": true,
"pos_att_type": "p2c|c2p",
"layer_norm_eps": 1e-7,
"max_relative_positions": -1,
"position_biased_input": false,
"num_attention_heads": 12,
"attention_head_size": 64,
"num_hidden_layers": 12,
"type_vocab_size": 0,
"vocab_size": 128100
}

Просмотреть файл

@ -5,20 +5,23 @@ cd $SCRIPT_DIR
cache_dir=/tmp/DeBERTa/MLM/
max_seq_length=512
data_dir=$cache_dir/wiki103/spm_$max_seq_length
function setup_wiki_data(){
task=$1
mkdir -p $cache_dir
if [[ ! -e $cache_dir/spm.model ]]; then
wget -q https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model -O $cache_dir/spm.model
wget -q https://huggingface.co/microsoft/deberta-v3-base/resolve/main/spm.model -O $cache_dir/spm.model
fi
if [[ ! -e $cache_dir/wiki103.zip ]]; then
if [[ ! -e $data_dir/test.txt ]]; then
wget -q https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -O $cache_dir/wiki103.zip
unzip -j $cache_dir/wiki103.zip -d $cache_dir/wiki103
mkdir -p $cache_dir/wiki103/spm
python ./prepare_data.py -i $cache_dir/wiki103/wiki.train.tokens -o $cache_dir/wiki103/spm/train.txt
python ./prepare_data.py -i $cache_dir/wiki103/wiki.valid.tokens -o $cache_dir/wiki103/spm/valid.txt
python ./prepare_data.py -i $cache_dir/wiki103/wiki.test.tokens -o $cache_dir/wiki103/spm/test.txt
mkdir -p $data_dir
python ./prepare_data.py -i $cache_dir/wiki103/wiki.train.tokens -o $data_dir/train.txt --max_seq_length $max_seq_length
python ./prepare_data.py -i $cache_dir/wiki103/wiki.valid.tokens -o $data_dir/valid.txt --max_seq_length $max_seq_length
python ./prepare_data.py -i $cache_dir/wiki103/wiki.test.tokens -o $data_dir/test.txt --max_seq_length $max_seq_length
fi
}
@ -36,11 +39,20 @@ case ${init,,} in
--learning_rate 1e-4 \
--train_batch_size 256 \
--max_ngram 1 \
--fp16 True "
;;
deberta-base)
parameters=" --num_train_epochs 1 \
--model_config deberta_base.json \
--warmup 10000 \
--learning_rate 1e-4 \
--train_batch_size 256 \
--max_ngram 3 \
--fp16 True "
;;
xlarge-v2)
parameters=" --num_train_epochs 1 \
--model_config xlarge.json \
--model_config deberta_xlarge.json \
--warmup 1000 \
--learning_rate 1e-4 \
--train_batch_size 32 \
@ -50,7 +62,7 @@ case ${init,,} in
xxlarge-v2)
parameters=" --num_train_epochs 1 \
--warmup 1000 \
--model_config xxlarge.json \
--model_config deberta_xxlarge.json \
--learning_rate 1e-4 \
--train_batch_size 32 \
--max_ngram 3 \
@ -59,18 +71,19 @@ case ${init,,} in
*)
echo "usage $0 <Pretrained model configuration>"
echo "Supported configurations"
echo "bert-base - Pretrained a bert base model with DeBERTa vocabulary (12 layers, 768 hidden size, 128k vocabulary size)"
echo "deberta-base - Pretrained a deberta base model (12 layers, 768 hidden size, 128k vocabulary size)"
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
exit 0
;;
esac
data_dir=$cache_dir/wiki103/spm
python -m DeBERTa.apps.run --model_config config.json \
--tag $tag \
--do_train \
--num_training_steps 1000000 \
--max_seq_len 512 \
--max_seq_len $max_seq_length \
--dump 10000 \
--task_name $Task \
--data_dir $data_dir \

Просмотреть файл

@ -4,8 +4,8 @@ import sys
import argparse
from tqdm import tqdm
def tokenize_data(input, output=None):
p,t=deberta.load_vocab(vocab_path=None, vocab_type='spm', pretrained_id='xlarge-v2')
def tokenize_data(input, output=None, max_seq_length=512):
p,t=deberta.load_vocab(vocab_path=None, vocab_type='spm', pretrained_id='deberta-v3-base')
tokenizer=deberta.tokenizers[t](p)
if output is None:
output=input + '.spm'
@ -23,8 +23,8 @@ def tokenize_data(input, output=None):
with open(output, 'w', encoding = 'utf-8') as wfs:
idx = 0
while idx < len(all_tokens):
wfs.write(' '.join(all_tokens[idx:idx+510]) + '\n')
idx += 510
wfs.write(' '.join(all_tokens[idx:idx+max_seq_length-2]) + '\n')
idx += (max_seq_length - 2)
lines += 1
print(f'Saved {lines} lines to {output}')
@ -32,5 +32,6 @@ def tokenize_data(input, output=None):
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True, help='The input data path')
parser.add_argument('-o', '--output', default=None, help='The output data path')
parser.add_argument('--max_seq_length', type=int, default=512, help='Maxium sequence length of inputs')
args = parser.parse_args()
tokenize_data(args.input, args.output)
tokenize_data(args.input, args.output, args.max_seq_length)

Просмотреть файл

@ -10,7 +10,6 @@ ujson
seqeval
psutil
sentencepiece
laser
#GitPython
torch
#torchvision