added dataset utils
This commit is contained in:
Родитель
fc12c867af
Коммит
b0ead86bf2
|
@ -13,7 +13,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -23,7 +23,7 @@
|
|||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from utils_nlp.dataset.url_utils import maybe_download, extract_zip\n",
|
||||
"from utils_nlp.dataset.dac import load_pandas_df\n",
|
||||
"from utils_nlp.eval.classification import eval_classification\n",
|
||||
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
|
||||
"from utils_nlp.bert.common import Language, Tokenizer\n",
|
||||
|
@ -45,13 +45,10 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"URL = (\"https://data.mendeley.com/datasets/v524p5dhpj/2\" \n",
|
||||
" \"/files/91cb8398-9451-43af-88fc-041a0956ae2d/\"\n",
|
||||
" \"arabic_dataset_classifiction.csv.zip\")\n",
|
||||
"DATA_FOLDER = \"../../../temp\"\n",
|
||||
"BERT_CACHE_DIR = \"../../../temp\"\n",
|
||||
"LANGUAGE = Language.MULTILINGUAL\n",
|
||||
|
@ -72,20 +69,16 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"zip_file = URL.split[\"/\"][-1]\n",
|
||||
"csv_file = zip_file.replace(\".zip\", \"\")\n",
|
||||
"maybe_download(URL, filename=zip_file, work_directory=DATA_FOLDER)\n",
|
||||
"extract_zip(file_path=os.path.join(DATA_FOLDER, zip_file), dest_path=DATA_FOLDER)\n",
|
||||
"df = pd.read_csv(os.path.join(DATA_FOLDER, csv_file))"
|
||||
"df = load_pandas_df(DATA_FOLDER)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -152,7 +145,7 @@
|
|||
"4 تزال صناعة الجلود في المغرب تتبع الطريقة التقل... 0"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -163,7 +156,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -181,27 +174,27 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4 34899\n",
|
||||
"3 15283\n",
|
||||
"1 12637\n",
|
||||
"2 10718\n",
|
||||
"0 10259\n",
|
||||
"4 46522\n",
|
||||
"3 20505\n",
|
||||
"1 16728\n",
|
||||
"2 14235\n",
|
||||
"0 13738\n",
|
||||
"Name: targe, dtype: int64"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df_train[label_col].value_counts()"
|
||||
"df[label_col].value_counts()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -213,7 +206,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -274,7 +267,7 @@
|
|||
"4 sports"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -295,7 +288,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -329,7 +322,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -353,7 +346,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -375,7 +368,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -394,7 +387,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -410,17 +403,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->262/2618; loss:1.655931\n",
|
||||
"epoch:1/1; batch:263->524/2618; loss:0.129833\n",
|
||||
"epoch:1/1; batch:525->786/2618; loss:0.295053\n",
|
||||
"epoch:1/1; batch:787->1048/2618; loss:0.043921\n",
|
||||
"epoch:1/1; batch:1049->1310/2618; loss:0.156879\n",
|
||||
"epoch:1/1; batch:1311->1572/2618; loss:0.168521\n",
|
||||
"epoch:1/1; batch:1573->1834/2618; loss:0.217612\n",
|
||||
"epoch:1/1; batch:1835->2096/2618; loss:0.314651\n",
|
||||
"epoch:1/1; batch:2097->2358/2618; loss:0.065314\n",
|
||||
"epoch:1/1; batch:2359->2618/2618; loss:0.088071\n",
|
||||
"[Training time: 1.414 hrs]\n"
|
||||
"epoch:1/1; batch:1->262/2618; loss:1.649290\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -448,17 +431,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"27936it [08:36, 53.07it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preds = classifier.predict(\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
|
||||
|
@ -475,31 +450,19 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" culture 0.91 0.96 0.94 3479\n",
|
||||
" diverse 0.92 0.98 0.95 4091\n",
|
||||
" economy 0.92 0.85 0.88 3517\n",
|
||||
" politics 0.91 0.89 0.90 5222\n",
|
||||
" sports 0.99 0.99 0.99 11623\n",
|
||||
"\n",
|
||||
" accuracy 0.95 27932\n",
|
||||
" macro avg 0.93 0.93 0.93 27932\n",
|
||||
"weighted avg 0.95 0.95 0.95 27932\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(classification_report(df_test[label_col], preds, target_names=labels))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -31,7 +31,7 @@ class BERTSequenceClassifier:
|
|||
Defaults to ".".
|
||||
"""
|
||||
if num_labels < 2:
|
||||
raise Exception("Number of labels should be at least 2.")
|
||||
raise ValueError("Number of labels should be at least 2.")
|
||||
|
||||
self.language = language
|
||||
self.num_labels = num_labels
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Dataset for Arabic Classification utils
|
||||
https://data.mendeley.com/datasets/v524p5dhpj/2
|
||||
Mohamed, BINIZ (2018), “DataSet for Arabic Classification”, Mendeley Data, v2
|
||||
paper link: ("https://www.mendeley.com/catalogue/
|
||||
arabic-text-classification-using-deep-learning-technics/")
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
|
||||
URL = (
|
||||
"https://data.mendeley.com/datasets/v524p5dhpj/2"
|
||||
"/files/91cb8398-9451-43af-88fc-041a0956ae2d/"
|
||||
"arabic_dataset_classifiction.csv.zip"
|
||||
)
|
||||
|
||||
|
||||
def load_pandas_df(local_cache_path=None):
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description]. Defaults to None.
|
||||
file_split (str, optional): The subset to load.
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the specified
|
||||
MultiNLI subset.
|
||||
"""
|
||||
zip_file = URL.split("/")[-1]
|
||||
csv_file_path = os.path.join(
|
||||
local_cache_path, zip_file.replace(".zip", "")
|
||||
)
|
||||
|
||||
maybe_download(URL, zip_file, local_cache_path)
|
||||
|
||||
if not os.path.exists(csv_file_path):
|
||||
extract_zip(file_path=csv_file_path, dest_path=local_cache_path)
|
||||
return pd.read_csv(csv_file_path)
|
Загрузка…
Ссылка в новой задаче