Add `examples/multiple-choice/run_swag_no_trainer.py` (#10934)
* Initial commit * Another bunch of updates * make style quliaty + delete debug arg from bash script * Use compue_metrics func * Do a few fixes * Add copyright * Fix typos
This commit is contained in:
Родитель
ae6b6963ad
Коммит
5057213bcc
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
-->
|
||||
|
||||
## Multiple Choice
|
||||
# Multiple Choice
|
||||
|
||||
Based on the script [`run_swag.py`]().
|
||||
|
||||
|
@ -41,6 +41,73 @@ eval_acc = 0.8338998300509847
|
|||
eval_loss = 0.44457291918821606
|
||||
```
|
||||
|
||||
## PyTorch version, no Trainer
|
||||
|
||||
Based on the script [run_ner_no_trainer.py](https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_swag_no_trainer.py).
|
||||
|
||||
Like `run_swag.py`, this script allows you to fine-tune any of the models on the [hub](https://huggingface.co/models) (as long as its architecture as a `ForMultipleChoice` version in the library) on
|
||||
the SWAG dataset or your own data in a csv or a JSON file. The main difference is that this
|
||||
script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like.
|
||||
|
||||
It offers less options than the script with `Trainer` (but you can easily change the options for the optimizer
|
||||
or the dataloaders directly in the script) but still run in a distributed setup, on TPU and supports mixed precision by
|
||||
the mean of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally
|
||||
after installing it:
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
then
|
||||
|
||||
```bash
|
||||
export DATASET_NAME=swag
|
||||
|
||||
python run_swag_no_trainer.py \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--max_seq_length 128 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--output_dir /tmp/$DATASET_NAME/
|
||||
```
|
||||
|
||||
You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and reply to the questions asked. Then
|
||||
|
||||
```bash
|
||||
accelerate test
|
||||
```
|
||||
|
||||
that will check everything is ready for training. Finally, you can launch training with
|
||||
|
||||
```bash
|
||||
export DATASET_NAME=swag
|
||||
|
||||
accelerate launch run_swag_no_trainer.py \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--max_seq_length 128 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--output_dir /tmp/$DATASET_NAME/
|
||||
```
|
||||
|
||||
This command is the same and will work for:
|
||||
|
||||
- a CPU-only setup
|
||||
- a setup with one GPU
|
||||
- a distributed training with several GPUs (single or multi node)
|
||||
- a training on TPUs
|
||||
|
||||
Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it.
|
||||
|
||||
## Tensorflow
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
accelerate launch run_swag_no_trainer.py \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--dataset_name swag \
|
||||
--output_dir /tmp/test-swag-no-trainer \
|
||||
--pad_to_max_length
|
|
@ -0,0 +1,488 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning a 🤗 Transformers model on multiple choice relying on the accelerate library without using a Trainer.
|
||||
"""
|
||||
# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import PaddingStrategy
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=128,
|
||||
help=(
|
||||
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
|
||||
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pad_to_max_length",
|
||||
action="store_true",
|
||||
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_slow_tokenizer",
|
||||
action="store_true",
|
||||
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the evaluation dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model type to use if training from scratch.",
|
||||
choices=MODEL_TYPES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Activate debug mode and run training only with a subset of data.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForMultipleChoice:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
]
|
||||
flattened_features = sum(flattened_features, [])
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flattened_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Un-flatten
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
# Add back labels
|
||||
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
accelerator = Accelerator()
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
# accelerator.is_local_main_process is only True for one process per machine.
|
||||
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_file is not None:
|
||||
data_files["train"] = args.train_file
|
||||
if args.validation_file is not None:
|
||||
data_files["validation"] = args.validation_file
|
||||
extension = args.train_file.split(".")[-1]
|
||||
raw_datasets = load_dataset(extension, data_files=data_files)
|
||||
# Trim a number of training examples
|
||||
if args.debug:
|
||||
for split in raw_datasets.keys():
|
||||
raw_datasets[split] = raw_datasets[split].select(range(100))
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
if raw_datasets["train"] is not None:
|
||||
column_names = raw_datasets["train"].column_names
|
||||
else:
|
||||
column_names = raw_datasets["validation"].column_names
|
||||
|
||||
# When using your own dataset or a different dataset from swag, you will probably need to change this.
|
||||
ending_names = [f"ending{i}" for i in range(4)]
|
||||
context_name = "sent1"
|
||||
question_header_name = "sent2"
|
||||
label_column_name = "label" if "label" in column_names else "labels"
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
if args.config_name:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
elif args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
config = CONFIG_MAPPING[args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
|
||||
elif args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
if args.model_name_or_path:
|
||||
model = AutoModelForMultipleChoice.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
model = AutoModelForMultipleChoice.from_config(config)
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
padding = "max_length" if args.pad_to_max_length else False
|
||||
|
||||
def preprocess_function(examples):
|
||||
first_sentences = [[context] * 4 for context in examples[context_name]]
|
||||
question_headers = examples[question_header_name]
|
||||
second_sentences = [
|
||||
[f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
|
||||
]
|
||||
labels = examples[label_column_name]
|
||||
|
||||
# Flatten out
|
||||
first_sentences = sum(first_sentences, [])
|
||||
second_sentences = sum(second_sentences, [])
|
||||
|
||||
# Tokenize
|
||||
tokenized_examples = tokenizer(
|
||||
first_sentences,
|
||||
second_sentences,
|
||||
max_length=args.max_length,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
)
|
||||
# Un-flatten
|
||||
tokenized_inputs = {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
|
||||
tokenized_inputs["labels"] = labels
|
||||
return tokenized_inputs
|
||||
|
||||
processed_datasets = raw_datasets.map(
|
||||
preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
|
||||
)
|
||||
|
||||
train_dataset = processed_datasets["train"]
|
||||
eval_dataset = processed_datasets["validation"]
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# DataLoaders creation:
|
||||
if args.pad_to_max_length:
|
||||
# If padding was already done ot max length, we use the default data collator that will just convert everything
|
||||
# to tensors.
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
# Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
|
||||
# the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
|
||||
# of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
|
||||
data_collator = DataCollatorForMultipleChoice(
|
||||
tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||
)
|
||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||
|
||||
# Optimizer
|
||||
# Split weights in two groups, one with weight decay and the other not.
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
||||
|
||||
# Use the device given by the `accelerator` object.
|
||||
device = accelerator.device
|
||||
model.to(device)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
else:
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metric = load_metric("accuracy")
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
completed_steps = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
model.eval()
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
metric.add_batch(
|
||||
predictions=accelerator.gather(predictions),
|
||||
references=accelerator.gather(batch["labels"]),
|
||||
)
|
||||
|
||||
eval_metric = metric.compute()
|
||||
accelerator.print(f"epoch {epoch}: {eval_metric}")
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Загрузка…
Ссылка в новой задаче