зеркало из https://github.com/microsoft/archai.git
Added classes for zerocost on darcyflow.
This commit is contained in:
Родитель
2d51c30390
Коммит
37c619cfe0
|
@ -392,6 +392,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "zerocost_darts_space_constant_random", "--datasets", "cifar10"]
|
||||
},
|
||||
{
|
||||
"name": "ZeroCost-Darts-Space-Constant-Random-Darcyflow-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "zerocost_darts_space_constant_random_darcyflow", "--datasets", "darcyflow"]
|
||||
},
|
||||
{
|
||||
"name": "Natsbench-Regular-Eval-Full",
|
||||
"type": "python",
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
from typing import Optional, Type
|
||||
from copy import deepcopy
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import utils
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.evaluater import EvalResult
|
||||
from archai.common.common import get_expdir, logger
|
||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
from nats_bench import create
|
||||
|
||||
class ZeroCostDartsSpaceConstantRandomDarcyFlowExpRunner(ExperimentRunner):
|
||||
"""Runs zero cost on architectures from DARTS search space
|
||||
which are randomly sampled in a reproducible way"""
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->RandomModelDescBuilder:
|
||||
return RandomModelDescBuilder()
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None # no search trainer
|
||||
|
||||
@overrides
|
||||
def searcher(self)->ManualFreezeSearcher:
|
||||
return ManualFreezeSearcher() # no searcher basically
|
||||
|
||||
@overrides
|
||||
def copy_search_to_eval(self)->None:
|
||||
pass
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
# without training architecture evaluation score
|
||||
# ---------------------------------------
|
||||
logger.pushd('zerocost_evaluate')
|
||||
zerocost_evaler = ZeroCostDartsSpaceConstantRandomEvaluator()
|
||||
conf_eval_zerocost = deepcopy(conf_eval)
|
||||
|
||||
if conf_eval_zerocost['checkpoint'] is not None:
|
||||
conf_eval_zerocost['checkpoint']['filename'] = '$expdir/zerocost_checkpoint.pth'
|
||||
|
||||
zerocost_eval_result = zerocost_evaler.evaluate(conf_eval_zerocost, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
return zerocost_eval_result
|
|
@ -99,8 +99,7 @@ class DarcyflowTrainer(ArchTrainer, EnforceOverrides):
|
|||
|
||||
# darcyflow specific line
|
||||
logits_c = logits_c.squeeze()
|
||||
# WARNING, DEBUG: Making code run through for now
|
||||
# this is missing all the y's decoding
|
||||
# decode
|
||||
yc_decoded = self.y_normalizer.decode(yc)
|
||||
logits_decoded = self.y_normalizer.decode(logits_c)
|
||||
|
||||
|
@ -112,7 +111,6 @@ class DarcyflowTrainer(ArchTrainer, EnforceOverrides):
|
|||
loss_sum += loss_c.item() * len(logits_c)
|
||||
loss_count += len(logits_c)
|
||||
logits_chunks.append(logits_c.detach().cpu())
|
||||
# logger.info(f"Loss {loss_c/loss_count}")
|
||||
|
||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
nas:
|
||||
search:
|
||||
model_desc:
|
||||
num_edges_to_sample: 2 # number of edges each node will take input from
|
||||
|
||||
eval:
|
||||
dartsspace:
|
||||
arch_index: 66
|
||||
model_desc:
|
||||
aux_weight: False # AuxTower class assumes specific input size hence breaks with many datasets.
|
||||
num_edges_to_sample: 2
|
||||
n_cells: 8
|
||||
n_reductions: 0 # for dense task we want output at the same resolution as input
|
||||
model_post_op: 'identity' # no pooling at the end as we want dense output
|
||||
logits_op: 'proj_channels_no_bn' # reduce channels to 1 and keep same resolution as input. we don't really have 'logits' here as we are doing dense classification.
|
||||
loader:
|
||||
aug: ''
|
||||
cutout: -1 # cutout length, use cutout augmentation when > 0
|
||||
val_ratio: 0.0
|
||||
train_batch: 4
|
||||
use_sampler: False # distributed stratified sampler is not compatible with dense tasks
|
||||
trainer:
|
||||
drop_path_prob: 0.2
|
||||
use_val: False
|
||||
plotsdir: ''
|
||||
epochs: 150
|
||||
lossfn:
|
||||
type: 'L2Loss'
|
||||
lr_schedule:
|
||||
min_lr: 0.0
|
||||
validation:
|
||||
lossfn:
|
||||
type: 'L2Loss'
|
|
@ -31,6 +31,7 @@ from archai.algos.zero_cost_measures.zero_cost_natsbench_experiment_runner impor
|
|||
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_experiment_runner import ZeroCostDartsSpaceConstantRandomExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_darcyflow_exprunner import ZeroCostDartsSpaceConstantRandomDarcyFlowExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
||||
|
@ -64,6 +65,7 @@ def main():
|
|||
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
|
||||
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
|
||||
'zerocost_darts_space_constant_random': ZeroCostDartsSpaceConstantRandomExperimentRunner,
|
||||
'zerocost_darts_space_constant_random_darcyflow': ZeroCostDartsSpaceConstantRandomDarcyFlowExpRunner,
|
||||
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
|
||||
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
|
||||
'nb101_regular_eval': Nb101RegularExperimentRunner,
|
||||
|
|
Загрузка…
Ссылка в новой задаче