More progress on local search with DARTS.

This commit is contained in:
Debadeepta Dey 2021-11-03 15:55:05 -07:00 коммит произвёл Gustavo Rosa
Родитель 7169c9d36f
Коммит 9f34cfa8ea
4 изменённых файлов: 30 добавлений и 2 удалений

8
.vscode/launch.json поставляемый
Просмотреть файл

@ -572,6 +572,14 @@
"console": "integratedTerminal",
"args": ["--full", "--algos", "local_search_natsbench_tss_reg", "--datasets", "cifar10", "--no-eval"]
},
{
"name": "Local Search DARTS Reg",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "local_search_darts_reg", "--datasets", "cifar10", "--no-eval"]
},
{
"name": "Analysis Aggregate",
"type": "python",

Просмотреть файл

@ -8,6 +8,7 @@ from archai.nas.model import Model
from archai.nas.arch_meta import ArchWithMetaData
from archai.nas.discrete_search_space import DiscreteSearchSpace
from archai.nas.model_desc import CellDesc, ModelDesc, CellType
from archai.common.common import get_conf
class DiscreteSearchSpaceDARTS(DiscreteSearchSpace):
@ -82,6 +83,9 @@ class DiscreteSearchSpaceDARTS(DiscreteSearchSpace):
conf_model_desc:Config,
seed:Optional[int]=None)->ArchWithMetaData:
''' Uniform random sample an architecture '''
config = get_conf()
model_desc = self.random_model_desc_builder.build(conf_model_desc, seed=seed)
model = Model(model_desc, affine=True)
meta_data = {'archid': self.arch_counter}

Просмотреть файл

@ -0,0 +1,15 @@
__include__: 'darts.yaml' # just use darts defaults
nas:
search:
max_num_models: 300
model_desc:
num_edges_to_sample: 2 # number of edges each node will take input from for random model construction
eval:
model_desc:
n_cells: 8
trainer:
epochs: 1

Просмотреть файл

@ -36,7 +36,7 @@ from archai.algos.random_darts.random_dartsspace_far_exp_runner import RandomDar
from archai.algos.local_search_natsbench.local_natsbench_tss_far_exp_runner import LocalNatsbenchTssFarExpRunner
from archai.algos.local_search_natsbench.local_search_natsbench_tss_fear_exp_runner import LocalSearchNatsbenchTSSFearExpRunner
from archai.algos.local_search_natsbench.local_search_natsbench_tss_reg_exp_runner import LocalSearchNatsbenchTSSRegExpRunner
from archai.algos.local_search_darts.local_search_darts_reg_exp_runner import LocalSearchDartsRegExpRunner
def main():
runner_types:Dict[str, Type[ExperimentRunner]] = {
@ -69,7 +69,8 @@ def main():
'random_dartsspace_far': RandomDartsSpaceFarExpRunner,
'local_natsbench_tss_far': LocalNatsbenchTssFarExpRunner,
'local_search_natsbench_tss_reg': LocalSearchNatsbenchTSSRegExpRunner,
'local_search_natsbench_tss_fear': LocalSearchNatsbenchTSSFearExpRunner
'local_search_natsbench_tss_fear': LocalSearchNatsbenchTSSFearExpRunner,
'local_search_darts_reg': LocalSearchDartsRegExpRunner
}
parser = argparse.ArgumentParser(description='NAS E2E Runs')