зеркало из https://github.com/microsoft/nni.git
NAS strategy (stage 2) - utils and bruteforce strategy (#5367)
This commit is contained in:
Родитель
808ac9a3fc
Коммит
4b6c0be045
|
@ -1,138 +1,235 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['GridSearch', 'Random']
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Sequence, Optional
|
||||
import warnings
|
||||
from typing import Any, Iterable
|
||||
|
||||
from nni.nas.execution import submit_models, query_available_resources, budget_exhausted
|
||||
from nni.nas.mutable import InvalidMutation, Sampler
|
||||
from .base import BaseStrategy
|
||||
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
|
||||
from numpy.random import RandomState
|
||||
|
||||
from nni.mutable import Sample, SampleValidationError
|
||||
from nni.nas.space import MutationSampler, ExecutableModelSpace, Mutator
|
||||
|
||||
from .base import Strategy
|
||||
from .utils import DeduplicationHelper, RetrySamplingHelper
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def grid_generator(search_space: Dict[Any, List[Any]], shuffle=True):
|
||||
keys = list(search_space.keys())
|
||||
search_space_values = copy.deepcopy(list(search_space.values()))
|
||||
if shuffle:
|
||||
for values in search_space_values:
|
||||
random.shuffle(values)
|
||||
for values in itertools.product(*search_space_values):
|
||||
yield {key: value for key, value in zip(keys, values)}
|
||||
|
||||
|
||||
def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500):
|
||||
keys = list(search_space.keys())
|
||||
history = set()
|
||||
search_space_values = copy.deepcopy(list(search_space.values()))
|
||||
while True:
|
||||
selected: Optional[Sequence[int]] = None
|
||||
for retry_count in range(retries):
|
||||
selected = [random.choice(v) for v in search_space_values]
|
||||
if not dedup:
|
||||
break
|
||||
selected = tuple(selected)
|
||||
if selected not in history:
|
||||
history.add(selected)
|
||||
break
|
||||
if retry_count + 1 == retries:
|
||||
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
|
||||
return
|
||||
assert selected is not None, 'Retry attempts exhausted.'
|
||||
yield {key: value for key, value in zip(keys, selected)}
|
||||
|
||||
|
||||
class GridSearch(BaseStrategy):
|
||||
class GridSearch(Strategy):
|
||||
"""
|
||||
Traverse the search space and try all the possible combinations one by one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shuffle : bool
|
||||
Shuffle the order in a candidate list, so that they are tried in a random order. Default: true.
|
||||
shuffle
|
||||
Shuffle the order in a candidate list, so that they are tried in a random order.
|
||||
Currently, the implementation is a pseudo-random shuffle, which only shuffles the order of every 100 candidates.
|
||||
seed
|
||||
Random seed.
|
||||
"""
|
||||
|
||||
def __init__(self, shuffle=True):
|
||||
self._polling_interval = 2.
|
||||
_shuffle_buffer_size: int = 100
|
||||
_granularity_patience: int = 3 # stop increasing granularity after this many times of no new sample found
|
||||
|
||||
def __init__(self, *, shuffle: bool = True, seed: int | None = None, dedup: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.shuffle = shuffle
|
||||
|
||||
def run(self, base_model, applied_mutators):
|
||||
search_space = dry_run_for_search_space(base_model, applied_mutators)
|
||||
for sample in grid_generator(search_space, shuffle=self.shuffle):
|
||||
_logger.debug('New model created. Waiting for resource. %s', str(sample))
|
||||
while query_available_resources() <= 0:
|
||||
if budget_exhausted():
|
||||
# Internal only:
|
||||
# Do not try the same configuration twice.
|
||||
# Turning it off might result in duplications when the space has an infinite size,
|
||||
# or the strategy tries to resume from a checkpoint,
|
||||
# but might improve memory efficiency in extreme cases.
|
||||
self._dedup = DeduplicationHelper() if dedup else None
|
||||
self._granularity = 1
|
||||
self._granularity_processed: int | None = None
|
||||
self._no_sample_found_counter = 0
|
||||
self._random_state = RandomState(seed)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'shuffle={self.shuffle}, dedup={self._dedup is not None}'
|
||||
|
||||
def _grid_generator(self, model_space: ExecutableModelSpace) -> Iterable[ExecutableModelSpace]:
|
||||
if self._no_sample_found_counter >= self._granularity_patience:
|
||||
_logger.info('Patience already run out (%d > %d). Nothing to search.',
|
||||
self._no_sample_found_counter, self._granularity_patience)
|
||||
return
|
||||
|
||||
finite = self._space_validation(model_space)
|
||||
|
||||
while True:
|
||||
new_sample_found = False
|
||||
for model in model_space.grid(granularity=self._granularity):
|
||||
if self._dedup is not None and not self._dedup.dedup(model.sample):
|
||||
continue
|
||||
|
||||
new_sample_found = True
|
||||
yield model
|
||||
|
||||
if not new_sample_found:
|
||||
self._no_sample_found_counter += 1
|
||||
_logger.info('No new sample found when granularity is %d. Current patience: %d.',
|
||||
self._granularity, self._no_sample_found_counter)
|
||||
if self._no_sample_found_counter >= self._granularity_patience:
|
||||
_logger.info('No new sample found for %d times. Stop increasing granularity.',
|
||||
self._granularity_patience)
|
||||
break
|
||||
else:
|
||||
self._no_sample_found_counter = 0
|
||||
|
||||
if finite:
|
||||
_logger.info('Space is finite. Grid generation is complete.')
|
||||
break
|
||||
|
||||
self._granularity += 1
|
||||
_logger.info('Space is infinite. Increasing granularity to %d.', self._granularity)
|
||||
|
||||
def _run(self) -> None:
|
||||
generator = self._grid_generator(self.model_space)
|
||||
|
||||
if self.shuffle:
|
||||
# Shuffle the order of every `_shuffle_buffer_size` candidates.
|
||||
shuffle_buffer = []
|
||||
generator_running = True
|
||||
|
||||
# Already generated does not mean already submitted.
|
||||
# We need to keep track of the granularity actually processed,
|
||||
# to avoid skipping granularities when resuming.
|
||||
self._granularity_processed = self._granularity
|
||||
|
||||
while generator_running:
|
||||
should_submit = False
|
||||
try:
|
||||
next_model = next(generator)
|
||||
shuffle_buffer.append(next_model)
|
||||
if len(shuffle_buffer) == self._shuffle_buffer_size:
|
||||
should_submit = True
|
||||
except StopIteration:
|
||||
# Submit the final models.
|
||||
should_submit = True
|
||||
generator_running = False
|
||||
|
||||
if should_submit:
|
||||
# Submit models and clear the shuffle buffer.
|
||||
self._random_state.shuffle(shuffle_buffer)
|
||||
for model in shuffle_buffer:
|
||||
if not self.wait_for_resource():
|
||||
_logger.info('Budget exhausted, but search space is not exhausted.')
|
||||
return
|
||||
self.engine.submit_models(model)
|
||||
shuffle_buffer = []
|
||||
|
||||
# Update granularity processed.
|
||||
self._granularity_processed = self._granularity
|
||||
|
||||
else:
|
||||
# Keep this in a separate branch because it's very simple.
|
||||
for model in generator:
|
||||
if not self.wait_for_resource():
|
||||
_logger.info('Budget exhausted, but search space is not exhausted.')
|
||||
return
|
||||
time.sleep(self._polling_interval)
|
||||
submit_models(get_targeted_model(base_model, applied_mutators, sample))
|
||||
self.engine.submit_models(model)
|
||||
|
||||
def _space_validation(self, model_space: ExecutableModelSpace) -> bool:
|
||||
"""Check whether the space is supported by grid search.
|
||||
|
||||
Return true if the space is finite, false if it's not.
|
||||
Raise error if it's not supported.
|
||||
"""
|
||||
for mutable in model_space.simplify().values():
|
||||
# This method will raise error if grid is not implemented.
|
||||
if len(list(mutable.grid(granularity=1))) != len(list(mutable.grid(granularity=1 + self._granularity_patience))):
|
||||
return False
|
||||
return True
|
||||
|
||||
class _RandomSampler(Sampler):
|
||||
def choice(self, candidates, mutator, model, index):
|
||||
return random.choice(candidates)
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self._granularity = state_dict['granularity']
|
||||
self._no_sample_found_counter = state_dict['no_sample_found_counter']
|
||||
self._random_state.set_state(state_dict['random_state'])
|
||||
_logger.info('Grid search will resume from granularity %d.', self._granularity)
|
||||
if self._dedup is not None:
|
||||
self._dedup.load_state_dict(state_dict)
|
||||
else:
|
||||
_logger.info('Grid search would possibly yield duplicate samples since dedup is turned off.')
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
result = {'random_state': self._random_state.get_state()}
|
||||
if self._granularity_processed is None:
|
||||
result.update(granularity=self._granularity, no_sample_found_counter=self._no_sample_found_counter)
|
||||
else:
|
||||
result.update(granularity=self._granularity_processed, no_sample_found_counter=0)
|
||||
|
||||
class Random(BaseStrategy):
|
||||
if self._dedup is not None:
|
||||
result.update(self._dedup.state_dict())
|
||||
return result
|
||||
|
||||
class Random(Strategy):
|
||||
"""
|
||||
Random search on the search space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
variational : bool
|
||||
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
|
||||
dedup : bool
|
||||
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
|
||||
model_filter: Callable[[Model], bool]
|
||||
Feed the model and return a bool. This will filter the models in search space and select which to submit.
|
||||
dedup
|
||||
Do not try the same configuration twice.
|
||||
seed
|
||||
Random seed.
|
||||
"""
|
||||
|
||||
def __init__(self, variational=False, dedup=True, model_filter=None):
|
||||
self.variational = variational
|
||||
self.dedup = dedup
|
||||
if variational and dedup:
|
||||
raise ValueError('Dedup is not supported in variational mode.')
|
||||
self.random_sampler = _RandomSampler()
|
||||
self._polling_interval = 2.
|
||||
self.filter = model_filter
|
||||
_duplicate_retry = 500
|
||||
|
||||
def run(self, base_model, applied_mutators):
|
||||
if self.variational:
|
||||
_logger.info('Random search running in variational mode.')
|
||||
sampler = _RandomSampler()
|
||||
for mutator in applied_mutators:
|
||||
mutator.bind_sampler(sampler)
|
||||
while True:
|
||||
avail_resource = query_available_resources()
|
||||
if avail_resource > 0:
|
||||
model = base_model
|
||||
for mutator in applied_mutators:
|
||||
model = mutator.apply(model)
|
||||
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
|
||||
if filter_model(self.filter, model):
|
||||
submit_models(model)
|
||||
elif budget_exhausted():
|
||||
break
|
||||
else:
|
||||
time.sleep(self._polling_interval)
|
||||
else:
|
||||
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
|
||||
search_space = dry_run_for_search_space(base_model, applied_mutators)
|
||||
for sample in random_generator(search_space, dedup=self.dedup):
|
||||
_logger.debug('New model created. Waiting for resource. %s', str(sample))
|
||||
while query_available_resources() <= 0:
|
||||
if budget_exhausted():
|
||||
return
|
||||
time.sleep(self._polling_interval)
|
||||
_logger.debug('Still waiting for resource.')
|
||||
try:
|
||||
model = get_targeted_model(base_model, applied_mutators, sample)
|
||||
if filter_model(self.filter, model):
|
||||
_logger.debug('Submitting model: %s', model)
|
||||
submit_models(model)
|
||||
except InvalidMutation as e:
|
||||
_logger.warning(f'Invalid mutation: {e}. Skip.')
|
||||
def __init__(self, *, dedup: bool = True, seed: int | None = None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if 'variational' in kwargs or 'model_filter' in kwargs:
|
||||
warnings.warn('Variational and model filter are no longer supported in random search and will be removed in future releases.',
|
||||
DeprecationWarning)
|
||||
|
||||
self._dedup_helper = DeduplicationHelper(raise_on_dup=True) if dedup else None
|
||||
self._retry_helper = RetrySamplingHelper(self._duplicate_retry)
|
||||
|
||||
self._random_state = RandomState(seed)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dedup={self._dedup_helper is not None}'
|
||||
|
||||
def random(self, model_space: ExecutableModelSpace) -> ExecutableModelSpace:
|
||||
"""Generate a random model from the space."""
|
||||
sample: Sample = {}
|
||||
model = model_space.random(random_state=self._random_state, memo=sample)
|
||||
if self._dedup_helper is not None:
|
||||
self._dedup_helper.dedup(sample)
|
||||
return model
|
||||
|
||||
def _run(self) -> None:
|
||||
while True:
|
||||
# Random search needs retry to:
|
||||
# 1. Generate new when dedup is on.
|
||||
# 2. Retry when the sample is invalid.
|
||||
model = self._retry_helper.retry(self.random, self.model_space)
|
||||
if model is None:
|
||||
_logger.info('Random generation has run out of patience. There is nothing to search. Exiting.')
|
||||
return
|
||||
|
||||
if not self.wait_for_resource():
|
||||
break
|
||||
|
||||
self.engine.submit_models(model)
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self._random_state.set_state(state_dict['random_state'])
|
||||
if self._dedup_helper is not None:
|
||||
self._dedup_helper.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
dedup_state = self._dedup_helper.state_dict() if self._dedup_helper is not None else {}
|
||||
return {
|
||||
'random_state': self._random_state.get_state(),
|
||||
**dedup_state
|
||||
}
|
||||
|
|
|
@ -1,58 +1,113 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import collections
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['DeduplicationHelper', 'DuplicationError', 'RetrySamplingHelper']
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from nni.nas.execution.common import Model
|
||||
from nni.nas.mutable import Mutator, Sampler
|
||||
from typing import Any, Type, TypeVar, Callable
|
||||
|
||||
from nni.mutable import SampleValidationError
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class _FixedSampler(Sampler):
|
||||
|
||||
def _to_hashable(obj):
|
||||
"""Trick to make a dict saveable in a set."""
|
||||
if isinstance(obj, dict):
|
||||
return frozenset((k, _to_hashable(v)) for k, v in obj.items())
|
||||
if isinstance(obj, list):
|
||||
return tuple(_to_hashable(v) for v in obj)
|
||||
return obj
|
||||
|
||||
|
||||
class DuplicationError(SampleValidationError):
|
||||
"""Exception raised when a sample is duplicated."""
|
||||
def __init__(self, sample):
|
||||
self.sample = sample
|
||||
|
||||
def choice(self, candidates, mutator, model, index):
|
||||
return self.sample[(mutator, index)]
|
||||
super().__init__(f'Duplicated sample found: {sample}')
|
||||
|
||||
|
||||
def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
|
||||
search_space = collections.OrderedDict()
|
||||
for mutator in mutators:
|
||||
recorded_candidates, model = mutator.dry_run(model)
|
||||
for i, candidates in enumerate(recorded_candidates):
|
||||
search_space[(mutator, i)] = candidates
|
||||
return search_space
|
||||
class DeduplicationHelper:
|
||||
"""Helper class to deduplicate samples.
|
||||
|
||||
def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]:
|
||||
search_space = collections.OrderedDict()
|
||||
for mutator in mutators:
|
||||
recorded_candidates, model = mutator.dry_run(model)
|
||||
if len(recorded_candidates) == 1:
|
||||
search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]}
|
||||
else:
|
||||
for i, candidate in enumerate(recorded_candidates):
|
||||
search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate}
|
||||
return search_space
|
||||
Different from the deduplication on the HPO side,
|
||||
this class simply checks if a sample has been tried before, and does nothing else.
|
||||
"""
|
||||
|
||||
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
|
||||
sampler = _FixedSampler(sample)
|
||||
model = base_model
|
||||
for mutator in mutators:
|
||||
model = mutator.bind_sampler(sampler).apply(model)
|
||||
return model
|
||||
def __init__(self, raise_on_dup: bool = False):
|
||||
self._history = set()
|
||||
self._raise_on_dup = raise_on_dup
|
||||
|
||||
def dedup(self, sample: Any) -> bool:
|
||||
"""
|
||||
If the new sample has not been seen before, it will be added to the history and return True.
|
||||
Otherwise, return False directly.
|
||||
|
||||
def filter_model(model_filter, ir_model):
|
||||
if model_filter is not None:
|
||||
_logger.debug(f'Check if model satisfies constraints.')
|
||||
if model_filter(ir_model):
|
||||
_logger.debug(f'Model satisfied. Submit the model.')
|
||||
return True
|
||||
else:
|
||||
_logger.debug(f'Model unsatisfied. Discard the model.')
|
||||
If raise_on_dup is true, a :class:`DuplicationError` will be raised instead of returning False.
|
||||
"""
|
||||
sample = _to_hashable(sample)
|
||||
if sample in self._history:
|
||||
_logger.debug('Duplicated sample found: %s', sample)
|
||||
if self._raise_on_dup:
|
||||
raise DuplicationError(sample)
|
||||
return False
|
||||
else:
|
||||
self._history.add(sample)
|
||||
return True
|
||||
|
||||
def remove(self, sample: Any) -> None:
|
||||
"""
|
||||
Remove a sample from the history.
|
||||
"""
|
||||
self._history.remove(_to_hashable(sample))
|
||||
|
||||
def reset(self):
|
||||
self._history = set()
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'dedup_history': list(self._history)
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._history = set(state_dict['dedup_history'])
|
||||
|
||||
|
||||
class RetrySamplingHelper:
|
||||
"""Helper class to retry a function until it succeeds.
|
||||
|
||||
Typical use case is to retry random sampling until a non-duplicate / valid sample is found.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
retries
|
||||
Number of retries.
|
||||
exception_types
|
||||
Exception types to catch.
|
||||
raise_last
|
||||
Whether to raise the last exception if all retries failed.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
retries: int = 500,
|
||||
exception_types: tuple[Type[Exception]] = (SampleValidationError,),
|
||||
raise_last: bool = False):
|
||||
self.retries = retries
|
||||
self.exception_types = exception_types
|
||||
self.raise_last = raise_last
|
||||
|
||||
def retry(self, func: Callable[..., T], *args, **kwargs) -> T | None:
|
||||
for retry in range(self.retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except self.exception_types as e:
|
||||
if retry in [0, 10, 100, 1000]:
|
||||
_logger.debug('Sampling failed. %d retries so far. Exception caught: %r', retry, e)
|
||||
if retry >= self.retries - 1 and self.raise_last:
|
||||
_logger.warning('Sampling failed after %d retries. Giving up and raising the last exception.', self.retries)
|
||||
raise
|
||||
|
||||
_logger.warning('Sampling failed after %d retires. Giving up and returning None.', self.retries)
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import pytest
|
||||
|
||||
from nni.mutable import *
|
||||
from nni.nas.strategy.utils import *
|
||||
|
||||
|
||||
def test_dedup():
|
||||
deduper = DeduplicationHelper()
|
||||
assert deduper.dedup(1)
|
||||
assert deduper.dedup(2)
|
||||
assert not deduper.dedup(1)
|
||||
|
||||
assert deduper.dedup({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
|
||||
assert not deduper.dedup({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
|
||||
|
||||
deduper.remove({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
|
||||
assert deduper.dedup({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
|
||||
|
||||
deduper.reset()
|
||||
assert deduper.dedup(1)
|
||||
|
||||
deduper = DeduplicationHelper(True)
|
||||
assert deduper.dedup(1)
|
||||
with pytest.raises(DuplicationError):
|
||||
deduper.dedup(1)
|
||||
|
||||
|
||||
def test_retry():
|
||||
helper = RetrySamplingHelper(10)
|
||||
assert helper.retry(lambda: 1) == 1
|
||||
|
||||
class _keep_trying_and_will_success:
|
||||
def __init__(self):
|
||||
self._count = 0
|
||||
|
||||
def __call__(self):
|
||||
self._count += 1
|
||||
if self._count < 5:
|
||||
raise KeyError()
|
||||
return 1
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
helper.retry(_keep_trying_and_will_success())
|
||||
|
||||
helper = RetrySamplingHelper(10, KeyError)
|
||||
assert helper.retry(_keep_trying_and_will_success()) == 1
|
||||
|
||||
helper = RetrySamplingHelper(3, KeyError)
|
||||
assert helper.retry(_keep_trying_and_will_success()) is None
|
||||
|
||||
helper = RetrySamplingHelper(3, KeyError, raise_last=True)
|
||||
with pytest.raises(KeyError):
|
||||
helper.retry(_keep_trying_and_will_success())
|
||||
|
Загрузка…
Ссылка в новой задаче