This commit is contained in:
Said Bleik 2019-06-13 22:26:29 -04:00
Родитель 274615b396
Коммит 51c22b9607
4 изменённых файлов: 31 добавлений и 111 удалений

Просмотреть файл

@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@ -21,7 +21,7 @@
"sys.path.append(\"../../\")\n",
"import os\n",
"import pandas as pd\n",
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.model_selection import train_test_split\n",
"from utils_nlp.dataset.url_utils import maybe_download, extract_zip\n",
"from utils_nlp.eval.classification import eval_classification\n",
@ -45,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
@ -76,8 +76,8 @@
"metadata": {},
"outputs": [],
"source": [
"zip_file = \"dac.zip\"\n",
"csv_file = \"arabic_dataset_classifiction.csv\"\n",
"zip_file = URL.split[\"/\"][-1]\n",
"csv_file = zip_file.replace(\".zip\", \"\")\n",
"maybe_download(URL, filename=zip_file, work_directory=DATA_FOLDER)\n",
"extract_zip(file_path=os.path.join(DATA_FOLDER, zip_file), dest_path=DATA_FOLDER)\n",
"df = pd.read_csv(os.path.join(DATA_FOLDER, csv_file))"
@ -344,7 +344,7 @@
"source": [
"In addition, we perform the following preprocessing steps in the cell below:\n",
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
"- Add sentence markers\n",
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
"- Pad or truncate the token lists to the specified max length\n",
"- Return mask lists that indicate paddings' positions\n",
"\n",
@ -419,7 +419,7 @@
"epoch:1/1; batch:1573->1834/2618; loss:0.217612\n",
"epoch:1/1; batch:1835->2096/2618; loss:0.314651\n",
"epoch:1/1; batch:2097->2358/2618; loss:0.065314\n",
"epoch:1/1; batch:2359->2620/2618; loss:0.088071\n",
"epoch:1/1; batch:2359->2618/2618; loss:0.088071\n",
"[Training time: 1.414 hrs]\n"
]
}
@ -475,106 +475,30 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" accuracy: 0.946656\n"
" culture 0.91 0.96 0.94 3479\n",
" diverse 0.92 0.98 0.95 4091\n",
" economy 0.92 0.85 0.88 3517\n",
" politics 0.91 0.89 0.90 5222\n",
" sports 0.99 0.99 0.99 11623\n",
"\n",
" accuracy 0.95 27932\n",
" macro avg 0.93 0.93 0.93 27932\n",
"weighted avg 0.95 0.95 0.95 27932\n",
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>f1</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>culture</td>\n",
" <td>0.911389</td>\n",
" <td>0.963783</td>\n",
" <td>0.936854</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>diverse</td>\n",
" <td>0.923255</td>\n",
" <td>0.976289</td>\n",
" <td>0.949032</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>economy</td>\n",
" <td>0.918719</td>\n",
" <td>0.848450</td>\n",
" <td>0.882188</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>politics</td>\n",
" <td>0.911776</td>\n",
" <td>0.886633</td>\n",
" <td>0.899029</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>sports</td>\n",
" <td>0.989656</td>\n",
" <td>0.987783</td>\n",
" <td>0.988719</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label precision recall f1\n",
"0 culture 0.911389 0.963783 0.936854\n",
"1 diverse 0.923255 0.976289 0.949032\n",
"2 economy 0.918719 0.848450 0.882188\n",
"3 politics 0.911776 0.886633 0.899029\n",
"4 sports 0.989656 0.987783 0.988719"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy = accuracy_score(df_test[label_col], preds)\n",
"precision = precision_score(df_test[label_col], preds, average=None)\n",
"recall = recall_score(df_test[label_col], preds, average=None)\n",
"f1 = f1_score(df_test[label_col], preds, average=None)\n",
"\n",
"print(\"\\n accuracy: {:.6f}\".format(accuracy))\n",
"pd.DataFrame({\"label\": labels, \"precision\": precision, \"recall\": recall, \"f1\": f1})"
"print(classification_report(df_test[label_col], preds, target_names=labels))"
]
}
],

Просмотреть файл

@ -46,13 +46,11 @@ class Tokenizer:
self.language = language
def tokenize(self, text):
"""Uses a BERT tokenizer
"""Tokenizes a list of documents using a BERT tokenizer
Args:
text (list): [description]
text (list(str)): list of text documents.
Returns:
[list]: [description]
[list(str)]: list of token lists.
"""
tokens = [self.tokenizer.tokenize(x) for x in text]
return tokens

Просмотреть файл

@ -140,7 +140,7 @@ class BERTSequenceClassifier:
epoch + 1,
num_epochs,
i + 1,
(i + 1 + (num_batches // 10)) % num_batches,
min(i + 1 + num_batches // 10, num_batches),
num_batches,
loss.data,
)

Просмотреть файл

@ -7,7 +7,6 @@ https://www.nyu.edu/projects/bowman/multinli/
import os
import pandas as pd
import requests
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
@ -19,21 +18,20 @@ DATA_FILES = {
def load_pandas_df(local_cache_path=None, file_split="train"):
"""Downloads and extracts the dataset files
"""Downloads and extracts the dataset files
Args:
local_cache_path ([type], optional): [description]. Defaults to None.
file_split (str, optional): The subset to load.
One of: {"train", "dev_matched", "dev_mismatched"}
Defaults to "train".
One of: {"train", "dev_matched", "dev_mismatched"}
Defaults to "train".
Returns:
pd.DataFrame: pandas DataFrame containing the specified MultiNLI subset.
pd.DataFrame: pandas DataFrame containing the specified
MultiNLI subset.
"""
file_name = URL.split("/")[-1]
if not os.path.exists(os.path.join(local_cache_path, file_name)):
response = requests.get(URL)
with open(os.path.join(local_cache_path, file_name), "wb") as f:
f.write(response.content)
maybe_download(URL, file_name, local_cache_path)
if not os.path.exists(
os.path.join(local_cache_path, DATA_FILES[file_split])
):