diff --git a/DeBERTa/apps/train.py b/DeBERTa/apps/train.py index dffe31e..e9218d1 100644 --- a/DeBERTa/apps/train.py +++ b/DeBERTa/apps/train.py @@ -32,7 +32,7 @@ def create_model(args, num_labels, model_class_fn): # Prepare model rank = getattr(args, 'rank', 0) init_model = args.init_model if rank<1 else None - model = model_class_fn(init_model, args.bert_config, num_labels=num_labels, \ + model = model_class_fn(init_model, args.model_config, num_labels=num_labels, \ drop_out=args.cls_drop_out, \ pre_trained = args.pre_trained) if args.fp16: @@ -379,7 +379,7 @@ def build_argument_parser(): type=str, help="The model state file used to initialize the model weights.") - parser.add_argument('--bert_config', + parser.add_argument('--model_config', type=str, help="The config file of bert model.") diff --git a/DeBERTa/deberta/__init__.py b/DeBERTa/deberta/__init__.py index 9c682f2..87d22dd 100644 --- a/DeBERTa/deberta/__init__.py +++ b/DeBERTa/deberta/__init__.py @@ -18,3 +18,4 @@ from .disentangled_attention import * from .ops import * from .bert import * from .gpt2_tokenizer import GPT2Tokenizer +from .config import * diff --git a/DeBERTa/deberta/bert.py b/DeBERTa/deberta/bert.py index 1dedf57..69b0c91 100644 --- a/DeBERTa/deberta/bert.py +++ b/DeBERTa/deberta/bert.py @@ -39,9 +39,10 @@ def linear_act(x): ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "tanh": torch.nn.functional.tanh, "linear": linear_act, 'sigmoid': torch.sigmoid} class BertLayerNorm(nn.Module): + """LayerNorm module in the TF style (epsilon inside the square root). + """ + def __init__(self, size, eps=1e-12): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ super().__init__() self.weight = nn.Parameter(torch.ones(size)) self.bias = nn.Parameter(torch.zeros(size)) @@ -139,6 +140,8 @@ class BertLayer(nn.Module): return layer_output class BertEncoder(nn.Module): + """ Modified BertEncoder with relative position bias support + """ def __init__(self, config): super().__init__() layer = BertLayer(config) diff --git a/DeBERTa/deberta/config.py b/DeBERTa/deberta/config.py index 7b82d8f..11f23aa 100644 --- a/DeBERTa/deberta/config.py +++ b/DeBERTa/deberta/config.py @@ -40,33 +40,41 @@ class AbsModelConfig(object): return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n" class ModelConfig(AbsModelConfig): - """Configuration class to store the configuration of a `BertModel`. + """Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model. + + Attributes: + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. + num_attention_heads (int): Number of attention heads for each attention layer in + the Transformer encoder, default: `12`. + intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder, default: `3072`. + hidden_act (str): The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. + hidden_dropout_prob (float): The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler, default: `0.1`. + attention_probs_dropout_prob (float): The dropout ratio for the attention + probabilities, default: `0.1`. + max_position_embeddings (int): The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048), default: `512`. + type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into + `DeBERTa` model, default: `-1`. + initializer_range (int): The sttdev of the _normal_initializer for + initializing all weight matrices, default: `0.02`. + relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. + max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. + padding_idx (int): The value used to pad input_ids, default: `0`. + position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. + pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None". + + """ def __init__(self): """Constructs ModelConfig. - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - hidden_dropout_prob: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. """ + self.hidden_size = 768 self.num_hidden_layers = 12 self.num_attention_heads = 12 diff --git a/DeBERTa/deberta/deberta.py b/DeBERTa/deberta/deberta.py index d6b0848..3b9491f 100644 --- a/DeBERTa/deberta/deberta.py +++ b/DeBERTa/deberta/deberta.py @@ -23,38 +23,17 @@ class DeBERTa(torch.nn.Module): """ DeBERTa encoder This module is composed of the input embedding layer with stacked transformer layers with disentangled attention. - Params: - `config`: A model config class instance with the configuration to build a new model. The schema is similar to BertConfig, for more details, please refer `config.py` - `pre_trained`: The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configuration, e.g. base, large, base_mnli, large_mnli + Parameters: + config: + A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ + for more details, please refer :class:`~DeBERTa.deberta.ModelConfig` - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional parameter for input mask or attention mask. - - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - - If it's an attention mask then if will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence. - `output_all_encoded_layers`: whether to output results of all encoder layers, default, True - - Outputs: - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else - the last layer of stacked transformer layers - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - bert = DeBERTa(pre_trained='base') - encoder_layers = bert(input_ids, attention_mask=attention_mask) - ``` + pre_trained: + The pre-trained DeBERTa model, it can be a physical path of a pre-trained DeBERTa model or a released configurations, \ + i.e. [**base, large, base_mnli, large_mnli**] """ + def __init__(self, config=None, pre_trained=None): super().__init__() if config: @@ -82,6 +61,54 @@ class DeBERTa(torch.nn.Module): self.apply_state(state) def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, position_ids = None, return_att = False): + """ + Args: + input_ids: + a torch.LongTensor of shape [batch_size, sequence_length] \ + with the word token indices in the vocabulary + + token_type_ids: + an optional torch.LongTensor of shape [batch_size, sequence_length] with the token \ + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to \ + a `sentence B` token (see BERT paper for more details). + + attention_mask: + an optional parameter for input mask or attention mask. + + - If it's an input mask, then it will be torch.LongTensor of shape [batch_size, sequence_length] with indices \ + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max \ + input sequence length in the current batch. It's the mask that we typically use for attention when \ + a batch has varying length sentences. + + - If it's an attention mask then it will be torch.LongTensor of shape [batch_size, sequence_length, sequence_length]. \ + In this case, it's a mask indicate which tokens in the sequence should be attended by other tokens in the sequence. + + output_all_encoded_layers: + whether to output results of all encoder layers, default, True + + Returns: + + - The output of the stacked transformer layers if `output_all_encoded_layers=True`, else \ + the last layer of stacked transformer layers + + - Attention matrix of self-attention layers if `return_att=True` + + + Example:: + + # Batch of wordPiece token ids. + # Each sample was padded with zero to the maxium length of the batch + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + # Mask of valid input ids + attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + + # DeBERTa model initialized with pretrained base model + bert = DeBERTa(pre_trained='base') + + encoder_layers = bert(input_ids, attention_mask=attention_mask) + + """ + if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -113,6 +140,13 @@ class DeBERTa(torch.nn.Module): return encoded_layers def apply_state(self, state = None): + """ Load state from previous loaded model state dictionary. + + Args: + state (:obj:`dict`, optional): State dictionary as the state returned by torch.module.state_dict(), default: `None`. \ + If it's `None`, then will use the pre-trained state loaded via the constructor to re-initialize \ + the `DeBERTa` model + """ if self.pre_trained is None and state is None: return if state is None: diff --git a/DeBERTa/deberta/disentangled_attention.py b/DeBERTa/deberta/disentangled_attention.py index 457f47e..24e323b 100644 --- a/DeBERTa/deberta/disentangled_attention.py +++ b/DeBERTa/deberta/disentangled_attention.py @@ -19,6 +19,22 @@ from .ops import * __all__=['build_relative_position', 'DisentangledSelfAttention'] def build_relative_position(query_size, key_size): + """ Build relative position according to the query and key + + We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size), + The relative positions from query to key is + + :math:`R_{q \\rightarrow k} = P_q - P_k` + + Args: + query_size (int): the length of query + key_size (int): the length of key + + Return: + :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + q_ids = np.arange(0, query_size) k_ids = np.arange(0, key_size) rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0],1)) @@ -27,8 +43,15 @@ def build_relative_position(query_size, key_size): rel_pos_ids = rel_pos_ids.unsqueeze(0) return rel_pos_ids - class DisentangledSelfAttention(torch.nn.Module): + """ Disentangled self-attention module + + Parameters: + config (:obj:`str`): + A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ + for more details, please refer :class:`~DeBERTa.deberta.ModelConfig` + + """ def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0: @@ -69,6 +92,29 @@ class DisentangledSelfAttention(torch.nn.Module): return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None): + """ Call the module + + Args: + hidden_states (:obj:`torch.FloatTensor`): + Input states to the module usally the output from previous layer, it will be the Q,K and V in `Attention(Q,K,V)` + + attention_mask (:obj:`torch.ByteTensor`): + An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maxium sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` th token. + + return_att (:obj:`bool`, optional): + Whether return the attention maxitrix. + + query_states (:obj:`torch.FloatTensor`, optional): + The `Q` state in `Attention(Q,K,V)`. + + relative_pos (:obj:`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with values ranging in [`-max_relative_positions`, `max_relative_positions`]. + + rel_embeddings (:obj:`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [:math:`2 \\times \\text{max_relative_positions}`, `hidden_size`]. + + + """ if query_states is None: qp = self.in_proj(hidden_states) #.split(self.all_head_size, dim=-1) query_layer,key_layer,value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) diff --git a/DeBERTa/deberta/gpt2_tokenizer.py b/DeBERTa/deberta/gpt2_tokenizer.py index feaa701..d2c6fad 100644 --- a/DeBERTa/deberta/gpt2_tokenizer.py +++ b/DeBERTa/deberta/gpt2_tokenizer.py @@ -19,6 +19,32 @@ from .cache_utils import load_vocab __all__ = ['GPT2Tokenizer'] class GPT2Tokenizer(object): + """ A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer + + Args: + + vocab_file (:obj:`str`, optional): + The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases `_, \ + e.g. "bpe_encoder", default: `None`. + + If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \ + state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \ + + The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are, + + - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\ + token of input sentence which is the same as `BERT`. + + - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 + + do_lower_case (:obj:`bool`, optional): + Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**. + + special_tokens (:obj:`list`, optional): + List of special tokens to be added to the end of the vocabulary. + + + """ def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None): pad='[PAD]' eos='[SEP]' @@ -48,14 +74,54 @@ class GPT2Tokenizer(object): self.ids_to_tokens = self.symbols def tokenize(self, text): + """ Convert an input text to tokens. + + Args: + + text (:obj:`str`): input text to be tokenized. + + Returns: + A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + + Example:: + + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + + """ bpe = self._encode(text) return [t for t in bpe.split(' ') if t] def convert_tokens_to_ids(self, tokens): + """ Convert list of tokens to ids. + + Args: + + tokens (:obj:`list`): list of tokens + + Returns: + + List of ids + """ + return [self.vocab[t] for t in tokens] def convert_ids_to_tokens(self, ids): + """ Convert list of ids to tokens. + + Args: + + ids (:obj:`list`): list of ids + + Returns: + + List of tokens + """ + tokens = [] for i in ids: tokens.append(self.ids_to_tokens[i]) @@ -65,9 +131,40 @@ class GPT2Tokenizer(object): return self.bpe.split_to_words(text) def decode(self, tokens): + """ Decode list of tokens to text strings. + + Args: + + tokens (:obj:`list`): list of tokens. + + Returns: + + Text string corresponds to the input tokens. + + Example:: + + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + + >>> tokenizer.decode(tokens) + 'Hello world!' + + """ return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) def add_special_token(self, token): + """Adds a special token to the dictionary. + + Args: + token (:obj:`str`): Tthe new token/word to be added to the vocabulary. + + Returns: + The id of new token in the vocabulary. + + """ self.special_tokens.append(token) return self.add_symbol(token) @@ -93,7 +190,16 @@ class GPT2Tokenizer(object): return self.bpe.decode(map(int, x.split())) def add_symbol(self, word, n=1): - """Adds a word to the dictionary""" + """Adds a word to the dictionary. + + Args: + word (:obj:`str`): Tthe new token/word to be added to the vocabulary. + n (int, optional): The frequency of the word. + + Returns: + The id of the new word. + + """ if word in self.indices: idx = self.indices[word] self.count[idx] = self.count[idx] + n diff --git a/DeBERTa/deberta/nnmodule.py b/DeBERTa/deberta/nnmodule.py index 4a3570e..09b33b0 100644 --- a/DeBERTa/deberta/nnmodule.py +++ b/DeBERTa/deberta/nnmodule.py @@ -10,10 +10,16 @@ from .cache_utils import load_model_state from ..utils import get_logger logger = get_logger() +__all__ = ['NNModule'] class NNModule(nn.Module): - """ An abstract class to handle weights initialization and + """ An abstract class to handle weights initialization and \ a simple interface for dowloading and loading pretrained models. + + Args: + + config (:obj:`~DeBERTa.deberta.ModelConfig`): The model config to the module + """ def __init__(self, config, *inputs, **kwargs): @@ -21,7 +27,25 @@ class NNModule(nn.Module): self.config = config def init_weights(self, module): - """ Initialize the weights. + """ Apply Gaussian(mean=0, std=`config.initializer_range`) initialization to the module. + + Args: + + module (:obj:`torch.nn.Module`): The module to apply the initialization. + + Example:: + + class MyModule(NNModule): + def __init__(self, config): + # Add construction instructions + self.bert = DeBERTa(config) + + # Add other modules + ... + + # Apply initialization + self.apply(self.init_weights) + """ if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @@ -29,18 +53,46 @@ class NNModule(nn.Module): module.bias.data.zero_() @classmethod - def load_model(cls, model_path, bert_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs): - """ - Instantiate a NNModule from a pre-trained model file. + def load_model(cls, model_path, model_config=None, tag=None, no_cache=False, cache_dir=None , *inputs, **kwargs): + """ Instantiate a sub-class of NNModule from a pre-trained model file. + + Args: + + model_path (:obj:`str`): Path or name of the pre-trained model which can be either, + + - The path of pre-trained model + + - The pre-trained DeBERTa model name in `DeBERTa GitHub releases `_, i.e. [**base, base_mnli, large, large_mnli**]. + + If `model_path` is `None` or `-`, then the method will create a new sub-class without initialing from pre-trained models. + + model_config (:obj:`str`): The path of model config file. If it's `None`, then the method will try to find the the config in order: + + 1. ['config'] in the model state dictionary. + + 2. `model_config.json` aside the `model_path`. + + If it failed to find a config the method will fail. + + tag (:obj:`str`, optional): The release tag of DeBERTa, default: `None`. + + no_cache (:obj:`bool`, optional): Disable local cache of downloaded models, default: `False`. + + cache_dir (:obj:`str`, optional): The cache directory used to save the downloaded models, default: `None`. If it's `None`, then the models will be saved at `$HOME/.~DeBERTa` + + Return: + + :obj:`NNModule` : The sub-class object. + """ # Load config - if bert_config: - config = ModelConfig.from_json_file(bert_config) + if model_config: + config = ModelConfig.from_json_file(model_config) else: config = None model_config = None model_state = None - if model_path.strip() == '-' or model_path.strip()=='': + if model_path and model_path.strip() == '-' or model_path.strip()=='': model_path = None try: model_state, model_config = load_model_state(model_path, tag=tag, no_cache=no_cache, cache_dir=cache_dir) diff --git a/DeBERTa/deberta/ops.py b/DeBERTa/deberta/ops.py index 4f53509..30f2ba9 100644 --- a/DeBERTa/deberta/ops.py +++ b/DeBERTa/deberta/ops.py @@ -19,8 +19,31 @@ else: __all__ = ['StableDropout', 'MaskedLayerNorm', 'XSoftmax'] class XSoftmax(torch.autograd.Function): + """ Masked Softmax which is optimized for saving memory + + Args: + + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation. + dim (int): The dimenssion that will apply softmax. + + Example:: + + import torch + from DeBERTa.deberta import XSoftmax + # Make a tensor + x = torch.randn([4,20,100]) + # Create a mask + mask = (x>0).int() + y = XSoftmax.apply(x, mask, dim=-1) + + """ + @staticmethod def forward(self, input, mask, dim): + """ + """ + self.dim = dim if version.Version(torch.__version__) >= version.Version('1.2.0a'): rmask = (1-mask).bool() @@ -35,6 +58,9 @@ class XSoftmax(torch.autograd.Function): @staticmethod def backward(self, grad_output): + """ + """ + output, = self.saved_tensors inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) return inputGrad, None, None @@ -88,6 +114,14 @@ class XDropout(torch.autograd.Function): return mask, dropout class StableDropout(torch.nn.Module): + """ Optimized dropout module for stabilizing the training + + Args: + + drop_prob (float): the dropout probabilities + + """ + def __init__(self, drop_prob): super().__init__() self.drop_prob = drop_prob @@ -95,6 +129,14 @@ class StableDropout(torch.nn.Module): self.context_stack = None def forward(self, x): + """ Call the module + + Args: + + x (:obj:`torch.tensor`): The input tensor to apply dropout + + + """ if self.training and self.drop_prob>0: return XDropout.apply(x, self.get_context()) return x @@ -123,6 +165,22 @@ class StableDropout(torch.nn.Module): return self.drop_prob def MaskedLayerNorm(layerNorm, input, mask = None): + """ Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module. + + Args: + layernorm (:obj:`~DeBERTa.deberta.BertLayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` + + Example:: + + # Create a tensor b x n x d + x = torch.randn([1,10,100]) + m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) + LayerNorm = DeBERTa.deberta.BertLayerNorm(100) + y = MaskedLayerNorm(LayerNorm, x, m) + + """ output = layerNorm(input).to(input) if mask is None: return output diff --git a/DeBERTa/deberta/pooling.py b/DeBERTa/deberta/pooling.py index c2a9ecc..edd6423 100644 --- a/DeBERTa/deberta/pooling.py +++ b/DeBERTa/deberta/pooling.py @@ -12,17 +12,49 @@ import json from .bert import ACT2FN from .ops import StableDropout +__all__ = ['PoolConfig', 'ContextPooler'] + class PoolConfig(object): - """Configuration class to store the configuration of `attention pool layer`. + """Configuration class to store the configuration of `pool layer`. + + Parameters: + + config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config. + + Attributes: + + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + + dropout (float): The dropout rate applied on the output of `[CLS]` token, + + hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh']. + + Example:: + + # Here is the content of an exmple model config file in json format + + { + "hidden_size": 768, + "num_hidden_layers" 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + ... + "pooling": { + "hidden_size": 768, + "hidden_act": "gelu", + "dropout": 0.1 + } + } + """ - def __init__(self, model_config): + def __init__(self, config): """Constructs PoolConfig. - Params: - `model_config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config. + Args: + `config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config. """ - pool_config = getattr(model_config, 'pooling', model_config) - self.hidden_size = getattr(pool_config, 'hidden_size', model_config.hidden_size) + pool_config = getattr(config, 'pooling', config) + self.hidden_size = getattr(pool_config, 'hidden_size', config.hidden_size) self.dropout = getattr(pool_config, 'dropout', 0) self.hidden_act = getattr(pool_config, 'hidden_act', 'gelu') diff --git a/README.md b/README.md index 2088404..d2ca42f 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Our pre-trained models are packaged into zipped files. You can download them fro # Try the code +Read our [documentation](https://deberta.readthedocs.io/en/latest/) + ## Requirements - Linux system, e.g. Ubuntu 18.04LTS - CUDA 10.0 @@ -77,6 +79,23 @@ class MyModel(torch.nn.Module): # 2. Change your tokenizer with the the tokenizer built in DeBERta from DeBERTa import deberta tokenizer = deberta.GPT2Tokenizer() +# We apply the same schema of special tokens as BERT, e.g. [CLS], [SEP], [MASK] +max_seq_len = 512 +tokens = tokenizer.tokenize('Examples input text of DeBERTa') +# Truncate long sequence +tokens = tokens[:max_seq_len] +# Add special tokens to the `tokens` +tokens = ['[CLS]'] + tokens + ['[SEP]'] +input_ids = tokenizer.convert_tokens_to_ids(tokens) +input_mask = [1]*len(input_ids) +# padding +paddings = max_seq_len-len(input_ids) +input_ids = input_ids + [0]*paddings +input_mask = input_mask + [0]*paddings +features = { +'input_ids': torch.tensor(input_ids, dtype=torch.int), +'input_mask': torch.tensor(input_mask, dtype=torch.int) +} ``` diff --git a/VERSION b/VERSION index b1e80bb..845639e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.3 +0.1.4 diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..6247f7e --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..79e8106 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,17 @@ +recommonmark +sphinx +sphinx-markdown-tables +sphinx-rtd-theme +nltk +spacy +numpy +pytest +regex +scipy +sklearn +torch==1.3.0 +torchvision==0.3.0 +tqdm +ujson +seqeval +psutil diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..8132a37 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,184 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) +sys.path.insert(0, src_dir) + + +# -- Project information ----------------------------------------------------- + +project = u'DeBERTa' +copyright = u'2020, Microsoft' +author = u'Pengcheng He' + + +# The short X.Y version +version=u'' +# The full version, including alpha/beta/rc tags +with open(os.path.join(src_dir, 'VERSION'), encoding='utf-8') as fs: + ver = fs.readline().strip() +release = ver + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.coverage', + 'sphinx.ext.napoleon', + 'recommonmark', + 'sphinx.ext.viewcode', + 'sphinx_markdown_tables' +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = ['.rst', '.md'] +# source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = None + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + 'analytics_id': 'UA-83738774-2' +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + +# This must be the name of an image file (path relative to the configuration +# directory) that is the favicon of the docs. Modern browsers use this as +# the icon for tabs, windows and bookmarks. It should be a Windows-style +# icon file (.ico). +html_favicon = 'favicon.ico' + + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = 'debertadoc' + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'deberta.tex', u'DeBERTa Documentation', + u'Microsoft', 'manual'), +] + + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'deberta', u'DeBERTa Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'deberta', u'DeBERTa Documentation', + author, 'deberta', 'One line description of project.', + 'Miscellaneous'), +] + + +# -- Options for Epub output ------------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = project + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +# +# epub_identifier = '' + +# A unique identification for the text. +# +# epub_uid = '' + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..c5f4a02 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,20 @@ +.. DeBERTa documentation master file, created by + sphinx-quickstart on Wed Jun 17 19:34:55 2020. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to DeBERTa's documentation! +=================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + modules/deberta + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/modules/deberta.rst b/docs/source/modules/deberta.rst new file mode 100644 index 0000000..55ae9b9 --- /dev/null +++ b/docs/source/modules/deberta.rst @@ -0,0 +1,76 @@ +DeBERTa +------------------------------ + +DeBERTa Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.DeBERTa + :members: + +NNModule +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.NNModule + :members: + +DisentangledSelfAttention +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.DisentangledSelfAttention + :members: + +.. autofunction:: DeBERTa.deberta.build_relative_position + +ContextPooler +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.ContextPooler + :members: + +BertEncoder +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.BertEncoder + :members: + +BertLayerNorm +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.BertLayerNorm + :members: + +XSoftmax +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.XSoftmax + +.. :members: + +StableDropout +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.StableDropout + :members: + +MaskedLayerNorm +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: DeBERTa.deberta.MaskedLayerNorm + +GPT2Tokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.GPT2Tokenizer + :members: + +ModelConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.ModelConfig + :members: + +PoolConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: DeBERTa.deberta.PoolConfig + :members: diff --git a/experiments/utils/train.sh b/experiments/utils/train.sh index fe2ee9d..31c00e0 100755 --- a/experiments/utils/train.sh +++ b/experiments/utils/train.sh @@ -111,7 +111,7 @@ fi parameters="--task_name $Task $CMD \ --data_dir $Data \ --init_model $MODEL \ - --bert_config $OUTPUT/model_config.json \ + --model_config $OUTPUT/model_config.json \ --max_seq_length 512 \ --eval_batch_size 128 \ --predict_batch_size 128 \ diff --git a/setup.py b/setup.py index da3cdc3..d346714 100644 --- a/setup.py +++ b/setup.py @@ -16,12 +16,17 @@ with open('VERSION') as fs: with open('requirements.txt') as fs: requirements = [l.strip() for l in fs] +extras = {} +extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"] + setuptools.setup( name="DeBERTa", version=version, author="penhe", author_email="penhe@microsoft.com", description="Decoding enhanced BERT with Disentangled Attention", + keywords="NLP deep learning transformer pytorch Attention BERT RoBERTa DeBERTa", + license="MIT", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/microsoft/DeBERTa", @@ -33,4 +38,5 @@ setuptools.setup( "Operating System :: OS Independent", ], python_requires='>=3.6', + extras_require=extras, install_requires=requirements)