resnet manual training, add .contiguous() for PyTorch 1.7

This commit is contained in:
Shital Shah 2021-01-08 08:14:35 -08:00 коммит произвёл Gustavo Rosa
Родитель e06c016b25
Коммит a95f2627b7
5 изменённых файлов: 25 добавлений и 15 удалений

Просмотреть файл

@ -172,7 +172,7 @@ def common_init(config_filepath: Optional[str]=None,
_create_sysinfo(conf)
# create a[ex to know distributed processing paramters
# create apex to know distributed processing paramters
conf_apex = get_conf_common(conf)['apex']
apex = ApexUtils(conf_apex, logger=logger)

Просмотреть файл

@ -207,7 +207,7 @@ def accuracy(output, target, topk=(1,)):
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res

Просмотреть файл

@ -41,6 +41,8 @@ class Trainer(EnforceOverrides):
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
# endregion
logger.pushd(self._title + '__init__')
self._apex = ApexUtils(conf_apex, logger)
self._checkpoint = checkpoint
@ -59,6 +61,8 @@ class Trainer(EnforceOverrides):
self._start_epoch = -1 # nothing is started yet
logger.popd()
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
logger.pushd(self._title)

18
confs/algos/resnet.yaml Normal file
Просмотреть файл

@ -0,0 +1,18 @@
__include__: 'darts.yaml' # just use darts defaults
common:
experiment_name: 'MyResnet'
nas:
eval:
loader:
train_batch: 128
test_batch: 4096
cutout: 0
drop_path_prob: 0.0
grad_clip: 0.0
aux_weight: 0.0
trainer:
epochs: 10
logger_freq: 1

Просмотреть файл

@ -12,19 +12,8 @@ from archai.common.common import logger, common_init
from archai.datasets import data
def train_test(conf_eval:Config):
# region conf vars
conf_loader = conf_eval['loader']
conf_trainer = conf_eval['trainer']
# endregion
conf_trainer['validation']['freq']=1
conf_trainer['epochs'] = 10
conf_loader['train_batch'] = 128
conf_loader['test_batch'] = 4096
conf_loader['cutout'] = 0
conf_trainer['drop_path_prob'] = 0.0
conf_trainer['grad_clip'] = 0.0
conf_trainer['aux_weight'] = 0.0
Net = cifar10_models.resnet34
model = Net().to(torch.device('cuda', 0))
@ -38,8 +27,7 @@ def train_test(conf_eval:Config):
if __name__ == '__main__':
conf = common_init(config_filepath='confs/algos/darts.yaml',
param_args=['--common.experiment_name', 'restnet_test'])
conf = common_init(config_filepath='confs/algos/resnet.yaml')
conf_eval = conf['nas']['eval']