diff --git a/.vscode/launch.json b/.vscode/launch.json index 19f56094..89fd528a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -163,7 +163,7 @@ "request": "launch", "program": "${cwd}/scripts/main.py", "console": "integratedTerminal", - "args": ["--no-search", "--algos", "manual"] + "args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"] }, { "name": "Resnet-Full", diff --git a/archai/common/trainer.py b/archai/common/trainer.py index 5f986758..5962cbc6 100644 --- a/archai/common/trainer.py +++ b/archai/common/trainer.py @@ -42,6 +42,11 @@ class Trainer(EnforceOverrides): if conf_validation else None self._metrics:Optional[Metrics] = None self._amp = Amp(self._apex) + + self._droppath_module = self._get_droppath_module() + if self._droppath_module is None and self._drop_path_prob > 0.0: + logger.warn({'droppath_module': None}) + self._start_epoch = -1 # nothing is started yet def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics: @@ -208,8 +213,8 @@ class Trainer(EnforceOverrides): logits, aux_logits = self.model(x), None tupled_out = isinstance(logits, Tuple) and len(logits) >=2 - if self._aux_weight: - assert tupled_out, "aux_logits cannot be None unless aux tower is disabled" + # if self._aux_weight: # TODO: some other way to validate? + # assert tupled_out, "aux_logits cannot be None unless aux tower is disabled" if tupled_out: # then we are using model created by desc logits, aux_logits = logits[0], logits[1] loss = self.compute_loss(self._lossfn, x, y, logits, @@ -239,8 +244,16 @@ class Trainer(EnforceOverrides): loss += aux_weight * lossfn(aux_logits, y) return loss + def _get_droppath_module(self)->Optional[nn.Module]: + m = self.model + if hasattr(self.model, 'module'): # for data parallel model + m = self.model.module + if hasattr(m, 'drop_path_prob'): + return m + return None + def _set_drop_path(self, epoch:int, epochs:int)->None: - if self._drop_path_prob: + if self._drop_path_prob and self._droppath_module is not None: drop_prob = self._drop_path_prob * epoch / epochs # set value as property in model (it will be used by forward()) # this is necessory when using DataParallel(model) diff --git a/archai/nas/evaluate.py b/archai/nas/evaluate.py index 6c1c9ca3..e5ca8bc8 100644 --- a/archai/nas/evaluate.py +++ b/archai/nas/evaluate.py @@ -51,8 +51,23 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]): logger.popd() + +def _default_module_name(dataset_name:str, function_name:str)->str: + module_name = '' + if dataset_name.startswith('cifar'): + if function_name.startswith('res'): # support resnext as well + module_name = 'archai.cifar10_models.resnet' + elif function_name.startswith('dense'): + module_name = 'archai.cifar10_models.densenet' + elif dataset_name.startswith('imagenet'): + module_name = 'torchvision.models' + if not module_name: + raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet') + return module_name + def create_model(conf_eval:Config, device)->nn.Module: # region conf vars + dataset_name = conf_eval['loader']['dataset']['name'] final_desc_filename = conf_eval['final_desc_filename'] final_model_factory = conf_eval['final_model_factory'] full_desc_filename = conf_eval['full_desc_filename'] @@ -65,13 +80,8 @@ def create_model(conf_eval:Config, device)->nn.Module: if len(splitted) > 1: module_name = splitted[0] - else: # to support lazyness while submitting scripts, we do bit of unnecessory smarts - if function_name.startswith('res'): # support resnext as well - module_name = 'archai.cifar10_models.resnet' - elif function_name.startswith('dense'): - module_name = 'archai.cifar10_models.densenet' - else: - module_name = '' + else: + module_name = _default_module_name(dataset_name, function_name) module = importlib.import_module(module_name) if module_name else sys.modules[__name__] function = getattr(module, function_name) diff --git a/confs/algos/manual.yaml b/confs/algos/manual.yaml index dcd1a4e8..731a28d5 100644 --- a/confs/algos/manual.yaml +++ b/confs/algos/manual.yaml @@ -3,7 +3,7 @@ __include__: 'darts.yaml' # just use darts defaults nas: eval: - final_model_factory: 'archai.cifar10_models.resnet.resnet18' + final_model_factory: 'resnet18' #darts loader/trainer loader: diff --git a/confs/datasets/imagenet.yaml b/confs/datasets/imagenet.yaml index ac88e09e..8ace806f 100644 --- a/confs/datasets/imagenet.yaml +++ b/confs/datasets/imagenet.yaml @@ -2,7 +2,7 @@ common: seed: 0.0 -imagenet_dataset: +dataset_eval: name: 'imagenet' dataroot: '~/dataroot' #torchvision data folder n_classes: 1000 @@ -21,11 +21,11 @@ nas: model_stem1_op: 'stem_conv3x3_s4s2' model_post_op: 'pool_avg2d7x7' dataset: - _copy: '/imagenet_dataset' + _copy: '/dataset_eval' loader: batch: 128 dataset: - _copy: '/imagenet_dataset' + _copy: '/dataset_eval' trainer: aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch drop_path_prob: 0.0 # probability that given edge will be dropped diff --git a/scripts/datasets/pt_install.py b/scripts/datasets/pt_install.py index da278a00..fc0f78e7 100644 --- a/scripts/datasets/pt_install.py +++ b/scripts/datasets/pt_install.py @@ -29,6 +29,25 @@ def _create_ram_disk(req_ram:int, path:str)->bool: # print('RAM disk is not created because not enough memory') # return False +def untar_dataset(pt_data_dir:str, conf_data:Config, dataroot:str)->None: + storage_name = conf_data['storage_name'] + tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar') + if not os.path.isfile(tar_filepath): + raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found') + + tar_size = pathlib.Path(tar_filepath).stat().st_size + print('tar_filepath:', tar_filepath, 'tar_size:', tar_size) + + local_dataroot = utils.full_path(dataroot) + print('local_dataroot:', local_dataroot) + _create_ram_disk(tar_size, local_dataroot) + # os.makedirs(local_dataroot, exist_ok=True) + + utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"') + + print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully') + + def main(): parser = argparse.ArgumentParser(description='Archai data install') parser.add_argument('--dataroot', type=str, default='~/dataroot', @@ -46,23 +65,12 @@ def main(): conf_data_filepath = f'confs/datasets/{args.dataset}.yaml' print('conf_data_filepath:', conf_data_filepath) - conf_data = Config(config_filepath=conf_data_filepath)['dataset'] - storage_name = conf_data['storage_name'] - tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar') - if not os.path.isfile(tar_filepath): - raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found') + conf = Config(config_filepath=conf_data_filepath) + for dataset_key in ['dataset', 'dataset_search', 'dataset_eval']: + if dataset_key in conf: + conf_data = conf[dataset_key] + untar_dataset(pt_data_dir, conf_data, args.dataroot) - tar_size = pathlib.Path(tar_filepath).stat().st_size - print('tar_filepath:', tar_filepath, 'tar_size:', tar_size) - - local_dataroot = utils.full_path(args.dataroot) - print('local_dataroot:', local_dataroot) - _create_ram_disk(tar_size, local_dataroot) - # os.makedirs(local_dataroot, exist_ok=True) - - utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"') - - print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully') if __name__ == '__main__': main()