зеркало из https://github.com/microsoft/archai.git
More progress on local search with DARTS.
This commit is contained in:
Родитель
7169c9d36f
Коммит
9f34cfa8ea
|
@ -572,6 +572,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "local_search_natsbench_tss_reg", "--datasets", "cifar10", "--no-eval"]
|
||||
},
|
||||
{
|
||||
"name": "Local Search DARTS Reg",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "local_search_darts_reg", "--datasets", "cifar10", "--no-eval"]
|
||||
},
|
||||
{
|
||||
"name": "Analysis Aggregate",
|
||||
"type": "python",
|
||||
|
|
|
@ -8,6 +8,7 @@ from archai.nas.model import Model
|
|||
from archai.nas.arch_meta import ArchWithMetaData
|
||||
from archai.nas.discrete_search_space import DiscreteSearchSpace
|
||||
from archai.nas.model_desc import CellDesc, ModelDesc, CellType
|
||||
from archai.common.common import get_conf
|
||||
|
||||
|
||||
class DiscreteSearchSpaceDARTS(DiscreteSearchSpace):
|
||||
|
@ -82,6 +83,9 @@ class DiscreteSearchSpaceDARTS(DiscreteSearchSpace):
|
|||
conf_model_desc:Config,
|
||||
seed:Optional[int]=None)->ArchWithMetaData:
|
||||
''' Uniform random sample an architecture '''
|
||||
config = get_conf()
|
||||
|
||||
|
||||
model_desc = self.random_model_desc_builder.build(conf_model_desc, seed=seed)
|
||||
model = Model(model_desc, affine=True)
|
||||
meta_data = {'archid': self.arch_counter}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
|
||||
nas:
|
||||
search:
|
||||
max_num_models: 300
|
||||
model_desc:
|
||||
num_edges_to_sample: 2 # number of edges each node will take input from for random model construction
|
||||
|
||||
eval:
|
||||
model_desc:
|
||||
n_cells: 8
|
||||
trainer:
|
||||
epochs: 1
|
||||
|
|
@ -36,7 +36,7 @@ from archai.algos.random_darts.random_dartsspace_far_exp_runner import RandomDar
|
|||
from archai.algos.local_search_natsbench.local_natsbench_tss_far_exp_runner import LocalNatsbenchTssFarExpRunner
|
||||
from archai.algos.local_search_natsbench.local_search_natsbench_tss_fear_exp_runner import LocalSearchNatsbenchTSSFearExpRunner
|
||||
from archai.algos.local_search_natsbench.local_search_natsbench_tss_reg_exp_runner import LocalSearchNatsbenchTSSRegExpRunner
|
||||
|
||||
from archai.algos.local_search_darts.local_search_darts_reg_exp_runner import LocalSearchDartsRegExpRunner
|
||||
|
||||
def main():
|
||||
runner_types:Dict[str, Type[ExperimentRunner]] = {
|
||||
|
@ -69,7 +69,8 @@ def main():
|
|||
'random_dartsspace_far': RandomDartsSpaceFarExpRunner,
|
||||
'local_natsbench_tss_far': LocalNatsbenchTssFarExpRunner,
|
||||
'local_search_natsbench_tss_reg': LocalSearchNatsbenchTSSRegExpRunner,
|
||||
'local_search_natsbench_tss_fear': LocalSearchNatsbenchTSSFearExpRunner
|
||||
'local_search_natsbench_tss_fear': LocalSearchNatsbenchTSSFearExpRunner,
|
||||
'local_search_darts_reg': LocalSearchDartsRegExpRunner
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS E2E Runs')
|
||||
|
|
Загрузка…
Ссылка в новой задаче