зеркало из https://github.com/microsoft/archai.git
Added naswotrain_evaluator.py.
This commit is contained in:
Родитель
17875f60be
Коммит
d61e33b753
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from archai.nas.evaluater import Evaluater
|
||||
from typing import Optional, Tuple
|
||||
import importlib
|
||||
import sys
|
||||
import string
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from overrides import overrides, EnforceOverrides
|
||||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger
|
||||
from archai.datasets import data
|
||||
from archai.nas.model_desc import ModelDesc
|
||||
from archai.nas.model_desc_builder import ModelDescBuilder
|
||||
from archai.nas import nas_utils
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.metrics import EpochMetrics, Metrics
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
|
||||
from .naswotrain_trainer import NaswotrainTrainer
|
||||
|
||||
|
||||
class NaswotrainEvaluator(Evaluater, EnforceOverrides):
|
||||
@overrides
|
||||
def train_model(self, conf_train:Config, model:nn.Module,
|
||||
checkpoint:Optional[CheckPoint])->Metrics:
|
||||
conf_loader = conf_train['loader']
|
||||
conf_train = conf_train['trainer']
|
||||
|
||||
# get data
|
||||
train_dl, test_dl = self.get_data(conf_loader)
|
||||
|
||||
trainer = NaswotrainTrainer(conf_train, model, checkpoint)
|
||||
train_metrics = trainer.fit(train_dl, test_dl)
|
||||
return train_metrics
|
|
@ -15,6 +15,7 @@ from archai.nas.evaluater import Evaluater, EvalResult
|
|||
from archai.common.common import get_expdir, logger
|
||||
|
||||
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
|
||||
from archai.algos.naswotrain.naswotrain_evaluator import NaswotrainEvaluator
|
||||
from .freeze_evaluator import FreezeEvaluator
|
||||
|
||||
class FreezeExperimentRunner(ExperimentRunner):
|
||||
|
@ -28,12 +29,22 @@ class FreezeExperimentRunner(ExperimentRunner):
|
|||
|
||||
@overrides
|
||||
def run_eval(self, conf_eval:Config)->EvalResult:
|
||||
# without training architecture evaluation score
|
||||
# ---------------------------------------
|
||||
logger.pushd('naswotrain_evaluate')
|
||||
naswotrain_evaler = NaswotrainEvaluator()
|
||||
naswotrain_eval_result = naswotrain_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
logger.popd()
|
||||
|
||||
# regular evaluation of the architecture
|
||||
# --------------------------------------
|
||||
reg_eval_result = None
|
||||
if conf_eval['trainer']['proxynas']['train_regular']:
|
||||
evaler = self.evaluater()
|
||||
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
|
||||
|
||||
# freeze train evaluation of the architecture
|
||||
# -------------------------------------------
|
||||
# change relevant parts of conv_eval to ensure that freeze_evaler
|
||||
# doesn't resume from checkpoints created by evaler and saves
|
||||
# models to different names as well
|
||||
|
|
Загрузка…
Ссылка в новой задаче