updated prerocessing, utils, classification

This commit is contained in:
Said Bleik 2019-05-21 16:45:23 -04:00
Родитель 49bb116474
Коммит 63e546ab3c
6 изменённых файлов: 233 добавлений и 276 удалений

Просмотреть файл

@ -24,7 +24,8 @@
"import pandas as pd\n",
"import utils_nlp.dataset.yahoo_answers as ya_dataset\n",
"from utils_nlp.eval.classification import eval_classification\n",
"from utils_nlp.classification.bert import BERTSequenceClassifier, Language\n",
"from utils_nlp.bert.sequence_classification import SequenceClassifier\n",
"from utils_nlp.bert.common import Language, Tokenizer\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np"
@ -45,7 +46,7 @@
"DEVICE = \"gpu\"\n",
"UPDATE_EMBEDDINGS = False\n",
"NUM_EPOCHS = 1\n",
"NUM_ROWS_TRAIN = 10000 # number of training examples to read\n",
"NUM_ROWS_TRAIN = 15000 # number of training examples to read\n",
"NUM_ROWS_TEST = 10000 # number of test examples to read"
]
},
@ -58,26 +59,23 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"if (not os.path.isfile(os.path.join(DATA_FOLDER, TRAIN_FILE))) or (\n",
" not os.path.isfile(os.path.join(DATA_FOLDER, TEST_FILE))\n",
"):\n",
" ya_dataset.download(DATA_FOLDER)"
"ya_dataset.download(DATA_FOLDER)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read and Preprocess Dataset"
"## Read Dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -96,6 +94,30 @@
"text_test = ya_dataset.get_text(df_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenize and Preprocess"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer(Language.ENGLISH, to_lower=False, cache_dir=BERT_CACHE_DIR)\n",
"\n",
"# tokenize\n",
"tokens_train = tokenizer.tokenize(text_train)\n",
"tokens_test = tokenizer.tokenize(text_test)\n",
"\n",
"# get BERT-format tokens (padded and truncated)\n",
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(tokens_train, MAX_LEN)\n",
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(tokens_test, MAX_LEN)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -105,41 +127,32 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"classifier = BERTSequenceClassifier(\n",
"classifier = SequenceClassifier(\n",
" language=Language.ENGLISH, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/1; batch:1/312; loss:2.334407091140747\n",
"epoch:1/1; batch:97/312; loss:1.4464187622070312\n",
"epoch:1/1; batch:193/312; loss:0.4434894919395447\n",
"epoch:1/1; batch:289/312; loss:0.7371340394020081\n"
]
}
],
"outputs": [],
"source": [
"# train\n",
"classifier.fit(text=text_train,\n",
" labels=labels_train,\n",
" max_len=MAX_LEN,\n",
" device=DEVICE,\n",
" use_multiple_gpus=True,\n",
" num_epochs=NUM_EPOCHS,\n",
" batch_size=BATCH_SIZE,\n",
" verbose=True)"
"classifier.fit(\n",
" tokens=tokens_train,\n",
" input_mask=mask_train,\n",
" labels=labels_train, \n",
" device=DEVICE,\n",
" use_multiple_gpus=True,\n",
" num_epochs=NUM_EPOCHS,\n",
" batch_size=BATCH_SIZE,\n",
" verbose=True,\n",
")"
]
},
{
@ -155,7 +168,7 @@
"metadata": {},
"outputs": [],
"source": [
"preds = classifier.predict(text=text_test, device=\"gpu\", batch_size=BATCH_SIZE)"
"preds = classifier.predict(tokens=tokens_test, input_mask=mask_test, device=DEVICE, batch_size=BATCH_SIZE)"
]
},
{
@ -167,127 +180,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" accuracy: 0.6608\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.5724</td>\n",
" <td>0.3835</td>\n",
" <td>0.4593</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.6680</td>\n",
" <td>0.7676</td>\n",
" <td>0.7143</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.6569</td>\n",
" <td>0.8488</td>\n",
" <td>0.7406</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.4837</td>\n",
" <td>0.4521</td>\n",
" <td>0.4673</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.7724</td>\n",
" <td>0.8596</td>\n",
" <td>0.8136</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.8457</td>\n",
" <td>0.8317</td>\n",
" <td>0.8387</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0.5804</td>\n",
" <td>0.5214</td>\n",
" <td>0.5493</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0.6691</td>\n",
" <td>0.7032</td>\n",
" <td>0.6858</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0.6826</td>\n",
" <td>0.7130</td>\n",
" <td>0.6974</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>0.7459</td>\n",
" <td>0.6761</td>\n",
" <td>0.7093</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" precision recall f1\n",
"0 0.5724 0.3835 0.4593\n",
"1 0.6680 0.7676 0.7143\n",
"2 0.6569 0.8488 0.7406\n",
"3 0.4837 0.4521 0.4673\n",
"4 0.7724 0.8596 0.8136\n",
"5 0.8457 0.8317 0.8387\n",
"6 0.5804 0.5214 0.5493\n",
"7 0.6691 0.7032 0.6858\n",
"8 0.6826 0.7130 0.6974\n",
"9 0.7459 0.6761 0.7093"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# eval\n",
"eval_results = eval_classification(labels_test, preds)\n",
@ -319,7 +214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.6.8"
}
},
"nbformat": 4,

Просмотреть файл

@ -1,12 +1,70 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from pytorch_pretrained_bert.tokenization import BertTokenizer
from enum import Enum
# Max supported sequence length
BERT_MAX_LEN = 512
class Language(Enum):
"""An enumeration of the supported languages."""
ENGLISH = "bert-base-uncased"
ENGLISHCASED = "bert-base-cased"
ENGLISHLARGE = "bert-large-uncased"
ENGLISHLARGECASED = "bert-large-cased"
CHINESE = "bert-base-chinese"
MULTILINGUAL = "bert-base-multilingual-cased"
class Tokenizer:
def __init__(
self, language=Language.ENGLISH, to_lower=False, cache_dir="."
):
"""Initializes the tokenizer and the underlying pretrained tokenizer.
Args:
language (Language, optional): The pretrained model's language.
Defaults to Language.ENGLISH.
cache_dir (str, optional): Location of BERT's cache directory. Defaults to ".".
"""
self.tokenizer = BertTokenizer.from_pretrained(
language.value, do_lower_case=to_lower, cache_dir=cache_dir
)
self.language = language
def tokenize(self, text):
tokens = [self.tokenizer.tokenize(x) for x in text]
return tokens
def preprocess_classification_tokens(self, tokens, max_len):
"""Preprocessing of input tokens:
- add BERT sentence markers ([CLS] and [SEP])
- map tokens to indices
- pad and truncate sequences
- create an input_mask
Args:
tokens ([type]): List of tokens to preprocess.
max_len ([type]): Maximum length of sequence.
Returns:
list of preprocesssed token lists
list of input mask lists
"""
if max_len > BERT_MAX_LEN:
print(
"setting max_len to max allowed tokens: {}".format(
BERT_MAX_LEN
)
)
max_len = BERT_MAX_LEN
# truncate and add BERT sentence markers
tokens = [["[CLS]"] + x[0 : max_len - 2] + ["[SEP]"] for x in tokens]
# convert tokens to indices
tokens = [self.tokenizer.convert_tokens_to_ids(x) for x in tokens]
# pad sequence
tokens = [x + [0] * (max_len - len(x)) for x in tokens]
# create input mask
input_mask = [[min(1, x) for x in y] for y in tokens]
return tokens, input_mask

Просмотреть файл

@ -3,22 +3,20 @@
import random
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.tokenization import BertTokenizer
from utils_nlp.pytorch.device import get_device
from utils_nlp.bert.common import Language
from utils_nlp.bert.common import Language, BERT_MAX_LEN
class BERTSequenceClassifier:
class SequenceClassifier:
"""BERT-based sequence classifier"""
BERT_MAX_LEN = 512
def __init__(self, language=Language.ENGLISH, num_labels=2, cache_dir="."):
"""Initializes the classifier and the underlying pretrained model and tokenizer.
"""Initializes the classifier and the underlying pretrained model.
Args:
language (Language, optional): The pretrained model's language.
Defaults to Language.ENGLISH.
@ -30,39 +28,26 @@ class BERTSequenceClassifier:
raise Exception("Number of labels should be at least 2.")
self.language = language
self.num_labels = num_labels
self._max_len = BERTSequenceClassifier.BERT_MAX_LEN
self.num_labels = num_labels
self.cache_dir = cache_dir
# create tokenizer
self.tokenizer = BertTokenizer.from_pretrained(
language.value, do_lower_case=False, cache_dir=cache_dir
)
# create classifier
self.model = BertForSequenceClassification.from_pretrained(
self._model = BertForSequenceClassification.from_pretrained(
language.value, cache_dir=cache_dir, num_labels=num_labels
)
self._is_trained = False
def _get_tokens(self, text):
tokens = [
self.tokenizer.tokenize(x)[0 : self._max_len - 2] for x in text
]
tokens = [["[CLS]"] + x + ["[SEP]"] for x in tokens]
tokens = [self.tokenizer.convert_tokens_to_ids(x) for x in tokens]
tokens = [x + [0] * (self._max_len - len(x)) for x in tokens]
input_mask = [[min(1, x) for x in y] for y in tokens]
return tokens, input_mask
def get_model(self):
"""Returns the underlying PyTorch model
BertForSequenceClassification or DataParallel (when multiple GPUs are used)."""
return self.model
return self._model
def fit(
self,
text,
labels,
max_len=512,
tokens,
input_mask,
labels,
device="gpu",
use_multiple_gpus=True,
num_epochs=1,
@ -71,7 +56,8 @@ class BERTSequenceClassifier:
):
"""Fine-tunes the BERT classifier using the given training data.
Args:
text (list): List of training text documents.
tokens (list): List of training token lists.
input_mask (list): List of input mask lists.
labels (list): List of training labels.
max_len (int, optional): Maximum number of tokens
(documents will be truncated or padded).
@ -86,26 +72,13 @@ class BERTSequenceClassifier:
Defaults to True.
"""
device = get_device(device)
self.model.to(device)
if max_len > BERTSequenceClassifier.BERT_MAX_LEN:
print(
"setting max_len to max allowed tokens: {}".format(
BERTSequenceClassifier.BERT_MAX_LEN
)
)
max_len = BERTSequenceClassifier.BERT_MAX_LEN
self._max_len = max_len
# tokenize, truncate, and pad
tokens, input_mask = self._get_tokens(text=text)
self._model.to(device)
# define loss function
loss_func = nn.CrossEntropyLoss().to(device)
# define optimizer and model parameters
param_optimizer = list(self.model.named_parameters())
param_optimizer = list(self._model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
@ -127,77 +100,79 @@ class BERTSequenceClassifier:
]
opt = BertAdam(optimizer_grouped_parameters, lr=2e-5)
n_gpus = 0
if use_multiple_gpus:
if torch.cuda.device_count() > 1:
self.model = nn.DataParallel(self.model)
n_gpus = torch.cuda.device_count()
if n_gpus > 1:
if isinstance(self._model, nn.DataParallel):
self._model = nn.DataParallel(self._model)
# train
self.model.train() # training mode
self._model.train() # training mode
num_examples = len(tokens)
num_batches = int(num_examples / batch_size)
for epoch in range(num_epochs):
for i in range(num_batches):
# get random batch
start = int(random.random() * num_examples)
end = start + batch_size
x_batch = torch.tensor(
tokens[start:end], dtype=torch.long, device=device
)
y_batch = torch.tensor(
labels[start:end], dtype=torch.long, device=device
)
mask_batch = torch.tensor(
input_mask[start:end], dtype=torch.long, device=device
)
opt.zero_grad()
y_h = self.model(
input_ids=x_batch,
token_type_ids=None,
attention_mask=mask_batch,
labels=None,
)
with tqdm(total=num_epochs * num_batches) as pbar:
for epoch in range(num_epochs):
for i in range(num_batches):
# get random batch
start = int(random.random() * num_examples)
end = start + batch_size
x_batch = torch.tensor(
tokens[start:end], dtype=torch.long, device=device
)
y_batch = torch.tensor(
labels[start:end], dtype=torch.long, device=device
)
mask_batch = torch.tensor(
input_mask[start:end], dtype=torch.long, device=device
)
loss = loss_func(y_h, y_batch)
loss.backward()
opt.step()
opt.zero_grad()
y_h = self._model(
input_ids=x_batch,
token_type_ids=None,
attention_mask=mask_batch,
labels=None,
)
if verbose:
if i % (2 * batch_size) == 0:
print(
"epoch:{}/{}; batch:{}/{}; loss:{}".format(
epoch + 1,
num_epochs,
i + 1,
num_batches,
loss.data,
loss = loss_func(y_h, y_batch)
if n_gpus > 1:
loss = loss.mean()
loss.backward()
opt.step(tqdm)
pbar.update()
if verbose:
if i % (batch_size) == 0:
print(
"epoch:{}/{}; batch:{}/{}; loss:{}".format(
epoch + 1,
num_epochs,
i + 1,
num_batches,
loss.data,
)
)
)
self._is_trained = True
def predict(self, text, device="gpu", batch_size=32):
def predict(self, tokens, input_mask, device="gpu", batch_size=32):
"""Scores the given dataset and returns the predicted classes.
Args:
text (list): List of text documents to score.
tokens (list): List of training token lists.
input_mask (list): List of input mask lists.
device (str, optional): Device used for scoring ("cpu" or "gpu"). Defaults to "gpu".
batch_size (int, optional): Scoring batch size. Defaults to 32.
Returns:
[ndarray]: Predicted classes.
"""
if not self._is_trained:
raise Exception("Please train model before scoring")
# tokenize, truncate, and pad
tokens, input_mask = self._get_tokens(text=text)
device = get_device(device)
self.model.to(device)
self._model.to(device)
# score
self.model.eval()
self._model.eval()
preds = []
for i in range(0, len(tokens), batch_size):
for i in tqdm(range(0, len(tokens), batch_size)):
x_batch = tokens[i : i + batch_size]
mask_batch = input_mask[i : i + batch_size]
x_batch = torch.tensor(x_batch, dtype=torch.long, device=device)
@ -205,7 +180,7 @@ class BERTSequenceClassifier:
mask_batch, dtype=torch.long, device=device
)
with torch.no_grad():
p_batch = self.model(
p_batch = self._model(
input_ids=x_batch,
token_type_ids=None,
attention_mask=mask_batch,

29
utils_nlp/dataset/imdb.py Normal file
Просмотреть файл

@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""IMDB dataset utils"""
import os
import pandas as pd
from utils_nlp.dataset.url_utils import maybe_download, extract_tar
URL = "https://s3.amazonaws.com/fast-ai-nlp/imdb.tgz"
def download(dir_path):
"""Downloads and extracts the dataset files"""
file_name = URL.split("/")[-1]
maybe_download(URL, file_name, dir_path)
extract_tar(os.path.join(dir_path, file_name))
def get_df(dir_path, label):
"""Returns a pandas dataframe given a path,
and appends the provided label"""
text = []
for doc_file in os.listdir(dir_path):
with open(os.path.join(dir_path, doc_file)) as f:
text.append(f.read())
labels = [label] * len(text)
return pd.DataFrame({"text": text, "label": labels})

Просмотреть файл

@ -3,6 +3,7 @@
import os
from urllib.request import urlretrieve
import tarfile
import logging
from contextlib import contextmanager
from tempfile import TemporaryDirectory
@ -59,6 +60,21 @@ def maybe_download(
return filepath
def extract_tar(file_path, extract_to="."):
"""Extracts all contents of a tar archive file.
Args:
file_path (str): Path of file to extract.
extract_to (str, optional): Destination directory. Defaults to ".".
"""
if not os.path.exists(file_path):
raise IOError("File doesn't exist")
if not os.path.exists(extract_to):
raise IOError("Destination directory doesn't exist")
tar = tarfile.open(file_path)
tar.extractall(path=extract_to)
tar.close()
@contextmanager
def download_path(path):
tmp_dir = TemporaryDirectory()

Просмотреть файл

@ -1,28 +1,28 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Yahoo! Answers dataset utils"""
import os
import pandas as pd
import numpy as np
import random
import pickle
import urllib3
import tarfile
import io
from utils_nlp.dataset.url_utils import maybe_download, extract_tar
URL = "https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz"
def download(dir_path):
con = urllib3.PoolManager()
resp = con.request("GET", URL)
tar = tarfile.open(fileobj=io.BytesIO(resp.data))
tar.extractall(path=dir_path)
tar.close()
"""Downloads and extracts the dataset files"""
file_name = URL.split("/")[-1]
maybe_download(URL, file_name, dir_path)
extract_tar(os.path.join(dir_path, file_name))
def read_data(data_file, nrows=None):
return pd.read_csv(data_file, header=None, nrows=nrows)
def clean_data(df):
def get_text(df):
df.fillna("", inplace=True)
text = df.iloc[:, 1] + " " + df.iloc[:, 2] + " " + df.iloc[:, 3]
text = text.str.replace(r"[^A-Za-z ]", "").str.lower()
@ -33,19 +33,3 @@ def clean_data(df):
def get_labels(df):
return list(df[0] - 1)
def get_batch_rnd(X, input_mask, y, n, batch_size):
i = int(random.random() * n)
X_b = X[i : i + batch_size]
y_b = y[i : i + batch_size]
m_b = input_mask[i : i + batch_size]
return X_b, m_b, y_b
def get_batch_by_idx(X, input_mask, y, i, batch_size):
X_b = X[i : i + batch_size]
y_b = y[i : i + batch_size]
m_b = input_mask[i : i + batch_size]
return X_b, m_b, y_b