зеркало из https://github.com/microsoft/archai.git
yaml redirects support for relative path, accumulated gradients for smaller GPUs, yaml tests
This commit is contained in:
Родитель
543a69d5c8
Коммит
33df398b51
|
@ -71,7 +71,7 @@ class Config(UserDict):
|
|||
# replace _copy paths
|
||||
# HACK: using file names to detect root config, may be there should be a flag?
|
||||
if resolve_redirects:
|
||||
yaml_utils.resolve_all(self, self)
|
||||
yaml_utils.resolve_all(self)
|
||||
|
||||
self.config_filepath = config_filepath
|
||||
|
||||
|
|
|
@ -13,10 +13,11 @@ from .common import logger
|
|||
from archai.common.apex_utils import ApexUtils
|
||||
|
||||
class Tester(EnforceOverrides):
|
||||
def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None:
|
||||
self._title = conf_eval['title']
|
||||
self._logger_freq = conf_eval['logger_freq']
|
||||
conf_lossfn = conf_eval['lossfn']
|
||||
def __init__(self, conf_val:Config, model:nn.Module, apex:ApexUtils)->None:
|
||||
self._title = conf_val['title']
|
||||
self._logger_freq = conf_val['logger_freq']
|
||||
conf_lossfn = conf_val['lossfn']
|
||||
self._batch_chunks = conf_val['batch_chunks']
|
||||
|
||||
self._apex = apex
|
||||
self.model = model
|
||||
|
@ -43,18 +44,34 @@ class Tester(EnforceOverrides):
|
|||
|
||||
with torch.no_grad(), logger.pushd('steps'):
|
||||
for step, (x, y) in enumerate(test_dl):
|
||||
x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True)
|
||||
|
||||
assert not self.model.training # derived class might alter the mode
|
||||
# derived class might alter the mode through pre/post hooks
|
||||
assert not self.model.training
|
||||
logger.pushd(step)
|
||||
|
||||
self._pre_step(x, y, self._metrics)
|
||||
logits = self.model(x)
|
||||
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
|
||||
if tupled_out:
|
||||
logits = logits[0]
|
||||
loss = self._lossfn(logits, y)
|
||||
self._post_step(x, y, logits, loss, steps, self._metrics)
|
||||
|
||||
# divide batch in to chunks if needed so it fits in GPU RAM
|
||||
if self._batch_chunks > 1:
|
||||
x_chunks, y_chunks = torch.chunk(x, self._batch_chunks), torch.chunk(y, self._batch_chunks)
|
||||
else:
|
||||
x_chunks, y_chunks = ((x,), (y,))
|
||||
|
||||
loss_chunks, logits_chunks = [],[]
|
||||
for xc, yc in zip(x_chunks, y_chunks):
|
||||
xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True)
|
||||
|
||||
logits_c = self.model(x)
|
||||
tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >=2
|
||||
if tupled_out:
|
||||
logits_c = logits_c[0]
|
||||
loss_c = self._lossfn(logits_c, y)
|
||||
|
||||
loss_chunks.append(loss_c.cpu())
|
||||
logits_chunks.append(logits_c.cpu())
|
||||
|
||||
self._post_step(x, y,
|
||||
torch.cat(logits_chunks), torch.cat(loss_chunks),
|
||||
steps, self._metrics)
|
||||
|
||||
# TODO: we possibly need to sync so all replicas are upto date
|
||||
self._apex.sync_devices()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Callable, Tuple, Optional
|
||||
|
||||
from torch import nn, Tensor
|
||||
from torch import nn, Tensor, torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -29,6 +29,7 @@ class Trainer(EnforceOverrides):
|
|||
self._epochs = conf_train['epochs']
|
||||
self._conf_optim = conf_train['optimizer']
|
||||
self._conf_sched = conf_train['lr_schedule']
|
||||
self._batch_chunks = conf_train['batch_chunks']
|
||||
conf_validation = conf_train['validation']
|
||||
conf_apex = conf_train['apex']
|
||||
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
||||
|
@ -40,6 +41,8 @@ class Trainer(EnforceOverrides):
|
|||
self.model = model
|
||||
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
|
||||
# using separate apex for Tester is not possible because we must use
|
||||
# same distributed model as Trainer and hence they must share apex
|
||||
self._tester = Tester(conf_validation, model, self._apex) \
|
||||
if conf_validation else None
|
||||
self._metrics:Optional[Metrics] = None
|
||||
|
@ -212,8 +215,6 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
logger.pushd('steps')
|
||||
for step, (x, y) in enumerate(train_dl):
|
||||
x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True)
|
||||
|
||||
logger.pushd(step)
|
||||
assert self.model.training # derived class might alter the mode
|
||||
|
||||
|
@ -221,16 +222,29 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._optim.zero_grad()
|
||||
|
||||
logits, aux_logits = self.model(x), None
|
||||
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
|
||||
# if self._aux_weight: # TODO: some other way to validate?
|
||||
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
|
||||
if tupled_out: # then we are using model created by desc
|
||||
logits, aux_logits = logits[0], logits[1]
|
||||
loss = self.compute_loss(self._lossfn, x, y, logits,
|
||||
self._aux_weight, aux_logits)
|
||||
# divide batch in to chunks if needed so it fits in GPU RAM
|
||||
if self._batch_chunks > 1:
|
||||
x_chunks, y_chunks = torch.chunk(x, self._batch_chunks), torch.chunk(y, self._batch_chunks)
|
||||
else:
|
||||
x_chunks, y_chunks = ((x,), (y,))
|
||||
|
||||
self._apex.backward(loss, self._optim)
|
||||
loss_chunks, logits_chunks = [],[]
|
||||
for xc, yc in zip(x_chunks, y_chunks):
|
||||
xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True)
|
||||
|
||||
logits_c, aux_logits = self.model(xc), None
|
||||
tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >=2
|
||||
# if self._aux_weight: # TODO: some other way to validate?
|
||||
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
|
||||
if tupled_out: # then we are using model created by desc
|
||||
logits_c, aux_logits = logits_c[0], logits_c[1]
|
||||
loss_c = self.compute_loss(self._lossfn, yc, logits_c,
|
||||
self._aux_weight, aux_logits)
|
||||
|
||||
self._apex.backward(loss_c, self._optim)
|
||||
|
||||
loss_chunks.append(loss_c.cpu())
|
||||
logits_chunks.append(logits_c.cpu())
|
||||
|
||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
|
||||
|
@ -243,7 +257,9 @@ class Trainer(EnforceOverrides):
|
|||
if self._sched and not self._sched_on_epoch:
|
||||
self._sched.step()
|
||||
|
||||
self.post_step(x, y, logits, loss, steps)
|
||||
self.post_step(x, y,
|
||||
torch.cat(logits_chunks), torch.cat(loss_chunks),
|
||||
steps)
|
||||
logger.popd()
|
||||
|
||||
# end of step
|
||||
|
@ -252,8 +268,7 @@ class Trainer(EnforceOverrides):
|
|||
self._sched.step()
|
||||
logger.popd()
|
||||
|
||||
def compute_loss(self, lossfn:Callable,
|
||||
x:Tensor, y:Tensor, logits:Tensor,
|
||||
def compute_loss(self, lossfn:Callable, y:Tensor, logits:Tensor,
|
||||
aux_weight:float, aux_logits:Optional[Tensor])->Tensor:
|
||||
loss = lossfn(logits, y)
|
||||
if aux_weight > 0.0 and aux_logits is not None:
|
||||
|
|
|
@ -92,7 +92,7 @@ def deep_comp(o1:Any, o2:Any)->bool:
|
|||
if o1 is not None and o2 is not None:
|
||||
# if both are dictionaries, we will compare each key
|
||||
if isinstance(o1, dict) and isinstance(o2, dict):
|
||||
for k in set().union(o1.keys() ,o2.keys()):
|
||||
for k in set().union(o1.keys(), o2.keys()):
|
||||
if k in o1 and k in o2:
|
||||
if not deep_comp(o1[k], o2[k]):
|
||||
return False
|
||||
|
|
|
@ -1,27 +1,43 @@
|
|||
|
||||
from typing import Mapping, MutableMapping, Any, Optional
|
||||
|
||||
_PREFIX_NODE = '_copy'
|
||||
_PREFIX_PATH = '_copy:'
|
||||
|
||||
_PREFIX_NODE = '_copy' # for copy node content command (must be dict)
|
||||
_PREFIX_PATH = '_copy:' # for copy node value command (must be scaler)
|
||||
|
||||
|
||||
def resolve_all(root_d:MutableMapping, cur:MutableMapping):
|
||||
def resolve_all(root_d:MutableMapping):
|
||||
_resolve_all(root_d, root_d, '/', set())
|
||||
|
||||
def _resolve_all(root_d:MutableMapping, cur:MutableMapping, cur_path:str, prev_paths:set):
|
||||
assert is_proper_path(cur_path)
|
||||
|
||||
if cur_path in prev_paths:
|
||||
return # else we get in to infinite recursion
|
||||
prev_paths.add(cur_path)
|
||||
|
||||
# if cur dict has '_copy' node with path in it
|
||||
child_path = cur.get(_PREFIX_NODE, None)
|
||||
if child_path and isinstance(child_path, str):
|
||||
# resolve this path to get source dict
|
||||
child_d = _resolve_path(root_d, child_path)
|
||||
child_d = _resolve_path(root_d, _rel2full_path(cur_path, child_path), prev_paths)
|
||||
# we expect target path to point to dict so we can merge its keys
|
||||
if not isinstance(child_d, Mapping):
|
||||
raise RuntimeError(f'Path "{child_path}" should be dictionary but its instead "{child_d}"')
|
||||
# replace keys that have not been overriden
|
||||
_merge_source(child_d, cur)
|
||||
# remove command key
|
||||
del cur[_PREFIX_NODE]
|
||||
|
||||
for k in cur.keys():
|
||||
# if this key needs path resolution, get target and replace the value
|
||||
rpath = _req_resolve(cur[k])
|
||||
if rpath:
|
||||
cur[k] = _resolve_path(root_d, rpath)
|
||||
cur[k] = _resolve_path(root_d,
|
||||
_rel2full_path(_join_path(cur_path, k), rpath), prev_paths)
|
||||
# if replaced value is again dictionary, recurse on it
|
||||
if isinstance(cur[k], MutableMapping):
|
||||
resolve_all(root_d, cur[k])
|
||||
_resolve_all(root_d, cur[k], _join_path(cur_path, k), prev_paths)
|
||||
|
||||
def _merge_source(source:Mapping, dest:MutableMapping)->None:
|
||||
# for anything that source has but dest doesn't, just do copy
|
||||
|
@ -38,46 +54,96 @@ def _merge_source(source:Mapping, dest:MutableMapping)->None:
|
|||
# else at least dest value is not dict and should not be overriden
|
||||
|
||||
def _req_resolve(v:Any)->Optional[str]:
|
||||
"""If the value is actually a path we need resolve then return that path or return None"""
|
||||
if isinstance(v, str) and v.startswith(_PREFIX_PATH):
|
||||
return v[len(_PREFIX_PATH):]
|
||||
# we will almost always have space after _copy command
|
||||
return v[len(_PREFIX_PATH):].strip()
|
||||
return None
|
||||
|
||||
def _resolve_path(root_d:MutableMapping, path:str)->Any:
|
||||
assert path # otherwise we end up returning root for null paths
|
||||
d = root_d
|
||||
cur_path = '' # maintained for debugging
|
||||
def _join_path(path1:str, path2:str):
|
||||
mid = 1 if path1.endswith('/') else 0
|
||||
mid += 1 if path2.startswith('/') else 0
|
||||
|
||||
parts = path.strip().split('/')
|
||||
# if path starts with '/' remove the first part
|
||||
if len(parts)>0 and parts[0]=='':
|
||||
parts = parts[1:]
|
||||
if len(parts)>0 and parts[-1]=='':
|
||||
parts = parts[:-1]
|
||||
# only 3 possibilities
|
||||
if mid==0:
|
||||
res = path1 + '/' + path2
|
||||
elif mid==1:
|
||||
res = path1 + path2
|
||||
else:
|
||||
res = path1[:-1] + path2
|
||||
|
||||
return _norm_ended(res)
|
||||
|
||||
def _norm_ended(path:str)->str:
|
||||
if len(path) > 1 and path.endswith('/'):
|
||||
path = path[:-1]
|
||||
return path
|
||||
|
||||
def is_proper_path(path:str)->bool:
|
||||
return path.startswith('/') and (len(path)==1 or not path.endswith('/'))
|
||||
|
||||
def _rel2full_path(cwd:str, rel_path:str)->str:
|
||||
"""Given current directory and path, we return abolute path. For example,
|
||||
cwd='/a/b/c' and rel_path='../d/e' should return '/a/b/d/e'. Note that rel_path
|
||||
can hold absolute path in which case it will start with '/'
|
||||
"""
|
||||
assert len(cwd) > 0 and cwd.startswith('/'), 'cwd must be absolute path'
|
||||
|
||||
rel_parts = rel_path.split('/')
|
||||
if rel_path.startswith('/'):
|
||||
cwd_parts = [] # rel_path is absolute path so ignore cwd
|
||||
else:
|
||||
cwd_parts = cwd.split('/')
|
||||
full_parts = cwd_parts + rel_parts
|
||||
|
||||
final = []
|
||||
for i in range(len(full_parts)):
|
||||
part = full_parts[i].strip()
|
||||
if not part or part == '.': # remove blank strings and single dots
|
||||
continue
|
||||
if part == '..':
|
||||
if len(final):
|
||||
final.pop()
|
||||
else:
|
||||
raise RuntimeError(f'cannot create abs path for cwd={cwd} and rel_path={rel_path}')
|
||||
else:
|
||||
final.append(part)
|
||||
|
||||
final = '/' + '/'.join(final) # should work even when final is empty
|
||||
assert not '..' in final and is_proper_path(final) # make algo indeed worked
|
||||
return final
|
||||
|
||||
|
||||
def _resolve_path(root_d:MutableMapping, path:str, prev_paths:set)->Any:
|
||||
"""For given path returns value or node from root_d"""
|
||||
|
||||
assert is_proper_path(path)
|
||||
|
||||
# traverse path in root dict hierarchy
|
||||
for part in parts:
|
||||
# make sure current node is dict so we can "cd" into it
|
||||
if isinstance(d, Mapping):
|
||||
# if path doesn't exit in current dir, see if there are any copy commands here
|
||||
if part not in d:
|
||||
resolve_all(root_d, d)
|
||||
# if path do exist but is string with copy command and then resolve it first
|
||||
else:
|
||||
rpath = _req_resolve(d[part])
|
||||
if rpath:
|
||||
d[part] = _resolve_path(root_d, rpath)
|
||||
# else resolution already done
|
||||
cur_path = '/' # path at each iteration of for loop
|
||||
d = root_d
|
||||
for part in path.split('/'):
|
||||
if not part:
|
||||
continue # there will be blank vals at start
|
||||
|
||||
# at this point we should have dict to "cd" into otherwise its an error
|
||||
if isinstance(d, Mapping) and part in d:
|
||||
# "cd" into child node
|
||||
d = d[part]
|
||||
cur_path += '/' + part
|
||||
# For each part, we need to be able find key in dict but some dics may not
|
||||
# be fully resolved yet. For last key, d will be either dict or other value.
|
||||
if isinstance(d, Mapping):
|
||||
# for this section, make sure everything is resolved
|
||||
# before we prob for the key
|
||||
_resolve_all(root_d, d, cur_path, prev_paths)
|
||||
|
||||
if part in d:
|
||||
# "cd" into child node
|
||||
d = d[part]
|
||||
cur_path = _join_path(cur_path, part)
|
||||
else:
|
||||
raise RuntimeError(f'Path {path} could not be found in specified dictionary at "{part}"')
|
||||
else:
|
||||
raise KeyError(f'Path {path} cannot be resolved because {part} in {cur_path} does not exist or is not a dictionary')
|
||||
raise KeyError(f'Path "{path}" cannot be resolved because "{cur_path}" is not a dictionary so "{part}" cannot exist in it')
|
||||
|
||||
# last child is our answer
|
||||
rpath = _req_resolve(d)
|
||||
if rpath:
|
||||
d = _resolve_path(root_d, rpath)
|
||||
d = _resolve_path(root_d, _rel2full_path(cur_path, rpath), prev_paths)
|
||||
return d
|
|
@ -31,7 +31,7 @@ class ArchTrainer(Trainer, EnforceOverrides):
|
|||
def compute_loss(self, lossfn: Callable,
|
||||
x: Tensor, y: Tensor, logits: Tensor,
|
||||
aux_weight: float, aux_logits: Optional[Tensor]) -> Tensor:
|
||||
loss = super().compute_loss(lossfn, x, y, logits,
|
||||
loss = super().compute_loss(lossfn, y, logits,
|
||||
aux_weight, aux_logits)
|
||||
# add L1 alpha regularization
|
||||
if self._l1_alphas > 0.0:
|
||||
|
|
|
@ -32,7 +32,7 @@ common:
|
|||
scale_lr: True # enable/disable distributed mode
|
||||
min_world_size: 0 # allows to confirm we are indeed in distributed setting
|
||||
detect_anomaly: False # if True, PyTorch code will run 6X slower
|
||||
seed: '_copy: common/seed'
|
||||
seed: '_copy: /common/seed'
|
||||
|
||||
smoke_test: False
|
||||
only_eval: False
|
||||
|
@ -54,8 +54,8 @@ nas:
|
|||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
resume: '_copy: common/resume'
|
||||
_copy: '/common/checkpoint'
|
||||
resume: '_copy: /common/resume'
|
||||
model_desc:
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
|
@ -90,15 +90,16 @@ nas:
|
|||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
_copy: '/common/apex'
|
||||
enabled: False
|
||||
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
|
||||
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'eval_train'
|
||||
epochs: 600
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
optimizer:
|
||||
|
@ -116,6 +117,7 @@ nas:
|
|||
epochs: 0 # 0 disables warmup
|
||||
validation:
|
||||
title: 'eval_test'
|
||||
batch_chunks: '_copy: ../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
logger_freq: 0
|
||||
freq: 1 # perform validation only every N epochs
|
||||
lossfn:
|
||||
|
@ -124,15 +126,15 @@ nas:
|
|||
search:
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
resume: '_copy: common/resume'
|
||||
_copy: '/common/checkpoint'
|
||||
resume: '_copy: /common/resume'
|
||||
search_iters: 1
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
|
||||
metrics_dir: '$expdir/models/{reductions}/{cells}/{nodes}/{search_iter}' # where metrics and model stats would be saved from each pareto iteration
|
||||
seed_train:
|
||||
trainer:
|
||||
_copy: 'nas/eval/trainer'
|
||||
_copy: '/nas/eval/trainer'
|
||||
apex:
|
||||
enabled: False
|
||||
title: 'seed_train'
|
||||
|
@ -140,12 +142,12 @@ nas:
|
|||
aux_weight: 0.0
|
||||
drop_path_prob: 0.0
|
||||
loader:
|
||||
_copy: 'nas/eval/loader'
|
||||
_copy: '/nas/eval/loader'
|
||||
train_batch: 128
|
||||
val_ratio: 0.1 #split portion for test set, 0 to 1
|
||||
post_train:
|
||||
trainer:
|
||||
_copy: 'nas/eval/trainer'
|
||||
_copy: '/nas/eval/trainer'
|
||||
apex:
|
||||
enabled: False
|
||||
title: 'post_train'
|
||||
|
@ -153,7 +155,7 @@ nas:
|
|||
aux_weight: 0.0
|
||||
drop_path_prob: 0.0
|
||||
loader:
|
||||
_copy: 'nas/eval/loader'
|
||||
_copy: '/nas/eval/loader'
|
||||
train_batch: 128
|
||||
val_ratio: 0.1 #split portion for test set, 0 to 1
|
||||
pareto:
|
||||
|
@ -199,13 +201,14 @@ nas:
|
|||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex:
|
||||
_copy: 'common/apex'
|
||||
aux_weight: '_copy: nas/search/model_desc/aux_weight'
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/search/model_desc/aux_weight'
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'arch_train'
|
||||
epochs: 50
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
# additional vals for the derived class
|
||||
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
|
@ -231,6 +234,7 @@ nas:
|
|||
validation:
|
||||
title: 'search_val'
|
||||
logger_freq: 0
|
||||
batch_chunks: '_copy: ../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
freq: 1 # perform validation only every N epochs
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
|
|
|
@ -17,27 +17,27 @@ nas:
|
|||
trainer:
|
||||
epochs: 0 # number of epochs model will be trained before search
|
||||
loader:
|
||||
train_batch: '_copy: common/toy_mode/train_batch'
|
||||
test_batch: '_copy: common/toy_mode/test_batch'
|
||||
train_batch: '_copy: /common/toy_mode/train_batch'
|
||||
test_batch: '_copy: /common/toy_mode/test_batch'
|
||||
dataset:
|
||||
max_batches: '_copy: common/toy_mode/max_batches'
|
||||
max_batches: '_copy: /common/toy_mode/max_batches'
|
||||
post_train:
|
||||
trainer:
|
||||
epochs: 1
|
||||
loader:
|
||||
train_batch: '_copy: common/toy_mode/train_batch'
|
||||
test_batch: '_copy: common/toy_mode/test_batch'
|
||||
train_batch: '_copy: /common/toy_mode/train_batch'
|
||||
test_batch: '_copy: /common/toy_mode/test_batch'
|
||||
dataset:
|
||||
max_batches: '_copy: common/toy_mode/max_batches'
|
||||
max_batches: '_copy: /common/toy_mode/max_batches'
|
||||
model_desc:
|
||||
n_reductions: 1 # number of reductions to be applied
|
||||
n_cells: 3 # number of cells
|
||||
n_nodes: 2 # number of nodes in a cell
|
||||
loader:
|
||||
train_batch: '_copy: common/toy_mode/train_batch'
|
||||
test_batch: '_copy: common/toy_mode/test_batch'
|
||||
train_batch: '_copy: /common/toy_mode/train_batch'
|
||||
test_batch: '_copy: /common/toy_mode/test_batch'
|
||||
dataset:
|
||||
max_batches: '_copy: common/toy_mode/max_batches'
|
||||
max_batches: '_copy: /common/toy_mode/max_batches'
|
||||
trainer:
|
||||
epochs: 1
|
||||
logger_freq: 1
|
||||
|
@ -51,10 +51,10 @@ nas:
|
|||
n_nodes: 4 # number of nodes in a cell
|
||||
n_reductions: 2 # number of reductions to be applied
|
||||
loader:
|
||||
train_batch: '_copy: common/toy_mode/train_batch'
|
||||
test_batch: '_copy: common/toy_mode/test_batch'
|
||||
train_batch: '_copy: /common/toy_mode/train_batch'
|
||||
test_batch: '_copy: /common/toy_mode/test_batch'
|
||||
dataset:
|
||||
max_batches: '_copy: common/toy_mode/max_batches'
|
||||
max_batches: '_copy: /common/toy_mode/max_batches'
|
||||
trainer:
|
||||
epochs: 1
|
||||
logger_freq: 1
|
||||
|
|
|
@ -304,11 +304,11 @@ child calls. The big downside of this is that if any of the child is long runnin
|
|||
than any other then we can't checkpoint within that specific child.
|
||||
|
||||
## Yaml design
|
||||
Copy node value using '_copy :path/to/node'.
|
||||
Copy node value using '_copy :/path/to/node'.
|
||||
- target will always be scaler
|
||||
- source could be dict or scaler
|
||||
- recursive replacements
|
||||
- this will replace above string path with target value
|
||||
Insert node childs using _copy: path/to/node
|
||||
Insert node childs using _copy: /path/to/node
|
||||
- content of source node is copied
|
||||
- rest of the child overrides
|
|
@ -11,20 +11,20 @@ common:
|
|||
logging: True
|
||||
|
||||
autoaug:
|
||||
dataset: '_copy: nas/train/dataset'
|
||||
logging: '_copy: nas/logging'
|
||||
dataset: '_copy: /nas/train/dataset'
|
||||
logging: '_copy: /nas/logging'
|
||||
|
||||
nas:
|
||||
train:
|
||||
dataset:
|
||||
_copy: 'common/dataset'
|
||||
_copy: '/common/dataset'
|
||||
classes: 4
|
||||
cifar:
|
||||
workers: 0
|
||||
logging: '_copy: common/logging'
|
||||
logging: '_copy: /common/logging'
|
||||
"""
|
||||
|
||||
d = yaml.safe_load(s)
|
||||
print(d)
|
||||
resolve_all(d, d)
|
||||
print(d)
|
||||
resolve_all(d)
|
||||
print(yaml.dump(d))
|
|
@ -0,0 +1,118 @@
|
|||
import yaml
|
||||
|
||||
from archai.common import yaml_utils
|
||||
from archai.common import utils
|
||||
|
||||
def test_yaml1():
|
||||
input="""
|
||||
l1:
|
||||
l1c1: 5
|
||||
l2:
|
||||
l2c2: 6
|
||||
l3:
|
||||
l3c2: 7
|
||||
l3c3: 8
|
||||
|
||||
n1:
|
||||
_copy: '/l1'
|
||||
|
||||
n2:
|
||||
n22:
|
||||
_copy: '/l1/l2/l3'
|
||||
|
||||
n3: '_copy: /n1/l1c1'
|
||||
|
||||
n4:
|
||||
n4c1: 9
|
||||
n4c2: '_copy: ../n4c1'
|
||||
n4c3: '_copy: ../../n2/n22'
|
||||
|
||||
|
||||
n5: '_copy: ./../n1/l1c1'
|
||||
"""
|
||||
|
||||
expected = """
|
||||
l1:
|
||||
l1c1: 5
|
||||
l2: &id001
|
||||
l2c2: 6
|
||||
l3:
|
||||
l3c2: 7
|
||||
l3c3: 8
|
||||
n1:
|
||||
l1c1: 5
|
||||
l2: *id001
|
||||
n2:
|
||||
n22: &id002
|
||||
l3c2: 7
|
||||
l3c3: 8
|
||||
n3: 5
|
||||
n4:
|
||||
n4c1: 9
|
||||
n4c2: 9
|
||||
n4c3: *id002
|
||||
n5: 5
|
||||
"""
|
||||
|
||||
d_input = yaml.safe_load(input)
|
||||
yaml_utils.resolve_all(d_input)
|
||||
#print(yaml.dump(d))
|
||||
d_expected = yaml.safe_load(expected)
|
||||
|
||||
utils.deep_comp(d_input, d_expected)
|
||||
|
||||
def test_yaml2():
|
||||
input="""
|
||||
common:
|
||||
dataset:
|
||||
name: 'd1'
|
||||
classes: 3
|
||||
cifar:
|
||||
limit: -1
|
||||
logging: True
|
||||
|
||||
autoaug:
|
||||
dataset: '_copy: /nas/train/dataset'
|
||||
logging: '_copy: /nas/logging'
|
||||
|
||||
nas:
|
||||
train:
|
||||
dataset:
|
||||
_copy: '/common/dataset'
|
||||
classes: 4
|
||||
cifar:
|
||||
workers: 0
|
||||
logging: '_copy: /common/logging'
|
||||
"""
|
||||
|
||||
expected="""
|
||||
autoaug:
|
||||
dataset: &id001
|
||||
cifar:
|
||||
limit: -1
|
||||
workers: 0
|
||||
classes: 4
|
||||
name: d1
|
||||
logging: true
|
||||
common:
|
||||
dataset:
|
||||
cifar:
|
||||
limit: -1
|
||||
classes: 3
|
||||
name: d1
|
||||
logging: true
|
||||
nas:
|
||||
logging: true
|
||||
train:
|
||||
dataset: *id001
|
||||
"""
|
||||
|
||||
d_input = yaml.safe_load(input)
|
||||
#print(d)
|
||||
yaml_utils.resolve_all(d_input)
|
||||
d_expected = yaml.safe_load(expected)
|
||||
|
||||
assert utils.deep_comp(d_input, d_expected)
|
||||
|
||||
test_yaml1()
|
||||
test_yaml2()
|
Загрузка…
Ссылка в новой задаче