зеркало из https://github.com/microsoft/nni.git
NAS benchmark (stage 2) - space and evaluator (#5380)
This commit is contained in:
Родитель
13028280ae
Коммит
1f6aedc48f
|
@ -7,7 +7,7 @@ torch == 1.13.1+cpu ; sys_platform != "darwin"
|
|||
torch == 1.13.1 ; sys_platform == "darwin"
|
||||
torchvision == 0.14.1+cpu ; sys_platform != "darwin"
|
||||
torchvision == 0.14.1 ; sys_platform == "darwin"
|
||||
pytorch-lightning >= 1.6.1
|
||||
pytorch-lightning >= 1.6.1, < 2.0
|
||||
torchmetrics
|
||||
lightgbm
|
||||
onnx
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
tensorflow
|
||||
torch == 1.13.1+cu117
|
||||
torchvision == 0.14.1+cu117
|
||||
pytorch-lightning >= 1.6.1
|
||||
pytorch-lightning >= 1.6.1, < 2.0
|
||||
|
||||
# for full-test-compression
|
||||
-f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
|
|
@ -8,9 +8,9 @@ import shutil
|
|||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
|
||||
__all__ = ['NNI_BLOB', 'load_or_download_file', 'upload_file', 'nni_cache_home']
|
||||
|
@ -55,25 +55,44 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
|
|||
elif download:
|
||||
_logger.info('"%s" does not exist. Downloading "%s"', local_path, download_url)
|
||||
|
||||
# Follow download implementation in torchvision:
|
||||
# We deliberately save it in a temp file and move it after
|
||||
# download is complete. This prevents a local working checkpoint
|
||||
# being overridden by a broken download.
|
||||
dst_dir = Path(local_path).parent
|
||||
dst_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
||||
r = requests.get(download_url, stream=True)
|
||||
total_length: Optional[str] = r.headers.get('content-length')
|
||||
assert total_length is not None, f'Content length is not found in the response of {download_url}'
|
||||
with tqdm.tqdm(total=int(total_length), disable=not progress,
|
||||
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for chunk in r.iter_content(8192):
|
||||
f.write(chunk)
|
||||
sha256.update(chunk)
|
||||
pbar.update(len(chunk))
|
||||
f.flush()
|
||||
f.close()
|
||||
if which('azcopy') is not None:
|
||||
output_level = []
|
||||
if not progress:
|
||||
output_level = ['--output-level', 'quiet']
|
||||
subprocess.run(['azcopy', 'copy', download_url, local_path] + output_level, check=True)
|
||||
|
||||
# Update hash as a verification
|
||||
with Path(local_path).open('rb') as fr:
|
||||
while True:
|
||||
chunk = fr.read(8192)
|
||||
if len(chunk) == 0:
|
||||
break
|
||||
sha256.update(chunk)
|
||||
|
||||
else:
|
||||
_logger.info('azcopy is not installed. Fall back to use requests.')
|
||||
|
||||
import requests
|
||||
|
||||
# Follow download implementation in torchvision:
|
||||
# We deliberately save it in a temp file and move it after
|
||||
# download is complete. This prevents a local working checkpoint
|
||||
# being overridden by a broken download.
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
||||
r = requests.get(download_url, stream=True)
|
||||
total_length: Optional[str] = r.headers.get('content-length')
|
||||
assert total_length is not None, f'Content length is not found in the response of {download_url}'
|
||||
with tqdm.tqdm(total=int(total_length), disable=not progress,
|
||||
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for chunk in r.iter_content(8192):
|
||||
f.write(chunk)
|
||||
sha256.update(chunk)
|
||||
pbar.update(len(chunk))
|
||||
f.flush()
|
||||
f.close()
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
'Download is not enabled, and file does not exist: {}. Please set download=True.'.format(local_path)
|
||||
|
@ -81,7 +100,8 @@ def load_or_download_file(local_path: str, download_url: str, download: bool = F
|
|||
|
||||
digest = sha256.hexdigest()
|
||||
if not digest.startswith(hash_prefix):
|
||||
raise RuntimeError('Invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest))
|
||||
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}") for {local_path}. '
|
||||
'Please delete the file and try re-downloading.')
|
||||
|
||||
if f is not None:
|
||||
shutil.move(f.name, local_path)
|
||||
|
|
|
@ -7,6 +7,6 @@ try:
|
|||
except ImportError:
|
||||
warnings.warn('peewee is not installed. Please install it to use NAS benchmarks.')
|
||||
|
||||
# from .evaluator import *
|
||||
# from .space import *
|
||||
from .evaluator import *
|
||||
from .space import *
|
||||
from .utils import load_benchmark, download_benchmark
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['BenchmarkEvaluator', 'NasBench101Benchmark', 'NasBench201Benchmark']
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
|
||||
import nni
|
||||
from nni.mutable import Sample, Mutable, LabeledMutable, Categorical, CategoricalMultiple, label_scope
|
||||
from nni.nas.evaluator import Evaluator
|
||||
from nni.nas.space import ExecutableModelSpace
|
||||
|
||||
from .space import SlimBenchmarkSpace
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _report_intermediates_and_final(query_result: list[Any], metric: str, query: str, scale: float = 1.) -> tuple[float, list[float]]:
|
||||
"""Convert benchmark results from database to results reported to NNI.
|
||||
|
||||
Utility function for :meth:`BenchmarkEvaluator.evaluate`.
|
||||
"""
|
||||
if not query_result:
|
||||
raise ValueError('Invalid query. Results from benchmark is empty: ' + query)
|
||||
if len(query_result) > 1:
|
||||
query_result = random.choice(query_result)
|
||||
else:
|
||||
query_result = query_result[0]
|
||||
query_dict = cast(dict, query_result)
|
||||
for i in query_dict.get('intermediates', []):
|
||||
if i[metric] is not None:
|
||||
nni.report_intermediate_result(i[metric] * scale)
|
||||
nni.report_final_result(query_dict[metric] * scale)
|
||||
return query_dict[metric]
|
||||
|
||||
|
||||
def _common_label_scope(labels: list[str]) -> str:
|
||||
"""Find the longest prefix of all labels.
|
||||
|
||||
The prefix must ends with ``/``.
|
||||
"""
|
||||
if not labels:
|
||||
return ''
|
||||
for i in range(len(labels[0]) - 1, -1, -1):
|
||||
if labels[0][i] == '/' and all(s.startswith(labels[0][:i + 1]) for s in labels):
|
||||
return labels[0][:i + 1]
|
||||
return ''
|
||||
|
||||
|
||||
def _strip_common_label_scope(sample: Sample) -> Sample:
|
||||
"""Strip the common label scope from the sample."""
|
||||
scope_name = _common_label_scope(list(sample))
|
||||
if not scope_name:
|
||||
return sample
|
||||
return {k[len(scope_name):]: v for k, v in sample.items()}
|
||||
|
||||
|
||||
def _sorted_dict(sample: Sample) -> Sample:
|
||||
"""Sort the keys of a dict."""
|
||||
return dict(sorted(sample.items()))
|
||||
|
||||
|
||||
class BenchmarkEvaluator(Evaluator):
|
||||
"""A special kind of evaluator that does not run real training, but queries a database."""
|
||||
|
||||
@classmethod
|
||||
def default_space(cls) -> SlimBenchmarkSpace:
|
||||
"""Return the default search space benchmarked by this evaluator.
|
||||
|
||||
Subclass should override this.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def validate_space(self, space: Mutable) -> dict[str, LabeledMutable]:
|
||||
"""Validate the search space. Raise exception if invalid. Returns the validated space.
|
||||
|
||||
By default, it will cross-check with the :meth:`default_space`, and return the default space.
|
||||
Differences in common scope names will be ignored.
|
||||
|
||||
I think the default implementation should work for most cases.
|
||||
But subclass can still override this method for looser or tighter validation.
|
||||
"""
|
||||
current_space = space.simplify()
|
||||
|
||||
scope_name = _common_label_scope(list(current_space))
|
||||
if not scope_name:
|
||||
default_space = self.default_space()
|
||||
else:
|
||||
with label_scope(scope_name.rstrip('/')):
|
||||
default_space = self.default_space()
|
||||
|
||||
if SlimBenchmarkSpace(current_space) != default_space:
|
||||
raise ValueError(f'Expect space to be {default_space}, got {current_space}')
|
||||
|
||||
return current_space
|
||||
|
||||
def evaluate(self, sample: Sample) -> Any:
|
||||
""":meth:`evaluate` receives a sample and returns a float score.
|
||||
It also reports intermediate and final results through NNI trial API.
|
||||
|
||||
Necessary format conversion and database query should be done in this method.
|
||||
|
||||
It is the main interface of this class. Subclass should override this.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _execute(self, model: ExecutableModelSpace) -> Any:
|
||||
"""Execute the model with the sample."""
|
||||
|
||||
from .space import BenchmarkModelSpace
|
||||
if not isinstance(model, BenchmarkModelSpace):
|
||||
warnings.warn('It would be better to use BenchmarkModelSpace for benchmarking to avoid '
|
||||
'unnecessary overhead and silent mistakes.')
|
||||
if model.sample is None:
|
||||
raise ValueError('Model can not be evaluted because it has not been sampled yet.')
|
||||
|
||||
return self.evaluate(model.sample)
|
||||
|
||||
|
||||
class NasBench101Benchmark(BenchmarkEvaluator):
|
||||
"""Benchmark evaluator for NAS-Bench-101.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_epochs
|
||||
Queried ``num_epochs``.
|
||||
metric
|
||||
Queried metric.
|
||||
include_intermediates
|
||||
Whether to report intermediate results.
|
||||
|
||||
See Also
|
||||
--------
|
||||
nni.nas.benchmark.nasbench101.query_nb101_trial_stats
|
||||
nni.nas.benchmark.nasbench101.Nb101TrialConfig
|
||||
"""
|
||||
|
||||
def __init__(self, num_epochs: int = 108, metric: str = 'valid_acc', include_intermediates: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.metric = metric
|
||||
self.include_intermediates = include_intermediates
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
@classmethod
|
||||
def default_space(cls) -> SlimBenchmarkSpace:
|
||||
from nni.nas.hub.pytorch.modules.nasbench101 import NasBench101CellConstraint, NasBench101Cell
|
||||
op_candidates = ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3']
|
||||
|
||||
# For readability, expand it here.
|
||||
num_nodes = NasBench101Cell._num_nodes_discrete(7)
|
||||
# num_nodes = Categorical([2, 3, 4, 5, 6, 7], label='num_nodes')
|
||||
ops = [
|
||||
Categorical(op_candidates, label='op1'),
|
||||
Categorical(op_candidates, label='op2'),
|
||||
Categorical(op_candidates, label='op3'),
|
||||
Categorical(op_candidates, label='op4'),
|
||||
Categorical(op_candidates, label='op5')
|
||||
]
|
||||
inputs = [
|
||||
CategoricalMultiple([0], n_chosen=None, label='input1'),
|
||||
CategoricalMultiple([0, 1], n_chosen=None, label='input2'),
|
||||
CategoricalMultiple([0, 1, 2], n_chosen=None, label='input3'),
|
||||
CategoricalMultiple([0, 1, 2, 3], n_chosen=None, label='input4'),
|
||||
CategoricalMultiple([0, 1, 2, 3, 4], n_chosen=None, label='input5'),
|
||||
CategoricalMultiple([0, 1, 2, 3, 4, 5], n_chosen=None, label='input6')
|
||||
]
|
||||
constraint = NasBench101CellConstraint(9, num_nodes, ops, inputs)
|
||||
|
||||
return SlimBenchmarkSpace({
|
||||
mutable.label: mutable for mutable in itertools.chain([num_nodes], ops, inputs, [constraint])
|
||||
})
|
||||
|
||||
def evaluate(self, sample: Sample) -> Any:
|
||||
sample = _sorted_dict(_strip_common_label_scope(sample))
|
||||
_logger.debug('NasBench101 sample submitted to query: %s', sample)
|
||||
|
||||
from nni.nas.benchmark.nasbench101 import query_nb101_trial_stats
|
||||
query = query_nb101_trial_stats(sample['final'], self.num_epochs, include_intermediates=self.include_intermediates)
|
||||
return _report_intermediates_and_final(
|
||||
list(query), self.metric, str(sample), .01
|
||||
)
|
||||
|
||||
|
||||
class NasBench201Benchmark(BenchmarkEvaluator):
|
||||
"""Benchmark evaluator for NAS-Bench-201.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_epochs
|
||||
Queried ``num_epochs``.
|
||||
dataset
|
||||
Queried ``dataset``.
|
||||
metric
|
||||
Queried metric.
|
||||
include_intermediates
|
||||
Whether to report intermediate results.
|
||||
|
||||
See Also
|
||||
--------
|
||||
nni.nas.benchmark.nasbench201.query_nb201_trial_stats
|
||||
nni.nas.benchmark.nasbench201.Nb201TrialConfig
|
||||
"""
|
||||
|
||||
def __init__(self, num_epochs: int = 200, dataset: str = 'cifar100', metric: str = 'valid_acc',
|
||||
include_intermediates: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.metric = metric
|
||||
self.dataset = dataset
|
||||
self.include_intermediates = include_intermediates
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
@classmethod
|
||||
def default_space(cls) -> SlimBenchmarkSpace:
|
||||
operations = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']
|
||||
ops = [
|
||||
Categorical(operations, label='0_1'),
|
||||
Categorical(operations, label='0_2'),
|
||||
Categorical(operations, label='1_2'),
|
||||
Categorical(operations, label='0_3'),
|
||||
Categorical(operations, label='1_3'),
|
||||
Categorical(operations, label='2_3')
|
||||
]
|
||||
return SlimBenchmarkSpace({op.label: op for op in ops})
|
||||
|
||||
def evaluate(self, sample: Sample) -> Any:
|
||||
sample = _sorted_dict(_strip_common_label_scope(sample))
|
||||
_logger.debug('NasBench201 sample submitted to query: %s', sample)
|
||||
|
||||
from nni.nas.benchmark.nasbench201 import query_nb201_trial_stats
|
||||
query = query_nb201_trial_stats(sample, self.num_epochs, self.dataset, include_intermediates=self.include_intermediates)
|
||||
return _report_intermediates_and_final(
|
||||
list(query), self.metric, str(sample), .01
|
||||
)
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['BenchmarkModelSpace', 'SlimBenchmarkSpace']
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, overload
|
||||
|
||||
from nni.mutable import MutableDict
|
||||
from nni.nas.space import RawFormatModelSpace, BaseModelSpace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .evaluator import BenchmarkEvaluator
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlimBenchmarkSpace(BaseModelSpace, MutableDict):
|
||||
"""Example model space without deep learning frameworks.
|
||||
|
||||
When constructing this, the dict should've been already simplified and validated.
|
||||
|
||||
It could look like::
|
||||
|
||||
{
|
||||
'layer1': nni.choice('layer1', ['a', 'b', 'c']),
|
||||
'layer2': nni.choice('layer2', ['d', 'e', 'f']),
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class BenchmarkModelSpace(RawFormatModelSpace):
|
||||
"""
|
||||
Model space that is specialized for benchmarking.
|
||||
|
||||
We recommend using this model space for benchmarking, for its validation and efficiency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_space
|
||||
If not provided, it will be set to the default model space of the evaluator.
|
||||
evaluator
|
||||
Evaluator that will be used to benchmark the space.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Can be either::
|
||||
|
||||
BenchmarkModelSpace(evaluator)
|
||||
|
||||
or::
|
||||
|
||||
BenchmarkModelSpace(pytorch_model_space, evaluator)
|
||||
|
||||
In the case where the model space is provided, it will be validated by the evaluator and must be a match.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __init__(self, model_space: BenchmarkEvaluator):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model_space: BaseModelSpace):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, model_space: None, evaluator: BenchmarkEvaluator):
|
||||
...
|
||||
|
||||
def __init__(self, model_space: BaseModelSpace | BenchmarkEvaluator | None, evaluator: BenchmarkEvaluator | None = None):
|
||||
from .evaluator import BenchmarkEvaluator
|
||||
|
||||
if isinstance(model_space, BenchmarkEvaluator):
|
||||
assert evaluator is None
|
||||
evaluator = model_space
|
||||
model_space = None
|
||||
|
||||
if not isinstance(evaluator, BenchmarkEvaluator):
|
||||
raise ValueError(f'Expect evaluator to be BenchmarkEvaluator, got {evaluator}')
|
||||
if model_space is None:
|
||||
_logger.info('Model space is not set. Using default model space from evaluator: %s', evaluator)
|
||||
model_space = evaluator.default_space()
|
||||
else:
|
||||
evaluator.validate_space(model_space)
|
||||
|
||||
super().__init__(model_space, evaluator)
|
||||
|
||||
def executable_model(self):
|
||||
raise RuntimeError(f'{self.__class__.__name__} is not executable. Please use `sample` instead.')
|
|
@ -235,6 +235,7 @@ class RegularizedEvolution(Strategy):
|
|||
if event.model in self._running_models:
|
||||
self._running_models.remove(event.model)
|
||||
if event.model.metric is not None:
|
||||
_logger.info('[Metric] %f Sample: %s', event.model.metric, event.model.sample)
|
||||
# Even if it fails, as long as it has a metric, we add it to the population.
|
||||
assert event.model.sample is not None
|
||||
self._population.append(Individual(event.model.sample, event.model.metric))
|
||||
|
|
|
@ -40,6 +40,11 @@ stages:
|
|||
|
||||
- template: templates/download-test-data.yml
|
||||
|
||||
- script: |
|
||||
cd test
|
||||
python algo/nas/benchmark/prepare.py
|
||||
displayName: Prepare NAS benchmark
|
||||
|
||||
- script: |
|
||||
cd test
|
||||
python -m pytest algo/nas
|
||||
|
@ -47,7 +52,7 @@ stages:
|
|||
|
||||
- job: windows
|
||||
pool: nni-it-1es-windows
|
||||
timeoutInMinutes: 90
|
||||
timeoutInMinutes: 120
|
||||
|
||||
steps:
|
||||
- template: templates/check-gpu-status.yml
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import sys
|
||||
import pytest
|
||||
|
||||
import nni
|
||||
from nni.nas.benchmark import download_benchmark
|
||||
|
||||
def prepare_benchmark():
|
||||
for benchmark in ['nasbench101', 'nasbench201']:
|
||||
download_benchmark(benchmark)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cached_parameter():
|
||||
if sys.platform != 'linux':
|
||||
pytest.skip('Benchmark tests are too slow on Windows.')
|
||||
nni.trial._params = None
|
||||
nni.trial.overwrite_intermediate_seq(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
prepare_benchmark()
|
|
@ -0,0 +1,52 @@
|
|||
from nni.nas.benchmark import *
|
||||
from nni.nas.execution import SequentialExecutionEngine
|
||||
from nni.nas.strategy import *
|
||||
|
||||
from nni.nas.hub.pytorch import NasBench101, NasBench201
|
||||
|
||||
from .prepare import *
|
||||
|
||||
# TODO: tune RL and make it work
|
||||
|
||||
def test_nasbench101_with_rl():
|
||||
pytorch_space = NasBench101()
|
||||
benchmark = NasBench101Benchmark()
|
||||
exec_space = BenchmarkModelSpace.from_model(pytorch_space, benchmark)
|
||||
|
||||
engine = SequentialExecutionEngine(max_model_count=200)
|
||||
strategy = PolicyBasedRL(reward_for_invalid=0)
|
||||
strategy(exec_space, engine)
|
||||
assert list(strategy.list_models(sort=True, limit=1))[0].metric > 0.94
|
||||
|
||||
|
||||
def test_nasbench201_with_rl():
|
||||
pytorch_space = NasBench201()
|
||||
benchmark = NasBench201Benchmark()
|
||||
exec_space = BenchmarkModelSpace.from_model(pytorch_space, benchmark)
|
||||
|
||||
engine = SequentialExecutionEngine(max_model_count=200)
|
||||
strategy = PolicyBasedRL()
|
||||
strategy(exec_space, engine)
|
||||
assert list(strategy.list_models(sort=True, limit=1))[0].metric > 0.7
|
||||
|
||||
|
||||
def test_nasbench101_with_evo():
|
||||
pytorch_space = NasBench101()
|
||||
benchmark = NasBench101Benchmark()
|
||||
exec_space = BenchmarkModelSpace.from_model(pytorch_space, benchmark)
|
||||
|
||||
engine = SequentialExecutionEngine(max_model_count=200)
|
||||
strategy = RegularizedEvolution(population_size=50, sample_size=25)
|
||||
strategy(exec_space, engine)
|
||||
assert list(strategy.list_models(sort=True, limit=1))[0].metric > 0.945
|
||||
|
||||
|
||||
def test_nasbench201_with_evo():
|
||||
pytorch_space = NasBench201()
|
||||
benchmark = NasBench201Benchmark()
|
||||
exec_space = BenchmarkModelSpace.from_model(pytorch_space, benchmark)
|
||||
|
||||
engine = SequentialExecutionEngine(max_model_count=200)
|
||||
strategy = RegularizedEvolution(population_size=50, sample_size=25)
|
||||
strategy(exec_space, engine)
|
||||
assert list(strategy.list_models(sort=True, limit=1))[0].metric > 0.73
|
|
@ -0,0 +1,56 @@
|
|||
from nni.mutable import SampleValidationError
|
||||
from nni.nas.benchmark import *
|
||||
|
||||
from nni.nas.hub.pytorch import NasBench101, NasBench201
|
||||
|
||||
from .prepare import *
|
||||
|
||||
|
||||
def test_nasbench101():
|
||||
benchmark = NasBench101Benchmark()
|
||||
exec_space = BenchmarkModelSpace(benchmark)
|
||||
model = exec_space.default()
|
||||
with benchmark.mock_runtime(model):
|
||||
model.execute()
|
||||
assert 0 < model.metric < 1
|
||||
|
||||
good = bad = 0
|
||||
for _ in range(30):
|
||||
try:
|
||||
model = exec_space.random()
|
||||
with benchmark.mock_runtime(model):
|
||||
model.execute()
|
||||
assert 0 < model.metric < 1
|
||||
good += 1
|
||||
except SampleValidationError:
|
||||
bad += 1
|
||||
assert good > 0 and bad > 0
|
||||
|
||||
pytorch_space = NasBench101()
|
||||
exec_space = BenchmarkModelSpace.from_model(pytorch_space, benchmark)
|
||||
model = exec_space.default()
|
||||
with benchmark.mock_runtime(model):
|
||||
model.execute()
|
||||
assert 0 < model.metric < 1
|
||||
|
||||
|
||||
def test_nasbench201():
|
||||
benchmark = NasBench201Benchmark()
|
||||
exec_space = BenchmarkModelSpace(benchmark)
|
||||
model = exec_space.default()
|
||||
with benchmark.mock_runtime(model):
|
||||
model.execute()
|
||||
assert 0 < model.metric < 1
|
||||
|
||||
for _ in range(30):
|
||||
model = exec_space.random()
|
||||
with benchmark.mock_runtime(model):
|
||||
model.execute()
|
||||
assert 0 < model.metric < 1
|
||||
|
||||
pytorch_space = NasBench201()
|
||||
exec_space = BenchmarkModelSpace.from_model(pytorch_space, benchmark)
|
||||
model = exec_space.random()
|
||||
with benchmark.mock_runtime(model):
|
||||
model.execute()
|
||||
assert 0 < model.metric < 1
|
Загрузка…
Ссылка в новой задаче