This commit is contained in:
Shital Shah 2020-05-06 01:15:00 -07:00
Родитель 8ff553fd36
Коммит 356af62d47
2 изменённых файлов: 19 добавлений и 14 удалений

Просмотреть файл

@ -5,13 +5,14 @@ import os
from overrides import EnforceOverrides
from .cell_builder import CellBuilder
from .arch_trainer import TArchTrainer
from ..common.common import common_init
from ..common import utils
from ..common.config import Config
from . import evaluate
from .search import Search
from archai.nas.cell_builder import CellBuilder
from archai.nas.arch_trainer import TArchTrainer
from archai.common.common import common_init
from archai.common import utils
from archai.common.config import Config
from archai.nas import evaluate
from archai.nas.search import Search
from archai.nas.finalizers import Finalizers
class ExperimentRunner(ABC, EnforceOverrides):
@ -29,8 +30,9 @@ class ExperimentRunner(ABC, EnforceOverrides):
def _run_search(self, conf_search:Config)->None:
cell_builder = self.cell_builder()
trainer_class = self.trainer_class()
finalizers = self.finalizers()
search = Search(conf_search, cell_builder, trainer_class)
search = Search(conf_search, cell_builder, trainer_class, finalizers)
search.generate_pareto()
def _init(self, suffix:str)->Config:
@ -88,4 +90,7 @@ class ExperimentRunner(ABC, EnforceOverrides):
@abstractmethod
def trainer_class(self)->TArchTrainer:
pass
pass
def finalizers(self)->Finalizers:
return Finalizers()

Просмотреть файл

@ -87,7 +87,7 @@ class SearchResult:
class Search:
def __init__(self, conf_search:Config, cell_builder:Optional[CellBuilder],
trainer_class:TArchTrainer) -> None:
trainer_class:TArchTrainer, finalizers:Finalizers) -> None:
# region config vars
conf_checkpoint = conf_search['checkpoint']
resume = conf_search['resume']
@ -113,6 +113,7 @@ class Search:
self.cell_builder = cell_builder
self.trainer_class = trainer_class
self.finalizers = finalizers
self._data_cache = {}
self._parito_filepath = utils.full_path(pareto_summary_filename)
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
@ -330,14 +331,13 @@ class Search:
trainer = Trainer(conf_trainer, model, checkpoint=None)
train_metrics = trainer.fit(train_dl, val_dl)
metrics_stats = Search._create_metrics_stats(model, train_metrics)
metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers)
logger.popd()
return metrics_stats
@staticmethod
def _create_metrics_stats(model:Model, train_metrics:Metrics)->MetricsStats:
finalizers = Finalizers()
def _create_metrics_stats(model:Model, train_metrics:Metrics, finalizers:Finalizers)->MetricsStats:
finalized = finalizers.finalize_model(model, restore_device=False)
# model stats is doing some hooks so do it last
model_stats = tw.ModelStats(model, [1,3,32,32],# TODO: remove this hard coding
@ -362,7 +362,7 @@ class Search:
arch_trainer = self.trainer_class(self.conf_train, model, checkpoint=None)
train_metrics = arch_trainer.fit(train_dl, val_dl)
metrics_stats = Search._create_metrics_stats(model, train_metrics)
metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers)
found_desc = metrics_stats.model_desc
else: # if no trainer needed, for example, for random search
found_desc = model_desc