This commit is contained in:
Said Bleik 2019-06-14 15:11:17 -04:00
Родитель fc12c867af
Коммит b0ead86bf2
3 изменённых файлов: 78 добавлений и 73 удалений

Просмотреть файл

@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -23,7 +23,7 @@
"import pandas as pd\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.model_selection import train_test_split\n",
"from utils_nlp.dataset.url_utils import maybe_download, extract_zip\n",
"from utils_nlp.dataset.dac import load_pandas_df\n",
"from utils_nlp.eval.classification import eval_classification\n",
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
"from utils_nlp.bert.common import Language, Tokenizer\n",
@ -45,13 +45,10 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"URL = (\"https://data.mendeley.com/datasets/v524p5dhpj/2\" \n",
" \"/files/91cb8398-9451-43af-88fc-041a0956ae2d/\"\n",
" \"arabic_dataset_classifiction.csv.zip\")\n",
"DATA_FOLDER = \"../../../temp\"\n",
"BERT_CACHE_DIR = \"../../../temp\"\n",
"LANGUAGE = Language.MULTILINGUAL\n",
@ -72,20 +69,16 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"zip_file = URL.split[\"/\"][-1]\n",
"csv_file = zip_file.replace(\".zip\", \"\")\n",
"maybe_download(URL, filename=zip_file, work_directory=DATA_FOLDER)\n",
"extract_zip(file_path=os.path.join(DATA_FOLDER, zip_file), dest_path=DATA_FOLDER)\n",
"df = pd.read_csv(os.path.join(DATA_FOLDER, csv_file))"
"df = load_pandas_df(DATA_FOLDER)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@ -152,7 +145,7 @@
"4 تزال صناعة الجلود في المغرب تتبع الطريقة التقل... 0"
]
},
"execution_count": 21,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -163,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -181,27 +174,27 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4 34899\n",
"3 15283\n",
"1 12637\n",
"2 10718\n",
"0 10259\n",
"4 46522\n",
"3 20505\n",
"1 16728\n",
"2 14235\n",
"0 13738\n",
"Name: targe, dtype: int64"
]
},
"execution_count": 22,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train[label_col].value_counts()"
"df[label_col].value_counts()"
]
},
{
@ -213,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -274,7 +267,7 @@
"4 sports"
]
},
"execution_count": 32,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -295,7 +288,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -329,7 +322,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -353,7 +346,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -375,7 +368,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -394,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": null,
"metadata": {
"scrolled": true
},
@ -410,17 +403,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch:1/1; batch:1->262/2618; loss:1.655931\n",
"epoch:1/1; batch:263->524/2618; loss:0.129833\n",
"epoch:1/1; batch:525->786/2618; loss:0.295053\n",
"epoch:1/1; batch:787->1048/2618; loss:0.043921\n",
"epoch:1/1; batch:1049->1310/2618; loss:0.156879\n",
"epoch:1/1; batch:1311->1572/2618; loss:0.168521\n",
"epoch:1/1; batch:1573->1834/2618; loss:0.217612\n",
"epoch:1/1; batch:1835->2096/2618; loss:0.314651\n",
"epoch:1/1; batch:2097->2358/2618; loss:0.065314\n",
"epoch:1/1; batch:2359->2618/2618; loss:0.088071\n",
"[Training time: 1.414 hrs]\n"
"epoch:1/1; batch:1->262/2618; loss:1.649290\n"
]
}
],
@ -448,17 +431,9 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"27936it [08:36, 53.07it/s] \n"
]
}
],
"outputs": [],
"source": [
"preds = classifier.predict(\n",
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
@ -475,31 +450,19 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" culture 0.91 0.96 0.94 3479\n",
" diverse 0.92 0.98 0.95 4091\n",
" economy 0.92 0.85 0.88 3517\n",
" politics 0.91 0.89 0.90 5222\n",
" sports 0.99 0.99 0.99 11623\n",
"\n",
" accuracy 0.95 27932\n",
" macro avg 0.93 0.93 0.93 27932\n",
"weighted avg 0.95 0.95 0.95 27932\n",
"\n"
]
}
],
"outputs": [],
"source": [
"print(classification_report(df_test[label_col], preds, target_names=labels))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

Просмотреть файл

@ -31,7 +31,7 @@ class BERTSequenceClassifier:
Defaults to ".".
"""
if num_labels < 2:
raise Exception("Number of labels should be at least 2.")
raise ValueError("Number of labels should be at least 2.")
self.language = language
self.num_labels = num_labels

42
utils_nlp/dataset/dac.py Normal file
Просмотреть файл

@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Dataset for Arabic Classification utils
https://data.mendeley.com/datasets/v524p5dhpj/2
Mohamed, BINIZ (2018), DataSet for Arabic Classification, Mendeley Data, v2
paper link: ("https://www.mendeley.com/catalogue/
arabic-text-classification-using-deep-learning-technics/")
"""
import os
import pandas as pd
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
URL = (
"https://data.mendeley.com/datasets/v524p5dhpj/2"
"/files/91cb8398-9451-43af-88fc-041a0956ae2d/"
"arabic_dataset_classifiction.csv.zip"
)
def load_pandas_df(local_cache_path=None):
"""Downloads and extracts the dataset files
Args:
local_cache_path ([type], optional): [description]. Defaults to None.
file_split (str, optional): The subset to load.
One of: {"train", "dev_matched", "dev_mismatched"}
Defaults to "train".
Returns:
pd.DataFrame: pandas DataFrame containing the specified
MultiNLI subset.
"""
zip_file = URL.split("/")[-1]
csv_file_path = os.path.join(
local_cache_path, zip_file.replace(".zip", "")
)
maybe_download(URL, zip_file, local_cache_path)
if not os.path.exists(csv_file_path):
extract_zip(file_path=csv_file_path, dest_path=local_cache_path)
return pd.read_csv(csv_file_path)