Fix some randomness in evolutionary pareto search not coming from given seed. (#225)

Add unit tests to cover this.
This commit is contained in:
Chris Lovett 2023-04-20 22:37:39 -07:00 коммит произвёл GitHub
Родитель e225885676
Коммит 4fc5a6d068
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 113 добавлений и 31 удалений

Просмотреть файл

@ -118,6 +118,7 @@ class EvolutionParetoSearch(Searcher):
_, valid_indices = self.so.validate_constraints(sample)
valid_sample += [sample[i] for i in valid_indices]
nb_tries += 1
return valid_sample[:num_models]
@ -143,6 +144,7 @@ class EvolutionParetoSearch(Searcher):
nb_tries = 0
while len(candidates) < mutations_per_parent and nb_tries < patience:
nb_tries += 1
mutated_model = self.search_space.mutate(p)
mutated_model.metadata["parent"] = p.archid
@ -152,7 +154,6 @@ class EvolutionParetoSearch(Searcher):
if mutated_model.archid not in self.seen_archs:
mutated_model.metadata["generation"] = self.iter_num
candidates[mutated_model.archid] = mutated_model
nb_tries += 1
mutations.update(candidates)
return list(mutations.values())
@ -176,7 +177,7 @@ class EvolutionParetoSearch(Searcher):
children, children_ids = [], set()
if len(parents) >= 2:
pairs = [random.sample(parents, 2) for _ in range(num_crossovers)]
pairs = [self.rng.sample(parents, 2) for _ in range(num_crossovers)]
for p1, p2 in pairs:
child = self.search_space.crossover([p1, p2])
nb_tries = 0
@ -215,7 +216,7 @@ class EvolutionParetoSearch(Searcher):
"""
random.shuffle(current_pop)
self.rng.shuffle(current_pop)
return current_pop[: self.max_unseen_population]
@overrides
@ -241,6 +242,8 @@ class EvolutionParetoSearch(Searcher):
logger.info(f"Calculating search objectives {list(self.so.objective_names)} for {len(unseen_pop)} models ...")
results = self.so.eval_all_objs(unseen_pop)
if len(results) == 0:
raise Exception("Search is finding no valid models")
self.search_state.add_iteration_results(
unseen_pop,
results,

Просмотреть файл

@ -3,27 +3,96 @@
import os
from typing import Optional
from random import Random
import pytest
from overrides import overrides
from archai.discrete_search.algos.evolution_pareto import EvolutionParetoSearch
from archai.discrete_search.api.search_objectives import SearchObjectives
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator
from archai.discrete_search.search_spaces.config import (
ArchParamTree, ConfigSearchSpace, DiscreteChoice,
)
class DummyEvaluator(ModelEvaluator):
def __init__(self, rng: Random):
self.dummy = True
self.rng = rng
@overrides
def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float:
return self.rng.random()
@pytest.fixture(scope="session")
def output_dir(tmp_path_factory):
return tmp_path_factory.mktemp("out")
@pytest.fixture
def tree_c2():
c = {
'p1': DiscreteChoice(list([False, True])),
'p2': DiscreteChoice(list([False, True]))
}
return c
def test_evolution_pareto(output_dir, search_space, search_objectives):
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5)
cache = []
for _ in range(2):
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5, seed=42)
search_space.rng = algo.rng
search_results = algo.search()
assert len(os.listdir(output_dir)) > 0
search_results = algo.search()
assert len(os.listdir(output_dir)) > 0
df = search_results.get_search_state_df()
assert all(0 <= x <= 0.4 for x in df["Random1"].tolist())
df = search_results.get_search_state_df()
assert all(0 <= x <= 0.4 for x in df["Random1"].tolist())
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
# Checks if all registered models satisfy constraints
_, valid_models = search_objectives.validate_constraints(all_models)
assert len(valid_models) == len(all_models)
# Checks if all registered models satisfy constraints
_, valid_models = search_objectives.validate_constraints(all_models)
assert len(valid_models) == len(all_models)
cache += [[m.archid for m in all_models]]
# make sure the archid's returned are repeatable so that search jobs can be restartable.
assert cache[0] == cache[1]
def test_evolution_pareto_tree_search(output_dir, tree_c2):
tree = ArchParamTree(tree_c2)
def use_arch(c):
if c.pick('p1'):
return
if c.pick('p2'):
return
seed = 42
cache = []
for _ in range(2):
search_objectives = SearchObjectives()
search_objectives.add_objective(
'Dummy',
DummyEvaluator(Random(seed)),
higher_is_better=False,
compute_intensive=False)
search_space = ConfigSearchSpace(use_arch, tree, seed=seed)
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5, seed=seed, save_pareto_model_weights=False)
search_results = algo.search()
assert len(os.listdir(output_dir)) > 0
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
cache += [[m.archid for m in all_models]]
# make sure the archid's returned are repeatable so that search jobs can be restartable.
assert cache[0] == cache[1]

Просмотреть файл

@ -50,7 +50,7 @@ def tree_c2():
def test_param_sharing(rng, tree_c1):
tree = ArchParamTree(tree_c1)
for _ in range(10):
config = tree.sample_config(rng)
p1 = config.pick('param1')
@ -68,10 +68,10 @@ def test_repeat_config_share(rng, tree_c1):
for _ in range(10):
config = tree.sample_config(rng)
for param_block in config.pick('param_list'):
par4 = param_block.pick('param4')
assert len(set(
p.pick('constant') for p in par4
)) == 1
@ -93,35 +93,47 @@ def test_ss(rng, tree_c2, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp('test_ss')
tree = ArchParamTree(tree_c2)
def use_arch(c):
if c.pick('p1'):
return
if c.pick('p2'):
return
ss = ConfigSearchSpace(use_arch, tree, seed=1)
m = ss.random_sample()
ss.save_arch(m, tmp_path / 'arch.json')
m2 = ss.load_arch(tmp_path / 'arch.json')
assert m.archid == m2.archid
cache = []
for _ in range(2):
ids = []
m3 = ss.mutate(m)
m4 = ss.crossover([m3, m2])
ss = ConfigSearchSpace(use_arch, tree, seed=1)
m = ss.random_sample()
ids += [m.archid]
ss.save_arch(m, tmp_path / 'arch.json')
m2 = ss.load_arch(tmp_path / 'arch.json')
assert m.archid == m2.archid
m3 = ss.mutate(m)
ids += [m3.archid]
m4 = ss.crossover([m3, m2])
ids += [m4.archid]
cache += [ids]
# make sure the archid's returned are repeatable so that search jobs can be restartable.
assert cache[0] == cache[1]
def test_ss_archid(rng, tree_c2):
tree = ArchParamTree(tree_c2)
def use_arch(c):
if c.pick('p1'):
return
if c.pick('p2'):
return
ss = ConfigSearchSpace(use_arch, tree, seed=1)
archids = set()
@ -130,5 +142,3 @@ def test_ss_archid(rng, tree_c2):
archids.add(config.archid)
assert len(archids) == 3 # Will fail with probability approx 1/2^100