зеркало из https://github.com/microsoft/DeBERTa.git
Add documentation
This commit is contained in:
Родитель
009cc44b8c
Коммит
2e3e748a65
|
@ -32,7 +32,7 @@ def create_model(args, num_labels, model_class_fn):
|
|||
# Prepare model
|
||||
rank = getattr(args, 'rank', 0)
|
||||
init_model = args.init_model if rank<1 else None
|
||||
model = model_class_fn(init_model, args.bert_config, num_labels=num_labels, \
|
||||
model = model_class_fn(init_model, args.model_config, num_labels=num_labels, \
|
||||
drop_out=args.cls_drop_out, \
|
||||
pre_trained = args.pre_trained)
|
||||
if args.fp16:
|
||||
|
@ -379,7 +379,7 @@ def build_argument_parser():
|
|||
type=str,
|
||||
help="The model state file used to initialize the model weights.")
|
||||
|
||||
parser.add_argument('--bert_config',
|
||||
parser.add_argument('--model_config',
|
||||
type=str,
|
||||
help="The config file of bert model.")
|
||||
|
||||
|
|
|
@ -18,3 +18,4 @@ from .disentangled_attention import *
|
|||
from .ops import *
|
||||
from .bert import *
|
||||
from .gpt2_tokenizer import GPT2Tokenizer
|
||||
from .config import *
|
||||
|
|
|
@ -39,9 +39,10 @@ def linear_act(x):
|
|||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.nn.functional.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid}
|
||||
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""LayerNorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
|
||||
def __init__(self, size, eps=1e-12):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(size))
|
||||
self.bias = nn.Parameter(torch.zeros(size))
|
||||
|
@ -139,6 +140,8 @@ class BertLayer(nn.Module):
|
|||
return layer_output
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
""" Modified BertEncoder with relative position bias support
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
layer = BertLayer(config)
|
||||
|
|
|
@ -40,33 +40,41 @@ class AbsModelConfig(object):
|
|||
return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n"
|
||||
|
||||
class ModelConfig(AbsModelConfig):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model.
|
||||
|
||||
Attributes:
|
||||
hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
|
||||
num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`.
|
||||
num_attention_heads (int): Number of attention heads for each attention layer in
|
||||
the Transformer encoder, default: `12`.
|
||||
intermediate_size (int): The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder, default: `3072`.
|
||||
hidden_act (str): The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`.
|
||||
hidden_dropout_prob (float): The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler, default: `0.1`.
|
||||
attention_probs_dropout_prob (float): The dropout ratio for the attention
|
||||
probabilities, default: `0.1`.
|
||||
max_position_embeddings (int): The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048), default: `512`.
|
||||
type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into
|
||||
`DeBERTa` model, default: `-1`.
|
||||
initializer_range (int): The sttdev of the _normal_initializer for
|
||||
initializing all weight matrices, default: `0.02`.
|
||||
relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`.
|
||||
max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`.
|
||||
padding_idx (int): The value used to pad input_ids, default: `0`.
|
||||
position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`.
|
||||
pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None".
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Constructs ModelConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
|
||||
self.hidden_size = 768
|
||||
self.num_hidden_layers = 12
|
||||
self.num_attention_heads = 12
|
||||
|
|
|
@ -23,38 +23,17 @@ class DeBERTa(torch.nn.Module):
|
|||
""" DeBERTa encoder
|
||||
This module is composed of the input embedding layer with stacked transformer layers with disentangled attention.
|
||||
|
||||
Params:
|
||||
`config`: A model config class instance with the configuration to build a new model. The schema is similar to BertConfig, for more details, please refer `config.py`
|
||||
`pre_trained`: The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configuration, e.g. base, large, base_mnli, large_mnli
|
||||
Parameters:
|
||||
config:
|
||||
A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
|
||||
for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
`attention_mask`: an optional parameter for input mask or attention mask.
|
||||
- If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
- If it's an attention mask then if will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
|
||||
`output_all_encoded_layers`: whether to output results of all encoder layers, default, True
|
||||
|
||||
Outputs:
|
||||
The output of the stacked transformer layers if `output_all_encoded_layers=True`, else
|
||||
the last layer of stacked transformer layers
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
bert = DeBERTa(pre_trained='base')
|
||||
encoder_layers = bert(input_ids, attention_mask=attention_mask)
|
||||
```
|
||||
pre_trained:
|
||||
The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \
|
||||
i.e. [**base, large, base_mnli, large_mnli**]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config=None, pre_trained=None):
|
||||
super().__init__()
|
||||
if config:
|
||||
|
@ -82,6 +61,54 @@ class DeBERTa(torch.nn.Module):
|
|||
self.apply_state(state)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, position_ids = None, return_att = False):
|
||||
"""
|
||||
Args:
|
||||
input_ids:
|
||||
a torch.LongTensor of shape [batch_size, sequence_length] \
|
||||
with the word token indices in the vocabulary
|
||||
|
||||
token_type_ids:
|
||||
an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
|
||||
attention_mask:
|
||||
an optional parameter for input mask or attention mask.
|
||||
|
||||
- If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when \
|
||||
a batch has varying length sentences.
|
||||
|
||||
- If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \
|
||||
In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence.
|
||||
|
||||
output_all_encoded_layers:
|
||||
whether to output results of all encoder layers, default, True
|
||||
|
||||
Returns:
|
||||
|
||||
- The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \
|
||||
the last layer of stacked transformer layers
|
||||
|
||||
- Attention matrix of self-attention layers if `return_att=True`
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
# Batch of wordPiece token ids.
|
||||
# Each sample was padded with zero to the maxium length of the batch
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
# Mask of valid input ids
|
||||
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
|
||||
# DeBERTa model initialized with pretrained base model
|
||||
bert = DeBERTa(pre_trained='base')
|
||||
|
||||
encoder_layers = bert(input_ids, attention_mask=attention_mask)
|
||||
|
||||
"""
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
|
@ -113,6 +140,13 @@ class DeBERTa(torch.nn.Module):
|
|||
return encoded_layers
|
||||
|
||||
def apply_state(self, state = None):
|
||||
""" Load state from previous loaded model state dictionary.
|
||||
|
||||
Args:
|
||||
state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \
|
||||
If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \
|
||||
the `DeBERTa` model
|
||||
"""
|
||||
if self.pre_trained is None and state is None:
|
||||
return
|
||||
if state is None:
|
||||
|
|
|
@ -19,6 +19,22 @@ from .ops import *
|
|||
__all__=['build_relative_position', 'DisentangledSelfAttention']
|
||||
|
||||
def build_relative_position(query_size, key_size):
|
||||
""" Build relative position according to the query and key
|
||||
|
||||
We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size),
|
||||
The relative positions from query to key is
|
||||
|
||||
:math:`R_{q \\rightarrow k} = P_q - P_k`
|
||||
|
||||
Args:
|
||||
query_size (int): the length of query
|
||||
key_size (int): the length of key
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
|
||||
|
||||
"""
|
||||
|
||||
q_ids = np.arange(0, query_size)
|
||||
k_ids = np.arange(0, key_size)
|
||||
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1))
|
||||
|
@ -27,8 +43,15 @@ def build_relative_position(query_size, key_size):
|
|||
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
||||
return rel_pos_ids
|
||||
|
||||
|
||||
class DisentangledSelfAttention(torch.nn.Module):
|
||||
""" Disentangled self-attention module
|
||||
|
||||
Parameters:
|
||||
config (:obj:`str`):
|
||||
A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \
|
||||
for more details, please refer :class:`~DeBERTa.deberta.ModelConfig`
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
|
@ -69,6 +92,29 @@ class DisentangledSelfAttention(torch.nn.Module):
|
|||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
|
||||
""" Call the module
|
||||
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor`):
|
||||
Input states to the module usally the output from previous layer, it will be the Q,K and V in `Attention(Q,K,V)`
|
||||
|
||||
attention_mask (:obj:`torch.ByteTensor`):
|
||||
An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maxium sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` th token.
|
||||
|
||||
return_att (:obj:`bool`, optional):
|
||||
Whether return the attention maxitrix.
|
||||
|
||||
query_states (:obj:`torch.FloatTensor`, optional):
|
||||
The `Q` state in `Attention(Q,K,V)`.
|
||||
|
||||
relative_pos (:obj:`torch.LongTensor`):
|
||||
The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with values ranging in [`-max_relative_positions`, `max_relative_positions`].
|
||||
|
||||
rel_embeddings (:obj:`torch.FloatTensor`):
|
||||
The embedding of relative distances. It's a tensor of shape [:math:`2 \\times \\text{max_relative_positions}`, `hidden_size`].
|
||||
|
||||
|
||||
"""
|
||||
if query_states is None:
|
||||
qp = self.in_proj(hidden_states) #.split(self.all_head_size, dim=-1)
|
||||
query_layer,key_layer,value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
|
||||
|
|
|
@ -19,6 +19,32 @@ from .cache_utils import load_vocab
|
|||
__all__ = ['GPT2Tokenizer']
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
""" A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
|
||||
|
||||
Args:
|
||||
|
||||
vocab_file (:obj:`str`, optional):
|
||||
The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, \
|
||||
e.g. "bpe_encoder", default: `None`.
|
||||
|
||||
If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \
|
||||
state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \
|
||||
|
||||
The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are,
|
||||
|
||||
- Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\
|
||||
token of input sentence which is the same as `BERT`.
|
||||
|
||||
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
|
||||
|
||||
do_lower_case (:obj:`bool`, optional):
|
||||
Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**.
|
||||
|
||||
special_tokens (:obj:`list`, optional):
|
||||
List of special tokens to be added to the end of the vocabulary.
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None):
|
||||
pad='[PAD]'
|
||||
eos='[SEP]'
|
||||
|
@ -48,14 +74,54 @@ class GPT2Tokenizer(object):
|
|||
self.ids_to_tokens = self.symbols
|
||||
|
||||
def tokenize(self, text):
|
||||
""" Convert an input text to tokens.
|
||||
|
||||
Args:
|
||||
|
||||
text (:obj:`str`): input text to be tokenized.
|
||||
|
||||
Returns:
|
||||
A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
|
||||
|
||||
Example::
|
||||
|
||||
>>> tokenizer = GPT2Tokenizer()
|
||||
>>> text = "Hello world!"
|
||||
>>> tokens = tokenizer.tokenize(text)
|
||||
>>> print(tokens)
|
||||
['15496', '995', '0']
|
||||
|
||||
"""
|
||||
bpe = self._encode(text)
|
||||
|
||||
return [t for t in bpe.split(' ') if t]
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
""" Convert list of tokens to ids.
|
||||
|
||||
Args:
|
||||
|
||||
tokens (:obj:`list<str>`): list of tokens
|
||||
|
||||
Returns:
|
||||
|
||||
List of ids
|
||||
"""
|
||||
|
||||
return [self.vocab[t] for t in tokens]
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
""" Convert list of ids to tokens.
|
||||
|
||||
Args:
|
||||
|
||||
ids (:obj:`list<int>`): list of ids
|
||||
|
||||
Returns:
|
||||
|
||||
List of tokens
|
||||
"""
|
||||
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
|
@ -65,9 +131,40 @@ class GPT2Tokenizer(object):
|
|||
return self.bpe.split_to_words(text)
|
||||
|
||||
def decode(self, tokens):
|
||||
""" Decode list of tokens to text strings.
|
||||
|
||||
Args:
|
||||
|
||||
tokens (:obj:`list<str>`): list of tokens.
|
||||
|
||||
Returns:
|
||||
|
||||
Text string corresponds to the input tokens.
|
||||
|
||||
Example::
|
||||
|
||||
>>> tokenizer = GPT2Tokenizer()
|
||||
>>> text = "Hello world!"
|
||||
>>> tokens = tokenizer.tokenize(text)
|
||||
>>> print(tokens)
|
||||
['15496', '995', '0']
|
||||
|
||||
>>> tokenizer.decode(tokens)
|
||||
'Hello world!'
|
||||
|
||||
"""
|
||||
return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
|
||||
|
||||
def add_special_token(self, token):
|
||||
"""Adds a special token to the dictionary.
|
||||
|
||||
Args:
|
||||
token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
||||
|
||||
Returns:
|
||||
The id of new token in the vocabulary.
|
||||
|
||||
"""
|
||||
self.special_tokens.append(token)
|
||||
return self.add_symbol(token)
|
||||
|
||||
|
@ -93,7 +190,16 @@ class GPT2Tokenizer(object):
|
|||
return self.bpe.decode(map(int, x.split()))
|
||||
|
||||
def add_symbol(self, word, n=1):
|
||||
"""Adds a word to the dictionary"""
|
||||
"""Adds a word to the dictionary.
|
||||
|
||||
Args:
|
||||
word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
||||
n (int, optional): The frequency of the word.
|
||||
|
||||
Returns:
|
||||
The id of the new word.
|
||||
|
||||
"""
|
||||
if word in self.indices:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
|
|
|
@ -10,10 +10,16 @@ from .cache_utils import load_model_state
|
|||
from ..utils import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['NNModule']
|
||||
|
||||
class NNModule(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
""" An abstract class to handle weights initialization and \
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
|
||||
Args:
|
||||
|
||||
config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
|
@ -21,7 +27,25 @@ class NNModule(nn.Module):
|
|||
self.config = config
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
""" Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module.
|
||||
|
||||
Args:
|
||||
|
||||
module (:obj:`torch.nn.Module`): The module to apply the initialization.
|
||||
|
||||
Example::
|
||||
|
||||
class MyModule(NNModule):
|
||||
def __init__(self, config):
|
||||
# Add construction instructions
|
||||
self.bert = DeBERTa(config)
|
||||
|
||||
# Add other modules
|
||||
...
|
||||
|
||||
# Apply initialization
|
||||
self.apply(self.init_weights)
|
||||
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
@ -29,18 +53,46 @@ class NNModule(nn.Module):
|
|||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model_path, bert_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a NNModule from a pre-trained model file.
|
||||
def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs):
|
||||
""" Instantiate a sub-class of NNModule from a pre-trained model file.
|
||||
|
||||
Args:
|
||||
|
||||
model_path (:obj:`str`): Path or name of the pre-trained model which can be either,
|
||||
|
||||
- The path of pre-trained model
|
||||
|
||||
- The pre-trained DeBERTa model name in `DeBERTa GitHub releases <https://github.com/microsoft/DeBERTa/releases>`_, i.e. [**base, base_mnli, large, large_mnli**].
|
||||
|
||||
If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models.
|
||||
|
||||
model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order:
|
||||
|
||||
1. ['config'] in the model state dictionary.
|
||||
|
||||
2. `model_config.json` aside the `model_path`.
|
||||
|
||||
If it failed to find a config the method will fail.
|
||||
|
||||
tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`.
|
||||
|
||||
no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`.
|
||||
|
||||
cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa`
|
||||
|
||||
Return:
|
||||
|
||||
:obj:`NNModule` : The sub-class object.
|
||||
|
||||
"""
|
||||
# Load config
|
||||
if bert_config:
|
||||
config = ModelConfig.from_json_file(bert_config)
|
||||
if model_config:
|
||||
config = ModelConfig.from_json_file(model_config)
|
||||
else:
|
||||
config = None
|
||||
model_config = None
|
||||
model_state = None
|
||||
if model_path.strip() == '-' or model_path.strip()=='':
|
||||
if model_path and model_path.strip() == '-' or model_path.strip()=='':
|
||||
model_path = None
|
||||
try:
|
||||
model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
||||
|
|
|
@ -19,8 +19,31 @@ else:
|
|||
__all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax']
|
||||
|
||||
class XSoftmax(torch.autograd.Function):
|
||||
""" Masked Softmax which is optimized for saving memory
|
||||
|
||||
Args:
|
||||
|
||||
input (:obj:`torch.tensor`): The input tensor that will apply softmax.
|
||||
mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation.
|
||||
dim (int): The dimenssion that will apply softmax.
|
||||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
from DeBERTa.deberta import XSoftmax
|
||||
# Make a tensor
|
||||
x = torch.randn([4,20,100])
|
||||
# Create a mask
|
||||
mask = (x>0).int()
|
||||
y = XSoftmax.apply(x, mask, dim=-1)
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(self, input, mask, dim):
|
||||
"""
|
||||
"""
|
||||
|
||||
self.dim = dim
|
||||
if version.Version(torch.__version__) >= version.Version('1.2.0a'):
|
||||
rmask = (1-mask).bool()
|
||||
|
@ -35,6 +58,9 @@ class XSoftmax(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(self, grad_output):
|
||||
"""
|
||||
"""
|
||||
|
||||
output, = self.saved_tensors
|
||||
inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
|
||||
return inputGrad, None, None
|
||||
|
@ -88,6 +114,14 @@ class XDropout(torch.autograd.Function):
|
|||
return mask, dropout
|
||||
|
||||
class StableDropout(torch.nn.Module):
|
||||
""" Optimized dropout module for stabilizing the training
|
||||
|
||||
Args:
|
||||
|
||||
drop_prob (float): the dropout probabilities
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
@ -95,6 +129,14 @@ class StableDropout(torch.nn.Module):
|
|||
self.context_stack = None
|
||||
|
||||
def forward(self, x):
|
||||
""" Call the module
|
||||
|
||||
Args:
|
||||
|
||||
x (:obj:`torch.tensor`): The input tensor to apply dropout
|
||||
|
||||
|
||||
"""
|
||||
if self.training and self.drop_prob>0:
|
||||
return XDropout.apply(x, self.get_context())
|
||||
return x
|
||||
|
@ -123,6 +165,22 @@ class StableDropout(torch.nn.Module):
|
|||
return self.drop_prob
|
||||
|
||||
def MaskedLayerNorm(layerNorm, input, mask = None):
|
||||
""" Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module.
|
||||
|
||||
Args:
|
||||
layernorm (:obj:`~DeBERTa.deberta.BertLayerNorm`): LayerNorm module or function
|
||||
input (:obj:`torch.tensor`): The input tensor
|
||||
mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0`
|
||||
|
||||
Example::
|
||||
|
||||
# Create a tensor b x n x d
|
||||
x = torch.randn([1,10,100])
|
||||
m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int)
|
||||
LayerNorm = DeBERTa.deberta.BertLayerNorm(100)
|
||||
y = MaskedLayerNorm(LayerNorm, x, m)
|
||||
|
||||
"""
|
||||
output = layerNorm(input).to(input)
|
||||
if mask is None:
|
||||
return output
|
||||
|
|
|
@ -12,17 +12,49 @@ import json
|
|||
from .bert import ACT2FN
|
||||
from .ops import StableDropout
|
||||
|
||||
__all__ = ['PoolConfig', 'ContextPooler']
|
||||
|
||||
class PoolConfig(object):
|
||||
"""Configuration class to store the configuration of `attention pool layer`.
|
||||
"""Configuration class to store the configuration of `pool layer`.
|
||||
|
||||
Parameters:
|
||||
|
||||
config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config.
|
||||
|
||||
Attributes:
|
||||
|
||||
hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`.
|
||||
|
||||
dropout (float): The dropout rate applied on the output of `[CLS]` token,
|
||||
|
||||
hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh'].
|
||||
|
||||
Example::
|
||||
|
||||
# Here is the content of an exmple model config file in json format
|
||||
|
||||
{
|
||||
"hidden_size": 768,
|
||||
"num_hidden_layers" 12,
|
||||
"num_attention_heads": 12,
|
||||
"intermediate_size": 3072,
|
||||
...
|
||||
"pooling": {
|
||||
"hidden_size": 768,
|
||||
"hidden_act": "gelu",
|
||||
"dropout": 0.1
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
def __init__(self, model_config):
|
||||
def __init__(self, config):
|
||||
"""Constructs PoolConfig.
|
||||
|
||||
Params:
|
||||
`model_config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config.
|
||||
Args:
|
||||
`config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config.
|
||||
"""
|
||||
pool_config = getattr(model_config, 'pooling', model_config)
|
||||
self.hidden_size = getattr(pool_config, 'hidden_size', model_config.hidden_size)
|
||||
pool_config = getattr(config, 'pooling', config)
|
||||
self.hidden_size = getattr(pool_config, 'hidden_size', config.hidden_size)
|
||||
self.dropout = getattr(pool_config, 'dropout', 0)
|
||||
self.hidden_act = getattr(pool_config, 'hidden_act', 'gelu')
|
||||
|
||||
|
|
19
README.md
19
README.md
|
@ -20,6 +20,8 @@ Our pre-trained models are packaged into zipped files. You can download them fro
|
|||
|
||||
# Try the code
|
||||
|
||||
Read our [documentation](https://deberta.readthedocs.io/en/latest/)
|
||||
|
||||
## Requirements
|
||||
- Linux system, e.g. Ubuntu 18.04LTS
|
||||
- CUDA 10.0
|
||||
|
@ -77,6 +79,23 @@ class MyModel(torch.nn.Module):
|
|||
# 2. Change your tokenizer with the the tokenizer built in DeBERta
|
||||
from DeBERTa import deberta
|
||||
tokenizer = deberta.GPT2Tokenizer()
|
||||
# We apply the same schema of special tokens as BERT, e.g. [CLS], [SEP], [MASK]
|
||||
max_seq_len = 512
|
||||
tokens = tokenizer.tokenize('Examples input text of DeBERTa')
|
||||
# Truncate long sequence
|
||||
tokens = tokens[:max_seq_len]
|
||||
# Add special tokens to the `tokens`
|
||||
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1]*len(input_ids)
|
||||
# padding
|
||||
paddings = max_seq_len-len(input_ids)
|
||||
input_ids = input_ids + [0]*paddings
|
||||
input_mask = input_mask + [0]*paddings
|
||||
features = {
|
||||
'input_ids': torch.tensor(input_ids, dtype=torch.int),
|
||||
'input_mask': torch.tensor(input_mask, dtype=torch.int)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
0.1.3
|
||||
0.1.4
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@ -0,0 +1,35 @@
|
|||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
|
@ -0,0 +1,17 @@
|
|||
recommonmark
|
||||
sphinx
|
||||
sphinx-markdown-tables
|
||||
sphinx-rtd-theme
|
||||
nltk
|
||||
spacy
|
||||
numpy
|
||||
pytest
|
||||
regex
|
||||
scipy
|
||||
sklearn
|
||||
torch==1.3.0
|
||||
torchvision==0.3.0
|
||||
tqdm
|
||||
ujson
|
||||
seqeval
|
||||
psutil
|
|
@ -0,0 +1,184 @@
|
|||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = u'DeBERTa'
|
||||
copyright = u'2020, Microsoft'
|
||||
author = u'Pengcheng He'
|
||||
|
||||
|
||||
# The short X.Y version
|
||||
version=u''
|
||||
# The full version, including alpha/beta/rc tags
|
||||
with open(os.path.join(src_dir, 'VERSION'), encoding='utf-8') as fs:
|
||||
ver = fs.readline().strip()
|
||||
release = ver
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.napoleon',
|
||||
'recommonmark',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx_markdown_tables'
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
# You can specify multiple suffix as a list of string:
|
||||
#
|
||||
source_suffix = ['.rst', '.md']
|
||||
# source_suffix = '.rst'
|
||||
|
||||
# The master toctree document.
|
||||
master_doc = 'index'
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
#
|
||||
# This is also used if you do content translation via gettext catalogs.
|
||||
# Usually you set "language" from the command line for these cases.
|
||||
language = None
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = None
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
html_theme_options = {
|
||||
'analytics_id': 'UA-83738774-2'
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
# Custom sidebar templates, must be a dictionary that maps document names
|
||||
# to template names.
|
||||
#
|
||||
# The default sidebars (for documents that don't match any pattern) are
|
||||
# defined by theme itself. Builtin themes are using these templates by
|
||||
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
|
||||
# 'searchbox.html']``.
|
||||
#
|
||||
# html_sidebars = {}
|
||||
|
||||
# This must be the name of an image file (path relative to the configuration
|
||||
# directory) that is the favicon of the docs. Modern browsers use this as
|
||||
# the icon for tabs, windows and bookmarks. It should be a Windows-style
|
||||
# icon file (.ico).
|
||||
html_favicon = 'favicon.ico'
|
||||
|
||||
|
||||
# -- Options for HTMLHelp output ---------------------------------------------
|
||||
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'debertadoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
# The paper size ('letterpaper' or 'a4paper').
|
||||
#
|
||||
# 'papersize': 'letterpaper',
|
||||
|
||||
# The font size ('10pt', '11pt' or '12pt').
|
||||
#
|
||||
# 'pointsize': '10pt',
|
||||
|
||||
# Additional stuff for the LaTeX preamble.
|
||||
#
|
||||
# 'preamble': '',
|
||||
|
||||
# Latex figure (float) alignment
|
||||
#
|
||||
# 'figure_align': 'htbp',
|
||||
}
|
||||
|
||||
# Grouping the document tree into LaTeX files. List of tuples
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, 'deberta.tex', u'DeBERTa Documentation',
|
||||
u'Microsoft', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [
|
||||
(master_doc, 'deberta', u'DeBERTa Documentation',
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
# (source start file, target name, title, author,
|
||||
# dir menu entry, description, category)
|
||||
texinfo_documents = [
|
||||
(master_doc, 'deberta', u'DeBERTa Documentation',
|
||||
author, 'deberta', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Epub output -------------------------------------------------
|
||||
|
||||
# Bibliographic Dublin Core info.
|
||||
epub_title = project
|
||||
|
||||
# The unique identifier of the text. This can be a ISBN number
|
||||
# or the project homepage.
|
||||
#
|
||||
# epub_identifier = ''
|
||||
|
||||
# A unique identification for the text.
|
||||
#
|
||||
# epub_uid = ''
|
||||
|
||||
# A list of files that should not be packed into the epub file.
|
||||
epub_exclude_files = ['search.html']
|
|
@ -0,0 +1,20 @@
|
|||
.. DeBERTa documentation master file, created by
|
||||
sphinx-quickstart on Wed Jun 17 19:34:55 2020.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to DeBERTa's documentation!
|
||||
===================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
modules/deberta
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
|
@ -0,0 +1,76 @@
|
|||
DeBERTa
|
||||
------------------------------
|
||||
|
||||
DeBERTa Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.DeBERTa
|
||||
:members:
|
||||
|
||||
NNModule
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.NNModule
|
||||
:members:
|
||||
|
||||
DisentangledSelfAttention
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.DisentangledSelfAttention
|
||||
:members:
|
||||
|
||||
.. autofunction:: DeBERTa.deberta.build_relative_position
|
||||
|
||||
ContextPooler
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.ContextPooler
|
||||
:members:
|
||||
|
||||
BertEncoder
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.BertEncoder
|
||||
:members:
|
||||
|
||||
BertLayerNorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.BertLayerNorm
|
||||
:members:
|
||||
|
||||
XSoftmax
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.XSoftmax
|
||||
|
||||
.. :members:
|
||||
|
||||
StableDropout
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.StableDropout
|
||||
:members:
|
||||
|
||||
MaskedLayerNorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: DeBERTa.deberta.MaskedLayerNorm
|
||||
|
||||
GPT2Tokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.GPT2Tokenizer
|
||||
:members:
|
||||
|
||||
ModelConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.ModelConfig
|
||||
:members:
|
||||
|
||||
PoolConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: DeBERTa.deberta.PoolConfig
|
||||
:members:
|
|
@ -111,7 +111,7 @@ fi
|
|||
parameters="--task_name $Task $CMD \
|
||||
--data_dir $Data \
|
||||
--init_model $MODEL \
|
||||
--bert_config $OUTPUT/model_config.json \
|
||||
--model_config $OUTPUT/model_config.json \
|
||||
--max_seq_length 512 \
|
||||
--eval_batch_size 128 \
|
||||
--predict_batch_size 128 \
|
||||
|
|
6
setup.py
6
setup.py
|
@ -16,12 +16,17 @@ with open('VERSION') as fs:
|
|||
with open('requirements.txt') as fs:
|
||||
requirements = [l.strip() for l in fs]
|
||||
|
||||
extras = {}
|
||||
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"]
|
||||
|
||||
setuptools.setup(
|
||||
name="DeBERTa",
|
||||
version=version,
|
||||
author="penhe",
|
||||
author_email="penhe@microsoft.com",
|
||||
description="Decoding enhanced BERT with Disentangled Attention",
|
||||
keywords="NLP deep learning transformer pytorch Attention BERT RoBERTa DeBERTa",
|
||||
license="MIT",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/microsoft/DeBERTa",
|
||||
|
@ -33,4 +38,5 @@ setuptools.setup(
|
|||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires='>=3.6',
|
||||
extras_require=extras,
|
||||
install_requires=requirements)
|
||||
|
|
Загрузка…
Ссылка в новой задаче