* refactor folder names

* define python logger correctly

* refactor rich printing method

* specify model parameters explicitly

* specify datamodule params explicitly

* refactor

* update readme
This commit is contained in:
Łukasz Zalewski 2021-03-25 01:31:52 +01:00 коммит произвёл GitHub
Родитель fca999f527
Коммит 381b5844ed
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
18 изменённых файлов: 133 добавлений и 168 удалений

Просмотреть файл

@ -30,9 +30,3 @@ repos:
# other flags: https://pycqa.github.io/isort/docs/configuration/options/
args: [--profile, black, --skip, src/train.py, --skip, run.py, --filter-files]
# files: "src/.*"
# MyPy (static type checking)
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v0.790
# hooks:
# - id: mypy

Просмотреть файл

@ -88,11 +88,11 @@ It makes your code neatly organized and provides lots of useful features, like a
The directory structure of new project looks like this:
```
├── configs <- Hydra configuration files
│ ├── trainer <- Configurations of Lightning trainers
│ ├── datamodule <- Configurations of Lightning datamodules
│ ├── model <- Configurations of Lightning models
│ ├── datamodule <- Configurations of Lightning datamodules
│ ├── callbacks <- Configurations of Lightning callbacks
│ ├── logger <- Configurations of Lightning loggers
│ ├── trainer <- Configurations of Lightning trainers
│ ├── optimizer <- Configurations of optimizers
│ ├── experiment <- Configurations of experiments
│ │
@ -109,11 +109,10 @@ The directory structure of new project looks like this:
├── src
│ ├── architectures <- PyTorch model architectures
│ ├── callbacks <- PyTorch Lightning callbacks
│ ├── datamodules <- PyTorch Lightning datamodules
│ ├── datasets <- PyTorch datasets
│ ├── models <- PyTorch Lightning models
│ ├── transforms <- Data transformations
│ ├── pl_callbacks <- PyTorch Lightning callbacks
│ ├── pl_datamodules <- PyTorch Lightning datamodules
│ ├── pl_models <- PyTorch Lightning models
│ ├── utils <- Utility scripts
│ │ ├── inference_example.py <- Example of inference with trained model
│ │ └── template_utils.py <- Some extra template utilities
@ -167,7 +166,7 @@ python run.py trainer.max_epochs=20 optimizer.lr=1e-4
```
> *You can also add new parameters with `+` sign.*
```yaml
python run.py +trainer.new_param="uwu"
python run.py +model.new_param="uwu"
```
@ -214,7 +213,7 @@ wandb:
# link to wandb dashboard should appear in the terminal
python run.py logger=wandb
```
> **Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.**
> ***Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.***
</details>
@ -222,7 +221,7 @@ python run.py logger=wandb
<details>
<summary>Train model with chosen experiment config</summary>
> Experiment configurations are placed in folder `configs/experiment/`.
> *Experiment configurations are placed in [configs/experiment/](configs/experiment/).*
```yaml
python run.py +experiment=exp_example_simple
```
@ -234,7 +233,7 @@ python run.py +experiment=exp_example_simple
<summary>Attach some callbacks to run</summary>
> *Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).<br>
Callbacks configurations are placed in `configs/callbacks/`.*
Callbacks configurations are placed in [configs/callbacks/](configs/callbacks/).*
```yaml
python run.py callbacks=default_callbacks
```
@ -393,10 +392,10 @@ defaults:
- logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
# path to original working directory (that `run.py` was executed from in command line)
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have path to original working directory as a special variable
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
# path to original working directory
# hydra hijacks working directory by changing it to the current log directory
# so it's useful to have this path as a special variable
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
work_dir: ${hydra:runtime.cwd}
@ -421,8 +420,8 @@ disable_warnings: False
disable_lightning_logs: False
# output paths for hydra logs
hydra:
# output paths for hydra logs
run:
dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
@ -499,24 +498,21 @@ trainer:
gradient_clip_val: 0.5
model:
_target_: src.models.mnist_model.LitModelMNIST
_target_: src.pl_models.mnist_model.MNISTLitModel
input_size: 784
lin1_size: 256
dropout1: 0.30
lin2_size: 256
dropout2: 0.25
lin3_size: 128
dropout3: 0.20
output_size: 10
optimizer:
_target_: torch.optim.Adam
lr: 0.001
eps: 1e-08
weight_decay: 0
weight_decay: 0.0005
datamodule:
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
_target_: src.pl_datamodules.mnist_datamodule.MNISTDataModule
data_dir: ${data_dir}
batch_size: 64
train_val_test_split: [55_000, 5_000, 10_000]
@ -533,9 +529,9 @@ logger:
<br>
### Workflow
1. Write your PyTorch Lightning model (see [mnist_model.py](src/models/mnist_model.py) for example)
2. Write your PyTorch Lightning datamodule (see [mnist_datamodule.py](src/datamodules/mnist_datamodule.py) for example)
3. Write your experiment config, containing paths to your model and datamodule (see [configs/experiment](configs/experiment) for examples)
1. Write your PyTorch Lightning model (see [mnist_model.py](src/pl_models/mnist_model.py) for example)
2. Write your PyTorch Lightning datamodule (see [mnist_datamodule.py](src/pl_datamodules/mnist_datamodule.py) for example)
3. Write your experiment config, containing paths to your model and datamodule (see [configs/experiment](configs/experiment/) for examples)
4. Run training with chosen experiment config:<br>
```yaml
python run.py +experiment=experiment_name
@ -595,7 +591,7 @@ These tools help you keep track of hyperparameters and output metrics and allow
You can use many of them at once (see [configs/logger/many_loggers.yaml](configs/logger/many_loggers.yaml) for example).<br>
You can also write your own logger.<br>
Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/models/mnist_model.py).
Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/pl_models/mnist_model.py).
<br><br>
@ -607,9 +603,9 @@ Take a look at [inference_example.py](src/utils/inference_example.py).
### Callbacks
Template contains example callbacks for better Weights&Biases integration (see [wandb_callbacks.py](src/callbacks/wandb_callbacks.py)).<br>
Template contains example callbacks for better Weights&Biases integration (see [wandb_callbacks.py](src/pl_callbacks/wandb_callbacks.py)).<br>
To support reproducibility: *UploadCodeToWandbAsArtifact*, *UploadCheckpointsToWandbAsArtifact*, *WatchModelWithWandb*.<br>
To provide examples of logging custom visualisations with callbacks only: *LogConfusionMatrixToWandb*, *LogF1PrecisionRecallHeatmapToWandb*.<br>
To provide examples of logging custom visualisations with callbacks only: *LogConfusionMatrixToWandb*, *LogF1PrecRecHeatmapToWandb*.<br>
<br><br>
@ -711,8 +707,8 @@ pip install git+git://github.com/YourGithubName/your-repo-name.git --upgrade
```
So any file can be easily imported into any other file like so:
```python
from project_name.models.mnist_model import LitModelMNIST
from project_name.datamodules.mnist_datamodule import MNISTDataModule
from project_name.pl_models.mnist_model import MNISTLitModel
from project_name.pl_datamodules.mnist_datamodule import MNISTDataModule
```
<br>
@ -839,22 +835,24 @@ pip install -r requirements.txt
Train model with default configuration
```yaml
# default
python run.py
# train on CPU
python run.py trainer.gpus=0
# train on GPU
python run.py trainer.gpus=1
```
Train model with chosen experiment configuration
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
```yaml
# experiment configurations are placed in folder `configs/experiment/`
python run.py +experiment=exp_example_simple
python run.py +experiment=experiment_name
```
You can override any parameter from command line like this
```yaml
python run.py trainer.max_epochs=20 optimizer.lr=0.0005
python run.py trainer.max_epochs=20 datamodule.batch_size=64
```
Train on GPU
```yaml
python run.py trainer.gpus=1
```
<br>

Просмотреть файл

@ -1,7 +1,7 @@
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "val/acc" # name of the logged metric which determines when model is improving
save_top_k: 2 # save k best models (determined by above metric)
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
mode: "max" # can be "max" or "min"
verbose: False

Просмотреть файл

@ -2,26 +2,26 @@ defaults:
- default_callbacks.yaml
upload_code_to_wandb_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCodeToWandbAsArtifact
code_dir: ${work_dir}
upload_ckpts_to_wandb_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCheckpointsToWandbAsArtifact
ckpt_dir: "checkpoints/"
upload_best_only: True
watch_model_with_wandb:
_target_: src.callbacks.wandb_callbacks.WatchModelWithWandb
_target_: src.pl_callbacks.wandb_callbacks.WatchModelWithWandb
log: "all"
log_freq: 100
upload_code_to_wandb_as_artifact:
_target_: src.pl_callbacks.wandb_callbacks.UploadCodeToWandbAsArtifact
code_dir: ${work_dir}
upload_ckpts_to_wandb_as_artifact:
_target_: src.pl_callbacks.wandb_callbacks.UploadCheckpointsToWandbAsArtifact
ckpt_dir: "checkpoints/"
upload_best_only: True
save_f1_precision_recall_heatmap_to_wandb:
_target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmapToWandb
_target_: src.pl_callbacks.wandb_callbacks.LogF1PrecRecHeatmapToWandb
save_confusion_matrix_to_wandb:
_target_: src.callbacks.wandb_callbacks.LogConfusionMatrixToWandb
_target_: src.pl_callbacks.wandb_callbacks.LogConfusionMatrixToWandb

Просмотреть файл

@ -14,36 +14,30 @@ defaults:
# - override hydra/job_logging: colorlog
# path to original working directory (that `run.py` was executed from in command line)
# path to original working directory
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have path to original work dir as a special variable
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
# so it's useful to have this path as a special variable
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
work_dir: ${hydra:runtime.cwd}
# path to folder with data
data_dir: ${work_dir}/data/
# use `python run.py debug=true` for easy debugging!
# this will run 1 train, val and test loop with only 1 batch
# equivalent to running `python run.py trainer.fast_dev_run=true`
# (this is placed here just for easier access from command line)
debug: False
# pretty print config at the start of the run using Rich library
print_config: True
# disable python warnings if they annoy you
disable_warnings: True
# disable lightning logs if they annoy you
disable_lightning_logs: False
hydra:
# output paths for hydra logs
run:

Просмотреть файл

@ -1,4 +1,4 @@
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
_target_: src.pl_datamodules.mnist_datamodule.MNISTDataModule
data_dir: ${data_dir} # data_dir is specified in config.yaml
batch_size: 64

Просмотреть файл

@ -28,18 +28,15 @@ trainer:
# resume_from_checkpoint: ${work_dir}/last.ckpt
model:
_target_: src.models.mnist_model.LitModelMNIST
_target_: src.pl_models.mnist_model.MNISTLitModel
optimizer: adam
lr: 0.001
weight_decay: 0.00005
architecture: SimpleDenseNet
input_size: 784
lin1_size: 256
dropout1: 0.30
lin2_size: 256
dropout2: 0.25
lin3_size: 128
dropout3: 0.20
output_size: 10
optimizer:
@ -48,7 +45,7 @@ optimizer:
weight_decay: 0
datamodule:
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
_target_: src.pl_datamodules.mnist_datamodule.MNISTDataModule
data_dir: ${data_dir}
batch_size: 64
train_val_test_split: [55_000, 5_000, 10_000]
@ -67,7 +64,7 @@ callbacks:
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: "val/acc"
patience: 100
patience: 10
mode: "max"
logger:

Просмотреть файл

@ -6,7 +6,7 @@
defaults:
- override /trainer: default_trainer.yaml # choose trainer from 'configs/trainer/' folder or set to null
- override /model: mnist_model.yaml # choose model from 'configs/model/' folder or set to null
- override /optimizer: adam.yaml # choose model from 'configs/model/' folder or set to null
- override /optimizer: adam.yaml # choose model from 'configs/model/' folder or set to null
- override /datamodule: mnist_datamodule.yaml # choose datamodule from 'configs/datamodule/' folder or set to null
- override /callbacks: default_callbacks.yaml # choose callback set from 'configs/callbacks/' folder or set to null
- override /logger: null # choose logger from 'configs/logger/' folder or set to null

Просмотреть файл

@ -1,10 +1,7 @@
_target_: src.models.mnist_model.LitModelMNIST
_target_: src.pl_models.mnist_model.MNISTLitModel
input_size: 784
lin1_size: 256
dropout1: 0.30
lin2_size: 256
dropout2: 0.25
lin3_size: 128
dropout3: 0.20
lin3_size: 256
output_size: 10

4
run.py
Просмотреть файл

@ -1,5 +1,4 @@
import hydra
from hydra.utils import log
from omegaconf import DictConfig
@ -7,7 +6,7 @@ from omegaconf import DictConfig
def main(config: DictConfig):
# Imports should be nested inside @hydra.main to optimize tab completion
# Learn more here: https://github.com/facebookresearch/hydra/issues/934
# Read more here: https://github.com/facebookresearch/hydra/issues/934
import dotenv
from src.train import train
from src.utils import template_utils
@ -25,7 +24,6 @@ def main(config: DictConfig):
# Pretty print config using Rich library
if config.get("print_config"):
log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
template_utils.print_config(config, resolve=True)
# Train model

Просмотреть файл

@ -9,23 +9,19 @@ class SimpleDenseNet(nn.Module):
nn.Linear(hparams["input_size"], hparams["lin1_size"]),
nn.BatchNorm1d(hparams["lin1_size"]),
nn.ReLU(),
nn.Dropout(p=hparams["dropout1"]),
nn.Linear(hparams["lin1_size"], hparams["lin2_size"]),
nn.BatchNorm1d(hparams["lin2_size"]),
nn.ReLU(),
nn.Dropout(p=hparams["dropout2"]),
nn.Linear(hparams["lin2_size"], hparams["lin3_size"]),
nn.BatchNorm1d(hparams["lin3_size"]),
nn.ReLU(),
nn.Dropout(p=hparams["dropout3"]),
nn.Linear(hparams["lin3_size"], hparams["output_size"]),
)
def forward(self, x):
batch_size, channels, width, height = x.size()
# mnist images are (1, 28, 28) (channels, width, height)
# (b, 1, 28, 28) -> (b, 1*28*28)
# (batch, 1, width, height) -> (batch, 1*width*height)
x = x.view(batch_size, -1)
return self.model(x)

Просмотреть файл

@ -7,15 +7,19 @@ import seaborn as sn
import torch
import wandb
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from sklearn import metrics
from sklearn.metrics import f1_score, precision_score, recall_score
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
for logger in trainer.logger:
if isinstance(logger, WandbLogger):
return logger
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
if isinstance(trainer.logger, LoggerCollection):
for logger in trainer.logger:
if isinstance(logger, WandbLogger):
return logger
raise Exception(
"You are using wandb related callback,"

Просмотреть файл

@ -1,4 +1,4 @@
from typing import Optional, Sequence
from typing import Optional
from pytorch_lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
@ -10,21 +10,29 @@ class MNISTDataModule(LightningDataModule):
"""
Example of LightningDataModule for MNIST dataset.
A DataModule standardizes the training, val, test splits, data preparation and transforms.
A DataModule standardizes the train, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
"""
def __init__(self, *args, **kwargs):
def __init__(
self,
data_dir="data/",
batch_size=64,
train_val_test_split=(55_000, 5_000, 10_000),
num_workers=0,
pin_memory=False,
**kwargs,
):
super().__init__()
self.data_dir = kwargs["data_dir"]
self.batch_size = kwargs["batch_size"]
self.train_val_test_split = kwargs["train_val_test_split"]
self.num_workers = kwargs["num_workers"]
self.pin_memory = kwargs["pin_memory"]
self.data_dir = data_dir
self.batch_size = batch_size
self.train_val_test_split = train_val_test_split
self.num_workers = num_workers
self.pin_memory = pin_memory
self.transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]

Просмотреть файл

@ -1,16 +1,15 @@
from typing import Any, Dict, List, Sequence, Tuple, Union
import hydra
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics.classification import Accuracy
from torch.optim import Optimizer
from src.architectures.simple_dense_net import SimpleDenseNet
class LitModelMNIST(pl.LightningModule):
class MNISTLitModel(LightningModule):
"""
Example of LightningModule for MNIST classification.
@ -25,14 +24,23 @@ class LitModelMNIST(pl.LightningModule):
https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
"""
def __init__(self, *args, **kwargs):
def __init__(
self,
optimizer,
input_size=784,
lin1_size=256,
lin2_size=256,
lin3_size=256,
output_size=10,
**kwargs
):
super().__init__()
# this line ensures params passed to LightningModule will be saved to ckpt
# it also allows to access params with 'self.hparams' attribute
self.save_hyperparameters()
self.architecture = SimpleDenseNet(hparams=self.hparams)
self.model = SimpleDenseNet(hparams=self.hparams)
# loss function
self.criterion = torch.nn.CrossEntropyLoss()
@ -51,7 +59,7 @@ class LitModelMNIST(pl.LightningModule):
}
def forward(self, x) -> torch.Tensor:
return self.architecture(x)
return self.model(x)
def step(self, batch) -> Dict[str, torch.Tensor]:
x, y = batch
@ -63,7 +71,7 @@ class LitModelMNIST(pl.LightningModule):
def training_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
loss, preds, targets = self.step(batch)
# log train metrics to your loggers!
# log train metrics
acc = self.train_accuracy(preds, targets)
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
@ -76,19 +84,17 @@ class LitModelMNIST(pl.LightningModule):
def validation_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
loss, preds, targets = self.step(batch)
# log val metrics to your loggers!
# log val metrics
acc = self.val_accuracy(preds, targets)
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
# we can return here dict with any tensors
# and then read it in some callback or in validation_epoch_end() below
return {"loss": loss, "preds": preds, "targets": targets}
def test_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
loss, preds, targets = self.step(batch)
# log test metrics to your loggers!
# log test metrics
acc = self.test_accuracy(preds, targets)
self.log("test/loss", loss, on_step=False, on_epoch=True)
self.log("test/acc", acc, on_step=False, on_epoch=True)
@ -113,6 +119,10 @@ class LitModelMNIST(pl.LightningModule):
self.log("val/acc_best", max(self.metric_hist["val/acc"]), prog_bar=False)
self.log("val/loss_best", min(self.metric_hist["val/loss"]), prog_bar=False)
# [OPTIONAL METHOD]
def test_epoch_end(self, outputs: List[Any]) -> None:
pass
def configure_optimizers(
self,
) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:

Просмотреть файл

@ -1,19 +1,17 @@
# lightning imports
import logging
from typing import List, Optional
from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning import seed_everything
# hydra imports
from omegaconf import DictConfig
from hydra.utils import log
import hydra
from omegaconf import DictConfig
# normal imports
from typing import List, Optional
# src imports
from src.utils import template_utils
log = logging.getLogger(__name__)
def train(config: DictConfig) -> Optional[float]:
"""Contains training pipeline.

Просмотреть файл

@ -1,13 +0,0 @@
from torchvision import transforms
mnist_train_transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
mnist_test_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Normalize((0.1307,), (0.3081,)),
]
)

Просмотреть файл

@ -1,7 +1,7 @@
from PIL import Image
from torchvision import transforms
from src.models.mnist_model import LitModelMNIST
from src.transforms import mnist_transforms
from src.pl_models.mnist_model import MNISTLitModel
def predict():
@ -16,7 +16,7 @@ def predict():
# load model from checkpoint
# model __init__ parameters will be loaded from ckpt automatically
# you can also pass some parameter explicitly to override it
trained_model = LitModelMNIST.load_from_checkpoint(checkpoint_path=CKPT_PATH)
trained_model = MNISTLitModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
# print model hyperparameters
print(trained_model.hparams)
@ -30,7 +30,14 @@ def predict():
# img = Image.open("data/example_img.png").convert("RGB") # convert to RGB
# preprocess
img = mnist_transforms.mnist_test_transforms(img)
mnist_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
img = mnist_transforms(img)
img = img.reshape((1, *img.size())) # reshape to form batch of size 1
# inference

Просмотреть файл

@ -4,13 +4,14 @@ from typing import List, Sequence
import pytorch_lightning as pl
import wandb
from hydra.utils import log
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.loggers.wandb import WandbLogger
from rich import print
from rich.syntax import Syntax
from rich.tree import Tree
log = logging.getLogger(__name__)
def extras(config: DictConfig) -> None:
"""A couple of optional utilities, controlled by main config file.
@ -22,10 +23,10 @@ def extras(config: DictConfig) -> None:
config (DictConfig): [description]
"""
# make it possible to add new keys to config
# Enable adding new keys to config
OmegaConf.set_struct(config, False)
# fix double logging bug (this will be removed when lightning releases patch)
# Fix double logging bug (this will be removed when lightning releases patch)
pl_logger = logging.getLogger("lightning")
pl_logger.propagate = False
@ -56,7 +57,7 @@ def extras(config: DictConfig) -> None:
if config.datamodule.get("num_workers"):
config.datamodule.num_workers = 0
# disable adding new keys to config
# Disable adding new keys to config
OmegaConf.set_struct(config, True)
@ -71,10 +72,6 @@ def print_config(
"logger",
"seed",
),
extra_depth_fields: Sequence[str] = (
"callbacks",
"logger",
),
resolve: bool = True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
@ -83,41 +80,21 @@ def print_config(
config (DictConfig): Config.
fields (Sequence[str], optional): Determines which main fields from config will be printed
and in what order.
extra_depth_fields (Sequence[str], optional): Fields which should be printed with extra tree depth.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
# TODO print main config path and experiment config path
# print(f"Main config path: [link file://{directory}]{directory}")
# TODO refactor the whole method
style = "dim"
tree = Tree(f":gear: CONFIG", style=style, guide_style=style)
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
if not config_section:
# raise Exception(f"Field {field} not found in config!")
branch.add("None")
continue
if field in extra_depth_fields:
for nested_field in config_section:
nested_config_section = config_section[nested_field]
nested_branch = branch.add(nested_field, style=style, guide_style=style)
cfg_str = OmegaConf.to_yaml(nested_config_section, resolve=resolve)
nested_branch.add(Syntax(cfg_str, "yaml"))
else:
if isinstance(config_section, DictConfig):
cfg_str = OmegaConf.to_yaml(config_section, resolve=resolve)
else:
cfg_str = str(config_section)
branch.add(Syntax(cfg_str, "yaml"))
branch.add(Syntax(branch_content, "yaml"))
print(tree)