Merge pull request #99 from microsoft/bleik

ar TC example
This commit is contained in:
Said Bleik 2019-06-14 16:21:03 -04:00 коммит произвёл GitHub
Родитель f6c06c395c 0929c37d56
Коммит a514025f5d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 625 добавлений и 566 удалений

Просмотреть файл

@ -0,0 +1,519 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
"\n",
"*Licensed under the MIT License.*\n",
"\n",
"# Classification of Arabic News Articles using BERT"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../../\")\n",
"import os\n",
"import pandas as pd\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.model_selection import train_test_split\n",
"from utils_nlp.dataset.dac import load_pandas_df\n",
"from utils_nlp.eval.classification import eval_classification\n",
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
"from utils_nlp.bert.common import Language, Tokenizer\n",
"from utils_nlp.common.timer import Timer\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction\n",
"In this notebook, we fine-tune and evaluate a pretrained [BERT](https://arxiv.org/abs/1810.04805) model on an Arabic dataset of news articles. The [dataset](https://data.mendeley.com/datasets/v524p5dhpj/2) includes articles from 3 different newspapers, and the articles are categorized into 5 classes: *sports, politics, culture, economy and diverse*. The data is described in more detail in this [paper](http://article.nadiapub.com/IJGDC/vol11_no9/9.pdf).\n",
"\n",
"We use a [sequence classifier](../../utils_nlp/bert/sequence_classification.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT) of Google's [BERT](https://github.com/google-research/bert). The classifier loads a pretrained [multilingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) that was trained on 104 languages, including Arabic."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"DATA_FOLDER = \"../../../temp\"\n",
"BERT_CACHE_DIR = \"../../../temp\"\n",
"LANGUAGE = Language.MULTILINGUAL\n",
"MAX_LEN = 200\n",
"BATCH_SIZE = 32\n",
"NUM_GPUS = 2\n",
"NUM_EPOCHS = 1\n",
"TRAIN_SIZE = 0.75"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read Dataset\n",
"We start by loading the data. The following lines also download the file if it doesn't exist, and extract the csv file into the specified data folder."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"df = load_pandas_df(DATA_FOLDER)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>targe</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>بين أستوديوهات ورزازات وصحراء مرزوكة وآثار ولي...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>قررت النجمة الأمريكية أوبرا وينفري ألا يقتصر ع...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>أخبارنا المغربية الوزاني تصوير الشملالي ألهب ا...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>اخبارنا المغربية قال ابراهيم الراشدي محامي سعد...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>تزال صناعة الجلود في المغرب تتبع الطريقة التقل...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text targe\n",
"0 بين أستوديوهات ورزازات وصحراء مرزوكة وآثار ولي... 0\n",
"1 قررت النجمة الأمريكية أوبرا وينفري ألا يقتصر ع... 0\n",
"2 أخبارنا المغربية الوزاني تصوير الشملالي ألهب ا... 0\n",
"3 اخبارنا المغربية قال ابراهيم الراشدي محامي سعد... 0\n",
"4 تزال صناعة الجلود في المغرب تتبع الطريقة التقل... 0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# set the text and label columns\n",
"text_col = df.columns[0]\n",
"label_col = df.columns[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inspect the distribution of labels:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4 46522\n",
"3 20505\n",
"1 16728\n",
"2 14235\n",
"0 13738\n",
"Name: targe, dtype: int64"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[label_col].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We compare the counts with those presented in the author's [paper](http://article.nadiapub.com/IJGDC/vol11_no9/9.pdf), and infer the following label mapping:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>culture</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>diverse</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>economy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>politics</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>sports</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label\n",
"0 culture\n",
"1 diverse\n",
"2 economy\n",
"3 politics\n",
"4 sports"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# ordered list of labels\n",
"labels = [\"culture\", \"diverse\", \"economy\", \"politics\", \"sports\"]\n",
"num_labels = len(labels)\n",
"pd.DataFrame({\"label\": labels})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we split the data for training and testing:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training examples: 83796\n",
"Number of testing examples: 27932\n"
]
}
],
"source": [
"df_train, df_test = train_test_split(df, train_size = TRAIN_SIZE, random_state=0)\n",
"print(\"Number of training examples: {}\".format(df_train.shape[0]))\n",
"print(\"Number of testing examples: {}\".format(df_test.shape[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenize and Preprocess"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer(LANGUAGE, cache_dir=BERT_CACHE_DIR)\n",
"tokens_train = tokenizer.tokenize(df_train[text_col].astype(str))\n",
"tokens_test = tokenizer.tokenize(df_test[text_col].astype(str))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition, we perform the following preprocessing steps in the cell below:\n",
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
"- Pad or truncate the token lists to the specified max length\n",
"- Return mask lists that indicate paddings' positions\n",
"\n",
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
" tokens_train, MAX_LEN\n",
")\n",
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
" tokens_test, MAX_LEN\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Model\n",
"Next, we create a sequence classifier that loads a pre-trained BERT model, given the language and number of labels."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"classifier = BERTSequenceClassifier(\n",
" language=LANGUAGE, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train\n",
"We train the classifier using the training examples. This involves fine-tuning the BERT Transformer and learning a linear classification layer on top of that:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"t_total value of -1 results in schedule not being applied\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/1; batch:1->262/2618; loss:1.632097\n",
"epoch:1/1; batch:263->524/2618; loss:0.402912\n",
"epoch:1/1; batch:525->786/2618; loss:0.324510\n",
"epoch:1/1; batch:787->1048/2618; loss:0.477946\n",
"epoch:1/1; batch:1049->1310/2618; loss:0.333729\n",
"epoch:1/1; batch:1311->1572/2618; loss:0.021917\n",
"epoch:1/1; batch:1573->1834/2618; loss:0.031262\n",
"epoch:1/1; batch:1835->2096/2618; loss:0.264172\n",
"epoch:1/1; batch:2097->2358/2618; loss:0.034074\n",
"epoch:1/1; batch:2359->2618/2618; loss:0.033827\n",
"[Training time: 1.400 hrs]\n"
]
}
],
"source": [
"with Timer() as t:\n",
" classifier.fit(\n",
" token_ids=tokens_train,\n",
" input_mask=mask_train,\n",
" labels=list(df_train[label_col]), \n",
" num_gpus=NUM_GPUS, \n",
" num_epochs=NUM_EPOCHS,\n",
" batch_size=BATCH_SIZE, \n",
" verbose=True,\n",
" ) \n",
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Score\n",
"We score the test set using the trained classifier:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"27936it [08:35, 56.23it/s] \n"
]
}
],
"source": [
"preds = classifier.predict(\n",
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate Results\n",
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" culture 0.96 0.93 0.94 3479\n",
" diverse 0.92 0.98 0.95 4091\n",
" economy 0.91 0.88 0.89 3517\n",
" politics 0.90 0.90 0.90 5222\n",
" sports 0.99 0.99 0.99 11623\n",
"\n",
" accuracy 0.95 27932\n",
" macro avg 0.94 0.94 0.94 27932\n",
"weighted avg 0.95 0.95 0.95 27932\n",
"\n"
]
}
],
"source": [
"print(classification_report(df_test[label_col], preds, target_names=labels))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -21,7 +21,7 @@
"sys.path.append(\"../../\")\n",
"import os\n",
"import pandas as pd\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"from utils_nlp.dataset.multinli import load_pandas_df\n",
@ -46,12 +46,12 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"DATA_FOLDER = \"./temp\"\n",
"BERT_CACHE_DIR = \"./temp\"\n",
"DATA_FOLDER = \"../../../temp\"\n",
"BERT_CACHE_DIR = \"../../../temp\"\n",
"LANGUAGE = Language.ENGLISH\n",
"TO_LOWER = True\n",
"MAX_LEN = 150\n",
@ -77,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -87,7 +87,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -154,7 +154,7 @@
"13 travel Thebes held onto power until the 12th Dynasty,..."
]
},
"execution_count": 8,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -172,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -186,7 +186,7 @@
"Name: genre, dtype: int64"
]
},
"execution_count": 11,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -204,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -221,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -256,7 +256,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -272,7 +272,7 @@
"source": [
"In addition, we perform the following preprocessing steps in the cell below:\n",
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
"- Add sentence markers\n",
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
"- Pad or truncate the token lists to the specified max length\n",
"- Return mask lists that indicate paddings' positions\n",
"\n",
@ -281,7 +281,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -303,7 +303,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -322,7 +322,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"metadata": {
"scrolled": true
},
@ -338,17 +338,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/1; batch:1->246/2454; loss:1.631739\n",
"epoch:1/1; batch:247->492/2454; loss:0.427608\n",
"epoch:1/1; batch:493->738/2454; loss:0.255493\n",
"epoch:1/1; batch:739->984/2454; loss:0.286230\n",
"epoch:1/1; batch:985->1230/2454; loss:0.375268\n",
"epoch:1/1; batch:1231->1476/2454; loss:0.146290\n",
"epoch:1/1; batch:1477->1722/2454; loss:0.092100\n",
"epoch:1/1; batch:1723->1968/2454; loss:0.009405\n",
"epoch:1/1; batch:1969->2214/2454; loss:0.038235\n",
"epoch:1/1; batch:2215->2460/2454; loss:0.098216\n",
"[Training time: 0.981 hrs]\n"
"epoch:1/1; batch:1->246/2454; loss:1.824086\n",
"epoch:1/1; batch:247->492/2454; loss:0.446337\n",
"epoch:1/1; batch:493->738/2454; loss:0.298814\n",
"epoch:1/1; batch:739->984/2454; loss:0.265785\n",
"epoch:1/1; batch:985->1230/2454; loss:0.101790\n",
"epoch:1/1; batch:1231->1476/2454; loss:0.251120\n",
"epoch:1/1; batch:1477->1722/2454; loss:0.040894\n",
"epoch:1/1; batch:1723->1968/2454; loss:0.038339\n",
"epoch:1/1; batch:1969->2214/2454; loss:0.021586\n",
"epoch:1/1; batch:2215->2454/2454; loss:0.130719\n",
"[Training time: 0.980 hrs]\n"
]
}
],
@ -376,14 +376,14 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"52384it [11:54, 88.50it/s] \n"
"52384it [11:54, 88.97it/s] \n"
]
}
],
@ -403,112 +403,36 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" accuracy: 0.938273\n"
" fiction 0.90 0.94 0.92 10275\n",
" government 0.97 0.93 0.95 10292\n",
" slate 0.88 0.85 0.87 10277\n",
" telephone 0.99 1.00 0.99 11205\n",
" travel 0.95 0.97 0.96 10311\n",
"\n",
" accuracy 0.94 52360\n",
" macro avg 0.94 0.94 0.94 52360\n",
"weighted avg 0.94 0.94 0.94 52360\n",
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>fiction</td>\n",
" <td>0.917004</td>\n",
" <td>0.925839</td>\n",
" <td>0.921401</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>government</td>\n",
" <td>0.961477</td>\n",
" <td>0.928780</td>\n",
" <td>0.944845</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>slate</td>\n",
" <td>0.875161</td>\n",
" <td>0.861535</td>\n",
" <td>0.868295</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>telephone</td>\n",
" <td>0.989105</td>\n",
" <td>0.996609</td>\n",
" <td>0.992843</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>travel</td>\n",
" <td>0.943405</td>\n",
" <td>0.973232</td>\n",
" <td>0.958087</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label precision recall f1\n",
"0 fiction 0.917004 0.925839 0.921401\n",
"1 government 0.961477 0.928780 0.944845\n",
"2 slate 0.875161 0.861535 0.868295\n",
"3 telephone 0.989105 0.996609 0.992843\n",
"4 travel 0.943405 0.973232 0.958087"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy = accuracy_score(labels_test, preds)\n",
"precision = precision_score(labels_test, preds, average=None)\n",
"recall = recall_score(labels_test, preds, average=None)\n",
"f1 = f1_score(labels_test, preds, average=None)\n",
"\n",
"print(\"\\n accuracy: {:.6f}\".format(accuracy))\n",
"pd.DataFrame({\"label\": label_encoder.classes_, \"precision\": precision, \"recall\": recall, \"f1\": f1})"
"print(classification_report(labels_test, preds, target_names=label_encoder.classes_))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -522,7 +446,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
"version": "3.6.8"
}
},
"nbformat": 4,

Просмотреть файл

@ -1,424 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
"\n",
"*Licensed under the MIT License.*\n",
"\n",
"# Text Classification of Yahoo Answers using BERT\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../../\")\n",
"import os\n",
"import pandas as pd\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
"import utils_nlp.dataset.yahoo_answers as ya_dataset\n",
"from utils_nlp.eval.classification import eval_classification\n",
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
"from utils_nlp.bert.common import Language, Tokenizer\n",
"from utils_nlp.common.timer import Timer\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"DATA_FOLDER = \"./temp\"\n",
"TRAIN_FILE = \"yahoo_answers_csv/train.csv\"\n",
"TEST_FILE = \"yahoo_answers_csv/test.csv\"\n",
"BERT_CACHE_DIR = \"./temp\"\n",
"MAX_LEN = 250\n",
"BATCH_SIZE = 16\n",
"NUM_GPUS = 2\n",
"NUM_EPOCHS = 1\n",
"NUM_ROWS_TRAIN = 50000 # number of training examples to read\n",
"NUM_ROWS_TEST = 20000 # number of test examples to read"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists(DATA_FOLDER):\n",
" os.mkdir(DATA_FOLDER)\n",
"ya_dataset.download(DATA_FOLDER)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read Dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# read data\n",
"df_train = ya_dataset.read_data(\n",
" os.path.join(DATA_FOLDER, TRAIN_FILE), nrows=NUM_ROWS_TRAIN\n",
")\n",
"df_test = ya_dataset.read_data(\n",
" os.path.join(DATA_FOLDER, TEST_FILE), nrows=NUM_ROWS_TEST\n",
")\n",
"\n",
"# get labels\n",
"labels_train = ya_dataset.get_labels(df_train)\n",
"labels_test = ya_dataset.get_labels(df_test)\n",
"\n",
"num_labels = len(np.unique(labels_train))\n",
"\n",
"# get text\n",
"text_train = ya_dataset.get_text(df_train)\n",
"text_test = ya_dataset.get_text(df_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tokenize and Preprocess"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and test sets."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = Tokenizer(Language.ENGLISH, to_lower=True, cache_dir=BERT_CACHE_DIR)\n",
"\n",
"# tokenize\n",
"tokens_train = tokenizer.tokenize(text_train)\n",
"tokens_test = tokenizer.tokenize(text_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition, we perform the following preprocessing steps in the cell below:\n",
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
"- Add sentence markers\n",
"- Pad or truncate the token lists to the specified max length\n",
"\n",
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
" tokens_train, MAX_LEN\n",
")\n",
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
" tokens_test, MAX_LEN\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Model\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"classifier = BERTSequenceClassifier(\n",
" language=Language.ENGLISH, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"t_total value of -1 results in schedule not being applied\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/1; batch:1->313/3125; loss:2.469508\n",
"epoch:1/1; batch:314->626/3125; loss:1.179081\n",
"epoch:1/1; batch:627->939/3125; loss:0.677443\n",
"epoch:1/1; batch:940->1252/3125; loss:1.689727\n",
"epoch:1/1; batch:1253->1565/3125; loss:0.781167\n",
"epoch:1/1; batch:1566->1878/3125; loss:1.036024\n",
"epoch:1/1; batch:1879->2191/3125; loss:0.909294\n",
"epoch:1/1; batch:2192->2504/3125; loss:0.441344\n",
"epoch:1/1; batch:2505->2817/3125; loss:0.823389\n",
"epoch:1/1; batch:2818->3130/3125; loss:1.036229\n",
"[Training time: 1.132 hrs]\n"
]
}
],
"source": [
"# train\n",
"with Timer() as t:\n",
" classifier.fit(\n",
" token_ids=tokens_train,\n",
" input_mask=mask_train,\n",
" labels=labels_train, \n",
" num_gpus=NUM_GPUS, \n",
" num_epochs=NUM_EPOCHS,\n",
" batch_size=BATCH_SIZE, \n",
" verbose=True,\n",
" ) \n",
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Score Test Set"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 20000/20000 [08:00<00:00, 41.85it/s]\n"
]
}
],
"source": [
"preds = classifier.predict(\n",
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate Results\n",
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" accuracy: 0.6564\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.592506</td>\n",
" <td>0.497053</td>\n",
" <td>0.540598</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.749070</td>\n",
" <td>0.673518</td>\n",
" <td>0.709288</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.789308</td>\n",
" <td>0.680955</td>\n",
" <td>0.731139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.561592</td>\n",
" <td>0.440535</td>\n",
" <td>0.493752</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.854772</td>\n",
" <td>0.789272</td>\n",
" <td>0.820717</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.885998</td>\n",
" <td>0.847659</td>\n",
" <td>0.866404</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0.425440</td>\n",
" <td>0.687416</td>\n",
" <td>0.525592</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0.756364</td>\n",
" <td>0.700337</td>\n",
" <td>0.727273</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0.826006</td>\n",
" <td>0.485432</td>\n",
" <td>0.611496</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>0.756186</td>\n",
" <td>0.731039</td>\n",
" <td>0.743400</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" precision recall f1\n",
"0 0.592506 0.497053 0.540598\n",
"1 0.749070 0.673518 0.709288\n",
"2 0.789308 0.680955 0.731139\n",
"3 0.561592 0.440535 0.493752\n",
"4 0.854772 0.789272 0.820717\n",
"5 0.885998 0.847659 0.866404\n",
"6 0.425440 0.687416 0.525592\n",
"7 0.756364 0.700337 0.727273\n",
"8 0.826006 0.485432 0.611496\n",
"9 0.756186 0.731039 0.743400"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy = accuracy_score(labels_test, preds)\n",
"precision = precision_score(labels_test, preds, average=None)\n",
"recall = recall_score(labels_test, preds, average=None)\n",
"f1 = f1_score(labels_test, preds, average=None)\n",
"\n",
"print(\"\\n accuracy: {}\".format(accuracy))\n",
"pd.DataFrame({\"precision\": precision, \"recall\": recall, \"f1\": f1})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -46,13 +46,11 @@ class Tokenizer:
self.language = language
def tokenize(self, text):
"""Uses a BERT tokenizer
"""Tokenizes a list of documents using a BERT tokenizer
Args:
text (list): [description]
text (list(str)): list of text documents.
Returns:
[list]: [description]
[list(str)]: list of token lists.
"""
tokens = [self.tokenizer.tokenize(x) for x in text]
return tokens

Просмотреть файл

@ -1,13 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# This script reuses some code from
# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py
import random
import numpy as np
import torch
import torch.nn as nn
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam
from tqdm import tqdm
from utils_nlp.bert.common import Language
from utils_nlp.pytorch.device_utils import get_device, move_to_device
@ -26,7 +31,7 @@ class BERTSequenceClassifier:
Defaults to ".".
"""
if num_labels < 2:
raise Exception("Number of labels should be at least 2.")
raise ValueError("Number of labels should be at least 2.")
self.language = language
self.num_labels = num_labels
@ -135,7 +140,7 @@ class BERTSequenceClassifier:
epoch + 1,
num_epochs,
i + 1,
i + 1 + (num_batches // 10),
min(i + 1 + num_batches // 10, num_batches),
num_batches,
loss.data,
)

38
utils_nlp/dataset/dac.py Normal file
Просмотреть файл

@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Dataset for Arabic Classification utils
https://data.mendeley.com/datasets/v524p5dhpj/2
Mohamed, BINIZ (2018), DataSet for Arabic Classification, Mendeley Data, v2
paper link: ("https://www.mendeley.com/catalogue/
arabic-text-classification-using-deep-learning-technics/")
"""
import os
import pandas as pd
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
URL = (
"https://data.mendeley.com/datasets/v524p5dhpj/2"
"/files/91cb8398-9451-43af-88fc-041a0956ae2d/"
"arabic_dataset_classifiction.csv.zip"
)
def load_pandas_df(local_cache_path=None):
"""Downloads and extracts the dataset files
Args:
local_cache_path ([type], optional): [description]. Defaults to None.
Returns:
pd.DataFrame: pandas DataFrame containing the loaded dataset.
"""
zip_file = URL.split("/")[-1]
csv_file_path = os.path.join(
local_cache_path, zip_file.replace(".zip", "")
)
maybe_download(URL, zip_file, local_cache_path)
if not os.path.exists(csv_file_path):
extract_zip(file_path=csv_file_path, dest_path=local_cache_path)
return pd.read_csv(csv_file_path)

Просмотреть файл

@ -7,7 +7,6 @@ https://www.nyu.edu/projects/bowman/multinli/
import os
import pandas as pd
import requests
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
@ -19,21 +18,20 @@ DATA_FILES = {
def load_pandas_df(local_cache_path=None, file_split="train"):
"""Downloads and extracts the dataset files
"""Downloads and extracts the dataset files
Args:
local_cache_path ([type], optional): [description]. Defaults to None.
file_split (str, optional): The subset to load.
One of: {"train", "dev_matched", "dev_mismatched"}
Defaults to "train".
One of: {"train", "dev_matched", "dev_mismatched"}
Defaults to "train".
Returns:
pd.DataFrame: pandas DataFrame containing the specified MultiNLI subset.
pd.DataFrame: pandas DataFrame containing the specified
MultiNLI subset.
"""
file_name = URL.split("/")[-1]
if not os.path.exists(os.path.join(local_cache_path, file_name)):
response = requests.get(URL)
with open(os.path.join(local_cache_path, file_name), "wb") as f:
f.write(response.content)
maybe_download(URL, file_name, local_cache_path)
if not os.path.exists(
os.path.join(local_cache_path, DATA_FILES[file_split])
):

Просмотреть файл

@ -1,14 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import requests
import logging
import math
import os
import tarfile
import zipfile
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from tqdm import tqdm
import requests
from tqdm import tqdm
log = logging.getLogger(__name__)
@ -17,13 +19,12 @@ def maybe_download(
url, filename=None, work_directory=".", expected_bytes=None
):
"""Download a file if it is not already downloaded.
Args:
filename (str): File name.
work_directory (str): Working directory.
url (str): URL of the file to download.
expected_bytes (int): Expected file size in bytes.
Returns:
str: File path of the file downloaded.
"""
@ -80,7 +81,7 @@ def extract_zip(file_path, dest_path="."):
raise IOError("File doesn't exist")
if not os.path.exists(dest_path):
raise IOError("Destination directory doesn't exist")
with ZipFile(file_path) as z:
with zipfile.ZipFile(file_path) as z:
z.extractall(path=dest_path)