Коммит
caa566dfd8
|
@ -0,0 +1,45 @@
|
|||
default_language_version:
|
||||
python: python3.8
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
# supported hooks: https://pre-commit.com/hooks.html
|
||||
hooks:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: double-quote-string-fixer
|
||||
- id: debug-statements
|
||||
- id: detect-private-key
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
# - id: check-merge-conflict
|
||||
|
||||
# - repo: https://github.com/psf/black
|
||||
# rev: 20.8b1
|
||||
# hooks:
|
||||
# - id: black
|
||||
|
||||
# - repo: https://github.com/PyCQA/isort
|
||||
# rev: 5.7.0
|
||||
# hooks:
|
||||
# - id: isort
|
||||
# # profiles: https://pycqa.github.io/isort/docs/configuration/profiles/
|
||||
# # other flags: https://pycqa.github.io/isort/docs/configuration/options/
|
||||
# args: [--profile, black]
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-yapf
|
||||
# rev: v0.30.0
|
||||
# hooks:
|
||||
# - id: yapf
|
||||
# args: [--parallel, --in-place]
|
||||
|
||||
# - repo: https://gitlab.com/pycqa/flake8
|
||||
# rev: 3.7.9
|
||||
# hooks:
|
||||
# - id: flake8
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v0.790
|
||||
# hooks:
|
||||
# - id: mypy
|
|
@ -15,7 +15,7 @@ defaults:
|
|||
|
||||
# path to original working directory (that `train.py` was executed from in command line)
|
||||
# hydra hijacks working directory by changing it to the current log directory,
|
||||
# so it's useful to have path to original working directory as a special variable
|
||||
# so it's useful to have path to original work dir as a special variable
|
||||
# read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
||||
work_dir: ${hydra:runtime.cwd}
|
||||
|
||||
|
@ -28,6 +28,14 @@ data_dir: ${work_dir}/data/
|
|||
print_config: True
|
||||
|
||||
|
||||
# disable python warnings if they annoy you
|
||||
disable_warnings: True
|
||||
|
||||
|
||||
# disable lightning logs if they annoy you
|
||||
disable_lightning_logs: False
|
||||
|
||||
|
||||
# output paths for hydra logs
|
||||
hydra:
|
||||
run:
|
||||
|
@ -35,3 +43,10 @@ hydra:
|
|||
sweep:
|
||||
dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
|
||||
subdir: ${hydra.job.num}
|
||||
|
||||
job:
|
||||
env_set:
|
||||
# currently there are some issues with running sweeps alongside wandb
|
||||
# https://github.com/wandb/client/issues/1314
|
||||
# this env variable fixes that
|
||||
WANDB_START_METHOD: thread
|
||||
|
|
|
@ -3,7 +3,7 @@ gpus: 0 # set -1 to train on all GPUs available, set 0 to train on CPU only
|
|||
min_epochs: 1
|
||||
max_epochs: 10
|
||||
gradient_clip_val: 0.5
|
||||
num_sanity_val_steps: 3
|
||||
num_sanity_val_steps: 2
|
||||
progress_bar_refresh_rate: 20
|
||||
weights_summary: null
|
||||
default_root_dir: "lightning_logs/"
|
||||
|
|
|
@ -17,10 +17,9 @@ wandb>=0.10.20
|
|||
# tensorboard
|
||||
|
||||
# --------- linters --------- #
|
||||
# black>=20.8b1
|
||||
# flake8>=3.8.4
|
||||
# pylint>=2.7.1
|
||||
# isort>=5.7.0
|
||||
pre-commit
|
||||
# isort
|
||||
# black
|
||||
|
||||
# --------- others --------- #
|
||||
rich>=9.12.3
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# pytorch lightning imports
|
||||
# lightning imports
|
||||
import pytorch_lightning as pl
|
||||
|
||||
# hydra imports
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
# loggers
|
||||
# logger imports
|
||||
import wandb
|
||||
from pytorch_lightning.loggers.wandb import WandbLogger
|
||||
|
||||
|
@ -15,9 +15,9 @@ from pytorch_lightning.loggers.wandb import WandbLogger
|
|||
# from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
|
||||
# rich imports
|
||||
from rich import print
|
||||
from rich.syntax import Syntax
|
||||
from rich.tree import Tree
|
||||
from rich import print
|
||||
|
||||
# normal imports
|
||||
from typing import List
|
||||
|
@ -34,41 +34,41 @@ def print_config(config: DictConfig):
|
|||
# directory = to_absolute_path("configs/config.yaml")
|
||||
# print(f"Main config path: [link file://{directory}]{directory}")
|
||||
|
||||
style = "dim"
|
||||
style = 'dim'
|
||||
|
||||
tree = Tree(f":gear: FULL HYDRA CONFIG", style=style, guide_style=style)
|
||||
tree = Tree(f':gear: HYDRA CONFIG', style=style, guide_style=style)
|
||||
|
||||
trainer = OmegaConf.to_yaml(config["trainer"], resolve=True)
|
||||
trainer_branch = tree.add("Trainer", style=style, guide_style=style)
|
||||
trainer_branch.add(Syntax(trainer, "yaml"))
|
||||
trainer = OmegaConf.to_yaml(config['trainer'], resolve=True)
|
||||
trainer_branch = tree.add('Trainer', style=style, guide_style=style)
|
||||
trainer_branch.add(Syntax(trainer, 'yaml'))
|
||||
|
||||
model = OmegaConf.to_yaml(config["model"], resolve=True)
|
||||
model_branch = tree.add("Model", style=style, guide_style=style)
|
||||
model_branch.add(Syntax(model, "yaml"))
|
||||
model = OmegaConf.to_yaml(config['model'], resolve=True)
|
||||
model_branch = tree.add('Model', style=style, guide_style=style)
|
||||
model_branch.add(Syntax(model, 'yaml'))
|
||||
|
||||
datamodule = OmegaConf.to_yaml(config["datamodule"], resolve=True)
|
||||
datamodule_branch = tree.add("Datamodule", style=style, guide_style=style)
|
||||
datamodule_branch.add(Syntax(datamodule, "yaml"))
|
||||
datamodule = OmegaConf.to_yaml(config['datamodule'], resolve=True)
|
||||
datamodule_branch = tree.add('Datamodule', style=style, guide_style=style)
|
||||
datamodule_branch.add(Syntax(datamodule, 'yaml'))
|
||||
|
||||
callbacks_branch = tree.add("Callbacks", style=style, guide_style=style)
|
||||
if "callbacks" in config:
|
||||
for cb_name, cb_conf in config["callbacks"].items():
|
||||
callbacks_branch = tree.add('Callbacks', style=style, guide_style=style)
|
||||
if 'callbacks' in config:
|
||||
for cb_name, cb_conf in config['callbacks'].items():
|
||||
cb = callbacks_branch.add(cb_name, style=style, guide_style=style)
|
||||
cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=True), "yaml"))
|
||||
cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=True), 'yaml'))
|
||||
else:
|
||||
callbacks_branch.add("None")
|
||||
callbacks_branch.add('None')
|
||||
|
||||
logger_branch = tree.add("Logger", style=style, guide_style=style)
|
||||
if "logger" in config:
|
||||
for lg_name, lg_conf in config["logger"].items():
|
||||
logger_branch = tree.add('Logger', style=style, guide_style=style)
|
||||
if 'logger' in config:
|
||||
for lg_name, lg_conf in config['logger'].items():
|
||||
lg = logger_branch.add(lg_name, style=style, guide_style=style)
|
||||
lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=True), "yaml"))
|
||||
lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=True), 'yaml'))
|
||||
else:
|
||||
logger_branch.add("None")
|
||||
logger_branch.add('None')
|
||||
|
||||
seed = config.get("seed", "None")
|
||||
seed_branch = tree.add(f"Seed", style=style, guide_style=style)
|
||||
seed_branch.add(str(seed) + "\n")
|
||||
seed = config.get('seed', 'None')
|
||||
seed_branch = tree.add(f'Seed', style=style, guide_style=style)
|
||||
seed_branch.add(str(seed) + '\n')
|
||||
|
||||
print(tree)
|
||||
|
||||
|
@ -95,32 +95,32 @@ def log_hparams_to_all_loggers(
|
|||
hparams = {}
|
||||
|
||||
# save all params of model, datamodule and trainer
|
||||
hparams.update(config["model"])
|
||||
hparams.update(config["datamodule"])
|
||||
hparams.update(config["trainer"])
|
||||
hparams.pop("_target_")
|
||||
hparams.update(config['model'])
|
||||
hparams.update(config['datamodule'])
|
||||
hparams.update(config['trainer'])
|
||||
hparams.pop('_target_')
|
||||
|
||||
# save seed
|
||||
hparams["seed"] = config.get("seed", "None")
|
||||
hparams['seed'] = config.get('seed', 'None')
|
||||
|
||||
# save targets
|
||||
hparams["_class_model"] = config["model"]["_target_"]
|
||||
hparams["_class_datamodule"] = config["datamodule"]["_target_"]
|
||||
hparams['_class_model'] = config['model']['_target_']
|
||||
hparams['_class_datamodule'] = config['datamodule']['_target_']
|
||||
|
||||
# save sizes of each dataset
|
||||
if hasattr(datamodule, "data_train") and datamodule.data_train:
|
||||
hparams["train_size"] = len(datamodule.data_train)
|
||||
if hasattr(datamodule, "data_val") and datamodule.data_val:
|
||||
hparams["val_size"] = len(datamodule.data_val)
|
||||
if hasattr(datamodule, "data_test") and datamodule.data_test:
|
||||
hparams["test_size"] = len(datamodule.data_test)
|
||||
if hasattr(datamodule, 'data_train') and datamodule.data_train:
|
||||
hparams['train_size'] = len(datamodule.data_train)
|
||||
if hasattr(datamodule, 'data_val') and datamodule.data_val:
|
||||
hparams['val_size'] = len(datamodule.data_val)
|
||||
if hasattr(datamodule, 'data_test') and datamodule.data_test:
|
||||
hparams['test_size'] = len(datamodule.data_test)
|
||||
|
||||
# save number of model parameters
|
||||
hparams["#params_total"] = sum(p.numel() for p in model.parameters())
|
||||
hparams["#params_trainable"] = sum(
|
||||
hparams['#params_total'] = sum(p.numel() for p in model.parameters())
|
||||
hparams['#params_trainable'] = sum(
|
||||
p.numel() for p in model.parameters() if p.requires_grad
|
||||
)
|
||||
hparams["#params_not_trainable"] = sum(
|
||||
hparams['#params_not_trainable'] = sum(
|
||||
p.numel() for p in model.parameters() if not p.requires_grad
|
||||
)
|
||||
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
# TESTS FOR DIFFERENT LOGGERS
|
||||
# TO EXECUTE:
|
||||
# bash tests/logger_tests.sh
|
||||
|
||||
# conda activate testenv
|
||||
|
||||
# Test CSV logger
|
||||
echo TEST 1
|
||||
python train.py logger=csv_logger trainer.min_epochs=3 trainer.max_epochs=3 trainer.gpus=1
|
||||
|
||||
# # Test Weights&Biases logger
|
||||
echo TEST 2
|
||||
python train.py logger=wandb logger.wandb.project="env_tests" trainer.min_epochs=10 trainer.max_epochs=10 trainer.gpus=1
|
||||
|
||||
# Test TensorBoard logger
|
||||
echo TEST 3
|
||||
python train.py logger=tensorboard trainer.min_epochs=10 trainer.max_epochs=10 trainer.gpus=1
|
||||
|
||||
# Test many loggers at once
|
||||
echo TEST 4
|
||||
python train.py logger=many_loggers trainer.min_epochs=10 trainer.max_epochs=10 trainer.gpus=1
|
|
@ -1,44 +0,0 @@
|
|||
# THESE ARE JUST A COUPLE OF QUICK EXPERIMENTS TO TEST IF YOUR MODEL DOESN'T CRASH UNDER DIFFERENT CONDITIONS
|
||||
# TO EXECUTE:
|
||||
# bash tests/quick_tests.sh
|
||||
|
||||
# conda activate testenv
|
||||
|
||||
export PYTHONWARNINGS="ignore"
|
||||
|
||||
print_test_name() {
|
||||
termwidth="$(tput cols)"
|
||||
padding="$(printf '%0.1s' ={1..500})"
|
||||
printf '\e[33m%*.*s %s %*.*s\n\e[0m' 0 "$(((termwidth-2-${#1})/2))" "$padding" "$1" 0 "$(((termwidth-1-${#1})/2))" "$padding"
|
||||
}
|
||||
|
||||
|
||||
# Test for CPU
|
||||
print_test_name "TEST 1"
|
||||
python train.py trainer.gpus=0 trainer.max_epochs=1 print_config=false
|
||||
|
||||
# Test for GPU
|
||||
print_test_name "TEST 2"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=1 print_config=false
|
||||
|
||||
# Test multiple workers and cuda pinned memory
|
||||
print_test_name "TEST 3"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=2 print_config=false\
|
||||
datamodule.num_workers=4 datamodule.pin_memory=True
|
||||
|
||||
# Test all experiment configs
|
||||
print_test_name "TEST 4"
|
||||
python train.py -m '+experiment=glob(*)' trainer.gpus=1 trainer.max_epochs=3 print_config=false
|
||||
|
||||
# Test with debug trainer
|
||||
print_test_name "TEST 5"
|
||||
python train.py trainer=debug_trainer print_config=false
|
||||
|
||||
# Overfit to 10 bathes
|
||||
print_test_name "TEST 6"
|
||||
python train.py trainer.min_epochs=20 trainer.max_epochs=20 +trainer.overfit_batches=10 print_config=false
|
||||
|
||||
# Test default hydra sweep over hyperparameters (runs 4 different combinations for 1 epoch)
|
||||
print_test_name "TEST 7"
|
||||
python train.py -m datamodule.batch_size=32,64 model.lr=0.001,0.003 print_config=false \
|
||||
trainer.gpus=1 trainer.max_epochs=1
|
|
@ -0,0 +1,97 @@
|
|||
# !/bin/bash
|
||||
# These are just a couple of quick experiments to test if your model doesn't crash under different conditions
|
||||
|
||||
# To execute:
|
||||
# bash tests/quick_tests.sh
|
||||
|
||||
# Method for printing test name
|
||||
echo() {
|
||||
termwidth="$(tput cols)"
|
||||
padding="$(printf '%0.1s' ={1..500})"
|
||||
printf '\e[33m%*.*s %s %*.*s\n\e[0m' 0 "$(((termwidth-2-${#1})/2))" "$padding" "$1" 0 "$(((termwidth-1-${#1})/2))" "$padding"
|
||||
}
|
||||
|
||||
# Make python hide warnings
|
||||
export PYTHONWARNINGS="ignore"
|
||||
|
||||
|
||||
# Test fast_dev_run (runs for 1 train, 1 val and 1 test batch)
|
||||
echo "TEST 1"
|
||||
python train.py +trainer.fast_dev_run=True \
|
||||
print_config=false
|
||||
|
||||
# Overfit to 10 bathes
|
||||
echo "TEST 2"
|
||||
python train.py +trainer.overfit_batches=10 \
|
||||
trainer.min_epochs=20 trainer.max_epochs=20 \
|
||||
print_config=false
|
||||
|
||||
# Test 1 epoch on CPU
|
||||
echo "TEST 3"
|
||||
python train.py trainer.gpus=0 trainer.max_epochs=1 \
|
||||
print_config=false
|
||||
|
||||
# Test 1 epoch on GPU
|
||||
echo "TEST 4"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=1 \
|
||||
print_config=false
|
||||
|
||||
# Test on 25% of data
|
||||
echo "TEST 5"
|
||||
python train.py trainer.max_epochs=1 \
|
||||
+trainer.limit_train_batches=0.25 +trainer.limit_val_batches=0.25 +trainer.limit_test_batches=0.25 \
|
||||
print_config=false
|
||||
|
||||
# Test on 15 train batches, 10 val batches, 5 test batches
|
||||
echo "TEST 6"
|
||||
python train.py trainer.max_epochs=1 \
|
||||
+trainer.limit_train_batches=15 +trainer.limit_val_batches=10 +trainer.limit_test_batches=5 \
|
||||
print_config=false
|
||||
|
||||
# Test all experiment configs
|
||||
echo "TEST 7"
|
||||
python train.py -m '+experiment=glob(*)' trainer.gpus=1 trainer.max_epochs=2 \
|
||||
print_config=false
|
||||
|
||||
# Test default hydra sweep over hyperparameters (runs 4 different combinations with fast_dev_run)
|
||||
echo "TEST 8"
|
||||
python train.py -m datamodule.batch_size=32,64 model.lr=0.001,0.003 \
|
||||
+trainer.fast_dev_run=True \
|
||||
print_config=false
|
||||
|
||||
# Test multiple workers and cuda pinned memory
|
||||
echo "TEST 9"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=2 \
|
||||
datamodule.num_workers=4 datamodule.pin_memory=True \
|
||||
print_config=false
|
||||
|
||||
# Test 16 bit precision
|
||||
echo "TEST 10"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=1 precision=16 \
|
||||
print_config=false
|
||||
|
||||
# Test gradient accumulation
|
||||
echo "TEST 11"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=1 accumulate_grad_batches=10 \
|
||||
print_config=false
|
||||
|
||||
# Test running validation loop twice per epoch
|
||||
echo "TEST 12"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=2 val_check_interval=0.5 \
|
||||
print_config=false
|
||||
|
||||
# Test CSV logger (5 epochs)
|
||||
echo "TEST 13"
|
||||
python train.py logger=csv_logger trainer.min_epochs=5 trainer.max_epochs=5 trainer.gpus=1 \
|
||||
print_config=false
|
||||
|
||||
# Test TensorBoard logger (5 epochs)
|
||||
echo "TEST 14"
|
||||
python train.py logger=tensorboard trainer.min_epochs=5 trainer.max_epochs=5 trainer.gpus=1 \
|
||||
print_config=false
|
||||
|
||||
# Test mixed-precision training
|
||||
echo "TEST 15"
|
||||
python train.py trainer.gpus=1 trainer.max_epochs=3 \
|
||||
+amp_backend='apex' amp_level='O2' \
|
||||
print_config=false
|
|
@ -1,18 +1,20 @@
|
|||
# TESTS FOR HYPERPARAMETER SWEEPS
|
||||
# TO EXECUTE:
|
||||
# !/bin/bash
|
||||
# Test hyperparameter sweeps
|
||||
|
||||
# To execute:
|
||||
# bash tests/sweep_tests.sh
|
||||
|
||||
# conda activate testenv
|
||||
|
||||
|
||||
# currently there are some issues with running sweeps alongside wandb
|
||||
# https://github.com/wandb/client/issues/1314
|
||||
# this env variable fixes that
|
||||
export WANDB_START_METHOD=thread
|
||||
echo() {
|
||||
termwidth="$(tput cols)"
|
||||
padding="$(printf '%0.1s' ={1..500})"
|
||||
printf '\e[33m%*.*s %s %*.*s\n\e[0m' 0 "$(((termwidth-2-${#1})/2))" "$padding" "$1" 0 "$(((termwidth-1-${#1})/2))" "$padding"
|
||||
}
|
||||
|
||||
# Make python hide warnings
|
||||
export PYTHONWARNINGS="ignore"
|
||||
|
||||
# Test default hydra sweep with wandb logging
|
||||
echo TEST 1
|
||||
echo "TEST 1"
|
||||
python train.py -m datamodule.batch_size=64,128 model.lr=0.001,0.003 \
|
||||
+experiment=exp_example_simple \
|
||||
trainer.gpus=1 trainer.max_epochs=2 seed=12345 \
|
||||
|
@ -20,7 +22,7 @@ datamodule.num_workers=12 datamodule.pin_memory=True \
|
|||
logger=wandb logger.wandb.project="env_tests" logger.wandb.group="DefaultSweep_MNIST_SimpleDenseNet"
|
||||
|
||||
# Test optuna sweep with wandb logging
|
||||
echo TEST 2
|
||||
echo "TEST 2"
|
||||
python train.py -m --config-name config_optuna.yaml \
|
||||
+experiment=exp_example_simple \
|
||||
trainer.gpus=1 trainer.max_epochs=5 seed=12345 \
|
||||
|
|
50
train.py
50
train.py
|
@ -1,4 +1,4 @@
|
|||
# pytorch lightning imports
|
||||
# lightning imports
|
||||
from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning import seed_everything
|
||||
|
@ -9,6 +9,8 @@ import hydra
|
|||
|
||||
# normal imports
|
||||
from typing import List
|
||||
import warnings
|
||||
import logging
|
||||
|
||||
# src imports
|
||||
from src.utils import template_utils as utils
|
||||
|
@ -16,39 +18,47 @@ from src.utils import template_utils as utils
|
|||
|
||||
def train(config: DictConfig):
|
||||
|
||||
# Disable python warnings
|
||||
if config['disable_warnings']:
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# Disable PyTorch Lightning logs
|
||||
if config['disable_lightning_logs']:
|
||||
logging.getLogger('lightning').setLevel(logging.ERROR)
|
||||
|
||||
# Pretty print config using Rich library
|
||||
if config["print_config"]:
|
||||
if config['print_config']:
|
||||
utils.print_config(config)
|
||||
|
||||
# Set seed for random number generators in pytorch, numpy and python.random
|
||||
if "seed" in config:
|
||||
seed_everything(config["seed"])
|
||||
if 'seed' in config:
|
||||
seed_everything(config['seed'])
|
||||
|
||||
# Init PyTorch Lightning model ⚡
|
||||
model: LightningModule = hydra.utils.instantiate(config["model"])
|
||||
model: LightningModule = hydra.utils.instantiate(config['model'])
|
||||
|
||||
# Init PyTorch Lightning datamodule ⚡
|
||||
datamodule: LightningDataModule = hydra.utils.instantiate(config["datamodule"])
|
||||
datamodule: LightningDataModule = hydra.utils.instantiate(config['datamodule'])
|
||||
datamodule.prepare_data()
|
||||
datamodule.setup()
|
||||
|
||||
# Init PyTorch Lightning callbacks ⚡
|
||||
# Init PyTorch Lightning callbacks
|
||||
callbacks: List[Callback] = []
|
||||
if "callbacks" in config:
|
||||
for _, cb_conf in config["callbacks"].items():
|
||||
if "_target_" in cb_conf:
|
||||
if 'callbacks' in config:
|
||||
for _, cb_conf in config['callbacks'].items():
|
||||
if '_target_' in cb_conf:
|
||||
callbacks.append(hydra.utils.instantiate(cb_conf))
|
||||
|
||||
# Init PyTorch Lightning loggers ⚡
|
||||
# Init PyTorch Lightning loggers
|
||||
logger: List[LightningLoggerBase] = []
|
||||
if "logger" in config:
|
||||
for _, lg_conf in config["logger"].items():
|
||||
if "_target_" in lg_conf:
|
||||
if 'logger' in config:
|
||||
for _, lg_conf in config['logger'].items():
|
||||
if '_target_' in lg_conf:
|
||||
logger.append(hydra.utils.instantiate(lg_conf))
|
||||
|
||||
# Init PyTorch Lightning trainer ⚡
|
||||
trainer: Trainer = hydra.utils.instantiate(
|
||||
config["trainer"], callbacks=callbacks, logger=logger
|
||||
config['trainer'], callbacks=callbacks, logger=logger
|
||||
)
|
||||
|
||||
# Send some parameters from config to all lightning loggers
|
||||
|
@ -61,7 +71,7 @@ def train(config: DictConfig):
|
|||
logger=logger,
|
||||
)
|
||||
|
||||
# Train the model
|
||||
# Train the model ⚡
|
||||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
# Evaluate model on test set after training
|
||||
|
@ -77,16 +87,16 @@ def train(config: DictConfig):
|
|||
logger=logger,
|
||||
)
|
||||
|
||||
# Return best achieved metric score for optuna
|
||||
optimized_metric = config.get("optimized_metric", None)
|
||||
# Return metric score for optuna optimization
|
||||
optimized_metric = config.get('optimized_metric', None)
|
||||
if optimized_metric:
|
||||
return trainer.callback_metrics[optimized_metric]
|
||||
|
||||
|
||||
@hydra.main(config_path="configs/", config_name="config.yaml")
|
||||
@hydra.main(config_path='configs/', config_name='config.yaml')
|
||||
def main(config: DictConfig):
|
||||
return train(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче