Train and freeze all but last layer implemented. To be tested.

This commit is contained in:
Debadeepta Dey 2020-11-25 10:32:33 -08:00 коммит произвёл Gustavo Rosa
Родитель 82a65ad075
Коммит 08e4d75ab6
5 изменённых файлов: 154 добавлений и 1 удалений

Просмотреть файл

Просмотреть файл

@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from archai.nas.evaluater import Evaluater
from typing import Optional, Tuple
import importlib
import sys
import string
import os
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from overrides import overrides, EnforceOverrides
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas import nas_utils
from archai.common import ml_utils, utils
from archai.common.metrics import EpochMetrics, Metrics
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from .freeze_trainer import FreezeTrainer
class FreezeEvaluator(Evaluater, EnforceOverrides):
@overrides
def train_model(self, conf_train:Config, model:nn.Module,
checkpoint:Optional[CheckPoint])->Metrics:
conf_loader = conf_train['loader']
conf_train = conf_train['trainer']
# get data
train_dl, test_dl = self.get_data(conf_loader)
trainer = FreezeTrainer(conf_train, model, checkpoint)
train_metrics = trainer.fit(train_dl, test_dl)
return train_metrics

Просмотреть файл

@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from archai.nas.evaluater import EvalResult
from typing import Type
from overrides import overrides
from archai.common.config import Config
from archai.nas import nas_utils
from archai.nas.exp_runner import ExperimentRunner
from archai.nas.arch_trainer import ArchTrainer, TArchTrainer
from archai.nas.evaluater import Evaluater, EvalResult
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from .freeze_evaluator import FreezeEvaluator
class FreezeExperimentRunner(ExperimentRunner):
@overrides
def model_desc_builder(self)->RandomModelDescBuilder:
return RandomModelDescBuilder()
@overrides
def trainer_class(self)->TArchTrainer:
return None
@overrides
def run_eval(self, conf_eval:Config)->EvalResult:
# regular evaluation of the architecture
evaler = self.evaluater()
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
freeze_evaler = FreezeEvaluator()
freeze_eval_result = freeze_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
# NOTE: Not returning freeze eval results
# but it seems like we don't need to anyways as things get logged to disk
return reg_eval_result

Просмотреть файл

@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Callable, Type
import os
import torch
from torch.utils.data import DataLoader
from torch import nn, Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from overrides import overrides, EnforceOverrides
from archai.common.config import Config
from archai.common import common, utils
from archai.nas.model import Model
from archai.nas.model_desc import ModelDesc
from archai.nas.arch_trainer import ArchTrainer
from archai.common.trainer import Trainer
from archai.nas.vis_model_desc import draw_model_desc
from archai.common.checkpoint import CheckPoint
TFreezeTrainer = Optional[Type['FreezeTrainer']]
class FreezeTrainer(ArchTrainer, EnforceOverrides):
def __init__(self, conf_train: Config, model: nn.Module,
checkpoint:Optional[CheckPoint]) -> None:
super().__init__(conf_train, model, checkpoint)
# region config vars specific to freeze trainer
self.conf_train = conf_train
self._val_top1_acc = conf_train['val_top1_acc_threshold']
# endregion
@overrides
def post_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
super()._post_epoch(train_dl, val_dl)
# if current validation accuracy is above
# freeze everything other than the last layer
best_val_top1_avg = self._metrics.best_val_top1()
if best_val_top1_avg >= self._val_top1_acc:
# freeze everything other than the last layer
self.freeze_but_last_layer()
def freeze_but_last_layer(self) -> None:
# first freeze all parameters
for param in self.model.parameters():
param.requires_grad = False
# NOTE: assumption here is that the last
# layer has the word 'logits' in the name string
# e.g. logits_op._op.weight, logits_op._op.bias
# e.g. _aux_towers.13.logits_op.weight, _aux_towers.13.logits_op.bias
# TODO: confirm from Shital that this is good!
for name, param in self.model.named_parameters():
if 'logits' in name:
param.requires_grad = True

Просмотреть файл

@ -196,7 +196,7 @@ class Trainer(EnforceOverrides):
if self._metrics.epochs() % self._validation_freq == 0 or \
self._metrics.epochs() >= self._epochs: # last epoch
# these asserts makes sure train and val are not ovrlapiing
# these asserts makes sure train and val are not overlapping
# assert train_dl.sampler.epoch == val_dl.sampler.epoch
# tidx = list(train_dl.sampler)
# vidx = list(val_dl.sampler)