зеркало из https://github.com/microsoft/archai.git
Started testing random search on Darts search space.
This commit is contained in:
Родитель
e16d7b1458
Коммит
1cf6843f5d
|
@ -507,6 +507,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "random_natsbench_tss_reg", "--datasets", "cifar100", "--no-eval"]
|
||||
},
|
||||
{
|
||||
"name": "Random Darts Space Reg Search",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "random_dartsspace_reg", "--no-eval"]
|
||||
},
|
||||
{
|
||||
"name": "Local Natsbench Tss Far Search",
|
||||
"type": "python",
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
from overrides import overrides
|
||||
from typing import Optional, Type, Tuple
|
||||
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.arch_trainer import TArchTrainer
|
||||
from archai.common import common
|
||||
from archai.common import utils
|
||||
from archai.common.config import Config
|
||||
from archai.nas.evaluater import Evaluater, EvalResult
|
||||
from archai.nas.searcher import Searcher, SearchResult
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_darts.random_dartsspace_reg_searcher import RandomDartsSpaceRegSearcher
|
||||
|
||||
class RandomDartsSpaceRegExpRunner(ExperimentRunner):
|
||||
''' Runs random search using FastArchRank on DARTS search space '''
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->Optional[ModelDescBuilder]:
|
||||
return RandomModelDescBuilder()
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None # no search trainer
|
||||
|
||||
@overrides
|
||||
def run_search(self, conf_search:Config)->SearchResult:
|
||||
model_desc_builder = self.model_desc_builder()
|
||||
trainer_class = self.trainer_class()
|
||||
finalizers = self.finalizers()
|
||||
search = self.searcher()
|
||||
return search.search(conf_search, model_desc_builder, trainer_class, finalizers)
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
evaler = self.evaluater()
|
||||
return evaler.evaluate(conf_eval)
|
||||
|
||||
@overrides
|
||||
def searcher(self)->Searcher:
|
||||
return RandomDartsSpaceRegSearcher()
|
||||
|
||||
@overrides
|
||||
def evaluater(self)->Evaluater:
|
||||
return None
|
||||
|
||||
@overrides
|
||||
def copy_search_to_eval(self) -> None:
|
||||
return None
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ from archai.nas.arch_trainer import TArchTrainer
|
|||
from archai.common.trainer import Trainer
|
||||
from archai.common import utils
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.model import Model
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
|
||||
|
@ -48,8 +49,7 @@ class RandomDartsSpaceRegSearcher(Searcher):
|
|||
# as we are creating model based on seed
|
||||
model_desc = model_desc_builder.build(conf_model_desc,
|
||||
seed=seed_for_arch_generation)
|
||||
model = self.model_from_desc(model_desc)
|
||||
|
||||
model = Model(model_desc, droppath=True, affine=True)
|
||||
|
||||
checkpoint = None
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ class RandomNatsbenchTssRegExpRunner(ExperimentRunner):
|
|||
def trainer_class(self)->TArchTrainer:
|
||||
return None # no search trainer
|
||||
|
||||
|
||||
@overrides
|
||||
def run_search(self, conf_search:Config)->SearchResult:
|
||||
search = self.searcher()
|
||||
|
|
|
@ -60,7 +60,7 @@ nas:
|
|||
resume: '_copy: /common/resume'
|
||||
model_desc:
|
||||
n_reductions: 2 # number of reductions to be applied
|
||||
n_cells: 8 # number of cells
|
||||
n_cells: 20 # number of cells
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
max_final_edges: 2 # max edge that can be in final arch per node
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
__include__: "../datasets/cifar10.yaml" # default dataset settings are for cifar
|
||||
|
||||
common:
|
||||
experiment_name: 'throwaway' # you should supply from command line
|
||||
experiment_desc: 'throwaway'
|
||||
logdir: '~/logdir'
|
||||
log_prefix: 'log' # prefix for log files that will becreated (log.log and log.yaml), no log files if ''
|
||||
log_level: 20 # logging.INFO
|
||||
backup_existing_log_file: False # should we overwrite existing log file without making a copy?
|
||||
yaml_log: True # if True, structured logs as yaml are also generated
|
||||
seed: 2.0
|
||||
tb_enable: False # if True then TensorBoard logging is enabled (may impact perf)
|
||||
tb_dir: '$expdir/tb' # path where tensorboard logs would be stored
|
||||
checkpoint:
|
||||
filename: '$expdir/checkpoint.pth'
|
||||
freq: 10
|
||||
|
||||
# reddis address of Ray cluster. Use None for single node run
|
||||
# otherwise it should something like host:6379. Make sure to run on head node:
|
||||
# "ray start --head --redis-port=6379"
|
||||
redis: null
|
||||
apex: # this is overriden in search and eval individually
|
||||
enabled: False # global switch to disable everything apex
|
||||
distributed_enabled: True # enable/disable distributed mode
|
||||
mixed_prec_enabled: True # switch to disable amp mixed precision
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
opt_level: 'O2' # optimization level for mixed precision
|
||||
bn_fp32: True # keep BN in fp32
|
||||
loss_scale: "dynamic" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
|
||||
sync_bn: False # should be replace BNs with sync BNs for distributed model
|
||||
scale_lr: True # enable/disable distributed mode
|
||||
min_world_size: 0 # allows to confirm we are indeed in distributed setting
|
||||
detect_anomaly: False # if True, PyTorch code will run 6X slower
|
||||
seed: '_copy: /common/seed'
|
||||
ray:
|
||||
enabled: False # initialize ray. Note: ray cannot be used if apex distributed is enabled
|
||||
local_mode: False # if True then ray runs in serial mode
|
||||
|
||||
smoke_test: False
|
||||
only_eval: False
|
||||
resume: True
|
||||
|
||||
dataset: {} # default dataset settings comes from __include__ on the top
|
||||
|
||||
nas:
|
||||
search:
|
||||
max_num_models: 5
|
||||
finalizer: 'default' # options are 'random' or 'default'
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: '/common/checkpoint'
|
||||
resume: '_copy: /common/resume'
|
||||
search_iters: 1
|
||||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
|
||||
metrics_dir: '$expdir/models/{reductions}/{cells}/{nodes}/{search_iter}' # where metrics and model stats would be saved from each pareto iteration
|
||||
model_desc:
|
||||
n_reductions: 2 # number of reductions to be applied
|
||||
n_cells: 8 # number of cells
|
||||
num_edges_to_sample: 2 # number of incoming edges per node to be randomly sampled
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
max_final_edges: 2 # max edge that can be in final arch per node
|
||||
model_post_op: 'pool_adaptive_avg2d'
|
||||
params: {}
|
||||
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
||||
aux_tower_stride: 3 # stride that aux tower should use, 3 is good for 32x32 images, 2 for imagenet
|
||||
model_stems:
|
||||
ops: ['stem_conv3x3', 'stem_conv3x3']
|
||||
stem_multiplier: 3 # output channels multiplier for the stem
|
||||
init_node_ch: 36 # num of input/output channels for nodes in 1st cell. NOTE: we match that in eval since this is discrete search.
|
||||
cell:
|
||||
n_nodes: 4 # number of nodes in a cell
|
||||
cell_post_op: 'concate_channels'
|
||||
loader:
|
||||
apex:
|
||||
_copy: '../../trainer/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 16 # cutout length, use cutout augmentation when > 0
|
||||
load_train: True # load train split of dataset
|
||||
train_batch: 96
|
||||
train_workers: 4 # if null then gpu_count*4
|
||||
test_workers: '_copy: ../train_workers' # if null then 4
|
||||
load_test: True # load test split of dataset
|
||||
test_batch: 1024
|
||||
val_ratio: 0.0 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/search/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'arch_train'
|
||||
epochs: 1
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
# additional vals for the derived class
|
||||
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 3.0e-4
|
||||
momentum: 0.9 # pytorch default is 0
|
||||
nesterov: False
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
|
||||
warmup: null
|
||||
validation:
|
||||
title: 'search_val'
|
||||
logger_freq: 0
|
||||
batch_chunks: '_copy: ../../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
freq: 1 # perform validation only every N epochs
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
trainer_full:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/search/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'arch_train'
|
||||
epochs: 2
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
# additional vals for the derived class
|
||||
plotsdir: '' #empty string means no plots, other wise plots are generated for each epoch in this dir
|
||||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
decay: 3.0e-4
|
||||
momentum: 0.9 # pytorch default is 0
|
||||
nesterov: False
|
||||
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
|
||||
warmup: null
|
||||
validation:
|
||||
title: 'search_val'
|
||||
logger_freq: 0
|
||||
batch_chunks: '_copy: ../../batch_chunks' # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
freq: 1 # perform validation only every N epochs
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
|
@ -29,6 +29,7 @@ from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runne
|
|||
from archai.algos.random_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
||||
from archai.algos.random_darts.random_dartsspace_reg_exp_runner import RandomDartsSpaceRegExpRunner
|
||||
from archai.algos.local_search_natsbench.local_natsbench_tss_far_exp_runner import LocalNatsbenchTssFarExpRunner
|
||||
|
||||
def main():
|
||||
|
@ -56,6 +57,7 @@ def main():
|
|||
'random_natsbench_tss_far': RandomNatsbenchTssFarExpRunner,
|
||||
'random_natsbench_tss_far_post': RandomNatsbenchTssFarPostExpRunner,
|
||||
'random_natsbench_tss_reg': RandomNatsbenchTssRegExpRunner,
|
||||
'random_dartsspace_reg': RandomDartsSpaceRegExpRunner,
|
||||
'local_natsbench_tss_far': LocalNatsbenchTssFarExpRunner
|
||||
}
|
||||
|
||||
|
@ -83,6 +85,7 @@ def main():
|
|||
random_natsbench_tss_far,
|
||||
random_natsbench_tss_far_post,
|
||||
random_natsbench_tss_reg,
|
||||
random_dartsspace_reg,
|
||||
local_natsbench_tss_far''',
|
||||
help='NAS algos to run, separated by comma')
|
||||
parser.add_argument('--datasets', type=str, default='cifar10',
|
||||
|
|
Загрузка…
Ссылка в новой задаче