зеркало из https://github.com/microsoft/archai.git
Changed proxynas natsbench space to optionally use train or val error thresholds.
This commit is contained in:
Родитель
d7c4b382e6
Коммит
407c02871c
|
@ -228,7 +228,7 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${cwd}/scripts/main.py",
|
"program": "${cwd}/scripts/main.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["--full", "--algos", "proxynas_natsbench_space", "--datasets", "synthetic_cifar10"]
|
"args": ["--full", "--algos", "proxynas_natsbench_space", "--datasets", "ImageNet16-120"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Proxynas-Natsbench-Space-Toy",
|
"name": "Proxynas-Natsbench-Space-Toy",
|
||||||
|
|
|
@ -34,17 +34,21 @@ class ConditionalTrainer(ArchTrainer, EnforceOverrides):
|
||||||
super().__init__(conf_train, model, checkpoint)
|
super().__init__(conf_train, model, checkpoint)
|
||||||
|
|
||||||
# region config vars specific to freeze trainer
|
# region config vars specific to freeze trainer
|
||||||
self._train_top1_acc_threshold = conf_train['train_top1_acc_threshold']
|
self._top1_acc_threshold = conf_train['top1_acc_threshold']
|
||||||
|
self._use_val = conf_train['use_val']
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def _should_terminate(self):
|
def _should_terminate(self):
|
||||||
# if current validation accuracy is above threshold
|
# if current validation accuracy is above threshold
|
||||||
# terminate training
|
# terminate training
|
||||||
best_train_top1_avg = self._metrics.best_train_top1()
|
if self._use_val:
|
||||||
|
best_top1_avg = self._metrics.best_val_top1()
|
||||||
|
else:
|
||||||
|
best_top1_avg = self._metrics.best_train_top1()
|
||||||
|
|
||||||
if best_train_top1_avg >= self._train_top1_acc_threshold:
|
if best_top1_avg >= self._top1_acc_threshold:
|
||||||
logger.info(f'terminating at {best_train_top1_avg}')
|
logger.info(f'terminating at {best_top1_avg}')
|
||||||
logger.info('----------terminating regular training---------')
|
logger.info('----------terminating regular training---------')
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,6 +13,7 @@ nas:
|
||||||
model_desc:
|
model_desc:
|
||||||
num_edges_to_sample: 2
|
num_edges_to_sample: 2
|
||||||
loader:
|
loader:
|
||||||
|
val_ratio: 0.2
|
||||||
train_batch: 256
|
train_batch: 256
|
||||||
aug: '' # in natsbench paper they use random flip and crop, which are part of the regular transforms
|
aug: '' # in natsbench paper they use random flip and crop, which are part of the regular transforms
|
||||||
naswotrain:
|
naswotrain:
|
||||||
|
@ -21,7 +22,8 @@ nas:
|
||||||
train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
|
train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
|
||||||
trainer:
|
trainer:
|
||||||
plotsdir: ''
|
plotsdir: ''
|
||||||
train_top1_acc_threshold: 0.15 # after some accuracy we will shift into training only the last 'n' layers
|
use_val: True
|
||||||
|
top1_acc_threshold: 0.2 # after some accuracy we will shift into training only the last 'n' layers
|
||||||
apex:
|
apex:
|
||||||
_copy: '/common/apex'
|
_copy: '/common/apex'
|
||||||
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
||||||
|
|
Загрузка…
Ссылка в новой задаче