зеркало из https://github.com/microsoft/archai.git
Added lots of sanity checks but freeze training is still degrading perf!
This commit is contained in:
Родитель
e5bf0312e5
Коммит
bb80965127
|
@ -31,7 +31,7 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
# region config vars specific to freeze trainer
|
||||
self._val_top1_acc = conf_train['val_top1_acc_threshold']
|
||||
self._val_top1_acc = conf_train['proxynas']['val_top1_acc_threshold']
|
||||
self._in_freeze_mode = False
|
||||
# endregion
|
||||
|
||||
|
@ -55,25 +55,23 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
del self._multi_optim
|
||||
|
||||
self.conf_optim['lr'] = self.conf_train['proxynas']['freeze_lr']
|
||||
self.conf_optim['decay'] = self.conf_train['proxynas']['freeze_decay']
|
||||
self.conf_optiom['momentum'] = self.conf_train['proxynas']['freeze_momentum']
|
||||
self.conf_sched = Config()
|
||||
self._aux_weight = self.conf_train['proxynas']['aux_weight']
|
||||
|
||||
self.model.zero_grad()
|
||||
self._multi_optim = self.create_multi_optim(len(train_dl))
|
||||
|
||||
# before checkpoint restore, convert to amp
|
||||
self.model = self._apex.to_amp(self.model, self._multi_optim,
|
||||
batch_size=train_dl.batch_size)
|
||||
|
||||
|
||||
|
||||
self._in_freeze_mode = True
|
||||
self._epoch_freeze_started = self._metrics.epochs()
|
||||
self._max_epochs = self._epoch_freeze_started + self.conf_train['proxynas']['freeze_epochs']
|
||||
logger.info('-----------Entered freeze training-----------------')
|
||||
|
||||
# TODO: Implement precise stopping at epoch
|
||||
# if self._in_freeze_mode and self.metrics.epochs() > self._max_epochs:
|
||||
# break
|
||||
|
||||
|
||||
|
||||
|
||||
def freeze_but_last_layer(self) -> None:
|
||||
|
@ -82,11 +80,16 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
# layer has the word 'logits' in the name string
|
||||
# e.g. logits_op._op.weight, logits_op._op.bias
|
||||
# e.g. _aux_towers.13.logits_op.weight, _aux_towers.13.logits_op.bias
|
||||
# TODO: confirm from Shital that this is good!
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if not 'logits' in name:
|
||||
param.requires_grad = False
|
||||
if 'logits_op._op' in name:
|
||||
param.requires_grad = True
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
logger.info(f'{name} requires grad')
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ class Evaluater(EnforceOverrides):
|
|||
def _default_module_name(self, dataset_name:str, function_name:str)->str:
|
||||
"""Select PyTorch pre-defined network to support manual mode"""
|
||||
module_name = ''
|
||||
# TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead
|
||||
# TODO: below detection code is too weak, need to improve, possibly encode image size in yaml and use that instead
|
||||
if dataset_name.startswith('cifar'):
|
||||
if function_name.startswith('res'): # support resnext as well
|
||||
module_name = 'archai.cifar10_models.resnet'
|
||||
|
|
|
@ -11,10 +11,12 @@ nas:
|
|||
num_edges_to_sample: 2
|
||||
trainer:
|
||||
plotsdir: ''
|
||||
val_top1_acc_threshold: 0.1 # after some accuracy we will shift into training only the last layer
|
||||
epochs: 600
|
||||
proxynas:
|
||||
val_top1_acc_threshold: 0.05 # after some accuracy we will shift into training only the last layer
|
||||
freeze_epochs: 10
|
||||
freeze_lr: 0.001
|
||||
freeze_decay: 0.0
|
||||
freeze_momentum: 0.0
|
||||
train_regular: False
|
||||
aux_weight: 0.0 # disable auxiliary loss part during finetuning
|
Загрузка…
Ссылка в новой задаче