Added naswotrain_evaluator.py.

This commit is contained in:
Debadeepta Dey 2020-12-11 16:41:30 -08:00 коммит произвёл Gustavo Rosa
Родитель 17875f60be
Коммит d61e33b753
2 изменённых файлов: 55 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from archai.nas.evaluater import Evaluater
from typing import Optional, Tuple
import importlib
import sys
import string
import os
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from overrides import overrides, EnforceOverrides
from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas import nas_utils
from archai.common import ml_utils, utils
from archai.common.metrics import EpochMetrics, Metrics
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from .naswotrain_trainer import NaswotrainTrainer
class NaswotrainEvaluator(Evaluater, EnforceOverrides):
@overrides
def train_model(self, conf_train:Config, model:nn.Module,
checkpoint:Optional[CheckPoint])->Metrics:
conf_loader = conf_train['loader']
conf_train = conf_train['trainer']
# get data
train_dl, test_dl = self.get_data(conf_loader)
trainer = NaswotrainTrainer(conf_train, model, checkpoint)
train_metrics = trainer.fit(train_dl, test_dl)
return train_metrics

Просмотреть файл

@ -15,6 +15,7 @@ from archai.nas.evaluater import Evaluater, EvalResult
from archai.common.common import get_expdir, logger
from archai.algos.random.random_model_desc_builder import RandomModelDescBuilder
from archai.algos.naswotrain.naswotrain_evaluator import NaswotrainEvaluator
from .freeze_evaluator import FreezeEvaluator
class FreezeExperimentRunner(ExperimentRunner):
@ -28,12 +29,22 @@ class FreezeExperimentRunner(ExperimentRunner):
@overrides
def run_eval(self, conf_eval:Config)->EvalResult:
# without training architecture evaluation score
# ---------------------------------------
logger.pushd('naswotrain_evaluate')
naswotrain_evaler = NaswotrainEvaluator()
naswotrain_eval_result = naswotrain_evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
logger.popd()
# regular evaluation of the architecture
# --------------------------------------
reg_eval_result = None
if conf_eval['trainer']['proxynas']['train_regular']:
evaler = self.evaluater()
reg_eval_result = evaler.evaluate(conf_eval, model_desc_builder=self.model_desc_builder())
# freeze train evaluation of the architecture
# -------------------------------------------
# change relevant parts of conv_eval to ensure that freeze_evaler
# doesn't resume from checkpoints created by evaler and saves
# models to different names as well