yaml redirects support for relative path, accumulated gradients for smaller GPUs, yaml tests

This commit is contained in:
Shital Shah 2020-04-25 21:02:45 -07:00
Родитель 543a69d5c8
Коммит 33df398b51
11 изменённых файлов: 320 добавлений и 100 удалений

Просмотреть файл

@ -71,7 +71,7 @@ class Config(UserDict):
# replace _copy paths
# HACK: using file names to detect root config, may be there should be a flag?
if resolve_redirects:
yaml_utils.resolve_all(self, self)
yaml_utils.resolve_all(self)
self.config_filepath = config_filepath

Просмотреть файл

@ -13,10 +13,11 @@ from .common import logger
from archai.common.apex_utils import ApexUtils
class Tester(EnforceOverrides):
def __init__(self, conf_eval:Config, model:nn.Module, apex:ApexUtils)->None:
self._title = conf_eval['title']
self._logger_freq = conf_eval['logger_freq']
conf_lossfn = conf_eval['lossfn']
def __init__(self, conf_val:Config, model:nn.Module, apex:ApexUtils)->None:
self._title = conf_val['title']
self._logger_freq = conf_val['logger_freq']
conf_lossfn = conf_val['lossfn']
self._batch_chunks = conf_val['batch_chunks']
self._apex = apex
self.model = model
@ -43,18 +44,34 @@ class Tester(EnforceOverrides):
with torch.no_grad(), logger.pushd('steps'):
for step, (x, y) in enumerate(test_dl):
x, y = x.to(self._apex.device, non_blocking=True), y.to(self._apex.device, non_blocking=True)
assert not self.model.training # derived class might alter the mode
# derived class might alter the mode through pre/post hooks
assert not self.model.training
logger.pushd(step)
self._pre_step(x, y, self._metrics)
logits = self.model(x)
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
if tupled_out:
logits = logits[0]
loss = self._lossfn(logits, y)
self._post_step(x, y, logits, loss, steps, self._metrics)
# divide batch in to chunks if needed so it fits in GPU RAM
if self._batch_chunks > 1:
x_chunks, y_chunks = torch.chunk(x, self._batch_chunks), torch.chunk(y, self._batch_chunks)
else:
x_chunks, y_chunks = ((x,), (y,))
loss_chunks, logits_chunks = [],[]
for xc, yc in zip(x_chunks, y_chunks):
xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True)
logits_c = self.model(x)
tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >=2
if tupled_out:
logits_c = logits_c[0]
loss_c = self._lossfn(logits_c, y)
loss_chunks.append(loss_c.cpu())
logits_chunks.append(logits_c.cpu())
self._post_step(x, y,
torch.cat(logits_chunks), torch.cat(loss_chunks),
steps, self._metrics)
# TODO: we possibly need to sync so all replicas are upto date
self._apex.sync_devices()

Просмотреть файл

@ -1,6 +1,6 @@
from typing import Callable, Tuple, Optional
from torch import nn, Tensor
from torch import nn, Tensor, torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
@ -29,6 +29,7 @@ class Trainer(EnforceOverrides):
self._epochs = conf_train['epochs']
self._conf_optim = conf_train['optimizer']
self._conf_sched = conf_train['lr_schedule']
self._batch_chunks = conf_train['batch_chunks']
conf_validation = conf_train['validation']
conf_apex = conf_train['apex']
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
@ -40,6 +41,8 @@ class Trainer(EnforceOverrides):
self.model = model
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
# using separate apex for Tester is not possible because we must use
# same distributed model as Trainer and hence they must share apex
self._tester = Tester(conf_validation, model, self._apex) \
if conf_validation else None
self._metrics:Optional[Metrics] = None
@ -212,8 +215,6 @@ class Trainer(EnforceOverrides):
logger.pushd('steps')
for step, (x, y) in enumerate(train_dl):
x, y = x.to(self.get_device(), non_blocking=True), y.to(self.get_device(), non_blocking=True)
logger.pushd(step)
assert self.model.training # derived class might alter the mode
@ -221,16 +222,29 @@ class Trainer(EnforceOverrides):
self._optim.zero_grad()
logits, aux_logits = self.model(x), None
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
# if self._aux_weight: # TODO: some other way to validate?
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
if tupled_out: # then we are using model created by desc
logits, aux_logits = logits[0], logits[1]
loss = self.compute_loss(self._lossfn, x, y, logits,
self._aux_weight, aux_logits)
# divide batch in to chunks if needed so it fits in GPU RAM
if self._batch_chunks > 1:
x_chunks, y_chunks = torch.chunk(x, self._batch_chunks), torch.chunk(y, self._batch_chunks)
else:
x_chunks, y_chunks = ((x,), (y,))
self._apex.backward(loss, self._optim)
loss_chunks, logits_chunks = [],[]
for xc, yc in zip(x_chunks, y_chunks):
xc, yc = xc.to(self.get_device(), non_blocking=True), yc.to(self.get_device(), non_blocking=True)
logits_c, aux_logits = self.model(xc), None
tupled_out = isinstance(logits_c, Tuple) and len(logits_c) >=2
# if self._aux_weight: # TODO: some other way to validate?
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
if tupled_out: # then we are using model created by desc
logits_c, aux_logits = logits_c[0], logits_c[1]
loss_c = self.compute_loss(self._lossfn, yc, logits_c,
self._aux_weight, aux_logits)
self._apex.backward(loss_c, self._optim)
loss_chunks.append(loss_c.cpu())
logits_chunks.append(logits_c.cpu())
# TODO: original darts clips alphas as well but pt.darts doesn't
self._apex.clip_grad(self._grad_clip, self.model, self._optim)
@ -243,7 +257,9 @@ class Trainer(EnforceOverrides):
if self._sched and not self._sched_on_epoch:
self._sched.step()
self.post_step(x, y, logits, loss, steps)
self.post_step(x, y,
torch.cat(logits_chunks), torch.cat(loss_chunks),
steps)
logger.popd()
# end of step
@ -252,8 +268,7 @@ class Trainer(EnforceOverrides):
self._sched.step()
logger.popd()
def compute_loss(self, lossfn:Callable,
x:Tensor, y:Tensor, logits:Tensor,
def compute_loss(self, lossfn:Callable, y:Tensor, logits:Tensor,
aux_weight:float, aux_logits:Optional[Tensor])->Tensor:
loss = lossfn(logits, y)
if aux_weight > 0.0 and aux_logits is not None:

Просмотреть файл

@ -92,7 +92,7 @@ def deep_comp(o1:Any, o2:Any)->bool:
if o1 is not None and o2 is not None:
# if both are dictionaries, we will compare each key
if isinstance(o1, dict) and isinstance(o2, dict):
for k in set().union(o1.keys() ,o2.keys()):
for k in set().union(o1.keys(), o2.keys()):
if k in o1 and k in o2:
if not deep_comp(o1[k], o2[k]):
return False

Просмотреть файл

@ -1,27 +1,43 @@
from typing import Mapping, MutableMapping, Any, Optional
_PREFIX_NODE = '_copy'
_PREFIX_PATH = '_copy:'
_PREFIX_NODE = '_copy' # for copy node content command (must be dict)
_PREFIX_PATH = '_copy:' # for copy node value command (must be scaler)
def resolve_all(root_d:MutableMapping, cur:MutableMapping):
def resolve_all(root_d:MutableMapping):
_resolve_all(root_d, root_d, '/', set())
def _resolve_all(root_d:MutableMapping, cur:MutableMapping, cur_path:str, prev_paths:set):
assert is_proper_path(cur_path)
if cur_path in prev_paths:
return # else we get in to infinite recursion
prev_paths.add(cur_path)
# if cur dict has '_copy' node with path in it
child_path = cur.get(_PREFIX_NODE, None)
if child_path and isinstance(child_path, str):
# resolve this path to get source dict
child_d = _resolve_path(root_d, child_path)
child_d = _resolve_path(root_d, _rel2full_path(cur_path, child_path), prev_paths)
# we expect target path to point to dict so we can merge its keys
if not isinstance(child_d, Mapping):
raise RuntimeError(f'Path "{child_path}" should be dictionary but its instead "{child_d}"')
# replace keys that have not been overriden
_merge_source(child_d, cur)
# remove command key
del cur[_PREFIX_NODE]
for k in cur.keys():
# if this key needs path resolution, get target and replace the value
rpath = _req_resolve(cur[k])
if rpath:
cur[k] = _resolve_path(root_d, rpath)
cur[k] = _resolve_path(root_d,
_rel2full_path(_join_path(cur_path, k), rpath), prev_paths)
# if replaced value is again dictionary, recurse on it
if isinstance(cur[k], MutableMapping):
resolve_all(root_d, cur[k])
_resolve_all(root_d, cur[k], _join_path(cur_path, k), prev_paths)
def _merge_source(source:Mapping, dest:MutableMapping)->None:
# for anything that source has but dest doesn't, just do copy
@ -38,46 +54,96 @@ def _merge_source(source:Mapping, dest:MutableMapping)->None:
# else at least dest value is not dict and should not be overriden
def _req_resolve(v:Any)->Optional[str]:
"""If the value is actually a path we need resolve then return that path or return None"""
if isinstance(v, str) and v.startswith(_PREFIX_PATH):
return v[len(_PREFIX_PATH):]
# we will almost always have space after _copy command
return v[len(_PREFIX_PATH):].strip()
return None
def _resolve_path(root_d:MutableMapping, path:str)->Any:
assert path # otherwise we end up returning root for null paths
d = root_d
cur_path = '' # maintained for debugging
def _join_path(path1:str, path2:str):
mid = 1 if path1.endswith('/') else 0
mid += 1 if path2.startswith('/') else 0
parts = path.strip().split('/')
# if path starts with '/' remove the first part
if len(parts)>0 and parts[0]=='':
parts = parts[1:]
if len(parts)>0 and parts[-1]=='':
parts = parts[:-1]
# only 3 possibilities
if mid==0:
res = path1 + '/' + path2
elif mid==1:
res = path1 + path2
else:
res = path1[:-1] + path2
return _norm_ended(res)
def _norm_ended(path:str)->str:
if len(path) > 1 and path.endswith('/'):
path = path[:-1]
return path
def is_proper_path(path:str)->bool:
return path.startswith('/') and (len(path)==1 or not path.endswith('/'))
def _rel2full_path(cwd:str, rel_path:str)->str:
"""Given current directory and path, we return abolute path. For example,
cwd='/a/b/c' and rel_path='../d/e' should return '/a/b/d/e'. Note that rel_path
can hold absolute path in which case it will start with '/'
"""
assert len(cwd) > 0 and cwd.startswith('/'), 'cwd must be absolute path'
rel_parts = rel_path.split('/')
if rel_path.startswith('/'):
cwd_parts = [] # rel_path is absolute path so ignore cwd
else:
cwd_parts = cwd.split('/')
full_parts = cwd_parts + rel_parts
final = []
for i in range(len(full_parts)):
part = full_parts[i].strip()
if not part or part == '.': # remove blank strings and single dots
continue
if part == '..':
if len(final):
final.pop()
else:
raise RuntimeError(f'cannot create abs path for cwd={cwd} and rel_path={rel_path}')
else:
final.append(part)
final = '/' + '/'.join(final) # should work even when final is empty
assert not '..' in final and is_proper_path(final) # make algo indeed worked
return final
def _resolve_path(root_d:MutableMapping, path:str, prev_paths:set)->Any:
"""For given path returns value or node from root_d"""
assert is_proper_path(path)
# traverse path in root dict hierarchy
for part in parts:
# make sure current node is dict so we can "cd" into it
if isinstance(d, Mapping):
# if path doesn't exit in current dir, see if there are any copy commands here
if part not in d:
resolve_all(root_d, d)
# if path do exist but is string with copy command and then resolve it first
else:
rpath = _req_resolve(d[part])
if rpath:
d[part] = _resolve_path(root_d, rpath)
# else resolution already done
cur_path = '/' # path at each iteration of for loop
d = root_d
for part in path.split('/'):
if not part:
continue # there will be blank vals at start
# at this point we should have dict to "cd" into otherwise its an error
if isinstance(d, Mapping) and part in d:
# "cd" into child node
d = d[part]
cur_path += '/' + part
# For each part, we need to be able find key in dict but some dics may not
# be fully resolved yet. For last key, d will be either dict or other value.
if isinstance(d, Mapping):
# for this section, make sure everything is resolved
# before we prob for the key
_resolve_all(root_d, d, cur_path, prev_paths)
if part in d:
# "cd" into child node
d = d[part]
cur_path = _join_path(cur_path, part)
else:
raise RuntimeError(f'Path {path} could not be found in specified dictionary at "{part}"')
else:
raise KeyError(f'Path {path} cannot be resolved because {part} in {cur_path} does not exist or is not a dictionary')
raise KeyError(f'Path "{path}" cannot be resolved because "{cur_path}" is not a dictionary so "{part}" cannot exist in it')
# last child is our answer
rpath = _req_resolve(d)
if rpath:
d = _resolve_path(root_d, rpath)
d = _resolve_path(root_d, _rel2full_path(cur_path, rpath), prev_paths)
return d

Просмотреть файл

@ -31,7 +31,7 @@ class ArchTrainer(Trainer, EnforceOverrides):
def compute_loss(self, lossfn: Callable,
x: Tensor, y: Tensor, logits: Tensor,
aux_weight: float, aux_logits: Optional[Tensor]) -> Tensor:
loss = super().compute_loss(lossfn, x, y, logits,
loss = super().compute_loss(lossfn, y, logits,
aux_weight, aux_logits)
# add L1 alpha regularization
if self._l1_alphas > 0.0:

Просмотреть файл

@ -32,7 +32,7 @@ common:
scale_lr: True # enable/disable distributed mode
min_world_size: 0 # allows to confirm we are indeed in distributed setting
detect_anomaly: False # if True, PyTorch code will run 6X slower
seed: '_copy: common/seed'
seed: '_copy: /common/seed'
smoke_test: False
only_eval: False
@ -54,8 +54,8 @@ nas:
model_filename: '$expdir/model.pt' # file to which trained model will be saved
data_parallel: False
checkpoint:
_copy: 'common/checkpoint'
resume: '_copy: common/resume'
_copy: '/common/checkpoint'
resume: '_copy: /common/resume'
model_desc:
dataset:
_copy: '/dataset'
@ -90,15 +90,16 @@ nas:
_copy: '/dataset'
trainer:
apex:
_copy: 'common/apex'
_copy: '/common/apex'
enabled: False
aux_weight: '_copy: nas/eval/model_desc/aux_weight'
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
drop_path_prob: 0.2 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
title: 'eval_train'
epochs: 600
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
lossfn:
type: 'CrossEntropyLoss'
optimizer:
@ -116,6 +117,7 @@ nas:
epochs: 0 # 0 disables warmup
validation:
title: 'eval_test'
batch_chunks: '_copy: ../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
logger_freq: 0
freq: 1 # perform validation only every N epochs
lossfn:
@ -124,15 +126,15 @@ nas:
search:
data_parallel: False
checkpoint:
_copy: 'common/checkpoint'
resume: '_copy: common/resume'
_copy: '/common/checkpoint'
resume: '_copy: /common/resume'
search_iters: 1
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
metrics_dir: '$expdir/models/{reductions}/{cells}/{nodes}/{search_iter}' # where metrics and model stats would be saved from each pareto iteration
seed_train:
trainer:
_copy: 'nas/eval/trainer'
_copy: '/nas/eval/trainer'
apex:
enabled: False
title: 'seed_train'
@ -140,12 +142,12 @@ nas:
aux_weight: 0.0
drop_path_prob: 0.0
loader:
_copy: 'nas/eval/loader'
_copy: '/nas/eval/loader'
train_batch: 128
val_ratio: 0.1 #split portion for test set, 0 to 1
post_train:
trainer:
_copy: 'nas/eval/trainer'
_copy: '/nas/eval/trainer'
apex:
enabled: False
title: 'post_train'
@ -153,7 +155,7 @@ nas:
aux_weight: 0.0
drop_path_prob: 0.0
loader:
_copy: 'nas/eval/loader'
_copy: '/nas/eval/loader'
train_batch: 128
val_ratio: 0.1 #split portion for test set, 0 to 1
pareto:
@ -199,13 +201,14 @@ nas:
_copy: '/dataset'
trainer:
apex:
_copy: 'common/apex'
aux_weight: '_copy: nas/search/model_desc/aux_weight'
_copy: '/common/apex'
aux_weight: '_copy: /nas/search/model_desc/aux_weight'
drop_path_prob: 0.0 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
title: 'arch_train'
epochs: 50
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
# additional vals for the derived class
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
@ -231,6 +234,7 @@ nas:
validation:
title: 'search_val'
logger_freq: 0
batch_chunks: '_copy: ../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
freq: 1 # perform validation only every N epochs
lossfn:
type: 'CrossEntropyLoss'

Просмотреть файл

@ -17,27 +17,27 @@ nas:
trainer:
epochs: 0 # number of epochs model will be trained before search
loader:
train_batch: '_copy: common/toy_mode/train_batch'
test_batch: '_copy: common/toy_mode/test_batch'
train_batch: '_copy: /common/toy_mode/train_batch'
test_batch: '_copy: /common/toy_mode/test_batch'
dataset:
max_batches: '_copy: common/toy_mode/max_batches'
max_batches: '_copy: /common/toy_mode/max_batches'
post_train:
trainer:
epochs: 1
loader:
train_batch: '_copy: common/toy_mode/train_batch'
test_batch: '_copy: common/toy_mode/test_batch'
train_batch: '_copy: /common/toy_mode/train_batch'
test_batch: '_copy: /common/toy_mode/test_batch'
dataset:
max_batches: '_copy: common/toy_mode/max_batches'
max_batches: '_copy: /common/toy_mode/max_batches'
model_desc:
n_reductions: 1 # number of reductions to be applied
n_cells: 3 # number of cells
n_nodes: 2 # number of nodes in a cell
loader:
train_batch: '_copy: common/toy_mode/train_batch'
test_batch: '_copy: common/toy_mode/test_batch'
train_batch: '_copy: /common/toy_mode/train_batch'
test_batch: '_copy: /common/toy_mode/test_batch'
dataset:
max_batches: '_copy: common/toy_mode/max_batches'
max_batches: '_copy: /common/toy_mode/max_batches'
trainer:
epochs: 1
logger_freq: 1
@ -51,10 +51,10 @@ nas:
n_nodes: 4 # number of nodes in a cell
n_reductions: 2 # number of reductions to be applied
loader:
train_batch: '_copy: common/toy_mode/train_batch'
test_batch: '_copy: common/toy_mode/test_batch'
train_batch: '_copy: /common/toy_mode/train_batch'
test_batch: '_copy: /common/toy_mode/test_batch'
dataset:
max_batches: '_copy: common/toy_mode/max_batches'
max_batches: '_copy: /common/toy_mode/max_batches'
trainer:
epochs: 1
logger_freq: 1

Просмотреть файл

@ -304,11 +304,11 @@ child calls. The big downside of this is that if any of the child is long runnin
than any other then we can't checkpoint within that specific child.
## Yaml design
Copy node value using '_copy :path/to/node'.
Copy node value using '_copy :/path/to/node'.
- target will always be scaler
- source could be dict or scaler
- recursive replacements
- this will replace above string path with target value
Insert node childs using _copy: path/to/node
Insert node childs using _copy: /path/to/node
- content of source node is copied
- rest of the child overrides

Просмотреть файл

@ -11,20 +11,20 @@ common:
logging: True
autoaug:
dataset: '_copy: nas/train/dataset'
logging: '_copy: nas/logging'
dataset: '_copy: /nas/train/dataset'
logging: '_copy: /nas/logging'
nas:
train:
dataset:
_copy: 'common/dataset'
_copy: '/common/dataset'
classes: 4
cifar:
workers: 0
logging: '_copy: common/logging'
logging: '_copy: /common/logging'
"""
d = yaml.safe_load(s)
print(d)
resolve_all(d, d)
print(d)
resolve_all(d)
print(yaml.dump(d))

118
tests/yaml_resolve_test.py Normal file
Просмотреть файл

@ -0,0 +1,118 @@
import yaml
from archai.common import yaml_utils
from archai.common import utils
def test_yaml1():
input="""
l1:
l1c1: 5
l2:
l2c2: 6
l3:
l3c2: 7
l3c3: 8
n1:
_copy: '/l1'
n2:
n22:
_copy: '/l1/l2/l3'
n3: '_copy: /n1/l1c1'
n4:
n4c1: 9
n4c2: '_copy: ../n4c1'
n4c3: '_copy: ../../n2/n22'
n5: '_copy: ./../n1/l1c1'
"""
expected = """
l1:
l1c1: 5
l2: &id001
l2c2: 6
l3:
l3c2: 7
l3c3: 8
n1:
l1c1: 5
l2: *id001
n2:
n22: &id002
l3c2: 7
l3c3: 8
n3: 5
n4:
n4c1: 9
n4c2: 9
n4c3: *id002
n5: 5
"""
d_input = yaml.safe_load(input)
yaml_utils.resolve_all(d_input)
#print(yaml.dump(d))
d_expected = yaml.safe_load(expected)
utils.deep_comp(d_input, d_expected)
def test_yaml2():
input="""
common:
dataset:
name: 'd1'
classes: 3
cifar:
limit: -1
logging: True
autoaug:
dataset: '_copy: /nas/train/dataset'
logging: '_copy: /nas/logging'
nas:
train:
dataset:
_copy: '/common/dataset'
classes: 4
cifar:
workers: 0
logging: '_copy: /common/logging'
"""
expected="""
autoaug:
dataset: &id001
cifar:
limit: -1
workers: 0
classes: 4
name: d1
logging: true
common:
dataset:
cifar:
limit: -1
classes: 3
name: d1
logging: true
nas:
logging: true
train:
dataset: *id001
"""
d_input = yaml.safe_load(input)
#print(d)
yaml_utils.resolve_all(d_input)
d_expected = yaml.safe_load(expected)
assert utils.deep_comp(d_input, d_expected)
test_yaml1()
test_yaml2()