imagenet support for handcrafted models

This commit is contained in:
Shital Shah 2020-04-19 21:29:24 -07:00
Родитель 8a104c53aa
Коммит cec9c16780
6 изменённых файлов: 62 добавлений и 31 удалений

2
.vscode/launch.json поставляемый
Просмотреть файл

@ -163,7 +163,7 @@
"request": "launch", "request": "launch",
"program": "${cwd}/scripts/main.py", "program": "${cwd}/scripts/main.py",
"console": "integratedTerminal", "console": "integratedTerminal",
"args": ["--no-search", "--algos", "manual"] "args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"]
}, },
{ {
"name": "Resnet-Full", "name": "Resnet-Full",

Просмотреть файл

@ -42,6 +42,11 @@ class Trainer(EnforceOverrides):
if conf_validation else None if conf_validation else None
self._metrics:Optional[Metrics] = None self._metrics:Optional[Metrics] = None
self._amp = Amp(self._apex) self._amp = Amp(self._apex)
self._droppath_module = self._get_droppath_module()
if self._droppath_module is None and self._drop_path_prob > 0.0:
logger.warn({'droppath_module': None})
self._start_epoch = -1 # nothing is started yet self._start_epoch = -1 # nothing is started yet
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics: def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
@ -208,8 +213,8 @@ class Trainer(EnforceOverrides):
logits, aux_logits = self.model(x), None logits, aux_logits = self.model(x), None
tupled_out = isinstance(logits, Tuple) and len(logits) >=2 tupled_out = isinstance(logits, Tuple) and len(logits) >=2
if self._aux_weight: # if self._aux_weight: # TODO: some other way to validate?
assert tupled_out, "aux_logits cannot be None unless aux tower is disabled" # assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
if tupled_out: # then we are using model created by desc if tupled_out: # then we are using model created by desc
logits, aux_logits = logits[0], logits[1] logits, aux_logits = logits[0], logits[1]
loss = self.compute_loss(self._lossfn, x, y, logits, loss = self.compute_loss(self._lossfn, x, y, logits,
@ -239,8 +244,16 @@ class Trainer(EnforceOverrides):
loss += aux_weight * lossfn(aux_logits, y) loss += aux_weight * lossfn(aux_logits, y)
return loss return loss
def _get_droppath_module(self)->Optional[nn.Module]:
m = self.model
if hasattr(self.model, 'module'): # for data parallel model
m = self.model.module
if hasattr(m, 'drop_path_prob'):
return m
return None
def _set_drop_path(self, epoch:int, epochs:int)->None: def _set_drop_path(self, epoch:int, epochs:int)->None:
if self._drop_path_prob: if self._drop_path_prob and self._droppath_module is not None:
drop_prob = self._drop_path_prob * epoch / epochs drop_prob = self._drop_path_prob * epoch / epochs
# set value as property in model (it will be used by forward()) # set value as property in model (it will be used by forward())
# this is necessory when using DataParallel(model) # this is necessory when using DataParallel(model)

Просмотреть файл

@ -51,8 +51,23 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
logger.popd() logger.popd()
def _default_module_name(dataset_name:str, function_name:str)->str:
module_name = ''
if dataset_name.startswith('cifar'):
if function_name.startswith('res'): # support resnext as well
module_name = 'archai.cifar10_models.resnet'
elif function_name.startswith('dense'):
module_name = 'archai.cifar10_models.densenet'
elif dataset_name.startswith('imagenet'):
module_name = 'torchvision.models'
if not module_name:
raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet')
return module_name
def create_model(conf_eval:Config, device)->nn.Module: def create_model(conf_eval:Config, device)->nn.Module:
# region conf vars # region conf vars
dataset_name = conf_eval['loader']['dataset']['name']
final_desc_filename = conf_eval['final_desc_filename'] final_desc_filename = conf_eval['final_desc_filename']
final_model_factory = conf_eval['final_model_factory'] final_model_factory = conf_eval['final_model_factory']
full_desc_filename = conf_eval['full_desc_filename'] full_desc_filename = conf_eval['full_desc_filename']
@ -65,13 +80,8 @@ def create_model(conf_eval:Config, device)->nn.Module:
if len(splitted) > 1: if len(splitted) > 1:
module_name = splitted[0] module_name = splitted[0]
else: # to support lazyness while submitting scripts, we do bit of unnecessory smarts else:
if function_name.startswith('res'): # support resnext as well module_name = _default_module_name(dataset_name, function_name)
module_name = 'archai.cifar10_models.resnet'
elif function_name.startswith('dense'):
module_name = 'archai.cifar10_models.densenet'
else:
module_name = ''
module = importlib.import_module(module_name) if module_name else sys.modules[__name__] module = importlib.import_module(module_name) if module_name else sys.modules[__name__]
function = getattr(module, function_name) function = getattr(module, function_name)

Просмотреть файл

@ -3,7 +3,7 @@ __include__: 'darts.yaml' # just use darts defaults
nas: nas:
eval: eval:
final_model_factory: 'archai.cifar10_models.resnet.resnet18' final_model_factory: 'resnet18'
#darts loader/trainer #darts loader/trainer
loader: loader:

Просмотреть файл

@ -2,7 +2,7 @@
common: common:
seed: 0.0 seed: 0.0
imagenet_dataset: dataset_eval:
name: 'imagenet' name: 'imagenet'
dataroot: '~/dataroot' #torchvision data folder dataroot: '~/dataroot' #torchvision data folder
n_classes: 1000 n_classes: 1000
@ -21,11 +21,11 @@ nas:
model_stem1_op: 'stem_conv3x3_s4s2' model_stem1_op: 'stem_conv3x3_s4s2'
model_post_op: 'pool_avg2d7x7' model_post_op: 'pool_avg2d7x7'
dataset: dataset:
_copy: '/imagenet_dataset' _copy: '/dataset_eval'
loader: loader:
batch: 128 batch: 128
dataset: dataset:
_copy: '/imagenet_dataset' _copy: '/dataset_eval'
trainer: trainer:
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
drop_path_prob: 0.0 # probability that given edge will be dropped drop_path_prob: 0.0 # probability that given edge will be dropped

Просмотреть файл

@ -29,6 +29,25 @@ def _create_ram_disk(req_ram:int, path:str)->bool:
# print('RAM disk is not created because not enough memory') # print('RAM disk is not created because not enough memory')
# return False # return False
def untar_dataset(pt_data_dir:str, conf_data:Config, dataroot:str)->None:
storage_name = conf_data['storage_name']
tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar')
if not os.path.isfile(tar_filepath):
raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found')
tar_size = pathlib.Path(tar_filepath).stat().st_size
print('tar_filepath:', tar_filepath, 'tar_size:', tar_size)
local_dataroot = utils.full_path(dataroot)
print('local_dataroot:', local_dataroot)
_create_ram_disk(tar_size, local_dataroot)
# os.makedirs(local_dataroot, exist_ok=True)
utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"')
print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully')
def main(): def main():
parser = argparse.ArgumentParser(description='Archai data install') parser = argparse.ArgumentParser(description='Archai data install')
parser.add_argument('--dataroot', type=str, default='~/dataroot', parser.add_argument('--dataroot', type=str, default='~/dataroot',
@ -46,23 +65,12 @@ def main():
conf_data_filepath = f'confs/datasets/{args.dataset}.yaml' conf_data_filepath = f'confs/datasets/{args.dataset}.yaml'
print('conf_data_filepath:', conf_data_filepath) print('conf_data_filepath:', conf_data_filepath)
conf_data = Config(config_filepath=conf_data_filepath)['dataset'] conf = Config(config_filepath=conf_data_filepath)
storage_name = conf_data['storage_name'] for dataset_key in ['dataset', 'dataset_search', 'dataset_eval']:
tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar') if dataset_key in conf:
if not os.path.isfile(tar_filepath): conf_data = conf[dataset_key]
raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found') untar_dataset(pt_data_dir, conf_data, args.dataroot)
tar_size = pathlib.Path(tar_filepath).stat().st_size
print('tar_filepath:', tar_filepath, 'tar_size:', tar_size)
local_dataroot = utils.full_path(args.dataroot)
print('local_dataroot:', local_dataroot)
_create_ram_disk(tar_size, local_dataroot)
# os.makedirs(local_dataroot, exist_ok=True)
utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"')
print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully')
if __name__ == '__main__': if __name__ == '__main__':
main() main()