minor fixes
This commit is contained in:
Родитель
274615b396
Коммит
51c22b9607
|
@ -13,7 +13,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -21,7 +21,7 @@
|
|||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from utils_nlp.dataset.url_utils import maybe_download, extract_zip\n",
|
||||
"from utils_nlp.eval.classification import eval_classification\n",
|
||||
|
@ -45,7 +45,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -76,8 +76,8 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"zip_file = \"dac.zip\"\n",
|
||||
"csv_file = \"arabic_dataset_classifiction.csv\"\n",
|
||||
"zip_file = URL.split[\"/\"][-1]\n",
|
||||
"csv_file = zip_file.replace(\".zip\", \"\")\n",
|
||||
"maybe_download(URL, filename=zip_file, work_directory=DATA_FOLDER)\n",
|
||||
"extract_zip(file_path=os.path.join(DATA_FOLDER, zip_file), dest_path=DATA_FOLDER)\n",
|
||||
"df = pd.read_csv(os.path.join(DATA_FOLDER, csv_file))"
|
||||
|
@ -344,7 +344,7 @@
|
|||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add sentence markers\n",
|
||||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"\n",
|
||||
|
@ -419,7 +419,7 @@
|
|||
"epoch:1/1; batch:1573->1834/2618; loss:0.217612\n",
|
||||
"epoch:1/1; batch:1835->2096/2618; loss:0.314651\n",
|
||||
"epoch:1/1; batch:2097->2358/2618; loss:0.065314\n",
|
||||
"epoch:1/1; batch:2359->2620/2618; loss:0.088071\n",
|
||||
"epoch:1/1; batch:2359->2618/2618; loss:0.088071\n",
|
||||
"[Training time: 1.414 hrs]\n"
|
||||
]
|
||||
}
|
||||
|
@ -475,106 +475,30 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" accuracy: 0.946656\n"
|
||||
" culture 0.91 0.96 0.94 3479\n",
|
||||
" diverse 0.92 0.98 0.95 4091\n",
|
||||
" economy 0.92 0.85 0.88 3517\n",
|
||||
" politics 0.91 0.89 0.90 5222\n",
|
||||
" sports 0.99 0.99 0.99 11623\n",
|
||||
"\n",
|
||||
" accuracy 0.95 27932\n",
|
||||
" macro avg 0.93 0.93 0.93 27932\n",
|
||||
"weighted avg 0.95 0.95 0.95 27932\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>label</th>\n",
|
||||
" <th>precision</th>\n",
|
||||
" <th>recall</th>\n",
|
||||
" <th>f1</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>culture</td>\n",
|
||||
" <td>0.911389</td>\n",
|
||||
" <td>0.963783</td>\n",
|
||||
" <td>0.936854</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>diverse</td>\n",
|
||||
" <td>0.923255</td>\n",
|
||||
" <td>0.976289</td>\n",
|
||||
" <td>0.949032</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>economy</td>\n",
|
||||
" <td>0.918719</td>\n",
|
||||
" <td>0.848450</td>\n",
|
||||
" <td>0.882188</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>politics</td>\n",
|
||||
" <td>0.911776</td>\n",
|
||||
" <td>0.886633</td>\n",
|
||||
" <td>0.899029</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>sports</td>\n",
|
||||
" <td>0.989656</td>\n",
|
||||
" <td>0.987783</td>\n",
|
||||
" <td>0.988719</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" label precision recall f1\n",
|
||||
"0 culture 0.911389 0.963783 0.936854\n",
|
||||
"1 diverse 0.923255 0.976289 0.949032\n",
|
||||
"2 economy 0.918719 0.848450 0.882188\n",
|
||||
"3 politics 0.911776 0.886633 0.899029\n",
|
||||
"4 sports 0.989656 0.987783 0.988719"
|
||||
]
|
||||
},
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"accuracy = accuracy_score(df_test[label_col], preds)\n",
|
||||
"precision = precision_score(df_test[label_col], preds, average=None)\n",
|
||||
"recall = recall_score(df_test[label_col], preds, average=None)\n",
|
||||
"f1 = f1_score(df_test[label_col], preds, average=None)\n",
|
||||
"\n",
|
||||
"print(\"\\n accuracy: {:.6f}\".format(accuracy))\n",
|
||||
"pd.DataFrame({\"label\": labels, \"precision\": precision, \"recall\": recall, \"f1\": f1})"
|
||||
"print(classification_report(df_test[label_col], preds, target_names=labels))"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -46,13 +46,11 @@ class Tokenizer:
|
|||
self.language = language
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Uses a BERT tokenizer
|
||||
|
||||
"""Tokenizes a list of documents using a BERT tokenizer
|
||||
Args:
|
||||
text (list): [description]
|
||||
|
||||
text (list(str)): list of text documents.
|
||||
Returns:
|
||||
[list]: [description]
|
||||
[list(str)]: list of token lists.
|
||||
"""
|
||||
tokens = [self.tokenizer.tokenize(x) for x in text]
|
||||
return tokens
|
||||
|
|
|
@ -140,7 +140,7 @@ class BERTSequenceClassifier:
|
|||
epoch + 1,
|
||||
num_epochs,
|
||||
i + 1,
|
||||
(i + 1 + (num_batches // 10)) % num_batches,
|
||||
min(i + 1 + num_batches // 10, num_batches),
|
||||
num_batches,
|
||||
loss.data,
|
||||
)
|
||||
|
|
|
@ -7,7 +7,6 @@ https://www.nyu.edu/projects/bowman/multinli/
|
|||
|
||||
import os
|
||||
import pandas as pd
|
||||
import requests
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
|
||||
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
|
||||
|
@ -19,21 +18,20 @@ DATA_FILES = {
|
|||
|
||||
|
||||
def load_pandas_df(local_cache_path=None, file_split="train"):
|
||||
"""Downloads and extracts the dataset files
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description]. Defaults to None.
|
||||
file_split (str, optional): The subset to load.
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the specified MultiNLI subset.
|
||||
pd.DataFrame: pandas DataFrame containing the specified
|
||||
MultiNLI subset.
|
||||
"""
|
||||
|
||||
file_name = URL.split("/")[-1]
|
||||
if not os.path.exists(os.path.join(local_cache_path, file_name)):
|
||||
response = requests.get(URL)
|
||||
with open(os.path.join(local_cache_path, file_name), "wb") as f:
|
||||
f.write(response.content)
|
||||
maybe_download(URL, file_name, local_cache_path)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split])
|
||||
):
|
||||
|
|
Загрузка…
Ссылка в новой задаче