зеркало из https://github.com/microsoft/archai.git
plugable finalizers
This commit is contained in:
Родитель
8ff553fd36
Коммит
356af62d47
|
@ -5,13 +5,14 @@ import os
|
|||
|
||||
from overrides import EnforceOverrides
|
||||
|
||||
from .cell_builder import CellBuilder
|
||||
from .arch_trainer import TArchTrainer
|
||||
from ..common.common import common_init
|
||||
from ..common import utils
|
||||
from ..common.config import Config
|
||||
from . import evaluate
|
||||
from .search import Search
|
||||
from archai.nas.cell_builder import CellBuilder
|
||||
from archai.nas.arch_trainer import TArchTrainer
|
||||
from archai.common.common import common_init
|
||||
from archai.common import utils
|
||||
from archai.common.config import Config
|
||||
from archai.nas import evaluate
|
||||
from archai.nas.search import Search
|
||||
from archai.nas.finalizers import Finalizers
|
||||
|
||||
|
||||
class ExperimentRunner(ABC, EnforceOverrides):
|
||||
|
@ -29,8 +30,9 @@ class ExperimentRunner(ABC, EnforceOverrides):
|
|||
def _run_search(self, conf_search:Config)->None:
|
||||
cell_builder = self.cell_builder()
|
||||
trainer_class = self.trainer_class()
|
||||
finalizers = self.finalizers()
|
||||
|
||||
search = Search(conf_search, cell_builder, trainer_class)
|
||||
search = Search(conf_search, cell_builder, trainer_class, finalizers)
|
||||
search.generate_pareto()
|
||||
|
||||
def _init(self, suffix:str)->Config:
|
||||
|
@ -88,4 +90,7 @@ class ExperimentRunner(ABC, EnforceOverrides):
|
|||
|
||||
@abstractmethod
|
||||
def trainer_class(self)->TArchTrainer:
|
||||
pass
|
||||
pass
|
||||
|
||||
def finalizers(self)->Finalizers:
|
||||
return Finalizers()
|
|
@ -87,7 +87,7 @@ class SearchResult:
|
|||
|
||||
class Search:
|
||||
def __init__(self, conf_search:Config, cell_builder:Optional[CellBuilder],
|
||||
trainer_class:TArchTrainer) -> None:
|
||||
trainer_class:TArchTrainer, finalizers:Finalizers) -> None:
|
||||
# region config vars
|
||||
conf_checkpoint = conf_search['checkpoint']
|
||||
resume = conf_search['resume']
|
||||
|
@ -113,6 +113,7 @@ class Search:
|
|||
|
||||
self.cell_builder = cell_builder
|
||||
self.trainer_class = trainer_class
|
||||
self.finalizers = finalizers
|
||||
self._data_cache = {}
|
||||
self._parito_filepath = utils.full_path(pareto_summary_filename)
|
||||
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||
|
@ -330,14 +331,13 @@ class Search:
|
|||
trainer = Trainer(conf_trainer, model, checkpoint=None)
|
||||
train_metrics = trainer.fit(train_dl, val_dl)
|
||||
|
||||
metrics_stats = Search._create_metrics_stats(model, train_metrics)
|
||||
metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers)
|
||||
|
||||
logger.popd()
|
||||
return metrics_stats
|
||||
|
||||
@staticmethod
|
||||
def _create_metrics_stats(model:Model, train_metrics:Metrics)->MetricsStats:
|
||||
finalizers = Finalizers()
|
||||
def _create_metrics_stats(model:Model, train_metrics:Metrics, finalizers:Finalizers)->MetricsStats:
|
||||
finalized = finalizers.finalize_model(model, restore_device=False)
|
||||
# model stats is doing some hooks so do it last
|
||||
model_stats = tw.ModelStats(model, [1,3,32,32],# TODO: remove this hard coding
|
||||
|
@ -362,7 +362,7 @@ class Search:
|
|||
arch_trainer = self.trainer_class(self.conf_train, model, checkpoint=None)
|
||||
train_metrics = arch_trainer.fit(train_dl, val_dl)
|
||||
|
||||
metrics_stats = Search._create_metrics_stats(model, train_metrics)
|
||||
metrics_stats = Search._create_metrics_stats(model, train_metrics, self.finalizers)
|
||||
found_desc = metrics_stats.model_desc
|
||||
else: # if no trainer needed, for example, for random search
|
||||
found_desc = model_desc
|
||||
|
|
Загрузка…
Ссылка в новой задаче