General improvements (#88)
* remove unnecessary linters from pre-commit-config * hook fix * small improvements * fix readme * implement easier access to debug mode * fix tests * Update README.md * Update README.md
This commit is contained in:
Родитель
5385a2ec98
Коммит
3d301977b3
|
@ -12,14 +12,17 @@ repos:
|
|||
- id: detect-private-key
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
|
||||
|
||||
# Black - code formatting
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 20.8b1
|
||||
hooks:
|
||||
- id: black
|
||||
# args: [--force-exclude, train.py]
|
||||
args: [--force-exclude, train.py]
|
||||
|
||||
|
||||
# Isort - import sorting
|
||||
# - repo: https://github.com/PyCQA/isort
|
||||
# rev: 5.7.0
|
||||
# hooks:
|
||||
|
@ -28,17 +31,8 @@ repos:
|
|||
# # other flags: https://pycqa.github.io/isort/docs/configuration/options/
|
||||
# args: [--profile, black]
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-yapf
|
||||
# rev: v0.30.0
|
||||
# hooks:
|
||||
# - id: yapf
|
||||
# args: [--parallel, --in-place]
|
||||
|
||||
# - repo: https://gitlab.com/pycqa/flake8
|
||||
# rev: 3.7.9
|
||||
# hooks:
|
||||
# - id: flake8
|
||||
|
||||
# MyPy - static type checking
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v0.790
|
||||
# hooks:
|
||||
|
|
83
README.md
83
README.md
|
@ -22,8 +22,8 @@ You should be able to easily modify behavior in [train.py](train.py) in case you
|
|||
|
||||
|
||||
If you use this template please add <br>
|
||||
[![](https://shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=001A22)](https://github.com/hobogalaxy/lightning-hydra-template) <br>
|
||||
to you `README.md`.
|
||||
[![](https://shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=303030)](https://github.com/hobogalaxy/lightning-hydra-template) <br>
|
||||
to your `README.md`.
|
||||
<br>
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ to you `README.md`.
|
|||
- [Main Ideas](#main-ideas)
|
||||
- [Some Notes](#some-notes)
|
||||
- [Project Structure](#project-structure)
|
||||
- [Quick Setup](#quick-setup)
|
||||
- [Quickstart](#quickstart)
|
||||
- [Your Superpowers](#your-superpowers)
|
||||
- [Features](#features)
|
||||
- [Main Project Configuration](#main-project-configuration)
|
||||
|
@ -139,14 +139,13 @@ The directory structure of new project looks like this:
|
|||
<br>
|
||||
|
||||
|
||||
## Quick Setup
|
||||
Install dependencies:
|
||||
## Quickstart
|
||||
```yaml
|
||||
# clone project
|
||||
git clone https://github.com/hobogalaxy/lightning-hydra-template
|
||||
cd lightning-hydra-template
|
||||
|
||||
# optionally create conda environment
|
||||
# [OPTIONAL] create conda environment
|
||||
conda env create -f conda_env_gpu.yaml -n testenv
|
||||
conda activate testenv
|
||||
|
||||
|
@ -154,14 +153,16 @@ conda activate testenv
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
When running `python train.py` you should see this:
|
||||
When running `python train.py` you should see something like this:
|
||||
<div align="center">
|
||||
|
||||
|
||||
![](https://github.com/hobogalaxy/lightning-hydra-template/blob/resources/teminal.png)
|
||||
|
||||
</div>
|
||||
|
||||
### Your Superpowers
|
||||
(click to expand)
|
||||
|
||||
<details>
|
||||
<summary>Override any config parameter from command line</summary>
|
||||
|
||||
|
@ -171,6 +172,7 @@ python train.py trainer.max_epochs=20 model.lr=0.0005
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Train on GPU</summary>
|
||||
|
||||
|
@ -180,6 +182,7 @@ python train.py trainer.gpus=1
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Train model with any logger available in PyTorch Lightning, like <a href="https://wandb.ai/">Weights&Biases</a></summary>
|
||||
|
||||
|
@ -197,6 +200,7 @@ python train.py logger=wandb
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Train model with chosen experiment config</summary>
|
||||
|
||||
|
@ -207,6 +211,7 @@ python train.py +experiment=exp_example_simple
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Execute all experiments from folder</summary>
|
||||
|
||||
|
@ -217,6 +222,7 @@ python train.py -m '+experiment=glob(*)'
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Attach some callbacks to run</summary>
|
||||
|
||||
|
@ -227,6 +233,31 @@ python train.py callbacks=default_callbacks
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Easily debug</summary>
|
||||
|
||||
```yaml
|
||||
# run 1 train, val and test loop, using only 1 batch
|
||||
python train.py debug=True
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Resume training from checkpoint</summary>
|
||||
|
||||
```yaml
|
||||
# checkpoint can be either path or URL
|
||||
# path should be absolute!
|
||||
python train.py trainer.resume_from_checkpoint="/home/user/X/lightning-hydra-template/logs/runs/2021-02-28/16-50-49/checkpoints/last.ckpt"
|
||||
# currently loading ckpt in Lightning doesn't resume logger experiment, this should change when v1.3 is released...
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Create a sweep over some hyperparameters </summary>
|
||||
|
||||
|
@ -238,32 +269,18 @@ python train.py -m datamodule.batch_size=32,64,128 model.lr=0.001,0.0005
|
|||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Create a sweep over some hyperparameters with Optuna</summary>
|
||||
|
||||
```yaml
|
||||
# this will run hyperparameter search defined in `configs/config_optuna.yaml`
|
||||
# this will run hyperparameter search defined in `configs/config_optuna.yaml`
|
||||
# over chosen experiment config
|
||||
python train.py -m --config-name config_optuna.yaml +experiment=exp_example_simple
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Resume training from checkpoint</summary>
|
||||
(TODO)
|
||||
<!--
|
||||
```yaml
|
||||
# checkpoint can be either path or URL
|
||||
# path should be either absolute or prefixed with `${work_dir}/`
|
||||
# use quotes '' around argument or otherwise $ symbol breaks it
|
||||
python train.py '+trainer.resume_from_checkpoint=${work_dir}/logs/runs/2021-02-28/16-50-49/checkpoints/last.ckpt'
|
||||
```
|
||||
-->
|
||||
|
||||
</details>
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
|
@ -316,6 +333,11 @@ work_dir: ${hydra:runtime.cwd}
|
|||
data_dir: ${work_dir}/data/
|
||||
|
||||
|
||||
# use `python train.py debug=true` for easy debugging!
|
||||
# (equivalent to running `python train.py trainer.fast_dev_run=True`)
|
||||
debug: False
|
||||
|
||||
|
||||
# pretty print config at the start of the run using Rich library
|
||||
print_config: True
|
||||
|
||||
|
@ -328,20 +350,27 @@ disable_warnings: False
|
|||
disable_lightning_logs: False
|
||||
|
||||
|
||||
# output paths for hydra logs
|
||||
# hydra configuration
|
||||
hydra:
|
||||
|
||||
# output paths for hydra logs
|
||||
run:
|
||||
dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
||||
sweep:
|
||||
dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
|
||||
subdir: ${hydra.job.num}
|
||||
|
||||
# set your environment variables here
|
||||
job:
|
||||
env_set:
|
||||
ENV_VAR_X: something
|
||||
```
|
||||
<br>
|
||||
|
||||
|
||||
## Experiment Configuration
|
||||
Location: [configs/experiment](configs/experiment)<br>
|
||||
You can store many experiment configurations in this folder.<br>
|
||||
You should store all your experiment configurations in this folder.<br>
|
||||
### Simple Example
|
||||
```yaml
|
||||
# to execute this experiment run:
|
||||
|
@ -535,7 +564,7 @@ choosing metric names with '/' for wandb -->
|
|||
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-orange?logo=pytorch"></a>
|
||||
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-blueviolet"></a>
|
||||
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-blue"></a>
|
||||
[![](https://shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=001A22)](https://github.com/hobogalaxy/lightning-hydra-template)
|
||||
[![](https://shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=303030)](https://github.com/hobogalaxy/lightning-hydra-template)
|
||||
|
||||
</div>
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ model_checkpoint:
|
|||
save_top_k: 2 # save k best models (determined by above metric)
|
||||
save_last: True # additionaly always save model from last epoch
|
||||
mode: "max" # can be "max" or "min"
|
||||
verbose: False
|
||||
dirpath: 'checkpoints/'
|
||||
filename: '{epoch:02d}'
|
||||
|
||||
|
|
|
@ -24,6 +24,12 @@ work_dir: ${hydra:runtime.cwd}
|
|||
data_dir: ${work_dir}/data/
|
||||
|
||||
|
||||
# use `python train.py debug=true` for easy debugging!
|
||||
# this will run 1 train, val and test loop with only 1 batch
|
||||
# (equivalent to running `python train.py trainer.fast_dev_run=True`)
|
||||
debug: False
|
||||
|
||||
|
||||
# pretty print config at the start of the run using Rich library
|
||||
print_config: True
|
||||
|
||||
|
@ -45,8 +51,9 @@ hydra:
|
|||
subdir: ${hydra.job.num}
|
||||
|
||||
job:
|
||||
# set here your environment variables
|
||||
env_set:
|
||||
# currently there are some issues with running sweeps alongside wandb
|
||||
# https://github.com/wandb/client/issues/1314
|
||||
# this env variable fixes that
|
||||
# this env var fixes that
|
||||
WANDB_START_METHOD: thread
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# python train.py -m --config-name config_optuna.yaml +experiment=exp_example_simple logger=wandb
|
||||
|
||||
defaults:
|
||||
# load everything from main config file
|
||||
- config.yaml
|
||||
# load everything from main config file
|
||||
- config.yaml
|
||||
|
||||
# override sweeper to optuna!
|
||||
- override hydra/sweeper: optuna
|
||||
# override sweeper to Optuna!
|
||||
- override hydra/sweeper: optuna
|
||||
|
||||
|
||||
# choose metric which will be optimized by optuna
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# @package _global_
|
||||
|
||||
defaults:
|
||||
# load everything from main config file
|
||||
- config.yaml
|
||||
|
||||
# override sweeper to Ray!
|
||||
- override hydra/launcher: ray
|
||||
# - override hydra/launcher: ray_aws
|
||||
|
||||
|
||||
hydra:
|
||||
launcher:
|
||||
ray:
|
||||
remote:
|
||||
max_retries: 4
|
||||
# InstanceType: p2.8xlarge
|
||||
# asd
|
||||
|
||||
# sync_up:
|
||||
# # source dir is relative in this case, assuming you are running from
|
||||
# # <project_root>/hydra/plugins/hydra_ray_launcher/example
|
||||
# # absolute path is also supported.
|
||||
# source_dir: '.'
|
||||
# # we leave target_dir to be null
|
||||
# # as a result the files will be synced to a temp dir on remote cluster.
|
||||
# # the temp dir will be cleaned up after the jobs are done.
|
||||
# # recommend to leave target_dir to be null if you are syncing code/artifacts to remote cluster so you don't need
|
||||
# # configure $PYTHONPATH on remote cluster
|
||||
# include: ['model', '*.py']
|
||||
# # No need to sync up config files.
|
||||
# exclude: ['*']
|
||||
# sync_down:
|
||||
# include: ['*.pt', '*/']
|
||||
# # No need to sync down config files.
|
||||
# exclude: ['*']
|
|
@ -8,3 +8,4 @@ progress_bar_refresh_rate: 10
|
|||
weights_summary: null
|
||||
default_root_dir: "lightning_logs/"
|
||||
fast_dev_run: False
|
||||
resume_from_checkpoint: null
|
||||
|
|
|
@ -8,6 +8,7 @@ pytorch-lightning>=1.2.1
|
|||
hydra-core==1.1.0.dev3
|
||||
hydra_colorlog>=1.0.0
|
||||
hydra-optuna-sweeper>=0.9.0rc2
|
||||
hydra-ray-launcher==0.1.2
|
||||
|
||||
# --------- loggers --------- #
|
||||
wandb>=0.10.20
|
||||
|
@ -20,6 +21,7 @@ wandb>=0.10.20
|
|||
pre-commit
|
||||
# isort
|
||||
# black
|
||||
# mypy
|
||||
|
||||
# --------- others --------- #
|
||||
rich>=9.12.3
|
||||
|
|
|
@ -5,7 +5,7 @@ from PIL import Image
|
|||
|
||||
def predict():
|
||||
"""
|
||||
This method is example of inference with a trained model.
|
||||
This is example of inference with a trained model.
|
||||
It Loads trained image classification model from checkpoint.
|
||||
Then it loads example image and predicts its label.
|
||||
"""
|
||||
|
|
|
@ -27,24 +27,24 @@ import logging
|
|||
|
||||
|
||||
def extras(config: DictConfig):
|
||||
"""A couple of optional utilities, controlled with variables in main config file.
|
||||
Simply delete those if you don' want them.
|
||||
|
||||
"""A couple of optional utilities, controlled by main config file.
|
||||
- easier access to debug mode
|
||||
- forcing debug friendly configuration
|
||||
- disabling warnings
|
||||
- disabling lightning logs
|
||||
Args:
|
||||
config (DictConfig): [description]
|
||||
"""
|
||||
|
||||
# [OPTIONAL] Disable python warnings if <disable_warnings=True>
|
||||
if config.get("disable_warnings"):
|
||||
log.info(f"Disabling python warnings! <{config.disable_warnings=}>")
|
||||
warnings.filterwarnings("ignore")
|
||||
# make it possible to add new keys to config
|
||||
OmegaConf.set_struct(config, False)
|
||||
|
||||
# [OPTIONAL] Disable Lightning logs if <disable_lightning_logs=True>
|
||||
if config.get("disable_lightning_logs"):
|
||||
log.info(f"Disabling lightning logs! {config.disable_lightning_logs=}>")
|
||||
logging.getLogger("lightning").setLevel(logging.ERROR)
|
||||
# [OPTIONAL] Set <config.trainer.fast_dev_run=True> if <config.debug=True>
|
||||
if config.get("debug"):
|
||||
log.info(f"Running in debug mode! <{config.debug=}>")
|
||||
config.trainer.fast_dev_run = True
|
||||
|
||||
# [OPTIONAL] Force debugger friendly configuration if <trainer.fast_dev_run=True>
|
||||
# [OPTIONAL] Force debugger friendly configuration if <config.trainer.fast_dev_run=True>
|
||||
if config.trainer.get("fast_dev_run"):
|
||||
log.info(
|
||||
f"Forcing debugger friendly configuration! "
|
||||
|
@ -56,17 +56,26 @@ def extras(config: DictConfig):
|
|||
if config.datamodule.get("num_workers"):
|
||||
config.datamodule.num_workers = 0
|
||||
|
||||
# [OPTIONAL] Pretty print config using Rich library if <print_config=True>
|
||||
if config.get("print_config"):
|
||||
log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
|
||||
print_config(config)
|
||||
# [OPTIONAL] Disable python warnings if <config.disable_warnings=True>
|
||||
if config.get("disable_warnings"):
|
||||
log.info(f"Disabling python warnings! <{config.disable_warnings=}>")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# [OPTIONAL] Disable Lightning logs if <config.disable_lightning_logs=True>
|
||||
if config.get("disable_lightning_logs"):
|
||||
log.info(f"Disabling lightning logs! {config.disable_lightning_logs=}>")
|
||||
logging.getLogger("lightning").setLevel(logging.ERROR)
|
||||
|
||||
# disable adding new keys to config
|
||||
OmegaConf.set_struct(config, True)
|
||||
|
||||
|
||||
def print_config(config: DictConfig):
|
||||
def print_config(config: DictConfig, resolve: bool = True):
|
||||
"""Prints content of Hydra config using Rich library.
|
||||
|
||||
Args:
|
||||
config (DictConfig): [description]
|
||||
resolve (bool, optional): Whether to resolve reference fields in Hydra config.
|
||||
"""
|
||||
|
||||
# TODO print main config path and experiment config path
|
||||
|
@ -77,15 +86,15 @@ def print_config(config: DictConfig):
|
|||
|
||||
tree = Tree(f":gear: TRAINING CONFIG", style=style, guide_style=style)
|
||||
|
||||
trainer = OmegaConf.to_yaml(config["trainer"], resolve=True)
|
||||
trainer = OmegaConf.to_yaml(config["trainer"], resolve=resolve)
|
||||
trainer_branch = tree.add("Trainer", style=style, guide_style=style)
|
||||
trainer_branch.add(Syntax(trainer, "yaml"))
|
||||
|
||||
model = OmegaConf.to_yaml(config["model"], resolve=True)
|
||||
model = OmegaConf.to_yaml(config["model"], resolve=resolve)
|
||||
model_branch = tree.add("Model", style=style, guide_style=style)
|
||||
model_branch.add(Syntax(model, "yaml"))
|
||||
|
||||
datamodule = OmegaConf.to_yaml(config["datamodule"], resolve=True)
|
||||
datamodule = OmegaConf.to_yaml(config["datamodule"], resolve=resolve)
|
||||
datamodule_branch = tree.add("Datamodule", style=style, guide_style=style)
|
||||
datamodule_branch.add(Syntax(datamodule, "yaml"))
|
||||
|
||||
|
@ -93,7 +102,7 @@ def print_config(config: DictConfig):
|
|||
if "callbacks" in config:
|
||||
for cb_name, cb_conf in config["callbacks"].items():
|
||||
cb = callbacks_branch.add(cb_name, style=style, guide_style=style)
|
||||
cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=True), "yaml"))
|
||||
cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=resolve), "yaml"))
|
||||
else:
|
||||
callbacks_branch.add("None")
|
||||
|
||||
|
@ -101,7 +110,7 @@ def print_config(config: DictConfig):
|
|||
if "logger" in config:
|
||||
for lg_name, lg_conf in config["logger"].items():
|
||||
lg = logger_branch.add(lg_name, style=style, guide_style=style)
|
||||
lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=True), "yaml"))
|
||||
lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=resolve), "yaml"))
|
||||
else:
|
||||
logger_branch.add("None")
|
||||
|
||||
|
@ -121,6 +130,7 @@ def log_hparams_to_all_loggers(
|
|||
logger: List[pl.loggers.LightningLoggerBase],
|
||||
):
|
||||
"""This method controls which parameters from Hydra config are saved by Lightning loggers.
|
||||
It additionaly saves sizes of each dataset and number of trainable model parameters.
|
||||
|
||||
Args:
|
||||
config (DictConfig): [description]
|
||||
|
|
|
@ -27,6 +27,12 @@ datamodule.num_workers=4 datamodule.pin_memory=True \
|
|||
print_config=false
|
||||
|
||||
echo "TEST 3"
|
||||
echo "Train with 16-bit precision (1 epoch)"
|
||||
python train.py trainer.gpus=-1 trainer.max_epochs=2 \
|
||||
+trainer.precision=16 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 4"
|
||||
echo "Train with mixed-precision (apex, 2 epochs)"
|
||||
python train.py trainer.gpus=-1 trainer.max_epochs=2 \
|
||||
+trainer.amp_backend='apex' +trainer.amp_level='O2' \
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
#####################################################
|
||||
# export WANDB_START_METHOD=thread
|
||||
# python hydra_wandb_test.py -m +some_param=1,2,3,4
|
||||
#####################################################
|
||||
|
||||
|
||||
# import os, sys
|
||||
# sys.path.insert(1, os.path.join(sys.path[0], ".."))
|
||||
# print(os.path.abspath(os.curdir))
|
||||
|
||||
|
||||
import hydra
|
||||
import wandb
|
||||
|
||||
|
||||
@hydra.main(config_path="../configs/", config_name="config.yaml")
|
||||
def main(config):
|
||||
wandb.init(project="env_tests")
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -16,65 +16,60 @@ export PYTHONWARNINGS="ignore"
|
|||
|
||||
|
||||
echo "TEST 1"
|
||||
echo "fast_dev_run=true (run 1 train, val and test loop using 1 batch)"
|
||||
python train.py trainer.fast_dev_run=True \
|
||||
echo "Debug mode (run 1 train, val and test loop using 1 batch)"
|
||||
python train.py debug=True \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 2"
|
||||
echo "Overfit to 10 bathes"
|
||||
echo "Overfit to 10 batches (10 epochs)"
|
||||
python train.py +trainer.overfit_batches=10 \
|
||||
trainer.min_epochs=10 trainer.max_epochs=10 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 3"
|
||||
echo "Train 1 epoch on CPU"
|
||||
echo "Train on CPU (1 epoch)"
|
||||
python train.py trainer.gpus=0 trainer.max_epochs=1 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 4"
|
||||
echo "Train on 25% of data"
|
||||
echo "Train on 25% of data (1 epoch)"
|
||||
python train.py trainer.max_epochs=1 \
|
||||
+trainer.limit_train_batches=0.25 +trainer.limit_val_batches=0.25 +trainer.limit_test_batches=0.25 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 5"
|
||||
echo "Train on 15 train batches, 10 val batches, 5 test batches"
|
||||
echo "Train on 15 train batches, 10 val batches, 5 test batches (1 epoch)"
|
||||
python train.py trainer.max_epochs=1 \
|
||||
+trainer.limit_train_batches=15 +trainer.limit_val_batches=10 +trainer.limit_test_batches=5 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 6"
|
||||
echo "Run all experiment configs for 2 epochs"
|
||||
echo "Run all experiment configs (2 epochs)"
|
||||
python train.py -m '+experiment=glob(*)' trainer.max_epochs=2 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 7"
|
||||
echo "Run default hydra sweep (executes 4 different combinations with fast_dev_run=True)"
|
||||
echo "Run default hydra sweep (executes 4 different combinations in debug mode)"
|
||||
python train.py -m datamodule.batch_size=32,64 model.lr=0.001,0.003 \
|
||||
trainer.fast_dev_run=True \
|
||||
debug=True \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 8"
|
||||
echo "Run with 16 bit precision"
|
||||
python train.py trainer.max_epochs=1 +trainer.precision=16 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 9"
|
||||
echo "Run with gradient accumulation"
|
||||
echo "Run with gradient accumulation (1 epoch)"
|
||||
python train.py trainer.max_epochs=1 +trainer.accumulate_grad_batches=10 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 10"
|
||||
echo "Run validation loop twice per epoch"
|
||||
echo "TEST 9"
|
||||
echo "Run validation loop twice per epoch (2 epochs)"
|
||||
python train.py trainer.max_epochs=2 +trainer.val_check_interval=0.5 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 11"
|
||||
echo "TEST 10"
|
||||
echo "Run with CSVLogger (2 epochs)"
|
||||
python train.py logger=csv trainer.min_epochs=5 trainer.max_epochs=2 \
|
||||
print_config=false
|
||||
|
||||
echo "TEST 12"
|
||||
echo "TEST 11"
|
||||
echo "Run with TensorBoardLogger (2 epochs)"
|
||||
python train.py logger=tensorboard trainer.min_epochs=5 trainer.max_epochs=2 \
|
||||
print_config=false
|
||||
|
|
13
train.py
13
train.py
|
@ -10,8 +10,6 @@ import hydra
|
|||
|
||||
# normal imports
|
||||
from typing import List
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
# src imports
|
||||
from src.utils import template_utils as utils
|
||||
|
@ -20,12 +18,18 @@ from src.utils import template_utils as utils
|
|||
def train(config: DictConfig):
|
||||
|
||||
# A couple of optional utilities:
|
||||
# - disabling warnings,
|
||||
# - printing config with Rich,
|
||||
# - easier access to debug mode
|
||||
# - forcing debug friendly configuration
|
||||
# - disabling warnings
|
||||
# - disabling lightning logs
|
||||
# You can safely get rid of this line if you don't want those
|
||||
utils.extras(config)
|
||||
|
||||
# Pretty print config using Rich library
|
||||
if config.get("print_config"):
|
||||
log.info(f"Pretty printing config with Rich! <{config.print_config=}>")
|
||||
utils.print_config(config, resolve=True)
|
||||
|
||||
# Set seed for random number generators in pytorch, numpy and python.random
|
||||
if "seed" in config:
|
||||
seed_everything(config.seed)
|
||||
|
@ -101,6 +105,7 @@ def train(config: DictConfig):
|
|||
|
||||
@hydra.main(config_path="configs/", config_name="config.yaml")
|
||||
def main(config: DictConfig):
|
||||
pass
|
||||
return train(config)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче