minor updates to seq classification
This commit is contained in:
Родитель
61b66a57aa
Коммит
ee9134d96f
|
@ -13,7 +13,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -34,20 +34,20 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"../../../.../temp\"\n",
|
||||
"DATA_FOLDER = \"./temp\"\n",
|
||||
"TRAIN_FILE = \"yahoo_answers_csv/train.csv\"\n",
|
||||
"TEST_FILE = \"yahoo_answers_csv/test.csv\"\n",
|
||||
"BERT_CACHE_DIR = \"../../../temp\"\n",
|
||||
"MAX_LEN = 300\n",
|
||||
"BERT_CACHE_DIR = \"./temp\"\n",
|
||||
"MAX_LEN = 250\n",
|
||||
"BATCH_SIZE = 16\n",
|
||||
"USE_GPU = True\n",
|
||||
"NUM_GPUS = 2\n",
|
||||
"NUM_EPOCHS = 1\n",
|
||||
"NUM_ROWS_TRAIN = 10000 # number of training examples to read\n",
|
||||
"NUM_ROWS_TEST = 10000 # number of test examples to read"
|
||||
"NUM_ROWS_TRAIN = 50000 # number of training examples to read\n",
|
||||
"NUM_ROWS_TEST = 20000 # number of test examples to read"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -59,10 +59,12 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not os.path.exists(DATA_FOLDER):\n",
|
||||
" os.mkdir(DATA_FOLDER)\n",
|
||||
"ya_dataset.download(DATA_FOLDER)"
|
||||
]
|
||||
},
|
||||
|
@ -75,7 +77,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -105,19 +107,44 @@
|
|||
"## Tokenize and Preprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and test sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(Language.ENGLISH, to_lower=False, cache_dir=BERT_CACHE_DIR)\n",
|
||||
"tokenizer = Tokenizer(Language.ENGLISH, to_lower=True, cache_dir=BERT_CACHE_DIR)\n",
|
||||
"\n",
|
||||
"# tokenize\n",
|
||||
"tokens_train = tokenizer.tokenize(text_train)\n",
|
||||
"tokens_test = tokenizer.tokenize(text_test)\n",
|
||||
"tokens_test = tokenizer.tokenize(text_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add sentence markers\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"\n",
|
||||
"# get BERT-format tokens (padded and truncated)\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
|
@ -135,7 +162,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -153,9 +180,36 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"t_total value of -1 results in schedule not being applied\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->313/3125; loss:2.469508\n",
|
||||
"epoch:1/1; batch:314->626/3125; loss:1.179081\n",
|
||||
"epoch:1/1; batch:627->939/3125; loss:0.677443\n",
|
||||
"epoch:1/1; batch:940->1252/3125; loss:1.689727\n",
|
||||
"epoch:1/1; batch:1253->1565/3125; loss:0.781167\n",
|
||||
"epoch:1/1; batch:1566->1878/3125; loss:1.036024\n",
|
||||
"epoch:1/1; batch:1879->2191/3125; loss:0.909294\n",
|
||||
"epoch:1/1; batch:2192->2504/3125; loss:0.441344\n",
|
||||
"epoch:1/1; batch:2505->2817/3125; loss:0.823389\n",
|
||||
"epoch:1/1; batch:2818->3130/3125; loss:1.036229\n",
|
||||
"[Training time: 1.132 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# train\n",
|
||||
"with Timer() as t:\n",
|
||||
|
@ -163,12 +217,11 @@
|
|||
" token_ids=tokens_train,\n",
|
||||
" input_mask=mask_train,\n",
|
||||
" labels=labels_train, \n",
|
||||
" use_gpu=USE_GPU, \n",
|
||||
" num_gpus=NUM_GPUS, \n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE, \n",
|
||||
" verbose=True,\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" ) \n",
|
||||
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
|
@ -181,12 +234,20 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 20000/20000 [08:00<00:00, 41.85it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = classifier.predict(\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, use_gpu=False, batch_size=BATCH_SIZE\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
@ -194,16 +255,134 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate Results"
|
||||
"## Evaluate Results\n",
|
||||
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" accuracy: 0.6564\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>precision</th>\n",
|
||||
" <th>recall</th>\n",
|
||||
" <th>f1</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.592506</td>\n",
|
||||
" <td>0.497053</td>\n",
|
||||
" <td>0.540598</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.749070</td>\n",
|
||||
" <td>0.673518</td>\n",
|
||||
" <td>0.709288</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.789308</td>\n",
|
||||
" <td>0.680955</td>\n",
|
||||
" <td>0.731139</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.561592</td>\n",
|
||||
" <td>0.440535</td>\n",
|
||||
" <td>0.493752</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.854772</td>\n",
|
||||
" <td>0.789272</td>\n",
|
||||
" <td>0.820717</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>0.885998</td>\n",
|
||||
" <td>0.847659</td>\n",
|
||||
" <td>0.866404</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>0.425440</td>\n",
|
||||
" <td>0.687416</td>\n",
|
||||
" <td>0.525592</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>0.756364</td>\n",
|
||||
" <td>0.700337</td>\n",
|
||||
" <td>0.727273</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>0.826006</td>\n",
|
||||
" <td>0.485432</td>\n",
|
||||
" <td>0.611496</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>0.756186</td>\n",
|
||||
" <td>0.731039</td>\n",
|
||||
" <td>0.743400</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" precision recall f1\n",
|
||||
"0 0.592506 0.497053 0.540598\n",
|
||||
"1 0.749070 0.673518 0.709288\n",
|
||||
"2 0.789308 0.680955 0.731139\n",
|
||||
"3 0.561592 0.440535 0.493752\n",
|
||||
"4 0.854772 0.789272 0.820717\n",
|
||||
"5 0.885998 0.847659 0.866404\n",
|
||||
"6 0.425440 0.687416 0.525592\n",
|
||||
"7 0.756364 0.700337 0.727273\n",
|
||||
"8 0.826006 0.485432 0.611496\n",
|
||||
"9 0.756186 0.731039 0.743400"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# eval metrics\n",
|
||||
"accuracy = accuracy_score(labels_test, preds)\n",
|
||||
"precision = precision_score(labels_test, preds, average=None)\n",
|
||||
"recall = recall_score(labels_test, preds, average=None)\n",
|
||||
|
@ -212,6 +391,13 @@
|
|||
"print(\"\\n accuracy: {}\".format(accuracy))\n",
|
||||
"pd.DataFrame({\"precision\": precision, \"recall\": recall, \"f1\": f1})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -115,6 +115,7 @@ class SequenceClassifier:
|
|||
)
|
||||
|
||||
opt.zero_grad()
|
||||
|
||||
y_h = self.model(
|
||||
input_ids=x_batch,
|
||||
token_type_ids=None,
|
||||
|
@ -128,7 +129,7 @@ class SequenceClassifier:
|
|||
if verbose:
|
||||
if i % ((num_batches // 10) + 1) == 0:
|
||||
print(
|
||||
"epoch:{}/{}; batch:{}->{}/{}; loss:{}".format(
|
||||
"epoch:{}/{}; batch:{}->{}/{}; loss:{:.6f}".format(
|
||||
epoch + 1,
|
||||
num_epochs,
|
||||
i + 1,
|
||||
|
@ -137,6 +138,9 @@ class SequenceClassifier:
|
|||
loss.data,
|
||||
)
|
||||
)
|
||||
# empty cache
|
||||
del [x_batch, y_batch, mask_batch]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, token_ids, input_mask, num_gpus=1, batch_size=32):
|
||||
"""Scores the given dataset and returns the predicted classes.
|
||||
|
|
|
@ -15,7 +15,7 @@ def download(dir_path):
|
|||
"""Downloads and extracts the dataset files"""
|
||||
file_name = URL.split("/")[-1]
|
||||
maybe_download(URL, file_name, dir_path)
|
||||
extract_tar(os.path.join(dir_path, file_name))
|
||||
extract_tar(os.path.join(dir_path, file_name), dir_path)
|
||||
|
||||
|
||||
def read_data(data_file, nrows=None):
|
||||
|
|
|
@ -11,7 +11,7 @@ def get_device(device="gpu"):
|
|||
Args:
|
||||
device (str, optional): Device string: "cpu" or "gpu". Defaults to "gpu".
|
||||
Returns:
|
||||
torch.device: A PyTorch device: cpu or gpu.
|
||||
torch.device: A PyTorch device (cpu or gpu).
|
||||
"""
|
||||
if device == "gpu":
|
||||
if torch.cuda.is_available():
|
||||
|
|
Загрузка…
Ссылка в новой задаче