Changed proxynas natsbench space to optionally use train or val error thresholds.

This commit is contained in:
Debadeepta Dey 2021-03-25 10:25:23 -07:00 коммит произвёл Gustavo Rosa
Родитель d7c4b382e6
Коммит 407c02871c
3 изменённых файлов: 12 добавлений и 6 удалений

2
.vscode/launch.json поставляемый
Просмотреть файл

@ -228,7 +228,7 @@
"request": "launch", "request": "launch",
"program": "${cwd}/scripts/main.py", "program": "${cwd}/scripts/main.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["--full", "--algos", "proxynas_natsbench_space", "--datasets", "synthetic_cifar10"] "args": ["--full", "--algos", "proxynas_natsbench_space", "--datasets", "ImageNet16-120"]
}, },
{ {
"name": "Proxynas-Natsbench-Space-Toy", "name": "Proxynas-Natsbench-Space-Toy",

Просмотреть файл

@ -34,17 +34,21 @@ class ConditionalTrainer(ArchTrainer, EnforceOverrides):
super().__init__(conf_train, model, checkpoint) super().__init__(conf_train, model, checkpoint)
# region config vars specific to freeze trainer # region config vars specific to freeze trainer
self._train_top1_acc_threshold = conf_train['train_top1_acc_threshold'] self._top1_acc_threshold = conf_train['top1_acc_threshold']
self._use_val = conf_train['use_val']
# endregion # endregion
@overrides @overrides
def _should_terminate(self): def _should_terminate(self):
# if current validation accuracy is above threshold # if current validation accuracy is above threshold
# terminate training # terminate training
best_train_top1_avg = self._metrics.best_train_top1() if self._use_val:
best_top1_avg = self._metrics.best_val_top1()
else:
best_top1_avg = self._metrics.best_train_top1()
if best_train_top1_avg >= self._train_top1_acc_threshold: if best_top1_avg >= self._top1_acc_threshold:
logger.info(f'terminating at {best_train_top1_avg}') logger.info(f'terminating at {best_top1_avg}')
logger.info('----------terminating regular training---------') logger.info('----------terminating regular training---------')
return True return True
else: else:

Просмотреть файл

@ -13,6 +13,7 @@ nas:
model_desc: model_desc:
num_edges_to_sample: 2 num_edges_to_sample: 2
loader: loader:
val_ratio: 0.2
train_batch: 256 train_batch: 256
aug: '' # in natsbench paper they use random flip and crop, which are part of the regular transforms aug: '' # in natsbench paper they use random flip and crop, which are part of the regular transforms
naswotrain: naswotrain:
@ -21,7 +22,8 @@ nas:
train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
trainer: trainer:
plotsdir: '' plotsdir: ''
train_top1_acc_threshold: 0.15 # after some accuracy we will shift into training only the last 'n' layers use_val: True
top1_acc_threshold: 0.2 # after some accuracy we will shift into training only the last 'n' layers
apex: apex:
_copy: '/common/apex' _copy: '/common/apex'
aux_weight: '_copy: /nas/eval/model_desc/aux_weight' aux_weight: '_copy: /nas/eval/model_desc/aux_weight'