Updated NER utils and notebook use the new Transformer class

This commit is contained in:
Ke Huang 2019-11-15 16:34:04 -05:00
Родитель a9016ecb40
Коммит 6e19905a13
4 изменённых файлов: 140 добавлений и 70 удалений

Просмотреть файл

@ -77,12 +77,13 @@
"import sys\n",
"import os\n",
"import scrapbook as sb\n",
"import torch\n",
"\n",
"from tempfile import TemporaryDirectory\n",
"from utils_nlp.dataset import wikigold\n",
"from utils_nlp.common.timer import Timer\n",
"from seqeval.metrics import classification_report\n",
"from utils_nlp.models.transformers.named_entity_recognition import *"
"from utils_nlp.models.transformers.named_entity_recognition import TokenClassifier"
]
},
{
@ -95,7 +96,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"# fraction of the dataset used for testing\n",
@ -126,8 +131,8 @@
"torch.manual_seed(RANDOM_SEED)\n",
"\n",
"# model configurations\n",
"MODEL_NAME = \"bert-base-uncased\"\n",
"DO_LOWER_CASE = True\n",
"MODEL_NAME = \"bert-base-cased\"\n",
"DO_LOWER_CASE = False\n",
"MAX_SEQ_LENGTH = 200\n",
"TRAILING_PIECE_TAG = \"X\"\n",
"DEVICE = \"cuda\"\n",
@ -155,7 +160,7 @@
"metadata": {},
"outputs": [],
"source": [
"train_dataset, test_dataset, label_map = wikigold.load_dataset(\n",
"train_dataloader, test_dataloader, label_map, test_dataset = wikigold.load_dataset(\n",
" local_path=DATA_PATH,\n",
" test_fraction=TEST_DATA_FRACTION,\n",
" random_seed=RANDOM_SEED,\n",
@ -165,7 +170,9 @@
" to_lower=DO_LOWER_CASE,\n",
" cache_dir=CACHE_DIR,\n",
" max_len=MAX_SEQ_LENGTH,\n",
" trailing_piece_tag=TRAILING_PIECE_TAG\n",
" trailing_piece_tag=TRAILING_PIECE_TAG,\n",
" batch_size=BATCH_SIZE,\n",
" num_gpus=None\n",
")"
]
},
@ -194,10 +201,8 @@
"# Fine tune the model using the training dataset\n",
"with Timer() as t:\n",
" model.fit(\n",
" train_dataset=train_dataset,\n",
" device=DEVICE,\n",
" train_dataloader=train_dataloader,\n",
" num_epochs=NUM_TRAIN_EPOCHS,\n",
" batch_size=BATCH_SIZE,\n",
" num_gpus=None,\n",
" local_rank=-1,\n",
" weight_decay=0.0,\n",
@ -205,7 +210,7 @@
" adam_epsilon=1e-8,\n",
" warmup_steps=0,\n",
" verbose=True,\n",
" seed=RANDOM_SEED,\n",
" seed=RANDOM_SEED\n",
" )\n",
"\n",
"print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))\n"
@ -228,11 +233,9 @@
"source": [
"with Timer() as t:\n",
" preds = model.predict(\n",
" eval_dataset=test_dataset,\n",
" device=DEVICE,\n",
" batch_size=BATCH_SIZE,\n",
" local_rank=-1,\n",
" verbose=True,\n",
" eval_dataloader=test_dataloader,\n",
" num_gpus=None,\n",
" verbose=True\n",
" )\n",
"\n",
"print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
@ -251,16 +254,7 @@
"metadata": {},
"outputs": [],
"source": [
"label_id2str = {v: k for k, v in label_map.items()}\n",
"true_ids = test_dataset.tensors[3].data.numpy()\n",
"true_labels = []\n",
"\n",
"for sentence in true_ids:\n",
" token_labels = []\n",
" for token_id in sentence:\n",
" token_labels.append(label_id2str[token_id])\n",
" \n",
" true_labels.append(token_labels)"
"true_labels = model.get_true_test_labels(label_map=label_map, dataset=test_dataset)"
]
},
{
@ -314,9 +308,9 @@
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.6 - AzureML",
"language": "python",
"name": "python3"
"name": "python3-azureml"
},
"language_info": {
"codemirror_mode": {

Просмотреть файл

@ -16,6 +16,7 @@ def test_ner_wikigold_bert(notebooks, tmp):
notebook_path,
OUTPUT_NOTEBOOK,
parameters={
"DATA_PATH": tmp,
"CACHE_DIR": tmp
},
kernel_name=KERNEL_NAME,

Просмотреть файл

@ -89,7 +89,9 @@ def load_dataset(
to_lower=True,
cache_dir=TemporaryDirectory().name,
max_len=MAX_SEQ_LEN,
trailing_piece_tag="X"
trailing_piece_tag="X",
batch_size=32,
num_gpus=None
):
"""
Load the wikigold dataset and split into training and testing datasets.
@ -120,26 +122,18 @@ def load_dataset(
For example, "criticize" is broken into "critic" and "##ize", "critic"
preserves its original label and "##ize" is labeled as trailing_piece_tag.
Default value is "X".
batch_size (int, optional): The batch size for training and testing.
Defaults to 32.
num_gpus (int, optional): The number of GPUs.
Defaults to None.
Returns:
tuple. The tuple contains three elements.
train_dataset (TensorDataset): A TensorDataset containing the following four tensors.
1. input_ids_all: Tensor. Each sublist contains numerical values,
i.e. token ids, corresponding to the tokens in the input
text data.
2. input_mask_all: Tensor. Each sublist contains the attention
mask of the input token id list, 1 for input tokens and 0 for
padded tokens, so that padded tokens are not attended to.
3. trailing_token_mask_all: Tensor. Each sublist is
a boolean list, True for the first word piece of each
original word, False for the trailing word pieces,
e.g. "##ize". This mask is useful for removing the
predictions on trailing word pieces, so that each
original word in the input text has a unique predicted
label.
4. label_ids_all: Tensor, each sublist contains token labels of
a input sentence/paragraph, if labels is provided. If the
`labels` argument is not provided, it will not return this tensor.
tuple. The tuple contains four elements.
train_dataload (DataLoader): a PyTorch DataLoader instance for training.
test_dataload (DataLoader): a PyTorch DataLoader instance for testing.
label_map (dict): A dictionary object to map a label (str) to an ID (int).
test_dataset (TensorDataset): A TensorDataset containing the following four tensors.
1. input_ids_all: Tensor. Each sublist contains numerical values,
@ -158,8 +152,6 @@ def load_dataset(
4. label_ids_all: Tensor, each sublist contains token labels of
a input sentence/paragraph, if labels is provided. If the
`labels` argument is not provided, it will not return this tensor.
label_map (dict): A dictionary object to map a label (str) to an ID (int).
"""
train_df, test_df = load_train_test_dfs(
@ -214,4 +206,20 @@ def load_dataset(
trailing_piece_tag=trailing_piece_tag
)
return (train_dataset, test_dataset, label_map)
train_dataloader = processor.create_dataloader_from_dataset(
train_dataset,
shuffle=True,
batch_size=batch_size,
num_gpus=num_gpus,
distributed=False
)
test_dataloader = processor.create_dataloader_from_dataset(
test_dataset,
shuffle=False,
batch_size=batch_size,
num_gpus=num_gpus,
distributed=False
)
return (train_dataloader, test_dataloader, label_map, test_dataset)

Просмотреть файл

@ -2,14 +2,17 @@
# Licensed under the MIT License.
import logging
from collections import Iterable
import numpy as np
import torch
import torch.nn as nn
from collections import Iterable
from torch.utils.data import TensorDataset
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForTokenClassification
from utils_nlp.common.pytorch_utils import get_device
from utils_nlp.models.transformers.common import MAX_SEQ_LEN, TOKENIZER_CLASS, Transformer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
TC_MODEL_CLASS = {k: BertForTokenClassification for k in BERT_PRETRAINED_MODEL_ARCHIVE_MAP}
@ -241,6 +244,27 @@ class TokenClassificationProcessor:
)
return td
def create_dataloader_from_dataset(
self,
dataset,
shuffle=False,
batch_size=32,
num_gpus=None,
distributed=False
):
if num_gpus is None:
num_gpus = torch.cuda.device_count()
batch_size = batch_size * max(1, num_gpus)
if distributed:
sampler = DistributedSampler(dataset)
else:
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
return DataLoader(dataset, sampler=sampler, batch_size=batch_size)
class TokenClassifier(Transformer):
"""
@ -269,27 +293,24 @@ class TokenClassifier(Transformer):
def fit(
self,
train_dataset,
train_dataloader,
num_epochs=1,
batch_size=32,
num_gpus=None,
local_rank=-1,
weight_decay=0.0,
learning_rate=5e-5,
adam_epsilon=1e-8,
warmup_steps=0,
verbose=False,
verbose=True,
seed=None,
):
"""
Fit the TokenClassifier model using the given training dataset.
Args:
train_dataset (Dataset): Dataset for training.
train_dataloader (DataLoader): DataLoader instance for training.
num_epochs (int, optional): Number of training epochs.
Defaults to 1.
batch_size (int, optional): Training batch size.
Defaults to 32.
num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will
be used. If set to 0 or GPUs are not available, CPU device will
be used. Defaults to None.
@ -310,15 +331,17 @@ class TokenClassifier(Transformer):
"""
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=local_rank)
self.model.to(device)
if isinstance(self.model, nn.DataParallel):
self.model.module.to(device)
else:
self.model.to(device)
super().fine_tune(
train_dataset=train_dataset,
train_dataloader=train_dataloader,
get_inputs=TokenClassificationProcessor.get_inputs,
device=device,
n_gpu=num_gpus,
num_train_epochs=num_epochs,
per_gpu_train_batch_size=batch_size,
weight_decay=weight_decay,
learning_rate=learning_rate,
adam_epsilon=adam_epsilon,
@ -327,19 +350,20 @@ class TokenClassifier(Transformer):
seed=seed,
)
def predict(self, eval_dataset, batch_size=32, num_gpus=None, local_rank=-1, verbose=False):
def predict(
self,
eval_dataloader,
num_gpus=None,
verbose=True
):
"""
Test on an evaluation dataset and get the token label predictions.
Args:
eval_dataset (TensorDataset): A TensorDataset for evaluation.
batch_size (int, optional): The batch size for evaluation.
Defaults to 32.
num_gpus (int, optional): The number of GPUs to use. If None, all available GPUs will
be used. If set to 0 or GPUs are not available, CPU device will
be used. Defaults to None.
local_rank (int, optional): Whether need to do distributed training.
Defaults to -1, no distributed training.
verbose (bool, optional): Verbose model.
Defaults to False.
@ -351,15 +375,17 @@ class TokenClassifier(Transformer):
"""
device, num_gpus = get_device(num_gpus=num_gpus, local_rank=-1)
if isinstance(self.model, nn.DataParallel):
self.model.module.to(device)
else:
self.model.to(device)
preds = list(
super().predict(
eval_dataset=eval_dataset,
eval_dataloader=eval_dataloader,
get_inputs=TokenClassificationProcessor.get_inputs,
device=device,
per_gpu_eval_batch_size=batch_size,
n_gpu=num_gpus,
local_rank=local_rank,
verbose=verbose,
verbose=verbose
)
)
preds_np = np.concatenate(preds)
@ -374,6 +400,7 @@ class TokenClassifier(Transformer):
The shape of the ndarray is [number_of_examples, sequence_length, number_of_labels].
label_map (dict): A dictionary object to map a label (str) to an ID (int).
dataset (TensorDataset): The TensorDataset for evaluation.
dataset (Dataset): The test Dataset instance.
Returns:
list: A list of lists. The size of the retured list is the number of testing samples.
@ -411,3 +438,43 @@ class TokenClassifier(Transformer):
one_sample.append(label_id2str[label_id])
labels.append(one_sample)
return labels
def get_true_test_labels(self, label_map, dataset):
"""
Get the true testing label values.
Args:
label_map (dict): A dictionary object to map a label (str) to an ID (int).
dataset (TensorDataset): The TensorDataset for evaluation.
dataset (Dataset): The test Dataset instance.
Returns:
list: A list of lists. The size of the retured list is the number of testing samples.
Each sublist represents the predicted label for each token.
"""
num_samples = len(dataset.tensors[0])
label_id2str = {v: k for k, v in label_map.items()}
attention_mask_all = dataset.tensors[1].data.numpy()
trailing_mask_all = dataset.tensors[2].data.numpy()
label_ids_all = dataset.tensors[3].data.numpy()
seq_len = len(trailing_mask_all[0])
labels = []
for idx in range(num_samples):
attention_mask = attention_mask_all[idx]
trailing_mask = trailing_mask_all[idx]
label_ids = label_ids_all[idx]
one_sample = []
for sid in range(seq_len):
if attention_mask[sid] == 0:
break
if not trailing_mask[sid]:
continue
label_id = label_ids[sid]
one_sample.append(label_id2str[label_id])
labels.append(one_sample)
return labels