зеркало из https://github.com/microsoft/archai.git
Added jobs for DARTS FEAR experiments.
This commit is contained in:
Родитель
563d0a744f
Коммит
cc1f417171
|
@ -689,7 +689,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_create_darts_space_benchmark.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_fixed",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -66,9 +66,10 @@ class FreezeDartsSpaceExperimentRunner(ExperimentRunner):
|
|||
conf_eval['checkpoint']['filename'] = '$expdir/freeze_checkpoint.pth'
|
||||
|
||||
logger.pushd('freeze_evaluate')
|
||||
freeze_evaler = FreezeDartsSpaceEvaluater()
|
||||
conf_eval_freeze = deepcopy(conf_eval)
|
||||
freeze_eval_result = freeze_evaler.evaluate(conf_eval_freeze, model_desc_builder=self.model_desc_builder())
|
||||
if conf_eval['trainer']['train_fear']:
|
||||
freeze_evaler = FreezeDartsSpaceEvaluater()
|
||||
conf_eval_freeze = deepcopy(conf_eval)
|
||||
freeze_eval_result = freeze_evaler.evaluate(conf_eval_freeze, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
# NOTE: Not returning freeze eval results to meet signature contract
|
||||
|
|
|
@ -19,12 +19,14 @@ nas:
|
|||
trainer:
|
||||
use_val: False
|
||||
plotsdir: ''
|
||||
epochs: 100
|
||||
epochs: 20
|
||||
top1_acc_threshold: 0.60 # after some accuracy we will shift into training only the last 'n' layers
|
||||
train_regular: True # if False the full regular training of the architecture will be bypassed
|
||||
train_fear: False
|
||||
|
||||
freeze_trainer:
|
||||
plotsdir: ''
|
||||
bypass_freeze: False
|
||||
identifiers_to_unfreeze: ['logits_op._op', 'cells.7', 'cells.6'] # last few layer names in DARTS space are 'logits_op._op': Linear, 'cells.19': prefix for all cell 19 parameters
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
|
|
|
@ -70,7 +70,7 @@ def main():
|
|||
# a = parse_a_job(job_dir)
|
||||
|
||||
# parallel parsing of yaml logs
|
||||
num_workers = 9
|
||||
num_workers = 8
|
||||
with Pool(num_workers) as p:
|
||||
a = p.map(parse_a_job, job_dirs)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче