Коммит
a514025f5d
|
@ -0,0 +1,519 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
|
||||
"\n",
|
||||
"*Licensed under the MIT License.*\n",
|
||||
"\n",
|
||||
"# Classification of Arabic News Articles using BERT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from utils_nlp.dataset.dac import load_pandas_df\n",
|
||||
"from utils_nlp.eval.classification import eval_classification\n",
|
||||
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
|
||||
"from utils_nlp.bert.common import Language, Tokenizer\n",
|
||||
"from utils_nlp.common.timer import Timer\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Introduction\n",
|
||||
"In this notebook, we fine-tune and evaluate a pretrained [BERT](https://arxiv.org/abs/1810.04805) model on an Arabic dataset of news articles. The [dataset](https://data.mendeley.com/datasets/v524p5dhpj/2) includes articles from 3 different newspapers, and the articles are categorized into 5 classes: *sports, politics, culture, economy and diverse*. The data is described in more detail in this [paper](http://article.nadiapub.com/IJGDC/vol11_no9/9.pdf).\n",
|
||||
"\n",
|
||||
"We use a [sequence classifier](../../utils_nlp/bert/sequence_classification.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT) of Google's [BERT](https://github.com/google-research/bert). The classifier loads a pretrained [multilingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) that was trained on 104 languages, including Arabic."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"../../../temp\"\n",
|
||||
"BERT_CACHE_DIR = \"../../../temp\"\n",
|
||||
"LANGUAGE = Language.MULTILINGUAL\n",
|
||||
"MAX_LEN = 200\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"NUM_GPUS = 2\n",
|
||||
"NUM_EPOCHS = 1\n",
|
||||
"TRAIN_SIZE = 0.75"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Read Dataset\n",
|
||||
"We start by loading the data. The following lines also download the file if it doesn't exist, and extract the csv file into the specified data folder."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = load_pandas_df(DATA_FOLDER)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>targe</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>بين أستوديوهات ورزازات وصحراء مرزوكة وآثار ولي...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>قررت النجمة الأمريكية أوبرا وينفري ألا يقتصر ع...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>أخبارنا المغربية الوزاني تصوير الشملالي ألهب ا...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>اخبارنا المغربية قال ابراهيم الراشدي محامي سعد...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>تزال صناعة الجلود في المغرب تتبع الطريقة التقل...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" text targe\n",
|
||||
"0 بين أستوديوهات ورزازات وصحراء مرزوكة وآثار ولي... 0\n",
|
||||
"1 قررت النجمة الأمريكية أوبرا وينفري ألا يقتصر ع... 0\n",
|
||||
"2 أخبارنا المغربية الوزاني تصوير الشملالي ألهب ا... 0\n",
|
||||
"3 اخبارنا المغربية قال ابراهيم الراشدي محامي سعد... 0\n",
|
||||
"4 تزال صناعة الجلود في المغرب تتبع الطريقة التقل... 0"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set the text and label columns\n",
|
||||
"text_col = df.columns[0]\n",
|
||||
"label_col = df.columns[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Inspect the distribution of labels:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4 46522\n",
|
||||
"3 20505\n",
|
||||
"1 16728\n",
|
||||
"2 14235\n",
|
||||
"0 13738\n",
|
||||
"Name: targe, dtype: int64"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df[label_col].value_counts()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We compare the counts with those presented in the author's [paper](http://article.nadiapub.com/IJGDC/vol11_no9/9.pdf), and infer the following label mapping:\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>culture</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>diverse</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>economy</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>politics</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>sports</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" label\n",
|
||||
"0 culture\n",
|
||||
"1 diverse\n",
|
||||
"2 economy\n",
|
||||
"3 politics\n",
|
||||
"4 sports"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ordered list of labels\n",
|
||||
"labels = [\"culture\", \"diverse\", \"economy\", \"politics\", \"sports\"]\n",
|
||||
"num_labels = len(labels)\n",
|
||||
"pd.DataFrame({\"label\": labels})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we split the data for training and testing:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of training examples: 83796\n",
|
||||
"Number of testing examples: 27932\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df_train, df_test = train_test_split(df, train_size = TRAIN_SIZE, random_state=0)\n",
|
||||
"print(\"Number of training examples: {}\".format(df_train.shape[0]))\n",
|
||||
"print(\"Number of testing examples: {}\".format(df_test.shape[0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tokenize and Preprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(LANGUAGE, cache_dir=BERT_CACHE_DIR)\n",
|
||||
"tokens_train = tokenizer.tokenize(df_train[text_col].astype(str))\n",
|
||||
"tokens_test = tokenizer.tokenize(df_test[text_col].astype(str))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_test, MAX_LEN\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Model\n",
|
||||
"Next, we create a sequence classifier that loads a pre-trained BERT model, given the language and number of labels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = BERTSequenceClassifier(\n",
|
||||
" language=LANGUAGE, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train\n",
|
||||
"We train the classifier using the training examples. This involves fine-tuning the BERT Transformer and learning a linear classification layer on top of that:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"t_total value of -1 results in schedule not being applied\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->262/2618; loss:1.632097\n",
|
||||
"epoch:1/1; batch:263->524/2618; loss:0.402912\n",
|
||||
"epoch:1/1; batch:525->786/2618; loss:0.324510\n",
|
||||
"epoch:1/1; batch:787->1048/2618; loss:0.477946\n",
|
||||
"epoch:1/1; batch:1049->1310/2618; loss:0.333729\n",
|
||||
"epoch:1/1; batch:1311->1572/2618; loss:0.021917\n",
|
||||
"epoch:1/1; batch:1573->1834/2618; loss:0.031262\n",
|
||||
"epoch:1/1; batch:1835->2096/2618; loss:0.264172\n",
|
||||
"epoch:1/1; batch:2097->2358/2618; loss:0.034074\n",
|
||||
"epoch:1/1; batch:2359->2618/2618; loss:0.033827\n",
|
||||
"[Training time: 1.400 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with Timer() as t:\n",
|
||||
" classifier.fit(\n",
|
||||
" token_ids=tokens_train,\n",
|
||||
" input_mask=mask_train,\n",
|
||||
" labels=list(df_train[label_col]), \n",
|
||||
" num_gpus=NUM_GPUS, \n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE, \n",
|
||||
" verbose=True,\n",
|
||||
" ) \n",
|
||||
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Score\n",
|
||||
"We score the test set using the trained classifier:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"27936it [08:35, 56.23it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = classifier.predict(\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate Results\n",
|
||||
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" culture 0.96 0.93 0.94 3479\n",
|
||||
" diverse 0.92 0.98 0.95 4091\n",
|
||||
" economy 0.91 0.88 0.89 3517\n",
|
||||
" politics 0.90 0.90 0.90 5222\n",
|
||||
" sports 0.99 0.99 0.99 11623\n",
|
||||
"\n",
|
||||
" accuracy 0.95 27932\n",
|
||||
" macro avg 0.94 0.94 0.94 27932\n",
|
||||
"weighted avg 0.95 0.95 0.95 27932\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(classification_report(df_test[label_col], preds, target_names=labels))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -21,7 +21,7 @@
|
|||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from utils_nlp.dataset.multinli import load_pandas_df\n",
|
||||
|
@ -46,12 +46,12 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"./temp\"\n",
|
||||
"BERT_CACHE_DIR = \"./temp\"\n",
|
||||
"DATA_FOLDER = \"../../../temp\"\n",
|
||||
"BERT_CACHE_DIR = \"../../../temp\"\n",
|
||||
"LANGUAGE = Language.ENGLISH\n",
|
||||
"TO_LOWER = True\n",
|
||||
"MAX_LEN = 150\n",
|
||||
|
@ -77,7 +77,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -87,7 +87,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -154,7 +154,7 @@
|
|||
"13 travel Thebes held onto power until the 12th Dynasty,..."
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -172,7 +172,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -186,7 +186,7 @@
|
|||
"Name: genre, dtype: int64"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -204,7 +204,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -221,7 +221,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -256,7 +256,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -272,7 +272,7 @@
|
|||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add sentence markers\n",
|
||||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"\n",
|
||||
|
@ -281,7 +281,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -303,7 +303,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -322,7 +322,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -338,17 +338,17 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->246/2454; loss:1.631739\n",
|
||||
"epoch:1/1; batch:247->492/2454; loss:0.427608\n",
|
||||
"epoch:1/1; batch:493->738/2454; loss:0.255493\n",
|
||||
"epoch:1/1; batch:739->984/2454; loss:0.286230\n",
|
||||
"epoch:1/1; batch:985->1230/2454; loss:0.375268\n",
|
||||
"epoch:1/1; batch:1231->1476/2454; loss:0.146290\n",
|
||||
"epoch:1/1; batch:1477->1722/2454; loss:0.092100\n",
|
||||
"epoch:1/1; batch:1723->1968/2454; loss:0.009405\n",
|
||||
"epoch:1/1; batch:1969->2214/2454; loss:0.038235\n",
|
||||
"epoch:1/1; batch:2215->2460/2454; loss:0.098216\n",
|
||||
"[Training time: 0.981 hrs]\n"
|
||||
"epoch:1/1; batch:1->246/2454; loss:1.824086\n",
|
||||
"epoch:1/1; batch:247->492/2454; loss:0.446337\n",
|
||||
"epoch:1/1; batch:493->738/2454; loss:0.298814\n",
|
||||
"epoch:1/1; batch:739->984/2454; loss:0.265785\n",
|
||||
"epoch:1/1; batch:985->1230/2454; loss:0.101790\n",
|
||||
"epoch:1/1; batch:1231->1476/2454; loss:0.251120\n",
|
||||
"epoch:1/1; batch:1477->1722/2454; loss:0.040894\n",
|
||||
"epoch:1/1; batch:1723->1968/2454; loss:0.038339\n",
|
||||
"epoch:1/1; batch:1969->2214/2454; loss:0.021586\n",
|
||||
"epoch:1/1; batch:2215->2454/2454; loss:0.130719\n",
|
||||
"[Training time: 0.980 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -376,14 +376,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"52384it [11:54, 88.50it/s] \n"
|
||||
"52384it [11:54, 88.97it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -403,112 +403,36 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" accuracy: 0.938273\n"
|
||||
" fiction 0.90 0.94 0.92 10275\n",
|
||||
" government 0.97 0.93 0.95 10292\n",
|
||||
" slate 0.88 0.85 0.87 10277\n",
|
||||
" telephone 0.99 1.00 0.99 11205\n",
|
||||
" travel 0.95 0.97 0.96 10311\n",
|
||||
"\n",
|
||||
" accuracy 0.94 52360\n",
|
||||
" macro avg 0.94 0.94 0.94 52360\n",
|
||||
"weighted avg 0.94 0.94 0.94 52360\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>label</th>\n",
|
||||
" <th>precision</th>\n",
|
||||
" <th>recall</th>\n",
|
||||
" <th>f1</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>fiction</td>\n",
|
||||
" <td>0.917004</td>\n",
|
||||
" <td>0.925839</td>\n",
|
||||
" <td>0.921401</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>government</td>\n",
|
||||
" <td>0.961477</td>\n",
|
||||
" <td>0.928780</td>\n",
|
||||
" <td>0.944845</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>slate</td>\n",
|
||||
" <td>0.875161</td>\n",
|
||||
" <td>0.861535</td>\n",
|
||||
" <td>0.868295</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>telephone</td>\n",
|
||||
" <td>0.989105</td>\n",
|
||||
" <td>0.996609</td>\n",
|
||||
" <td>0.992843</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>travel</td>\n",
|
||||
" <td>0.943405</td>\n",
|
||||
" <td>0.973232</td>\n",
|
||||
" <td>0.958087</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" label precision recall f1\n",
|
||||
"0 fiction 0.917004 0.925839 0.921401\n",
|
||||
"1 government 0.961477 0.928780 0.944845\n",
|
||||
"2 slate 0.875161 0.861535 0.868295\n",
|
||||
"3 telephone 0.989105 0.996609 0.992843\n",
|
||||
"4 travel 0.943405 0.973232 0.958087"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"accuracy = accuracy_score(labels_test, preds)\n",
|
||||
"precision = precision_score(labels_test, preds, average=None)\n",
|
||||
"recall = recall_score(labels_test, preds, average=None)\n",
|
||||
"f1 = f1_score(labels_test, preds, average=None)\n",
|
||||
"\n",
|
||||
"print(\"\\n accuracy: {:.6f}\".format(accuracy))\n",
|
||||
"pd.DataFrame({\"label\": label_encoder.classes_, \"precision\": precision, \"recall\": recall, \"f1\": f1})"
|
||||
"print(classification_report(labels_test, preds, target_names=label_encoder.classes_))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.5",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -522,7 +446,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.5"
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -1,424 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
|
||||
"\n",
|
||||
"*Licensed under the MIT License.*\n",
|
||||
"\n",
|
||||
"# Text Classification of Yahoo Answers using BERT\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
||||
"import utils_nlp.dataset.yahoo_answers as ya_dataset\n",
|
||||
"from utils_nlp.eval.classification import eval_classification\n",
|
||||
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
|
||||
"from utils_nlp.bert.common import Language, Tokenizer\n",
|
||||
"from utils_nlp.common.timer import Timer\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"./temp\"\n",
|
||||
"TRAIN_FILE = \"yahoo_answers_csv/train.csv\"\n",
|
||||
"TEST_FILE = \"yahoo_answers_csv/test.csv\"\n",
|
||||
"BERT_CACHE_DIR = \"./temp\"\n",
|
||||
"MAX_LEN = 250\n",
|
||||
"BATCH_SIZE = 16\n",
|
||||
"NUM_GPUS = 2\n",
|
||||
"NUM_EPOCHS = 1\n",
|
||||
"NUM_ROWS_TRAIN = 50000 # number of training examples to read\n",
|
||||
"NUM_ROWS_TEST = 20000 # number of test examples to read"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Download Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not os.path.exists(DATA_FOLDER):\n",
|
||||
" os.mkdir(DATA_FOLDER)\n",
|
||||
"ya_dataset.download(DATA_FOLDER)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Read Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read data\n",
|
||||
"df_train = ya_dataset.read_data(\n",
|
||||
" os.path.join(DATA_FOLDER, TRAIN_FILE), nrows=NUM_ROWS_TRAIN\n",
|
||||
")\n",
|
||||
"df_test = ya_dataset.read_data(\n",
|
||||
" os.path.join(DATA_FOLDER, TEST_FILE), nrows=NUM_ROWS_TEST\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# get labels\n",
|
||||
"labels_train = ya_dataset.get_labels(df_train)\n",
|
||||
"labels_test = ya_dataset.get_labels(df_test)\n",
|
||||
"\n",
|
||||
"num_labels = len(np.unique(labels_train))\n",
|
||||
"\n",
|
||||
"# get text\n",
|
||||
"text_train = ya_dataset.get_text(df_train)\n",
|
||||
"text_test = ya_dataset.get_text(df_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tokenize and Preprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and test sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(Language.ENGLISH, to_lower=True, cache_dir=BERT_CACHE_DIR)\n",
|
||||
"\n",
|
||||
"# tokenize\n",
|
||||
"tokens_train = tokenizer.tokenize(text_train)\n",
|
||||
"tokens_test = tokenizer.tokenize(text_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add sentence markers\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_test, MAX_LEN\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = BERTSequenceClassifier(\n",
|
||||
" language=Language.ENGLISH, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"t_total value of -1 results in schedule not being applied\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->313/3125; loss:2.469508\n",
|
||||
"epoch:1/1; batch:314->626/3125; loss:1.179081\n",
|
||||
"epoch:1/1; batch:627->939/3125; loss:0.677443\n",
|
||||
"epoch:1/1; batch:940->1252/3125; loss:1.689727\n",
|
||||
"epoch:1/1; batch:1253->1565/3125; loss:0.781167\n",
|
||||
"epoch:1/1; batch:1566->1878/3125; loss:1.036024\n",
|
||||
"epoch:1/1; batch:1879->2191/3125; loss:0.909294\n",
|
||||
"epoch:1/1; batch:2192->2504/3125; loss:0.441344\n",
|
||||
"epoch:1/1; batch:2505->2817/3125; loss:0.823389\n",
|
||||
"epoch:1/1; batch:2818->3130/3125; loss:1.036229\n",
|
||||
"[Training time: 1.132 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# train\n",
|
||||
"with Timer() as t:\n",
|
||||
" classifier.fit(\n",
|
||||
" token_ids=tokens_train,\n",
|
||||
" input_mask=mask_train,\n",
|
||||
" labels=labels_train, \n",
|
||||
" num_gpus=NUM_GPUS, \n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE, \n",
|
||||
" verbose=True,\n",
|
||||
" ) \n",
|
||||
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Score Test Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 20000/20000 [08:00<00:00, 41.85it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = classifier.predict(\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate Results\n",
|
||||
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" accuracy: 0.6564\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>precision</th>\n",
|
||||
" <th>recall</th>\n",
|
||||
" <th>f1</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.592506</td>\n",
|
||||
" <td>0.497053</td>\n",
|
||||
" <td>0.540598</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.749070</td>\n",
|
||||
" <td>0.673518</td>\n",
|
||||
" <td>0.709288</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.789308</td>\n",
|
||||
" <td>0.680955</td>\n",
|
||||
" <td>0.731139</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.561592</td>\n",
|
||||
" <td>0.440535</td>\n",
|
||||
" <td>0.493752</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.854772</td>\n",
|
||||
" <td>0.789272</td>\n",
|
||||
" <td>0.820717</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>0.885998</td>\n",
|
||||
" <td>0.847659</td>\n",
|
||||
" <td>0.866404</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>0.425440</td>\n",
|
||||
" <td>0.687416</td>\n",
|
||||
" <td>0.525592</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>0.756364</td>\n",
|
||||
" <td>0.700337</td>\n",
|
||||
" <td>0.727273</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>0.826006</td>\n",
|
||||
" <td>0.485432</td>\n",
|
||||
" <td>0.611496</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>0.756186</td>\n",
|
||||
" <td>0.731039</td>\n",
|
||||
" <td>0.743400</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" precision recall f1\n",
|
||||
"0 0.592506 0.497053 0.540598\n",
|
||||
"1 0.749070 0.673518 0.709288\n",
|
||||
"2 0.789308 0.680955 0.731139\n",
|
||||
"3 0.561592 0.440535 0.493752\n",
|
||||
"4 0.854772 0.789272 0.820717\n",
|
||||
"5 0.885998 0.847659 0.866404\n",
|
||||
"6 0.425440 0.687416 0.525592\n",
|
||||
"7 0.756364 0.700337 0.727273\n",
|
||||
"8 0.826006 0.485432 0.611496\n",
|
||||
"9 0.756186 0.731039 0.743400"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"accuracy = accuracy_score(labels_test, preds)\n",
|
||||
"precision = precision_score(labels_test, preds, average=None)\n",
|
||||
"recall = recall_score(labels_test, preds, average=None)\n",
|
||||
"f1 = f1_score(labels_test, preds, average=None)\n",
|
||||
"\n",
|
||||
"print(\"\\n accuracy: {}\".format(accuracy))\n",
|
||||
"pd.DataFrame({\"precision\": precision, \"recall\": recall, \"f1\": f1})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.5",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -46,13 +46,11 @@ class Tokenizer:
|
|||
self.language = language
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Uses a BERT tokenizer
|
||||
|
||||
"""Tokenizes a list of documents using a BERT tokenizer
|
||||
Args:
|
||||
text (list): [description]
|
||||
|
||||
text (list(str)): list of text documents.
|
||||
Returns:
|
||||
[list]: [description]
|
||||
[list(str)]: list of token lists.
|
||||
"""
|
||||
tokens = [self.tokenizer.tokenize(x) for x in text]
|
||||
return tokens
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# This script reuses some code from
|
||||
# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils_nlp.bert.common import Language
|
||||
from utils_nlp.pytorch.device_utils import get_device, move_to_device
|
||||
|
||||
|
@ -26,7 +31,7 @@ class BERTSequenceClassifier:
|
|||
Defaults to ".".
|
||||
"""
|
||||
if num_labels < 2:
|
||||
raise Exception("Number of labels should be at least 2.")
|
||||
raise ValueError("Number of labels should be at least 2.")
|
||||
|
||||
self.language = language
|
||||
self.num_labels = num_labels
|
||||
|
@ -135,7 +140,7 @@ class BERTSequenceClassifier:
|
|||
epoch + 1,
|
||||
num_epochs,
|
||||
i + 1,
|
||||
i + 1 + (num_batches // 10),
|
||||
min(i + 1 + num_batches // 10, num_batches),
|
||||
num_batches,
|
||||
loss.data,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Dataset for Arabic Classification utils
|
||||
https://data.mendeley.com/datasets/v524p5dhpj/2
|
||||
Mohamed, BINIZ (2018), “DataSet for Arabic Classification”, Mendeley Data, v2
|
||||
paper link: ("https://www.mendeley.com/catalogue/
|
||||
arabic-text-classification-using-deep-learning-technics/")
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
|
||||
URL = (
|
||||
"https://data.mendeley.com/datasets/v524p5dhpj/2"
|
||||
"/files/91cb8398-9451-43af-88fc-041a0956ae2d/"
|
||||
"arabic_dataset_classifiction.csv.zip"
|
||||
)
|
||||
|
||||
|
||||
def load_pandas_df(local_cache_path=None):
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description]. Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the loaded dataset.
|
||||
"""
|
||||
zip_file = URL.split("/")[-1]
|
||||
csv_file_path = os.path.join(
|
||||
local_cache_path, zip_file.replace(".zip", "")
|
||||
)
|
||||
|
||||
maybe_download(URL, zip_file, local_cache_path)
|
||||
|
||||
if not os.path.exists(csv_file_path):
|
||||
extract_zip(file_path=csv_file_path, dest_path=local_cache_path)
|
||||
return pd.read_csv(csv_file_path)
|
|
@ -7,7 +7,6 @@ https://www.nyu.edu/projects/bowman/multinli/
|
|||
|
||||
import os
|
||||
import pandas as pd
|
||||
import requests
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
|
||||
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
|
||||
|
@ -19,21 +18,20 @@ DATA_FILES = {
|
|||
|
||||
|
||||
def load_pandas_df(local_cache_path=None, file_split="train"):
|
||||
"""Downloads and extracts the dataset files
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description]. Defaults to None.
|
||||
file_split (str, optional): The subset to load.
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the specified MultiNLI subset.
|
||||
pd.DataFrame: pandas DataFrame containing the specified
|
||||
MultiNLI subset.
|
||||
"""
|
||||
|
||||
file_name = URL.split("/")[-1]
|
||||
if not os.path.exists(os.path.join(local_cache_path, file_name)):
|
||||
response = requests.get(URL)
|
||||
with open(os.path.join(local_cache_path, file_name), "wb") as f:
|
||||
f.write(response.content)
|
||||
maybe_download(URL, file_name, local_cache_path)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split])
|
||||
):
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import requests
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import tarfile
|
||||
import zipfile
|
||||
from contextlib import contextmanager
|
||||
from tempfile import TemporaryDirectory
|
||||
from tqdm import tqdm
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -17,13 +19,12 @@ def maybe_download(
|
|||
url, filename=None, work_directory=".", expected_bytes=None
|
||||
):
|
||||
"""Download a file if it is not already downloaded.
|
||||
|
||||
|
||||
Args:
|
||||
filename (str): File name.
|
||||
work_directory (str): Working directory.
|
||||
url (str): URL of the file to download.
|
||||
expected_bytes (int): Expected file size in bytes.
|
||||
|
||||
Returns:
|
||||
str: File path of the file downloaded.
|
||||
"""
|
||||
|
@ -80,7 +81,7 @@ def extract_zip(file_path, dest_path="."):
|
|||
raise IOError("File doesn't exist")
|
||||
if not os.path.exists(dest_path):
|
||||
raise IOError("Destination directory doesn't exist")
|
||||
with ZipFile(file_path) as z:
|
||||
with zipfile.ZipFile(file_path) as z:
|
||||
z.extractall(path=dest_path)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче