зеркало из https://github.com/microsoft/archai.git
Added new algo manual_freeze which can train a handcrafted network with freezetrainer.
This commit is contained in:
Родитель
4efc420bbe
Коммит
9a54ffb594
|
@ -247,13 +247,29 @@
|
|||
"args": ["--aml_secrets_filepath", "/home/dedey/aml_secrets/aml_secrets_msrlabspvc1.yaml", "--algo", "darts", "--full"]
|
||||
},
|
||||
{
|
||||
"name": "Manual-E2E-Toy",
|
||||
"name": "Manual-Toy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "manual"]
|
||||
},
|
||||
{
|
||||
"name": "ManualFreeze-Toy",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "manual_freeze"]
|
||||
},
|
||||
{
|
||||
"name": "ManualFreeze-Full",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--algos", "manual_freeze", "--full"]
|
||||
},
|
||||
{
|
||||
"name": "TrainAug resnet50 cocob cifar10",
|
||||
"type": "python",
|
||||
|
@ -280,7 +296,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/reports/exprep.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--results-dir", "C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\proxynas_initial",
|
||||
"args": ["--results-dir", "C:\\Users\\dedey\\Documents\\archaiphilly\\phillytools\\proxynas_unfreeze_logits",
|
||||
"--out-dir", "C:\\Users\\dedey\\archai_experiment_reports"]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional
|
||||
import importlib
|
||||
import sys
|
||||
import string
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.nas.evaluater import Evaluater
|
||||
from archai.algos.proxynas.freeze_trainer import FreezeTrainer
|
||||
|
||||
|
||||
|
||||
class ManualFreezeEvaluater(Evaluater):
|
||||
@overrides
|
||||
def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder,
|
||||
final_desc_filename=None, full_desc_filename=None)->nn.Module:
|
||||
# region conf vars
|
||||
dataset_name = conf_eval['loader']['dataset']['name']
|
||||
|
||||
# if explicitly passed in then don't get from conf
|
||||
if not final_desc_filename:
|
||||
final_desc_filename = conf_eval['final_desc_filename']
|
||||
model_factory_spec = conf_eval['model_factory_spec']
|
||||
# endregion
|
||||
|
||||
assert model_factory_spec
|
||||
|
||||
return self._model_from_factory(model_factory_spec, dataset_name)
|
||||
|
||||
def _model_from_factory(self, model_factory_spec:str, dataset_name:str)->Model:
|
||||
splitted = model_factory_spec.rsplit('.', 1)
|
||||
function_name = splitted[-1]
|
||||
|
||||
if len(splitted) > 1:
|
||||
module_name = splitted[0]
|
||||
else:
|
||||
module_name = self._default_module_name(dataset_name, function_name)
|
||||
|
||||
module = importlib.import_module(module_name) if module_name else sys.modules[__name__]
|
||||
function = getattr(module, function_name)
|
||||
model = function()
|
||||
|
||||
logger.info({'model_factory':True,
|
||||
'module_name': module_name,
|
||||
'function_name': function_name,
|
||||
'params': ml_utils.param_size(model)})
|
||||
|
||||
return model
|
||||
|
||||
@overrides
|
||||
def train_model(self, conf_train:Config, model:nn.Module,
|
||||
checkpoint:Optional[CheckPoint])->Metrics:
|
||||
conf_loader = conf_train['loader']
|
||||
conf_train = conf_train['trainer']
|
||||
|
||||
# TODO: this will not be needed after precise freeze stopping is implemented
|
||||
# but reasonable for now
|
||||
conf_train['epochs'] = conf_train['proxynas']['freeze_epochs']
|
||||
|
||||
# get data
|
||||
train_dl, test_dl = self.get_data(conf_loader)
|
||||
|
||||
trainer = FreezeTrainer(conf_train, model, checkpoint)
|
||||
train_metrics = trainer.fit(train_dl, test_dl)
|
||||
return train_metrics
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas import nas_utils
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.evaluater import EvalResult
|
||||
from .manual_freeze_evaluater import ManualFreezeEvaluater
|
||||
from .manual_freeze_searcher import ManualFreezeSearcher
|
||||
|
||||
from archai.common.common import get_expdir, logger
|
||||
|
||||
|
||||
|
||||
class ManualFreezeExperimentRunner(ExperimentRunner):
|
||||
"""Runs manually designed models such as resnet"""
|
||||
|
||||
@overrides
|
||||
def model_desc_builder(self)->Optional[ModelDescBuilder]:
|
||||
return None
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None # no search trainer
|
||||
|
||||
@overrides
|
||||
def searcher(self)->ManualFreezeSearcher:
|
||||
return ManualFreezeSearcher()
|
||||
|
||||
@overrides
|
||||
def evaluater(self)->ManualFreezeEvaluater:
|
||||
return ManualFreezeEvaluater()
|
||||
|
||||
@overrides
|
||||
def copy_search_to_eval(self)->None:
|
||||
pass
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
# regular evaluation of the architecture
|
||||
reg_eval_result = None
|
||||
if conf_eval['trainer']['proxynas']['train_regular']:
|
||||
evaler = self.evaluater()
|
||||
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
|
||||
# change relevant parts of conv_eval to ensure that freeze_evaler
|
||||
# doesn't resume from checkpoints created by evaler and saves
|
||||
# models to different names as well
|
||||
conf_eval['full_desc_filename'] = '$expdir/freeze_full_model_desc.yaml'
|
||||
conf_eval['metric_filename'] = '$expdir/freeze_eval_train_metrics.yaml'
|
||||
conf_eval['model_filename'] = '$expdir/freeze_model.pt'
|
||||
|
||||
if conf_eval['checkpoint'] is not None:
|
||||
conf_eval['checkpoint']['filename'] = '$expdir/freeze_checkpoint.pth'
|
||||
|
||||
logger.pushd('freeze_evaluate')
|
||||
freeze_evaler = ManualFreezeEvaluater()
|
||||
freeze_eval_result = freeze_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
# NOTE: Not returning freeze eval results to meet signature contract
|
||||
# but it seems like we don't need to anyways as everything we need is
|
||||
# logged to disk
|
||||
if reg_eval_result is not None:
|
||||
return reg_eval_result
|
||||
else:
|
||||
return freeze_eval_result
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Iterator, Mapping, Type, Optional, Tuple, List
|
||||
import math
|
||||
import copy
|
||||
import random
|
||||
import os
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from archai.common.common import logger
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas.arch_trainer import TArchTrainer
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.nas.model_desc import CellType, ModelDesc
|
||||
from archai.datasets import data
|
||||
from archai.nas.model import Model
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.common import utils
|
||||
from archai.nas.finalizers import Finalizers
|
||||
from archai.nas.searcher import Searcher, SearchResult
|
||||
|
||||
class ManualFreezeSearcher(Searcher):
|
||||
@overrides
|
||||
def search(self, conf_search:Config, model_desc_builder:Optional[ModelDescBuilder],
|
||||
trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult:
|
||||
# for manual search, we already have a model so no search result are returned
|
||||
return SearchResult(None, None, None)
|
|
@ -21,6 +21,7 @@ from archai.nas.arch_trainer import ArchTrainer
|
|||
from archai.common.trainer import Trainer
|
||||
from archai.nas.vis_model_desc import draw_model_desc
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.ml_utils import set_optim_lr
|
||||
|
||||
TFreezeTrainer = Optional[Type['FreezeTrainer']]
|
||||
|
||||
|
@ -51,20 +52,20 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
# freeze everything other than the last layer
|
||||
self.freeze_but_last_layer()
|
||||
|
||||
# reset optimizer
|
||||
del self._multi_optim
|
||||
# # reset optimizer
|
||||
# del self._multi_optim
|
||||
|
||||
self.conf_optim['lr'] = self.conf_train['proxynas']['freeze_lr']
|
||||
self.conf_optim['decay'] = self.conf_train['proxynas']['freeze_decay']
|
||||
self.conf_optim['momentum'] = self.conf_train['proxynas']['freeze_momentum']
|
||||
self.conf_sched = Config()
|
||||
self._aux_weight = self.conf_train['proxynas']['aux_weight']
|
||||
# self.conf_optim['lr'] = self.conf_train['proxynas']['freeze_lr']
|
||||
# self.conf_optim['decay'] = self.conf_train['proxynas']['freeze_decay']
|
||||
# self.conf_optim['momentum'] = self.conf_train['proxynas']['freeze_momentum']
|
||||
# self.conf_sched = Config()
|
||||
# self._aux_weight = self.conf_train['proxynas']['aux_weight']
|
||||
|
||||
self.model.zero_grad()
|
||||
self._multi_optim = self.create_multi_optim(len(train_dl))
|
||||
# before checkpoint restore, convert to amp
|
||||
self.model = self._apex.to_amp(self.model, self._multi_optim,
|
||||
batch_size=train_dl.batch_size)
|
||||
# self.model.zero_grad()
|
||||
# self._multi_optim = self.create_multi_optim(len(train_dl))
|
||||
# # before checkpoint restore, convert to amp
|
||||
# self.model = self._apex.to_amp(self.model, self._multi_optim,
|
||||
# batch_size=train_dl.batch_size)
|
||||
|
||||
self._in_freeze_mode = True
|
||||
self._epoch_freeze_started = self._metrics.epochs()
|
||||
|
@ -83,7 +84,9 @@ class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
|||
param.requires_grad = False
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'logits_op._op' in name:
|
||||
# TODO: Make the layer names to be updated a config value
|
||||
# 'logits_op._op'
|
||||
if 'fc' in name:
|
||||
param.requires_grad = True
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
|
||||
nas:
|
||||
eval:
|
||||
model_factory_spec: 'resnet18'
|
||||
|
||||
#darts loader/trainer
|
||||
loader:
|
||||
train_batch: 128 #96
|
||||
cutout: 0
|
||||
trainer:
|
||||
plotsdir: ''
|
||||
aux_weight: 0.0
|
||||
grad_clip: 0.0
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
epochs: 200
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.0333 #0.025 # init learning rate
|
||||
decay: 3.0e-4 # pytorch default is 0.0
|
||||
momentum: 0.9 # pytorch default is 0.0
|
||||
nesterov: False # pytorch default is False
|
||||
warmup: null
|
||||
lr_schedule:
|
||||
type: 'cosine'
|
||||
min_lr: 0.001 # min learning rate to be set in eta_min param of scheduler
|
||||
proxynas:
|
||||
val_top1_acc_threshold: 0.05 # after some accuracy we will shift into training only the last layer
|
||||
freeze_epochs: 200
|
||||
freeze_lr: 0.001
|
||||
freeze_decay: 0.0
|
||||
freeze_momentum: 0.0
|
||||
train_regular: False
|
||||
aux_weight: 0.0 # disable auxiliary loss part during finetuning
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# in toy mode, load the confif for algo and then override with common settings for toy mode
|
||||
# any additional algo specific toy mode settings will go in this file
|
||||
__include__: ['manual_freeze.yaml', 'toy_common.yaml']
|
|
@ -13,7 +13,7 @@ nas:
|
|||
plotsdir: ''
|
||||
epochs: 600
|
||||
proxynas:
|
||||
val_top1_acc_threshold: 0.6 # after some accuracy we will shift into training only the last layer
|
||||
val_top1_acc_threshold: 0.60 # after some accuracy we will shift into training only the last layer
|
||||
freeze_epochs: 200
|
||||
freeze_lr: 0.001
|
||||
freeze_decay: 0.0
|
||||
|
|
|
@ -15,6 +15,7 @@ from archai.algos.gumbelsoftmax.gs_exp_runner import GsExperimentRunner
|
|||
from archai.algos.divnas.divnas_exp_runner import DivnasExperimentRunner
|
||||
from archai.algos.didarts.didarts_exp_runner import DiDartsExperimentRunner
|
||||
from archai.algos.proxynas.freeze_experiment_runner import FreezeExperimentRunner
|
||||
from archai.algos.manual_freeze.manual_freeze_exp_runner import ManualFreezeExperimentRunner
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -24,6 +25,7 @@ def main():
|
|||
'xnas': XnasExperimentRunner,
|
||||
'random': RandomExperimentRunner,
|
||||
'manual': ManualExperimentRunner,
|
||||
'manual_freeze': ManualFreezeExperimentRunner,
|
||||
'gs': GsExperimentRunner,
|
||||
'divnas': DivnasExperimentRunner,
|
||||
'didarts': DiDartsExperimentRunner,
|
||||
|
|
Загрузка…
Ссылка в новой задаче