зеркало из https://github.com/microsoft/archai.git
Added zero cost evaluation on darts constant random architectures.
This commit is contained in:
Родитель
c029c8c717
Коммит
0bb248e20d
|
@ -220,7 +220,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "ninapro"]
|
||||
"args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "cifar100"]
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Darts-Space-Full",
|
||||
|
@ -376,6 +376,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "zerocost_natsbench_epochs_space"]
|
||||
},
|
||||
{
|
||||
"name": "ZeroCost-Darts-Space-Constant-Random-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "zerocost_darts_space_constant_random", "--datasets", "cifar10"]
|
||||
},
|
||||
{
|
||||
"name": "Natsbench-Regular-Eval-Full",
|
||||
"type": "python",
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional
|
||||
import importlib
|
||||
import sys
|
||||
import string
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.nas.evaluater import EvalResult, Evaluater
|
||||
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
|
||||
from archai.algos.random_sample_darts_space.constant_darts_space_sampler import ConstantDartsSpaceSampler
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
from .zero_cost_trainer import ZeroCostTrainer
|
||||
|
||||
class ZeroCostDartsSpaceConstantRandomEvaluator(Evaluater):
|
||||
@overrides
|
||||
def create_model(self, conf_eval:Config, model_desc_builder:RandomModelDescBuilder,
|
||||
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
||||
|
||||
assert model_desc_builder is not None, 'DartsSpaceEvaluater requires model_desc_builder'
|
||||
assert final_desc_filename is None, 'DartsSpaceEvaluater creates its own model desc based on arch index'
|
||||
assert type(model_desc_builder) == RandomModelDescBuilder, 'DartsSpaceEvaluater requires RandomModelDescBuilder'
|
||||
|
||||
# region conf vars
|
||||
# if not explicitly passed in then get from conf
|
||||
if not final_desc_filename:
|
||||
final_desc_filename = conf_eval['final_desc_filename']
|
||||
full_desc_filename = conf_eval['full_desc_filename']
|
||||
conf_model_desc = conf_eval['model_desc']
|
||||
arch_index = conf_eval['dartsspace']['arch_index']
|
||||
# endregion
|
||||
|
||||
assert arch_index >= 0
|
||||
|
||||
# get random seed from constant sampler
|
||||
# to get deterministic model creation
|
||||
self.constant_sampler = ConstantDartsSpaceSampler()
|
||||
random_seed_for_model_construction = self.constant_sampler.get_archid(arch_index)
|
||||
|
||||
# we don't load template model desc file from disk
|
||||
# as we are creating model based on arch_index
|
||||
model_desc = model_desc_builder.build(conf_model_desc,
|
||||
seed=random_seed_for_model_construction)
|
||||
|
||||
# save desc for reference
|
||||
model_desc.save(full_desc_filename)
|
||||
|
||||
model = self.model_from_desc(model_desc)
|
||||
|
||||
logger.info({'model_factory':False,
|
||||
'cells_len':len(model.desc.cell_descs()),
|
||||
'init_node_ch': conf_model_desc['model_stems']['init_node_ch'],
|
||||
'n_cells': conf_model_desc['n_cells'],
|
||||
'n_reductions': conf_model_desc['n_reductions'],
|
||||
'n_nodes': conf_model_desc['cell']['n_nodes']})
|
||||
|
||||
return model
|
||||
|
||||
@overrides
|
||||
def train_model(self, conf_train:Config, model:nn.Module,
|
||||
checkpoint:Optional[CheckPoint])->Metrics:
|
||||
conf_loader = conf_train['loader']
|
||||
conf_train = conf_train['trainer']
|
||||
|
||||
# get data
|
||||
data_loaders = self.get_data(conf_loader)
|
||||
|
||||
trainer = ZeroCostTrainer(conf_train, model, checkpoint)
|
||||
train_metrics = trainer.fit(data_loaders, self.num_classes)
|
||||
return train_metrics
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
from typing import Optional, Type
|
||||
from copy import deepcopy
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import utils
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.evaluater import EvalResult
|
||||
from archai.common.common import get_expdir, logger
|
||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
|
||||
|
||||
from nats_bench import create
|
||||
|
||||
class ZeroCostDartsSpaceConstantRandomExperimentRunner(ExperimentRunner):
|
||||
"""Runs zero cost on architectures from DARTS search space
|
||||
which are randomly sampled in a reproducible way"""
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->Optional[ModelDescBuilder]:
|
||||
return None
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None # no search trainer
|
||||
|
||||
@overrides
|
||||
def searcher(self)->ManualFreezeSearcher:
|
||||
return ManualFreezeSearcher() # no searcher basically
|
||||
|
||||
@overrides
|
||||
def copy_search_to_eval(self)->None:
|
||||
pass
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
# without training architecture evaluation score
|
||||
# ---------------------------------------
|
||||
logger.pushd('zerocost_evaluate')
|
||||
zerocost_evaler = ZeroCostDartsSpaceConstantRandomEvaluator()
|
||||
conf_eval_zerocost = deepcopy(conf_eval)
|
||||
|
||||
if conf_eval_zerocost['checkpoint'] is not None:
|
||||
conf_eval_zerocost['checkpoint']['filename'] = '$expdir/zerocost_checkpoint.pth'
|
||||
|
||||
zerocost_eval_result = zerocost_evaler.evaluate(conf_eval_zerocost, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
return zerocost_eval_result
|
|
@ -0,0 +1,23 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
nas:
|
||||
search:
|
||||
model_desc:
|
||||
num_edges_to_sample: 2 # number of edges each node will take input from
|
||||
|
||||
eval:
|
||||
dartsspace:
|
||||
arch_index: 66
|
||||
model_desc:
|
||||
aux_weight: False # AuxTower class assumes specific input size hence breaks with many datasets.
|
||||
num_edges_to_sample: 2
|
||||
n_cells: 8
|
||||
loader:
|
||||
aug: ''
|
||||
cutout: -1 # cutout length, use cutout augmentation when > 0
|
||||
val_ratio: 0.0
|
||||
train_batch: 96
|
||||
trainer:
|
||||
use_val: False
|
||||
plotsdir: ''
|
||||
epochs: 100
|
|
@ -29,6 +29,7 @@ from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import Free
|
|||
from archai.algos.zero_cost_measures.zero_cost_natsbench_experiment_runner import ZeroCostNatsbenchExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_experiment_runner import ZeroCostDartsSpaceConstantRandomExperimentRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
||||
|
@ -39,6 +40,7 @@ from archai.algos.local_search_natsbench.local_search_natsbench_tss_fear_exp_run
|
|||
from archai.algos.local_search_natsbench.local_search_natsbench_tss_reg_exp_runner import LocalSearchNatsbenchTSSRegExpRunner
|
||||
from archai.algos.local_search_darts.local_search_darts_reg_exp_runner import LocalSearchDartsRegExpRunner
|
||||
|
||||
|
||||
def main():
|
||||
runner_types:Dict[str, Type[ExperimentRunner]] = {
|
||||
'darts': DartsExperimentRunner,
|
||||
|
@ -59,6 +61,7 @@ def main():
|
|||
'zerocost_natsbench_space': ZeroCostNatsbenchExperimentRunner,
|
||||
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
|
||||
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
|
||||
'zerocost_darts_space_constant_random': ZeroCostDartsSpaceConstantRandomExperimentRunner,
|
||||
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
|
||||
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
|
||||
'nb101_regular_eval': Nb101RegularExperimentRunner,
|
||||
|
|
Загрузка…
Ссылка в новой задаче