Refactored code a bit for better organization.

This commit is contained in:
Debadeepta Dey 2021-12-01 15:40:11 -08:00 коммит произвёл Gustavo Rosa
Родитель b9f45f158f
Коммит 1620f0eafc
26 изменённых файлов: 117 добавлений и 33 удалений

10
.vscode/launch.json поставляемый
Просмотреть файл

@ -214,13 +214,21 @@
"console": "integratedTerminal",
"args": ["--algos", "random"]
},
{
"name": "Darts Space Constant Random Archs",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "sphericalcifar100"]
},
{
"name": "Proxynas-Darts-Space-Full",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "proxynas_darts_space", "--datasets", "sphericalcifar100"]
"args": ["--full", "--algos", "proxynas_darts_space", "--datasets", "cifar10"]
},
{
"name": "Proxynas-Darts-Space-Toy",

Просмотреть файл

@ -11,7 +11,7 @@ from archai.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas.operations import MultiOp, Op
from archai.common.config import Config
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from .petridish_op import PetridishOp, TempIdentityOp

Просмотреть файл

@ -29,8 +29,8 @@ from archai.common.checkpoint import CheckPoint
from archai.nas.evaluater import Evaluater
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
from archai.algos.proxynas.constant_darts_space_sampler import ConstantDartsSpaceSampler
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.constant_darts_space_sampler import ConstantDartsSpaceSampler
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
class FreezeDartsSpaceEvaluater(Evaluater):
@overrides

Просмотреть файл

@ -15,9 +15,9 @@ from archai.nas.evaluater import Evaluater, EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
from .darts_space_evaluater import DartsSpaceEvaluater
from ..random_sample_darts_space.darts_space_evaluater import DartsSpaceEvaluater
from .freeze_darts_space_evaluater import FreezeDartsSpaceEvaluater
class FreezeDartsSpaceExperimentRunner(ExperimentRunner):

Просмотреть файл

@ -1,9 +1,5 @@
import random
class ConstantDartsSpaceSampler():
''' Always returns the same set of random seeds to be used for
reproducible sampling of architectures from the DARTS search space'''

Просмотреть файл

@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from archai.nas.evaluater import EvalResult
from typing import Type
from copy import deepcopy
from overrides import overrides
from archai.common.config import Config
from archai.nas import nas_utils
from archai.nas.exp_runner import ExperimentRunner
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
from archai.nas.evaluater import Evaluater, EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.proxynas.freeze_manual_searcher import ManualFreezeSearcher
from .darts_space_evaluater import DartsSpaceEvaluater
class DartsSpaceConstantRandomArchsExperimentRunner(ExperimentRunner):
''' Samples a reproducible random architecture from DARTS search space and freeze trains it '''
@overrides
def model_desc_builder(self)->RandomModelDescBuilder:
return RandomModelDescBuilder()
@overrides
def trainer_class(self)->TArchTrainer:
return None
@overrides
def searcher(self)->ManualFreezeSearcher:
return ManualFreezeSearcher() # no searcher basically
@overrides
def copy_search_to_eval(self)->None:
pass
@overrides
def run_eval(self, conf_eval:Config)->EvalResult:
# regular evaluation of the architecture
# this is expensive
# --------------------------------------
logger.pushd('regular_evaluate')
evaler = DartsSpaceEvaluater()
conf_eval_reg = deepcopy(conf_eval)
reg_eval_result = evaler.evaluate(conf_eval_reg, model_desc_builder=self.model_desc_builder())
logger.popd()
return reg_eval_result

Просмотреть файл

@ -30,8 +30,8 @@ from archai.common.checkpoint import CheckPoint
from archai.nas.evaluater import Evaluater
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
from archai.algos.proxynas.constant_darts_space_sampler import ConstantDartsSpaceSampler
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.constant_darts_space_sampler import ConstantDartsSpaceSampler
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
class DartsSpaceEvaluater(Evaluater):
@overrides
@ -64,7 +64,8 @@ class DartsSpaceEvaluater(Evaluater):
seed=random_seed_for_model_construction)
# convert from ModelDesc to Genotype for use with nasbench301
# NOTE: this is just showing how to potentially connect with 301.
# NOTE: this is just showing how to potentially connect with 301.
# This is for future use.
genotype = create_nb301_genotype_from_desc(model_desc)
print(genotype)

Просмотреть файл

@ -12,7 +12,7 @@ from archai.nas.searcher import Searcher, SearchResult
from archai.nas.finalizers import Finalizers
from archai.nas.random_finalizers import RandomFinalizers
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.algos.random_natsbench.random_natsbench_tss_far_searcher import RandomNatsbenchTssFarSearcher
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_searcher import RandomNatsbenchTssFarSearcher
class RandomNatsbenchTssFarExpRunner(ExperimentRunner):

Просмотреть файл

@ -12,7 +12,7 @@ from archai.nas.searcher import Searcher, SearchResult
from archai.nas.finalizers import Finalizers
from archai.nas.random_finalizers import RandomFinalizers
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.algos.random_natsbench.random_natsbench_tss_far_post_searcher import RandomNatsbenchTssFarPostSearcher
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_searcher import RandomNatsbenchTssFarPostSearcher
class RandomNatsbenchTssFarPostExpRunner(ExperimentRunner):

Просмотреть файл

@ -12,7 +12,7 @@ from archai.nas.searcher import Searcher, SearchResult
from archai.nas.finalizers import Finalizers
from archai.nas.random_finalizers import RandomFinalizers
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.algos.random_natsbench.random_natsbench_tss_reg_searcher import RandomNatsbenchTssRegSearcher
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_searcher import RandomNatsbenchTssRegSearcher
class RandomNatsbenchTssRegExpRunner(ExperimentRunner):

Просмотреть файл

@ -12,8 +12,8 @@ from archai.nas.searcher import Searcher, SearchResult
from archai.nas.finalizers import Finalizers
from archai.nas.random_finalizers import RandomFinalizers
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_darts.random_dartsspace_far_searcher import RandomDartsSpaceFarSearcher
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_search_darts_space.random_dartsspace_far_searcher import RandomDartsSpaceFarSearcher
class RandomDartsSpaceFarExpRunner(ExperimentRunner):
''' Runs random search using FastArchRank on DARTS search space '''

Просмотреть файл

@ -18,7 +18,7 @@ from archai.nas.finalizers import Finalizers
from archai.nas.model import Model
from archai.algos.proxynas.conditional_trainer import ConditionalTrainer
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
class RandomDartsSpaceFarSearcher(Searcher):

Просмотреть файл

@ -12,8 +12,8 @@ from archai.nas.searcher import Searcher, SearchResult
from archai.nas.finalizers import Finalizers
from archai.nas.random_finalizers import RandomFinalizers
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_darts.random_dartsspace_reg_searcher import RandomDartsSpaceRegSearcher
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_search_darts_space.random_dartsspace_reg_searcher import RandomDartsSpaceRegSearcher
class RandomDartsSpaceRegExpRunner(ExperimentRunner):
''' Runs random search using FastArchRank on DARTS search space '''

Просмотреть файл

@ -16,7 +16,7 @@ from archai.common.trainer import Trainer
from archai.common import utils
from archai.nas.finalizers import Finalizers
from archai.nas.model import Model
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
class RandomDartsSpaceRegSearcher(Searcher):

Просмотреть файл

@ -21,7 +21,7 @@ from archai.common import utils
def load_spherical_data(path, val_split=0.0):
''' Copied and modified from
''' Modified from
https://github.com/rtu715/NAS-Bench-360/blob/main/backbone/utils_pt.py '''
data_file = os.path.join(path, 's2_cifar100.gz')
@ -30,11 +30,11 @@ def load_spherical_data(path, val_split=0.0):
train_data = torch.from_numpy(
dataset["train"]["images"][:, None, :, :].astype(np.float32))
train_data = torch.squeeze(train_data)
train_labels = torch.from_numpy(
dataset["train"]["labels"].astype(np.int64))
all_train_dataset = data_utils.TensorDataset(train_data, train_labels)
print(len(all_train_dataset))
if val_split == 0.0:
val_dataset = None
train_dataset = all_train_dataset
@ -46,9 +46,9 @@ def load_spherical_data(path, val_split=0.0):
val_dataset = data_utils.TensorDataset(train_data[ntrain:], train_labels[ntrain:])
val_dataset.targets = train_labels[ntrain:] # compatibility with stratified sampler
test_data = torch.from_numpy(
dataset["test"]["images"][:, None, :, :].astype(np.float32))
test_data = torch.squeeze(test_data)
test_labels = torch.from_numpy(
dataset["test"]["labels"].astype(np.int64))

Просмотреть файл

@ -3,7 +3,7 @@ import random
from typing import List, Optional
from overrides.overrides import overrides
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.random_sample_darts_space.random_model_desc_builder import RandomModelDescBuilder
from archai.common.config import Config
from archai.nas.model import Model
from archai.nas.arch_meta import ArchWithMetaData

Просмотреть файл

@ -0,0 +1,24 @@
__include__: 'darts.yaml' # just use darts defaults
nas:
search:
model_desc:
num_edges_to_sample: 2 # number of edges each node will take input from
eval:
dartsspace:
arch_index: 66
model_desc:
num_edges_to_sample: 2
n_cells: 8
loader:
val_ratio: 0.0
train_batch: 96
freeze_loader:
train_batch: 96 # batch size for freeze training
trainer:
use_val: False
plotsdir: ''
epochs: 100

Просмотреть файл

@ -8,12 +8,13 @@ from archai.common import utils
from archai.nas.exp_runner import ExperimentRunner
from archai.algos.darts.darts_exp_runner import DartsExperimentRunner
from archai.algos.petridish.petridish_exp_runner import PetridishExperimentRunner
from archai.algos.random.random_exp_runner import RandomExperimentRunner
from archai.algos.random_sample_darts_space.random_exp_runner import RandomExperimentRunner
from archai.algos.manual.manual_exp_runner import ManualExperimentRunner
from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
from archai.algos.random_sample_darts_space.darts_space_constant_random_archs_exp_runner import DartsSpaceConstantRandomArchsExperimentRunner
from archai.algos.proxynas.freeze_darts_space_experiment_runner import FreezeDartsSpaceExperimentRunner
from archai.algos.proxynas.freeze_natsbench_experiment_runner import FreezeNatsbenchExperimentRunner
from archai.algos.proxynas.freeze_natsbench_sss_experiment_runner import FreezeNatsbenchSSSExperimentRunner
@ -28,11 +29,11 @@ from archai.algos.proxynas.freezeaddon_nasbench101_experiment_runner import Free
from archai.algos.zero_cost_measures.zero_cost_natsbench_experiment_runner import ZeroCostNatsbenchExperimentRunner
from archai.algos.zero_cost_measures.zero_cost_natsbench_conditional_experiment_runner import ZeroCostConditionalNatsbenchExperimentRunner
from archai.algos.zero_cost_measures.zero_cost_natsbench_epochs_experiment_runner import ZeroCostNatsbenchEpochsExperimentRunner
from archai.algos.random_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
from archai.algos.random_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
from archai.algos.random_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
from archai.algos.random_darts.random_dartsspace_reg_exp_runner import RandomDartsSpaceRegExpRunner
from archai.algos.random_darts.random_dartsspace_far_exp_runner import RandomDartsSpaceFarExpRunner
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_exp_runner import RandomNatsbenchTssFarExpRunner
from archai.algos.random_sample_natsbench.random_natsbench_tss_far_post_exp_runner import RandomNatsbenchTssFarPostExpRunner
from archai.algos.random_sample_natsbench.random_natsbench_tss_reg_exp_runner import RandomNatsbenchTssRegExpRunner
from archai.algos.random_search_darts_space.random_dartsspace_reg_exp_runner import RandomDartsSpaceRegExpRunner
from archai.algos.random_search_darts_space.random_dartsspace_far_exp_runner import RandomDartsSpaceFarExpRunner
from archai.algos.local_search_natsbench.local_natsbench_tss_far_exp_runner import LocalNatsbenchTssFarExpRunner
from archai.algos.local_search_natsbench.local_search_natsbench_tss_fear_exp_runner import LocalSearchNatsbenchTSSFearExpRunner
from archai.algos.local_search_natsbench.local_search_natsbench_tss_reg_exp_runner import LocalSearchNatsbenchTSSRegExpRunner
@ -48,6 +49,7 @@ def main():
'gs': GsExperimentRunner,
'divnas': DivnasExperimentRunner,
'didarts': DiDartsExperimentRunner,
'darts_space_constant_random_archs': DartsSpaceConstantRandomArchsExperimentRunner,
'proxynas_darts_space': FreezeDartsSpaceExperimentRunner,
'proxynas_natsbench_space': FreezeNatsbenchExperimentRunner,
'proxynas_natsbench_sss_space': FreezeNatsbenchSSSExperimentRunner,
@ -77,11 +79,12 @@ def main():
parser.add_argument('--algos', type=str, default='''darts,
xnas,
random,
didarts,
didarts,
petridish,
gs,
manual,
divnas,
darts_space_constant_random_archs,
proxynas_manual,
proxynas_darts_space,
proxynas_natsbench_space,