Getting ready to run FEAR on DARTS space without any freeze.

This commit is contained in:
Debadeepta Dey 2021-10-17 12:01:01 -07:00 коммит произвёл Gustavo Rosa
Родитель 66d56ce796
Коммит 7ebad4ff40
5 изменённых файлов: 31 добавлений и 9 удалений

9
.vscode/launch.json поставляемый
Просмотреть файл

@ -921,6 +921,15 @@
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
"--out-dir", "F:\\archai_experiment_reports"]
},
{
"name": "Analysis Simulate FEAR on Nasbench301",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/reports/fear_analysis/simulate_fear_on_nb301.py",
"console": "integratedTerminal",
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
"--out-dir", "F:\\archai_experiment_reports"]
},
{
"name": "CurrentFile",
"type": "python",

Просмотреть файл

@ -39,8 +39,18 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
# freeze everything other than the last layer
if not self.conf_train['bypass_freeze']:
logger.info('no freezing!')
# addup parameters which are not frozen
num_frozen_params = 0
for l in model_stats.layer_stats:
for identifier in self.conf_train['identifiers_to_unfreeze']:
if identifier in l.name:
num_frozen_params += l.parameters
ratio_unfrozen = num_frozen_params / model_stats.parameters
logger.info(f'unfrozen parameters ratio {ratio_unfrozen}')
self._freeze_but_last_layer()
else:
logger.info(f'Bypassing freezing!')
def _freeze_but_last_layer(self) -> None:

Просмотреть файл

@ -91,6 +91,7 @@ nas:
dataset:
_copy: '/dataset'
freeze_trainer:
bypass_freeze: True
identifiers_to_unfreeze: ['logits_op._op', 'cells.7', 'cells.6'] # last few layer names in DARTS space are 'logits_op._op': Linear, 'cells.19': prefix for all cell 19 parameters
apex:
_copy: '/common/apex'
@ -125,7 +126,7 @@ nas:
lossfn:
type: 'CrossEntropyLoss'
trainer_full:
top1_acc_threshold: 0.6
top1_acc_threshold: 0.1
use_val: False
apex:
_copy: '/common/apex'

Просмотреть файл

@ -14,7 +14,7 @@ import statistics
def find_train_thresh_epochs(train_acc:List[float], train_thresh:float)->int:
for i, t in enumerate(train_acc):
if t >= train_thresh:
return i
return i + 1
def main():
@ -24,13 +24,14 @@ def main():
args, extra_args = parser.parse_known_args()
train_thresh = 60.0
post_thresh_epochs = 5
post_thresh_epochs = 10
all_test_acc = []
all_fear_end_acc = []
all_fear_time = []
all_reg_train_acc = defaultdict(list)
all_reg_train_time_per_epoch = defaultdict(list)
# collect all the json file names in the log dir recursively
for root, dir, files in os.walk(args.nb301_logs_dir):
@ -61,6 +62,7 @@ def main():
# evaluation baseline
for epoch_num, train_acc in enumerate(log_data['learning_curves']['Train/train_accuracy']):
all_reg_train_acc[epoch_num].append(train_acc)
all_reg_train_time_per_epoch[epoch_num].append((epoch_num + 1) * per_epoch_time)
@ -69,14 +71,17 @@ def main():
print(f'FEAR avg time: {statistics.mean(all_fear_time)}')
spes_train_acc_vs_epoch = {}
avg_time_train_acc_vs_epoch = {}
for epoch_num, train_accs_epoch in all_reg_train_acc.items():
if len(train_accs_epoch) != len(all_test_acc):
continue
this_spe, _ = spearmanr(all_test_acc, train_accs_epoch)
spes_train_acc_vs_epoch[epoch_num] = this_spe
avg_time_train_acc_vs_epoch[epoch_num] = statistics.mean(all_reg_train_time_per_epoch[epoch_num])
for epoch_num, spe in spes_train_acc_vs_epoch.items():
print(f'Epoch {epoch_num}, spearman {spe}')
avg_time = avg_time_train_acc_vs_epoch[epoch_num]
print(f'Epoch {epoch_num}, spearman {spe}, avg. time: {avg_time} seconds')

Просмотреть файл

@ -4,8 +4,6 @@ exp_folder: 'F:\\archai_experiment_reports'
darts_cifar10:
freezetrain:
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly'
<<<<<<< HEAD
=======
ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly'
ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly'
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly'
@ -13,7 +11,6 @@ darts_cifar10:
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly'
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly'
>>>>>>> eaaf94d5 (Added analysis script for ranking in nb301.)
shortreg:
dt_reg_b96_e5: 'dt_reg_b96_e5'