NAS strategy (stage 2) - utils and bruteforce strategy (#5367)

This commit is contained in:
Yuge Zhang 2023-03-03 09:49:45 +08:00 коммит произвёл GitHub
Родитель 808ac9a3fc
Коммит 4b6c0be045
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 353 добавлений и 147 удалений

Просмотреть файл

@ -1,138 +1,235 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import itertools
from __future__ import annotations
__all__ = ['GridSearch', 'Random']
import logging
import random
import time
from typing import Any, Dict, List, Sequence, Optional
import warnings
from typing import Any, Iterable
from nni.nas.execution import submit_models, query_available_resources, budget_exhausted
from nni.nas.mutable import InvalidMutation, Sampler
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
from numpy.random import RandomState
from nni.mutable import Sample, SampleValidationError
from nni.nas.space import MutationSampler, ExecutableModelSpace, Mutator
from .base import Strategy
from .utils import DeduplicationHelper, RetrySamplingHelper
_logger = logging.getLogger(__name__)
def grid_generator(search_space: Dict[Any, List[Any]], shuffle=True):
keys = list(search_space.keys())
search_space_values = copy.deepcopy(list(search_space.values()))
if shuffle:
for values in search_space_values:
random.shuffle(values)
for values in itertools.product(*search_space_values):
yield {key: value for key, value in zip(keys, values)}
def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500):
keys = list(search_space.keys())
history = set()
search_space_values = copy.deepcopy(list(search_space.values()))
while True:
selected: Optional[Sequence[int]] = None
for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values]
if not dedup:
break
selected = tuple(selected)
if selected not in history:
history.add(selected)
break
if retry_count + 1 == retries:
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return
assert selected is not None, 'Retry attempts exhausted.'
yield {key: value for key, value in zip(keys, selected)}
class GridSearch(BaseStrategy):
class GridSearch(Strategy):
"""
Traverse the search space and try all the possible combinations one by one.
Parameters
----------
shuffle : bool
Shuffle the order in a candidate list, so that they are tried in a random order. Default: true.
shuffle
Shuffle the order in a candidate list, so that they are tried in a random order.
Currently, the implementation is a pseudo-random shuffle, which only shuffles the order of every 100 candidates.
seed
Random seed.
"""
def __init__(self, shuffle=True):
self._polling_interval = 2.
_shuffle_buffer_size: int = 100
_granularity_patience: int = 3 # stop increasing granularity after this many times of no new sample found
def __init__(self, *, shuffle: bool = True, seed: int | None = None, dedup: bool = True):
super().__init__()
self.shuffle = shuffle
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0:
if budget_exhausted():
# Internal only:
# Do not try the same configuration twice.
# Turning it off might result in duplications when the space has an infinite size,
# or the strategy tries to resume from a checkpoint,
# but might improve memory efficiency in extreme cases.
self._dedup = DeduplicationHelper() if dedup else None
self._granularity = 1
self._granularity_processed: int | None = None
self._no_sample_found_counter = 0
self._random_state = RandomState(seed)
def extra_repr(self) -> str:
return f'shuffle={self.shuffle}, dedup={self._dedup is not None}'
def _grid_generator(self, model_space: ExecutableModelSpace) -> Iterable[ExecutableModelSpace]:
if self._no_sample_found_counter >= self._granularity_patience:
_logger.info('Patience already run out (%d > %d). Nothing to search.',
self._no_sample_found_counter, self._granularity_patience)
return
finite = self._space_validation(model_space)
while True:
new_sample_found = False
for model in model_space.grid(granularity=self._granularity):
if self._dedup is not None and not self._dedup.dedup(model.sample):
continue
new_sample_found = True
yield model
if not new_sample_found:
self._no_sample_found_counter += 1
_logger.info('No new sample found when granularity is %d. Current patience: %d.',
self._granularity, self._no_sample_found_counter)
if self._no_sample_found_counter >= self._granularity_patience:
_logger.info('No new sample found for %d times. Stop increasing granularity.',
self._granularity_patience)
break
else:
self._no_sample_found_counter = 0
if finite:
_logger.info('Space is finite. Grid generation is complete.')
break
self._granularity += 1
_logger.info('Space is infinite. Increasing granularity to %d.', self._granularity)
def _run(self) -> None:
generator = self._grid_generator(self.model_space)
if self.shuffle:
# Shuffle the order of every `_shuffle_buffer_size` candidates.
shuffle_buffer = []
generator_running = True
# Already generated does not mean already submitted.
# We need to keep track of the granularity actually processed,
# to avoid skipping granularities when resuming.
self._granularity_processed = self._granularity
while generator_running:
should_submit = False
try:
next_model = next(generator)
shuffle_buffer.append(next_model)
if len(shuffle_buffer) == self._shuffle_buffer_size:
should_submit = True
except StopIteration:
# Submit the final models.
should_submit = True
generator_running = False
if should_submit:
# Submit models and clear the shuffle buffer.
self._random_state.shuffle(shuffle_buffer)
for model in shuffle_buffer:
if not self.wait_for_resource():
_logger.info('Budget exhausted, but search space is not exhausted.')
return
self.engine.submit_models(model)
shuffle_buffer = []
# Update granularity processed.
self._granularity_processed = self._granularity
else:
# Keep this in a separate branch because it's very simple.
for model in generator:
if not self.wait_for_resource():
_logger.info('Budget exhausted, but search space is not exhausted.')
return
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
self.engine.submit_models(model)
def _space_validation(self, model_space: ExecutableModelSpace) -> bool:
"""Check whether the space is supported by grid search.
Return true if the space is finite, false if it's not.
Raise error if it's not supported.
"""
for mutable in model_space.simplify().values():
# This method will raise error if grid is not implemented.
if len(list(mutable.grid(granularity=1))) != len(list(mutable.grid(granularity=1 + self._granularity_patience))):
return False
return True
class _RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
def load_state_dict(self, state_dict: dict) -> None:
self._granularity = state_dict['granularity']
self._no_sample_found_counter = state_dict['no_sample_found_counter']
self._random_state.set_state(state_dict['random_state'])
_logger.info('Grid search will resume from granularity %d.', self._granularity)
if self._dedup is not None:
self._dedup.load_state_dict(state_dict)
else:
_logger.info('Grid search would possibly yield duplicate samples since dedup is turned off.')
def state_dict(self) -> dict:
result = {'random_state': self._random_state.get_state()}
if self._granularity_processed is None:
result.update(granularity=self._granularity, no_sample_found_counter=self._no_sample_found_counter)
else:
result.update(granularity=self._granularity_processed, no_sample_found_counter=0)
class Random(BaseStrategy):
if self._dedup is not None:
result.update(self._dedup.state_dict())
return result
class Random(Strategy):
"""
Random search on the search space.
Parameters
----------
variational : bool
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
dedup
Do not try the same configuration twice.
seed
Random seed.
"""
def __init__(self, variational=False, dedup=True, model_filter=None):
self.variational = variational
self.dedup = dedup
if variational and dedup:
raise ValueError('Dedup is not supported in variational mode.')
self.random_sampler = _RandomSampler()
self._polling_interval = 2.
self.filter = model_filter
_duplicate_retry = 500
def run(self, base_model, applied_mutators):
if self.variational:
_logger.info('Random search running in variational mode.')
sampler = _RandomSampler()
for mutator in applied_mutators:
mutator.bind_sampler(sampler)
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
for mutator in applied_mutators:
model = mutator.apply(model)
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
if filter_model(self.filter, model):
submit_models(model)
elif budget_exhausted():
break
else:
time.sleep(self._polling_interval)
else:
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval)
_logger.debug('Still waiting for resource.')
try:
model = get_targeted_model(base_model, applied_mutators, sample)
if filter_model(self.filter, model):
_logger.debug('Submitting model: %s', model)
submit_models(model)
except InvalidMutation as e:
_logger.warning(f'Invalid mutation: {e}. Skip.')
def __init__(self, *, dedup: bool = True, seed: int | None = None, **kwargs):
super().__init__()
if 'variational' in kwargs or 'model_filter' in kwargs:
warnings.warn('Variational and model filter are no longer supported in random search and will be removed in future releases.',
DeprecationWarning)
self._dedup_helper = DeduplicationHelper(raise_on_dup=True) if dedup else None
self._retry_helper = RetrySamplingHelper(self._duplicate_retry)
self._random_state = RandomState(seed)
def extra_repr(self) -> str:
return f'dedup={self._dedup_helper is not None}'
def random(self, model_space: ExecutableModelSpace) -> ExecutableModelSpace:
"""Generate a random model from the space."""
sample: Sample = {}
model = model_space.random(random_state=self._random_state, memo=sample)
if self._dedup_helper is not None:
self._dedup_helper.dedup(sample)
return model
def _run(self) -> None:
while True:
# Random search needs retry to:
# 1. Generate new when dedup is on.
# 2. Retry when the sample is invalid.
model = self._retry_helper.retry(self.random, self.model_space)
if model is None:
_logger.info('Random generation has run out of patience. There is nothing to search. Exiting.')
return
if not self.wait_for_resource():
break
self.engine.submit_models(model)
def load_state_dict(self, state_dict: dict) -> None:
self._random_state.set_state(state_dict['random_state'])
if self._dedup_helper is not None:
self._dedup_helper.load_state_dict(state_dict)
def state_dict(self) -> dict:
dedup_state = self._dedup_helper.state_dict() if self._dedup_helper is not None else {}
return {
'random_state': self._random_state.get_state(),
**dedup_state
}

Просмотреть файл

@ -1,58 +1,113 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
from __future__ import annotations
__all__ = ['DeduplicationHelper', 'DuplicationError', 'RetrySamplingHelper']
import logging
from typing import Dict, Any, List
from nni.nas.execution.common import Model
from nni.nas.mutable import Mutator, Sampler
from typing import Any, Type, TypeVar, Callable
from nni.mutable import SampleValidationError
_logger = logging.getLogger(__name__)
T = TypeVar('T')
class _FixedSampler(Sampler):
def _to_hashable(obj):
"""Trick to make a dict saveable in a set."""
if isinstance(obj, dict):
return frozenset((k, _to_hashable(v)) for k, v in obj.items())
if isinstance(obj, list):
return tuple(_to_hashable(v) for v in obj)
return obj
class DuplicationError(SampleValidationError):
"""Exception raised when a sample is duplicated."""
def __init__(self, sample):
self.sample = sample
def choice(self, candidates, mutator, model, index):
return self.sample[(mutator, index)]
super().__init__(f'Duplicated sample found: {sample}')
def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
for i, candidates in enumerate(recorded_candidates):
search_space[(mutator, i)] = candidates
return search_space
class DeduplicationHelper:
"""Helper class to deduplicate samples.
def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
if len(recorded_candidates) == 1:
search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]}
else:
for i, candidate in enumerate(recorded_candidates):
search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate}
return search_space
Different from the deduplication on the HPO side,
this class simply checks if a sample has been tried before, and does nothing else.
"""
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample)
model = base_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
return model
def __init__(self, raise_on_dup: bool = False):
self._history = set()
self._raise_on_dup = raise_on_dup
def dedup(self, sample: Any) -> bool:
"""
If the new sample has not been seen before, it will be added to the history and return True.
Otherwise, return False directly.
def filter_model(model_filter, ir_model):
if model_filter is not None:
_logger.debug(f'Check if model satisfies constraints.')
if model_filter(ir_model):
_logger.debug(f'Model satisfied. Submit the model.')
return True
else:
_logger.debug(f'Model unsatisfied. Discard the model.')
If raise_on_dup is true, a :class:`DuplicationError` will be raised instead of returning False.
"""
sample = _to_hashable(sample)
if sample in self._history:
_logger.debug('Duplicated sample found: %s', sample)
if self._raise_on_dup:
raise DuplicationError(sample)
return False
else:
self._history.add(sample)
return True
def remove(self, sample: Any) -> None:
"""
Remove a sample from the history.
"""
self._history.remove(_to_hashable(sample))
def reset(self):
self._history = set()
def state_dict(self):
return {
'dedup_history': list(self._history)
}
def load_state_dict(self, state_dict):
self._history = set(state_dict['dedup_history'])
class RetrySamplingHelper:
"""Helper class to retry a function until it succeeds.
Typical use case is to retry random sampling until a non-duplicate / valid sample is found.
Parameters
----------
retries
Number of retries.
exception_types
Exception types to catch.
raise_last
Whether to raise the last exception if all retries failed.
"""
def __init__(self,
retries: int = 500,
exception_types: tuple[Type[Exception]] = (SampleValidationError,),
raise_last: bool = False):
self.retries = retries
self.exception_types = exception_types
self.raise_last = raise_last
def retry(self, func: Callable[..., T], *args, **kwargs) -> T | None:
for retry in range(self.retries):
try:
return func(*args, **kwargs)
except self.exception_types as e:
if retry in [0, 10, 100, 1000]:
_logger.debug('Sampling failed. %d retries so far. Exception caught: %r', retry, e)
if retry >= self.retries - 1 and self.raise_last:
_logger.warning('Sampling failed after %d retries. Giving up and raising the last exception.', self.retries)
raise
_logger.warning('Sampling failed after %d retires. Giving up and returning None.', self.retries)
return None

Просмотреть файл

@ -0,0 +1,54 @@
import pytest
from nni.mutable import *
from nni.nas.strategy.utils import *
def test_dedup():
deduper = DeduplicationHelper()
assert deduper.dedup(1)
assert deduper.dedup(2)
assert not deduper.dedup(1)
assert deduper.dedup({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
assert not deduper.dedup({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
deduper.remove({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
assert deduper.dedup({'a': 1, 'b': {'c': 3, 'd': [2, 3, 4]}})
deduper.reset()
assert deduper.dedup(1)
deduper = DeduplicationHelper(True)
assert deduper.dedup(1)
with pytest.raises(DuplicationError):
deduper.dedup(1)
def test_retry():
helper = RetrySamplingHelper(10)
assert helper.retry(lambda: 1) == 1
class _keep_trying_and_will_success:
def __init__(self):
self._count = 0
def __call__(self):
self._count += 1
if self._count < 5:
raise KeyError()
return 1
with pytest.raises(KeyError):
helper.retry(_keep_trying_and_will_success())
helper = RetrySamplingHelper(10, KeyError)
assert helper.retry(_keep_trying_and_will_success()) == 1
helper = RetrySamplingHelper(3, KeyError)
assert helper.retry(_keep_trying_and_will_success()) is None
helper = RetrySamplingHelper(3, KeyError, raise_last=True)
with pytest.raises(KeyError):
helper.retry(_keep_trying_and_will_success())