Added BERT prefix to classifier names and some minor docstring updates.

This commit is contained in:
Hong Lu 2019-06-07 17:08:29 -04:00
Родитель 4e7ac8adc1
Коммит 049ddf6442
4 изменённых файлов: 83 добавлений и 66 удалений

Просмотреть файл

@ -45,7 +45,7 @@
"if nlp_path not in sys.path:\n",
" sys.path.insert(0, nlp_path)\n",
"\n",
"from utils_nlp.bert.token_classification import BertTokenClassifier, postprocess_token_labels\n",
"from utils_nlp.bert.token_classification import BERTTokenClassifier, postprocess_token_labels\n",
"from utils_nlp.bert.common import Language, Tokenizer\n",
"from utils_nlp.dataset.wikigold import download, read_data, get_train_test_data, get_unique_labels"
]
@ -112,7 +112,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"metadata": {
"scrolled": true
},
@ -128,10 +128,10 @@
"['O', 'I-LOC', 'I-MISC', 'I-PER', 'I-ORG', 'X']\n",
"\n",
"Sample sentence: \n",
"His revolt was fueled by grievances stemming from Spanish tributes and abuses , and his belief in self-government , that the administration and leadership of the Roman Catholic Church and government in the Ilocos Region ( which at this time did not include Pangasinan ) should be led by trained Ilocano officials .\n",
"-DOCSTART-\n",
"\n",
"Sample sentence labels: \n",
"['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'I-MISC', 'I-MISC', 'O', 'O', 'O', 'O', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'I-MISC', 'O', 'O']\n",
"['O']\n",
"\n"
]
}
@ -214,13 +214,13 @@
"train_token_ids, train_input_mask, train_trailing_token_mask, train_label_ids = \\\n",
" tokenizer.preprocess_ner_tokens(text=train_text,\n",
" label_map=label_map,\n",
" max_seq_length=max_seq_length,\n",
" max_len=max_seq_length,\n",
" labels=train_labels,\n",
" trailing_piece_tag=\"X\")\n",
"test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \\\n",
" tokenizer.preprocess_ner_tokens(text=test_text,\n",
" label_map=label_map,\n",
" max_seq_length=max_seq_length,\n",
" max_len=max_seq_length,\n",
" labels=test_labels,\n",
" trailing_piece_tag=\"X\")"
]
@ -238,7 +238,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 7,
"metadata": {
"scrolled": true
},
@ -293,7 +293,7 @@
},
"outputs": [],
"source": [
"token_classifier = BertTokenClassifier(language=language,\n",
"token_classifier = BERTTokenClassifier(language=language,\n",
" num_labels=len(label_list),\n",
" cache_dir=cache_dir)"
]
@ -335,8 +335,8 @@
"\n",
"Iteration: 41%|████▏ | 24/58 [00:31<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 41%|████▏ | 24/58 [00:49<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 81%|████████ | 47/58 [01:01<00:14, 1.31s/it]\u001b[A\n",
"Epoch: 20%|██ | 1/5 [01:15<05:02, 75.53s/it]0s/it]\u001b[A\n",
"Iteration: 83%|████████▎ | 48/58 [01:02<00:12, 1.30s/it]\u001b[A\n",
"Epoch: 20%|██ | 1/5 [01:14<04:59, 74.89s/it]9s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -352,11 +352,11 @@
"output_type": "stream",
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:45, 1.31s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:44<00:45, 1.31s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:14<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 40%|████ | 2/5 [02:31<03:47, 75.71s/it]1s/it]\u001b[A\n",
"Iteration: 41%|████▏ | 24/58 [00:31<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 41%|████▏ | 24/58 [00:45<00:44, 1.30s/it]\u001b[A\n",
"Iteration: 83%|████████▎ | 48/58 [01:02<00:13, 1.30s/it]\u001b[A\n",
"Iteration: 83%|████████▎ | 48/58 [01:15<00:13, 1.30s/it]\u001b[A\n",
"Epoch: 40%|████ | 2/5 [02:30<03:44, 74.96s/it]0s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -373,9 +373,9 @@
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:48<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:49<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 60%|██████ | 3/5 [03:47<02:31, 75.84s/it]1s/it]\u001b[A\n",
"Epoch: 60%|██████ | 3/5 [03:46<02:30, 75.28s/it]1s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -391,11 +391,11 @@
"output_type": "stream",
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.33s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:42<00:46, 1.33s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.33s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:12<00:15, 1.33s/it]\u001b[A\n",
"Epoch: 80%|████████ | 4/5 [05:04<01:16, 76.02s/it]2s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:43<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:13<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 80%|████████ | 4/5 [05:02<01:15, 75.57s/it]1s/it]\u001b[A\n",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]\u001b[A"
]
},
@ -412,10 +412,9 @@
"text": [
"\n",
"Iteration: 40%|███▉ | 23/58 [00:30<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 40%|███▉ | 23/58 [00:45<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:01<00:15, 1.33s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:15<00:15, 1.33s/it]\u001b[A\n",
"Epoch: 100%|██████████| 5/5 [06:20<00:00, 76.16s/it]2s/it]\u001b[A"
"Iteration: 40%|███▉ | 23/58 [00:47<00:46, 1.32s/it]\u001b[A\n",
"Iteration: 79%|███████▉ | 46/58 [01:00<00:15, 1.32s/it]\u001b[A\n",
"Epoch: 100%|██████████| 5/5 [06:18<00:00, 75.79s/it]2s/it]\u001b[A"
]
},
{
@ -456,6 +455,21 @@
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Iteration: 0%| | 0/58 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Only 1 CUDA device is available. Data parallelism is not possible.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
@ -506,11 +520,11 @@
"text": [
" precision recall f1-score support\n",
"\n",
" LOC 0.91 0.84 0.88 558\n",
" X 0.97 0.98 0.98 2060\n",
" PER 0.90 0.95 0.93 569\n",
" ORG 0.70 0.82 0.76 553\n",
" MISC 0.74 0.71 0.72 399\n",
" PER 0.90 0.95 0.93 569\n",
" X 0.97 0.98 0.98 2060\n",
" LOC 0.91 0.84 0.88 558\n",
"\n",
"micro avg 0.89 0.91 0.90 4139\n",
"macro avg 0.90 0.91 0.90 4139\n",
@ -548,9 +562,9 @@
"text": [
" precision recall f1-score support\n",
"\n",
" LOC 0.90 0.85 0.88 507\n",
" PER 0.89 0.94 0.92 460\n",
" ORG 0.65 0.81 0.72 440\n",
" LOC 0.90 0.85 0.88 507\n",
" MISC 0.67 0.68 0.68 328\n",
"\n",
"micro avg 0.77 0.83 0.80 1735\n",

Просмотреть файл

@ -93,14 +93,15 @@ class Tokenizer:
def preprocess_ner_tokens(
self,
text,
max_seq_length=BERT_MAX_LEN,
max_len=BERT_MAX_LEN,
labels=None,
label_map=None,
trailing_piece_tag="X",
):
"""
Preprocesses input tokens, involving the following steps
1. Convert input token to token ids
Preprocesses input text, involving the following steps
0. Tokenize input text.
1. Convert string tokens to token ids.
2. Convert input labels to label ids, if labels and label_map are
provided.
3. If a word is tokenized into multiple pieces of tokens by the
@ -111,7 +112,7 @@ class Tokenizer:
Args:
text (list): List of input sentences/paragraphs.
max_seq_length (int, optional): Maximum length of the list of
max_len (int, optional): Maximum length of the list of
tokens. Lists longer than this are truncated and shorter
ones are padded with "O"s. Default value is BERT_MAX_LEN=512.
labels (list, optional): List of token label lists. Default
@ -119,7 +120,7 @@ class Tokenizer:
label_map (dict, optional): Dictionary for mapping original token
labels (which may be string type) to integers. Default value
is None.
trailing_piece_tag (str, optional): Tags used to label trailing
trailing_piece_tag (str, optional): Tag used to label trailing
word pieces. For example, "playing" is broken into "play"
and "##ing", "play" preserves its original label and "##ing"
is labeled as trailing_piece_tag. Default value is "X".
@ -144,13 +145,13 @@ class Tokenizer:
each sublist contains token labels of a input
sentence/paragraph, if labels is provided.
"""
if max_seq_length > BERT_MAX_LEN:
if max_len > BERT_MAX_LEN:
warnings.warn(
"setting max_len to max allowed tokens: {}".format(
BERT_MAX_LEN
)
)
max_seq_length = BERT_MAX_LEN
max_len = BERT_MAX_LEN
label_available = True
if labels is None:
@ -173,9 +174,9 @@ class Tokenizer:
new_labels.append(tag)
tokens.append(sub_word)
if len(tokens) > max_seq_length:
tokens = tokens[:max_seq_length]
new_labels = new_labels[:max_seq_length]
if len(tokens) > max_len:
tokens = tokens[:max_len]
new_labels = new_labels[:max_len]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
@ -184,8 +185,8 @@ class Tokenizer:
input_mask = [1.0] * len(input_ids)
# Zero-pad up to the max sequence length.
padding = [0.0] * (max_seq_length - len(input_ids))
label_padding = ["O"] * (max_seq_length - len(input_ids))
padding = [0.0] * (max_len - len(input_ids))
label_padding = ["O"] * (max_len - len(input_ids))
input_ids += padding
input_mask += padding

Просмотреть файл

@ -8,11 +8,11 @@ import torch.nn as nn
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam
from tqdm import tqdm
from utils_nlp.bert.common import BERT_MAX_LEN, Language
from utils_nlp.bert.common import Language
from utils_nlp.pytorch.device_utils import get_device, move_to_device
class SequenceClassifier:
class BERTSequenceClassifier:
"""BERT-based sequence classifier"""
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
@ -20,9 +20,10 @@ class SequenceClassifier:
Args:
language (Language, optional): The pretrained model's language.
Defaults to Language.ENGLISH.
num_labels (int, optional): The number of unique labels in the training data.
Defaults to 2.
cache_dir (str, optional): Location of BERT's cache directory. Defaults to ".".
num_labels (int, optional): The number of unique labels in the
training data. Defaults to 2.
cache_dir (str, optional): Location of BERT's cache directory.
Defaults to ".".
"""
if num_labels < 2:
raise Exception("Number of labels should be at least 2.")
@ -54,14 +55,15 @@ class SequenceClassifier:
labels (list): List of training labels.
device (str, optional): Device used for training ("cpu" or "gpu").
Defaults to "gpu".
num_gpus (int, optional): The number of gpus to use.
If None is specified, all available GPUs will be used.
Defaults to None.
num_epochs (int, optional): Number of training epochs. Defaults to 1.
num_gpus (int, optional): The number of gpus to use.
If None is specified, all available GPUs
will be used. Defaults to None.
num_epochs (int, optional): Number of training epochs.
Defaults to 1.
batch_size (int, optional): Training batch size. Defaults to 32.
lr (float): Learning rate of the Adam optimizer. Defaults to 2e-5.
verbose (bool, optional): If True, shows the training progress and loss values.
Defaults to True.
verbose (bool, optional): If True, shows the training progress and
loss values. Defaults to True.
"""
device = get_device("cpu" if num_gpus == 0 else "gpu")
@ -142,14 +144,14 @@ class SequenceClassifier:
del [x_batch, y_batch, mask_batch]
torch.cuda.empty_cache()
def predict(self, token_ids, input_mask, num_gpus=1, batch_size=32):
def predict(self, token_ids, input_mask, num_gpus=None, batch_size=32):
"""Scores the given dataset and returns the predicted classes.
Args:
token_ids (list): List of training token lists.
input_mask (list): List of input mask lists.
num_gpus (int, optional): The number of gpus to use.
If None is specified, all available GPUs will be used.
Defaults to 1.
num_gpus (int, optional): The number of gpus to use.
If None is specified, all available GPUs
will be used. Defaults to None.
batch_size (int, optional): Scoring batch size. Defaults to 32.
Returns:
[ndarray]: Predicted classes.

Просмотреть файл

@ -19,7 +19,7 @@ from .common import Language, create_data_loader
from utils_nlp.pytorch.device_utils import get_device, move_to_device
class BertTokenClassifier:
class BERTTokenClassifier:
"""BERT-based token classifier."""
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
@ -123,12 +123,12 @@ class BertTokenClassifier:
labels (list): List of lists, each sublist contains numerical
token labels of an input sentence/paragraph.
num_gpus (int, optional): The number of GPUs to use.
If None, all available GPUs will be used. Defaults to 1.
If None, all available GPUs will be used. Defaults to None.
num_epochs (int, optional): Number of training epochs.
Defaults to 1.
batch_size (int, optional): Training batch size. Defaults to 32.
learning_rate (float, optional): learning rate of the BertAdam
optimizer. Defaults to 5e-5.
optimizer. Defaults to 2e-5.
warmup_proportion (float, optional): Proportion of training to
perform linear learning rate warmup for. E.g., 0.1 = 10% of
training. Defaults to None.
@ -184,7 +184,7 @@ class BertTokenClassifier:
tr_loss += loss.item()
nb_tr_steps += 1
# Update paramters based on the current gradient
# Update parameters based on current gradients
optimizer.step()
# Reset parameter gradients to zero
optimizer.zero_grad()
@ -193,7 +193,7 @@ class BertTokenClassifier:
print("Train loss: {}".format(train_loss))
def predict(
self, token_ids, input_mask, labels=None, batch_size=32, num_gpus=1
self, token_ids, input_mask, labels=None, batch_size=32, num_gpus=None
):
"""
Predict token labels on the testing data.
@ -206,13 +206,13 @@ class BertTokenClassifier:
the attention mask of the input token list, 1 for input
tokens and 0 for padded tokens, so that padded tokens are
not attended to.
labels (list, optional): List of lists, each sublist contains
labels (list, optional): List of lists. Each sublist contains
numerical token labels of an input sentence/paragraph.
If provided, it's used to compute the evaluation loss.
Default value is None.
batch_size (int, optional): Training batch size. Defaults to 32.
batch_size (int, optional): Testing batch size. Defaults to 32.
num_gpus (int, optional): The number of GPUs to use.
If None, all available GPUs will be used. Defaults to 1.
If None, all available GPUs will be used. Defaults to None.
Returns:
list: List of lists of predicted token labels.
@ -317,7 +317,7 @@ def postprocess_token_labels(
]
if remove_trailing_word_pieces and trailing_token_mask:
# Remove hte padded values in trailing_token_mask first
# Remove the padded values in trailing_token_mask first
token_mask_no_padding = [
[token for token, padding in zip(t_mask, p_mask) if padding == 1]
for t_mask, p_mask in zip(trailing_token_mask, input_mask)