зеркало из https://github.com/microsoft/archai.git
Added classes for zerocost on darcyflow.
This commit is contained in:
Родитель
2d51c30390
Коммит
37c619cfe0
|
@ -392,6 +392,14 @@
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["--full", "--algos", "zerocost_darts_space_constant_random", "--datasets", "cifar10"]
|
"args": ["--full", "--algos", "zerocost_darts_space_constant_random", "--datasets", "cifar10"]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ZeroCost-Darts-Space-Constant-Random-Darcyflow-Full",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${cwd}/scripts/main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["--full", "--algos", "zerocost_darts_space_constant_random_darcyflow", "--datasets", "darcyflow"]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Natsbench-Regular-Eval-Full",
|
"name": "Natsbench-Regular-Eval-Full",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Type
|
||||||
|
from copy import deepcopy
|
||||||
|
import os
|
||||||
|
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
|
from archai.common.config import Config
|
||||||
|
from archai.nas import nas_utils
|
||||||
|
from archai.common import utils
|
||||||
|
from archai.nas.exp_runner import ExperimentRunner
|
||||||
|
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||||
|
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||||
|
from archai.nas.evaluater import EvalResult
|
||||||
|
from archai.common.common import get_expdir, logger
|
||||||
|
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||||
|
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_evaluator import ZeroCostDartsSpaceConstantRandomEvaluator
|
||||||
|
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||||
|
|
||||||
|
from nats_bench import create
|
||||||
|
|
||||||
|
class ZeroCostDartsSpaceConstantRandomDarcyFlowExpRunner(ExperimentRunner):
|
||||||
|
"""Runs zero cost on architectures from DARTS search space
|
||||||
|
which are randomly sampled in a reproducible way"""
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def model_desc_builder(self)->RandomModelDescBuilder:
|
||||||
|
return RandomModelDescBuilder()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def trainer_class(self)->TArchTrainer:
|
||||||
|
return None # no search trainer
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def searcher(self)->ManualFreezeSearcher:
|
||||||
|
return ManualFreezeSearcher() # no searcher basically
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def copy_search_to_eval(self)->None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||||
|
# without training architecture evaluation score
|
||||||
|
# ---------------------------------------
|
||||||
|
logger.pushd('zerocost_evaluate')
|
||||||
|
zerocost_evaler = ZeroCostDartsSpaceConstantRandomEvaluator()
|
||||||
|
conf_eval_zerocost = deepcopy(conf_eval)
|
||||||
|
|
||||||
|
if conf_eval_zerocost['checkpoint'] is not None:
|
||||||
|
conf_eval_zerocost['checkpoint']['filename'] = '$expdir/zerocost_checkpoint.pth'
|
||||||
|
|
||||||
|
zerocost_eval_result = zerocost_evaler.evaluate(conf_eval_zerocost, model_desc_builder=self.model_desc_builder())
|
||||||
|
logger.popd()
|
||||||
|
|
||||||
|
return zerocost_eval_result
|
|
@ -99,8 +99,7 @@ class DarcyflowTrainer(ArchTrainer, EnforceOverrides):
|
||||||
|
|
||||||
# darcyflow specific line
|
# darcyflow specific line
|
||||||
logits_c = logits_c.squeeze()
|
logits_c = logits_c.squeeze()
|
||||||
# WARNING, DEBUG: Making code run through for now
|
# decode
|
||||||
# this is missing all the y's decoding
|
|
||||||
yc_decoded = self.y_normalizer.decode(yc)
|
yc_decoded = self.y_normalizer.decode(yc)
|
||||||
logits_decoded = self.y_normalizer.decode(logits_c)
|
logits_decoded = self.y_normalizer.decode(logits_c)
|
||||||
|
|
||||||
|
@ -112,7 +111,6 @@ class DarcyflowTrainer(ArchTrainer, EnforceOverrides):
|
||||||
loss_sum += loss_c.item() * len(logits_c)
|
loss_sum += loss_c.item() * len(logits_c)
|
||||||
loss_count += len(logits_c)
|
loss_count += len(logits_c)
|
||||||
logits_chunks.append(logits_c.detach().cpu())
|
logits_chunks.append(logits_c.detach().cpu())
|
||||||
# logger.info(f"Loss {loss_c/loss_count}")
|
|
||||||
|
|
||||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||||
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
|
self._apex.clip_grad(self._grad_clip, self.model, self._multi_optim)
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
__include__: 'darts.yaml' # just use darts defaults
|
||||||
|
|
||||||
|
nas:
|
||||||
|
search:
|
||||||
|
model_desc:
|
||||||
|
num_edges_to_sample: 2 # number of edges each node will take input from
|
||||||
|
|
||||||
|
eval:
|
||||||
|
dartsspace:
|
||||||
|
arch_index: 66
|
||||||
|
model_desc:
|
||||||
|
aux_weight: False # AuxTower class assumes specific input size hence breaks with many datasets.
|
||||||
|
num_edges_to_sample: 2
|
||||||
|
n_cells: 8
|
||||||
|
n_reductions: 0 # for dense task we want output at the same resolution as input
|
||||||
|
model_post_op: 'identity' # no pooling at the end as we want dense output
|
||||||
|
logits_op: 'proj_channels_no_bn' # reduce channels to 1 and keep same resolution as input. we don't really have 'logits' here as we are doing dense classification.
|
||||||
|
loader:
|
||||||
|
aug: ''
|
||||||
|
cutout: -1 # cutout length, use cutout augmentation when > 0
|
||||||
|
val_ratio: 0.0
|
||||||
|
train_batch: 4
|
||||||
|
use_sampler: False # distributed stratified sampler is not compatible with dense tasks
|
||||||
|
trainer:
|
||||||
|
drop_path_prob: 0.2
|
||||||
|
use_val: False
|
||||||
|
plotsdir: ''
|
||||||
|
epochs: 150
|
||||||
|
lossfn:
|
||||||
|
type: 'L2Loss'
|
||||||
|
lr_schedule:
|
||||||
|
min_lr: 0.0
|
||||||
|
validation:
|
||||||
|
lossfn:
|
||||||
|
type: 'L2Loss'
|
|
@ -31,6 +31,7 @@ from archai.algos.zero_cost_measures.zero_cost_natsbench_experiment_runner impor
|
||||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
|
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
|
||||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
|
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
|
||||||
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_experiment_runner import ZeroCostDartsSpaceConstantRandomExperimentRunner
|
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_experiment_runner import ZeroCostDartsSpaceConstantRandomExperimentRunner
|
||||||
|
from archai.algos.zero_cost_measures.zero_cost_darts_space_constant_random_darcyflow_exprunner import ZeroCostDartsSpaceConstantRandomDarcyFlowExpRunner
|
||||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
||||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
||||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
||||||
|
@ -64,6 +65,7 @@ def main():
|
||||||
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
|
'zerocost_conditional_natsbench_space': ZeroCostConditionalNatsbenchExperimentRunner,
|
||||||
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
|
'zerocost_natsbench_epochs_space': ZeroCostNatsbenchEpochsExperimentRunner,
|
||||||
'zerocost_darts_space_constant_random': ZeroCostDartsSpaceConstantRandomExperimentRunner,
|
'zerocost_darts_space_constant_random': ZeroCostDartsSpaceConstantRandomExperimentRunner,
|
||||||
|
'zerocost_darts_space_constant_random_darcyflow': ZeroCostDartsSpaceConstantRandomDarcyFlowExpRunner,
|
||||||
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
|
'natsbench_regular_eval': NatsbenchRegularExperimentRunner,
|
||||||
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
|
'natsbench_sss_regular_eval': NatsbenchSSSRegularExperimentRunner,
|
||||||
'nb101_regular_eval': Nb101RegularExperimentRunner,
|
'nb101_regular_eval': Nb101RegularExperimentRunner,
|
||||||
|
|
Загрузка…
Ссылка в новой задаче