зеркало из https://github.com/microsoft/archai.git
imagenet support for handcrafted models
This commit is contained in:
Родитель
8a104c53aa
Коммит
cec9c16780
|
@ -163,7 +163,7 @@
|
|||
"request": "launch",
|
||||
"program": "${cwd}/scripts/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["--no-search", "--algos", "manual"]
|
||||
"args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"]
|
||||
},
|
||||
{
|
||||
"name": "Resnet-Full",
|
||||
|
|
|
@ -42,6 +42,11 @@ class Trainer(EnforceOverrides):
|
|||
if conf_validation else None
|
||||
self._metrics:Optional[Metrics] = None
|
||||
self._amp = Amp(self._apex)
|
||||
|
||||
self._droppath_module = self._get_droppath_module()
|
||||
if self._droppath_module is None and self._drop_path_prob > 0.0:
|
||||
logger.warn({'droppath_module': None})
|
||||
|
||||
self._start_epoch = -1 # nothing is started yet
|
||||
|
||||
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
||||
|
@ -208,8 +213,8 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
logits, aux_logits = self.model(x), None
|
||||
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
|
||||
if self._aux_weight:
|
||||
assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
|
||||
# if self._aux_weight: # TODO: some other way to validate?
|
||||
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
|
||||
if tupled_out: # then we are using model created by desc
|
||||
logits, aux_logits = logits[0], logits[1]
|
||||
loss = self.compute_loss(self._lossfn, x, y, logits,
|
||||
|
@ -239,8 +244,16 @@ class Trainer(EnforceOverrides):
|
|||
loss += aux_weight * lossfn(aux_logits, y)
|
||||
return loss
|
||||
|
||||
def _get_droppath_module(self)->Optional[nn.Module]:
|
||||
m = self.model
|
||||
if hasattr(self.model, 'module'): # for data parallel model
|
||||
m = self.model.module
|
||||
if hasattr(m, 'drop_path_prob'):
|
||||
return m
|
||||
return None
|
||||
|
||||
def _set_drop_path(self, epoch:int, epochs:int)->None:
|
||||
if self._drop_path_prob:
|
||||
if self._drop_path_prob and self._droppath_module is not None:
|
||||
drop_prob = self._drop_path_prob * epoch / epochs
|
||||
# set value as property in model (it will be used by forward())
|
||||
# this is necessory when using DataParallel(model)
|
||||
|
|
|
@ -51,8 +51,23 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
|||
|
||||
logger.popd()
|
||||
|
||||
|
||||
def _default_module_name(dataset_name:str, function_name:str)->str:
|
||||
module_name = ''
|
||||
if dataset_name.startswith('cifar'):
|
||||
if function_name.startswith('res'): # support resnext as well
|
||||
module_name = 'archai.cifar10_models.resnet'
|
||||
elif function_name.startswith('dense'):
|
||||
module_name = 'archai.cifar10_models.densenet'
|
||||
elif dataset_name.startswith('imagenet'):
|
||||
module_name = 'torchvision.models'
|
||||
if not module_name:
|
||||
raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet')
|
||||
return module_name
|
||||
|
||||
def create_model(conf_eval:Config, device)->nn.Module:
|
||||
# region conf vars
|
||||
dataset_name = conf_eval['loader']['dataset']['name']
|
||||
final_desc_filename = conf_eval['final_desc_filename']
|
||||
final_model_factory = conf_eval['final_model_factory']
|
||||
full_desc_filename = conf_eval['full_desc_filename']
|
||||
|
@ -65,13 +80,8 @@ def create_model(conf_eval:Config, device)->nn.Module:
|
|||
|
||||
if len(splitted) > 1:
|
||||
module_name = splitted[0]
|
||||
else: # to support lazyness while submitting scripts, we do bit of unnecessory smarts
|
||||
if function_name.startswith('res'): # support resnext as well
|
||||
module_name = 'archai.cifar10_models.resnet'
|
||||
elif function_name.startswith('dense'):
|
||||
module_name = 'archai.cifar10_models.densenet'
|
||||
else:
|
||||
module_name = ''
|
||||
else:
|
||||
module_name = _default_module_name(dataset_name, function_name)
|
||||
|
||||
module = importlib.import_module(module_name) if module_name else sys.modules[__name__]
|
||||
function = getattr(module, function_name)
|
||||
|
|
|
@ -3,7 +3,7 @@ __include__: 'darts.yaml' # just use darts defaults
|
|||
|
||||
nas:
|
||||
eval:
|
||||
final_model_factory: 'archai.cifar10_models.resnet.resnet18'
|
||||
final_model_factory: 'resnet18'
|
||||
|
||||
#darts loader/trainer
|
||||
loader:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
common:
|
||||
seed: 0.0
|
||||
|
||||
imagenet_dataset:
|
||||
dataset_eval:
|
||||
name: 'imagenet'
|
||||
dataroot: '~/dataroot' #torchvision data folder
|
||||
n_classes: 1000
|
||||
|
@ -21,11 +21,11 @@ nas:
|
|||
model_stem1_op: 'stem_conv3x3_s4s2'
|
||||
model_post_op: 'pool_avg2d7x7'
|
||||
dataset:
|
||||
_copy: '/imagenet_dataset'
|
||||
_copy: '/dataset_eval'
|
||||
loader:
|
||||
batch: 128
|
||||
dataset:
|
||||
_copy: '/imagenet_dataset'
|
||||
_copy: '/dataset_eval'
|
||||
trainer:
|
||||
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
|
|
|
@ -29,6 +29,25 @@ def _create_ram_disk(req_ram:int, path:str)->bool:
|
|||
# print('RAM disk is not created because not enough memory')
|
||||
# return False
|
||||
|
||||
def untar_dataset(pt_data_dir:str, conf_data:Config, dataroot:str)->None:
|
||||
storage_name = conf_data['storage_name']
|
||||
tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar')
|
||||
if not os.path.isfile(tar_filepath):
|
||||
raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found')
|
||||
|
||||
tar_size = pathlib.Path(tar_filepath).stat().st_size
|
||||
print('tar_filepath:', tar_filepath, 'tar_size:', tar_size)
|
||||
|
||||
local_dataroot = utils.full_path(dataroot)
|
||||
print('local_dataroot:', local_dataroot)
|
||||
_create_ram_disk(tar_size, local_dataroot)
|
||||
# os.makedirs(local_dataroot, exist_ok=True)
|
||||
|
||||
utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"')
|
||||
|
||||
print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Archai data install')
|
||||
parser.add_argument('--dataroot', type=str, default='~/dataroot',
|
||||
|
@ -46,23 +65,12 @@ def main():
|
|||
conf_data_filepath = f'confs/datasets/{args.dataset}.yaml'
|
||||
print('conf_data_filepath:', conf_data_filepath)
|
||||
|
||||
conf_data = Config(config_filepath=conf_data_filepath)['dataset']
|
||||
storage_name = conf_data['storage_name']
|
||||
tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar')
|
||||
if not os.path.isfile(tar_filepath):
|
||||
raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found')
|
||||
conf = Config(config_filepath=conf_data_filepath)
|
||||
for dataset_key in ['dataset', 'dataset_search', 'dataset_eval']:
|
||||
if dataset_key in conf:
|
||||
conf_data = conf[dataset_key]
|
||||
untar_dataset(pt_data_dir, conf_data, args.dataroot)
|
||||
|
||||
tar_size = pathlib.Path(tar_filepath).stat().st_size
|
||||
print('tar_filepath:', tar_filepath, 'tar_size:', tar_size)
|
||||
|
||||
local_dataroot = utils.full_path(args.dataroot)
|
||||
print('local_dataroot:', local_dataroot)
|
||||
_create_ram_disk(tar_size, local_dataroot)
|
||||
# os.makedirs(local_dataroot, exist_ok=True)
|
||||
|
||||
utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"')
|
||||
|
||||
print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче