зеркало из https://github.com/microsoft/archai.git
Added functionality to bypass freezing completely for stage 2 of natsbench tss training.
This commit is contained in:
Родитель
fb1bef8af6
Коммит
c3bdb2125a
|
@ -254,7 +254,7 @@
|
|||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "proxynas_natsbench_sss_space", "--datasets", "cifar10"],
|
||||
"justMyCode": false
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Phased-FT-Natsbench-Space-Full",
|
||||
|
@ -727,7 +727,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6",
|
||||
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_r0.2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -38,26 +38,13 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
super().pre_fit(data_loaders)
|
||||
|
||||
# freeze everything other than the last layer
|
||||
self._freeze_but_last_layer()
|
||||
if not self.conf_train['bypass_freeze']:
|
||||
logger.info('no freezing!')
|
||||
self._freeze_but_last_layer()
|
||||
|
||||
|
||||
def _freeze_but_last_layer(self) -> None:
|
||||
|
||||
# # Freezing via module names
|
||||
# for module in self.model.modules():
|
||||
# module.requires_grad = False
|
||||
|
||||
# # Unfreeze only some
|
||||
# for name, module in self.model.named_modules():
|
||||
# for identifier in self.conf_train['identifiers_to_unfreeze']:
|
||||
# if identifier in name:
|
||||
# print('we are hitting this')
|
||||
# module.requires_grad = True
|
||||
|
||||
# for name, module in self.model.named_modules():
|
||||
# if module.requires_grad:
|
||||
# logger.info(f'{name} requires grad')
|
||||
|
||||
# Do it via parameters
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
|
|
@ -19,11 +19,7 @@ nas:
|
|||
naswotrain:
|
||||
train_batch: 256 # batch size for computing trainingless score
|
||||
freeze_loader:
|
||||
<<<<<<< HEAD
|
||||
train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
|
||||
=======
|
||||
train_batch: 256 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
|
||||
>>>>>>> ea99f933 (Sanity checked zero cost on synthetic cifar10.)
|
||||
trainer:
|
||||
plotsdir: ''
|
||||
use_val: False
|
||||
|
@ -56,6 +52,7 @@ nas:
|
|||
|
||||
freeze_trainer:
|
||||
plotsdir: ''
|
||||
bypass_freeze: True # if true will not freeze anything. identifiers_to_unfreeze has no effect.
|
||||
identifiers_to_unfreeze: ['classifier', 'lastact', 'cells.16', 'cells.15', 'cells.14', 'cells.13'] # last few layer names in natsbench: lastact, lastact.0, lastact.1: BN-Relu, global_pooling: global avg. pooling (doesn't get exposed as a named param though), classifier: linear layer
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
|
@ -65,7 +62,7 @@ nas:
|
|||
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
|
||||
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
|
||||
title: 'eval_train'
|
||||
epochs: 10
|
||||
epochs: 5
|
||||
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
|
||||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
|
|
|
@ -52,7 +52,7 @@ nas:
|
|||
plotsdir: ''
|
||||
use_ratio: True
|
||||
layer_basenames: ['classifier', 'lastact', 'cells.4', 'cells.3', 'cells.2', 'cells.1'] # used when use_ratio is True. Important that this is from output layer to input layer order.
|
||||
desired_ratio_unfreeze: 0.1 # used when use_ratio is True
|
||||
desired_ratio_unfreeze: 1.0 # used when use_ratio is True
|
||||
identifiers_to_unfreeze: ['classifier', 'lastact', 'cells.4', 'cells.3', 'cells.2', 'cells.1'] # used when use_ratio is False
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
|
|
|
@ -4,6 +4,12 @@ exp_folder: 'F:\\archai_experiment_reports'
|
|||
natsbench_sss_cifar10:
|
||||
freezetrain:
|
||||
nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
|
||||
nb_sss_c3_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c3_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
|
||||
nb_sss_c2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
|
||||
nb_sss_c1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
|
||||
|
||||
nb_sss_r0.1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_r0.1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
|
||||
nb_sss_r0.2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_r0.2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
|
||||
|
||||
shortreg:
|
||||
nb_sss_reg_b256_e5: 'nb_sss_reg_b256_e5'
|
||||
|
|
Загрузка…
Ссылка в новой задаче