2. Remove torch version dependencies
3. Support to save GPT2Tokenizer to local.
This commit is contained in:
Lu Wang 2020-06-29 11:45:21 -07:00 коммит произвёл Pengcheng He
Родитель a0e332fe61
Коммит ae38d6fccc
7 изменённых файлов: 117 добавлений и 55 удалений

Просмотреть файл

@ -33,7 +33,7 @@ class SequenceClassificationModel(NNModule):
self.pooler = ContextPooler(pool_config)
output_dim = self.pooler.output_dim()
self.classifier = nn.Linear(output_dim, num_labels)
self.classifier = torch.nn.Linear(output_dim, num_labels)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
self.apply(self.init_weights)
@ -49,7 +49,7 @@ class SequenceClassificationModel(NNModule):
if labels is not None:
if self.num_labels ==1:
# regression task
loss_fn = nn.MSELoss()
loss_fn = torch.nn.MSELoss()
logits=logits.view(-1).to(labels.dtype)
loss = loss_fn(logits, labels.view(-1))
elif labels.dim()==1 or labels.size(-1)==1:

Просмотреть файл

@ -159,8 +159,8 @@ class BertEncoder(nn.Module):
def get_attention_mask(self, attention_mask):
if attention_mask.dim()<=2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
att_mask = extended_attention_mask.byte()
attention_mask = att_mask*att_mask.squeeze(-2).unsqueeze(-1)
attention_mask = extended_attention_mask*extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim()==3:
attention_mask = attention_mask.unsqueeze(1)
@ -169,7 +169,7 @@ class BertEncoder(nn.Module):
def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
if self.relative_attention and relative_pos is None:
q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
relative_pos = build_relative_position(q, hidden_states.size(-2))
relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device)
return relative_pos
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, return_att=False, query_states = None, relative_pos=None):

Просмотреть файл

@ -11,14 +11,14 @@
Disentangled SelfAttention module
"""
import numpy as np
import torch
import math
from .ops import *
__all__=['build_relative_position', 'DisentangledSelfAttention']
def build_relative_position(query_size, key_size):
def build_relative_position(query_size, key_size, device):
""" Build relative position according to the query and key
We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size),
@ -35,14 +35,29 @@ def build_relative_position(query_size, key_size):
"""
q_ids = np.arange(0, query_size)
k_ids = np.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
q_ids = torch.arange(query_size, dtype=torch.long, device=device)
k_ids = torch.arange(key_size, dtype=torch.long, device=device)
rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
@torch.jit.script
def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
@torch.jit.script
def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
@torch.jit.script
def pos_dynamic_expand(pos_index, p2c_att, key_layer):
return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
class DisentangledSelfAttention(torch.nn.Module):
""" Disentangled self-attention module
@ -176,7 +191,7 @@ class DisentangledSelfAttention(torch.nn.Module):
def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
if relative_pos is None:
q = query_layer.size(-2)
relative_pos = build_relative_position(q, key_layer.size(-2))
relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device)
if relative_pos.dim()==2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
elif relative_pos.dim()==3:
@ -201,14 +216,14 @@ class DisentangledSelfAttention(torch.nn.Module):
if 'c2p' in self.pos_att_type:
c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]))
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
score += c2p_att
# position->content
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
pos_query_layer /= math.sqrt(pos_query_layer.size(-1)*scale_factor)
if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2)).to(query_layer.device)
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else:
r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span*2-1)
@ -217,9 +232,9 @@ class DisentangledSelfAttention(torch.nn.Module):
if 'p2c' in self.pos_att_type:
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])).transpose(-1,-2)
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)).transpose(-1,-2)
if query_layer.size(-2) != key_layer.size(-2):
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))))
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
score += p2c_att
return score

Просмотреть файл

@ -46,25 +46,26 @@ class GPT2Tokenizer(object):
"""
def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
pad='[PAD]'
eos='[SEP]'
unk='[UNK]'
bos='[CLS]'
self.pad_token='[PAD]'
self.sep_token='[SEP]'
self.unk_token='[UNK]'
self.cls_token='[CLS]'
self.symbols = []
self.count = []
self.indices = {}
self.add_symbol(pad)
self.add_symbol(bos)
self.add_symbol(eos)
self.add_symbol(unk)
self.pad_token_id = self.add_symbol(self.pad_token)
self.cls_token_id = self.add_symbol(self.cls_token)
self.sep_token_id = self.add_symbol(self.sep_token)
self.unk_token_id = self.add_symbol(self.unk_token)
gpt2_encoder = load_vocab(vocab_file)
self.bpe = get_encoder(gpt2_encoder['encoder'], gpt2_encoder['vocab'])
for w,n in gpt2_encoder['dict_map']:
self.gpt2_encoder = load_vocab(vocab_file)
self.bpe = get_encoder(self.gpt2_encoder['encoder'], self.gpt2_encoder['vocab'])
for w,n in self.gpt2_encoder['dict_map']:
self.add_symbol(w, n)
self.mask_id = self.add_symbol('[MASK]')
self.mask_token='[MASK]'
self.mask_id = self.add_symbol(self.mask_token)
self.special_tokens = ['[MASK]', '[SEP]', '[PAD]', '[UNK]', '[CLS]']
if special_tokens is not None:
for t in special_tokens:
@ -211,3 +212,5 @@ class GPT2Tokenizer(object):
self.count.append(n)
return idx
def save_pretrained(self, path: str):
torch.save(self.gpt2_encoder, path)

Просмотреть файл

@ -10,6 +10,7 @@
import math
from packaging import version
import torch
from ..utils.jit_tracing import traceable
if version.Version(torch.__version__) >= version.Version('1.0.0'):
from torch import _softmax_backward_data as _softmax_backward_data
@ -18,6 +19,7 @@ else:
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax']
@traceable
class XSoftmax(torch.autograd.Function):
""" Masked Softmax which is optimized for saving memory
@ -46,9 +48,9 @@ class XSoftmax(torch.autograd.Function):
self.dim = dim
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
rmask = (1-mask).bool()
rmask = ~(mask.bool())
else:
rmask = (1-mask).byte()
rmask = (1-mask).byte() # This line is not supported by Onnx tracing.
output = input.masked_fill(rmask, float('-inf'))
output = torch.softmax(output, self.dim)
@ -72,10 +74,32 @@ class DropoutContext(object):
self.scale = 1
self.reuse_mask = True
def get_mask(input, local_context):
if not isinstance(local_context, DropoutContext):
dropout = local_context
mask = None
else:
dropout = local_context.dropout
dropout *= local_context.scale
mask = local_context.mask if local_context.reuse_mask else None
if dropout>0 and mask is None:
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
else:
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
local_context.mask = mask
return mask, dropout
@traceable
class XDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, input, local_ctx):
mask, dropout = XDropout.get_mask(input, local_ctx)
mask, dropout = get_mask(input, local_ctx)
ctx.scale=1.0/(1-dropout)
if dropout>0:
ctx.save_for_backward(mask)
@ -91,28 +115,6 @@ class XDropout(torch.autograd.Function):
else:
return grad_output, None
@staticmethod
def get_mask(input, local_context):
if not isinstance(local_context, DropoutContext):
dropout = local_context
mask = None
else:
dropout = local_context.dropout
dropout *= local_context.scale
mask = local_context.mask if local_context.reuse_mask else None
if dropout>0 and mask is None:
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).bool()
else:
mask=(1-torch.empty_like(input).bernoulli_(1-dropout)).byte()
if isinstance(local_context, DropoutContext):
if local_context.mask is None:
local_context.mask = mask
return mask, dropout
class StableDropout(torch.nn.Module):
""" Optimized dropout module for stabilizing the training

Просмотреть файл

@ -0,0 +1,41 @@
"""
Logging util
@Author: penhe@microsoft.com
"""
""" Utils for torch jit tracing customer operators/functions
"""
import os
def traceable(cls):
""" Decorator over customer functions
There is an issue for tracing customer python torch Function, using this decorator to work around it.
e.g.
@traceable
class MyOp(torch.autograd.Function):
xxx
"""
class _Function(object):
@staticmethod
def apply(*args):
jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true')
if jit_trace:
return cls.forward(_Function, *args)
else:
return cls.apply(*args)
@staticmethod
def save_for_backward(*args):
pass
return _Function
class TraceMode():
""" Trace context used when tracing modules contains customer operators/Functions
"""
def __enter__(self):
os.environ['JIT_TRACE'] = 'True'
return self
def __exit__(self, exp_value, exp_type, trace):
del os.environ['JIT_TRACE']

Просмотреть файл

@ -5,9 +5,10 @@ pytest
regex
scipy
sklearn
torch==1.3.0
torchvision==0.3.0
torch
torchvision
tqdm
ujson
seqeval
psutil
GitPython