Resolved conflict and merged staging.
This commit is contained in:
Коммит
dfb8553c5b
|
@ -20,14 +20,15 @@ steps:
|
|||
|
||||
- bash: |
|
||||
conda remove -q -n nlp --all -y
|
||||
conda env create -f environment.yml
|
||||
python tools/generate_conda_file.py --gpu
|
||||
conda env create -n nlp_gpu -f nlp_gpu.yaml
|
||||
conda env list
|
||||
source activate nlp
|
||||
source activate nlp_gpu
|
||||
displayName: 'Creating Conda Environment with dependencies'
|
||||
|
||||
- bash: |
|
||||
source activate nlp
|
||||
python -m ipykernel install --user --name nlp --display-name "nlp"
|
||||
source activate nlp_gpu
|
||||
python -m ipykernel install --user --name nlp_gpu --display-name "nlp_gpu"
|
||||
# Commenting out pytest since it contains bunch of tests from other project which are not applicable.
|
||||
# But keeping the line here to show how to run it once tests relevant to this project are added
|
||||
# pytest --junitxml=junit/test-unitttest.xml #not running any tests for now
|
||||
|
|
19
README.md
19
README.md
|
@ -6,22 +6,27 @@
|
|||
|
||||
# NLP Best Practices
|
||||
|
||||
This repository will provide examples and best practices for building NLP systems, provided as Jupyter notebooks and utility functions.
|
||||
This repository contains examples and best practices for building NLP systems, provided as Jupyter notebooks and utility functions. The focus of the repository is on state-of-the-art methods and common scenarios that are popular among researchers and practitioners working on problems involving text and language.
|
||||
|
||||
The following section includes a list of the available scenarios. Each scenario is demonstrated in one or more Jupyter notebook examples that make use of the core code base of models and utilities.
|
||||
|
||||
|
||||
## Scenarios
|
||||
|
||||
|
||||
| Scenario | Applications | Languages | Models |
|
||||
|---| ------------------------ | -------------------------------------------- | ------------------- |
|
||||
| Text Classification | Sentiment Analysis <br> Topic Classification | English | BERT, fastText |
|
||||
| Named Entity Recognition | | English | BERT |
|
||||
| Sentence Encoding | Sentence Similarity | English |
|
||||
|[Text Classification](scenarios/text_classification) |Topic Classification|en, zh, ar|BERT|
|
||||
|[Named Entity Recognition](scenarios/named_entity_recognition) |Wikipedia NER | en, zh |BERT|
|
||||
|[Sentence Similarity](scenarios/sentence_similarity) |STS Benchmark |en|Representation: TF-IDF, Word Embeddings, Doc Embeddings<br>Metrics: Cosine Similarity, Word Mover's Distance|
|
||||
|[Embeddings](scenarios/embeddings)| Custom Embeddings Training|en|Word2Vec<br>fastText<br>GloVe|
|
||||
|
||||
|
||||
## Planning etc documents
|
||||
|
||||
All feature planning is done via projects, milestones, and issues in this Github repository.
|
||||
## Planning
|
||||
All feature planning is done via projects, milestones, and issues in this repository.
|
||||
|
||||
## Getting Started
|
||||
To get started, navigate to the [Setup Guide](SETUP.md), where you'll find instructions on how to setup your environment and dependencies.
|
||||
|
||||
## Contributing
|
||||
This project welcomes contributions and suggestions. Before contributing, please see our [contribution guidelines](CONTRIBUTING.md).
|
||||
|
|
10
SETUP.md
10
SETUP.md
|
@ -1,4 +1,4 @@
|
|||
# Setup guide
|
||||
# Setup Guide
|
||||
|
||||
This document describes how to setup all the dependencies to run the notebooks in this repository.
|
||||
|
||||
|
@ -16,12 +16,12 @@ For training at scale, operationalization or hyperparameter tuning, it is recomm
|
|||
* [Register the conda environment in the DSVM JupyterHub](#register-the-conda-environment-in--the-dsvm-jupyterhub)
|
||||
|
||||
|
||||
## Compute environments
|
||||
## Compute Environments
|
||||
|
||||
Depending on the type of NLP system and the notebook that needs to be run, there are different computational requirements. Currently, this repository supports **Python CPU** and **Python GPU**.
|
||||
|
||||
|
||||
## Setup guide for Local or DSVM
|
||||
## Setup Guide for Local or DSVM
|
||||
|
||||
### Requirements
|
||||
|
||||
|
@ -29,7 +29,7 @@ Depending on the type of NLP system and the notebook that needs to be run, there
|
|||
* Anaconda with Python version >= 3.6.
|
||||
* This is pre-installed on Azure DSVM such that one can run the following steps directly. To setup on your local machine, [Miniconda](https://docs.conda.io/en/latest/miniconda.html) is a quick way to get started.
|
||||
|
||||
### Dependencies setup
|
||||
### Dependencies Setup
|
||||
|
||||
|
||||
We provide a script, [generate_conda_file.py](tools/generate_conda_file.py), to generate a conda-environment yaml file
|
||||
|
@ -57,7 +57,7 @@ Assuming that you have a GPU machine, to install the Python GPU environment, whi
|
|||
</details>
|
||||
|
||||
|
||||
### Register the conda environment in the DSVM JupyterHub
|
||||
### Register Conda Environment in DSVM JupyterHub
|
||||
|
||||
We can register our created conda environment to appear as a kernel in the Jupyter notebooks.
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
"import sys\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sys.path.append(\"../../../\")\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
|
@ -52,7 +52,7 @@
|
|||
"from utils_nlp.dataset.preprocess import to_spacy_tokens\n",
|
||||
"from utils_nlp.dataset.url_utils import maybe_download, download_path\n",
|
||||
"\n",
|
||||
"INSTALLER_PATH = '../../../data/'\n",
|
||||
"INSTALLER_PATH = '../../data/'\n",
|
||||
"print(\"System version: {}\".format(sys.version))"
|
||||
]
|
||||
},
|
||||
|
@ -72,13 +72,13 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Data downloaded to ../../../data/MSRParaphraseCorpus.msi \n"
|
||||
"Data downloaded to ../../data/MSRParaphraseCorpus.msi\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"url = \"https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B\" \\\n",
|
||||
" \"-3604ED519838/MSRParaphraseCorpus.msi \"\n",
|
||||
" \"-3604ED519838/MSRParaphraseCorpus.msi\"\n",
|
||||
"data_path = maybe_download(url, work_directory=INSTALLER_PATH)\n",
|
||||
"print(\"Data downloaded to {}\".format(data_path)) "
|
||||
]
|
||||
|
@ -99,7 +99,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The Windows Installer for Mircosoft Paraphrase Corpus has been downloaded at ../../../data/MSRParaphraseCorpus.msi \n",
|
||||
"The Windows Installer for Mircosoft Paraphrase Corpus has been downloaded at ../../data/MSRParaphraseCorpus.msi\n",
|
||||
"Please install and provide the installed directory. Thanks! \n",
|
||||
"C:\\MSRParaphraseCorpus\n",
|
||||
"Dataset successfully installed at C:\\MSRParaphraseCorpus\n"
|
||||
|
@ -383,7 +383,7 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The Windows Installer for Mircosoft Paraphrase Corpus has been downloaded at C:\\Projects\\NLP-BP\\NLP\\data\\MSRParaphraseCorpus.msi \n",
|
||||
"The Windows Installer for Mircosoft Paraphrase Corpus has been downloaded at E:\\Projects\\NLP-BP\\temp\\nlp\\data\\MSRParaphraseCorpus.msi \n",
|
||||
"\n",
|
||||
"Please install and provide the installed directory. Thanks! \n",
|
||||
"C:\\MSRParaphraseCorpus\n"
|
||||
|
@ -479,9 +479,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:nlp]",
|
||||
"display_name": "nlp",
|
||||
"language": "python",
|
||||
"name": "conda-env-nlp-py"
|
||||
"name": "nlp"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
|
|
@ -46,7 +46,7 @@
|
|||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"sys.path.append(\"../../../\")\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
|
@ -62,7 +62,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_DATA_PATH = \"../../../data\""
|
||||
"BASE_DATA_PATH = \"../../data\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -88,7 +88,15 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 92.3k/92.3k [01:01<00:00, 1.50kKB/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# defaults to txt\n",
|
||||
"train = snli.load_pandas_df(BASE_DATA_PATH, file_split=\"train\")\n",
|
||||
|
@ -430,16 +438,7 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[nltk_data] Downloading package punkt to /Users/caseyhong/nltk_data...\n",
|
||||
"[nltk_data] Package punkt is already up-to-date!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_tok = to_nltk_tokens(to_lowercase_all(train))"
|
||||
]
|
||||
|
@ -575,20 +574,7 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[nltk_data] Downloading package punkt to /Users/caseyhong/nltk_data...\n",
|
||||
"[nltk_data] Package punkt is already up-to-date!\n",
|
||||
"[nltk_data] Downloading package punkt to /Users/caseyhong/nltk_data...\n",
|
||||
"[nltk_data] Package punkt is already up-to-date!\n",
|
||||
"[nltk_data] Downloading package punkt to /Users/caseyhong/nltk_data...\n",
|
||||
"[nltk_data] Package punkt is already up-to-date!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train = snli.load_pandas_df(BASE_DATA_PATH, file_split=\"train\")\n",
|
||||
"dev = snli.load_pandas_df(BASE_DATA_PATH, file_split=\"dev\")\n",
|
||||
|
@ -658,9 +644,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python (nlp_cpu)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "nlp_cpu"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -672,7 +658,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.5"
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Rather than use pre-trained embeddings (as we did in the baseline_deep_dive notebook), we can train word embeddings using our own dataset. In this notebook, we demonstrate the training process for producing word embeddings using the word2vec, GloVe, and fastText models. We'll utilize the STS Benchmark dataset for this task. "
|
||||
"Rather than use pre-trained embeddings (as we did in the sentence similarity baseline_deep_dive [notebook](../sentence_similarity/baseline_deep_dive.ipynb)), we can train word embeddings using our own dataset. In this notebook, we demonstrate the training process for producing word embeddings using the word2vec, GloVe, and fastText models. We'll utilize the STS Benchmark dataset for this task. "
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -890,7 +890,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.5"
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -0,0 +1,519 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
|
||||
"\n",
|
||||
"*Licensed under the MIT License.*\n",
|
||||
"\n",
|
||||
"# Classification of Arabic News Articles using BERT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from utils_nlp.dataset.dac import load_pandas_df\n",
|
||||
"from utils_nlp.eval.classification import eval_classification\n",
|
||||
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
|
||||
"from utils_nlp.bert.common import Language, Tokenizer\n",
|
||||
"from utils_nlp.common.timer import Timer\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Introduction\n",
|
||||
"In this notebook, we fine-tune and evaluate a pretrained [BERT](https://arxiv.org/abs/1810.04805) model on an Arabic dataset of news articles. The [dataset](https://data.mendeley.com/datasets/v524p5dhpj/2) includes articles from 3 different newspapers, and the articles are categorized into 5 classes: *sports, politics, culture, economy and diverse*. The data is described in more detail in this [paper](http://article.nadiapub.com/IJGDC/vol11_no9/9.pdf).\n",
|
||||
"\n",
|
||||
"We use a [sequence classifier](../../utils_nlp/bert/sequence_classification.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT) of Google's [BERT](https://github.com/google-research/bert). The classifier loads a pretrained [multilingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) that was trained on 104 languages, including Arabic."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"../../../temp\"\n",
|
||||
"BERT_CACHE_DIR = \"../../../temp\"\n",
|
||||
"LANGUAGE = Language.MULTILINGUAL\n",
|
||||
"MAX_LEN = 200\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"NUM_GPUS = 2\n",
|
||||
"NUM_EPOCHS = 1\n",
|
||||
"TRAIN_SIZE = 0.75"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Read Dataset\n",
|
||||
"We start by loading the data. The following lines also download the file if it doesn't exist, and extract the csv file into the specified data folder."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = load_pandas_df(DATA_FOLDER)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>targe</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>بين أستوديوهات ورزازات وصحراء مرزوكة وآثار ولي...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>قررت النجمة الأمريكية أوبرا وينفري ألا يقتصر ع...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>أخبارنا المغربية الوزاني تصوير الشملالي ألهب ا...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>اخبارنا المغربية قال ابراهيم الراشدي محامي سعد...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>تزال صناعة الجلود في المغرب تتبع الطريقة التقل...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" text targe\n",
|
||||
"0 بين أستوديوهات ورزازات وصحراء مرزوكة وآثار ولي... 0\n",
|
||||
"1 قررت النجمة الأمريكية أوبرا وينفري ألا يقتصر ع... 0\n",
|
||||
"2 أخبارنا المغربية الوزاني تصوير الشملالي ألهب ا... 0\n",
|
||||
"3 اخبارنا المغربية قال ابراهيم الراشدي محامي سعد... 0\n",
|
||||
"4 تزال صناعة الجلود في المغرب تتبع الطريقة التقل... 0"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# set the text and label columns\n",
|
||||
"text_col = df.columns[0]\n",
|
||||
"label_col = df.columns[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Inspect the distribution of labels:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"4 46522\n",
|
||||
"3 20505\n",
|
||||
"1 16728\n",
|
||||
"2 14235\n",
|
||||
"0 13738\n",
|
||||
"Name: targe, dtype: int64"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df[label_col].value_counts()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We compare the counts with those presented in the author's [paper](http://article.nadiapub.com/IJGDC/vol11_no9/9.pdf), and infer the following label mapping:\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>culture</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>diverse</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>economy</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>politics</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>sports</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" label\n",
|
||||
"0 culture\n",
|
||||
"1 diverse\n",
|
||||
"2 economy\n",
|
||||
"3 politics\n",
|
||||
"4 sports"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# ordered list of labels\n",
|
||||
"labels = [\"culture\", \"diverse\", \"economy\", \"politics\", \"sports\"]\n",
|
||||
"num_labels = len(labels)\n",
|
||||
"pd.DataFrame({\"label\": labels})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we split the data for training and testing:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of training examples: 83796\n",
|
||||
"Number of testing examples: 27932\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df_train, df_test = train_test_split(df, train_size = TRAIN_SIZE, random_state=0)\n",
|
||||
"print(\"Number of training examples: {}\".format(df_train.shape[0]))\n",
|
||||
"print(\"Number of testing examples: {}\".format(df_test.shape[0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tokenize and Preprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(LANGUAGE, cache_dir=BERT_CACHE_DIR)\n",
|
||||
"tokens_train = tokenizer.tokenize(df_train[text_col].astype(str))\n",
|
||||
"tokens_test = tokenizer.tokenize(df_test[text_col].astype(str))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_test, MAX_LEN\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Model\n",
|
||||
"Next, we create a sequence classifier that loads a pre-trained BERT model, given the language and number of labels."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = BERTSequenceClassifier(\n",
|
||||
" language=LANGUAGE, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train\n",
|
||||
"We train the classifier using the training examples. This involves fine-tuning the BERT Transformer and learning a linear classification layer on top of that:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"t_total value of -1 results in schedule not being applied\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->262/2618; loss:1.632097\n",
|
||||
"epoch:1/1; batch:263->524/2618; loss:0.402912\n",
|
||||
"epoch:1/1; batch:525->786/2618; loss:0.324510\n",
|
||||
"epoch:1/1; batch:787->1048/2618; loss:0.477946\n",
|
||||
"epoch:1/1; batch:1049->1310/2618; loss:0.333729\n",
|
||||
"epoch:1/1; batch:1311->1572/2618; loss:0.021917\n",
|
||||
"epoch:1/1; batch:1573->1834/2618; loss:0.031262\n",
|
||||
"epoch:1/1; batch:1835->2096/2618; loss:0.264172\n",
|
||||
"epoch:1/1; batch:2097->2358/2618; loss:0.034074\n",
|
||||
"epoch:1/1; batch:2359->2618/2618; loss:0.033827\n",
|
||||
"[Training time: 1.400 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with Timer() as t:\n",
|
||||
" classifier.fit(\n",
|
||||
" token_ids=tokens_train,\n",
|
||||
" input_mask=mask_train,\n",
|
||||
" labels=list(df_train[label_col]), \n",
|
||||
" num_gpus=NUM_GPUS, \n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE, \n",
|
||||
" verbose=True,\n",
|
||||
" ) \n",
|
||||
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Score\n",
|
||||
"We score the test set using the trained classifier:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"27936it [08:35, 56.23it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = classifier.predict(\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate Results\n",
|
||||
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" culture 0.96 0.93 0.94 3479\n",
|
||||
" diverse 0.92 0.98 0.95 4091\n",
|
||||
" economy 0.91 0.88 0.89 3517\n",
|
||||
" politics 0.90 0.90 0.90 5222\n",
|
||||
" sports 0.99 0.99 0.99 11623\n",
|
||||
"\n",
|
||||
" accuracy 0.95 27932\n",
|
||||
" macro avg 0.94 0.94 0.94 27932\n",
|
||||
"weighted avg 0.95 0.95 0.95 27932\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(classification_report(df_test[label_col], preds, target_names=labels))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -21,7 +21,7 @@
|
|||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from utils_nlp.dataset.multinli import load_pandas_df\n",
|
||||
|
@ -46,12 +46,12 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"./temp\"\n",
|
||||
"BERT_CACHE_DIR = \"./temp\"\n",
|
||||
"DATA_FOLDER = \"../../../temp\"\n",
|
||||
"BERT_CACHE_DIR = \"../../../temp\"\n",
|
||||
"LANGUAGE = Language.ENGLISH\n",
|
||||
"TO_LOWER = True\n",
|
||||
"MAX_LEN = 150\n",
|
||||
|
@ -77,7 +77,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -87,7 +87,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -154,7 +154,7 @@
|
|||
"13 travel Thebes held onto power until the 12th Dynasty,..."
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -172,7 +172,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -186,7 +186,7 @@
|
|||
"Name: genre, dtype: int64"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -204,7 +204,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -221,7 +221,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -256,7 +256,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -272,7 +272,7 @@
|
|||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add sentence markers\n",
|
||||
"- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"- Return mask lists that indicate paddings' positions\n",
|
||||
"\n",
|
||||
|
@ -281,7 +281,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -303,7 +303,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -322,7 +322,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -338,17 +338,17 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->246/2454; loss:1.631739\n",
|
||||
"epoch:1/1; batch:247->492/2454; loss:0.427608\n",
|
||||
"epoch:1/1; batch:493->738/2454; loss:0.255493\n",
|
||||
"epoch:1/1; batch:739->984/2454; loss:0.286230\n",
|
||||
"epoch:1/1; batch:985->1230/2454; loss:0.375268\n",
|
||||
"epoch:1/1; batch:1231->1476/2454; loss:0.146290\n",
|
||||
"epoch:1/1; batch:1477->1722/2454; loss:0.092100\n",
|
||||
"epoch:1/1; batch:1723->1968/2454; loss:0.009405\n",
|
||||
"epoch:1/1; batch:1969->2214/2454; loss:0.038235\n",
|
||||
"epoch:1/1; batch:2215->2460/2454; loss:0.098216\n",
|
||||
"[Training time: 0.981 hrs]\n"
|
||||
"epoch:1/1; batch:1->246/2454; loss:1.824086\n",
|
||||
"epoch:1/1; batch:247->492/2454; loss:0.446337\n",
|
||||
"epoch:1/1; batch:493->738/2454; loss:0.298814\n",
|
||||
"epoch:1/1; batch:739->984/2454; loss:0.265785\n",
|
||||
"epoch:1/1; batch:985->1230/2454; loss:0.101790\n",
|
||||
"epoch:1/1; batch:1231->1476/2454; loss:0.251120\n",
|
||||
"epoch:1/1; batch:1477->1722/2454; loss:0.040894\n",
|
||||
"epoch:1/1; batch:1723->1968/2454; loss:0.038339\n",
|
||||
"epoch:1/1; batch:1969->2214/2454; loss:0.021586\n",
|
||||
"epoch:1/1; batch:2215->2454/2454; loss:0.130719\n",
|
||||
"[Training time: 0.980 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -376,14 +376,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"52384it [11:54, 88.50it/s] \n"
|
||||
"52384it [11:54, 88.97it/s] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -403,112 +403,36 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" accuracy: 0.938273\n"
|
||||
" fiction 0.90 0.94 0.92 10275\n",
|
||||
" government 0.97 0.93 0.95 10292\n",
|
||||
" slate 0.88 0.85 0.87 10277\n",
|
||||
" telephone 0.99 1.00 0.99 11205\n",
|
||||
" travel 0.95 0.97 0.96 10311\n",
|
||||
"\n",
|
||||
" accuracy 0.94 52360\n",
|
||||
" macro avg 0.94 0.94 0.94 52360\n",
|
||||
"weighted avg 0.94 0.94 0.94 52360\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>label</th>\n",
|
||||
" <th>precision</th>\n",
|
||||
" <th>recall</th>\n",
|
||||
" <th>f1</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>fiction</td>\n",
|
||||
" <td>0.917004</td>\n",
|
||||
" <td>0.925839</td>\n",
|
||||
" <td>0.921401</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>government</td>\n",
|
||||
" <td>0.961477</td>\n",
|
||||
" <td>0.928780</td>\n",
|
||||
" <td>0.944845</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>slate</td>\n",
|
||||
" <td>0.875161</td>\n",
|
||||
" <td>0.861535</td>\n",
|
||||
" <td>0.868295</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>telephone</td>\n",
|
||||
" <td>0.989105</td>\n",
|
||||
" <td>0.996609</td>\n",
|
||||
" <td>0.992843</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>travel</td>\n",
|
||||
" <td>0.943405</td>\n",
|
||||
" <td>0.973232</td>\n",
|
||||
" <td>0.958087</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" label precision recall f1\n",
|
||||
"0 fiction 0.917004 0.925839 0.921401\n",
|
||||
"1 government 0.961477 0.928780 0.944845\n",
|
||||
"2 slate 0.875161 0.861535 0.868295\n",
|
||||
"3 telephone 0.989105 0.996609 0.992843\n",
|
||||
"4 travel 0.943405 0.973232 0.958087"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"accuracy = accuracy_score(labels_test, preds)\n",
|
||||
"precision = precision_score(labels_test, preds, average=None)\n",
|
||||
"recall = recall_score(labels_test, preds, average=None)\n",
|
||||
"f1 = f1_score(labels_test, preds, average=None)\n",
|
||||
"\n",
|
||||
"print(\"\\n accuracy: {:.6f}\".format(accuracy))\n",
|
||||
"pd.DataFrame({\"label\": label_encoder.classes_, \"precision\": precision, \"recall\": recall, \"f1\": f1})"
|
||||
"print(classification_report(labels_test, preds, target_names=label_encoder.classes_))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.5",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -522,7 +446,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.5"
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
@ -1,424 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"*Copyright (c) Microsoft Corporation. All rights reserved.*\n",
|
||||
"\n",
|
||||
"*Licensed under the MIT License.*\n",
|
||||
"\n",
|
||||
"# Text Classification of Yahoo Answers using BERT\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"sys.path.append(\"../../\")\n",
|
||||
"import os\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
|
||||
"import utils_nlp.dataset.yahoo_answers as ya_dataset\n",
|
||||
"from utils_nlp.eval.classification import eval_classification\n",
|
||||
"from utils_nlp.bert.sequence_classification import BERTSequenceClassifier\n",
|
||||
"from utils_nlp.bert.common import Language, Tokenizer\n",
|
||||
"from utils_nlp.common.timer import Timer\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FOLDER = \"./temp\"\n",
|
||||
"TRAIN_FILE = \"yahoo_answers_csv/train.csv\"\n",
|
||||
"TEST_FILE = \"yahoo_answers_csv/test.csv\"\n",
|
||||
"BERT_CACHE_DIR = \"./temp\"\n",
|
||||
"MAX_LEN = 250\n",
|
||||
"BATCH_SIZE = 16\n",
|
||||
"NUM_GPUS = 2\n",
|
||||
"NUM_EPOCHS = 1\n",
|
||||
"NUM_ROWS_TRAIN = 50000 # number of training examples to read\n",
|
||||
"NUM_ROWS_TEST = 20000 # number of test examples to read"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Download Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if not os.path.exists(DATA_FOLDER):\n",
|
||||
" os.mkdir(DATA_FOLDER)\n",
|
||||
"ya_dataset.download(DATA_FOLDER)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Read Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# read data\n",
|
||||
"df_train = ya_dataset.read_data(\n",
|
||||
" os.path.join(DATA_FOLDER, TRAIN_FILE), nrows=NUM_ROWS_TRAIN\n",
|
||||
")\n",
|
||||
"df_test = ya_dataset.read_data(\n",
|
||||
" os.path.join(DATA_FOLDER, TEST_FILE), nrows=NUM_ROWS_TEST\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# get labels\n",
|
||||
"labels_train = ya_dataset.get_labels(df_train)\n",
|
||||
"labels_test = ya_dataset.get_labels(df_test)\n",
|
||||
"\n",
|
||||
"num_labels = len(np.unique(labels_train))\n",
|
||||
"\n",
|
||||
"# get text\n",
|
||||
"text_train = ya_dataset.get_text(df_train)\n",
|
||||
"text_test = ya_dataset.get_text(df_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tokenize and Preprocess"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Before training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and test sets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(Language.ENGLISH, to_lower=True, cache_dir=BERT_CACHE_DIR)\n",
|
||||
"\n",
|
||||
"# tokenize\n",
|
||||
"tokens_train = tokenizer.tokenize(text_train)\n",
|
||||
"tokens_test = tokenizer.tokenize(text_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition, we perform the following preprocessing steps in the cell below:\n",
|
||||
"- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
|
||||
"- Add sentence markers\n",
|
||||
"- Pad or truncate the token lists to the specified max length\n",
|
||||
"\n",
|
||||
"*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokens_train, mask_train = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_train, MAX_LEN\n",
|
||||
")\n",
|
||||
"tokens_test, mask_test = tokenizer.preprocess_classification_tokens(\n",
|
||||
" tokens_test, MAX_LEN\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create Model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = BERTSequenceClassifier(\n",
|
||||
" language=Language.ENGLISH, num_labels=num_labels, cache_dir=BERT_CACHE_DIR\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"t_total value of -1 results in schedule not being applied\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch:1/1; batch:1->313/3125; loss:2.469508\n",
|
||||
"epoch:1/1; batch:314->626/3125; loss:1.179081\n",
|
||||
"epoch:1/1; batch:627->939/3125; loss:0.677443\n",
|
||||
"epoch:1/1; batch:940->1252/3125; loss:1.689727\n",
|
||||
"epoch:1/1; batch:1253->1565/3125; loss:0.781167\n",
|
||||
"epoch:1/1; batch:1566->1878/3125; loss:1.036024\n",
|
||||
"epoch:1/1; batch:1879->2191/3125; loss:0.909294\n",
|
||||
"epoch:1/1; batch:2192->2504/3125; loss:0.441344\n",
|
||||
"epoch:1/1; batch:2505->2817/3125; loss:0.823389\n",
|
||||
"epoch:1/1; batch:2818->3130/3125; loss:1.036229\n",
|
||||
"[Training time: 1.132 hrs]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# train\n",
|
||||
"with Timer() as t:\n",
|
||||
" classifier.fit(\n",
|
||||
" token_ids=tokens_train,\n",
|
||||
" input_mask=mask_train,\n",
|
||||
" labels=labels_train, \n",
|
||||
" num_gpus=NUM_GPUS, \n",
|
||||
" num_epochs=NUM_EPOCHS,\n",
|
||||
" batch_size=BATCH_SIZE, \n",
|
||||
" verbose=True,\n",
|
||||
" ) \n",
|
||||
"print(\"[Training time: {:.3f} hrs]\".format(t.interval / 3600))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Score Test Set"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 20000/20000 [08:00<00:00, 41.85it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preds = classifier.predict(\n",
|
||||
" token_ids=tokens_test, input_mask=mask_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evaluate Results\n",
|
||||
"Finally, we compute the accuracy, precision, recall, and F1 metrics of the evaluation on the test set."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" accuracy: 0.6564\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>precision</th>\n",
|
||||
" <th>recall</th>\n",
|
||||
" <th>f1</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>0.592506</td>\n",
|
||||
" <td>0.497053</td>\n",
|
||||
" <td>0.540598</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>0.749070</td>\n",
|
||||
" <td>0.673518</td>\n",
|
||||
" <td>0.709288</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>0.789308</td>\n",
|
||||
" <td>0.680955</td>\n",
|
||||
" <td>0.731139</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>0.561592</td>\n",
|
||||
" <td>0.440535</td>\n",
|
||||
" <td>0.493752</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>0.854772</td>\n",
|
||||
" <td>0.789272</td>\n",
|
||||
" <td>0.820717</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>0.885998</td>\n",
|
||||
" <td>0.847659</td>\n",
|
||||
" <td>0.866404</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>0.425440</td>\n",
|
||||
" <td>0.687416</td>\n",
|
||||
" <td>0.525592</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>0.756364</td>\n",
|
||||
" <td>0.700337</td>\n",
|
||||
" <td>0.727273</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>0.826006</td>\n",
|
||||
" <td>0.485432</td>\n",
|
||||
" <td>0.611496</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>0.756186</td>\n",
|
||||
" <td>0.731039</td>\n",
|
||||
" <td>0.743400</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" precision recall f1\n",
|
||||
"0 0.592506 0.497053 0.540598\n",
|
||||
"1 0.749070 0.673518 0.709288\n",
|
||||
"2 0.789308 0.680955 0.731139\n",
|
||||
"3 0.561592 0.440535 0.493752\n",
|
||||
"4 0.854772 0.789272 0.820717\n",
|
||||
"5 0.885998 0.847659 0.866404\n",
|
||||
"6 0.425440 0.687416 0.525592\n",
|
||||
"7 0.756364 0.700337 0.727273\n",
|
||||
"8 0.826006 0.485432 0.611496\n",
|
||||
"9 0.756186 0.731039 0.743400"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"accuracy = accuracy_score(labels_test, preds)\n",
|
||||
"precision = precision_score(labels_test, preds, average=None)\n",
|
||||
"recall = recall_score(labels_test, preds, average=None)\n",
|
||||
"f1 = f1_score(labels_test, preds, average=None)\n",
|
||||
"\n",
|
||||
"print(\"\\n accuracy: {}\".format(accuracy))\n",
|
||||
"pd.DataFrame({\"precision\": precision, \"recall\": recall, \"f1\": f1})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.5",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -13,7 +13,6 @@
|
|||
|
||||
import argparse
|
||||
import textwrap
|
||||
from sys import platform
|
||||
|
||||
|
||||
HELP_MSG = """
|
||||
|
@ -25,13 +24,15 @@ $ conda env update -f {conda_env}.yaml
|
|||
|
||||
To register the conda environment in Jupyter:
|
||||
$ conda activate {conda_env}
|
||||
$ python -m ipykernel install --user --name {conda_env} --display-name "Python ({conda_env})"
|
||||
$ python -m ipykernel install --user --name {conda_env} \
|
||||
--display-name "Python ({conda_env})"
|
||||
"""
|
||||
|
||||
CHANNELS = ["defaults", "conda-forge", "pytorch"]
|
||||
|
||||
CONDA_BASE = {
|
||||
"python": "python==3.6.8",
|
||||
"pip": "pip>=19.1.1",
|
||||
"gitpython": "gitpython>=2.1.8",
|
||||
"ipykernel": "ipykernel>=4.6.1",
|
||||
"jupyter": "jupyter>=1.0.0",
|
||||
|
@ -53,7 +54,9 @@ CONDA_GPU = {
|
|||
}
|
||||
|
||||
PIP_BASE = {
|
||||
"azureml-sdk[notebooks,tensorboard]": "azureml-sdk[notebooks,tensorboard]==1.0.33",
|
||||
"azureml-sdk[notebooks,tensorboard]": (
|
||||
"azureml-sdk[notebooks,tensorboard]==1.0.33"
|
||||
),
|
||||
"azureml-dataprep": "azureml-dataprep==1.1.4",
|
||||
"black": "black>=18.6b4",
|
||||
"papermill": "papermill==0.18.2",
|
||||
|
@ -62,15 +65,18 @@ PIP_BASE = {
|
|||
"pyemd": "pyemd==0.5.1",
|
||||
"ipywebrtc": "ipywebrtc==0.4.3",
|
||||
"pre-commit": "pre-commit>=1.14.4",
|
||||
"spacy-models": "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz",
|
||||
"spacy": "spacy>=2.1.4",
|
||||
"spacy-models": (
|
||||
"https://github.com/explosion/spacy-models/releases/download/"
|
||||
"en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz"
|
||||
),
|
||||
"gensim": "gensim>=3.7.0",
|
||||
"nltk": "nltk>=3.4",
|
||||
"pytorch-pretrained-bert": "pytorch-pretrained-bert>=0.6",
|
||||
"horovod": "horovod>=0.16.1",
|
||||
"seqeval": "seqeval>=0.0.12",
|
||||
}
|
||||
|
||||
PIP_GPU = {}
|
||||
PIP_GPU = {"horovod": "horovod>=0.16.1"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -78,7 +84,8 @@ if __name__ == "__main__":
|
|||
description=textwrap.dedent(
|
||||
"""
|
||||
This script generates a conda file for different environments.
|
||||
Plain python is the default, but flags can be used to support GPU functionality"""
|
||||
Plain python is the default,
|
||||
but flags can be used to support GPU functionality."""
|
||||
),
|
||||
epilog=HELP_MSG,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
|
|
|
@ -45,13 +45,11 @@ class Tokenizer:
|
|||
self.language = language
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Uses a BERT tokenizer
|
||||
|
||||
"""Tokenizes a list of documents using a BERT tokenizer
|
||||
Args:
|
||||
text (list): [description]
|
||||
|
||||
text (list(str)): list of text documents.
|
||||
Returns:
|
||||
[list]: [description]
|
||||
[list(str)]: list of token lists.
|
||||
"""
|
||||
tokens = [self.tokenizer.tokenize(x) for x in text]
|
||||
return tokens
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# This script reuses some code from
|
||||
# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils_nlp.bert.common import Language
|
||||
from utils_nlp.pytorch.device_utils import get_device, move_to_device
|
||||
|
||||
|
@ -26,7 +31,7 @@ class BERTSequenceClassifier:
|
|||
Defaults to ".".
|
||||
"""
|
||||
if num_labels < 2:
|
||||
raise Exception("Number of labels should be at least 2.")
|
||||
raise ValueError("Number of labels should be at least 2.")
|
||||
|
||||
self.language = language
|
||||
self.num_labels = num_labels
|
||||
|
@ -135,7 +140,7 @@ class BERTSequenceClassifier:
|
|||
epoch + 1,
|
||||
num_epochs,
|
||||
i + 1,
|
||||
i + 1 + (num_batches // 10),
|
||||
min(i + 1 + num_batches // 10, num_batches),
|
||||
num_batches,
|
||||
loss.data,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
import nltk
|
||||
|
||||
nltk.download("punkt", quiet=True)
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Dataset for Arabic Classification utils
|
||||
https://data.mendeley.com/datasets/v524p5dhpj/2
|
||||
Mohamed, BINIZ (2018), “DataSet for Arabic Classification”, Mendeley Data, v2
|
||||
paper link: ("https://www.mendeley.com/catalogue/
|
||||
arabic-text-classification-using-deep-learning-technics/")
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
|
||||
URL = (
|
||||
"https://data.mendeley.com/datasets/v524p5dhpj/2"
|
||||
"/files/91cb8398-9451-43af-88fc-041a0956ae2d/"
|
||||
"arabic_dataset_classifiction.csv.zip"
|
||||
)
|
||||
|
||||
|
||||
def load_pandas_df(local_cache_path=None):
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description]. Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the loaded dataset.
|
||||
"""
|
||||
zip_file = URL.split("/")[-1]
|
||||
csv_file_path = os.path.join(
|
||||
local_cache_path, zip_file.replace(".zip", "")
|
||||
)
|
||||
|
||||
maybe_download(URL, zip_file, local_cache_path)
|
||||
|
||||
if not os.path.exists(csv_file_path):
|
||||
extract_zip(file_path=csv_file_path, dest_path=local_cache_path)
|
||||
return pd.read_csv(csv_file_path)
|
|
@ -7,7 +7,6 @@ https://www.nyu.edu/projects/bowman/multinli/
|
|||
|
||||
import os
|
||||
import pandas as pd
|
||||
import requests
|
||||
from utils_nlp.dataset.url_utils import extract_zip, maybe_download
|
||||
|
||||
URL = "http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip"
|
||||
|
@ -19,21 +18,20 @@ DATA_FILES = {
|
|||
|
||||
|
||||
def load_pandas_df(local_cache_path=None, file_split="train"):
|
||||
"""Downloads and extracts the dataset files
|
||||
"""Downloads and extracts the dataset files
|
||||
Args:
|
||||
local_cache_path ([type], optional): [description]. Defaults to None.
|
||||
file_split (str, optional): The subset to load.
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
One of: {"train", "dev_matched", "dev_mismatched"}
|
||||
Defaults to "train".
|
||||
Returns:
|
||||
pd.DataFrame: pandas DataFrame containing the specified MultiNLI subset.
|
||||
pd.DataFrame: pandas DataFrame containing the specified
|
||||
MultiNLI subset.
|
||||
"""
|
||||
|
||||
file_name = URL.split("/")[-1]
|
||||
if not os.path.exists(os.path.join(local_cache_path, file_name)):
|
||||
response = requests.get(URL)
|
||||
with open(os.path.join(local_cache_path, file_name), "wb") as f:
|
||||
f.write(response.content)
|
||||
maybe_download(URL, file_name, local_cache_path)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(local_cache_path, DATA_FILES[file_split])
|
||||
):
|
||||
|
|
|
@ -122,12 +122,8 @@ def to_nltk_tokens(
|
|||
pd.DataFrame: Dataframe with new columns token_cols, each containing a
|
||||
list of tokens for their respective sentences.
|
||||
"""
|
||||
|
||||
nltk.download("punkt")
|
||||
text_df = df[sentence_cols]
|
||||
tok_df = text_df.applymap(
|
||||
lambda sentence: nltk.word_tokenize(sentence)
|
||||
)
|
||||
tok_df = text_df.applymap(lambda sentence: nltk.word_tokenize(sentence))
|
||||
tok_df.columns = token_cols
|
||||
tokenized = pd.concat([df, tok_df], axis=1)
|
||||
return tokenized
|
||||
|
@ -158,11 +154,9 @@ def rm_nltk_stopwords(
|
|||
nltk.download("stopwords")
|
||||
stop_words = tuple(stopwords.words("english"))
|
||||
text_df = df[sentence_cols]
|
||||
stop_df = (
|
||||
text_df
|
||||
.applymap(lambda sentence: nltk.word_tokenize(sentence))
|
||||
.applymap(lambda l: [word for word in l if word not in stop_words])
|
||||
)
|
||||
stop_df = text_df.applymap(
|
||||
lambda sentence: nltk.word_tokenize(sentence)
|
||||
).applymap(lambda l: [word for word in l if word not in stop_words])
|
||||
|
||||
stop_df.columns = stop_cols
|
||||
return pd.concat([df, stop_df], axis=1)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import requests
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import tarfile
|
||||
import zipfile
|
||||
from contextlib import contextmanager
|
||||
from tempfile import TemporaryDirectory
|
||||
from tqdm import tqdm
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -17,13 +19,12 @@ def maybe_download(
|
|||
url, filename=None, work_directory=".", expected_bytes=None
|
||||
):
|
||||
"""Download a file if it is not already downloaded.
|
||||
|
||||
|
||||
Args:
|
||||
filename (str): File name.
|
||||
work_directory (str): Working directory.
|
||||
url (str): URL of the file to download.
|
||||
expected_bytes (int): Expected file size in bytes.
|
||||
|
||||
Returns:
|
||||
str: File path of the file downloaded.
|
||||
"""
|
||||
|
@ -80,7 +81,7 @@ def extract_zip(file_path, dest_path="."):
|
|||
raise IOError("File doesn't exist")
|
||||
if not os.path.exists(dest_path):
|
||||
raise IOError("Destination directory doesn't exist")
|
||||
with ZipFile(file_path) as z:
|
||||
with zipfile.ZipFile(file_path) as z:
|
||||
z.extractall(path=dest_path)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче