зеркало из https://github.com/microsoft/DeBERTa.git
Add Masked Language Model task
This commit is contained in:
Родитель
4cfa08e7c8
Коммит
c7147a038e
|
@ -2,3 +2,4 @@ from .ner import *
|
|||
from .multi_choice import *
|
||||
from .sequence_classification import *
|
||||
from .record_qa import *
|
||||
from .masked_language_model import *
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
#
|
||||
# Author: penhe@microsoft.com
|
||||
# Date: 04/25/2019
|
||||
#
|
||||
""" Masked Language model for representation learning
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import csv
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pdb
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from ...deberta import *
|
||||
|
||||
__all__ = ['MaskedLanguageModel']
|
||||
|
||||
class EnhancedMaskDecoder(torch.nn.Module):
|
||||
def __init__(self, config, vocab_size):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lm_head = BertLMPredictionHead(config, vocab_size)
|
||||
|
||||
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None, *wargs, **kwargs):
|
||||
mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
lm_loss = torch.tensor(0).to(ctx_layers[-1])
|
||||
arlm_loss = torch.tensor(0).to(ctx_layers[-1])
|
||||
ctx_layer = mlm_ctx_layers[-1]
|
||||
lm_logits = self.lm_head(ctx_layer, ebd_weight).float()
|
||||
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
|
||||
lm_labels = target_ids.view(-1)
|
||||
label_index = (target_ids.view(-1)>0).nonzero().view(-1)
|
||||
lm_labels = lm_labels.index_select(0, label_index)
|
||||
lm_loss = loss_fct(lm_logits, lm_labels.long())
|
||||
return lm_logits, lm_labels, lm_loss
|
||||
|
||||
def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=None):
|
||||
if attention_mask.dim()<=2:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
att_mask = extended_attention_mask.byte()
|
||||
attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
|
||||
elif attention_mask.dim()==3:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
target_mask = target_ids>0
|
||||
hidden_states = encoder_layers[-2]
|
||||
layers = [encoder.layer[-1] for _ in range(2)]
|
||||
|
||||
z_states += hidden_states
|
||||
query_mask = attention_mask
|
||||
query_states = z_states
|
||||
outputs = []
|
||||
rel_embeddings = encoder.get_rel_embedding()
|
||||
|
||||
for layer in layers:
|
||||
# TODO: pass relative pos ids
|
||||
output = layer(hidden_states, query_mask, return_att=False, query_states = query_states, relative_pos=relative_pos, rel_embeddings = rel_embeddings)
|
||||
query_states = output
|
||||
outputs.append(query_states)
|
||||
|
||||
_mask_index = (target_ids>0).view(-1).nonzero().view(-1)
|
||||
def flatten_states(q_states):
|
||||
q_states = q_states.view((-1, q_states.size(-1)))
|
||||
q_states = q_states.index_select(0, _mask_index)
|
||||
return q_states
|
||||
|
||||
return [flatten_states(q) for q in outputs]
|
||||
|
||||
class MaskedLanguageModel(NNModule):
|
||||
""" Masked language model with DeBERTa
|
||||
"""
|
||||
def __init__(self, config, *wargs, **kwargs):
|
||||
super().__init__(config)
|
||||
self.deberta = DeBERTa(config)
|
||||
|
||||
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
|
||||
self.position_buckets = getattr(config, 'position_buckets', -1)
|
||||
if self.max_relative_positions <1:
|
||||
self.max_relative_positions = config.max_position_embeddings
|
||||
self.lm_predictions = EnhancedMaskDecoder(self.deberta.config, self.deberta.embeddings.word_embeddings.weight.size(0))
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None, **kwargs):
|
||||
device = list(self.parameters())[0].device
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
type_ids = None
|
||||
lm_labels = labels.to(device)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = input_mask
|
||||
|
||||
encoder_output = self.deberta(input_ids, input_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
|
||||
encoder_layers = encoder_output['hidden_states']
|
||||
z_states = encoder_output['position_embeddings']
|
||||
ctx_layer = encoder_layers[-1]
|
||||
lm_loss = torch.tensor(0).to(ctx_layer).float()
|
||||
lm_logits = None
|
||||
label_inputs = None
|
||||
if lm_labels is not None:
|
||||
ebd_weight = self.deberta.embeddings.word_embeddings.weight
|
||||
|
||||
label_index = (lm_labels.view(-1) > 0).nonzero()
|
||||
label_inputs = torch.gather(input_ids.view(-1), 0, label_index.view(-1))
|
||||
if label_index.size(0)>0:
|
||||
(lm_logits, lm_labels, lm_loss) = self.lm_predictions(encoder_layers, ebd_weight, lm_labels, input_ids, input_mask, z_states, attention_mask, self.deberta.encoder)
|
||||
|
||||
return {
|
||||
'logits' : lm_logits,
|
||||
'labels' : lm_labels,
|
||||
'loss' : lm_loss.float(),
|
||||
}
|
|
@ -45,9 +45,9 @@ class MultiChoiceModel(NNModule):
|
|||
position_ids = position_ids.view([-1, position_ids.size(-1)])
|
||||
if input_mask is not None:
|
||||
input_mask = input_mask.view([-1, input_mask.size(-1)])
|
||||
encoder_layers = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
|
||||
outputs = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
|
||||
position_ids=position_ids, output_all_encoded_layers=True)
|
||||
hidden_states = encoder_layers[-1]
|
||||
hidden_states = outputs['hidden_states'][-1]
|
||||
logits = self.classifier(self.dropout(self.pooler(hidden_states)))
|
||||
logits = logits.float().squeeze(-1)
|
||||
logits = logits.view([-1, num_opts])
|
||||
|
@ -57,7 +57,10 @@ class MultiChoiceModel(NNModule):
|
|||
loss_fn = CrossEntropyLoss()
|
||||
loss = loss_fn(logits, labels)
|
||||
|
||||
return (logits, loss)
|
||||
return {
|
||||
'logits' : logits,
|
||||
'loss' : loss
|
||||
}
|
||||
|
||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
|
|
|
@ -31,8 +31,9 @@ class NERModel(NNModule):
|
|||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
|
||||
encoder_layers = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask, \
|
||||
outputs = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask, \
|
||||
position_ids=position_ids, output_all_encoded_layers=True)
|
||||
encoder_layers = outputs['hidden_states']
|
||||
cls = encoder_layers[-1]
|
||||
cls = self.proj(cls)
|
||||
cls = ACT2FN['gelu'](cls)
|
||||
|
@ -47,4 +48,7 @@ class NERModel(NNModule):
|
|||
loss_fn = CrossEntropyLoss()
|
||||
loss = loss_fn(valid_logits, valid_labels)
|
||||
|
||||
return (logits, loss)
|
||||
return {
|
||||
'logits' : logits,
|
||||
'loss' : loss
|
||||
}
|
||||
|
|
|
@ -32,8 +32,9 @@ class ReCoRDQAModel(NNModule):
|
|||
self.deberta.apply_state()
|
||||
|
||||
def forward(self, input_ids, entity_indice, type_ids=None, input_mask=None, labels=None, position_ids=None, placeholder=None, **kwargs):
|
||||
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,\
|
||||
outputs = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,\
|
||||
position_ids=position_ids, output_all_encoded_layers=True)
|
||||
encoder_layers = outputs['hidden_states']
|
||||
# bxexsp
|
||||
entity_mask = entity_indice>0
|
||||
tokens = encoder_layers[-1]
|
||||
|
@ -66,4 +67,6 @@ class ReCoRDQAModel(NNModule):
|
|||
loss_fn = BCEWithLogitsLoss()
|
||||
loss = loss_fn(sp_logits, labels.to(sp_logits))
|
||||
|
||||
return (logits, loss)
|
||||
return {
|
||||
'logits': logits,
|
||||
'loss': loss }
|
||||
|
|
|
@ -41,8 +41,9 @@ class SequenceClassificationModel(NNModule):
|
|||
self.deberta.apply_state()
|
||||
|
||||
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
|
||||
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
|
||||
outputs = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
|
||||
position_ids=position_ids, output_all_encoded_layers=True)
|
||||
encoder_layers = outputs['hidden_states']
|
||||
pooled_output = self.pooler(encoder_layers[-1])
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
@ -69,6 +70,10 @@ class SequenceClassificationModel(NNModule):
|
|||
label_confidence = 1
|
||||
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()
|
||||
|
||||
return {
|
||||
'logits' : logits,
|
||||
'loss' : loss
|
||||
}
|
||||
return (logits,loss)
|
||||
|
||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
|
||||
|
|
|
@ -26,6 +26,8 @@ from ..utils import *
|
|||
from ..utils import xtqdm as tqdm
|
||||
from .tasks import load_tasks,get_task
|
||||
|
||||
import pdb
|
||||
|
||||
import LASER
|
||||
from LASER.training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
|
||||
from LASER.data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
|
||||
|
@ -58,7 +60,8 @@ def train_model(args, model, device, train_data, eval_data):
|
|||
return eval_metric
|
||||
|
||||
def loss_fn(trainer, model, data):
|
||||
_, loss = model(**data)
|
||||
output = model(**data)
|
||||
loss = output['loss']
|
||||
return loss.mean(), data['input_ids'].size(0)
|
||||
|
||||
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
|
||||
|
@ -67,9 +70,12 @@ def train_model(args, model, device, train_data, eval_data):
|
|||
def merge_distributed(data_list, max_len=None):
|
||||
merged = []
|
||||
def gather(data):
|
||||
data_chunks = [torch.zeros_like(data) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(data_chunks, data)
|
||||
torch.cuda.synchronize()
|
||||
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
|
||||
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
|
||||
data_chunks[data.device.index] = data
|
||||
for i,_chunk in enumerate(data_chunks):
|
||||
torch.distributed.broadcast(_chunk, src=i)
|
||||
return data_chunks
|
||||
|
||||
for data in data_list:
|
||||
|
@ -82,12 +88,12 @@ def merge_distributed(data_list, max_len=None):
|
|||
data_chunks.append(data_)
|
||||
merged.append(data_chunks)
|
||||
else:
|
||||
data_chunks = gather(data)
|
||||
merged.extend(data_chunks)
|
||||
_chunks = gather(data)
|
||||
merged.extend(_chunks)
|
||||
else:
|
||||
merged.append(data)
|
||||
if not isinstance(merged[0], Sequence):
|
||||
merged = torch.cat(merged)
|
||||
merged = torch.cat([m.cpu() for m in merged])
|
||||
if max_len is not None:
|
||||
return merged[:max_len]
|
||||
else:
|
||||
|
@ -95,7 +101,7 @@ def merge_distributed(data_list, max_len=None):
|
|||
else:
|
||||
data_list=[]
|
||||
for d in zip(*merged):
|
||||
data = torch.cat(d)
|
||||
data = torch.cat([x.cpu() for x in d])
|
||||
if max_len is not None:
|
||||
data = data[:max_len]
|
||||
data_list.append(data)
|
||||
|
@ -163,8 +169,13 @@ def run_eval(args, model, device, eval_data, prefix=None, tag=None, steps=None):
|
|||
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=no_tqdm):
|
||||
batch = batch_to(batch, device)
|
||||
with torch.no_grad():
|
||||
logits, tmp_eval_loss = model(**batch)
|
||||
label_ids = batch['labels'].to(device)
|
||||
output = model(**batch)
|
||||
logits = output['logits'].detach()
|
||||
tmp_eval_loss = output['loss'].detach()
|
||||
if 'labels' in output:
|
||||
label_ids = output['labels'].detach().to(device)
|
||||
else:
|
||||
label_ids = batch['labels'].to(device)
|
||||
predicts.append(logits)
|
||||
labels.append(label_ids)
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
|
@ -198,8 +209,9 @@ def run_predict(args, model, device, eval_data, prefix=None):
|
|||
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=args.rank>0):
|
||||
batch = batch_to(batch, device)
|
||||
with torch.no_grad():
|
||||
logits, _ = model(**batch)
|
||||
predicts.append(logits)
|
||||
output = model(**batch)
|
||||
logits = output['logits']
|
||||
predicts.append(logits)
|
||||
predicts = merge_distributed(predicts, len(eval_item.data))
|
||||
if args.rank<=0:
|
||||
predict_fn = eval_item.predict_fn
|
||||
|
@ -238,7 +250,7 @@ def main(args):
|
|||
logger.info(" Prediction batch size = %d", args.predict_batch_size)
|
||||
|
||||
if args.do_train:
|
||||
train_data = task.train_data(max_seq_len=args.max_seq_length, mask_gen = None, debug=args.debug)
|
||||
train_data = task.train_data(max_seq_len=args.max_seq_length, debug=args.debug)
|
||||
model_class_fn = task.get_model_class_fn()
|
||||
model = create_model(args, len(label_list), model_class_fn)
|
||||
if args.do_train:
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
#
|
||||
# Author: penhe@microsoft.com
|
||||
# Date: 01/25/2019
|
||||
#
|
||||
|
||||
from glob import glob
|
||||
from collections import OrderedDict,defaultdict,Sequence
|
||||
from bisect import bisect
|
||||
import copy
|
||||
import math
|
||||
from scipy.special import softmax
|
||||
import numpy as np
|
||||
import pdb
|
||||
import os
|
||||
import sys
|
||||
import csv
|
||||
|
||||
import random
|
||||
import torch
|
||||
import re
|
||||
import ujson as json
|
||||
from .metrics import *
|
||||
from .task import EvalData, Task
|
||||
from .task_registry import register_task
|
||||
from ...utils import xtqdm as tqdm
|
||||
from ...data import ExampleInstance, ExampleSet, DynamicDataset,example_to_feature
|
||||
from ...data.example import _truncate_segments
|
||||
from ...data.example import *
|
||||
from ...utils import get_logger
|
||||
from ..models import MaskedLanguageModel
|
||||
|
||||
logger=get_logger()
|
||||
|
||||
__all__ = ["MLMTask"]
|
||||
|
||||
class NGramMaskGenerator:
|
||||
"""
|
||||
Mask ngram tokens
|
||||
https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py
|
||||
"""
|
||||
def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_seq=None, max_gram = 3, keep_prob = 0.1, mask_prob=0.8, **kwargs):
|
||||
self.tokenizer = tokenizer
|
||||
self.mask_lm_prob = mask_lm_prob
|
||||
self.keep_prob = keep_prob
|
||||
self.mask_prob = mask_prob
|
||||
assert self.mask_prob+self.keep_prob<=1, f'The prob of using [MASK]({mask_prob}) and the prob of using original token({keep_prob}) should between [0,1]'
|
||||
self.max_preds_per_seq = max_preds_per_seq
|
||||
if max_preds_per_seq is None:
|
||||
self.max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
|
||||
|
||||
self.max_gram = max_gram
|
||||
self.mask_window = int(1/mask_lm_prob) # make ngrams per window sized context
|
||||
self.vocab_words = list(tokenizer.vocab.keys())
|
||||
|
||||
def mask_tokens(self, tokens, rng, **kwargs):
|
||||
special_tokens = ['[MASK]', '[CLS]', '[SEP]', '[PAD]', '[UNK]'] # + self.tokenizer.tokenize(' ')
|
||||
indices = [i for i in range(len(tokens)) if tokens[i] not in special_tokens]
|
||||
ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
|
||||
pvals = 1. / np.arange(1, self.max_gram + 1)
|
||||
pvals /= pvals.sum(keepdims=True)
|
||||
|
||||
unigrams = []
|
||||
for id in indices:
|
||||
if len(unigrams)>=1 and self.tokenizer.part_of_whole_word(tokens[id]):
|
||||
unigrams[-1].append(id)
|
||||
else:
|
||||
unigrams.append([id])
|
||||
|
||||
num_to_predict = min(self.max_preds_per_seq, max(1, int(round(len(tokens) * self.mask_lm_prob))))
|
||||
mask_len = 0
|
||||
offset = 0
|
||||
mask_grams = np.array([False]*len(unigrams))
|
||||
while offset < len(unigrams):
|
||||
n = self._choice(rng, ngrams, p=pvals)
|
||||
ctx_size = min(n*self.mask_window, len(unigrams)-offset)
|
||||
m = rng.randint(0, ctx_size-1)
|
||||
s = offset + m
|
||||
e = min(offset+m+n, len(unigrams))
|
||||
offset = max(offset+ctx_size, e)
|
||||
mask_grams[s:e] = True
|
||||
|
||||
target_labels = [None]*len(tokens)
|
||||
w_cnt = 0
|
||||
for m,word in zip(mask_grams, unigrams):
|
||||
if m:
|
||||
for idx in word:
|
||||
label = self._mask_token(idx, tokens, rng, self.mask_prob, self.keep_prob)
|
||||
target_labels[idx] = label
|
||||
w_cnt += 1
|
||||
if w_cnt >= num_to_predict:
|
||||
break
|
||||
|
||||
target_labels = [self.tokenizer.vocab[x] if x else 0 for x in target_labels]
|
||||
return tokens, target_labels
|
||||
|
||||
def _choice(self, rng, data, p):
|
||||
cul = np.cumsum(p)
|
||||
x = rng.random()*cul[-1]
|
||||
id = bisect(cul, x)
|
||||
return data[id]
|
||||
|
||||
def _mask_token(self, idx, tokens, rng, mask_prob, keep_prob):
|
||||
label = tokens[idx]
|
||||
mask = '[MASK]'
|
||||
rand = rng.random()
|
||||
if rand < mask_prob:
|
||||
new_label = mask
|
||||
elif rand < mask_prob+keep_prob:
|
||||
new_label = label
|
||||
else:
|
||||
new_label = rng.choice(self.vocab_words)
|
||||
|
||||
tokens[idx] = new_label
|
||||
|
||||
return label
|
||||
|
||||
@register_task(name="MLM", desc="Masked language model pretraining task")
|
||||
class MLMTask(Task):
|
||||
def __init__(self, data_dir, tokenizer, args, **kwargs):
|
||||
super().__init__(tokenizer, args, **kwargs)
|
||||
self.data_dir = data_dir
|
||||
self.mask_gen = NGramMaskGenerator(tokenizer, max_gram=self.args.max_ngram)
|
||||
|
||||
def train_data(self, max_seq_len=512, **kwargs):
|
||||
data = self.load_data(os.path.join(self.data_dir, 'train.txt'))
|
||||
examples = ExampleSet(data)
|
||||
if self.args.num_training_steps is None:
|
||||
dataset_size = len(examples)
|
||||
else:
|
||||
dataset_size = self.args.num_training_steps*self.args.train_batch_size
|
||||
return DynamicDataset(examples, feature_fn = self.get_feature_fn(max_seq_len=max_seq_len, mask_gen=self.mask_gen), \
|
||||
dataset_size = dataset_size, shuffle=True, **kwargs)
|
||||
|
||||
def get_labels(self):
|
||||
return list(self.tokenizer.vocab.values())
|
||||
|
||||
def eval_data(self, max_seq_len=512, **kwargs):
|
||||
ds = [
|
||||
self._data('dev', 'valid.txt', 'dev'),
|
||||
]
|
||||
|
||||
for d in ds:
|
||||
_size = len(d.data)
|
||||
d.data = DynamicDataset(d.data, feature_fn = self.get_feature_fn(max_seq_len=max_seq_len, mask_gen=self.mask_gen), dataset_size = _size, **kwargs)
|
||||
return ds
|
||||
|
||||
def test_data(self, max_seq_len=512, **kwargs):
|
||||
"""See base class."""
|
||||
raise NotImplemented('This method is not implemented yet.')
|
||||
|
||||
def _data(self, name, path, type_name = 'dev', ignore_metric=False):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
data = []
|
||||
for p in path:
|
||||
input_src = os.path.join(self.data_dir, p)
|
||||
assert os.path.exists(input_src), f"{input_src} doesn't exists"
|
||||
data.extend(self.load_data(input_src))
|
||||
|
||||
predict_fn = self.get_predict_fn()
|
||||
examples = ExampleSet(data)
|
||||
return EvalData(name, examples,
|
||||
metrics_fn = self.get_metrics_fn(), predict_fn = predict_fn, ignore_metric=ignore_metric, critial_metrics=['accuracy'])
|
||||
|
||||
def get_metrics_fn(self):
|
||||
"""Calcuate metrics based on prediction results"""
|
||||
def metrics_fn(logits, labels):
|
||||
preds = np.argmax(logits, axis=-1)
|
||||
acc = (preds==labels).sum()/len(labels)
|
||||
metrics = OrderedDict(accuracy= acc)
|
||||
return metrics
|
||||
return metrics_fn
|
||||
|
||||
def load_data(self, path):
|
||||
examples = []
|
||||
with open(path, encoding='utf-8') as fs:
|
||||
for l in fs:
|
||||
if len(l) > 1:
|
||||
example = ExampleInstance(segments=[l])
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
def get_feature_fn(self, max_seq_len = 512, mask_gen = None):
|
||||
def _example_to_feature(example, rng=None, ext_params=None, **kwargs):
|
||||
return self.example_to_feature(self.tokenizer, example, max_seq_len = max_seq_len, \
|
||||
rng = rng, mask_generator = mask_gen, ext_params = ext_params, **kwargs)
|
||||
return _example_to_feature
|
||||
|
||||
def example_to_feature(self, tokenizer, example, max_seq_len=512, rng=None, mask_generator = None, ext_params=None, **kwargs):
|
||||
if not rng:
|
||||
rng = random
|
||||
max_num_tokens = max_seq_len - 2
|
||||
|
||||
segments = [ example.segments[0].strip().split() ]
|
||||
segments = _truncate_segments(segments, max_num_tokens, rng)
|
||||
_tokens = ['[CLS]'] + segments[0] + ['[SEP]']
|
||||
if mask_generator:
|
||||
tokens, lm_labels = mask_generator.mask_tokens(_tokens, rng)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
features = OrderedDict(input_ids = token_ids,
|
||||
position_ids = list(range(len(token_ids))),
|
||||
input_mask = [1]*len(token_ids),
|
||||
labels = lm_labels)
|
||||
|
||||
for f in features:
|
||||
features[f] = torch.tensor(features[f] + [0]*(max_seq_len - len(token_ids)), dtype=torch.int)
|
||||
return features
|
||||
|
||||
def get_model_class_fn(self):
|
||||
def partial_class(*wargs, **kwargs):
|
||||
return MaskedLanguageModel.load_model(*wargs, **kwargs)
|
||||
return partial_class
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
"""Add task specific arguments
|
||||
e.g. parser.add_argument('--data_dir', type=str, help='The path of data directory.')
|
||||
"""
|
||||
parser.add_argument('--max_ngram', type=int, default=1, help='Maxium ngram sampling span')
|
||||
parser.add_argument('--num_training_steps', type=int, default=None, help='Maxium pre-training steps')
|
||||
|
||||
def test_MLM():
|
||||
from ...deberta import tokenizers,load_vocab
|
||||
import pdb
|
||||
vocab_path, vocab_type = load_vocab(vocab_path = None, vocab_type = 'spm', pretrained_id = 'xlarge-v2')
|
||||
tokenizer = tokenizers[vocab_type](vocab_path)
|
||||
mask_gen = NGramMaskGenerator(tokenizer, max_gram=1)
|
||||
mlm = MLMTask('/mnt/penhe/data/wiki103/spm', tokenizer, None)
|
||||
train_data = mlm.train_data(mask_gen = mask_gen)
|
||||
pdb.set_trace()
|
|
@ -21,7 +21,7 @@ from .ops import *
|
|||
from .disentangled_attention import *
|
||||
from .da_utils import *
|
||||
|
||||
__all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm']
|
||||
__all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead']
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -184,7 +184,7 @@ class BertEncoder(nn.Module):
|
|||
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
|
||||
|
||||
all_encoder_layers = []
|
||||
att_matrixs = []
|
||||
att_matrices = []
|
||||
if isinstance(hidden_states, Sequence):
|
||||
next_kv = hidden_states[0]
|
||||
else:
|
||||
|
@ -209,15 +209,15 @@ class BertEncoder(nn.Module):
|
|||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(output_states)
|
||||
if return_att:
|
||||
att_matrixs.append(att_m)
|
||||
att_matrices.append(att_m)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(output_states)
|
||||
if return_att:
|
||||
att_matrixs.append(att_m)
|
||||
if return_att:
|
||||
return (all_encoder_layers, att_matrixs)
|
||||
else:
|
||||
return all_encoder_layers
|
||||
att_matrices.append(att_m)
|
||||
return {
|
||||
'hidden_states': all_encoder_layers,
|
||||
'attention_matrices': att_matrices
|
||||
}
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
|
@ -229,10 +229,7 @@ class BertEmbeddings(nn.Module):
|
|||
self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx)
|
||||
|
||||
self.position_biased_input = getattr(config, 'position_biased_input', True)
|
||||
if not self.position_biased_input:
|
||||
self.position_embeddings = None
|
||||
else:
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
|
||||
|
||||
if config.type_vocab_size>0:
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
|
||||
|
@ -253,14 +250,9 @@ class BertEmbeddings(nn.Module):
|
|||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
if self.position_embeddings is not None:
|
||||
position_embeddings = self.position_embeddings(position_ids.long())
|
||||
else:
|
||||
position_embeddings = torch.zeros_like(words_embeddings)
|
||||
position_embeddings = self.position_embeddings(position_ids.long())
|
||||
|
||||
embeddings = words_embeddings
|
||||
if self.position_biased_input:
|
||||
embeddings += position_embeddings
|
||||
if self.config.type_vocab_size>0:
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings += token_type_embeddings
|
||||
|
@ -270,4 +262,28 @@ class BertEmbeddings(nn.Module):
|
|||
|
||||
embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
return {
|
||||
'embeddings': embeddings,
|
||||
'position_embeddings': position_embeddings}
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, vocab_size):
|
||||
super().__init__()
|
||||
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
|
||||
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(vocab_size))
|
||||
|
||||
def forward(self, hidden_states, embeding_weight):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
# b x s x d
|
||||
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
|
||||
|
||||
# b x s x v
|
||||
logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
|
||||
return logits
|
||||
|
|
|
@ -110,19 +110,13 @@ class DeBERTa(torch.nn.Module):
|
|||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
|
||||
embedding_output = ebd_output['embeddings']
|
||||
encoder_output = self.encoder(embedding_output,
|
||||
attention_mask,
|
||||
output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
|
||||
if return_att:
|
||||
encoded_layers, att_matrixs = encoded_layers
|
||||
|
||||
if not output_all_encoded_layers:
|
||||
encoded_layers = encoded_layers[-1:]
|
||||
|
||||
if return_att:
|
||||
return encoded_layers, att_matrixs
|
||||
return encoded_layers
|
||||
encoder_output.update(ebd_output)
|
||||
return encoder_output
|
||||
|
||||
def apply_state(self, state = None):
|
||||
""" Load state from previous loaded model state dictionary.
|
||||
|
|
|
@ -92,8 +92,9 @@ class NNModule(nn.Module):
|
|||
config = None
|
||||
model_config = None
|
||||
model_state = None
|
||||
if model_path and model_path.strip() == '-' or model_path.strip()=='':
|
||||
if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
|
||||
model_path = None
|
||||
|
||||
try:
|
||||
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
||||
except Exception as exp:
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
#!/bin/bash
|
||||
SCRIPT=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT")
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
cache_dir=/tmp/DeBERTa/MLM/
|
||||
|
||||
function setup_wiki_data(){
|
||||
task=$1
|
||||
mkdir -p $cache_dir
|
||||
if [[ ! -e $cache_dir/spm.model ]]; then
|
||||
wget -q https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model -O $cache_dir/spm.model
|
||||
fi
|
||||
|
||||
if [[ ! -e $cache_dir/wiki103.zip ]]; then
|
||||
wget -q https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -O $cache_dir/wiki103.zip
|
||||
unzip -j $cache_dir/wiki103.zip -d $cache_dir/wiki103
|
||||
mkdir -p $cache_dir/wiki103/spm
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.train.tokens -o $cache_dir/wiki103/spm/train.txt
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.valid.tokens -o $cache_dir/wiki103/spm/valid.txt
|
||||
python ./prepare_data.py -i $cache_dir/wiki103/wiki.test.tokens -o $cache_dir/wiki103/spm/test.txt
|
||||
fi
|
||||
}
|
||||
|
||||
setup_wiki_data
|
||||
|
||||
Task=MLM
|
||||
|
||||
init=$1
|
||||
tag=$init
|
||||
case ${init,,} in
|
||||
xlarge-v2)
|
||||
parameters=" --num_train_epochs 1 \
|
||||
--model_config xlarge.json \
|
||||
--warmup 1000 \
|
||||
--learning_rate 1e-4 \
|
||||
--train_batch_size 32 \
|
||||
--max_ngram 3 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xxlarge-v2)
|
||||
parameters=" --num_train_epochs 1 \
|
||||
--warmup 1000 \
|
||||
--model_config xxlarge.json \
|
||||
--learning_rate 1e-4 \
|
||||
--train_batch_size 32 \
|
||||
--max_ngram 3 \
|
||||
--fp16 True "
|
||||
;;
|
||||
*)
|
||||
echo "usage $0 <Pretrained model configuration>"
|
||||
echo "Supported configurations"
|
||||
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
|
||||
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
data_dir=$cache_dir/wiki103/spm
|
||||
python -m DeBERTa.apps.run --model_config config.json \
|
||||
--tag $tag \
|
||||
--do_train \
|
||||
--num_training_steps 10000 \
|
||||
--max_seq_len 512 \
|
||||
--task_name $Task \
|
||||
--data_dir $data_dir \
|
||||
--vocab_path $cache_dir/spm.model \
|
||||
--vocab_type spm \
|
||||
--output_dir /tmp/ttonly/$tag/$task $parameters
|
|
@ -0,0 +1,36 @@
|
|||
# coding: utf-8
|
||||
from DeBERTa import deberta
|
||||
import sys
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
def tokenize_data(input, output=None):
|
||||
p,t=deberta.load_vocab(vocab_path=None, vocab_type='gpm', pretrained_id='xlarge-v2')
|
||||
tokenizer=deberta.tokenizers[t](p)
|
||||
if output is None:
|
||||
output=input + '.gpm'
|
||||
all_tokens = []
|
||||
with open(input, encoding = 'utf-8') as fs:
|
||||
for l in tqdm(fs, ncols=80, desc='Loading'):
|
||||
if len(l) > 0:
|
||||
tokens = tokenizer.tokenize(l)
|
||||
else:
|
||||
tokens = []
|
||||
all_tokens.extend(tokens)
|
||||
|
||||
print(f'Loaded {len(all_tokens)} tokens from {input}')
|
||||
lines = 0
|
||||
with open(output, 'w', encoding = 'utf-8') as wfs:
|
||||
idx = 0
|
||||
while idx < len(all_tokens):
|
||||
wfs.write(' '.join(all_tokens[idx:idx+510]) + '\n')
|
||||
idx += 510
|
||||
lines += 1
|
||||
|
||||
print(f'Saved {lines} lines to {output}')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input', required=True, help='The input data path')
|
||||
parser.add_argument('-o', '--output', default=None, help='The output data path')
|
||||
args = parser.parse_args()
|
||||
tokenize_data(args.input, args.output)
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1536,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 512,
|
||||
"relative_attention": true,
|
||||
"position_buckets": 256,
|
||||
"norm_rel_ebd": "layer_norm",
|
||||
"share_att_key": true,
|
||||
"pos_att_type": "p2c|c2p",
|
||||
"layer_norm_eps": 1e-7,
|
||||
"conv_kernel_size": 3,
|
||||
"conv_act": "gelu",
|
||||
"max_relative_positions": -1,
|
||||
"position_biased_input": false,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_size": 64,
|
||||
"num_hidden_layers": 24,
|
||||
"type_vocab_size": 0,
|
||||
"vocab_size": 128100
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1536,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 512,
|
||||
"relative_attention": true,
|
||||
"position_buckets": 256,
|
||||
"norm_rel_ebd": "layer_norm",
|
||||
"share_att_key": true,
|
||||
"pos_att_type": "p2c|c2p",
|
||||
"layer_norm_eps": 1e-7,
|
||||
"conv_kernel_size": 3,
|
||||
"conv_act": "gelu",
|
||||
"max_relative_positions": -1,
|
||||
"position_biased_input": false,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_size": 64,
|
||||
"num_hidden_layers": 48,
|
||||
"type_vocab_size": 0,
|
||||
"vocab_size": 128100
|
||||
}
|
Загрузка…
Ссылка в новой задаче