зеркало из https://github.com/microsoft/archai.git
Changed proxynas natsbench space to optionally use train or val error thresholds.
This commit is contained in:
Родитель
d7c4b382e6
Коммит
407c02871c
|
@ -228,7 +228,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "proxynas_natsbench_space", "--datasets", "synthetic_cifar10"]
|
||||
"args": ["--full", "--algos", "proxynas_natsbench_space", "--datasets", "ImageNet16-120"]
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Natsbench-Space-Toy",
|
||||
|
|
|
@ -34,17 +34,21 @@ class ConditionalTrainer(ArchTrainer, EnforceOverrides):
|
|||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
# region config vars specific to freeze trainer
|
||||
self._train_top1_acc_threshold = conf_train['train_top1_acc_threshold']
|
||||
self._top1_acc_threshold = conf_train['top1_acc_threshold']
|
||||
self._use_val = conf_train['use_val']
|
||||
# endregion
|
||||
|
||||
@overrides
|
||||
def _should_terminate(self):
|
||||
# if current validation accuracy is above threshold
|
||||
# terminate training
|
||||
best_train_top1_avg = self._metrics.best_train_top1()
|
||||
if self._use_val:
|
||||
best_top1_avg = self._metrics.best_val_top1()
|
||||
else:
|
||||
best_top1_avg = self._metrics.best_train_top1()
|
||||
|
||||
if best_train_top1_avg >= self._train_top1_acc_threshold:
|
||||
logger.info(f'terminating at {best_train_top1_avg}')
|
||||
if best_top1_avg >= self._top1_acc_threshold:
|
||||
logger.info(f'terminating at {best_top1_avg}')
|
||||
logger.info('----------terminating regular training---------')
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -13,6 +13,7 @@ nas:
|
|||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
loader:
|
||||
val_ratio: 0.2
|
||||
train_batch: 256
|
||||
aug: '' # in natsbench paper they use random flip and crop, which are part of the regular transforms
|
||||
naswotrain:
|
||||
|
@ -21,7 +22,8 @@ nas:
|
|||
train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
|
||||
trainer:
|
||||
plotsdir: ''
|
||||
train_top1_acc_threshold: 0.15 # after some accuracy we will shift into training only the last 'n' layers
|
||||
use_val: True
|
||||
top1_acc_threshold: 0.2 # after some accuracy we will shift into training only the last 'n' layers
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
||||
|
|
Загрузка…
Ссылка в новой задаче