Added zero cost evaluation on darts constant random architectures.

This commit is contained in:
Debadeepta Dey 2021-12-03 15:37:23 -08:00 коммит произвёл Gustavo Rosa
Родитель c029c8c717
Коммит 0bb248e20d
5 изменённых файлов: 184 добавлений и 1 удалений

10
.vscode/launch.json поставляемый
Просмотреть файл

@ -220,7 +220,7 @@
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "ninapro"]
"args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "cifar100"]
},
{
"name": "Proxynas-Darts-Space-Full",
@ -376,6 +376,14 @@
"console": "integratedTerminal",
"args": ["--algos", "zerocost_natsbench_epochs_space"]
},
{
"name": "ZeroCost-Darts-Space-Constant-Random-Full",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "zerocost_darts_space_constant_random", "--datasets", "cifar10"]
},
{
"name": "Natsbench-Regular-Eval-Full",
"type": "python",

Просмотреть файл

@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
import importlib
import sys
import string
import os
from overrides import overrides
import torch
from torch import nn
from overrides import overrides, EnforceOverrides
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas import nas_utils
from archai.common import ml_utils, utils
from archai.common.metrics import EpochMetrics, Metrics
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.nas.evaluater import EvalResult, Evaluater
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
from archai.algos.random_sample_darts_space.constant_darts_space_sampler import ConstantDartsSpaceSampler
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from .zero_cost_trainer import ZeroCostTrainer
class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
@overrides
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
final_desc_filename=None, full_desc_filename=None)->nn.Module:
assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder'
assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index'
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
# region conf vars
# if not explicitly passed in then get from conf
if not final_desc_filename:
final_desc_filename = conf_eval['final_desc_filename']
full_desc_filename = conf_eval['full_desc_filename']
conf_model_desc = conf_eval['model_desc']
arch_index = conf_eval['dartsspace']['arch_index']
# endregion
assert arch_index >= 0
# get random seed from constant sampler
# to get deterministic model creation
self.constant_sampler = ConstantDartsSpaceSampler()
random_seed_for_model_construction = self.constant_sampler.get_archid(arch_index)
# we don't load template model desc file from disk
# as we are creating model based on arch_index
model_desc = model_desc_builder.build(conf_model_desc,
seed=random_seed_for_model_construction)
# save desc for reference
model_desc.save(full_desc_filename)
model = self.model_from_desc(model_desc)
logger.info({'model_factory':False,
'cells_len':len(model.desc.cell_descs()),
'init_node_ch': conf_model_desc['model_stems']['init_node_ch'],
'n_cells': conf_model_desc['n_cells'],
'n_reductions': conf_model_desc['n_reductions'],
'n_nodes': conf_model_desc['cell']['n_nodes']})
return model
@overrides
def train_model(self, conf_train:Config, model:nn.Module,
checkpoint:Optional[CheckPoint])->Metrics:
conf_loader = conf_train['loader']
conf_train = conf_train['trainer']
# get data
data_loaders = self.get_data(conf_loader)
trainer = ZeroCostTrainer(conf_train, model, checkpoint)
train_metrics = trainer.fit(data_loaders, self.num_classes)
return train_metrics

Просмотреть файл

@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Type
from copy import deepcopy
import os
from overrides import overrides
from archai.common.config import Config
from archai.nas import nas_utils
from archai.common import utils
from archai.nas.exp_runner import ExperimentRunner
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas.evaluater import EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
from nats_bench import create
class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner):
"""Runs zero cost on architectures from DARTS search space
which are randomly sampled in a reproducible way"""
@overrides
def model_desc_builder(self)->Optional[ModelDescBuilder]:
return None
@overrides
def trainer_class(self)->TArchTrainer:
return None # no search trainer
@overrides
def searcher(self)->ManualFreezeSearcher:
return ManualFreezeSearcher() # no searcher basically
@overrides
def copy_search_to_eval(self)->None:
pass
@overrides
def run_eval(self, conf_eval:Config)->EvalResult:
# without training architecture evaluation score
# ---------------------------------------
logger.pushd('zerocost_evaluate')
zerocost_evaler = ZeroCostDartsSpaceConstantRandomEvaluator()
conf_eval_zerocost = deepcopy(conf_eval)
if conf_eval_zerocost['checkpoint'] is not None:
conf_eval_zerocost['checkpoint']['filename'] = '$expdir/zerocost_checkpoint.pth'
zerocost_eval_result = zerocost_evaler.evaluate(conf_eval_zerocost, model_desc_builder=self.model_desc_builder())
logger.popd()
return zerocost_eval_result

Просмотреть файл

@ -0,0 +1,23 @@
__include__: 'darts.yaml' # just use darts defaults
nas:
search:
model_desc:
num_edges_to_sample: 2 # number of edges each node will take input from
eval:
dartsspace:
arch_index: 66
model_desc:
aux_weight: False # AuxTower class assumes specific input size hence breaks with many datasets.
num_edges_to_sample: 2
n_cells: 8
loader:
aug: ''
cutout: -1 # cutout length, use cutout augmentation when > 0
val_ratio: 0.0
train_batch: 96
trainer:
use_val: False
plotsdir: ''
epochs: 100

Просмотреть файл

@ -29,6 +29,7 @@ from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import Free
from archai.algos.zero_cost_measures.zero_cost_natsbench_experiment_runner import ZeroCostNatsbenchExperimentRunner
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_experiment_runner import ZeroCostDartsSpaceConstantRandomExperimentRunner
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
@ -39,6 +40,7 @@ from archai.algos.local_search_natsbench.local_search_natsbench_tss_fear_exp_run
from archai.algos.local_search_natsbench.local_search_natsbench_tss_reg_exp_runner import LocalSearchNatsbenchTSSRegExpRunner
from archai.algos.local_search_darts.local_search_darts_reg_exp_runner import LocalSearchDartsRegExpRunner
def main():
runner_types:Dict[str, Type[ExperimentRunner]] = {
'darts': DartsExperimentRunner,
@ -59,6 +61,7 @@ def main():
'zerocost_natsbench_space': ZeroCostNatsbenchExperimentRunner,
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
'zerocost_darts_space_constant_random': ZeroCostDartsSpaceConstantRandomExperimentRunner,
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
'nb101_regular_eval': Nb101RegularExperimentRunner,