зеркало из https://github.com/microsoft/archai.git
More progress on freeze training. Testing and debugging underway.
This commit is contained in:
Родитель
08e4d75ab6
Коммит
7e99526376
|
@ -206,6 +206,22 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "random"]
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--full", "--algos", "proxynas"]
|
||||
},
|
||||
{
|
||||
"name": "Proxynas-Toy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "proxynas"]
|
||||
},
|
||||
{
|
||||
"name": "Resnet-Toy",
|
||||
"type": "python",
|
||||
|
|
|
@ -12,6 +12,8 @@ from archai.nas.exp_runner import ExperimentRunner
|
|||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.evaluater import Evaluater, EvalResult
|
||||
|
||||
from archai.common.common import get_expdir, logger
|
||||
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from .freeze_evaluator import FreezeEvaluator
|
||||
|
||||
|
@ -30,8 +32,10 @@ class FreezeExperimentRunner(ExperimentRunner):
|
|||
evaler = self.evaluater()
|
||||
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
|
||||
logger.pushd('freeze_evaluate')
|
||||
freeze_evaler = FreezeEvaluator()
|
||||
freeze_eval_result = freeze_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
# NOTE: Not returning freeze eval results
|
||||
# but it seems like we don't need to anyways as things get logged to disk
|
||||
|
|
|
@ -36,7 +36,7 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
|
||||
@overrides
|
||||
def post_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
|
||||
super()._post_epoch(train_dl, val_dl)
|
||||
super().post_epoch(train_dl, val_dl)
|
||||
|
||||
# if current validation accuracy is above
|
||||
# freeze everything other than the last layer
|
||||
|
|
|
@ -132,7 +132,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state specific to each run
|
||||
optim = self.create_optimizer(self.conf_optim, self.model.parameters())
|
||||
optim = self.create_optimizer(self.conf_optim, filter(lambda p: p.requires_grad, self.model.parameters()))
|
||||
# create scheduler for optim before applying amp
|
||||
sched, sched_on_epoch = self.create_scheduler(self.conf_sched, optim, train_len)
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
|
||||
nas:
|
||||
search:
|
||||
model_desc:
|
||||
num_edges_to_sample: 2 # number of edges each node will take input from
|
||||
|
||||
eval:
|
||||
model_desc:
|
||||
num_edges_to_sample: 2
|
||||
trainer:
|
||||
plotsdir: ''
|
||||
val_top1_acc_threshold: 0.6 # after 60% accuracy we will shift into training only the last layer
|
|
@ -0,0 +1,3 @@
|
|||
# in toy mode, load the config for algo and then override with common settings for toy mode
|
||||
# any additional algo specific toy mode settings will go in this file
|
||||
__include__: ['proxynas.yaml', 'toy_common.yaml']
|
|
@ -14,6 +14,7 @@ from archai.algos.xnas.xnas_exp_runner import XnasExperimentRunner
|
|||
from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
|
||||
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
|
||||
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
|
||||
from archai.algos.proxynas.freeze_experiment_runner import FreezeExperimentRunner
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -25,12 +26,13 @@ def main():
|
|||
'manual': ManualExperimentRunner,
|
||||
'gs': GsExperimentRunner,
|
||||
'divnas': DivnasExperimentRunner,
|
||||
'didarts': DiDartsExperimentRunner
|
||||
'didarts': DiDartsExperimentRunner,
|
||||
'proxynas': FreezeExperimentRunner
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='NAS E2E Runs')
|
||||
parser.add_argument('--algos', type=str, default='darts,xnas,random,didarts,petridish,gs,manual,divnas',
|
||||
help='NAS algos to run, seperated by comma')
|
||||
parser.add_argument('--algos', type=str, default='darts,xnas,random,didarts,petridish,gs,manual,divnas,proxynas',
|
||||
help='NAS algos to run, separated by comma')
|
||||
parser.add_argument('--datasets', type=str, default='cifar10',
|
||||
help='datasets to use, separated by comma')
|
||||
parser.add_argument('--full', type=lambda x:x.lower()=='true',
|
||||
|
|
Загрузка…
Ссылка в новой задаче