* format files

* remove best metric score callback

* update readme
This commit is contained in:
Łukasz Zalewski 2021-03-15 17:24:56 +01:00 коммит произвёл GitHub
Родитель d571a4e2c2
Коммит 1f326bb5e9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 18 добавлений и 69 удалений

Просмотреть файл

@ -28,7 +28,7 @@ repos:
- id: isort
# profiles: https://pycqa.github.io/isort/docs/configuration/profiles/
# other flags: https://pycqa.github.io/isort/docs/configuration/options/
args: [--profile, black, --skip, src/train.py, --filter-files]
args: [--profile, black, --skip, src/train.py, --skip, run.py, --filter-files]
# files: "src/.*"
# MyPy (static type checking)

Просмотреть файл

@ -571,23 +571,17 @@ Lightning provides convenient method for logging custom metrics from inside Ligh
### Inference
Template contains simple example of loading model from checkpoint and running predictions.
Template contains simple example of loading model from checkpoint and running predictions.<br>
Take a look at [inference_example.py](src/utils/inference_example.py).
<br><br>
### Callbacks
Template contains example callbacks for better Weights&Biases integration (see [wandb_callbacks.py](src/callbacks/wandb_callbacks.py)).<br>
To support reproducibility: *UploadCodeToWandbAsArtifact*, *UploadCheckpointsToWandbAsArtifact*, *WatchModelWithWandb*.<br>
To provide examples of logging custom visualisations with callbacks only: *LogConfusionMatrixToWandb*, *LogF1PrecisionRecallHeatmapToWandb*.<br>
<br><br>
To support reproducibility:
- UploadCodeToWandbAsArtifact
- UploadCheckpointsToWandbAsArtifact
- WatchModelWithWandb
To provide examples of logging custom visualisations with callbacks only:
- LogConfusionMatrixToWandb
- LogF1PrecisionRecallHeatmapToWandb
<br>
## Best Practices

5
run.py
Просмотреть файл

@ -1,4 +1,5 @@
import hydra
from hydra.utils import log
from omegaconf import DictConfig
@ -10,7 +11,7 @@ def main(config: DictConfig):
import dotenv
from src.train import train
from src.utils import template_utils
# load environment variables from `.env` file
dotenv.load_dotenv(dotenv_path=".env", override=True)
@ -24,7 +25,7 @@ def main(config: DictConfig):
# Pretty print config using Rich library
if config.get("print_config"):
hydra.utils.log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
template_utils.print_config(config, resolve=True)
# Train model

Просмотреть файл

@ -26,7 +26,7 @@ def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger:
class UploadCodeToWandbAsArtifact(Callback):
"""Upload all *.py files to wandb as an artifact at the beginning of the run."""
"""Upload all *.py files to wandb as an artifact, at the beginning of the run."""
def __init__(self, code_dir: str):
self.code_dir = code_dir
@ -43,7 +43,7 @@ class UploadCodeToWandbAsArtifact(Callback):
class UploadCheckpointsToWandbAsArtifact(Callback):
"""Upload experiment checkpoints to wandb as an artifact at the end of training."""
"""Upload checkpoints to wandb as an artifact, at the end of training."""
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
self.ckpt_dir = ckpt_dir
@ -82,7 +82,6 @@ class LogF1PrecisionRecallHeatmapToWandb(Callback):
"""
Generate f1, precision and recall heatmap from validation step outputs.
Expects validation step to return predictions and targets.
Works only for single label classification!
"""
def __init__(self, class_names: List[str] = None):
@ -136,7 +135,6 @@ class LogConfusionMatrixToWandb(Callback):
"""
Generate Confusion Matrix.
Expects validation step to return predictions and targets.
Works only for single label classification!
"""
def __init__(self, class_names: List[str] = None):
@ -180,51 +178,3 @@ class LogConfusionMatrixToWandb(Callback):
self.preds = []
self.targets = []
''' BUGGED :(
class LogBestMetricScoresToWandb(Callback):
"""
Store in wandb:
- max train acc
- min train loss
- max val acc
- min val loss
Useful for comparing runs in table views, as wandb doesn't currently support column aggregation.
"""
def __init__(self):
self.train_loss_best = None
self.train_acc_best = None
self.val_loss_best = None
self.val_acc_best = None
self.ready = False
def on_sanity_check_end(self, trainer, pl_module):
"""Start executing this callback only after all validation sanity checks end."""
self.ready = True
def on_epoch_end(self, trainer, pl_module):
if self.ready:
logger = get_wandb_logger(trainer=trainer)
experiment = logger.experiment
metrics = trainer.callback_metrics
if not self.train_loss_best or metrics["train/loss"] < self.train_loss_best:
self.train_loss_best = metrics["train_loss"]
if not self.train_acc_best or metrics["train/acc"] > self.train_acc_best:
self.train_acc_best = metrics["train/acc"]
if not self.val_loss_best or metrics["val/loss"] < self.val_loss_best:
self.val_loss_best = metrics["val/loss"]
if not self.val_acc_best or metrics["val/acc"] > self.val_acc_best:
self.val_acc_best = metrics["val/acc"]
experiment.log({"train/loss_best": self.train_loss_best}, commit=False)
experiment.log({"train/acc_best": self.train_acc_best}, commit=False)
experiment.log({"val/loss_best": self.val_loss_best}, commit=False)
experiment.log({"val/acc_best": self.val_acc_best}, commit=False)
'''

Просмотреть файл

@ -5,5 +5,9 @@ mnist_train_transforms = transforms.Compose(
)
mnist_test_transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Resize((28, 28)), transforms.Normalize((0.1307,), (0.3081,))]
[
transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

Просмотреть файл

@ -17,7 +17,7 @@ def predict():
# model __init__ parameters will be loaded from ckpt automatically
# you can also pass some parameter explicitly to override it
trained_model = LitModelMNIST.load_from_checkpoint(checkpoint_path=CKPT_PATH)
# print model hyperparameters
print(trained_model.hparams)

Просмотреть файл

@ -89,9 +89,9 @@ def print_config(
# TODO print main config path and experiment config path
# print(f"Main config path: [link file://{directory}]{directory}")
# TODO refactor the whole method
style = "dim"
tree = Tree(f":gear: CONFIG", style=style, guide_style=style)