зеркало из https://github.com/Azure/nlp-samples.git
bert-train fine tune#2
This commit is contained in:
Родитель
a7328f36bf
Коммит
79cc9ccc89
|
@ -8,6 +8,7 @@ This repo is a collection of NLP samples on Azure, especially for Japanese.
|
|||
|---------|-----------|
|
||||
|[rinna GPT-2 train](./examples/rinna-gpt2-train)|Fine-Tune rinna GPT-2 model with ONNX Runtime|
|
||||
|[rinna GPT-2 predict](./examples/rinna-gpt2-predict)|Convert rinna GPT-2 Model to ONNX with Quantization|
|
||||
|[BERT Fine Tuning](./examples/bert-train)|BERT Train Fine Tuning|
|
||||
|
||||
|
||||
## Contributing
|
||||
|
|
|
@ -8,3 +8,8 @@ Livedoor News コーパスのクラス分類を題材にした BERT モデルの
|
|||
|---------|-----------|
|
||||
|[BERT Fine Tuning (Local)](./local/)| PyTorch Lightnings|
|
||||
<!-- |[BERT Fine Tuning on Azure ML Compute Clusters](./azureml/)| Azure ML CLI 2.0 + PyTorch Lightnings| -->
|
||||
|
||||
|
||||
## Reference
|
||||
- [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) - PyTorch Lightning document.
|
||||
- [azureml-pl-sample](https://github.com/ShuntaIto/azureml-pl-sample) - Sample code for PyTorch Lightning on AzureML.
|
|
@ -2,46 +2,62 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# BERT Classifictaion Fine Tuning with PyTorch Lightning\n",
|
||||
"ローカル環境で BERT のファインチューニングを行います。"
|
||||
],
|
||||
"metadata": {}
|
||||
"ローカル環境で BERT のファインチューニングを行います。PyTorch Lightning を利用します。"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 0. 事前準備\n",
|
||||
"ターミナルから学習データと Python 環境のセットアップを行います。\n",
|
||||
"### Data\n",
|
||||
"ターミナルで以下のコマンドを実行し、Livedoor ニュースのコーパスデータの前処理を実施します。\n",
|
||||
"```bash\n",
|
||||
"python utils/livedoor-dataprep.py\n",
|
||||
"````"
|
||||
],
|
||||
"metadata": {}
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Python 環境の準備\n",
|
||||
"ターミナルで以下のコマンドを実行し conda 環境を構築してください。\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"$ conda env create --file bert_finetune_local.yml \n",
|
||||
"conda env create --file bert_finetune_local.yml\n",
|
||||
"#ipython kernel install --user --name=bert_finetune_local --display-name=bert_finetune_local # Jupyter を利用する際に必要\n",
|
||||
"```"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. ライブラリのインポート"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766078354
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Global seed set to 1234\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
|
@ -56,54 +72,39 @@
|
|||
"pl.seed_everything(1234)\n",
|
||||
"torch.manual_seed(1234)\n",
|
||||
"np.random.seed(1234)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Global seed set to 1234\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766078354
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"source": [
|
||||
"# GPU が利用可能か確認\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766078417
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GPU が利用可能か確認\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. データ前処理"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"source": [
|
||||
"df = pd.read_csv(\"./data/processed/livedoor.tsv\", delimiter='\\t')\n",
|
||||
"df = df.dropna()\n",
|
||||
"df.head()"
|
||||
],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766078727
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
|
@ -180,75 +181,95 @@
|
|||
"4 livedoor-homme "
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"execution_count": 3
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766078727
|
||||
}
|
||||
}
|
||||
"source": [
|
||||
"df = pd.read_csv(\"./data/processed/livedoor.tsv\", delimiter='\\t')\n",
|
||||
"df = df.dropna()\n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"source": [
|
||||
"X_train, X_test = train_test_split(df, test_size=0.2, stratify=df['label_index'])\n",
|
||||
"X_train.reset_index(drop=True, inplace=True)\n",
|
||||
"X_test.reset_index(drop=True, inplace=True)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766078843
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train, X_test = train_test_split(df, test_size=0.2, stratify=df['label_index'])\n",
|
||||
"X_train.reset_index(drop=True, inplace=True)\n",
|
||||
"X_test.reset_index(drop=True, inplace=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train.to_csv(\"./data/processed/livedoor-train.tsv\", sep='\\t', index=False)\n",
|
||||
"X_test.to_csv(\"./data/processed/livedoor-test.tsv\", sep='\\t', index=False)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"source": [
|
||||
"train_dataset = datasets.LivedoorDataset(X_train)\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619766081789
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = datasets.LivedoorDataset(X_train)\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_dataset = datasets.LivedoorDataset(X_train)\n",
|
||||
"test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. モデル学習"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619768011572
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
|
||||
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||||
"GPU available: True, used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = models.LitBert()\n",
|
||||
"\n",
|
||||
|
@ -261,38 +282,16 @@
|
|||
"\n",
|
||||
"model.to(device)\n",
|
||||
"trainer = pl.Trainer(gpus=1, default_root_dir='pl-model', max_epochs=5)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']\n",
|
||||
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||||
"GPU available: True, used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"gather": {
|
||||
"logged": 1619768011572
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"source": [
|
||||
"# モデル学習開始\n",
|
||||
"trainer.fit(model, train_dataloader=train_loader, val_dataloaders=test_loader)"
|
||||
],
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/anaconda/envs/bert_finetune_local/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:531: LightningDeprecationWarning: `trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6. Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'\n",
|
||||
" \"`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6.\"\n",
|
||||
|
@ -309,22 +308,22 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "80614176df2a4d8d8ae1f9387c57574d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "80614176df2a4d8d8ae1f9387c57574d"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation sanity check: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/anaconda/envs/bert_finetune_local/lib/python3.6/site-packages/pytorch_lightning/trainer/data_loading.py:377: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for val/test/predict dataloaders.\n",
|
||||
" f\"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn\"\n",
|
||||
|
@ -336,111 +335,110 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "430abc63d3674ed58559de21b4563a97",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "430abc63d3674ed58559de21b4563a97"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Training: -1it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "dd3171cfad284bfeba8ac7f7d92064da",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "dd3171cfad284bfeba8ac7f7d92064da"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validating: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "bcc999d271224f429f19bc82e6acfaa3",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "bcc999d271224f429f19bc82e6acfaa3"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validating: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "97b315a5eca14e719e04d00d80399a5d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "97b315a5eca14e719e04d00d80399a5d"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validating: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b5862bbcaa3a4c6b810981d4ddbb7828",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "b5862bbcaa3a4c6b810981d4ddbb7828"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validating: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a21cac1101eb4b039fcc0220b24e10b0",
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "a21cac1101eb4b039fcc0220b24e10b0"
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validating: 0it [00:00, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {}
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
"source": [
|
||||
"# モデル学習開始\n",
|
||||
"trainer.fit(model, train_dataloader=train_loader, val_dataloaders=test_loader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. モデル検証"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"source": [
|
||||
"# モデル検証\n",
|
||||
"result = trainer.test(model, test_dataloaders=test_loader)\n",
|
||||
"print(result)"
|
||||
],
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/anaconda/envs/bert_finetune_local/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py:679: LightningDeprecationWarning: `trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6. Use `trainer.test(dataloaders)` instead.\n",
|
||||
" \"`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6.\"\n",
|
||||
|
@ -450,50 +448,58 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {}
|
||||
"source": [
|
||||
"# モデル検証\n",
|
||||
"result = trainer.test(model, test_dataloaders=test_loader)\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# モデル保存\n",
|
||||
"trainer.save_checkpoint(\"./model/bert-livedoor.ckpt\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Tensorboard の起動\n",
|
||||
"ターミナルから実行します。\n",
|
||||
"```bash\n",
|
||||
"tensorboard --logdir pl-model/lightning_logs\n",
|
||||
"```"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [],
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "cb548e6cc0dfcc93fff3180785f33996431c026a0645ee6605bd3f7e301d3c90"
|
||||
},
|
||||
"kernel_info": {
|
||||
"name": "python38-azureml"
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3.6.13 64-bit ('bert_finetune_local': conda)"
|
||||
"display_name": "Python 3.6.13 64-bit ('bert_finetune_local': conda)",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -516,11 +522,8 @@
|
|||
},
|
||||
"nteract": {
|
||||
"version": "nteract-front-end@1.0.0"
|
||||
},
|
||||
"interpreter": {
|
||||
"hash": "cb548e6cc0dfcc93fff3180785f33996431c026a0645ee6605bd3f7e301d3c90"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче