Added lots of sanity checks but freeze training is still degrading perf!

This commit is contained in:
Debadeepta Dey 2020-12-01 19:22:39 -08:00 коммит произвёл Gustavo Rosa
Родитель e5bf0312e5
Коммит bb80965127
3 изменённых файлов: 18 добавлений и 13 удалений

Просмотреть файл

@ -31,7 +31,7 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
super().__init__(conf_train, model, checkpoint)
# region config vars specific to freeze trainer
self._val_top1_acc = conf_train['val_top1_acc_threshold']
self._val_top1_acc = conf_train['proxynas']['val_top1_acc_threshold']
self._in_freeze_mode = False
# endregion
@ -55,25 +55,23 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
del self._multi_optim
self.conf_optim['lr'] = self.conf_train['proxynas']['freeze_lr']
self.conf_optim['decay'] = self.conf_train['proxynas']['freeze_decay']
self.conf_optiom['momentum'] = self.conf_train['proxynas']['freeze_momentum']
self.conf_sched = Config()
self._aux_weight = self.conf_train['proxynas']['aux_weight']
self.model.zero_grad()
self._multi_optim = self.create_multi_optim(len(train_dl))
# before checkpoint restore, convert to amp
self.model = self._apex.to_amp(self.model, self._multi_optim,
batch_size=train_dl.batch_size)
self._in_freeze_mode = True
self._epoch_freeze_started = self._metrics.epochs()
self._max_epochs = self._epoch_freeze_started + self.conf_train['proxynas']['freeze_epochs']
logger.info('-----------Entered freeze training-----------------')
# TODO: Implement precise stopping at epoch
# if self._in_freeze_mode and self.metrics.epochs() > self._max_epochs:
# break
def freeze_but_last_layer(self) -> None:
@ -82,11 +80,16 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
# layer has the word 'logits' in the name string
# e.g. logits_op._op.weight, logits_op._op.bias
# e.g. _aux_towers.13.logits_op.weight, _aux_towers.13.logits_op.bias
# TODO: confirm from Shital that this is good!
for name, param in self.model.named_parameters():
param.requires_grad = False
for name, param in self.model.named_parameters():
if not 'logits' in name:
param.requires_grad = False
if 'logits_op._op' in name:
param.requires_grad = True
for name, param in self.model.named_parameters():
if param.requires_grad:
logger.info(f'{name} requires grad')

Просмотреть файл

@ -91,7 +91,7 @@ class Evaluater(EnforceOverrides):
def _default_module_name(self, dataset_name:str, function_name:str)->str:
"""Select PyTorch pre-defined network to support manual mode"""
module_name = ''
# TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead
# TODO: below detection code is too weak, need to improve, possibly encode image size in yaml and use that instead
if dataset_name.startswith('cifar'):
if function_name.startswith('res'): # support resnext as well
module_name = 'archai.cifar10_models.resnet'

Просмотреть файл

@ -11,10 +11,12 @@ nas:
num_edges_to_sample: 2
trainer:
plotsdir: ''
val_top1_acc_threshold: 0.1 # after some accuracy we will shift into training only the last layer
epochs: 600
proxynas:
val_top1_acc_threshold: 0.05 # after some accuracy we will shift into training only the last layer
freeze_epochs: 10
freeze_lr: 0.001
freeze_decay: 0.0
freeze_momentum: 0.0
train_regular: False
aux_weight: 0.0 # disable auxiliary loss part during finetuning