зеркало из https://github.com/microsoft/archai.git
Freeze training on proxynas search space is working properly.
This commit is contained in:
Родитель
2dd7cb63db
Коммит
27e25f672d
|
@ -55,4 +55,8 @@ class TinyNetwork(nn.Module):
|
|||
out = out.view(out.size(0), -1)
|
||||
logits = self.classifier(out)
|
||||
|
||||
return out, logits
|
||||
# archai trainer class expects output to be
|
||||
# logits, aux_logits
|
||||
# WARNING: what does this break?
|
||||
#return out, logits
|
||||
return logits, None
|
||||
|
|
|
@ -42,34 +42,31 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
|
||||
def _freeze_but_last_layer(self) -> None:
|
||||
|
||||
# Freezing via module names
|
||||
for module in self.model.modules():
|
||||
module.requires_grad = False
|
||||
# # Freezing via module names
|
||||
# for module in self.model.modules():
|
||||
# module.requires_grad = False
|
||||
|
||||
# Unfreeze only some
|
||||
for name, module in self.model.named_modules():
|
||||
for identifier in self.conf_train['identifiers_to_unfreeze']:
|
||||
if identifier in name:
|
||||
module.requires_grad = True
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
if module.requires_grad:
|
||||
logger.info(f'{name} requires grad')
|
||||
|
||||
# NOTE: freezing via named_parameters() doesn't expose all parameters? Check with Shital.
|
||||
# for name, param in self.model.named_parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# for name, param in self.model.named_parameters():
|
||||
# # TODO: Make the layer names to be updated a config value
|
||||
# # 'fc' for resnet18
|
||||
# # 'logits_op._op' for darts search space
|
||||
# for identifier in self.conf_train['proxynas']['identifiers_to_unfreeze']:
|
||||
# # Unfreeze only some
|
||||
# for name, module in self.model.named_modules():
|
||||
# for identifier in self.conf_train['identifiers_to_unfreeze']:
|
||||
# if identifier in name:
|
||||
# param.requires_grad = True
|
||||
# print('we are hitting this')
|
||||
# module.requires_grad = True
|
||||
|
||||
# for name, param in self.model.named_parameters():
|
||||
# if param.requires_grad:
|
||||
# for name, module in self.model.named_modules():
|
||||
# if module.requires_grad:
|
||||
# logger.info(f'{name} requires grad')
|
||||
|
||||
|
||||
# Do it via parameters
|
||||
# NOTE: freezing via named_parameters() doesn't expose all parameters? Check with Shital.
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
for identifier in self.conf_train['identifiers_to_unfreeze']:
|
||||
if identifier in name:
|
||||
param.requires_grad = True
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad:
|
||||
logger.info(f'{name} requires grad')
|
|
@ -71,6 +71,10 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._metrics = Metrics(self._title, self._apex, logger_freq=self._logger_freq)
|
||||
|
||||
# NOTE: critical that pre_fit is called before creating optimizers
|
||||
# as otherwise FreezeTrainer does not work correctly
|
||||
self.pre_fit(train_dl, val_dl)
|
||||
|
||||
# create optimizers and schedulers
|
||||
self._multi_optim = self.create_multi_optim(len(data_loaders.train_dl))
|
||||
# before checkpoint restore, convert to amp
|
||||
|
|
|
@ -13,10 +13,11 @@ nas:
|
|||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
loader:
|
||||
train_batch: 96
|
||||
naswotrain:
|
||||
train_batch: 128 # batch size for computing trainingless score
|
||||
freeze_loader:
|
||||
train_batch: 1024 # batch size for freeze training
|
||||
train_batch: 2048 # batch size for freeze training
|
||||
trainer:
|
||||
plotsdir: ''
|
||||
epochs: 600
|
||||
|
@ -24,7 +25,7 @@ nas:
|
|||
|
||||
freeze_trainer:
|
||||
plotsdir: ''
|
||||
identifiers_to_unfreeze: ['classifier'] # last few layer names in natsbench: lastact, lastact.0, lastact.1: BN-Relu, global_pooling: global avg. pooling, classifier: linear layer
|
||||
identifiers_to_unfreeze: ['classifier', 'lastact', 'global_pooling'] # last few layer names in natsbench: lastact, lastact.0, lastact.1: BN-Relu, global_pooling: global avg. pooling, classifier: linear layer
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aux_weight: 0.0 # very important that this is 0.0 for freeze training
|
||||
|
|
Загрузка…
Ссылка в новой задаче