зеркало из https://github.com/microsoft/archai.git
Getting ready to run FEAR on DARTS space without any freeze.
This commit is contained in:
Родитель
66d56ce796
Коммит
7ebad4ff40
|
@ -921,6 +921,15 @@
|
|||
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
"name": "Analysis Simulate FEAR on Nasbench301",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/fear_analysis/simulate_fear_on_nb301.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--nb301-logs-dir", "C:\\Users\\dedey\\dataroot\\nasbench301\\nasbench301_full_data\\nb_301_v13_lc_iclr_final\\rs",
|
||||
"--out-dir", "F:\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
"name": "CurrentFile",
|
||||
"type": "python",
|
||||
|
|
|
@ -39,8 +39,18 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
|
||||
# freeze everything other than the last layer
|
||||
if not self.conf_train['bypass_freeze']:
|
||||
logger.info('no freezing!')
|
||||
# addup parameters which are not frozen
|
||||
num_frozen_params = 0
|
||||
for l in model_stats.layer_stats:
|
||||
for identifier in self.conf_train['identifiers_to_unfreeze']:
|
||||
if identifier in l.name:
|
||||
num_frozen_params += l.parameters
|
||||
ratio_unfrozen = num_frozen_params / model_stats.parameters
|
||||
logger.info(f'unfrozen parameters ratio {ratio_unfrozen}')
|
||||
|
||||
self._freeze_but_last_layer()
|
||||
else:
|
||||
logger.info(f'Bypassing freezing!')
|
||||
|
||||
|
||||
def _freeze_but_last_layer(self) -> None:
|
||||
|
|
|
@ -91,6 +91,7 @@ nas:
|
|||
dataset:
|
||||
_copy: '/dataset'
|
||||
freeze_trainer:
|
||||
bypass_freeze: True
|
||||
identifiers_to_unfreeze: ['logits_op._op', 'cells.7', 'cells.6'] # last few layer names in DARTS space are 'logits_op._op': Linear, 'cells.19': prefix for all cell 19 parameters
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
|
@ -125,7 +126,7 @@ nas:
|
|||
lossfn:
|
||||
type: 'CrossEntropyLoss'
|
||||
trainer_full:
|
||||
top1_acc_threshold: 0.6
|
||||
top1_acc_threshold: 0.1
|
||||
use_val: False
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
|
|
|
@ -14,7 +14,7 @@ import statistics
|
|||
def find_train_thresh_epochs(train_acc:List[float], train_thresh:float)->int:
|
||||
for i, t in enumerate(train_acc):
|
||||
if t >= train_thresh:
|
||||
return i
|
||||
return i + 1
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -24,13 +24,14 @@ def main():
|
|||
args, extra_args = parser.parse_known_args()
|
||||
|
||||
train_thresh = 60.0
|
||||
post_thresh_epochs = 5
|
||||
post_thresh_epochs = 10
|
||||
|
||||
all_test_acc = []
|
||||
all_fear_end_acc = []
|
||||
all_fear_time = []
|
||||
|
||||
all_reg_train_acc = defaultdict(list)
|
||||
all_reg_train_time_per_epoch = defaultdict(list)
|
||||
|
||||
# collect all the json file names in the log dir recursively
|
||||
for root, dir, files in os.walk(args.nb301_logs_dir):
|
||||
|
@ -61,6 +62,7 @@ def main():
|
|||
# evaluation baseline
|
||||
for epoch_num, train_acc in enumerate(log_data['learning_curves']['Train/train_accuracy']):
|
||||
all_reg_train_acc[epoch_num].append(train_acc)
|
||||
all_reg_train_time_per_epoch[epoch_num].append((epoch_num + 1) * per_epoch_time)
|
||||
|
||||
|
||||
|
||||
|
@ -69,14 +71,17 @@ def main():
|
|||
print(f'FEAR avg time: {statistics.mean(all_fear_time)}')
|
||||
|
||||
spes_train_acc_vs_epoch = {}
|
||||
avg_time_train_acc_vs_epoch = {}
|
||||
for epoch_num, train_accs_epoch in all_reg_train_acc.items():
|
||||
if len(train_accs_epoch) != len(all_test_acc):
|
||||
continue
|
||||
this_spe, _ = spearmanr(all_test_acc, train_accs_epoch)
|
||||
spes_train_acc_vs_epoch[epoch_num] = this_spe
|
||||
avg_time_train_acc_vs_epoch[epoch_num] = statistics.mean(all_reg_train_time_per_epoch[epoch_num])
|
||||
|
||||
for epoch_num, spe in spes_train_acc_vs_epoch.items():
|
||||
print(f'Epoch {epoch_num}, spearman {spe}')
|
||||
avg_time = avg_time_train_acc_vs_epoch[epoch_num]
|
||||
print(f'Epoch {epoch_num}, spearman {spe}, avg. time: {avg_time} seconds')
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -4,8 +4,6 @@ exp_folder: 'F:\\archai_experiment_reports'
|
|||
darts_cifar10:
|
||||
freezetrain:
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_ftonly'
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte15_ct96_ftt0.6_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly: 'ft_dt_fb96_ftlr0.025_fte20_ct96_ftt0.6_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c5_ftonly'
|
||||
|
@ -13,7 +11,6 @@ darts_cifar10:
|
|||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_c3_ftonly'
|
||||
ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly: 'ft_dt_fb96_ftlr0.025_fte10_ct96_ftt0.6_nofreeze_ftonly'
|
||||
|
||||
>>>>>>> eaaf94d5 (Added analysis script for ranking in nb301.)
|
||||
|
||||
shortreg:
|
||||
dt_reg_b96_e5: 'dt_reg_b96_e5'
|
||||
|
|
Загрузка…
Ссылка в новой задаче