Dev (#107)
* refactor folder names * define python logger correctly * refactor rich printing method * specify model parameters explicitly * specify datamodule params explicitly * refactor * update readme
This commit is contained in:
Родитель
fca999f527
Коммит
381b5844ed
|
@ -30,9 +30,3 @@ repos:
|
|||
# other flags: https://pycqa.github.io/isort/docs/configuration/options/
|
||||
args: [--profile, black, --skip, src/train.py, --skip, run.py, --filter-files]
|
||||
# files: "src/.*"
|
||||
|
||||
# MyPy (static type checking)
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v0.790
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
|
|
72
README.md
72
README.md
|
@ -88,11 +88,11 @@ It makes your code neatly organized and provides lots of useful features, like a
|
|||
The directory structure of new project looks like this:
|
||||
```
|
||||
├── configs <- Hydra configuration files
|
||||
│ ├── trainer <- Configurations of Lightning trainers
|
||||
│ ├── datamodule <- Configurations of Lightning datamodules
|
||||
│ ├── model <- Configurations of Lightning models
|
||||
│ ├── datamodule <- Configurations of Lightning datamodules
|
||||
│ ├── callbacks <- Configurations of Lightning callbacks
|
||||
│ ├── logger <- Configurations of Lightning loggers
|
||||
│ ├── trainer <- Configurations of Lightning trainers
|
||||
│ ├── optimizer <- Configurations of optimizers
|
||||
│ ├── experiment <- Configurations of experiments
|
||||
│ │
|
||||
|
@ -109,11 +109,10 @@ The directory structure of new project looks like this:
|
|||
│
|
||||
├── src
|
||||
│ ├── architectures <- PyTorch model architectures
|
||||
│ ├── callbacks <- PyTorch Lightning callbacks
|
||||
│ ├── datamodules <- PyTorch Lightning datamodules
|
||||
│ ├── datasets <- PyTorch datasets
|
||||
│ ├── models <- PyTorch Lightning models
|
||||
│ ├── transforms <- Data transformations
|
||||
│ ├── pl_callbacks <- PyTorch Lightning callbacks
|
||||
│ ├── pl_datamodules <- PyTorch Lightning datamodules
|
||||
│ ├── pl_models <- PyTorch Lightning models
|
||||
│ ├── utils <- Utility scripts
|
||||
│ │ ├── inference_example.py <- Example of inference with trained model
|
||||
│ │ └── template_utils.py <- Some extra template utilities
|
||||
|
@ -167,7 +166,7 @@ python run.py trainer.max_epochs=20 optimizer.lr=1e-4
|
|||
```
|
||||
> *You can also add new parameters with `+` sign.*
|
||||
```yaml
|
||||
python run.py +trainer.new_param="uwu"
|
||||
python run.py +model.new_param="uwu"
|
||||
|
||||
```
|
||||
|
||||
|
@ -214,7 +213,7 @@ wandb:
|
|||
# link to wandb dashboard should appear in the terminal
|
||||
python run.py logger=wandb
|
||||
```
|
||||
> **Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.**
|
||||
> ***Click [here](https://wandb.ai/hobglob/template-dashboard/) to see example wandb dashboard generated with this template.***
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -222,7 +221,7 @@ python run.py logger=wandb
|
|||
<details>
|
||||
<summary>Train model with chosen experiment config</summary>
|
||||
|
||||
> Experiment configurations are placed in folder `configs/experiment/`.
|
||||
> *Experiment configurations are placed in [configs/experiment/](configs/experiment/).*
|
||||
```yaml
|
||||
python run.py +experiment=exp_example_simple
|
||||
```
|
||||
|
@ -234,7 +233,7 @@ python run.py +experiment=exp_example_simple
|
|||
<summary>Attach some callbacks to run</summary>
|
||||
|
||||
> *Callbacks can be used for things such as as model checkpointing, early stopping and [many more](https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html#built-in-callbacks).<br>
|
||||
Callbacks configurations are placed in `configs/callbacks/`.*
|
||||
Callbacks configurations are placed in [configs/callbacks/](configs/callbacks/).*
|
||||
```yaml
|
||||
python run.py callbacks=default_callbacks
|
||||
```
|
||||
|
@ -393,10 +392,10 @@ defaults:
|
|||
- logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
|
||||
|
||||
|
||||
# path to original working directory (that `run.py` was executed from in command line)
|
||||
# hydra hijacks working directory by changing it to the current log directory,
|
||||
# so it's useful to have path to original working directory as a special variable
|
||||
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
||||
# path to original working directory
|
||||
# hydra hijacks working directory by changing it to the current log directory
|
||||
# so it's useful to have this path as a special variable
|
||||
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
||||
work_dir: ${hydra:runtime.cwd}
|
||||
|
||||
|
||||
|
@ -421,8 +420,8 @@ disable_warnings: False
|
|||
disable_lightning_logs: False
|
||||
|
||||
|
||||
# output paths for hydra logs
|
||||
hydra:
|
||||
# output paths for hydra logs
|
||||
run:
|
||||
dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
sweep:
|
||||
|
@ -499,24 +498,21 @@ trainer:
|
|||
gradient_clip_val: 0.5
|
||||
|
||||
model:
|
||||
_target_: src.models.mnist_model.LitModelMNIST
|
||||
_target_: src.pl_models.mnist_model.MNISTLitModel
|
||||
input_size: 784
|
||||
lin1_size: 256
|
||||
dropout1: 0.30
|
||||
lin2_size: 256
|
||||
dropout2: 0.25
|
||||
lin3_size: 128
|
||||
dropout3: 0.20
|
||||
output_size: 10
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.Adam
|
||||
lr: 0.001
|
||||
eps: 1e-08
|
||||
weight_decay: 0
|
||||
weight_decay: 0.0005
|
||||
|
||||
datamodule:
|
||||
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
|
||||
_target_: src.pl_datamodules.mnist_datamodule.MNISTDataModule
|
||||
data_dir: ${data_dir}
|
||||
batch_size: 64
|
||||
train_val_test_split: [55_000, 5_000, 10_000]
|
||||
|
@ -533,9 +529,9 @@ logger:
|
|||
<br>
|
||||
|
||||
### Workflow
|
||||
1. Write your PyTorch Lightning model (see [mnist_model.py](src/models/mnist_model.py) for example)
|
||||
2. Write your PyTorch Lightning datamodule (see [mnist_datamodule.py](src/datamodules/mnist_datamodule.py) for example)
|
||||
3. Write your experiment config, containing paths to your model and datamodule (see [configs/experiment](configs/experiment) for examples)
|
||||
1. Write your PyTorch Lightning model (see [mnist_model.py](src/pl_models/mnist_model.py) for example)
|
||||
2. Write your PyTorch Lightning datamodule (see [mnist_datamodule.py](src/pl_datamodules/mnist_datamodule.py) for example)
|
||||
3. Write your experiment config, containing paths to your model and datamodule (see [configs/experiment](configs/experiment/) for examples)
|
||||
4. Run training with chosen experiment config:<br>
|
||||
```yaml
|
||||
python run.py +experiment=experiment_name
|
||||
|
@ -595,7 +591,7 @@ These tools help you keep track of hyperparameters and output metrics and allow
|
|||
You can use many of them at once (see [configs/logger/many_loggers.yaml](configs/logger/many_loggers.yaml) for example).<br>
|
||||
You can also write your own logger.<br>
|
||||
|
||||
Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/models/mnist_model.py).
|
||||
Lightning provides convenient method for logging custom metrics from inside LightningModule. Read the docs [here](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#automatic-logging) or take a look at [MNIST example](src/pl_models/mnist_model.py).
|
||||
<br><br>
|
||||
|
||||
|
||||
|
@ -607,9 +603,9 @@ Take a look at [inference_example.py](src/utils/inference_example.py).
|
|||
|
||||
|
||||
### Callbacks
|
||||
Template contains example callbacks for better Weights&Biases integration (see [wandb_callbacks.py](src/callbacks/wandb_callbacks.py)).<br>
|
||||
Template contains example callbacks for better Weights&Biases integration (see [wandb_callbacks.py](src/pl_callbacks/wandb_callbacks.py)).<br>
|
||||
To support reproducibility: *UploadCodeToWandbAsArtifact*, *UploadCheckpointsToWandbAsArtifact*, *WatchModelWithWandb*.<br>
|
||||
To provide examples of logging custom visualisations with callbacks only: *LogConfusionMatrixToWandb*, *LogF1PrecisionRecallHeatmapToWandb*.<br>
|
||||
To provide examples of logging custom visualisations with callbacks only: *LogConfusionMatrixToWandb*, *LogF1PrecRecHeatmapToWandb*.<br>
|
||||
<br><br>
|
||||
|
||||
|
||||
|
@ -711,8 +707,8 @@ pip install git+git://github.com/YourGithubName/your-repo-name.git --upgrade
|
|||
```
|
||||
So any file can be easily imported into any other file like so:
|
||||
```python
|
||||
from project_name.models.mnist_model import LitModelMNIST
|
||||
from project_name.datamodules.mnist_datamodule import MNISTDataModule
|
||||
from project_name.pl_models.mnist_model import MNISTLitModel
|
||||
from project_name.pl_datamodules.mnist_datamodule import MNISTDataModule
|
||||
```
|
||||
<br>
|
||||
|
||||
|
@ -839,22 +835,24 @@ pip install -r requirements.txt
|
|||
|
||||
Train model with default configuration
|
||||
```yaml
|
||||
# default
|
||||
python run.py
|
||||
|
||||
# train on CPU
|
||||
python run.py trainer.gpus=0
|
||||
|
||||
# train on GPU
|
||||
python run.py trainer.gpus=1
|
||||
```
|
||||
|
||||
Train model with chosen experiment configuration
|
||||
Train model with chosen experiment configuration from [configs/experiment/](configs/experiment/)
|
||||
```yaml
|
||||
# experiment configurations are placed in folder `configs/experiment/`
|
||||
python run.py +experiment=exp_example_simple
|
||||
python run.py +experiment=experiment_name
|
||||
```
|
||||
|
||||
You can override any parameter from command line like this
|
||||
```yaml
|
||||
python run.py trainer.max_epochs=20 optimizer.lr=0.0005
|
||||
python run.py trainer.max_epochs=20 datamodule.batch_size=64
|
||||
```
|
||||
|
||||
Train on GPU
|
||||
```yaml
|
||||
python run.py trainer.gpus=1
|
||||
```
|
||||
<br>
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
model_checkpoint:
|
||||
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
||||
monitor: "val/acc" # name of the logged metric which determines when model is improving
|
||||
save_top_k: 2 # save k best models (determined by above metric)
|
||||
save_top_k: 1 # save k best models (determined by above metric)
|
||||
save_last: True # additionaly always save model from last epoch
|
||||
mode: "max" # can be "max" or "min"
|
||||
verbose: False
|
||||
|
|
|
@ -2,26 +2,26 @@ defaults:
|
|||
- default_callbacks.yaml
|
||||
|
||||
|
||||
upload_code_to_wandb_as_artifact:
|
||||
_target_: src.callbacks.wandb_callbacks.UploadCodeToWandbAsArtifact
|
||||
code_dir: ${work_dir}
|
||||
|
||||
|
||||
upload_ckpts_to_wandb_as_artifact:
|
||||
_target_: src.callbacks.wandb_callbacks.UploadCheckpointsToWandbAsArtifact
|
||||
ckpt_dir: "checkpoints/"
|
||||
upload_best_only: True
|
||||
|
||||
|
||||
watch_model_with_wandb:
|
||||
_target_: src.callbacks.wandb_callbacks.WatchModelWithWandb
|
||||
_target_: src.pl_callbacks.wandb_callbacks.WatchModelWithWandb
|
||||
log: "all"
|
||||
log_freq: 100
|
||||
|
||||
|
||||
upload_code_to_wandb_as_artifact:
|
||||
_target_: src.pl_callbacks.wandb_callbacks.UploadCodeToWandbAsArtifact
|
||||
code_dir: ${work_dir}
|
||||
|
||||
|
||||
upload_ckpts_to_wandb_as_artifact:
|
||||
_target_: src.pl_callbacks.wandb_callbacks.UploadCheckpointsToWandbAsArtifact
|
||||
ckpt_dir: "checkpoints/"
|
||||
upload_best_only: True
|
||||
|
||||
|
||||
save_f1_precision_recall_heatmap_to_wandb:
|
||||
_target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmapToWandb
|
||||
_target_: src.pl_callbacks.wandb_callbacks.LogF1PrecRecHeatmapToWandb
|
||||
|
||||
|
||||
save_confusion_matrix_to_wandb:
|
||||
_target_: src.callbacks.wandb_callbacks.LogConfusionMatrixToWandb
|
||||
_target_: src.pl_callbacks.wandb_callbacks.LogConfusionMatrixToWandb
|
||||
|
|
|
@ -14,36 +14,30 @@ defaults:
|
|||
# - override hydra/job_logging: colorlog
|
||||
|
||||
|
||||
# path to original working directory (that `run.py` was executed from in command line)
|
||||
# path to original working directory
|
||||
# hydra hijacks working directory by changing it to the current log directory,
|
||||
# so it's useful to have path to original work dir as a special variable
|
||||
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
||||
# so it's useful to have this path as a special variable
|
||||
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
||||
work_dir: ${hydra:runtime.cwd}
|
||||
|
||||
|
||||
# path to folder with data
|
||||
data_dir: ${work_dir}/data/
|
||||
|
||||
|
||||
# use `python run.py debug=true` for easy debugging!
|
||||
# this will run 1 train, val and test loop with only 1 batch
|
||||
# equivalent to running `python run.py trainer.fast_dev_run=true`
|
||||
# (this is placed here just for easier access from command line)
|
||||
debug: False
|
||||
|
||||
|
||||
# pretty print config at the start of the run using Rich library
|
||||
print_config: True
|
||||
|
||||
|
||||
# disable python warnings if they annoy you
|
||||
disable_warnings: True
|
||||
|
||||
|
||||
# disable lightning logs if they annoy you
|
||||
disable_lightning_logs: False
|
||||
|
||||
|
||||
hydra:
|
||||
# output paths for hydra logs
|
||||
run:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
|
||||
_target_: src.pl_datamodules.mnist_datamodule.MNISTDataModule
|
||||
|
||||
data_dir: ${data_dir} # data_dir is specified in config.yaml
|
||||
batch_size: 64
|
||||
|
|
|
@ -28,18 +28,15 @@ trainer:
|
|||
# resume_from_checkpoint: ${work_dir}/last.ckpt
|
||||
|
||||
model:
|
||||
_target_: src.models.mnist_model.LitModelMNIST
|
||||
_target_: src.pl_models.mnist_model.MNISTLitModel
|
||||
optimizer: adam
|
||||
lr: 0.001
|
||||
weight_decay: 0.00005
|
||||
architecture: SimpleDenseNet
|
||||
input_size: 784
|
||||
lin1_size: 256
|
||||
dropout1: 0.30
|
||||
lin2_size: 256
|
||||
dropout2: 0.25
|
||||
lin3_size: 128
|
||||
dropout3: 0.20
|
||||
output_size: 10
|
||||
|
||||
optimizer:
|
||||
|
@ -48,7 +45,7 @@ optimizer:
|
|||
weight_decay: 0
|
||||
|
||||
datamodule:
|
||||
_target_: src.datamodules.mnist_datamodule.MNISTDataModule
|
||||
_target_: src.pl_datamodules.mnist_datamodule.MNISTDataModule
|
||||
data_dir: ${data_dir}
|
||||
batch_size: 64
|
||||
train_val_test_split: [55_000, 5_000, 10_000]
|
||||
|
@ -67,7 +64,7 @@ callbacks:
|
|||
early_stopping:
|
||||
_target_: pytorch_lightning.callbacks.EarlyStopping
|
||||
monitor: "val/acc"
|
||||
patience: 100
|
||||
patience: 10
|
||||
mode: "max"
|
||||
|
||||
logger:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
defaults:
|
||||
- override /trainer: default_trainer.yaml # choose trainer from 'configs/trainer/' folder or set to null
|
||||
- override /model: mnist_model.yaml # choose model from 'configs/model/' folder or set to null
|
||||
- override /optimizer: adam.yaml # choose model from 'configs/model/' folder or set to null
|
||||
- override /optimizer: adam.yaml # choose model from 'configs/model/' folder or set to null
|
||||
- override /datamodule: mnist_datamodule.yaml # choose datamodule from 'configs/datamodule/' folder or set to null
|
||||
- override /callbacks: default_callbacks.yaml # choose callback set from 'configs/callbacks/' folder or set to null
|
||||
- override /logger: null # choose logger from 'configs/logger/' folder or set to null
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
_target_: src.models.mnist_model.LitModelMNIST
|
||||
_target_: src.pl_models.mnist_model.MNISTLitModel
|
||||
|
||||
input_size: 784
|
||||
lin1_size: 256
|
||||
dropout1: 0.30
|
||||
lin2_size: 256
|
||||
dropout2: 0.25
|
||||
lin3_size: 128
|
||||
dropout3: 0.20
|
||||
lin3_size: 256
|
||||
output_size: 10
|
||||
|
|
4
run.py
4
run.py
|
@ -1,5 +1,4 @@
|
|||
import hydra
|
||||
from hydra.utils import log
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
|
@ -7,7 +6,7 @@ from omegaconf import DictConfig
|
|||
def main(config: DictConfig):
|
||||
|
||||
# Imports should be nested inside @hydra.main to optimize tab completion
|
||||
# Learn more here: https://github.com/facebookresearch/hydra/issues/934
|
||||
# Read more here: https://github.com/facebookresearch/hydra/issues/934
|
||||
import dotenv
|
||||
from src.train import train
|
||||
from src.utils import template_utils
|
||||
|
@ -25,7 +24,6 @@ def main(config: DictConfig):
|
|||
|
||||
# Pretty print config using Rich library
|
||||
if config.get("print_config"):
|
||||
log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
|
||||
template_utils.print_config(config, resolve=True)
|
||||
|
||||
# Train model
|
||||
|
|
|
@ -9,23 +9,19 @@ class SimpleDenseNet(nn.Module):
|
|||
nn.Linear(hparams["input_size"], hparams["lin1_size"]),
|
||||
nn.BatchNorm1d(hparams["lin1_size"]),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=hparams["dropout1"]),
|
||||
nn.Linear(hparams["lin1_size"], hparams["lin2_size"]),
|
||||
nn.BatchNorm1d(hparams["lin2_size"]),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=hparams["dropout2"]),
|
||||
nn.Linear(hparams["lin2_size"], hparams["lin3_size"]),
|
||||
nn.BatchNorm1d(hparams["lin3_size"]),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=hparams["dropout3"]),
|
||||
nn.Linear(hparams["lin3_size"], hparams["output_size"]),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, width, height = x.size()
|
||||
|
||||
# mnist images are (1, 28, 28) (channels, width, height)
|
||||
# (b, 1, 28, 28) -> (b, 1*28*28)
|
||||
# (batch, 1, width, height) -> (batch, 1*width*height)
|
||||
x = x.view(batch_size, -1)
|
||||
|
||||
return self.model(x)
|
||||
|
|
|
@ -7,15 +7,19 @@ import seaborn as sn
|
|||
import torch
|
||||
import wandb
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
|
||||
from sklearn import metrics
|
||||
from sklearn.metrics import f1_score, precision_score, recall_score
|
||||
|
||||
|
||||
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
|
||||
for logger in trainer.logger:
|
||||
if isinstance(logger, WandbLogger):
|
||||
return logger
|
||||
if isinstance(trainer.logger, WandbLogger):
|
||||
return trainer.logger
|
||||
|
||||
if isinstance(trainer.logger, LoggerCollection):
|
||||
for logger in trainer.logger:
|
||||
if isinstance(logger, WandbLogger):
|
||||
return logger
|
||||
|
||||
raise Exception(
|
||||
"You are using wandb related callback,"
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
|
@ -10,21 +10,29 @@ class MNISTDataModule(LightningDataModule):
|
|||
"""
|
||||
Example of LightningDataModule for MNIST dataset.
|
||||
|
||||
A DataModule standardizes the training, val, test splits, data preparation and transforms.
|
||||
A DataModule standardizes the train, val, test splits, data preparation and transforms.
|
||||
The main advantage is consistent data splits, data preparation and transforms across models.
|
||||
|
||||
Read the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir="data/",
|
||||
batch_size=64,
|
||||
train_val_test_split=(55_000, 5_000, 10_000),
|
||||
num_workers=0,
|
||||
pin_memory=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.data_dir = kwargs["data_dir"]
|
||||
self.batch_size = kwargs["batch_size"]
|
||||
self.train_val_test_split = kwargs["train_val_test_split"]
|
||||
self.num_workers = kwargs["num_workers"]
|
||||
self.pin_memory = kwargs["pin_memory"]
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.train_val_test_split = train_val_test_split
|
||||
self.num_workers = num_workers
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.transforms = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
|
@ -1,16 +1,15 @@
|
|||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import hydra
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.metrics.classification import Accuracy
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from src.architectures.simple_dense_net import SimpleDenseNet
|
||||
|
||||
|
||||
class LitModelMNIST(pl.LightningModule):
|
||||
class MNISTLitModel(LightningModule):
|
||||
"""
|
||||
Example of LightningModule for MNIST classification.
|
||||
|
||||
|
@ -25,14 +24,23 @@ class LitModelMNIST(pl.LightningModule):
|
|||
https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
input_size=784,
|
||||
lin1_size=256,
|
||||
lin2_size=256,
|
||||
lin3_size=256,
|
||||
output_size=10,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# this line ensures params passed to LightningModule will be saved to ckpt
|
||||
# it also allows to access params with 'self.hparams' attribute
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.architecture = SimpleDenseNet(hparams=self.hparams)
|
||||
self.model = SimpleDenseNet(hparams=self.hparams)
|
||||
|
||||
# loss function
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
|
@ -51,7 +59,7 @@ class LitModelMNIST(pl.LightningModule):
|
|||
}
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
return self.architecture(x)
|
||||
return self.model(x)
|
||||
|
||||
def step(self, batch) -> Dict[str, torch.Tensor]:
|
||||
x, y = batch
|
||||
|
@ -63,7 +71,7 @@ class LitModelMNIST(pl.LightningModule):
|
|||
def training_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
|
||||
loss, preds, targets = self.step(batch)
|
||||
|
||||
# log train metrics to your loggers!
|
||||
# log train metrics
|
||||
acc = self.train_accuracy(preds, targets)
|
||||
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
||||
self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
@ -76,19 +84,17 @@ class LitModelMNIST(pl.LightningModule):
|
|||
def validation_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
|
||||
loss, preds, targets = self.step(batch)
|
||||
|
||||
# log val metrics to your loggers!
|
||||
# log val metrics
|
||||
acc = self.val_accuracy(preds, targets)
|
||||
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
|
||||
self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
||||
# we can return here dict with any tensors
|
||||
# and then read it in some callback or in validation_epoch_end() below
|
||||
return {"loss": loss, "preds": preds, "targets": targets}
|
||||
|
||||
def test_step(self, batch: Any, batch_idx: int) -> Dict[str, torch.Tensor]:
|
||||
loss, preds, targets = self.step(batch)
|
||||
|
||||
# log test metrics to your loggers!
|
||||
# log test metrics
|
||||
acc = self.test_accuracy(preds, targets)
|
||||
self.log("test/loss", loss, on_step=False, on_epoch=True)
|
||||
self.log("test/acc", acc, on_step=False, on_epoch=True)
|
||||
|
@ -113,6 +119,10 @@ class LitModelMNIST(pl.LightningModule):
|
|||
self.log("val/acc_best", max(self.metric_hist["val/acc"]), prog_bar=False)
|
||||
self.log("val/loss_best", min(self.metric_hist["val/loss"]), prog_bar=False)
|
||||
|
||||
# [OPTIONAL METHOD]
|
||||
def test_epoch_end(self, outputs: List[Any]) -> None:
|
||||
pass
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:
|
14
src/train.py
14
src/train.py
|
@ -1,19 +1,17 @@
|
|||
# lightning imports
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
# hydra imports
|
||||
from omegaconf import DictConfig
|
||||
from hydra.utils import log
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
# normal imports
|
||||
from typing import List, Optional
|
||||
|
||||
# src imports
|
||||
from src.utils import template_utils
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train(config: DictConfig) -> Optional[float]:
|
||||
"""Contains training pipeline.
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
from torchvision import transforms
|
||||
|
||||
mnist_train_transforms = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
||||
)
|
||||
|
||||
mnist_test_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((28, 28)),
|
||||
transforms.Normalize((0.1307,), (0.3081,)),
|
||||
]
|
||||
)
|
|
@ -1,7 +1,7 @@
|
|||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from src.models.mnist_model import LitModelMNIST
|
||||
from src.transforms import mnist_transforms
|
||||
from src.pl_models.mnist_model import MNISTLitModel
|
||||
|
||||
|
||||
def predict():
|
||||
|
@ -16,7 +16,7 @@ def predict():
|
|||
# load model from checkpoint
|
||||
# model __init__ parameters will be loaded from ckpt automatically
|
||||
# you can also pass some parameter explicitly to override it
|
||||
trained_model = LitModelMNIST.load_from_checkpoint(checkpoint_path=CKPT_PATH)
|
||||
trained_model = MNISTLitModel.load_from_checkpoint(checkpoint_path=CKPT_PATH)
|
||||
|
||||
# print model hyperparameters
|
||||
print(trained_model.hparams)
|
||||
|
@ -30,7 +30,14 @@ def predict():
|
|||
# img = Image.open("data/example_img.png").convert("RGB") # convert to RGB
|
||||
|
||||
# preprocess
|
||||
img = mnist_transforms.mnist_test_transforms(img)
|
||||
mnist_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((28, 28)),
|
||||
transforms.Normalize((0.1307,), (0.3081,)),
|
||||
]
|
||||
)
|
||||
img = mnist_transforms(img)
|
||||
img = img.reshape((1, *img.size())) # reshape to form batch of size 1
|
||||
|
||||
# inference
|
||||
|
|
|
@ -4,13 +4,14 @@ from typing import List, Sequence
|
|||
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from hydra.utils import log
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.loggers.wandb import WandbLogger
|
||||
from rich import print
|
||||
from rich.syntax import Syntax
|
||||
from rich.tree import Tree
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extras(config: DictConfig) -> None:
|
||||
"""A couple of optional utilities, controlled by main config file.
|
||||
|
@ -22,10 +23,10 @@ def extras(config: DictConfig) -> None:
|
|||
config (DictConfig): [description]
|
||||
"""
|
||||
|
||||
# make it possible to add new keys to config
|
||||
# Enable adding new keys to config
|
||||
OmegaConf.set_struct(config, False)
|
||||
|
||||
# fix double logging bug (this will be removed when lightning releases patch)
|
||||
# Fix double logging bug (this will be removed when lightning releases patch)
|
||||
pl_logger = logging.getLogger("lightning")
|
||||
pl_logger.propagate = False
|
||||
|
||||
|
@ -56,7 +57,7 @@ def extras(config: DictConfig) -> None:
|
|||
if config.datamodule.get("num_workers"):
|
||||
config.datamodule.num_workers = 0
|
||||
|
||||
# disable adding new keys to config
|
||||
# Disable adding new keys to config
|
||||
OmegaConf.set_struct(config, True)
|
||||
|
||||
|
||||
|
@ -71,10 +72,6 @@ def print_config(
|
|||
"logger",
|
||||
"seed",
|
||||
),
|
||||
extra_depth_fields: Sequence[str] = (
|
||||
"callbacks",
|
||||
"logger",
|
||||
),
|
||||
resolve: bool = True,
|
||||
) -> None:
|
||||
"""Prints content of DictConfig using Rich library and its tree structure.
|
||||
|
@ -83,41 +80,21 @@ def print_config(
|
|||
config (DictConfig): Config.
|
||||
fields (Sequence[str], optional): Determines which main fields from config will be printed
|
||||
and in what order.
|
||||
extra_depth_fields (Sequence[str], optional): Fields which should be printed with extra tree depth.
|
||||
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
||||
"""
|
||||
|
||||
# TODO print main config path and experiment config path
|
||||
# print(f"Main config path: [link file://{directory}]{directory}")
|
||||
|
||||
# TODO refactor the whole method
|
||||
|
||||
style = "dim"
|
||||
|
||||
tree = Tree(f":gear: CONFIG", style=style, guide_style=style)
|
||||
|
||||
for field in fields:
|
||||
branch = tree.add(field, style=style, guide_style=style)
|
||||
|
||||
config_section = config.get(field)
|
||||
branch_content = str(config_section)
|
||||
if isinstance(config_section, DictConfig):
|
||||
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
|
||||
|
||||
if not config_section:
|
||||
# raise Exception(f"Field {field} not found in config!")
|
||||
branch.add("None")
|
||||
continue
|
||||
|
||||
if field in extra_depth_fields:
|
||||
for nested_field in config_section:
|
||||
nested_config_section = config_section[nested_field]
|
||||
nested_branch = branch.add(nested_field, style=style, guide_style=style)
|
||||
cfg_str = OmegaConf.to_yaml(nested_config_section, resolve=resolve)
|
||||
nested_branch.add(Syntax(cfg_str, "yaml"))
|
||||
else:
|
||||
if isinstance(config_section, DictConfig):
|
||||
cfg_str = OmegaConf.to_yaml(config_section, resolve=resolve)
|
||||
else:
|
||||
cfg_str = str(config_section)
|
||||
branch.add(Syntax(cfg_str, "yaml"))
|
||||
branch.add(Syntax(branch_content, "yaml"))
|
||||
|
||||
print(tree)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче