[Fix] text-classification PL example (#6027)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Родитель
eb2bd8d6eb
Коммит
ffceef2042
|
@ -73,7 +73,7 @@ class BaseTransformer(pl.LightningModule):
|
||||||
# self.save_hyperparameters()
|
# self.save_hyperparameters()
|
||||||
# can also expand arguments into trainer signature for easier reading
|
# can also expand arguments into trainer signature for easier reading
|
||||||
|
|
||||||
self.hparams = hparams
|
self.save_hyperparameters(hparams)
|
||||||
self.step_count = 0
|
self.step_count = 0
|
||||||
self.output_dir = Path(self.hparams.output_dir)
|
self.output_dir = Path(self.hparams.output_dir)
|
||||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||||
|
@ -245,7 +245,7 @@ class BaseTransformer(pl.LightningModule):
|
||||||
|
|
||||||
class LoggingCallback(pl.Callback):
|
class LoggingCallback(pl.Callback):
|
||||||
def on_batch_end(self, trainer, pl_module):
|
def on_batch_end(self, trainer, pl_module):
|
||||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||||
pl_module.logger.log_metrics(lrs)
|
pl_module.logger.log_metrics(lrs)
|
||||||
|
|
||||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
|
@ -278,6 +278,10 @@ def add_generic_args(parser, root_dir) -> None:
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpus", default=0, type=int, help="The number of GPUs allocated for this, it is by default 0 meaning none",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fp16",
|
"--fp16",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
@ -291,7 +295,7 @@ def add_generic_args(parser, root_dir) -> None:
|
||||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
"See details at https://nvidia.github.io/apex/amp.html",
|
"See details at https://nvidia.github.io/apex/amp.html",
|
||||||
)
|
)
|
||||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
|
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||||
|
|
|
@ -23,7 +23,7 @@ mkdir -p $OUTPUT_DIR
|
||||||
# Add parent directory to python path to access lightning_base.py
|
# Add parent directory to python path to access lightning_base.py
|
||||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
python3 run_pl_glue.py --data_dir $DATA_DIR \
|
python3 run_pl_glue.py --gpus 1 --data_dir $DATA_DIR \
|
||||||
--task $TASK \
|
--task $TASK \
|
||||||
--model_name_or_path $BERT_MODEL \
|
--model_name_or_path $BERT_MODEL \
|
||||||
--output_dir $OUTPUT_DIR \
|
--output_dir $OUTPUT_DIR \
|
||||||
|
|
|
@ -3,6 +3,7 @@ import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -24,6 +25,8 @@ class GLUETransformer(BaseTransformer):
|
||||||
mode = "sequence-classification"
|
mode = "sequence-classification"
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
|
if type(hparams) == dict:
|
||||||
|
hparams = Namespace(**hparams)
|
||||||
hparams.glue_output_mode = glue_output_modes[hparams.task]
|
hparams.glue_output_mode = glue_output_modes[hparams.task]
|
||||||
num_labels = glue_tasks_num_labels[hparams.task]
|
num_labels = glue_tasks_num_labels[hparams.task]
|
||||||
|
|
||||||
|
@ -41,7 +44,8 @@ class GLUETransformer(BaseTransformer):
|
||||||
outputs = self(**inputs)
|
outputs = self(**inputs)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
|
|
||||||
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||||
|
tensorboard_logs = {"loss": loss}
|
||||||
return {"loss": loss, "log": tensorboard_logs}
|
return {"loss": loss, "log": tensorboard_logs}
|
||||||
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
|
@ -71,7 +75,7 @@ class GLUETransformer(BaseTransformer):
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
|
|
||||||
def load_dataset(self, mode, batch_size):
|
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool) -> DataLoader:
|
||||||
"Load datasets. Called after prepare data."
|
"Load datasets. Called after prepare data."
|
||||||
|
|
||||||
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
# We test on dev set to compare to benchmarks without having to submit to GLUE server
|
||||||
|
|
Загрузка…
Ссылка в новой задаче