зеркало из https://github.com/microsoft/archai.git
imagenet support for handcrafted models
This commit is contained in:
Родитель
8a104c53aa
Коммит
cec9c16780
|
@ -163,7 +163,7 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "${cwd}/scripts/main.py",
|
"program": "${cwd}/scripts/main.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["--no-search", "--algos", "manual"]
|
"args": ["--no-search", "--algos", "manual", "--datasets", "imagenet"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Resnet-Full",
|
"name": "Resnet-Full",
|
||||||
|
|
|
@ -42,6 +42,11 @@ class Trainer(EnforceOverrides):
|
||||||
if conf_validation else None
|
if conf_validation else None
|
||||||
self._metrics:Optional[Metrics] = None
|
self._metrics:Optional[Metrics] = None
|
||||||
self._amp = Amp(self._apex)
|
self._amp = Amp(self._apex)
|
||||||
|
|
||||||
|
self._droppath_module = self._get_droppath_module()
|
||||||
|
if self._droppath_module is None and self._drop_path_prob > 0.0:
|
||||||
|
logger.warn({'droppath_module': None})
|
||||||
|
|
||||||
self._start_epoch = -1 # nothing is started yet
|
self._start_epoch = -1 # nothing is started yet
|
||||||
|
|
||||||
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
||||||
|
@ -208,8 +213,8 @@ class Trainer(EnforceOverrides):
|
||||||
|
|
||||||
logits, aux_logits = self.model(x), None
|
logits, aux_logits = self.model(x), None
|
||||||
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
|
tupled_out = isinstance(logits, Tuple) and len(logits) >=2
|
||||||
if self._aux_weight:
|
# if self._aux_weight: # TODO: some other way to validate?
|
||||||
assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
|
# assert tupled_out, "aux_logits cannot be None unless aux tower is disabled"
|
||||||
if tupled_out: # then we are using model created by desc
|
if tupled_out: # then we are using model created by desc
|
||||||
logits, aux_logits = logits[0], logits[1]
|
logits, aux_logits = logits[0], logits[1]
|
||||||
loss = self.compute_loss(self._lossfn, x, y, logits,
|
loss = self.compute_loss(self._lossfn, x, y, logits,
|
||||||
|
@ -239,8 +244,16 @@ class Trainer(EnforceOverrides):
|
||||||
loss += aux_weight * lossfn(aux_logits, y)
|
loss += aux_weight * lossfn(aux_logits, y)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def _get_droppath_module(self)->Optional[nn.Module]:
|
||||||
|
m = self.model
|
||||||
|
if hasattr(self.model, 'module'): # for data parallel model
|
||||||
|
m = self.model.module
|
||||||
|
if hasattr(m, 'drop_path_prob'):
|
||||||
|
return m
|
||||||
|
return None
|
||||||
|
|
||||||
def _set_drop_path(self, epoch:int, epochs:int)->None:
|
def _set_drop_path(self, epoch:int, epochs:int)->None:
|
||||||
if self._drop_path_prob:
|
if self._drop_path_prob and self._droppath_module is not None:
|
||||||
drop_prob = self._drop_path_prob * epoch / epochs
|
drop_prob = self._drop_path_prob * epoch / epochs
|
||||||
# set value as property in model (it will be used by forward())
|
# set value as property in model (it will be used by forward())
|
||||||
# this is necessory when using DataParallel(model)
|
# this is necessory when using DataParallel(model)
|
||||||
|
|
|
@ -51,8 +51,23 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
||||||
|
|
||||||
logger.popd()
|
logger.popd()
|
||||||
|
|
||||||
|
|
||||||
|
def _default_module_name(dataset_name:str, function_name:str)->str:
|
||||||
|
module_name = ''
|
||||||
|
if dataset_name.startswith('cifar'):
|
||||||
|
if function_name.startswith('res'): # support resnext as well
|
||||||
|
module_name = 'archai.cifar10_models.resnet'
|
||||||
|
elif function_name.startswith('dense'):
|
||||||
|
module_name = 'archai.cifar10_models.densenet'
|
||||||
|
elif dataset_name.startswith('imagenet'):
|
||||||
|
module_name = 'torchvision.models'
|
||||||
|
if not module_name:
|
||||||
|
raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet')
|
||||||
|
return module_name
|
||||||
|
|
||||||
def create_model(conf_eval:Config, device)->nn.Module:
|
def create_model(conf_eval:Config, device)->nn.Module:
|
||||||
# region conf vars
|
# region conf vars
|
||||||
|
dataset_name = conf_eval['loader']['dataset']['name']
|
||||||
final_desc_filename = conf_eval['final_desc_filename']
|
final_desc_filename = conf_eval['final_desc_filename']
|
||||||
final_model_factory = conf_eval['final_model_factory']
|
final_model_factory = conf_eval['final_model_factory']
|
||||||
full_desc_filename = conf_eval['full_desc_filename']
|
full_desc_filename = conf_eval['full_desc_filename']
|
||||||
|
@ -65,13 +80,8 @@ def create_model(conf_eval:Config, device)->nn.Module:
|
||||||
|
|
||||||
if len(splitted) > 1:
|
if len(splitted) > 1:
|
||||||
module_name = splitted[0]
|
module_name = splitted[0]
|
||||||
else: # to support lazyness while submitting scripts, we do bit of unnecessory smarts
|
else:
|
||||||
if function_name.startswith('res'): # support resnext as well
|
module_name = _default_module_name(dataset_name, function_name)
|
||||||
module_name = 'archai.cifar10_models.resnet'
|
|
||||||
elif function_name.startswith('dense'):
|
|
||||||
module_name = 'archai.cifar10_models.densenet'
|
|
||||||
else:
|
|
||||||
module_name = ''
|
|
||||||
|
|
||||||
module = importlib.import_module(module_name) if module_name else sys.modules[__name__]
|
module = importlib.import_module(module_name) if module_name else sys.modules[__name__]
|
||||||
function = getattr(module, function_name)
|
function = getattr(module, function_name)
|
||||||
|
|
|
@ -3,7 +3,7 @@ __include__: 'darts.yaml' # just use darts defaults
|
||||||
|
|
||||||
nas:
|
nas:
|
||||||
eval:
|
eval:
|
||||||
final_model_factory: 'archai.cifar10_models.resnet.resnet18'
|
final_model_factory: 'resnet18'
|
||||||
|
|
||||||
#darts loader/trainer
|
#darts loader/trainer
|
||||||
loader:
|
loader:
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
common:
|
common:
|
||||||
seed: 0.0
|
seed: 0.0
|
||||||
|
|
||||||
imagenet_dataset:
|
dataset_eval:
|
||||||
name: 'imagenet'
|
name: 'imagenet'
|
||||||
dataroot: '~/dataroot' #torchvision data folder
|
dataroot: '~/dataroot' #torchvision data folder
|
||||||
n_classes: 1000
|
n_classes: 1000
|
||||||
|
@ -21,11 +21,11 @@ nas:
|
||||||
model_stem1_op: 'stem_conv3x3_s4s2'
|
model_stem1_op: 'stem_conv3x3_s4s2'
|
||||||
model_post_op: 'pool_avg2d7x7'
|
model_post_op: 'pool_avg2d7x7'
|
||||||
dataset:
|
dataset:
|
||||||
_copy: '/imagenet_dataset'
|
_copy: '/dataset_eval'
|
||||||
loader:
|
loader:
|
||||||
batch: 128
|
batch: 128
|
||||||
dataset:
|
dataset:
|
||||||
_copy: '/imagenet_dataset'
|
_copy: '/dataset_eval'
|
||||||
trainer:
|
trainer:
|
||||||
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
||||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||||
|
|
|
@ -29,6 +29,25 @@ def _create_ram_disk(req_ram:int, path:str)->bool:
|
||||||
# print('RAM disk is not created because not enough memory')
|
# print('RAM disk is not created because not enough memory')
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
|
def untar_dataset(pt_data_dir:str, conf_data:Config, dataroot:str)->None:
|
||||||
|
storage_name = conf_data['storage_name']
|
||||||
|
tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar')
|
||||||
|
if not os.path.isfile(tar_filepath):
|
||||||
|
raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found')
|
||||||
|
|
||||||
|
tar_size = pathlib.Path(tar_filepath).stat().st_size
|
||||||
|
print('tar_filepath:', tar_filepath, 'tar_size:', tar_size)
|
||||||
|
|
||||||
|
local_dataroot = utils.full_path(dataroot)
|
||||||
|
print('local_dataroot:', local_dataroot)
|
||||||
|
_create_ram_disk(tar_size, local_dataroot)
|
||||||
|
# os.makedirs(local_dataroot, exist_ok=True)
|
||||||
|
|
||||||
|
utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"')
|
||||||
|
|
||||||
|
print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Archai data install')
|
parser = argparse.ArgumentParser(description='Archai data install')
|
||||||
parser.add_argument('--dataroot', type=str, default='~/dataroot',
|
parser.add_argument('--dataroot', type=str, default='~/dataroot',
|
||||||
|
@ -46,23 +65,12 @@ def main():
|
||||||
conf_data_filepath = f'confs/datasets/{args.dataset}.yaml'
|
conf_data_filepath = f'confs/datasets/{args.dataset}.yaml'
|
||||||
print('conf_data_filepath:', conf_data_filepath)
|
print('conf_data_filepath:', conf_data_filepath)
|
||||||
|
|
||||||
conf_data = Config(config_filepath=conf_data_filepath)['dataset']
|
conf = Config(config_filepath=conf_data_filepath)
|
||||||
storage_name = conf_data['storage_name']
|
for dataset_key in ['dataset', 'dataset_search', 'dataset_eval']:
|
||||||
tar_filepath = os.path.join(pt_data_dir, storage_name + '.tar')
|
if dataset_key in conf:
|
||||||
if not os.path.isfile(tar_filepath):
|
conf_data = conf[dataset_key]
|
||||||
raise RuntimeError(f'Tar file for dataset at {tar_filepath} was not found')
|
untar_dataset(pt_data_dir, conf_data, args.dataroot)
|
||||||
|
|
||||||
tar_size = pathlib.Path(tar_filepath).stat().st_size
|
|
||||||
print('tar_filepath:', tar_filepath, 'tar_size:', tar_size)
|
|
||||||
|
|
||||||
local_dataroot = utils.full_path(args.dataroot)
|
|
||||||
print('local_dataroot:', local_dataroot)
|
|
||||||
_create_ram_disk(tar_size, local_dataroot)
|
|
||||||
# os.makedirs(local_dataroot, exist_ok=True)
|
|
||||||
|
|
||||||
utils.exec_shell_command(f'tar -xf "{tar_filepath}" -C "{local_dataroot}"')
|
|
||||||
|
|
||||||
print(f'dataset copied from {tar_filepath} to {local_dataroot} sucessfully')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче