Added functionality to bypass freezing completely for stage 2 of natsbench tss training.

This commit is contained in:
Debadeepta Dey 2021-08-01 20:08:54 -07:00 коммит произвёл Gustavo Rosa
Родитель fb1bef8af6
Коммит c3bdb2125a
5 изменённых файлов: 14 добавлений и 24 удалений

4
.vscode/launch.json поставляемый
Просмотреть файл

@ -254,7 +254,7 @@
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "proxynas_natsbench_sss_space", "--datasets", "cifar10"],
"justMyCode": false
"justMyCode": true
},
{
"name": "Proxynas-Phased-FT-Natsbench-Space-Full",
@ -727,7 +727,7 @@
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/analysis_freeze_natsbench_sss.py",
"console": "integratedTerminal",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6",
"args": ["--results-dir", "F:\\archaiphilly\\phillytools\\nb_sss_r0.2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6",
"--out-dir", "F:\\archai_experiment_reports"]
},
{

Просмотреть файл

@ -38,26 +38,13 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
super().pre_fit(data_loaders)
# freeze everything other than the last layer
self._freeze_but_last_layer()
if not self.conf_train['bypass_freeze']:
logger.info('no freezing!')
self._freeze_but_last_layer()
def _freeze_but_last_layer(self) -> None:
# # Freezing via module names
# for module in self.model.modules():
# module.requires_grad = False
# # Unfreeze only some
# for name, module in self.model.named_modules():
# for identifier in self.conf_train['identifiers_to_unfreeze']:
# if identifier in name:
# print('we are hitting this')
# module.requires_grad = True
# for name, module in self.model.named_modules():
# if module.requires_grad:
# logger.info(f'{name} requires grad')
# Do it via parameters
for param in self.model.parameters():
param.requires_grad = False

Просмотреть файл

@ -19,11 +19,7 @@ nas:
naswotrain:
train_batch: 256 # batch size for computing trainingless score
freeze_loader:
<<<<<<< HEAD
train_batch: 1024 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
=======
train_batch: 256 # batch size for freeze training. 2048 works reliably on V100 with cell13 onwards unfrozen
>>>>>>> ea99f933 (Sanity checked zero cost on synthetic cifar10.)
trainer:
plotsdir: ''
use_val: False
@ -56,6 +52,7 @@ nas:
freeze_trainer:
plotsdir: ''
bypass_freeze: True # if true will not freeze anything. identifiers_to_unfreeze has no effect.
identifiers_to_unfreeze: ['classifier', 'lastact', 'cells.16', 'cells.15', 'cells.14', 'cells.13'] # last few layer names in natsbench: lastact, lastact.0, lastact.1: BN-Relu, global_pooling: global avg. pooling (doesn't get exposed as a named param though), classifier: linear layer
apex:
_copy: '/common/apex'
@ -65,7 +62,7 @@ nas:
l1_alphas: 0.0 # weight to be applied to sum(abs(alphas)) to loss term
logger_freq: 1000 # after every N updates dump loss and other metrics in logger
title: 'eval_train'
epochs: 10
epochs: 5
batch_chunks: 1 # split batch into these many chunks and accumulate gradients so we can support GPUs with lower RAM
lossfn:
type: 'CrossEntropyLoss'

Просмотреть файл

@ -52,7 +52,7 @@ nas:
plotsdir: ''
use_ratio: True
layer_basenames: ['classifier', 'lastact', 'cells.4', 'cells.3', 'cells.2', 'cells.1'] # used when use_ratio is True. Important that this is from output layer to input layer order.
desired_ratio_unfreeze: 0.1 # used when use_ratio is True
desired_ratio_unfreeze: 1.0 # used when use_ratio is True
identifiers_to_unfreeze: ['classifier', 'lastact', 'cells.4', 'cells.3', 'cells.2', 'cells.1'] # used when use_ratio is False
apex:
_copy: '/common/apex'

Просмотреть файл

@ -4,6 +4,12 @@ exp_folder: 'F:\\archai_experiment_reports'
natsbench_sss_cifar10:
freezetrain:
nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c4_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
nb_sss_c3_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c3_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
nb_sss_c2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
nb_sss_c1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_c1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
nb_sss_r0.1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_r0.1_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
nb_sss_r0.2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6: 'nb_sss_r0.2_ft_fb256_ftlr0.1_fte10_ct256_ftt0.6'
shortreg:
nb_sss_reg_b256_e5: 'nb_sss_reg_b256_e5'