зеркало из https://github.com/microsoft/archai.git
Train and freeze all but last layer implemented. To be tested.
This commit is contained in:
Родитель
82a65ad075
Коммит
08e4d75ab6
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from archai.nas.evaluater import Evaluater
|
||||
from typing import Optional, Tuple
|
||||
import importlib
|
||||
import sys
|
||||
import string
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
|
||||
from .freeze_trainer import FreezeTrainer
|
||||
|
||||
|
||||
class FreezeEvaluator(Evaluater, EnforceOverrides):
|
||||
@overrides
|
||||
def train_model(self, conf_train:Config, model:nn.Module,
|
||||
checkpoint:Optional[CheckPoint])->Metrics:
|
||||
conf_loader = conf_train['loader']
|
||||
conf_train = conf_train['trainer']
|
||||
|
||||
# get data
|
||||
train_dl, test_dl = self.get_data(conf_loader)
|
||||
|
||||
trainer = FreezeTrainer(conf_train, model, checkpoint)
|
||||
train_metrics = trainer.fit(train_dl, test_dl)
|
||||
return train_metrics
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from archai.nas.evaluater import EvalResult
|
||||
from typing import Type
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.nas import nas_utils
|
||||
from archai.nas.exp_runner import ExperimentRunner
|
||||
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
|
||||
from archai.nas.evaluater import Evaluater, EvalResult
|
||||
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from .freeze_evaluator import FreezeEvaluator
|
||||
|
||||
class FreezeExperimentRunner(ExperimentRunner):
|
||||
@overrides
|
||||
def model_desc_builder(self)->RandomModelDescBuilder:
|
||||
return RandomModelDescBuilder()
|
||||
|
||||
@overrides
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
return None
|
||||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
# regular evaluation of the architecture
|
||||
evaler = self.evaluater()
|
||||
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
|
||||
freeze_evaler = FreezeEvaluator()
|
||||
freeze_eval_result = freeze_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
|
||||
# NOTE: Not returning freeze eval results
|
||||
# but it seems like we don't need to anyways as things get logged to disk
|
||||
return reg_eval_result
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from typing import Optional, Callable, Type
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import nn, Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.config import Config
|
||||
from archai.common import common, utils
|
||||
from archai.nas.model import Model
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.arch_trainer import ArchTrainer
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.nas.vis_model_desc import draw_model_desc
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
|
||||
TFreezeTrainer = Optional[Type['FreezeTrainer']]
|
||||
|
||||
|
||||
class FreezeTrainer(ArchTrainer, EnforceOverrides):
|
||||
def __init__(self, conf_train: Config, model: nn.Module,
|
||||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
# region config vars specific to freeze trainer
|
||||
self.conf_train = conf_train
|
||||
self._val_top1_acc = conf_train['val_top1_acc_threshold']
|
||||
# endregion
|
||||
|
||||
@overrides
|
||||
def post_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
|
||||
super()._post_epoch(train_dl, val_dl)
|
||||
|
||||
# if current validation accuracy is above
|
||||
# freeze everything other than the last layer
|
||||
best_val_top1_avg = self._metrics.best_val_top1()
|
||||
|
||||
if best_val_top1_avg >= self._val_top1_acc:
|
||||
|
||||
# freeze everything other than the last layer
|
||||
self.freeze_but_last_layer()
|
||||
|
||||
|
||||
def freeze_but_last_layer(self) -> None:
|
||||
# first freeze all parameters
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# NOTE: assumption here is that the last
|
||||
# layer has the word 'logits' in the name string
|
||||
# e.g. logits_op._op.weight, logits_op._op.bias
|
||||
# e.g. _aux_towers.13.logits_op.weight, _aux_towers.13.logits_op.bias
|
||||
# TODO: confirm from Shital that this is good!
|
||||
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'logits' in name:
|
||||
param.requires_grad = True
|
||||
|
||||
|
||||
|
|
@ -196,7 +196,7 @@ class Trainer(EnforceOverrides):
|
|||
if self._metrics.epochs() % self._validation_freq == 0 or \
|
||||
self._metrics.epochs() >= self._epochs: # last epoch
|
||||
|
||||
# these asserts makes sure train and val are not ovrlapiing
|
||||
# these asserts makes sure train and val are not overlapping
|
||||
# assert train_dl.sampler.epoch == val_dl.sampler.epoch
|
||||
# tidx = list(train_dl.sampler)
|
||||
# vidx = list(val_dl.sampler)
|
||||
|
|
Загрузка…
Ссылка в новой задаче