Add Masked Language Model task

This commit is contained in:
Pengcheng He 2021-03-29 19:16:15 -04:00 коммит произвёл Pengcheng He
Родитель 4cfa08e7c8
Коммит c7147a038e
15 изменённых файлов: 600 добавлений и 52 удалений

Просмотреть файл

@ -2,3 +2,4 @@ from .ner import *
from .multi_choice import *
from .sequence_classification import *
from .record_qa import *
from .masked_language_model import *

Просмотреть файл

@ -0,0 +1,125 @@
#
# Author: penhe@microsoft.com
# Date: 04/25/2019
#
""" Masked Language model for representation learning
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from concurrent.futures import ThreadPoolExecutor
import csv
import os
import json
import random
import time
from tqdm import tqdm, trange
import numpy as np
import torch
import torch.nn as nn
import pdb
from collections.abc import Mapping
from copy import copy
from ...deberta import *
__all__ = ['MaskedLanguageModel']
class EnhancedMaskDecoder(torch.nn.Module):
def __init__(self, config, vocab_size):
super().__init__()
self.config = config
self.lm_head = BertLMPredictionHead(config, vocab_size)
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None, *wargs, **kwargs):
mlm_ctx_layers = self.emd_context_layer(ctx_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=relative_pos)
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
lm_loss = torch.tensor(0).to(ctx_layers[-1])
arlm_loss = torch.tensor(0).to(ctx_layers[-1])
ctx_layer = mlm_ctx_layers[-1]
lm_logits = self.lm_head(ctx_layer, ebd_weight).float()
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_labels = target_ids.view(-1)
label_index = (target_ids.view(-1)>0).nonzero().view(-1)
lm_labels = lm_labels.index_select(0, label_index)
lm_loss = loss_fct(lm_logits, lm_labels.long())
return lm_logits, lm_labels, lm_loss
def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, target_ids, input_ids, input_mask, relative_pos=None):
if attention_mask.dim()<=2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
att_mask = extended_attention_mask.byte()
attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
elif attention_mask.dim()==3:
attention_mask = attention_mask.unsqueeze(1)
target_mask = target_ids>0
hidden_states = encoder_layers[-2]
layers = [encoder.layer[-1] for _ in range(2)]
z_states += hidden_states
query_mask = attention_mask
query_states = z_states
outputs = []
rel_embeddings = encoder.get_rel_embedding()
for layer in layers:
# TODO: pass relative pos ids
output = layer(hidden_states, query_mask, return_att=False, query_states = query_states, relative_pos=relative_pos, rel_embeddings = rel_embeddings)
query_states = output
outputs.append(query_states)
_mask_index = (target_ids>0).view(-1).nonzero().view(-1)
def flatten_states(q_states):
q_states = q_states.view((-1, q_states.size(-1)))
q_states = q_states.index_select(0, _mask_index)
return q_states
return [flatten_states(q) for q in outputs]
class MaskedLanguageModel(NNModule):
""" Masked language model with DeBERTa
"""
def __init__(self, config, *wargs, **kwargs):
super().__init__(config)
self.deberta = DeBERTa(config)
self.max_relative_positions = getattr(config, 'max_relative_positions', -1)
self.position_buckets = getattr(config, 'position_buckets', -1)
if self.max_relative_positions <1:
self.max_relative_positions = config.max_position_embeddings
self.lm_predictions = EnhancedMaskDecoder(self.deberta.config, self.deberta.embeddings.word_embeddings.weight.size(0))
self.apply(self.init_weights)
def forward(self, input_ids, input_mask=None, labels=None, position_ids=None, attention_mask=None, **kwargs):
device = list(self.parameters())[0].device
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
type_ids = None
lm_labels = labels.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
else:
attention_mask = input_mask
encoder_output = self.deberta(input_ids, input_mask, type_ids, output_all_encoded_layers=True, position_ids = position_ids)
encoder_layers = encoder_output['hidden_states']
z_states = encoder_output['position_embeddings']
ctx_layer = encoder_layers[-1]
lm_loss = torch.tensor(0).to(ctx_layer).float()
lm_logits = None
label_inputs = None
if lm_labels is not None:
ebd_weight = self.deberta.embeddings.word_embeddings.weight
label_index = (lm_labels.view(-1) > 0).nonzero()
label_inputs = torch.gather(input_ids.view(-1), 0, label_index.view(-1))
if label_index.size(0)>0:
(lm_logits, lm_labels, lm_loss) = self.lm_predictions(encoder_layers, ebd_weight, lm_labels, input_ids, input_mask, z_states, attention_mask, self.deberta.encoder)
return {
'logits' : lm_logits,
'labels' : lm_labels,
'loss' : lm_loss.float(),
}

Просмотреть файл

@ -45,9 +45,9 @@ class MultiChoiceModel(NNModule):
position_ids = position_ids.view([-1, position_ids.size(-1)])
if input_mask is not None:
input_mask = input_mask.view([-1, input_mask.size(-1)])
encoder_layers = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
outputs = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
position_ids=position_ids, output_all_encoded_layers=True)
hidden_states = encoder_layers[-1]
hidden_states = outputs['hidden_states'][-1]
logits = self.classifier(self.dropout(self.pooler(hidden_states)))
logits = logits.float().squeeze(-1)
logits = logits.view([-1, num_opts])
@ -57,7 +57,10 @@ class MultiChoiceModel(NNModule):
loss_fn = CrossEntropyLoss()
loss = loss_fn(logits, labels)
return (logits, loss)
return {
'logits' : logits,
'loss' : loss
}
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):

Просмотреть файл

@ -31,8 +31,9 @@ class NERModel(NNModule):
self.apply(self.init_weights)
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
encoder_layers = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask, \
outputs = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask, \
position_ids=position_ids, output_all_encoded_layers=True)
encoder_layers = outputs['hidden_states']
cls = encoder_layers[-1]
cls = self.proj(cls)
cls = ACT2FN['gelu'](cls)
@ -47,4 +48,7 @@ class NERModel(NNModule):
loss_fn = CrossEntropyLoss()
loss = loss_fn(valid_logits, valid_labels)
return (logits, loss)
return {
'logits' : logits,
'loss' : loss
}

Просмотреть файл

@ -32,8 +32,9 @@ class ReCoRDQAModel(NNModule):
self.deberta.apply_state()
def forward(self, input_ids, entity_indice, type_ids=None, input_mask=None, labels=None, position_ids=None, placeholder=None, **kwargs):
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,\
outputs = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,\
position_ids=position_ids, output_all_encoded_layers=True)
encoder_layers = outputs['hidden_states']
# bxexsp
entity_mask = entity_indice>0
tokens = encoder_layers[-1]
@ -66,4 +67,6 @@ class ReCoRDQAModel(NNModule):
loss_fn = BCEWithLogitsLoss()
loss = loss_fn(sp_logits, labels.to(sp_logits))
return (logits, loss)
return {
'logits': logits,
'loss': loss }

Просмотреть файл

@ -41,8 +41,9 @@ class SequenceClassificationModel(NNModule):
self.deberta.apply_state()
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
outputs = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,
position_ids=position_ids, output_all_encoded_layers=True)
encoder_layers = outputs['hidden_states']
pooled_output = self.pooler(encoder_layers[-1])
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
@ -69,6 +70,10 @@ class SequenceClassificationModel(NNModule):
label_confidence = 1
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()
return {
'logits' : logits,
'loss' : loss
}
return (logits,loss)
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,

Просмотреть файл

@ -26,6 +26,8 @@ from ..utils import *
from ..utils import xtqdm as tqdm
from .tasks import load_tasks,get_task
import pdb
import LASER
from LASER.training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
from LASER.data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
@ -58,7 +60,8 @@ def train_model(args, model, device, train_data, eval_data):
return eval_metric
def loss_fn(trainer, model, data):
_, loss = model(**data)
output = model(**data)
loss = output['loss']
return loss.mean(), data['input_ids'].size(0)
trainer = DistributedTrainer(args, args.output_dir, model, device, data_fn, loss_fn = loss_fn, eval_fn = eval_fn, dump_interval = args.dump_interval)
@ -67,9 +70,12 @@ def train_model(args, model, device, train_data, eval_data):
def merge_distributed(data_list, max_len=None):
merged = []
def gather(data):
data_chunks = [torch.zeros_like(data) for _ in range(args.world_size)]
torch.distributed.all_gather(data_chunks, data)
torch.cuda.synchronize()
data_size = [torch.zeros(data.dim(), dtype=torch.int).to(data.device) for _ in range(args.world_size)]
torch.distributed.all_gather(data_size, torch.tensor(data.size()).to(data_size[0]))
data_chunks = [torch.zeros(tuple(s.cpu().numpy())).to(data) for s in data_size]
data_chunks[data.device.index] = data
for i,_chunk in enumerate(data_chunks):
torch.distributed.broadcast(_chunk, src=i)
return data_chunks
for data in data_list:
@ -82,12 +88,12 @@ def merge_distributed(data_list, max_len=None):
data_chunks.append(data_)
merged.append(data_chunks)
else:
data_chunks = gather(data)
merged.extend(data_chunks)
_chunks = gather(data)
merged.extend(_chunks)
else:
merged.append(data)
if not isinstance(merged[0], Sequence):
merged = torch.cat(merged)
merged = torch.cat([m.cpu() for m in merged])
if max_len is not None:
return merged[:max_len]
else:
@ -95,7 +101,7 @@ def merge_distributed(data_list, max_len=None):
else:
data_list=[]
for d in zip(*merged):
data = torch.cat(d)
data = torch.cat([x.cpu() for x in d])
if max_len is not None:
data = data[:max_len]
data_list.append(data)
@ -163,8 +169,13 @@ def run_eval(args, model, device, eval_data, prefix=None, tag=None, steps=None):
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=no_tqdm):
batch = batch_to(batch, device)
with torch.no_grad():
logits, tmp_eval_loss = model(**batch)
label_ids = batch['labels'].to(device)
output = model(**batch)
logits = output['logits'].detach()
tmp_eval_loss = output['loss'].detach()
if 'labels' in output:
label_ids = output['labels'].detach().to(device)
else:
label_ids = batch['labels'].to(device)
predicts.append(logits)
labels.append(label_ids)
eval_loss += tmp_eval_loss.mean().item()
@ -198,8 +209,9 @@ def run_predict(args, model, device, eval_data, prefix=None):
for batch in tqdm(AsyncDataLoader(eval_dataloader), ncols=80, desc='Evaluating: {}'.format(prefix), disable=args.rank>0):
batch = batch_to(batch, device)
with torch.no_grad():
logits, _ = model(**batch)
predicts.append(logits)
output = model(**batch)
logits = output['logits']
predicts.append(logits)
predicts = merge_distributed(predicts, len(eval_item.data))
if args.rank<=0:
predict_fn = eval_item.predict_fn
@ -238,7 +250,7 @@ def main(args):
logger.info(" Prediction batch size = %d", args.predict_batch_size)
if args.do_train:
train_data = task.train_data(max_seq_len=args.max_seq_length, mask_gen = None, debug=args.debug)
train_data = task.train_data(max_seq_len=args.max_seq_length, debug=args.debug)
model_class_fn = task.get_model_class_fn()
model = create_model(args, len(label_list), model_class_fn)
if args.do_train:

Просмотреть файл

@ -0,0 +1,231 @@
#
# Author: penhe@microsoft.com
# Date: 01/25/2019
#
from glob import glob
from collections import OrderedDict,defaultdict,Sequence
from bisect import bisect
import copy
import math
from scipy.special import softmax
import numpy as np
import pdb
import os
import sys
import csv
import random
import torch
import re
import ujson as json
from .metrics import *
from .task import EvalData, Task
from .task_registry import register_task
from ...utils import xtqdm as tqdm
from ...data import ExampleInstance, ExampleSet, DynamicDataset,example_to_feature
from ...data.example import _truncate_segments
from ...data.example import *
from ...utils import get_logger
from ..models import MaskedLanguageModel
logger=get_logger()
__all__ = ["MLMTask"]
class NGramMaskGenerator:
"""
Mask ngram tokens
https://github.com/zihangdai/xlnet/blob/0b642d14dd8aec7f1e1ecbf7d6942d5faa6be1f0/data_utils.py
"""
def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_seq=None, max_gram = 3, keep_prob = 0.1, mask_prob=0.8, **kwargs):
self.tokenizer = tokenizer
self.mask_lm_prob = mask_lm_prob
self.keep_prob = keep_prob
self.mask_prob = mask_prob
assert self.mask_prob+self.keep_prob<=1, f'The prob of using [MASK]({mask_prob}) and the prob of using original token({keep_prob}) should between [0,1]'
self.max_preds_per_seq = max_preds_per_seq
if max_preds_per_seq is None:
self.max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
self.max_gram = max_gram
self.mask_window = int(1/mask_lm_prob) # make ngrams per window sized context
self.vocab_words = list(tokenizer.vocab.keys())
def mask_tokens(self, tokens, rng, **kwargs):
special_tokens = ['[MASK]', '[CLS]', '[SEP]', '[PAD]', '[UNK]'] # + self.tokenizer.tokenize(' ')
indices = [i for i in range(len(tokens)) if tokens[i] not in special_tokens]
ngrams = np.arange(1, self.max_gram + 1, dtype=np.int64)
pvals = 1. / np.arange(1, self.max_gram + 1)
pvals /= pvals.sum(keepdims=True)
unigrams = []
for id in indices:
if len(unigrams)>=1 and self.tokenizer.part_of_whole_word(tokens[id]):
unigrams[-1].append(id)
else:
unigrams.append([id])
num_to_predict = min(self.max_preds_per_seq, max(1, int(round(len(tokens) * self.mask_lm_prob))))
mask_len = 0
offset = 0
mask_grams = np.array([False]*len(unigrams))
while offset < len(unigrams):
n = self._choice(rng, ngrams, p=pvals)
ctx_size = min(n*self.mask_window, len(unigrams)-offset)
m = rng.randint(0, ctx_size-1)
s = offset + m
e = min(offset+m+n, len(unigrams))
offset = max(offset+ctx_size, e)
mask_grams[s:e] = True
target_labels = [None]*len(tokens)
w_cnt = 0
for m,word in zip(mask_grams, unigrams):
if m:
for idx in word:
label = self._mask_token(idx, tokens, rng, self.mask_prob, self.keep_prob)
target_labels[idx] = label
w_cnt += 1
if w_cnt >= num_to_predict:
break
target_labels = [self.tokenizer.vocab[x] if x else 0 for x in target_labels]
return tokens, target_labels
def _choice(self, rng, data, p):
cul = np.cumsum(p)
x = rng.random()*cul[-1]
id = bisect(cul, x)
return data[id]
def _mask_token(self, idx, tokens, rng, mask_prob, keep_prob):
label = tokens[idx]
mask = '[MASK]'
rand = rng.random()
if rand < mask_prob:
new_label = mask
elif rand < mask_prob+keep_prob:
new_label = label
else:
new_label = rng.choice(self.vocab_words)
tokens[idx] = new_label
return label
@register_task(name="MLM", desc="Masked language model pretraining task")
class MLMTask(Task):
def __init__(self, data_dir, tokenizer, args, **kwargs):
super().__init__(tokenizer, args, **kwargs)
self.data_dir = data_dir
self.mask_gen = NGramMaskGenerator(tokenizer, max_gram=self.args.max_ngram)
def train_data(self, max_seq_len=512, **kwargs):
data = self.load_data(os.path.join(self.data_dir, 'train.txt'))
examples = ExampleSet(data)
if self.args.num_training_steps is None:
dataset_size = len(examples)
else:
dataset_size = self.args.num_training_steps*self.args.train_batch_size
return DynamicDataset(examples, feature_fn = self.get_feature_fn(max_seq_len=max_seq_len, mask_gen=self.mask_gen), \
dataset_size = dataset_size, shuffle=True, **kwargs)
def get_labels(self):
return list(self.tokenizer.vocab.values())
def eval_data(self, max_seq_len=512, **kwargs):
ds = [
self._data('dev', 'valid.txt', 'dev'),
]
for d in ds:
_size = len(d.data)
d.data = DynamicDataset(d.data, feature_fn = self.get_feature_fn(max_seq_len=max_seq_len, mask_gen=self.mask_gen), dataset_size = _size, **kwargs)
return ds
def test_data(self, max_seq_len=512, **kwargs):
"""See base class."""
raise NotImplemented('This method is not implemented yet.')
def _data(self, name, path, type_name = 'dev', ignore_metric=False):
if isinstance(path, str):
path = [path]
data = []
for p in path:
input_src = os.path.join(self.data_dir, p)
assert os.path.exists(input_src), f"{input_src} doesn't exists"
data.extend(self.load_data(input_src))
predict_fn = self.get_predict_fn()
examples = ExampleSet(data)
return EvalData(name, examples,
metrics_fn = self.get_metrics_fn(), predict_fn = predict_fn, ignore_metric=ignore_metric, critial_metrics=['accuracy'])
def get_metrics_fn(self):
"""Calcuate metrics based on prediction results"""
def metrics_fn(logits, labels):
preds = np.argmax(logits, axis=-1)
acc = (preds==labels).sum()/len(labels)
metrics = OrderedDict(accuracy= acc)
return metrics
return metrics_fn
def load_data(self, path):
examples = []
with open(path, encoding='utf-8') as fs:
for l in fs:
if len(l) > 1:
example = ExampleInstance(segments=[l])
examples.append(example)
return examples
def get_feature_fn(self, max_seq_len = 512, mask_gen = None):
def _example_to_feature(example, rng=None, ext_params=None, **kwargs):
return self.example_to_feature(self.tokenizer, example, max_seq_len = max_seq_len, \
rng = rng, mask_generator = mask_gen, ext_params = ext_params, **kwargs)
return _example_to_feature
def example_to_feature(self, tokenizer, example, max_seq_len=512, rng=None, mask_generator = None, ext_params=None, **kwargs):
if not rng:
rng = random
max_num_tokens = max_seq_len - 2
segments = [ example.segments[0].strip().split() ]
segments = _truncate_segments(segments, max_num_tokens, rng)
_tokens = ['[CLS]'] + segments[0] + ['[SEP]']
if mask_generator:
tokens, lm_labels = mask_generator.mask_tokens(_tokens, rng)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
features = OrderedDict(input_ids = token_ids,
position_ids = list(range(len(token_ids))),
input_mask = [1]*len(token_ids),
labels = lm_labels)
for f in features:
features[f] = torch.tensor(features[f] + [0]*(max_seq_len - len(token_ids)), dtype=torch.int)
return features
def get_model_class_fn(self):
def partial_class(*wargs, **kwargs):
return MaskedLanguageModel.load_model(*wargs, **kwargs)
return partial_class
@classmethod
def add_arguments(cls, parser):
"""Add task specific arguments
e.g. parser.add_argument('--data_dir', type=str, help='The path of data directory.')
"""
parser.add_argument('--max_ngram', type=int, default=1, help='Maxium ngram sampling span')
parser.add_argument('--num_training_steps', type=int, default=None, help='Maxium pre-training steps')
def test_MLM():
from ...deberta import tokenizers,load_vocab
import pdb
vocab_path, vocab_type = load_vocab(vocab_path = None, vocab_type = 'spm', pretrained_id = 'xlarge-v2')
tokenizer = tokenizers[vocab_type](vocab_path)
mask_gen = NGramMaskGenerator(tokenizer, max_gram=1)
mlm = MLMTask('/mnt/penhe/data/wiki103/spm', tokenizer, None)
train_data = mlm.train_data(mask_gen = mask_gen)
pdb.set_trace()

Просмотреть файл

@ -21,7 +21,7 @@ from .ops import *
from .disentangled_attention import *
from .da_utils import *
__all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm']
__all__ = ['BertEncoder', 'BertEmbeddings', 'ACT2FN', 'LayerNorm', 'BertLMPredictionHead']
class BertSelfOutput(nn.Module):
def __init__(self, config):
@ -184,7 +184,7 @@ class BertEncoder(nn.Module):
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
all_encoder_layers = []
att_matrixs = []
att_matrices = []
if isinstance(hidden_states, Sequence):
next_kv = hidden_states[0]
else:
@ -209,15 +209,15 @@ class BertEncoder(nn.Module):
if output_all_encoded_layers:
all_encoder_layers.append(output_states)
if return_att:
att_matrixs.append(att_m)
att_matrices.append(att_m)
if not output_all_encoded_layers:
all_encoder_layers.append(output_states)
if return_att:
att_matrixs.append(att_m)
if return_att:
return (all_encoder_layers, att_matrixs)
else:
return all_encoder_layers
att_matrices.append(att_m)
return {
'hidden_states': all_encoder_layers,
'attention_matrices': att_matrices
}
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
@ -229,10 +229,7 @@ class BertEmbeddings(nn.Module):
self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx = padding_idx)
self.position_biased_input = getattr(config, 'position_biased_input', True)
if not self.position_biased_input:
self.position_embeddings = None
else:
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
if config.type_vocab_size>0:
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
@ -253,14 +250,9 @@ class BertEmbeddings(nn.Module):
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
if self.position_embeddings is not None:
position_embeddings = self.position_embeddings(position_ids.long())
else:
position_embeddings = torch.zeros_like(words_embeddings)
position_embeddings = self.position_embeddings(position_ids.long())
embeddings = words_embeddings
if self.position_biased_input:
embeddings += position_embeddings
if self.config.type_vocab_size>0:
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings += token_type_embeddings
@ -270,4 +262,28 @@ class BertEmbeddings(nn.Module):
embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask)
embeddings = self.dropout(embeddings)
return embeddings
return {
'embeddings': embeddings,
'position_embeddings': position_embeddings}
class BertLMPredictionHead(nn.Module):
def __init__(self, config, vocab_size):
super().__init__()
self.embedding_size = getattr(config, 'embedding_size', config.hidden_size)
self.dense = nn.Linear(config.hidden_size, self.embedding_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = LayerNorm(self.embedding_size, config.layer_norm_eps, elementwise_affine=True)
self.bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, hidden_states, embeding_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
# b x s x d
hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states)
# b x s x v
logits = torch.matmul(hidden_states, embeding_weight.t().to(hidden_states)) + self.bias
return logits

Просмотреть файл

@ -110,19 +110,13 @@ class DeBERTa(torch.nn.Module):
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
embedding_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
encoded_layers = self.encoder(embedding_output,
ebd_output = self.embeddings(input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask)
embedding_output = ebd_output['embeddings']
encoder_output = self.encoder(embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers, return_att = return_att)
if return_att:
encoded_layers, att_matrixs = encoded_layers
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1:]
if return_att:
return encoded_layers, att_matrixs
return encoded_layers
encoder_output.update(ebd_output)
return encoder_output
def apply_state(self, state = None):
""" Load state from previous loaded model state dictionary.

Просмотреть файл

@ -92,8 +92,9 @@ class NNModule(nn.Module):
config = None
model_config = None
model_state = None
if model_path and model_path.strip() == '-' or model_path.strip()=='':
if (model_path is not None) and (model_path.strip() == '-' or model_path.strip()==''):
model_path = None
try:
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
except Exception as exp:

Просмотреть файл

@ -0,0 +1,69 @@
#!/bin/bash
SCRIPT=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT")
cd $SCRIPT_DIR
cache_dir=/tmp/DeBERTa/MLM/
function setup_wiki_data(){
task=$1
mkdir -p $cache_dir
if [[ ! -e $cache_dir/spm.model ]]; then
wget -q https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model -O $cache_dir/spm.model
fi
if [[ ! -e $cache_dir/wiki103.zip ]]; then
wget -q https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -O $cache_dir/wiki103.zip
unzip -j $cache_dir/wiki103.zip -d $cache_dir/wiki103
mkdir -p $cache_dir/wiki103/spm
python ./prepare_data.py -i $cache_dir/wiki103/wiki.train.tokens -o $cache_dir/wiki103/spm/train.txt
python ./prepare_data.py -i $cache_dir/wiki103/wiki.valid.tokens -o $cache_dir/wiki103/spm/valid.txt
python ./prepare_data.py -i $cache_dir/wiki103/wiki.test.tokens -o $cache_dir/wiki103/spm/test.txt
fi
}
setup_wiki_data
Task=MLM
init=$1
tag=$init
case ${init,,} in
xlarge-v2)
parameters=" --num_train_epochs 1 \
--model_config xlarge.json \
--warmup 1000 \
--learning_rate 1e-4 \
--train_batch_size 32 \
--max_ngram 3 \
--fp16 True "
;;
xxlarge-v2)
parameters=" --num_train_epochs 1 \
--warmup 1000 \
--model_config xxlarge.json \
--learning_rate 1e-4 \
--train_batch_size 32 \
--max_ngram 3 \
--fp16 True "
;;
*)
echo "usage $0 <Pretrained model configuration>"
echo "Supported configurations"
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
exit 0
;;
esac
data_dir=$cache_dir/wiki103/spm
python -m DeBERTa.apps.run --model_config config.json \
--tag $tag \
--do_train \
--num_training_steps 10000 \
--max_seq_len 512 \
--task_name $Task \
--data_dir $data_dir \
--vocab_path $cache_dir/spm.model \
--vocab_type spm \
--output_dir /tmp/ttonly/$tag/$task $parameters

Просмотреть файл

@ -0,0 +1,36 @@
# coding: utf-8
from DeBERTa import deberta
import sys
import argparse
from tqdm import tqdm
def tokenize_data(input, output=None):
p,t=deberta.load_vocab(vocab_path=None, vocab_type='gpm', pretrained_id='xlarge-v2')
tokenizer=deberta.tokenizers[t](p)
if output is None:
output=input + '.gpm'
all_tokens = []
with open(input, encoding = 'utf-8') as fs:
for l in tqdm(fs, ncols=80, desc='Loading'):
if len(l) > 0:
tokens = tokenizer.tokenize(l)
else:
tokens = []
all_tokens.extend(tokens)
print(f'Loaded {len(all_tokens)} tokens from {input}')
lines = 0
with open(output, 'w', encoding = 'utf-8') as wfs:
idx = 0
while idx < len(all_tokens):
wfs.write(' '.join(all_tokens[idx:idx+510]) + '\n')
idx += 510
lines += 1
print(f'Saved {lines} lines to {output}')
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True, help='The input data path')
parser.add_argument('-o', '--output', default=None, help='The output data path')
args = parser.parse_args()
tokenize_data(args.input, args.output)

Просмотреть файл

@ -0,0 +1,24 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 512,
"relative_attention": true,
"position_buckets": 256,
"norm_rel_ebd": "layer_norm",
"share_att_key": true,
"pos_att_type": "p2c|c2p",
"layer_norm_eps": 1e-7,
"conv_kernel_size": 3,
"conv_act": "gelu",
"max_relative_positions": -1,
"position_biased_input": false,
"num_attention_heads": 24,
"attention_head_size": 64,
"num_hidden_layers": 24,
"type_vocab_size": 0,
"vocab_size": 128100
}

Просмотреть файл

@ -0,0 +1,24 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 512,
"relative_attention": true,
"position_buckets": 256,
"norm_rel_ebd": "layer_norm",
"share_att_key": true,
"pos_att_type": "p2c|c2p",
"layer_norm_eps": 1e-7,
"conv_kernel_size": 3,
"conv_act": "gelu",
"max_relative_positions": -1,
"position_biased_input": false,
"num_attention_heads": 24,
"attention_head_size": 64,
"num_hidden_layers": 48,
"type_vocab_size": 0,
"vocab_size": 128100
}