Random search with FastArchRank runs through.

This commit is contained in:
Debadeepta Dey 2021-04-13 18:42:56 -07:00 коммит произвёл Gustavo Rosa
Родитель a7e87ff6f6
Коммит 1a8f11a87d
4 изменённых файлов: 172 добавлений и 14 удалений

Просмотреть файл

@ -58,9 +58,12 @@ class ConditionalTrainer(ArchTrainer, EnforceOverrides):
# terminate if maximum training duration
# threshold is exceeded
if self._max_duration_secs:
if self._metrics.total_training_time >= self._max_duration_secs:
total_train_time = self._metrics.total_training_time()
if total_train_time >= self._max_duration_secs:
logger.info(f'max duration of training time {total_train_time} exceeded')
logger.info('----------terminating regular training---------')
should_terminate = True
return should_terminate
@overrides

Просмотреть файл

@ -1,4 +1,5 @@
import math as ma
from time import time
from typing import Set
import os
import random
@ -32,7 +33,6 @@ class RandomNatsbenchTssFarSearcher(Searcher):
# region config vars
max_num_models = conf_search['max_num_models']
ratio_fastest_duration = conf_search['ratio_fastest_duration']
top1_acc_threshold = conf_search['top1_acc_threshold']
dataroot = utils.full_path(conf_search['loader']['dataset']['dataroot'])
dataset_name = conf_search['loader']['dataset']['name']
natsbench_location = os.path.join(dataroot, 'natsbench', conf_search['natsbench']['natsbench_tss_fast'])
@ -42,14 +42,14 @@ class RandomNatsbenchTssFarSearcher(Searcher):
# endregion
# create the natsbench api
api = create_natsbench_tss_api(natsbench_location, 'tss', fast_mode=True, verbose=True)
api = create_natsbench_tss_api(natsbench_location)
# presample max number of archids without replacement
random_archids = random.sample(range(api), k=max_num_models)
random_archids = random.sample(range(len(api)), k=max_num_models)
best_trains = [(-1, -ma.Inf)]
best_tests = [(-1, -ma.Inf)]
fastest_cond_train = ma.Inf
best_trains = [(-1, -ma.inf)]
best_tests = [(-1, -ma.inf)]
fastest_cond_train = ma.inf
for archid in random_archids:
# get model
@ -64,14 +64,22 @@ class RandomNatsbenchTssFarSearcher(Searcher):
# starts exceeding fastest time to
# reach threshold by a ratio then early
# terminate it
logger.pushd('conditional training')
logger.pushd(f'conditional_training_{archid}')
data_loaders = self.get_data(conf_loader)
time_allowed = ratio_fastest_duration * fastest_cond_train
cond_trainer = ConditionalTrainer(conf_train, model, checkpoint, time_allowed)
cond_trainer_metrics = cond_trainer.fit(data_loaders)
cond_train_time = cond_trainer_metrics.total_training_time()
if cond_train_time >= time_allowed:
# this arch exceeded time to reach threshold
# cut losses and move to next one
continue
if cond_train_time < fastest_cond_train:
fastest_cond_train = cond_train_time
logger.info(f'fastest condition train till now: {fastest_cond_train} seconds!')
logger.popd()
# if we did not early terminate in conditional
@ -82,7 +90,7 @@ class RandomNatsbenchTssFarSearcher(Searcher):
conf_loader_freeze = deepcopy(conf_loader)
conf_loader_freeze['train_batch'] = conf_loader['freeze_loader']['train_batch']
logger.pushd('freeze_training')
logger.pushd(f'freeze_training_{archid}')
data_loaders = self.get_data(conf_loader_freeze)
# now just finetune the last few layers
checkpoint = None
@ -100,7 +108,6 @@ class RandomNatsbenchTssFarSearcher(Searcher):
best_tests.append((archid, this_arch_top1_test))
# dump important things to log
logger.info({'best_trains':best_trains, 'best_tests':best_tests})

Просмотреть файл

@ -256,7 +256,7 @@ class Metrics:
return test_epoch_metrics.top1.avg if test_epoch_metrics is not None else math.nan
def total_training_time(self)->float:
self.run_metrics.total_train_time()
return self.run_metrics.total_train_time()
class Accumulator:
# TODO: replace this with Metrics class

Просмотреть файл

@ -0,0 +1,148 @@
__include__: "../datasets/cifar10.yaml" # default dataset settings are for cifar
common:
experiment_name: 'throwaway' # you should supply from command line
experiment_desc: 'throwaway'
logdir: '~/logdir'
log_prefix: 'log' # prefix for log files that will becreated (log.log and log.yaml), no log files if ''
log_level: 20 # logging.INFO
backup_existing_log_file: False # should we overwrite existing log file without making a copy?
yaml_log: True # if True, structured logs as yaml are also generated
seed: 2.0
tb_enable: False # if True then TensorBoard logging is enabled (may impact perf)
tb_dir: '$expdir/tb' # path where tensorboard logs would be stored
checkpoint:
filename: '$expdir/checkpoint.pth'
freq: 10
# reddis address of Ray cluster. Use None for single node run
# otherwise it should something like host:6379. Make sure to run on head node:
# "ray start --head --redis-port=6379"
redis: null
apex: # this is overriden in search and eval individually
enabled: False # global switch to disable everything apex
distributed_enabled: True # enable/disable distributed mode
mixed_prec_enabled: True # switch to disable amp mixed precision
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
opt_level: 'O2' # optimization level for mixed precision
bn_fp32: True # keep BN in fp32
loss_scale: "dynamic" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
sync_bn: False # should be replace BNs with sync BNs for distributed model
scale_lr: True # enable/disable distributed mode
min_world_size: 0 # allows to confirm we are indeed in distributed setting
detect_anomaly: False # if True, PyTorch code will run 6X slower
seed: '_copy: /common/seed'
ray:
enabled: False # initialize ray. Note: ray cannot be used if apex distributed is enabled
local_mode: False # if True then ray runs in serial mode
smoke_test: False
only_eval: False
resume: True
dataset: {} # default dataset settings comes from __include__ on the top
nas:
search:
max_num_models: 4
ratio_fastest_duration: 1.2
natsbench:
natsbench_tss_fast: 'NATS-tss-v1_0-3ffb9-simple' # folder name in dataroot/natsbench that contains the tss fast mode folder
finalizer: 'default' # options are 'random' or 'default'
data_parallel: False
checkpoint:
_copy: '/common/checkpoint'
resume: '_copy: /common/resume'
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
loader:
apex:
_copy: '../../trainer/apex'
aug: '' # additional augmentations to use
cutout: 0 # cutout length, use cutout augmentation when > 0
load_train: True # load train split of dataset
train_batch: 64
freeze_loader:
train_batch: 1024 # batch size for freeze training.
train_workers: 4 # if null then gpu_count*4
test_workers: '_copy: ../train_workers' # if null then 4
load_test: False # load test split of dataset
test_batch: 1024
val_ratio: 0.0 #split portion for test set, 0 to 1
val_fold: 0 #Fold number to use (0 to 4)
cv_num: 5 # total number of folds available
dataset:
_copy: '/dataset'
trainer:
use_val: False
top1_acc_threshold: 0.1 # after some accuracy we will shift into training only the last 'n' layers
apex:
_copy: '/common/apex'
aux_weight: 0.0
drop_path_prob: 0.2 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
title: 'arch_train'
epochs: 200
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
# additional vals for the derived class
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
lossfn:
type: 'CrossEntropyLoss'
optimizer:
type: 'sgd'
lr: 0.1 # init learning rate
decay: 5.0e-4
momentum: 0.9 # pytorch default is 0
nesterov: True
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
lr_schedule:
type: 'cosine'
min_lr: 0.000 # min learning rate, this will be used in eta_min param of scheduler
warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler
multiplier: 1
epochs: 0 # 0 disables warmup
validation:
title: 'search_val'
logger_freq: 0
batch_chunks: '_copy: ../../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
freq: 1 # perform validation only every N epochs
lossfn:
type: 'CrossEntropyLoss'
freeze_trainer:
plotsdir: ''
identifiers_to_unfreeze: ['classifier', 'lastact', 'cells.16', 'cells.15', 'cells.14', 'cells.13'] # last few layer names in natsbench: lastact, lastact.0, lastact.1: BN-Relu, global_pooling: global avg. pooling (doesn't get exposed as a named param though), classifier: linear layer
apex:
_copy: '/common/apex'
aux_weight: 0.0 # very important that this is 0.0 for freeze training
drop_path_prob: 0.0 # very important that this is 0.0 for freeze training
grad_clip: 5.0 # grads above this value is clipped
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
title: 'eval_train'
epochs: 5
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
lossfn:
type: 'CrossEntropyLoss'
optimizer:
type: 'sgd'
lr: 0.1 # init learning rate
decay: 5.0e-4 # pytorch default is 0.0
momentum: 0.9 # pytorch default is 0.0
nesterov: True # pytorch default is False
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
lr_schedule:
type: 'cosine'
min_lr: 0.000 # min learning rate to be set in eta_min param of scheduler
warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler
multiplier: 1
epochs: 0 # 0 disables warmup
validation:
title: 'eval_test'
batch_chunks: '_copy: ../../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
logger_freq: 0
freq: 1 # perform validation only every N epochs
lossfn:
type: 'CrossEntropyLoss'