More progress on freeze training. Testing and debugging underway.

This commit is contained in:
Debadeepta Dey 2020-11-25 13:35:28 -08:00 коммит произвёл Gustavo Rosa
Родитель 08e4d75ab6
Коммит 7e99526376
7 изменённых файлов: 44 добавлений и 5 удалений

16
.vscode/launch.json поставляемый
Просмотреть файл

@ -206,6 +206,22 @@
"console": "integratedTerminal",
"args": ["--algos", "random"]
},
{
"name": "Proxynas-Full",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--full", "--algos", "proxynas"]
},
{
"name": "Proxynas-Toy",
"type": "python",
"request": "launch",
"program": "${cwd}/scripts/main.py",
"console": "integratedTerminal",
"args": ["--algos", "proxynas"]
},
{
"name": "Resnet-Toy",
"type": "python",

Просмотреть файл

@ -12,6 +12,8 @@ from archai.nas.exp_runner import ExperimentRunner
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
from archai.nas.evaluater import Evaluater, EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from .freeze_evaluator import FreezeEvaluator
@ -30,8 +32,10 @@ class FreezeExperimentRunner(ExperimentRunner):
evaler = self.evaluater()
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
logger.pushd('freeze_evaluate')
freeze_evaler = FreezeEvaluator()
freeze_eval_result = freeze_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
logger.popd()
# NOTE: Not returning freeze eval results
# but it seems like we don't need to anyways as things get logged to disk

Просмотреть файл

@ -36,7 +36,7 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
@overrides
def post_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
super()._post_epoch(train_dl, val_dl)
super().post_epoch(train_dl, val_dl)
# if current validation accuracy is above
# freeze everything other than the last layer

Просмотреть файл

@ -132,7 +132,7 @@ class Trainer(EnforceOverrides):
# optimizers, schedulers needs to be recreated for each fit call
# as they have state specific to each run
optim = self.create_optimizer(self.conf_optim, self.model.parameters())
optim = self.create_optimizer(self.conf_optim, filter(lambda p: p.requires_grad, self.model.parameters()))
# create scheduler for optim before applying amp
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)

14
confs/algos/proxynas.yaml Normal file
Просмотреть файл

@ -0,0 +1,14 @@
__include__: 'darts.yaml' # just use darts defaults
nas:
search:
model_desc:
num_edges_to_sample: 2 # number of edges each node will take input from
eval:
model_desc:
num_edges_to_sample: 2
trainer:
plotsdir: ''
val_top1_acc_threshold: 0.6 # after 60% accuracy we will shift into training only the last layer

Просмотреть файл

@ -0,0 +1,3 @@
# in toy mode, load the config for algo and then override with common settings for toy mode
# any additional algo specific toy mode settings will go in this file
__include__: ['proxynas.yaml', 'toy_common.yaml']

Просмотреть файл

@ -14,6 +14,7 @@ from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
from archai.algos.proxynas.freeze_experiment_runner import FreezeExperimentRunner
def main():
@ -25,12 +26,13 @@ def main():
'manual': ManualExperimentRunner,
'gs': GsExperimentRunner,
'divnas': DivnasExperimentRunner,
'didarts': DiDartsExperimentRunner
'didarts': DiDartsExperimentRunner,
'proxynas': FreezeExperimentRunner
}
parser = argparse.ArgumentParser(description='NAS E2E Runs')
parser.add_argument('--algos', type=str, default='darts,xnas,random,didarts,petridish,gs,manual,divnas',
help='NAS algos to run, seperated by comma')
parser.add_argument('--algos', type=str, default='darts,xnas,random,didarts,petridish,gs,manual,divnas,proxynas',
help='NAS algos to run, separated by comma')
parser.add_argument('--datasets', type=str, default='cifar10',
help='datasets to use, separated by comma')
parser.add_argument('--full', type=lambda x:x.lower()=='true',