зеркало из https://github.com/Azure/nlp-samples.git
This commit is contained in:
Родитель
27ddb9e95c
Коммит
703d67142a
|
@ -0,0 +1,2 @@
|
|||
Portions of the source code are based on the [transformers](https://github.com/huggingface/transformers) project, which is licensed under Apache 2.0.
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"Registrations": [
|
||||
{
|
||||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "9a0a8c1c6f4f2f0c80ff07d36713a3ada785eec5",
|
||||
"repositoryUrl": "https://github.com/huggingface/transformers.git"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"Version": 1
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
|
||||
# ONNX Runtime Training Module for PyTorch
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Usage
|
||||
# Build: docker build -f Dockerfile.ort-cu102-cudnn7-devel-ubuntu18.04 -t [image-name] .
|
||||
# Run: docker run -it --gpus all --name [run-name] [image-name]:latest /bin/bash
|
||||
# Example:
|
||||
# docker build -f Dockerfile.ort-cu102-cudnn7-devel-ubuntu18.04 -t ort.cu102 .
|
||||
# docker run -it --gpus all --name my-experiments ort.cu102:latest /bin/bash
|
||||
|
||||
# CUDA development image for building sources
|
||||
FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 as builder
|
||||
|
||||
# Install and update tools to minimize security vulnerabilities
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y software-properties-common wget apt-utils patchelf git libprotobuf-dev protobuf-compiler cmake
|
||||
RUN unattended-upgrade
|
||||
RUN apt-get autoremove -y
|
||||
|
||||
# Python and pip
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1
|
||||
RUN apt-get install -y python3-pip
|
||||
RUN update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1
|
||||
RUN pip install --upgrade pip
|
||||
|
||||
# PyTorch
|
||||
RUN pip install onnx ninja
|
||||
RUN pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# ORT Module
|
||||
RUN pip install onnxruntime-training==1.8.0
|
||||
RUN pip install torch-ort
|
||||
|
||||
WORKDIR /stage
|
||||
|
||||
#Install huggingface transformers
|
||||
RUN cd /stage && git clone https://github.com/microsoft/huggingface-transformers.git &&\
|
||||
cd huggingface-transformers &&\
|
||||
git checkout raviskolli/ort_t5 &&\
|
||||
pip install -e .
|
||||
|
||||
# Install AzureML support and commonly used packages.
|
||||
RUN pip install azureml-defaults wget fairscale
|
||||
RUN pip install sacrebleu datasets deepspeed
|
||||
RUN pip install scipy sklearn accelerate
|
||||
RUN pip install sentencepiece protobuf
|
||||
RUN pip install azureml-mlflow mlflow
|
|
@ -0,0 +1,13 @@
|
|||
name: rinna-gpt2
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.6
|
||||
- pip
|
||||
- pip:
|
||||
- transformers
|
||||
- scikit-learn
|
||||
- torch
|
||||
- ipywidgets
|
||||
- sentencepiece
|
||||
- azureml-core
|
|
@ -0,0 +1,348 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# rinna GPT-2 モデルの Fine Tuning\n",
|
||||
"HuggingFace の transformers ライブラリを用いて [rinna gpt-2](https://huggingface.co/rinna/japanese-gpt2-medium) モデルの Fine Tuning を行います。"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 事前準備\n",
|
||||
"必要なライブラリをインポートします。"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"from azureml.core import Experiment, Workspace, Environment\n",
|
||||
"from azureml.core.compute import ComputeTarget\n",
|
||||
"from azureml.core import ScriptRunConfig\n",
|
||||
"from azureml.core.runconfig import PyTorchConfiguration\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.makedirs('src', exist_ok=True)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Azure ML Workspace へ接続します。"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"ws = Workspace.from_config()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"実験 Experiment の名称"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"model_experiment = Experiment(ws, name=\"rinna-gpt2-exp\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"分散学習の設定"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"distr_config = PyTorchConfiguration(process_count=1, node_count=1)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"環境 Environment の設定"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"hf_ort_env = Environment.from_dockerfile(name='rinna-docker-env', dockerfile='Dockerfile')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"学習コードの準備"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"%%writefile src/train.py\n",
|
||||
"#!/usr/bin/env python\n",
|
||||
"# coding=utf-8\n",
|
||||
"# Copyright 2020 The HuggingFace Inc. team. All rights reserved.\n",
|
||||
"#\n",
|
||||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# http://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License.\n",
|
||||
"\n",
|
||||
"import io\n",
|
||||
"import sys\n",
|
||||
"from azureml.core import Run\n",
|
||||
"import argparse\n",
|
||||
"import mlflow\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import (AutoModelForCausalLM,\n",
|
||||
" DataCollatorForLanguageModeling, T5Tokenizer,\n",
|
||||
" TextDataset, Trainer, TrainerCallback,\n",
|
||||
" TrainingArguments, default_data_collator)\n",
|
||||
"\n",
|
||||
"# 日本語対応\n",
|
||||
"sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')\n",
|
||||
"\n",
|
||||
"# 引数\n",
|
||||
"parser = argparse.ArgumentParser()\n",
|
||||
"\n",
|
||||
"parser.add_argument('--max_steps', type=int, default=100)\n",
|
||||
"parser.add_argument('--output_dir', type=str)\n",
|
||||
"parser.add_argument('--model_name_or_path', default='rinna/japanese-gpt2-medium')\n",
|
||||
"\n",
|
||||
"args = parser.parse_args()\n",
|
||||
"\n",
|
||||
"# Azure ML 事前準備\n",
|
||||
"run = Run.get_context()\n",
|
||||
"ws = run.experiment.workspace\n",
|
||||
"\n",
|
||||
"# mlflow trackinr uri の設定\n",
|
||||
"mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())\n",
|
||||
"\n",
|
||||
"# tokenizer, model オブジェクトのロード\n",
|
||||
"tokenizer = T5Tokenizer.from_pretrained(\"rinna/japanese-gpt2-medium\", do_lower_case=True)\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\"rinna/japanese-gpt2-medium\")\n",
|
||||
"model.resize_token_embeddings(len(tokenizer))\n",
|
||||
"\n",
|
||||
"# データセット\n",
|
||||
"train_path = 'train.txt'\n",
|
||||
"test_path = 'test.txt'\n",
|
||||
"\n",
|
||||
"train_dataset = TextDataset(tokenizer=tokenizer, file_path=train_path, block_size=512)\n",
|
||||
"eval_dataset = TextDataset(tokenizer=tokenizer, file_path=test_path, block_size=512)\n",
|
||||
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
|
||||
"\n",
|
||||
"# mlflow でログを取るための callback クラス\n",
|
||||
"class MyCallback(TrainerCallback):\n",
|
||||
" def __init__(self, azureml_run=None):\n",
|
||||
" self.mlflow = mlflow\n",
|
||||
"\n",
|
||||
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
||||
" if state.is_world_process_zero:\n",
|
||||
" for k, v in logs.items():\n",
|
||||
" if isinstance(v, (int, float)):\n",
|
||||
" self.mlflow.log_metric(k, v, step=state.global_step)\n",
|
||||
"\n",
|
||||
"# Trainer 引数\n",
|
||||
"training_args = TrainingArguments(\n",
|
||||
" output_dir=\"./outputs\", \n",
|
||||
" overwrite_output_dir=True, \n",
|
||||
" max_steps=args.max_steps,\n",
|
||||
" per_device_train_batch_size=1,\n",
|
||||
" per_device_eval_batch_size=1,\n",
|
||||
" do_train=True,\n",
|
||||
" do_eval=True,\n",
|
||||
" evaluation_strategy=\"steps\",\n",
|
||||
" eval_steps=50,\n",
|
||||
" fp16=True,\n",
|
||||
" report_to=[\"none\"],\n",
|
||||
" ort=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# モデル学習の設定\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
" eval_dataset=eval_dataset,\n",
|
||||
" callbacks=[MyCallback]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# モデル学習開始\n",
|
||||
"trainer.train()\n",
|
||||
"\n",
|
||||
"# モデルの保存\n",
|
||||
"trainer.save_model()\n",
|
||||
"\n",
|
||||
"# モデルの検証\n",
|
||||
"tokenizer = T5Tokenizer.from_pretrained(\"rinna/japanese-gpt2-medium\", do_lower_case=True)\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\"./outputs\")\n",
|
||||
"\n",
|
||||
"input = tokenizer.encode(\"仕事\", return_tensors=\"pt\")\n",
|
||||
"output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)\n",
|
||||
"print(tokenizer.batch_decode(output))\n",
|
||||
"\n",
|
||||
"input = tokenizer.encode(\"料理\", return_tensors=\"pt\")\n",
|
||||
"output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)\n",
|
||||
"print(tokenizer.batch_decode(output))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"input = tokenizer.encode(\"握手をしたら、\", return_tensors=\"pt\")\n",
|
||||
"output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)\n",
|
||||
"print(tokenizer.batch_decode(output))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"スクリプトの引数の定義"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"script_params = ['--max_steps', 100, '--output_dir', './outputs', '--model_name_or_path', 'rinna/japanese-gpt2-medium']"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## モデル学習\n",
|
||||
"`ScriptRunConfig` を用いて Azure Machine Learning Compute Cluster 上で学習ができるように設定します。"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"model_run_config = ScriptRunConfig(\n",
|
||||
" source_directory='./src',\n",
|
||||
" script='./train.py',\n",
|
||||
" arguments=script_params,\n",
|
||||
" compute_target=ComputeTarget(workspace=ws, name=\"gpuinstance\"),\n",
|
||||
" environment=hf_ort_env,\n",
|
||||
" distributed_job_config=distr_config)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"モデル学習の開始"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"run = model_experiment.submit(model_run_config)\n",
|
||||
"run"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"run.wait_for_completion(show_output=True)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## モデルテスト\n",
|
||||
"ローカル環境でモデルの推論を行います。Run の outputs フォルダのモデルファイルをダウンロード & ロードして利用します。"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"run_test = ws.get_run(run.id)\n",
|
||||
"run_test.run.download_files(prefix='outputs/models/', output_directory='./')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"from transformers import T5Tokenizer, AutoModelForCausalLM\n",
|
||||
"\n",
|
||||
"tokenizer = T5Tokenizer.from_pretrained(\"outputs/models/\", do_lower_case=True)\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\"outputs/models/\")"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"input = tokenizer.encode(\"こんにちは、\", return_tensors=\"pt\")\n",
|
||||
"output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=10)\n",
|
||||
"print(tokenizer.batch_decode(output))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"orig_nbformat": 4,
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import sys
|
||||
from azureml.core import Run
|
||||
import argparse
|
||||
import mlflow
|
||||
from datasets import load_dataset
|
||||
from transformers import (AutoModelForCausalLM,
|
||||
DataCollatorForLanguageModeling, T5Tokenizer,
|
||||
TextDataset, Trainer, TrainerCallback,
|
||||
TrainingArguments, default_data_collator)
|
||||
|
||||
# 日本語対応
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
|
||||
# 引数
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--max_steps', type=int, default=100)
|
||||
parser.add_argument('--output_dir', type=str)
|
||||
parser.add_argument('--model_name_or_path', default='rinna/japanese-gpt2-medium')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Azure ML 事前準備
|
||||
run = Run.get_context()
|
||||
ws = run.experiment.workspace
|
||||
|
||||
# mlflow trackinr uri の設定
|
||||
mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())
|
||||
|
||||
# tokenizer, model オブジェクトのロード
|
||||
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium", do_lower_case=True)
|
||||
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# データセット
|
||||
train_path = 'train.txt'
|
||||
test_path = 'test.txt'
|
||||
|
||||
train_dataset = TextDataset(tokenizer=tokenizer, file_path=train_path, block_size=512)
|
||||
eval_dataset = TextDataset(tokenizer=tokenizer, file_path=test_path, block_size=512)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# mlflow でログを取るための callback クラス
|
||||
class MyCallback(TrainerCallback):
|
||||
def __init__(self, azureml_run=None):
|
||||
self.mlflow = mlflow
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
self.mlflow.log_metric(k, v, step=state.global_step)
|
||||
|
||||
# Trainer 引数
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./outputs",
|
||||
overwrite_output_dir=True,
|
||||
max_steps=args.max_steps,
|
||||
per_device_train_batch_size=1,
|
||||
per_device_eval_batch_size=1,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=50,
|
||||
fp16=True,
|
||||
report_to=["none"],
|
||||
ort=True,
|
||||
)
|
||||
|
||||
|
||||
# モデル学習の設定
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
callbacks=[MyCallback]
|
||||
)
|
||||
|
||||
# モデル学習開始
|
||||
trainer.train()
|
||||
|
||||
# モデルの保存
|
||||
trainer.save_model()
|
||||
|
||||
# モデルの検証
|
||||
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium", do_lower_case=True)
|
||||
model = AutoModelForCausalLM.from_pretrained("./outputs")
|
||||
|
||||
input = tokenizer.encode("仕事", return_tensors="pt")
|
||||
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)
|
||||
print(tokenizer.batch_decode(output))
|
||||
|
||||
input = tokenizer.encode("料理", return_tensors="pt")
|
||||
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)
|
||||
print(tokenizer.batch_decode(output))
|
||||
|
||||
|
||||
input = tokenizer.encode("握手をしたら、", return_tensors="pt")
|
||||
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)
|
||||
print(tokenizer.batch_decode(output))
|
Загрузка…
Ссылка в новой задаче