зеркало из https://github.com/microsoft/archai.git
resnet manual training, add .contiguous() for PyTorch 1.7
This commit is contained in:
Родитель
e06c016b25
Коммит
a95f2627b7
|
@ -172,7 +172,7 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
|
||||
_create_sysinfo(conf)
|
||||
|
||||
# create a[ex to know distributed processing paramters
|
||||
# create apex to know distributed processing paramters
|
||||
conf_apex = get_conf_common(conf)['apex']
|
||||
apex = ApexUtils(conf_apex, logger=logger)
|
||||
|
||||
|
|
|
@ -207,7 +207,7 @@ def accuracy(output, target, topk=(1,)):
|
|||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(1.0 / batch_size))
|
||||
|
||||
return res
|
||||
|
|
|
@ -41,6 +41,8 @@ class Trainer(EnforceOverrides):
|
|||
self._validation_freq = 0 if conf_validation is None else conf_validation['freq']
|
||||
# endregion
|
||||
|
||||
logger.pushd(self._title + '__init__')
|
||||
|
||||
self._apex = ApexUtils(conf_apex, logger)
|
||||
|
||||
self._checkpoint = checkpoint
|
||||
|
@ -59,6 +61,8 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._start_epoch = -1 # nothing is started yet
|
||||
|
||||
logger.popd()
|
||||
|
||||
def fit(self, train_dl:DataLoader, val_dl:Optional[DataLoader])->Metrics:
|
||||
logger.pushd(self._title)
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
__include__: 'darts.yaml' # just use darts defaults
|
||||
|
||||
common:
|
||||
experiment_name: 'MyResnet'
|
||||
|
||||
nas:
|
||||
eval:
|
||||
loader:
|
||||
train_batch: 128
|
||||
test_batch: 4096
|
||||
cutout: 0
|
||||
drop_path_prob: 0.0
|
||||
grad_clip: 0.0
|
||||
aux_weight: 0.0
|
||||
trainer:
|
||||
epochs: 10
|
||||
logger_freq: 1
|
||||
|
|
@ -12,19 +12,8 @@ from archai.common.common import logger, common_init
|
|||
from archai.datasets import data
|
||||
|
||||
def train_test(conf_eval:Config):
|
||||
# region conf vars
|
||||
conf_loader = conf_eval['loader']
|
||||
conf_trainer = conf_eval['trainer']
|
||||
# endregion
|
||||
|
||||
conf_trainer['validation']['freq']=1
|
||||
conf_trainer['epochs'] = 10
|
||||
conf_loader['train_batch'] = 128
|
||||
conf_loader['test_batch'] = 4096
|
||||
conf_loader['cutout'] = 0
|
||||
conf_trainer['drop_path_prob'] = 0.0
|
||||
conf_trainer['grad_clip'] = 0.0
|
||||
conf_trainer['aux_weight'] = 0.0
|
||||
|
||||
Net = cifar10_models.resnet34
|
||||
model = Net().to(torch.device('cuda', 0))
|
||||
|
@ -38,8 +27,7 @@ def train_test(conf_eval:Config):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
conf = common_init(config_filepath='confs/algos/darts.yaml',
|
||||
param_args=['--common.experiment_name', 'restnet_test'])
|
||||
conf = common_init(config_filepath='confs/algos/resnet.yaml')
|
||||
|
||||
conf_eval = conf['nas']['eval']
|
||||
|
Загрузка…
Ссылка в новой задаче