зеркало из https://github.com/microsoft/archai.git
Fix some randomness in evolutionary pareto search not coming from given seed. (#225)
Add unit tests to cover this.
This commit is contained in:
Родитель
e225885676
Коммит
4fc5a6d068
|
@ -118,6 +118,7 @@ class EvolutionParetoSearch(Searcher):
|
|||
|
||||
_, valid_indices = self.so.validate_constraints(sample)
|
||||
valid_sample += [sample[i] for i in valid_indices]
|
||||
nb_tries += 1
|
||||
|
||||
return valid_sample[:num_models]
|
||||
|
||||
|
@ -143,6 +144,7 @@ class EvolutionParetoSearch(Searcher):
|
|||
nb_tries = 0
|
||||
|
||||
while len(candidates) < mutations_per_parent and nb_tries < patience:
|
||||
nb_tries += 1
|
||||
mutated_model = self.search_space.mutate(p)
|
||||
mutated_model.metadata["parent"] = p.archid
|
||||
|
||||
|
@ -152,7 +154,6 @@ class EvolutionParetoSearch(Searcher):
|
|||
if mutated_model.archid not in self.seen_archs:
|
||||
mutated_model.metadata["generation"] = self.iter_num
|
||||
candidates[mutated_model.archid] = mutated_model
|
||||
nb_tries += 1
|
||||
mutations.update(candidates)
|
||||
|
||||
return list(mutations.values())
|
||||
|
@ -176,7 +177,7 @@ class EvolutionParetoSearch(Searcher):
|
|||
children, children_ids = [], set()
|
||||
|
||||
if len(parents) >= 2:
|
||||
pairs = [random.sample(parents, 2) for _ in range(num_crossovers)]
|
||||
pairs = [self.rng.sample(parents, 2) for _ in range(num_crossovers)]
|
||||
for p1, p2 in pairs:
|
||||
child = self.search_space.crossover([p1, p2])
|
||||
nb_tries = 0
|
||||
|
@ -215,7 +216,7 @@ class EvolutionParetoSearch(Searcher):
|
|||
|
||||
"""
|
||||
|
||||
random.shuffle(current_pop)
|
||||
self.rng.shuffle(current_pop)
|
||||
return current_pop[: self.max_unseen_population]
|
||||
|
||||
@overrides
|
||||
|
@ -241,6 +242,8 @@ class EvolutionParetoSearch(Searcher):
|
|||
logger.info(f"Calculating search objectives {list(self.so.objective_names)} for {len(unseen_pop)} models ...")
|
||||
|
||||
results = self.so.eval_all_objs(unseen_pop)
|
||||
if len(results) == 0:
|
||||
raise Exception("Search is finding no valid models")
|
||||
self.search_state.add_iteration_results(
|
||||
unseen_pop,
|
||||
results,
|
||||
|
|
|
@ -3,27 +3,96 @@
|
|||
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
from random import Random
|
||||
import pytest
|
||||
|
||||
from overrides import overrides
|
||||
from archai.discrete_search.algos.evolution_pareto import EvolutionParetoSearch
|
||||
from archai.discrete_search.api.search_objectives import SearchObjectives
|
||||
from archai.discrete_search.api.archai_model import ArchaiModel
|
||||
from archai.discrete_search.api.model_evaluator import ModelEvaluator
|
||||
from archai.discrete_search.search_spaces.config import (
|
||||
ArchParamTree, ConfigSearchSpace, DiscreteChoice,
|
||||
)
|
||||
|
||||
|
||||
class DummyEvaluator(ModelEvaluator):
|
||||
def __init__(self, rng: Random):
|
||||
self.dummy = True
|
||||
self.rng = rng
|
||||
|
||||
@overrides
|
||||
def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float:
|
||||
return self.rng.random()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def output_dir(tmp_path_factory):
|
||||
return tmp_path_factory.mktemp("out")
|
||||
|
||||
@pytest.fixture
|
||||
def tree_c2():
|
||||
c = {
|
||||
'p1': DiscreteChoice(list([False, True])),
|
||||
'p2': DiscreteChoice(list([False, True]))
|
||||
}
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def test_evolution_pareto(output_dir, search_space, search_objectives):
|
||||
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5)
|
||||
cache = []
|
||||
for _ in range(2):
|
||||
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5, seed=42)
|
||||
search_space.rng = algo.rng
|
||||
|
||||
search_results = algo.search()
|
||||
assert len(os.listdir(output_dir)) > 0
|
||||
search_results = algo.search()
|
||||
assert len(os.listdir(output_dir)) > 0
|
||||
|
||||
df = search_results.get_search_state_df()
|
||||
assert all(0 <= x <= 0.4 for x in df["Random1"].tolist())
|
||||
df = search_results.get_search_state_df()
|
||||
assert all(0 <= x <= 0.4 for x in df["Random1"].tolist())
|
||||
|
||||
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
|
||||
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
|
||||
|
||||
# Checks if all registered models satisfy constraints
|
||||
_, valid_models = search_objectives.validate_constraints(all_models)
|
||||
assert len(valid_models) == len(all_models)
|
||||
# Checks if all registered models satisfy constraints
|
||||
_, valid_models = search_objectives.validate_constraints(all_models)
|
||||
assert len(valid_models) == len(all_models)
|
||||
|
||||
cache += [[m.archid for m in all_models]]
|
||||
|
||||
# make sure the archid's returned are repeatable so that search jobs can be restartable.
|
||||
assert cache[0] == cache[1]
|
||||
|
||||
|
||||
def test_evolution_pareto_tree_search(output_dir, tree_c2):
|
||||
tree = ArchParamTree(tree_c2)
|
||||
|
||||
def use_arch(c):
|
||||
if c.pick('p1'):
|
||||
return
|
||||
|
||||
if c.pick('p2'):
|
||||
return
|
||||
|
||||
seed = 42
|
||||
|
||||
cache = []
|
||||
for _ in range(2):
|
||||
search_objectives = SearchObjectives()
|
||||
search_objectives.add_objective(
|
||||
'Dummy',
|
||||
DummyEvaluator(Random(seed)),
|
||||
higher_is_better=False,
|
||||
compute_intensive=False)
|
||||
search_space = ConfigSearchSpace(use_arch, tree, seed=seed)
|
||||
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5, seed=seed, save_pareto_model_weights=False)
|
||||
|
||||
search_results = algo.search()
|
||||
assert len(os.listdir(output_dir)) > 0
|
||||
|
||||
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
|
||||
|
||||
cache += [[m.archid for m in all_models]]
|
||||
|
||||
# make sure the archid's returned are repeatable so that search jobs can be restartable.
|
||||
assert cache[0] == cache[1]
|
||||
|
|
|
@ -50,7 +50,7 @@ def tree_c2():
|
|||
|
||||
def test_param_sharing(rng, tree_c1):
|
||||
tree = ArchParamTree(tree_c1)
|
||||
|
||||
|
||||
for _ in range(10):
|
||||
config = tree.sample_config(rng)
|
||||
p1 = config.pick('param1')
|
||||
|
@ -68,10 +68,10 @@ def test_repeat_config_share(rng, tree_c1):
|
|||
|
||||
for _ in range(10):
|
||||
config = tree.sample_config(rng)
|
||||
|
||||
|
||||
for param_block in config.pick('param_list'):
|
||||
par4 = param_block.pick('param4')
|
||||
|
||||
|
||||
assert len(set(
|
||||
p.pick('constant') for p in par4
|
||||
)) == 1
|
||||
|
@ -93,35 +93,47 @@ def test_ss(rng, tree_c2, tmp_path_factory):
|
|||
tmp_path = tmp_path_factory.mktemp('test_ss')
|
||||
|
||||
tree = ArchParamTree(tree_c2)
|
||||
|
||||
|
||||
def use_arch(c):
|
||||
if c.pick('p1'):
|
||||
return
|
||||
|
||||
|
||||
if c.pick('p2'):
|
||||
return
|
||||
|
||||
ss = ConfigSearchSpace(use_arch, tree, seed=1)
|
||||
m = ss.random_sample()
|
||||
ss.save_arch(m, tmp_path / 'arch.json')
|
||||
|
||||
m2 = ss.load_arch(tmp_path / 'arch.json')
|
||||
assert m.archid == m2.archid
|
||||
cache = []
|
||||
for _ in range(2):
|
||||
ids = []
|
||||
|
||||
m3 = ss.mutate(m)
|
||||
m4 = ss.crossover([m3, m2])
|
||||
ss = ConfigSearchSpace(use_arch, tree, seed=1)
|
||||
m = ss.random_sample()
|
||||
ids += [m.archid]
|
||||
ss.save_arch(m, tmp_path / 'arch.json')
|
||||
|
||||
m2 = ss.load_arch(tmp_path / 'arch.json')
|
||||
assert m.archid == m2.archid
|
||||
|
||||
m3 = ss.mutate(m)
|
||||
|
||||
ids += [m3.archid]
|
||||
m4 = ss.crossover([m3, m2])
|
||||
ids += [m4.archid]
|
||||
cache += [ids]
|
||||
|
||||
# make sure the archid's returned are repeatable so that search jobs can be restartable.
|
||||
assert cache[0] == cache[1]
|
||||
|
||||
|
||||
def test_ss_archid(rng, tree_c2):
|
||||
tree = ArchParamTree(tree_c2)
|
||||
|
||||
|
||||
def use_arch(c):
|
||||
if c.pick('p1'):
|
||||
return
|
||||
|
||||
|
||||
if c.pick('p2'):
|
||||
return
|
||||
|
||||
|
||||
ss = ConfigSearchSpace(use_arch, tree, seed=1)
|
||||
archids = set()
|
||||
|
||||
|
@ -130,5 +142,3 @@ def test_ss_archid(rng, tree_c2):
|
|||
archids.add(config.archid)
|
||||
|
||||
assert len(archids) == 3 # Will fail with probability approx 1/2^100
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче