Dev (#99)
* format files * remove best metric score callback * update readme
This commit is contained in:
Родитель
d571a4e2c2
Коммит
1f326bb5e9
|
@ -28,7 +28,7 @@ repos:
|
|||
- id: isort
|
||||
# profiles: https://pycqa.github.io/isort/docs/configuration/profiles/
|
||||
# other flags: https://pycqa.github.io/isort/docs/configuration/options/
|
||||
args: [--profile, black, --skip, src/train.py, --filter-files]
|
||||
args: [--profile, black, --skip, src/train.py, --skip, run.py, --filter-files]
|
||||
# files: "src/.*"
|
||||
|
||||
# MyPy (static type checking)
|
||||
|
|
14
README.md
14
README.md
|
@ -571,23 +571,17 @@ Lightning provides convenient method for logging custom metrics from inside Ligh
|
|||
|
||||
|
||||
### Inference
|
||||
Template contains simple example of loading model from checkpoint and running predictions.
|
||||
Template contains simple example of loading model from checkpoint and running predictions.<br>
|
||||
Take a look at [inference_example.py](src/utils/inference_example.py).
|
||||
<br><br>
|
||||
|
||||
|
||||
### Callbacks
|
||||
Template contains example callbacks for better Weights&Biases integration (see [wandb_callbacks.py](src/callbacks/wandb_callbacks.py)).<br>
|
||||
To support reproducibility: *UploadCodeToWandbAsArtifact*, *UploadCheckpointsToWandbAsArtifact*, *WatchModelWithWandb*.<br>
|
||||
To provide examples of logging custom visualisations with callbacks only: *LogConfusionMatrixToWandb*, *LogF1PrecisionRecallHeatmapToWandb*.<br>
|
||||
<br><br>
|
||||
|
||||
To support reproducibility:
|
||||
- UploadCodeToWandbAsArtifact
|
||||
- UploadCheckpointsToWandbAsArtifact
|
||||
- WatchModelWithWandb
|
||||
|
||||
To provide examples of logging custom visualisations with callbacks only:
|
||||
- LogConfusionMatrixToWandb
|
||||
- LogF1PrecisionRecallHeatmapToWandb
|
||||
<br>
|
||||
|
||||
|
||||
## Best Practices
|
||||
|
|
5
run.py
5
run.py
|
@ -1,4 +1,5 @@
|
|||
import hydra
|
||||
from hydra.utils import log
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
|
@ -10,7 +11,7 @@ def main(config: DictConfig):
|
|||
import dotenv
|
||||
from src.train import train
|
||||
from src.utils import template_utils
|
||||
|
||||
|
||||
# load environment variables from `.env` file
|
||||
dotenv.load_dotenv(dotenv_path=".env", override=True)
|
||||
|
||||
|
@ -24,7 +25,7 @@ def main(config: DictConfig):
|
|||
|
||||
# Pretty print config using Rich library
|
||||
if config.get("print_config"):
|
||||
hydra.utils.log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
|
||||
log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
|
||||
template_utils.print_config(config, resolve=True)
|
||||
|
||||
# Train model
|
||||
|
|
|
@ -26,7 +26,7 @@ def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger:
|
|||
|
||||
|
||||
class UploadCodeToWandbAsArtifact(Callback):
|
||||
"""Upload all *.py files to wandb as an artifact at the beginning of the run."""
|
||||
"""Upload all *.py files to wandb as an artifact, at the beginning of the run."""
|
||||
|
||||
def __init__(self, code_dir: str):
|
||||
self.code_dir = code_dir
|
||||
|
@ -43,7 +43,7 @@ class UploadCodeToWandbAsArtifact(Callback):
|
|||
|
||||
|
||||
class UploadCheckpointsToWandbAsArtifact(Callback):
|
||||
"""Upload experiment checkpoints to wandb as an artifact at the end of training."""
|
||||
"""Upload checkpoints to wandb as an artifact, at the end of training."""
|
||||
|
||||
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
|
||||
self.ckpt_dir = ckpt_dir
|
||||
|
@ -82,7 +82,6 @@ class LogF1PrecisionRecallHeatmapToWandb(Callback):
|
|||
"""
|
||||
Generate f1, precision and recall heatmap from validation step outputs.
|
||||
Expects validation step to return predictions and targets.
|
||||
Works only for single label classification!
|
||||
"""
|
||||
|
||||
def __init__(self, class_names: List[str] = None):
|
||||
|
@ -136,7 +135,6 @@ class LogConfusionMatrixToWandb(Callback):
|
|||
"""
|
||||
Generate Confusion Matrix.
|
||||
Expects validation step to return predictions and targets.
|
||||
Works only for single label classification!
|
||||
"""
|
||||
|
||||
def __init__(self, class_names: List[str] = None):
|
||||
|
@ -180,51 +178,3 @@ class LogConfusionMatrixToWandb(Callback):
|
|||
|
||||
self.preds = []
|
||||
self.targets = []
|
||||
|
||||
|
||||
''' BUGGED :(
|
||||
class LogBestMetricScoresToWandb(Callback):
|
||||
"""
|
||||
Store in wandb:
|
||||
- max train acc
|
||||
- min train loss
|
||||
- max val acc
|
||||
- min val loss
|
||||
Useful for comparing runs in table views, as wandb doesn't currently support column aggregation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.train_loss_best = None
|
||||
self.train_acc_best = None
|
||||
self.val_loss_best = None
|
||||
self.val_acc_best = None
|
||||
self.ready = False
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
"""Start executing this callback only after all validation sanity checks end."""
|
||||
self.ready = True
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if self.ready:
|
||||
logger = get_wandb_logger(trainer=trainer)
|
||||
experiment = logger.experiment
|
||||
|
||||
metrics = trainer.callback_metrics
|
||||
|
||||
if not self.train_loss_best or metrics["train/loss"] < self.train_loss_best:
|
||||
self.train_loss_best = metrics["train_loss"]
|
||||
|
||||
if not self.train_acc_best or metrics["train/acc"] > self.train_acc_best:
|
||||
self.train_acc_best = metrics["train/acc"]
|
||||
|
||||
if not self.val_loss_best or metrics["val/loss"] < self.val_loss_best:
|
||||
self.val_loss_best = metrics["val/loss"]
|
||||
|
||||
if not self.val_acc_best or metrics["val/acc"] > self.val_acc_best:
|
||||
self.val_acc_best = metrics["val/acc"]
|
||||
|
||||
experiment.log({"train/loss_best": self.train_loss_best}, commit=False)
|
||||
experiment.log({"train/acc_best": self.train_acc_best}, commit=False)
|
||||
experiment.log({"val/loss_best": self.val_loss_best}, commit=False)
|
||||
experiment.log({"val/acc_best": self.val_acc_best}, commit=False)
|
||||
'''
|
||||
|
|
|
@ -5,5 +5,9 @@ mnist_train_transforms = transforms.Compose(
|
|||
)
|
||||
|
||||
mnist_test_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Resize((28, 28)), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((28, 28)),
|
||||
transforms.Normalize((0.1307,), (0.3081,)),
|
||||
]
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ def predict():
|
|||
# model __init__ parameters will be loaded from ckpt automatically
|
||||
# you can also pass some parameter explicitly to override it
|
||||
trained_model = LitModelMNIST.load_from_checkpoint(checkpoint_path=CKPT_PATH)
|
||||
|
||||
|
||||
# print model hyperparameters
|
||||
print(trained_model.hparams)
|
||||
|
||||
|
|
|
@ -89,9 +89,9 @@ def print_config(
|
|||
|
||||
# TODO print main config path and experiment config path
|
||||
# print(f"Main config path: [link file://{directory}]{directory}")
|
||||
|
||||
|
||||
# TODO refactor the whole method
|
||||
|
||||
|
||||
style = "dim"
|
||||
|
||||
tree = Tree(f":gear: CONFIG", style=style, guide_style=style)
|
||||
|
|
Загрузка…
Ссылка в новой задаче