зеркало из https://github.com/microsoft/archai.git
Refactored code a bit for better organization.
This commit is contained in:
Родитель
b9f45f158f
Коммит
1620f0eafc
|
@ -214,13 +214,21 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "random"]
|
||||
},
|
||||
{
|
||||
"name": "Darts Space Constant Random Archs",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "sphericalcifar100"]
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Darts-Space-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "proxynas_darts_space", "--datasets", "sphericalcifar100"]
|
||||
"args": ["--full", "--algos", "proxynas_darts_space", "--datasets", "cifar10"]
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Darts-Space-Toy",
|
||||
|
|
|
@ -11,7 +11,7 @@ from archai.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
|
|||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.operations import MultiOp, Op
|
||||
from archai.common.config import Config
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
from .petridish_op import PetridishOp, TempIdentityOp
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ from archai.common.checkpoint import CheckPoint
|
|||
from archai.nas.evaluater import Evaluater
|
||||
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
|
||||
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
|
||||
from archai.algos.proxynas.constant_darts_space_sampler import ConstantDartsSpaceSampler
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.constant_darts_space_sampler import ConstantDartsSpaceSampler
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
class FreezeDartsSpaceEvaluater(Evaluater):
|
||||
@overrides
|
||||
|
|
|
@ -15,9 +15,9 @@ from archai.nas.evaluater import Evaluater, EvalResult
|
|||
|
||||
from archai.common.common import get_expdir, logger
|
||||
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||
from .darts_space_evaluater import DartsSpaceEvaluater
|
||||
from ..random_sample_darts_space.darts_space_evaluater import DartsSpaceEvaluater
|
||||
from .freeze_darts_space_evaluater import FreezeDartsSpaceEvaluater
|
||||
|
||||
class FreezeDartsSpaceExperimentRunner(ExperimentRunner):
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
import random
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ConstantDartsSpaceSampler():
|
||||
''' Always returns the same set of random seeds to be used for
|
||||
reproducible sampling of architectures from the DARTS search space'''
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from archai.nas.evaluater import EvalResult
|
||||
from typing import Type
|
||||
from copy import deepcopy
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas import nas_utils
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.evaluater import Evaluater, EvalResult
|
||||
|
||||
from archai.common.common import get_expdir, logger
|
||||
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
|
||||
from .darts_space_evaluater import DartsSpaceEvaluater
|
||||
|
||||
class DartsSpaceConstantRandomArchsExperimentRunner(ExperimentRunner):
|
||||
''' Samples a reproducible random architecture from DARTS search space and freeze trains it '''
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->RandomModelDescBuilder:
|
||||
return RandomModelDescBuilder()
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None
|
||||
|
||||
@overrides
|
||||
def searcher(self)->ManualFreezeSearcher:
|
||||
return ManualFreezeSearcher() # no searcher basically
|
||||
|
||||
@overrides
|
||||
def copy_search_to_eval(self)->None:
|
||||
pass
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
# regular evaluation of the architecture
|
||||
# this is expensive
|
||||
# --------------------------------------
|
||||
logger.pushd('regular_evaluate')
|
||||
evaler = DartsSpaceEvaluater()
|
||||
conf_eval_reg = deepcopy(conf_eval)
|
||||
reg_eval_result = evaler.evaluate(conf_eval_reg, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
return reg_eval_result
|
|
@ -30,8 +30,8 @@ from archai.common.checkpoint import CheckPoint
|
|||
from archai.nas.evaluater import Evaluater
|
||||
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
|
||||
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
|
||||
from archai.algos.proxynas.constant_darts_space_sampler import ConstantDartsSpaceSampler
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.constant_darts_space_sampler import ConstantDartsSpaceSampler
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
class DartsSpaceEvaluater(Evaluater):
|
||||
@overrides
|
||||
|
@ -64,7 +64,8 @@ class DartsSpaceEvaluater(Evaluater):
|
|||
seed=random_seed_for_model_construction)
|
||||
|
||||
# convert from ModelDesc to Genotype for use with nasbench301
|
||||
# NOTE: this is just showing how to potentially connect with 301.
|
||||
# NOTE: this is just showing how to potentially connect with 301.
|
||||
# This is for future use.
|
||||
genotype = create_nb301_genotype_from_desc(model_desc)
|
||||
print(genotype)
|
||||
|
|
@ -12,7 +12,7 @@ from archai.nas.searcher import Searcher, SearchResult
|
|||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_far_searcher import RandomNatsbenchTssFarSearcher
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_searcher import RandomNatsbenchTssFarSearcher
|
||||
|
||||
|
||||
class RandomNatsbenchTssFarExpRunner(ExperimentRunner):
|
|
@ -12,7 +12,7 @@ from archai.nas.searcher import Searcher, SearchResult
|
|||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_far_post_searcher import RandomNatsbenchTssFarPostSearcher
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_searcher import RandomNatsbenchTssFarPostSearcher
|
||||
|
||||
|
||||
class RandomNatsbenchTssFarPostExpRunner(ExperimentRunner):
|
|
@ -12,7 +12,7 @@ from archai.nas.searcher import Searcher, SearchResult
|
|||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_reg_searcher import RandomNatsbenchTssRegSearcher
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_searcher import RandomNatsbenchTssRegSearcher
|
||||
|
||||
|
||||
class RandomNatsbenchTssRegExpRunner(ExperimentRunner):
|
|
@ -12,8 +12,8 @@ from archai.nas.searcher import Searcher, SearchResult
|
|||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_darts.random_dartsspace_far_searcher import RandomDartsSpaceFarSearcher
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_search_darts_space.random_dartsspace_far_searcher import RandomDartsSpaceFarSearcher
|
||||
|
||||
class RandomDartsSpaceFarExpRunner(ExperimentRunner):
|
||||
''' Runs random search using FastArchRank on DARTS search space '''
|
|
@ -18,7 +18,7 @@ from archai.nas.finalizers import Finalizers
|
|||
from archai.nas.model import Model
|
||||
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
|
||||
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
|
||||
class RandomDartsSpaceFarSearcher(Searcher):
|
|
@ -12,8 +12,8 @@ from archai.nas.searcher import Searcher, SearchResult
|
|||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.random_finalizers import RandomFinalizers
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_darts.random_dartsspace_reg_searcher import RandomDartsSpaceRegSearcher
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_search_darts_space.random_dartsspace_reg_searcher import RandomDartsSpaceRegSearcher
|
||||
|
||||
class RandomDartsSpaceRegExpRunner(ExperimentRunner):
|
||||
''' Runs random search using FastArchRank on DARTS search space '''
|
|
@ -16,7 +16,7 @@ from archai.common.trainer import Trainer
|
|||
from archai.common import utils
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.model import Model
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
|
||||
|
||||
class RandomDartsSpaceRegSearcher(Searcher):
|
|
@ -21,7 +21,7 @@ from archai.common import utils
|
|||
|
||||
|
||||
def load_spherical_data(path, val_split=0.0):
|
||||
''' Copied and modified from
|
||||
''' Modified from
|
||||
https://github.com/rtu715/NAS-Bench-360/blob/main/backbone/utils_pt.py '''
|
||||
|
||||
data_file = os.path.join(path, 's2_cifar100.gz')
|
||||
|
@ -30,11 +30,11 @@ def load_spherical_data(path, val_split=0.0):
|
|||
|
||||
train_data = torch.from_numpy(
|
||||
dataset["train"]["images"][:, None, :, :].astype(np.float32))
|
||||
train_data = torch.squeeze(train_data)
|
||||
train_labels = torch.from_numpy(
|
||||
dataset["train"]["labels"].astype(np.int64))
|
||||
|
||||
all_train_dataset = data_utils.TensorDataset(train_data, train_labels)
|
||||
print(len(all_train_dataset))
|
||||
if val_split == 0.0:
|
||||
val_dataset = None
|
||||
train_dataset = all_train_dataset
|
||||
|
@ -46,9 +46,9 @@ def load_spherical_data(path, val_split=0.0):
|
|||
val_dataset = data_utils.TensorDataset(train_data[ntrain:], train_labels[ntrain:])
|
||||
val_dataset.targets = train_labels[ntrain:] # compatibility with stratified sampler
|
||||
|
||||
|
||||
test_data = torch.from_numpy(
|
||||
dataset["test"]["images"][:, None, :, :].astype(np.float32))
|
||||
test_data = torch.squeeze(test_data)
|
||||
test_labels = torch.from_numpy(
|
||||
dataset["test"]["labels"].astype(np.int64))
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import random
|
|||
from typing import List, Optional
|
||||
from overrides.overrides import overrides
|
||||
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.common.config import Config
|
||||
from archai.nas.model import Model
|
||||
from archai.nas.arch_meta import ArchWithMetaData
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
nas:
|
||||
search:
|
||||
model_desc:
|
||||
num_edges_to_sample: 2 # number of edges each node will take input from
|
||||
|
||||
eval:
|
||||
dartsspace:
|
||||
arch_index: 66
|
||||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
n_cells: 8
|
||||
loader:
|
||||
val_ratio: 0.0
|
||||
train_batch: 96
|
||||
freeze_loader:
|
||||
train_batch: 96 # batch size for freeze training
|
||||
trainer:
|
||||
use_val: False
|
||||
plotsdir: ''
|
||||
epochs: 100
|
||||
|
||||
|
|
@ -8,12 +8,13 @@ from archai.common import utils
|
|||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.algos.darts.darts_exp_runner import DartsExperimentRunner
|
||||
from archai.algos.petridish.petridish_exp_runner import PetridishExperimentRunner
|
||||
from archai.algos.random.random_exp_runner import RandomExperimentRunner
|
||||
from archai.algos.random_sample_darts_space.random_exp_runner import RandomExperimentRunner
|
||||
from archai.algos.manual.manual_exp_runner import ManualExperimentRunner
|
||||
from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
|
||||
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
|
||||
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
|
||||
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
|
||||
from archai.algos.random_sample_darts_space.darts_space_constant_random_archs_exp_runner import DartsSpaceConstantRandomArchsExperimentRunner
|
||||
from archai.algos.proxynas.freeze_darts_space_experiment_runner import FreezeDartsSpaceExperimentRunner
|
||||
from archai.algos.proxynas.freeze_natsbench_experiment_runner import FreezeNatsbenchExperimentRunner
|
||||
from archai.algos.proxynas.freeze_natsbench_sss_experiment_runner import FreezeNatsbenchSSSExperimentRunner
|
||||
|
@ -28,11 +29,11 @@ from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import Free
|
|||
from archai.algos.zero_cost_measures.zero_cost_natsbench_experiment_runner import ZeroCostNatsbenchExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
|
||||
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
||||
from archai.algos.random_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
||||
from archai.algos.random_darts.random_dartsspace_reg_exp_runner import RandomDartsSpaceRegExpRunner
|
||||
from archai.algos.random_darts.random_dartsspace_far_exp_runner import RandomDartsSpaceFarExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
|
||||
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
|
||||
from archai.algos.random_search_darts_space.random_dartsspace_reg_exp_runner import RandomDartsSpaceRegExpRunner
|
||||
from archai.algos.random_search_darts_space.random_dartsspace_far_exp_runner import RandomDartsSpaceFarExpRunner
|
||||
from archai.algos.local_search_natsbench.local_natsbench_tss_far_exp_runner import LocalNatsbenchTssFarExpRunner
|
||||
from archai.algos.local_search_natsbench.local_search_natsbench_tss_fear_exp_runner import LocalSearchNatsbenchTSSFearExpRunner
|
||||
from archai.algos.local_search_natsbench.local_search_natsbench_tss_reg_exp_runner import LocalSearchNatsbenchTSSRegExpRunner
|
||||
|
@ -48,6 +49,7 @@ def main():
|
|||
'gs': GsExperimentRunner,
|
||||
'divnas': DivnasExperimentRunner,
|
||||
'didarts': DiDartsExperimentRunner,
|
||||
'darts_space_constant_random_archs': DartsSpaceConstantRandomArchsExperimentRunner,
|
||||
'proxynas_darts_space': FreezeDartsSpaceExperimentRunner,
|
||||
'proxynas_natsbench_space': FreezeNatsbenchExperimentRunner,
|
||||
'proxynas_natsbench_sss_space': FreezeNatsbenchSSSExperimentRunner,
|
||||
|
@ -77,11 +79,12 @@ def main():
|
|||
parser.add_argument('--algos', type=str, default='''darts,
|
||||
xnas,
|
||||
random,
|
||||
didarts,
|
||||
didarts,
|
||||
petridish,
|
||||
gs,
|
||||
manual,
|
||||
divnas,
|
||||
darts_space_constant_random_archs,
|
||||
proxynas_manual,
|
||||
proxynas_darts_space,
|
||||
proxynas_natsbench_space,
|
||||
|
|
Загрузка…
Ссылка в новой задаче