зеркало из https://github.com/microsoft/DeBERTa.git
1. Fix DeBERTa for tracing.
2. Remove torch version dependencies 3. Support to save GPT2Tokenizer to local.
This commit is contained in:
Родитель
a0e332fe61
Коммит
ae38d6fccc
|
@ -33,7 +33,7 @@ class SequenceClassificationModel(NNModule):
|
|||
self.pooler = ContextPooler(pool_config)
|
||||
output_dim = self.pooler.output_dim()
|
||||
|
||||
self.classifier = nn.Linear(output_dim, num_labels)
|
||||
self.classifier = torch.nn.Linear(output_dim, num_labels)
|
||||
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
|
||||
self.dropout = StableDropout(drop_out)
|
||||
self.apply(self.init_weights)
|
||||
|
@ -49,7 +49,7 @@ class SequenceClassificationModel(NNModule):
|
|||
if labels is not None:
|
||||
if self.num_labels ==1:
|
||||
# regression task
|
||||
loss_fn = nn.MSELoss()
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
logits=logits.view(-1).to(labels.dtype)
|
||||
loss = loss_fn(logits, labels.view(-1))
|
||||
elif labels.dim()==1 or labels.size(-1)==1:
|
||||
|
|
|
@ -159,8 +159,8 @@ class BertEncoder(nn.Module):
|
|||
def get_attention_mask(self, attention_mask):
|
||||
if attention_mask.dim()<=2:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
att_mask = extended_attention_mask.byte()
|
||||
attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
|
||||
attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
||||
attention_mask = attention_mask.byte()
|
||||
elif attention_mask.dim()==3:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
|
@ -169,7 +169,7 @@ class BertEncoder(nn.Module):
|
|||
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
|
||||
if self.relative_attention and relative_pos is None:
|
||||
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
|
||||
relative_pos = build_relative_position(q, hidden_states.size(-2))
|
||||
relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device)
|
||||
return relative_pos
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):
|
||||
|
|
|
@ -11,14 +11,14 @@
|
|||
Disentangled SelfAttention module
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import math
|
||||
from .ops import *
|
||||
|
||||
__all__=['build_relative_position', 'DisentangledSelfAttention']
|
||||
|
||||
def build_relative_position(query_size, key_size):
|
||||
|
||||
def build_relative_position(query_size, key_size, device):
|
||||
""" Build relative position according to the query and key
|
||||
|
||||
We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size),
|
||||
|
@ -35,14 +35,29 @@ def build_relative_position(query_size, key_size):
|
|||
|
||||
"""
|
||||
|
||||
q_ids = np.arange(0, query_size)
|
||||
k_ids = np.arange(0, key_size)
|
||||
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
|
||||
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
|
||||
q_ids = torch.arange(query_size, dtype=torch.long, device=device)
|
||||
k_ids = torch.arange(key_size, dtype=torch.long, device=device)
|
||||
rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
|
||||
rel_pos_ids = rel_pos_ids[:query_size, :]
|
||||
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
||||
return rel_pos_ids
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
|
||||
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
|
||||
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
|
||||
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
|
||||
|
||||
|
||||
class DisentangledSelfAttention(torch.nn.Module):
|
||||
""" Disentangled self-attention module
|
||||
|
||||
|
@ -176,7 +191,7 @@ class DisentangledSelfAttention(torch.nn.Module):
|
|||
def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
|
||||
if relative_pos is None:
|
||||
q = query_layer.size(-2)
|
||||
relative_pos = build_relative_position(q, key_layer.size(-2))
|
||||
relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device)
|
||||
if relative_pos.dim()==2:
|
||||
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
|
||||
elif relative_pos.dim()==3:
|
||||
|
@ -201,14 +216,14 @@ class DisentangledSelfAttention(torch.nn.Module):
|
|||
if 'c2p' in self.pos_att_type:
|
||||
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
|
||||
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
|
||||
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]))
|
||||
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
|
||||
score += c2p_att
|
||||
|
||||
# position->content
|
||||
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
|
||||
pos_query_layer /= math.sqrt(pos_query_layer.size(-1)*scale_factor)
|
||||
if query_layer.size(-2) != key_layer.size(-2):
|
||||
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2)).to(query_layer.device)
|
||||
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
|
||||
else:
|
||||
r_pos = relative_pos
|
||||
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span*2-1)
|
||||
|
@ -217,9 +232,9 @@ class DisentangledSelfAttention(torch.nn.Module):
|
|||
|
||||
if 'p2c' in self.pos_att_type:
|
||||
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
|
||||
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])).transpose(-1,-2)
|
||||
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)).transpose(-1,-2)
|
||||
if query_layer.size(-2) != key_layer.size(-2):
|
||||
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))))
|
||||
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
|
||||
score += p2c_att
|
||||
|
||||
return score
|
||||
|
|
|
@ -46,25 +46,26 @@ class GPT2Tokenizer(object):
|
|||
|
||||
"""
|
||||
def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
|
||||
pad='[PAD]'
|
||||
eos='[SEP]'
|
||||
unk='[UNK]'
|
||||
bos='[CLS]'
|
||||
self.pad_token='[PAD]'
|
||||
self.sep_token='[SEP]'
|
||||
self.unk_token='[UNK]'
|
||||
self.cls_token='[CLS]'
|
||||
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.add_symbol(pad)
|
||||
self.add_symbol(bos)
|
||||
self.add_symbol(eos)
|
||||
self.add_symbol(unk)
|
||||
self.pad_token_id = self.add_symbol(self.pad_token)
|
||||
self.cls_token_id = self.add_symbol(self.cls_token)
|
||||
self.sep_token_id = self.add_symbol(self.sep_token)
|
||||
self.unk_token_id = self.add_symbol(self.unk_token)
|
||||
|
||||
gpt2_encoder = load_vocab(vocab_file)
|
||||
self.bpe = get_encoder(gpt2_encoder['encoder'], gpt2_encoder['vocab'])
|
||||
for w,n in gpt2_encoder['dict_map']:
|
||||
self.gpt2_encoder = load_vocab(vocab_file)
|
||||
self.bpe = get_encoder(self.gpt2_encoder['encoder'], self.gpt2_encoder['vocab'])
|
||||
for w,n in self.gpt2_encoder['dict_map']:
|
||||
self.add_symbol(w, n)
|
||||
|
||||
self.mask_id = self.add_symbol('[MASK]')
|
||||
self.mask_token='[MASK]'
|
||||
self.mask_id = self.add_symbol(self.mask_token)
|
||||
self.special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
|
||||
if special_tokens is not None:
|
||||
for t in special_tokens:
|
||||
|
@ -211,3 +212,5 @@ class GPT2Tokenizer(object):
|
|||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def save_pretrained(self, path: str):
|
||||
torch.save(self.gpt2_encoder, path)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
import math
|
||||
from packaging import version
|
||||
import torch
|
||||
from ..utils.jit_tracing import traceable
|
||||
|
||||
if version.Version(torch.__version__) >= version.Version('1.0.0'):
|
||||
from torch import _softmax_backward_data as _softmax_backward_data
|
||||
|
@ -18,6 +19,7 @@ else:
|
|||
|
||||
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax']
|
||||
|
||||
@traceable
|
||||
class XSoftmax(torch.autograd.Function):
|
||||
""" Masked Softmax which is optimized for saving memory
|
||||
|
||||
|
@ -46,9 +48,9 @@ class XSoftmax(torch.autograd.Function):
|
|||
|
||||
self.dim = dim
|
||||
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
||||
rmask = (1-mask).bool()
|
||||
rmask = ~(mask.bool())
|
||||
else:
|
||||
rmask = (1-mask).byte()
|
||||
rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
|
||||
|
||||
output = input.masked_fill(rmask, float('-inf'))
|
||||
output = torch.softmax(output, self.dim)
|
||||
|
@ -72,10 +74,32 @@ class DropoutContext(object):
|
|||
self.scale = 1
|
||||
self.reuse_mask = True
|
||||
|
||||
def get_mask(input, local_context):
|
||||
if not isinstance(local_context, DropoutContext):
|
||||
dropout = local_context
|
||||
mask = None
|
||||
else:
|
||||
dropout = local_context.dropout
|
||||
dropout *= local_context.scale
|
||||
mask = local_context.mask if local_context.reuse_mask else None
|
||||
|
||||
if dropout>0 and mask is None:
|
||||
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
||||
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
|
||||
else:
|
||||
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
|
||||
|
||||
if isinstance(local_context, DropoutContext):
|
||||
if local_context.mask is None:
|
||||
local_context.mask = mask
|
||||
|
||||
return mask, dropout
|
||||
|
||||
@traceable
|
||||
class XDropout(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, local_ctx):
|
||||
mask, dropout = XDropout.get_mask(input, local_ctx)
|
||||
mask, dropout = get_mask(input, local_ctx)
|
||||
ctx.scale=1.0/(1-dropout)
|
||||
if dropout>0:
|
||||
ctx.save_for_backward(mask)
|
||||
|
@ -91,28 +115,6 @@ class XDropout(torch.autograd.Function):
|
|||
else:
|
||||
return grad_output, None
|
||||
|
||||
@staticmethod
|
||||
def get_mask(input, local_context):
|
||||
if not isinstance(local_context, DropoutContext):
|
||||
dropout = local_context
|
||||
mask = None
|
||||
else:
|
||||
dropout = local_context.dropout
|
||||
dropout *= local_context.scale
|
||||
mask = local_context.mask if local_context.reuse_mask else None
|
||||
|
||||
if dropout>0 and mask is None:
|
||||
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
||||
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
|
||||
else:
|
||||
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
|
||||
|
||||
if isinstance(local_context, DropoutContext):
|
||||
if local_context.mask is None:
|
||||
local_context.mask = mask
|
||||
|
||||
return mask, dropout
|
||||
|
||||
class StableDropout(torch.nn.Module):
|
||||
""" Optimized dropout module for stabilizing the training
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
Logging util
|
||||
@Author: penhe@microsoft.com
|
||||
"""
|
||||
|
||||
""" Utils for torch jit tracing customer operators/functions
|
||||
"""
|
||||
import os
|
||||
|
||||
def traceable(cls):
|
||||
""" Decorator over customer functions
|
||||
There is an issue for tracing customer python torch Function, using this decorator to work around it.
|
||||
e.g.
|
||||
@traceable
|
||||
class MyOp(torch.autograd.Function):
|
||||
xxx
|
||||
"""
|
||||
class _Function(object):
|
||||
@staticmethod
|
||||
def apply(*args):
|
||||
jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true')
|
||||
if jit_trace:
|
||||
return cls.forward(_Function, *args)
|
||||
else:
|
||||
return cls.apply(*args)
|
||||
|
||||
@staticmethod
|
||||
def save_for_backward(*args):
|
||||
pass
|
||||
|
||||
return _Function
|
||||
|
||||
class TraceMode():
|
||||
""" Trace context used when tracing modules contains customer operators/Functions
|
||||
"""
|
||||
def __enter__(self):
|
||||
os.environ['JIT_TRACE'] = 'True'
|
||||
return self
|
||||
|
||||
def __exit__(self, exp_value, exp_type, trace):
|
||||
del os.environ['JIT_TRACE']
|
|
@ -5,9 +5,10 @@ pytest
|
|||
regex
|
||||
scipy
|
||||
sklearn
|
||||
torch==1.3.0
|
||||
torchvision==0.3.0
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
ujson
|
||||
seqeval
|
||||
psutil
|
||||
GitPython
|
||||
|
|
Загрузка…
Ссылка в новой задаче