From 027fdeac5a1373fdccf5283b6c04edb71810c6e1 Mon Sep 17 00:00:00 2001 From: Debadeepta Dey Date: Thu, 21 Apr 2022 17:00:57 -0700 Subject: [PATCH] The pipeline runs through nominally! But lots of work to be done still. --- .../evolution_pareto_search_segmentation.py | 26 +++++++++++++------ .../segmentation_trainer.py | 9 ++++--- .../discrete_search_space_segmentation.py | 13 +++++++--- .../evolution_pareto_search_segmentation.yaml | 2 +- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/archai/algos/evolution_pareto_image_seg/evolution_pareto_search_segmentation.py b/archai/algos/evolution_pareto_image_seg/evolution_pareto_search_segmentation.py index 8c24c04d..d1ee54cd 100644 --- a/archai/algos/evolution_pareto_image_seg/evolution_pareto_search_segmentation.py +++ b/archai/algos/evolution_pareto_image_seg/evolution_pareto_search_segmentation.py @@ -88,6 +88,10 @@ class EvolutionParetoSearchSegmentation(EvolutionParetoSearch): def _evaluate(self, arch:ArchWithMetaData)->float: + # DEBUG: simulate architecture evaluation + f1 = random.random() + return f1 + # see if we have visited this arch before if arch.metadata['archid'] in self.eval_cache: logger.info(f"{arch.metadata['archid']} is in cache! Returning from cache.") @@ -95,31 +99,37 @@ class EvolutionParetoSearchSegmentation(EvolutionParetoSearch): # if not in cache actually evaluate it # ------------------------------------- - logger.pushd(f"regular_training_{arch.metadata['archid']}") # train # TODO: how do we set the number of epochs it will train for? dataset_dir = os.path.join(self.dataroot, 'face_synthetics') - trainer = SegmentationTrainer(arch.arch, dataset_dir=dataset_dir, val_size=2000, gpus=1) + # TODO: most of these should come from conf + # TODO: batch size 16 has lr 2e-4. can we increase batch size? what lr? + trainer = SegmentationTrainer(arch.arch, + dataset_dir=dataset_dir, + max_steps=100, + val_size=2000, + img_size=256, + augmentation='none', + batch_size=64, + lr=8e-4, + criterion_name='ce', + gpus=1, + seed=42) trainer.fit(run_path=utils.full_path(get_expdir())) # validate val_dl = trainer.val_dataloader outputs = [] - with torch.no_grad(): for bi, b in enumerate(tqdm(val_dl)): b['image'] = b['image'].to('cuda') b['mask'] = b['mask'].to('cuda') + trainer.model.to('cuda') outputs.append(trainer.model.validation_step(b, bi)) results = trainer.model.shared_epoch_end(outputs, stage='validation') - logger.popd() - - # # DEBUG: simulate architecture evaluation - # f1 = random.random() - f1 = results['validation_overall_f1'] return f1 diff --git a/archai/algos/evolution_pareto_image_seg/segmentation_trainer.py b/archai/algos/evolution_pareto_image_seg/segmentation_trainer.py index 2710cd32..df5fdbc6 100644 --- a/archai/algos/evolution_pareto_image_seg/segmentation_trainer.py +++ b/archai/algos/evolution_pareto_image_seg/segmentation_trainer.py @@ -119,7 +119,8 @@ class LightningModelWrapper(pl.LightningModule): results = get_custom_overall_metrics(tp, fp, fn, tn, stage=stage) results[f'{stage}_loss'] = avg_loss - self.log_dict(results, sync_dist=True) + # TODO: enabling this causes error in lightning + # self.log_dict(results, sync_dist=True) return results def configure_optimizers(self): @@ -173,7 +174,8 @@ class LightningModelWrapper(pl.LightningModule): class SegmentationTrainer(): def __init__(self, model: SegmentationNasModel, dataset_dir: str, - max_steps: int = 12_000, val_size: int = 2000, img_size: int = 256, + max_steps: int = 12000, val_size: int = 2000, + val_interval: int = 1000, img_size: int = 256, augmentation: str = 'none', batch_size: int = 16, lr: float = 2e-4, criterion_name: str = 'ce', gpus: int = 1, seed: int = 1): @@ -195,6 +197,7 @@ class SegmentationTrainer(): exponential_decay_lr=True, img_size=img_size) self.img_size = img_size self.gpus = gpus + self.val_interval = val_interval def get_training_callbacks(self, run_dir: Path) -> List[pl.callbacks.Callback]: return [pl.callbacks.ModelCheckpoint( @@ -211,7 +214,7 @@ class SegmentationTrainer(): max_steps=self.max_steps, default_root_dir=run_path, gpus=self.gpus, - val_check_interval=1_200, + val_check_interval=self.val_interval, callbacks=self.get_training_callbacks(run_path) ) diff --git a/archai/search_spaces/discrete_search_spaces/segmentation_search_spaces/discrete_search_space_segmentation.py b/archai/search_spaces/discrete_search_spaces/segmentation_search_spaces/discrete_search_space_segmentation.py index 1d718049..c1322a94 100644 --- a/archai/search_spaces/discrete_search_spaces/segmentation_search_spaces/discrete_search_space_segmentation.py +++ b/archai/search_spaces/discrete_search_spaces/segmentation_search_spaces/discrete_search_space_segmentation.py @@ -4,6 +4,8 @@ from overrides.overrides import overrides import copy import uuid +import torch + from archai.nas.arch_meta import ArchWithMetaData from archai.nas.discrete_search_space import DiscreteSearchSpace @@ -50,7 +52,7 @@ class DiscreteSearchSpaceSegmentation(DiscreteSearchSpace): # and change its operator at random # and its input sources # WARNING: this can result in some nodes left hanging - chosen_node_idx = random.randint(1, len(graph)) + chosen_node_idx = random.randint(1, len(graph)-1) node = graph[chosen_node_idx] node['op'] = random.choice(self.operations) # choose up to k inputs from previous nodes @@ -61,7 +63,7 @@ class DiscreteSearchSpaceSegmentation(DiscreteSearchSpace): # now go through every node in the graph (except output node) # and make sure it is being used as input in some node after it - for i, node in enumerate(graph): + for i, node in enumerate(graph[:-1]): this_name = node['name'] orphan = True # test whether not orphan @@ -70,17 +72,22 @@ class DiscreteSearchSpaceSegmentation(DiscreteSearchSpace): orphan = False if orphan: # choose a forward node to connect it with - chosen_forward_idx = random.randint(i+1, len(graph)) + chosen_forward_idx = random.randint(i+1, len(graph)-1) graph[chosen_forward_idx]['inputs'].append(this_name) # compile the model model = SegmentationNasModel.from_config(graph, channels_per_scale) + # TODO: these should come from config or elsewhere + # such that they are not hardcoded in here + out_shape = model.validate_forward(torch.randn(1, 3, 256, 256)).shape + assert out_shape == torch.Size([1, 19, 256, 256]) extradata = { 'datasetname': self.datasetname, 'graph': graph, 'channels_per_scale': channels_per_scale, 'archid': uuid.uuid4(), #TODO: need to replace with a string of the graph } + arch_meta = ArchWithMetaData(model, extradata) return [arch_meta] diff --git a/confs/algos/evolution_pareto_search_segmentation.yaml b/confs/algos/evolution_pareto_search_segmentation.yaml index 3dbce268..420f1cfa 100644 --- a/confs/algos/evolution_pareto_search_segmentation.yaml +++ b/confs/algos/evolution_pareto_search_segmentation.yaml @@ -44,7 +44,7 @@ dataset: {} # default dataset settings comes from __include__ on the top nas: search: - init_num_models: 20 # initial random models to seed the search + init_num_models: 3 # initial random models to seed the search num_iters: 3 # number of pareto frontier search iterations num_random_mix: 20 # how many random models to add to the parent mixture use_benchmark: True