зеркало из https://github.com/microsoft/archai.git
support general paths in yaml
This commit is contained in:
Родитель
af1d639c6e
Коммит
c3e97d378b
|
@ -6,7 +6,7 @@ import os
|
|||
import torch
|
||||
|
||||
from .config import Config
|
||||
from .common import expdir_abspath
|
||||
from . import utils
|
||||
|
||||
_CallbackType = Callable #[['CheckPoint', *kargs: Any, **kwargs: Any], None]
|
||||
class CheckPoint(UserDict):
|
||||
|
@ -22,7 +22,7 @@ class CheckPoint(UserDict):
|
|||
super().__init__()
|
||||
|
||||
# region config vars
|
||||
self.filepath = expdir_abspath(conf_checkpoint['filename'])
|
||||
self.filepath = utils.full_path(conf_checkpoint['filename'])
|
||||
self.freq = conf_checkpoint['freq']
|
||||
# endregion
|
||||
|
||||
|
|
|
@ -110,10 +110,10 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
|
||||
if expdir:
|
||||
# copy net config to experiment folder for reference
|
||||
with open(os.path.join(expdir, 'full_config.yaml'), 'w') as f:
|
||||
yaml.dump(conf, f)
|
||||
with open(expdir_abspath('config_used.yaml'), 'w') as f:
|
||||
yaml.dump(conf.to_dict(), f)
|
||||
if not utils.is_debugging():
|
||||
sysinfo_filepath = os.path.join(expdir, 'sysinfo.txt')
|
||||
sysinfo_filepath = expdir_abspath('sysinfo.txt')
|
||||
subprocess.Popen([f'./sysinfo.sh "{expdir}" > "{sysinfo_filepath}"'],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
shell=True)
|
||||
|
@ -128,44 +128,23 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
|
||||
return conf
|
||||
|
||||
def expdir_abspath(subpath:Optional[str], ensure_exists=False)->Optional[str]:
|
||||
"""Returns full path for given relative path within experiment directory.
|
||||
If experiment directory is not setup then None is returned.
|
||||
"""
|
||||
|
||||
expdir = get_expdir()
|
||||
if not expdir or not subpath:
|
||||
return None
|
||||
if subpath:
|
||||
expdir = os.path.join(expdir, subpath)
|
||||
if ensure_exists:
|
||||
os.makedirs(expdir, exist_ok=True)
|
||||
|
||||
return expdir
|
||||
|
||||
def expdir_filepath(filename:str, subdir:List[str]=[], ensure_path=True)\
|
||||
->Optional[str]:
|
||||
if len(subdir):
|
||||
filepath = expdir_abspath(os.path.join(*subdir), ensure_path)
|
||||
if filepath:
|
||||
filepath = os.path.join(filepath, filename)
|
||||
else:
|
||||
filepath = expdir_abspath(filename)
|
||||
return filepath
|
||||
def expdir_abspath(path:str, create=False)->str:
|
||||
"""Returns full path for given relative path within experiment directory."""
|
||||
return utils.full_path(os.path.join('$expdir',path), create=create)
|
||||
|
||||
def _create_tb_writer(is_master=True)-> SummaryWriterAny:
|
||||
conf_common = get_conf_common()
|
||||
|
||||
tbdir, conf_enable_tb = expdir_abspath('tb'), conf_common['enable_tb']
|
||||
enable_tb = conf_enable_tb and is_master and tbdir is not None and len(tbdir) > 0
|
||||
tb_dir, conf_enable_tb = utils.full_path(conf_common['tb_dir']), conf_common['tb_enable']
|
||||
tb_enable = conf_enable_tb and is_master and tb_dir is not None and len(tb_dir) > 0
|
||||
|
||||
logger.info({'conf_enable_tb': conf_enable_tb,
|
||||
'enable_tb': enable_tb,
|
||||
'tbdir': tbdir})
|
||||
'tb_enable': tb_enable,
|
||||
'tb_dir': tb_dir})
|
||||
|
||||
WriterClass = SummaryWriter if enable_tb else SummaryWriterDummy
|
||||
WriterClass = SummaryWriter if tb_enable else SummaryWriterDummy
|
||||
|
||||
return WriterClass(log_dir=tbdir)
|
||||
return WriterClass(log_dir=tb_dir)
|
||||
|
||||
def _setup_dirs()->Optional[str]:
|
||||
conf_common = get_conf_common()
|
||||
|
@ -179,17 +158,17 @@ def _setup_dirs()->Optional[str]:
|
|||
# make sure logdir and expdir exists
|
||||
logdir = conf_common['logdir']
|
||||
if logdir:
|
||||
logdir = utils.full_path(os.path.expandvars(logdir))
|
||||
logdir = utils.full_path(logdir)
|
||||
expdir = os.path.join(logdir, experiment_name)
|
||||
os.makedirs(expdir, exist_ok=True)
|
||||
else:
|
||||
expdir = ''
|
||||
raise RuntimeError('The logdir setting must be specified for the output directory in yaml')
|
||||
|
||||
# update conf so everyone gets expanded full paths from here on
|
||||
conf_common['logdir'], conf_data['dataroot'], conf_common['expdir'] = \
|
||||
logdir, dataroot, expdir
|
||||
|
||||
# set environment variable so it can be referenced in paths
|
||||
# set environment variable so it can be referenced in paths used in config
|
||||
os.environ['expdir'] = expdir
|
||||
|
||||
return expdir
|
||||
|
@ -203,8 +182,10 @@ def _setup_logger():
|
|||
if not sys_log_filepath:
|
||||
sys_logger.warn(
|
||||
'logdir not specified, no logs will be created or any models saved')
|
||||
|
||||
global logger
|
||||
logger.reset(expdir_abspath('logs.yaml'), sys_logger)
|
||||
logs_yaml_filepath = expdir_abspath('logs.yaml')
|
||||
logger.reset(logs_yaml_filepath, sys_logger)
|
||||
|
||||
def _setup_gpus():
|
||||
conf_common = get_conf_common()
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch import Tensor
|
|||
import yaml
|
||||
|
||||
from . import utils, ml_utils
|
||||
from .common import logger, get_tb_writer, expdir_abspath
|
||||
from .common import logger, get_tb_writer
|
||||
|
||||
class Metrics:
|
||||
"""Record top1, top5, loss metrics, track best so far.
|
||||
|
@ -153,13 +153,11 @@ class Metrics:
|
|||
# simply convert current object to dictionary
|
||||
utils.load_state_dict(self, state_dict)
|
||||
|
||||
def save(self, filename:str)->Optional[str]:
|
||||
save_path = expdir_abspath(filename)
|
||||
if save_path:
|
||||
if not save_path.endswith('.yaml'):
|
||||
save_path += '.yaml'
|
||||
pathlib.Path(save_path).write_text(yaml.dump(self))
|
||||
return save_path
|
||||
def save(self, filepath:str)->Optional[str]:
|
||||
if filepath:
|
||||
filepath = utils.full_path(filepath)
|
||||
pathlib.Path(filepath).write_text(yaml.dump(self))
|
||||
return filepath
|
||||
|
||||
def epochs(self)->int:
|
||||
"""Returns epochs recorded so far"""
|
||||
|
|
|
@ -105,10 +105,14 @@ def deep_comp(o1:Any, o2:Any)->bool:
|
|||
def is_debugging()->bool:
|
||||
return 'pydevd' in sys.modules # works for vscode
|
||||
|
||||
def full_path(path:str)->str:
|
||||
path = os.path.expandvars(path)
|
||||
path = os.path.expanduser(path)
|
||||
return os.path.abspath(path)
|
||||
def full_path(path:str, create=False)->str:
|
||||
assert path
|
||||
path = os.path.abspath(
|
||||
os.path.expanduser(
|
||||
os.path.expandvars(path)))
|
||||
if create:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
def zero_file(filepath)->None:
|
||||
"""Creates or truncates existing file"""
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from ..common.config import Config
|
||||
from ..common import common
|
||||
from ..common import common, utils
|
||||
from ..nas.model import Model
|
||||
from ..nas.model_desc import ModelDesc
|
||||
from ..common.trainer import Trainer
|
||||
|
@ -25,7 +25,7 @@ class ArchTrainer(Trainer, EnforceOverrides):
|
|||
super().__init__(conf_train, model, device, checkpoint, aux_tower=True)
|
||||
|
||||
self._l1_alphas = conf_train['l1_alphas']
|
||||
self._plotsdir = common.expdir_abspath(conf_train['plotsdir'], True)
|
||||
self._plotsdir = conf_train['plotsdir']
|
||||
|
||||
@overrides
|
||||
def compute_loss(self, lossfn: Callable,
|
||||
|
@ -55,7 +55,9 @@ class ArchTrainer(Trainer, EnforceOverrides):
|
|||
is_best = is_best or best_train==train_metrics.cur_epoch()
|
||||
if is_best:
|
||||
# log model_desc as a image
|
||||
plot_filepath = os.path.join(
|
||||
self._plotsdir, "EP{train_metrics.epoch:03d}")
|
||||
draw_model_desc(self.model.finalize(), plot_filepath+"-normal",
|
||||
caption=f"Epoch {train_metrics.epochs()-1}")
|
||||
plot_filepath = utils.full_path(os.path.join(
|
||||
self._plotsdir,
|
||||
f"EP{train_metrics.cur_epoch().index:03d}"),
|
||||
create=True)
|
||||
draw_model_desc(self.model.finalize(), filepath=plot_filepath,
|
||||
caption=f"Epoch {train_metrics.cur_epoch().index}")
|
||||
|
|
|
@ -15,7 +15,8 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
|||
|
||||
# region conf vars
|
||||
conf_loader = conf_eval['loader']
|
||||
save_filename = conf_eval['save_filename']
|
||||
model_filename = conf_eval['model_filename']
|
||||
metric_filename = conf_eval['metric_filename']
|
||||
conf_model_desc = conf_eval['model_desc']
|
||||
conf_checkpoint = conf_eval['checkpoint']
|
||||
resume = conf_eval['resume']
|
||||
|
@ -45,10 +46,10 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
|||
checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||
trainer = Trainer(conf_train, model, device, checkpoint, aux_tower=True)
|
||||
train_metrics = trainer.fit(train_dl, test_dl)
|
||||
train_metrics.save('eval_train_metrics')
|
||||
train_metrics.save(metric_filename)
|
||||
|
||||
# save model
|
||||
save_path = model.save(save_filename)
|
||||
save_path = model.save(model_filename)
|
||||
logger.info({'model_save_path': save_path})
|
||||
logger.popd()
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@ from overrides import EnforceOverrides
|
|||
|
||||
from .cell_builder import CellBuilder
|
||||
from .arch_trainer import TArchTrainer
|
||||
from ..common.common import common_init, expdir_abspath
|
||||
from ..common.common import common_init
|
||||
from ..common import utils
|
||||
from ..common.config import Config
|
||||
from . import evaluate
|
||||
from .search import Search
|
||||
|
@ -54,13 +55,13 @@ class ExperimentRunner(ABC, EnforceOverrides):
|
|||
def _copy_final_desc(self, search_conf)->Tuple[Config, Config]:
|
||||
# get desc file path that search has produced
|
||||
search_desc_filename = search_conf['nas']['search']['final_desc_filename']
|
||||
search_desc_filepath = expdir_abspath(search_desc_filename)
|
||||
search_desc_filepath = utils.full_path(search_desc_filename)
|
||||
assert search_desc_filepath and os.path.exists(search_desc_filepath)
|
||||
|
||||
# get file path that eval would need
|
||||
eval_conf = self._init('eval')
|
||||
eval_desc_filename = eval_conf['nas']['eval']['final_desc_filename']
|
||||
eval_desc_filepath = expdir_abspath(eval_desc_filename)
|
||||
eval_desc_filepath = utils.full_path(eval_desc_filename)
|
||||
assert eval_desc_filepath
|
||||
shutil.copy2(search_desc_filepath, eval_desc_filepath)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from overrides import overrides
|
|||
from .cell import Cell
|
||||
from .operations import Op, DropPath_
|
||||
from .model_desc import ModelDesc, AuxTowerDesc, CellDesc
|
||||
from ..common.common import logger, expdir_filepath
|
||||
from ..common.common import logger
|
||||
from ..common import utils, ml_utils
|
||||
|
||||
class Model(nn.Module):
|
||||
|
@ -136,11 +136,11 @@ class Model(nn.Module):
|
|||
if isinstance(module, DropPath_):
|
||||
module.p = p
|
||||
|
||||
def save(self, filename:str, subdir:List[str]=[])->Optional[str]:
|
||||
save_path = expdir_filepath(filename, subdir)
|
||||
if save_path:
|
||||
ml_utils.save_model(self, save_path)
|
||||
return save_path
|
||||
def save(self, filepath:str)->Optional[str]:
|
||||
if filepath:
|
||||
filepath = utils.full_path(filepath)
|
||||
ml_utils.save_model(self, filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
class AuxTower(nn.Module):
|
||||
|
|
|
@ -7,7 +7,7 @@ import copy
|
|||
|
||||
import yaml
|
||||
|
||||
from ..common.common import expdir_abspath, expdir_filepath
|
||||
from archai.common import utils
|
||||
|
||||
|
||||
"""
|
||||
|
@ -317,22 +317,20 @@ class ModelDesc:
|
|||
op_desc = getattr(self, attr)
|
||||
op_desc.load_state_dict(attr_state_dict)
|
||||
|
||||
def save(self, filename:str, subdir:List[str]=[])->Optional[str]:
|
||||
yaml_filepath = expdir_filepath(filename, subdir)
|
||||
if yaml_filepath:
|
||||
if not yaml_filepath.endswith('.yaml'):
|
||||
yaml_filepath += '.yaml'
|
||||
def save(self, filename:str)->Optional[str]:
|
||||
if filename:
|
||||
filename = utils.full_path(filename)
|
||||
|
||||
# clear so PyTorch state is not saved in yaml
|
||||
state_dict = self.state_dict(clear=True)
|
||||
pt_filepath = ModelDesc._pt_filepath(yaml_filepath)
|
||||
pt_filepath = ModelDesc._pt_filepath(filename)
|
||||
torch.save(state_dict, pt_filepath)
|
||||
# save yaml
|
||||
pathlib.Path(yaml_filepath).write_text(yaml.dump(self))
|
||||
pathlib.Path(filename).write_text(yaml.dump(self))
|
||||
# restore state
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
return yaml_filepath
|
||||
return filename
|
||||
|
||||
@staticmethod
|
||||
def _pt_filepath(desc_filepath:str)->str:
|
||||
|
@ -340,8 +338,8 @@ class ModelDesc:
|
|||
return str(pathlib.Path(desc_filepath).with_suffix('.pth'))
|
||||
|
||||
@staticmethod
|
||||
def load(yaml_filename:str)->'ModelDesc':
|
||||
yaml_filepath = expdir_abspath(yaml_filename)
|
||||
def load(filename:str)->'ModelDesc':
|
||||
yaml_filepath = utils.full_path(filename)
|
||||
if not yaml_filepath or not os.path.exists(yaml_filepath):
|
||||
raise RuntimeError("Model description file is not found."
|
||||
"Typically this file should be generated from the search."
|
||||
|
|
|
@ -8,7 +8,7 @@ import tensorwatch as tw
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
import yaml
|
||||
|
||||
from archai.common.common import expdir_filepath, logger
|
||||
from archai.common.common import logger
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.config import Config
|
||||
from .cell_builder import CellBuilder
|
||||
|
@ -20,6 +20,7 @@ from archai.datasets import data
|
|||
from .model import Model
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.common import utils
|
||||
import os
|
||||
|
||||
class MetricsStats:
|
||||
"""Holds model statistics and training metrics for given description"""
|
||||
|
@ -94,6 +95,7 @@ class Search:
|
|||
self.conf_train = conf_search['trainer']
|
||||
self.final_desc_filename = conf_search['final_desc_filename']
|
||||
self.full_desc_filename = conf_search['full_desc_filename']
|
||||
self.metrics_dir = conf_search['metrics_dir']
|
||||
self.conf_presearch_train = conf_search['seed_train']
|
||||
self.conf_postsearch_train = conf_search['post_train']
|
||||
conf_pareto = conf_search['pareto']
|
||||
|
@ -105,13 +107,14 @@ class Search:
|
|||
self.max_nodes = conf_pareto['max_nodes']
|
||||
self.search_iters = conf_search['search_iters']
|
||||
self.pareto_enabled = conf_pareto['enabled']
|
||||
pareto_summary_filename = conf_pareto['summary_filename']
|
||||
# endregion
|
||||
|
||||
self.device = torch.device(conf_search['device'])
|
||||
self.cell_builder = cell_builder
|
||||
self.trainer_class = trainer_class
|
||||
self._data_cache = {}
|
||||
self._parito_filepath = expdir_filepath('perito.tsv')
|
||||
self._parito_filepath = utils.full_path(pareto_summary_filename)
|
||||
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||
|
||||
logger.info({'pareto_enabled': self.pareto_enabled,
|
||||
|
@ -236,22 +239,23 @@ class Search:
|
|||
"""Save the model and metric info into a log file"""
|
||||
|
||||
# construct path where we will save
|
||||
subdir = ['models', str(reductions), str(cells), str(nodes),
|
||||
str(search_iter)]
|
||||
subdir = utils.full_path(self.metrics_dir.format(**vars()), create=True)
|
||||
|
||||
# save metric_infi
|
||||
metrics_stats_filepath = expdir_filepath('metrics_stats.yaml', subdir,
|
||||
ensure_path=True)
|
||||
metrics_stats_filepath = os.path.join(subdir, 'metrics_stats.yaml')
|
||||
if metrics_stats_filepath:
|
||||
with open(metrics_stats_filepath, 'w') as f:
|
||||
yaml.dump(metrics_stats, f)
|
||||
|
||||
# save just metrics separately
|
||||
metrics_filepath = expdir_filepath('metrics.yaml', subdir)
|
||||
metrics_filepath = os.path.join(subdir, 'metrics.yaml')
|
||||
if metrics_filepath:
|
||||
with open(metrics_filepath, 'w') as f:
|
||||
yaml.dump(metrics_stats.train_metrics, f)
|
||||
|
||||
logger.info({'metrics_stats_filepath': metrics_stats_filepath,
|
||||
'metrics_filepath': metrics_filepath})
|
||||
|
||||
# append key info in root pareto data
|
||||
if self._parito_filepath:
|
||||
train_top1 = val_top1 = train_epoch = val_epoch = math.nan
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Union, List, Tuple, Optional
|
|||
from .model_desc import CellDesc, CellType, ModelDesc
|
||||
from ..common.utils import first_or_default
|
||||
|
||||
def draw_model_desc(model_desc:ModelDesc, file_path:str=None, caption:str=None,
|
||||
def draw_model_desc(model_desc:ModelDesc, filepath:str=None, caption:str=None,
|
||||
render=True)->Tuple[Optional[Digraph],Optional[Digraph]]:
|
||||
normal_cell_desc = first_or_default((c for c in model_desc.cell_descs() \
|
||||
if c.cell_type == CellType.Regular), None)
|
||||
|
@ -14,16 +14,18 @@ def draw_model_desc(model_desc:ModelDesc, file_path:str=None, caption:str=None,
|
|||
reduced_cell_desc = first_or_default((c for c in model_desc.cell_descs() \
|
||||
if c.cell_type == CellType.Reduction), None)
|
||||
|
||||
g_normal = draw_cell_desc(normal_cell_desc, file_path, caption, render) \
|
||||
if normal_cell_desc is not None else None
|
||||
g_reduct = draw_cell_desc(reduced_cell_desc, file_path, caption, render) \
|
||||
if reduced_cell_desc is not None else None
|
||||
g_normal = draw_cell_desc(normal_cell_desc,
|
||||
filepath+'-normal.png' if filepath else None,
|
||||
caption) if normal_cell_desc is not None else None
|
||||
g_reduct = draw_cell_desc(reduced_cell_desc,
|
||||
filepath+'-reduced.png' if filepath else None,
|
||||
caption) if reduced_cell_desc is not None else None
|
||||
|
||||
return g_normal, g_reduct
|
||||
|
||||
def draw_cell_desc(cell_desc:CellDesc, file_path:str=None, caption:str=None,
|
||||
render=True)->Digraph:
|
||||
""" make DAG plot and optionally save to file_path as .png """
|
||||
def draw_cell_desc(cell_desc:CellDesc, filepath:str=None, caption:str=None
|
||||
)->Digraph:
|
||||
""" make DAG plot and optionally save to filepath as .png """
|
||||
|
||||
edge_attr = {
|
||||
'fontsize': '20',
|
||||
|
@ -79,6 +81,6 @@ def draw_cell_desc(cell_desc:CellDesc, file_path:str=None, caption:str=None,
|
|||
if caption:
|
||||
g.attr(label=caption, overlap='false', fontsize='20', fontname='times')
|
||||
|
||||
if render:
|
||||
g.render(file_path, view=False)
|
||||
if filepath:
|
||||
g.render(filepath, view=False)
|
||||
return g
|
||||
|
|
|
@ -5,11 +5,12 @@ common:
|
|||
experiment_desc: 'throwaway'
|
||||
logdir: '~/logdir'
|
||||
seed: 2.0
|
||||
enable_tb: False # if True then TensorBoard logging is enabled (may impact perf)
|
||||
tb_enable: False # if True then TensorBoard logging is enabled (may impact perf)
|
||||
tb_dir: '$expdir/tb' # path where tensorboard logs would be stored
|
||||
horovod: False
|
||||
device: 'cuda'
|
||||
checkpoint:
|
||||
filename: 'checkpoint.pth'
|
||||
filename: '$expdir/checkpoint.pth'
|
||||
freq: 10
|
||||
detect_anomaly: False # if True, PyTorch code will run 6X slower
|
||||
# TODO: workers setting
|
||||
|
@ -28,9 +29,10 @@ dataset: {} # default dataset settings comes from __include__ on the top
|
|||
|
||||
nas:
|
||||
eval:
|
||||
full_desc_filename: 'full_model_desc.yaml' # model desc used for building model for evaluation
|
||||
final_desc_filename: 'final_model_desc.yaml' # model desc used as template to construct cells
|
||||
save_filename: 'model.pt' # file to which trained model will be saved
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # model desc used for building model for evaluation
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # model desc used as template to construct cells
|
||||
metric_filename: '$expdir/eval_train_metrics.yaml'
|
||||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||
device: '_copy: common/device'
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
|
@ -100,8 +102,9 @@ nas:
|
|||
_copy: 'common/checkpoint'
|
||||
resume: '_copy: common/resume'
|
||||
search_iters: 1
|
||||
full_desc_filename: 'full_model_desc.yaml' # arch before it was finalized
|
||||
final_desc_filename: 'final_model_desc.yaml' # final arch is saved in this file
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
|
||||
metrics_dir: '$expdir/models/{reductions}/{cells}/{nodes}/{search_iter}' # where metrics and model stats would be saved from each pareto iteration
|
||||
device: '_copy: common/device'
|
||||
seed_train:
|
||||
trainer:
|
||||
|
@ -131,6 +134,7 @@ nas:
|
|||
max_reductions: 2
|
||||
max_nodes: 4
|
||||
enabled: False
|
||||
summary_filename: '$expdir/perito.tsv' # for each iteration of macro, we fave model and perf summary
|
||||
model_desc:
|
||||
_copy: 'nas/eval/model_desc'
|
||||
init_ch_out: 16 # num of output channels for the first cell
|
||||
|
@ -162,7 +166,7 @@ nas:
|
|||
title: 'search_train'
|
||||
epochs: 50
|
||||
# additional vals for the derived class
|
||||
plotsdir: '' # use default subfolder in logdir
|
||||
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
|
||||
from archai.common.common import common_init, expdir_abspath
|
||||
from archai.common.common import common_init
|
||||
from archai.common import utils
|
||||
import os
|
||||
|
||||
def get_filepath(suffix):
|
||||
conf = common_init(config_filepath='confs/algos/darts.yaml',
|
||||
param_args=['--common.experiment_name', 'test_basename' + f'_{suffix}'
|
||||
])
|
||||
return expdir_abspath('somefile.txt')
|
||||
return utils.full_path(os.path.join('$expdir' ,'somefile.txt'))
|
||||
|
||||
print(get_filepath('search'))
|
||||
print(get_filepath('eval'))
|
||||
|
|
|
@ -13,7 +13,7 @@ conf_eval = conf['nas']['eval']
|
|||
conf_model_desc = conf_eval['model_desc']
|
||||
|
||||
conf_model_desc['n_cells'] = 14
|
||||
template_model_desc = ModelDesc.load('final_model_desc.yaml')
|
||||
template_model_desc = ModelDesc.load('$expdir/final_model_desc.yaml')
|
||||
model_desc = create_macro_desc(conf_model_desc, True, template_model_desc)
|
||||
|
||||
mb = PetridishCellBuilder()
|
||||
|
|
Загрузка…
Ссылка в новой задаче