This commit is contained in:
Keita Onabuta 2021-10-13 13:52:33 +00:00
Родитель a7328f36bf
Коммит 79cc9ccc89
3 изменённых файлов: 154 добавлений и 145 удалений

Просмотреть файл

@ -8,6 +8,7 @@ This repo is a collection of NLP samples on Azure, especially for Japanese.
|---------|-----------|
|[rinna GPT-2 train](./examples/rinna-gpt2-train)|Fine-Tune rinna GPT-2 model with ONNX Runtime|
|[rinna GPT-2 predict](./examples/rinna-gpt2-predict)|Convert rinna GPT-2 Model to ONNX with Quantization|
|[BERT Fine Tuning](./examples/bert-train)|BERT Train Fine Tuning|
## Contributing

Просмотреть файл

@ -8,3 +8,8 @@ Livedoor News コーパスのクラス分類を題材にした BERT モデルの
|---------|-----------|
|[BERT Fine Tuning (Local)](./local/)| PyTorch Lightnings|
<!-- |[BERT Fine Tuning on Azure ML Compute Clusters](./azureml/)| Azure ML CLI 2.0 + PyTorch Lightnings| -->
## Reference
- [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) - PyTorch Lightning document.
- [azureml-pl-sample](https://github.com/ShuntaIto/azureml-pl-sample) - Sample code for PyTorch Lightning on AzureML.

Просмотреть файл

@ -2,46 +2,62 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BERT Classifictaion Fine Tuning with PyTorch Lightning\n",
"ローカル環境で BERT のファインチューニングを行います。"
],
"metadata": {}
"ローカル環境で BERT のファインチューニングを行います。PyTorch Lightning を利用します。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 0. 事前準備\n",
"ターミナルから学習データと Python 環境のセットアップを行います。\n",
"### Data\n",
"ターミナルで以下のコマンドを実行し、Livedoor ニュースのコーパスデータの前処理を実施します。\n",
"```bash\n",
"python utils/livedoor-dataprep.py\n",
"````"
],
"metadata": {}
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Python 環境の準備\n",
"ターミナルで以下のコマンドを実行し conda 環境を構築してください。\n",
"\n",
"```bash\n",
"$ conda env create --file bert_finetune_local.yml \n",
"conda env create --file bert_finetune_local.yml\n",
"#ipython kernel install --user --name=bert_finetune_local --display-name=bert_finetune_local # Jupyter を利用する際に必要\n",
"```"
],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. ライブラリのインポート"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"gather": {
"logged": 1619766078354
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Global seed set to 1234\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
@ -56,54 +72,39 @@
"pl.seed_everything(1234)\n",
"torch.manual_seed(1234)\n",
"np.random.seed(1234)"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Global seed set to 1234\n"
]
}
],
"metadata": {
"gather": {
"logged": 1619766078354
}
}
]
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"# GPU が利用可能か確認\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
],
"outputs": [],
"metadata": {
"gather": {
"logged": 1619766078417
}
}
},
"outputs": [],
"source": [
"# GPU が利用可能か確認\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. データ前処理"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"df = pd.read_csv(\"./data/processed/livedoor.tsv\", delimiter='\\t')\n",
"df = df.dropna()\n",
"df.head()"
],
"metadata": {
"gather": {
"logged": 1619766078727
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
@ -180,75 +181,95 @@
"4 livedoor-homme "
]
},
"execution_count": 3,
"metadata": {},
"execution_count": 3
"output_type": "execute_result"
}
],
"metadata": {
"gather": {
"logged": 1619766078727
}
}
"source": [
"df = pd.read_csv(\"./data/processed/livedoor.tsv\", delimiter='\\t')\n",
"df = df.dropna()\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"X_train, X_test = train_test_split(df, test_size=0.2, stratify=df['label_index'])\n",
"X_train.reset_index(drop=True, inplace=True)\n",
"X_test.reset_index(drop=True, inplace=True)"
],
"outputs": [],
"metadata": {
"gather": {
"logged": 1619766078843
}
}
},
"outputs": [],
"source": [
"X_train, X_test = train_test_split(df, test_size=0.2, stratify=df['label_index'])\n",
"X_train.reset_index(drop=True, inplace=True)\n",
"X_test.reset_index(drop=True, inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X_train.to_csv(\"./data/processed/livedoor-train.tsv\", sep='\\t', index=False)\n",
"X_test.to_csv(\"./data/processed/livedoor-test.tsv\", sep='\\t', index=False)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"train_dataset = datasets.LivedoorDataset(X_train)\n",
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)"
],
"outputs": [],
"metadata": {
"gather": {
"logged": 1619766081789
}
}
},
"outputs": [],
"source": [
"train_dataset = datasets.LivedoorDataset(X_train)\n",
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"test_dataset = datasets.LivedoorDataset(X_train)\n",
"test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. モデル学習"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"gather": {
"logged": 1619768011572
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"GPU available: True, used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n"
]
}
],
"source": [
"model = models.LitBert()\n",
"\n",
@ -261,38 +282,16 @@
"\n",
"model.to(device)\n",
"trainer = pl.Trainer(gpus=1, default_root_dir='pl-model', max_epochs=5)"
],
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"GPU available: True, used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n"
]
}
],
"metadata": {
"gather": {
"logged": 1619768011572
}
}
]
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"# モデル学習開始\n",
"trainer.fit(model, train_dataloader=train_loader, val_dataloaders=test_loader)"
],
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/bert_finetune_local/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:531: LightningDeprecationWarning: `trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6. Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'\n",
" \"`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6.\"\n",
@ -309,22 +308,22 @@
]
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "80614176df2a4d8d8ae1f9387c57574d",
"version_major": 2,
"version_minor": 0,
"model_id": "80614176df2a4d8d8ae1f9387c57574d"
"version_minor": 0
},
"text/plain": [
"Validation sanity check: 0it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "stream",
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/bert_finetune_local/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py:377: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for val/test/predict dataloaders.\n",
" f\"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn\"\n",
@ -336,111 +335,110 @@
]
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "430abc63d3674ed58559de21b4563a97",
"version_major": 2,
"version_minor": 0,
"model_id": "430abc63d3674ed58559de21b4563a97"
"version_minor": 0
},
"text/plain": [
"Training: -1it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dd3171cfad284bfeba8ac7f7d92064da",
"version_major": 2,
"version_minor": 0,
"model_id": "dd3171cfad284bfeba8ac7f7d92064da"
"version_minor": 0
},
"text/plain": [
"Validating: 0it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bcc999d271224f429f19bc82e6acfaa3",
"version_major": 2,
"version_minor": 0,
"model_id": "bcc999d271224f429f19bc82e6acfaa3"
"version_minor": 0
},
"text/plain": [
"Validating: 0it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "97b315a5eca14e719e04d00d80399a5d",
"version_major": 2,
"version_minor": 0,
"model_id": "97b315a5eca14e719e04d00d80399a5d"
"version_minor": 0
},
"text/plain": [
"Validating: 0it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b5862bbcaa3a4c6b810981d4ddbb7828",
"version_major": 2,
"version_minor": 0,
"model_id": "b5862bbcaa3a4c6b810981d4ddbb7828"
"version_minor": 0
},
"text/plain": [
"Validating: 0it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a21cac1101eb4b039fcc0220b24e10b0",
"version_major": 2,
"version_minor": 0,
"model_id": "a21cac1101eb4b039fcc0220b24e10b0"
"version_minor": 0
},
"text/plain": [
"Validating: 0it [00:00, ?it/s]"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
}
],
"metadata": {}
"source": [
"# モデル学習開始\n",
"trainer.fit(model, train_dataloader=train_loader, val_dataloaders=test_loader)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. モデル検証"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"# モデル検証\n",
"result = trainer.test(model, test_dataloaders=test_loader)\n",
"print(result)"
],
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/bert_finetune_local/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:679: LightningDeprecationWarning: `trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6. Use `trainer.test(dataloaders)` instead.\n",
" \"`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6.\"\n",
@ -450,50 +448,58 @@
]
},
{
"output_type": "stream",
"name": "stdout",
"output_type": "stream",
"text": [
"[]\n"
]
}
],
"metadata": {}
"source": [
"# モデル検証\n",
"result = trainer.test(model, test_dataloaders=test_loader)\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# モデル保存\n",
"trainer.save_checkpoint(\"./model/bert-livedoor.ckpt\")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tensorboard の起動\n",
"ターミナルから実行します。\n",
"```bash\n",
"tensorboard --logdir pl-model/lightning_logs\n",
"```"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"metadata": {},
"outputs": [],
"metadata": {}
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "cb548e6cc0dfcc93fff3180785f33996431c026a0645ee6605bd3f7e301d3c90"
},
"kernel_info": {
"name": "python38-azureml"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.13 64-bit ('bert_finetune_local': conda)"
"display_name": "Python 3.6.13 64-bit ('bert_finetune_local': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -516,11 +522,8 @@
},
"nteract": {
"version": "nteract-front-end@1.0.0"
},
"interpreter": {
"hash": "cb548e6cc0dfcc93fff3180785f33996431c026a0645ee6605bd3f7e301d3c90"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}