diff --git a/archai/algos/nasbench101/nasbench101_dataset.py b/archai/algos/nasbench101/nasbench101_dataset.py index 3e404551..d0a6bb1a 100644 --- a/archai/algos/nasbench101/nasbench101_dataset.py +++ b/archai/algos/nasbench101/nasbench101_dataset.py @@ -96,10 +96,13 @@ import logging import numpy as np from numpy.lib.function_base import average +from torch import nn + from archai.common import utils from . import config from . import model_metrics_pb2 from . import model_spec as _model_spec +from . import model_builder # Bring ModelSpec to top-level for convenience. See lib/model_spec.py. ModelSpec = _model_spec.ModelSpec @@ -304,6 +307,13 @@ class Nasbench101Dataset(object): return data + def create_model(self, index:int, device=None, + stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, num_labels=10)->nn.Module: + data = self[index] + adj, ops = data['module_adjacency'], data['module_operations'] + return model_builder.build(adj, ops, device=device, + stem_out_channels=stem_out_channels, num_stacks=num_stacks, + num_modules_per_stack=num_modules_per_stack, num_labels=num_labels) def is_valid(self, desc_matrix:List[List[int]], vertex_ops:List[str]): """Checks the validity of the model_spec. diff --git a/scripts/experiments/loss_var/nasbench101_var.py b/scripts/experiments/loss_var/nasbench101_var.py index 8d81a598..52e09c46 100644 --- a/scripts/experiments/loss_var/nasbench101_var.py +++ b/scripts/experiments/loss_var/nasbench101_var.py @@ -1,11 +1,13 @@ import argparse -from typing import List, Mapping, Tuple, Any +import math +from typing import List, Mapping, Optional, Tuple, Any import os import logging import numpy as np import time import torch +from torch import nn from torch.optim.optimizer import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.nn.modules.loss import _Loss @@ -17,91 +19,47 @@ import yaml from archai.common import utils from archai import cifar10_models -from archai.datasets.list_dataset import ListDataset +from archai.algos.nasbench101.nasbench101_dataset import Nasbench101Dataset -def train_test(datadir:str, expdir:str, - exp_name:str, exp_desc:str, epochs:int, model_name:str, - train_batch_size:int, loader_workers:int, seed, half:bool, - test_batch_size:int, cutout:int)->Tuple[List[Mapping], int]: - - # dirs - utils.setup_cuda(seed) - device = torch.device('cuda') - - datadir = utils.full_path(datadir) - os.makedirs(datadir, exist_ok=True) - - utils.create_logger(filepath=os.path.join(expdir, 'logs.log')) - - # log config for reference - logging.info(f'exp_name="{exp_name}", exp_desc="{exp_desc}"') - logging.info(f'model_name="{model_name}", seed={seed}, epochs={epochs}') - logging.info(f'half={half}, cutout={cutout}, train_batch_size={train_batch_size}') - logging.info(f'datadir="{datadir}"') - logging.info(f'expdir="{expdir}"') - - model_class = getattr(cifar10_models, model_name) - net = model_class() - logging.info(f'param_size_m={param_size(net):.1e}') - net = net.to(device) - - crit = torch.nn.CrossEntropyLoss().to(device) - - optim, sched, sched_on_epoch, batch_size = optim_sched(net) - logging.info(f'train_batch_size={train_batch_size}, batch_size={batch_size}') - logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}') - - # load data just before train start so any errors so far is not delayed - train_dl, test_dl = cifar10_dataloaders(datadir, - train_batch_size=batch_size, test_batch_size=test_batch_size, - train_num_workers=loader_workers, test_num_workers=loader_workers, - cutout=cutout) - - metrics = train(epochs, train_dl, test_dl, net, device, crit, optim, - sched, sched_on_epoch, half, False) - - with open(os.path.join(expdir, 'metrics.yaml'), 'w') as f: - yaml.dump(metrics, f) - - return metrics, train_batch_size - - -def train(epochs, train_dl, test_dl, net, device, crit, optim, - sched, sched_on_epoch, half, quiet)->List[Mapping]: - if half: - net.half() - crit.half() +def train(epochs, train_dl, val_dal, net, device, crit, optim, + sched, sched_on_epoch, half, quiet) -> List[Mapping]: train_acc, test_acc = 0.0, 0.0 metrics = [] for epoch in range(epochs): lr = optim.param_groups[0]['lr'] train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim, - sched, sched_on_epoch, half) - test_acc = test(net, test_dl, device, half) - metrics.append({'test_top1':test_acc, 'train_top1':train_acc, 'lr':lr, + sched, sched_on_epoch, half) + + val_acc = test(net, val_dal, device, + half) if val_dal is not None else math.nan + metrics.append({'val_top1': val_acc, 'train_top1': train_acc, 'lr': lr, 'epoch': epoch, 'train_loss': loss}) if not quiet: - logging.info(f'train_epoch={epoch}, test_top1={test_acc},' + logging.info(f'train_epoch={epoch}, val_top1={val_acc},' f' train_top1={train_acc}, lr={lr:.4g}') return metrics + def optim_sched(net): - lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4 - optim = torch.optim.SGD(net.parameters(), - lr, momentum=momentum, weight_decay=weight_decay) - logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}') + lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4 + optim = torch.optim.SGD(net.parameters(), + lr, momentum=momentum, weight_decay=weight_decay) + logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}') - sched = torch.optim.lr_scheduler.MultiStepLR(optim, - milestones=[100, 150, 200, 400, 600]) # resnet original paper - sched_on_epoch = True + sched = torch.optim.lr_scheduler.MultiStepLR(optim, + milestones=[100, 150, 200, 400, 600]) # resnet original paper + sched_on_epoch = True - return optim, sched, sched_on_epoch, 128 + logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}') + + return optim, sched, sched_on_epoch -def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096, - cutout=0, train_num_workers=-1, test_num_workers=-1)\ - ->Tuple[DataLoader, DataLoader]: +def get_data(datadir: str, train_batch_size=128, test_batch_size=4096, + cutout=0, train_num_workers=-1, test_num_workers=-1, + val_percent=10.0)\ + -> Tuple[DataLoader, Optional[DataLoader], DataLoader]: if utils.is_debugging(): train_num_workers = test_num_workers = 0 logging.info('debugger=true, num_workers=0') @@ -112,23 +70,40 @@ def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096, train_transform = cifar10_transform(aug=True, cutout=cutout) trainset = torchvision.datasets.CIFAR10(root=datadir, train=True, - download=True, transform=train_transform) + download=True, transform=train_transform) + + val_len = int(len(trainset) * val_percent / 100.0) + train_len = len(trainset) - val_len + + valset = None + if val_len: + trainset, valset = torch.utils.data.random_split( + trainset, [train_len, val_len]) + train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, - shuffle=True, num_workers=train_num_workers, pin_memory=True) + shuffle=True, num_workers=train_num_workers, pin_memory=True) + + if valset is not None: + val_dl = torch.utils.data.DataLoader(valset, batch_size=test_batch_size, + shuffle=False, num_workers=test_num_workers, pin_memory=True) + else: + val_dl = None test_transform = cifar10_transform(aug=False, cutout=0) testset = torchvision.datasets.CIFAR10(root=datadir, train=False, - download=True, transform=test_transform) + download=True, transform=test_transform) test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, - shuffle=False, num_workers=test_num_workers, pin_memory=True) + shuffle=False, num_workers=test_num_workers, pin_memory=True) - return train_dl, test_dl + logging.info( + f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}') + + return train_dl, val_dl, test_dl def train_epoch(epoch, net, train_dl, device, crit, optim, - sched, sched_on_epoch, half)->Tuple[float, float]: + sched, sched_on_epoch, half) -> Tuple[float, float]: correct, total, loss_total = 0, 0, 0.0 - ds = ListDataset(train_dl) net.train() for batch_idx, (inputs, targets) in enumerate(train_dl): inputs = inputs.to(device, non_blocking=True) @@ -149,9 +124,9 @@ def train_epoch(epoch, net, train_dl, device, crit, optim, return 100.0*correct/total, loss_total -def train_step(net:torch.nn.Module, - crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool, - inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]: +def train_step(net: nn.Module, + crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool, + inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]: outputs = net(inputs) loss = crit(outputs, targets) @@ -163,7 +138,8 @@ def train_step(net:torch.nn.Module, sched.step() return outputs, loss.item() -def test(net, test_dl, device, half)->float: + +def test(net, test_dl, device, half) -> float: correct, total = 0, 0 net.eval() with torch.no_grad(): @@ -181,12 +157,13 @@ def test(net, test_dl, device, half)->float: return 100.0*correct/total -def param_size(model:torch.nn.Module)->int: +def param_size(model: torch.nn.Module) -> int: """count all parameters excluding auxiliary""" - return sum(v.numel() for name, v in model.named_parameters() \ - if "auxiliary" not in name) + return sum(v.numel() for name, v in model.named_parameters() + if "auxiliary" not in name) -def cifar10_transform(aug:bool, cutout=0): + +def cifar10_transform(aug: bool, cutout=0): MEAN = [0.49139968, 0.48215827, 0.44653124] STD = [0.24703233, 0.24348505, 0.26158768] @@ -202,7 +179,7 @@ def cifar10_transform(aug:bool, cutout=0): ] transf = aug_transf + transf - if cutout > 0: # must be after normalization + if cutout > 0: # must be after normalization transf += [CutoutDefault(cutout)] return transforms.Compose(transf) @@ -212,6 +189,7 @@ class CutoutDefault: """ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py """ + def __init__(self, length): self.length = length @@ -233,29 +211,70 @@ class CutoutDefault: return img +def log_metrics(expdir: str, filename: str, metrics, test_acc: float, args, perf_data:dict) -> None: + print('filename:', f'test_acc: {test_acc}', metrics[-1]) + results = [ + ('test_acc', test_acc), + ('nasbenc101_test_acc', perf_data['avg_final_test_accuracy']), + ('val_acc', metrics[-1]['val_top1']), + ('epochs', args.epochs), + ('train_batch_size', args.train_batch_size), + ('test_batch_size', args.test_batch_size), + ('model_name', args.model_name), + ('exp_name', args.experiment_name), + ('exp_desc', args.experiment_description), + ('seed', args.seed), + ('devices', utils.cuda_device_names()), + ('half', args.half), + ('cutout', args.cutout), + ('train_acc', metrics[-1]['train_top1']), + ('loader_workers', args.loader_workers), + ('date', str(time.time())), + ] + utils.append_csv_file(os.path.join(expdir, f'{filename}.tsv'), results) + with open(os.path.join(expdir, f'{filename}_metrics.yaml'), 'w') as f: + yaml.dump(metrics, f) + with open(os.path.join(expdir, f'{filename}_nasbench101.yaml'), 'w') as f: + yaml.dump(perf_data, f) + +def create_crit(device, half): + crit = nn.CrossEntropyLoss().to(device) + if half: + crit.half() + return crit + +def create_model(nsds, index, device, half) -> nn.Module: + net = nsds.create_model(index) + logging.info(f'param_size_m={param_size(net):.1e}') + net = net.to(device) + if half: + net.half() + return net + + def main(): parser = argparse.ArgumentParser(description='Pytorch cifar training') parser.add_argument('--experiment-name', '-n', default='train_pytorch') parser.add_argument('--experiment-description', '-d', default='Train cifar usin pure PyTorch code') - parser.add_argument('--epochs', '-e', type=int, default=300) + parser.add_argument('--epochs', '-e', type=int, default=1) parser.add_argument('--model-name', '-m', default='resnet34') - parser.add_argument('--train-batch', '-b', type=int, default=-1) - parser.add_argument('--test-batch', type=int, default=4096) + parser.add_argument('--device', default='', + help='"cuda" or "cpu" or "" in which case use cuda if available') + parser.add_argument('--train-batch-size', '-b', type=int, default=128) + parser.add_argument('--test-batch-size', type=int, default=4096) parser.add_argument('--seed', '-s', type=float, default=42) - parser.add_argument('--half', type=lambda x:x.lower()=='true', + parser.add_argument('--half', type=lambda x: x.lower() == 'true', nargs='?', const=True, default=False) parser.add_argument('--cutout', type=int, default=0) - parser.add_argument('--loader', default='torch', help='torch or dali') parser.add_argument('--datadir', default='', help='where to find dataset files, default is ~/torchvision_data_dir') parser.add_argument('--outdir', default='', help='where to put results, default is ~/logdir') - parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)') - parser.add_argument('--optim-sched', '-os', default='darts', - help='Optimizer and scheduler provider') + parser.add_argument('--loader-workers', type=int, default=-1, + help='number of thread/workers for data loader (-1 means auto)') args = parser.parse_args() @@ -264,41 +283,51 @@ def main(): if not args.outdir: args.outdir = os.environ.get('PT_OUTPUT_DIR', '') if not args.outdir: - args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name) + args.outdir = os.path.join( + '~/logdir', 'cifar_testbed', args.experiment_name) expdir = utils.full_path(args.outdir) os.makedirs(expdir, exist_ok=True) - metrics, train_batch_size = train_test(datadir=args.datadir, expdir=expdir, - exp_name=args.experiment_name, - exp_desc=args.experiment_description, - epochs=args.epochs, model_name=args.model_name, - train_batch_size=args.train_batch, loader_workers=args.loader_workers, - seed=args.seed, half=args.half, test_batch_size=args.test_batch, - cutout=args.cutout) + utils.setup_cuda(args.seed) + datadir = utils.full_path(args.datadir) + os.makedirs(datadir, exist_ok=True) - print(metrics[-1]) + utils.create_logger(filepath=os.path.join(expdir, 'logs.log')) - results = [ - ('test_acc', metrics[-1]['test_top1']), - ('epochs', args.epochs), - ('train_batch_size', train_batch_size), - ('test_batch_size', args.test_batch), - ('model_name', args.model_name), - ('exp_name', args.experiment_name), - ('exp_desc', args.experiment_description), - ('seed', args.seed), - ('devices', utils.cuda_device_names()), - ('half', args.half), - ('cutout', args.cutout), - ('optim_sched', args.optim_sched), - ('train_acc', metrics[-1]['train_top1']), - ('loader', args.loader), - ('loader_workers', args.loader_workers), - ('date', str(time.time())), - ] + # log config for reference + logging.info( + f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"') + logging.info( + f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}') + logging.info(f'half={args.half}, cutout={args.cutout}') + logging.info(f'datadir="{datadir}"') + logging.info(f'expdir="{expdir}"') + logging.info(f'train_batch_size={args.train_batch_size}') - utils.append_csv_file(os.path.join(expdir, 'results.tsv'), results) + if args.device: + device = torch.device(args.device) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + nsds = Nasbench101Dataset( + os.path.join(args.datadir, 'nasbench_ds', 'nasbench_only108.tfrecord.pkl')) + + # load data just before train start so any errors so far is not delayed + train_dl, val_dl, test_dl = get_data(datadir=datadir, + train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size, + train_num_workers=args.loader_workers, test_num_workers=args.loader_workers, + cutout=args.cutout) + + for model_id in [4, 400, 4000, 40000, 400000]: + perf_data = nsds[model_id] + net = create_model(nsds, model_id, device, args.half) + crit = create_crit(device, args.half) + optim, sched, sched_on_epoch = optim_sched(net) + train_metrics = train(perf_data['epochs'], train_dl, val_dl, net, device, crit, optim, + sched, sched_on_epoch, args.half, False) + test_acc = test(net, test_dl, device, args.half) + log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/scripts/nasbench101/query_test.py b/scripts/nasbench101/query_test.py index 1ae5fc52..01319bef 100644 --- a/scripts/nasbench101/query_test.py +++ b/scripts/nasbench101/query_test.py @@ -27,6 +27,9 @@ def main(): best = nsds[len(nsds)-1] print('best', best) + # create model by index + model = nsds.create_model(42) + print(model) if __name__ == '__main__': main() \ No newline at end of file diff --git a/scripts/plain_models/cifar_resnet/train_pytorch.py b/scripts/plain_models/cifar_resnet/train_pytorch.py index 209dd234..ef7a7cd1 100644 --- a/scripts/plain_models/cifar_resnet/train_pytorch.py +++ b/scripts/plain_models/cifar_resnet/train_pytorch.py @@ -22,22 +22,24 @@ from archai import cifar10_models def train(epochs, train_dl, val_dal, net, device, crit, optim, - sched, sched_on_epoch, half, quiet)->List[Mapping]: + sched, sched_on_epoch, half, quiet) -> List[Mapping]: train_acc, test_acc = 0.0, 0.0 metrics = [] for epoch in range(epochs): lr = optim.param_groups[0]['lr'] train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim, - sched, sched_on_epoch, half) + sched, sched_on_epoch, half) - val_acc = test(net, val_dal, device, half) if val_dal is not None else math.nan - metrics.append({'val_top1':val_acc, 'train_top1':train_acc, 'lr':lr, + val_acc = test(net, val_dal, device, + half) if val_dal is not None else math.nan + metrics.append({'val_top1': val_acc, 'train_top1': train_acc, 'lr': lr, 'epoch': epoch, 'train_loss': loss}) if not quiet: logging.info(f'train_epoch={epoch}, val_top1={val_acc},' f' train_top1={train_acc}, lr={lr:.4g}') return metrics + def optim_sched(net): lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4 optim = torch.optim.SGD(net.parameters(), @@ -45,7 +47,7 @@ def optim_sched(net): logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}') sched = torch.optim.lr_scheduler.MultiStepLR(optim, - milestones=[100, 150, 200, 400, 600]) # resnet original paper + milestones=[100, 150, 200, 400, 600]) # resnet original paper sched_on_epoch = True logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}') @@ -53,10 +55,10 @@ def optim_sched(net): return optim, sched, sched_on_epoch -def get_data(datadir:str, train_batch_size=128, test_batch_size=4096, - cutout=0, train_num_workers=-1, test_num_workers=-1, - val_percent=10.0)\ - ->Tuple[DataLoader, Optional[DataLoader], DataLoader]: +def get_data(datadir: str, train_batch_size=128, test_batch_size=4096, + cutout=0, train_num_workers=-1, test_num_workers=-1, + val_percent=10.0)\ + -> Tuple[DataLoader, Optional[DataLoader], DataLoader]: if utils.is_debugging(): train_num_workers = test_num_workers = 0 logging.info('debugger=true, num_workers=0') @@ -67,37 +69,39 @@ def get_data(datadir:str, train_batch_size=128, test_batch_size=4096, train_transform = cifar10_transform(aug=True, cutout=cutout) trainset = torchvision.datasets.CIFAR10(root=datadir, train=True, - download=True, transform=train_transform) + download=True, transform=train_transform) val_len = int(len(trainset) * val_percent / 100.0) train_len = len(trainset) - val_len valset = None if val_len: - trainset, valset = torch.utils.data.random_split(trainset, [train_len, val_len]) + trainset, valset = torch.utils.data.random_split( + trainset, [train_len, val_len]) train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, - shuffle=True, num_workers=train_num_workers, pin_memory=True) + shuffle=True, num_workers=train_num_workers, pin_memory=True) if valset is not None: val_dl = torch.utils.data.DataLoader(valset, batch_size=test_batch_size, - shuffle=False, num_workers=test_num_workers, pin_memory=True) + shuffle=False, num_workers=test_num_workers, pin_memory=True) else: val_dl = None test_transform = cifar10_transform(aug=False, cutout=0) testset = torchvision.datasets.CIFAR10(root=datadir, train=False, - download=True, transform=test_transform) + download=True, transform=test_transform) test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, - shuffle=False, num_workers=test_num_workers, pin_memory=True) + shuffle=False, num_workers=test_num_workers, pin_memory=True) - logging.info(f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}') + logging.info( + f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}') return train_dl, val_dl, test_dl def train_epoch(epoch, net, train_dl, device, crit, optim, - sched, sched_on_epoch, half)->Tuple[float, float]: + sched, sched_on_epoch, half) -> Tuple[float, float]: correct, total, loss_total = 0, 0, 0.0 net.train() for batch_idx, (inputs, targets) in enumerate(train_dl): @@ -119,9 +123,9 @@ def train_epoch(epoch, net, train_dl, device, crit, optim, return 100.0*correct/total, loss_total -def train_step(net:nn.Module, - crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool, - inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]: +def train_step(net: nn.Module, + crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool, + inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]: outputs = net(inputs) loss = crit(outputs, targets) @@ -133,7 +137,8 @@ def train_step(net:nn.Module, sched.step() return outputs, loss.item() -def test(net, test_dl, device, half)->float: + +def test(net, test_dl, device, half) -> float: correct, total = 0, 0 net.eval() with torch.no_grad(): @@ -151,12 +156,13 @@ def test(net, test_dl, device, half)->float: return 100.0*correct/total -def param_size(model:torch.nn.Module)->int: +def param_size(model: torch.nn.Module) -> int: """count all parameters excluding auxiliary""" - return sum(v.numel() for name, v in model.named_parameters() \ - if "auxiliary" not in name) + return sum(v.numel() for name, v in model.named_parameters() + if "auxiliary" not in name) -def cifar10_transform(aug:bool, cutout=0): + +def cifar10_transform(aug: bool, cutout=0): MEAN = [0.49139968, 0.48215827, 0.44653124] STD = [0.24703233, 0.24348505, 0.26158768] @@ -172,7 +178,7 @@ def cifar10_transform(aug:bool, cutout=0): ] transf = aug_transf + transf - if cutout > 0: # must be after normalization + if cutout > 0: # must be after normalization transf += [CutoutDefault(cutout)] return transforms.Compose(transf) @@ -182,6 +188,7 @@ class CutoutDefault: """ Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py """ + def __init__(self, length): self.length = length @@ -202,7 +209,8 @@ class CutoutDefault: img *= mask return img -def log_metrics(expdir:str, filename:str, metrics, test_acc:float, args)->None: + +def log_metrics(expdir: str, filename: str, metrics, test_acc: float, args) -> None: print('filename:', f'test_acc: {test_acc}', metrics[-1]) results = [ ('test_acc', test_acc), @@ -225,13 +233,15 @@ def log_metrics(expdir:str, filename:str, metrics, test_acc:float, args)->None: with open(os.path.join(expdir, f'{filename}.yaml'), 'w') as f: yaml.dump(metrics, f) + def create_crit(device, half): crit = nn.CrossEntropyLoss().to(device) if half: crit.half() return crit -def create_model(model_name, device, half)->nn.Module: + +def create_model(model_name, device, half) -> nn.Module: model_class = getattr(cifar10_models, model_name) net = model_class() logging.info(f'param_size_m={param_size(net):.1e}') @@ -240,6 +250,7 @@ def create_model(model_name, device, half)->nn.Module: net.half() return net + def main(): parser = argparse.ArgumentParser(description='Pytorch cifar training') parser.add_argument('--experiment-name', '-n', default='train_pytorch') @@ -247,11 +258,12 @@ def main(): default='Train cifar usin pure PyTorch code') parser.add_argument('--epochs', '-e', type=int, default=1) parser.add_argument('--model-name', '-m', default='resnet34') - parser.add_argument('--device', default='', help='"cuda" or "cpu" or "" in which case use cuda if available') + parser.add_argument('--device', default='', + help='"cuda" or "cpu" or "" in which case use cuda if available') parser.add_argument('--train-batch-size', '-b', type=int, default=128) parser.add_argument('--test-batch-size', type=int, default=4096) parser.add_argument('--seed', '-s', type=float, default=42) - parser.add_argument('--half', type=lambda x:x.lower()=='true', + parser.add_argument('--half', type=lambda x: x.lower() == 'true', nargs='?', const=True, default=False) parser.add_argument('--cutout', type=int, default=0) @@ -260,7 +272,8 @@ def main(): parser.add_argument('--outdir', default='', help='where to put results, default is ~/logdir') - parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)') + parser.add_argument('--loader-workers', type=int, default=-1, + help='number of thread/workers for data loader (-1 means auto)') args = parser.parse_args() @@ -269,7 +282,8 @@ def main(): if not args.outdir: args.outdir = os.environ.get('PT_OUTPUT_DIR', '') if not args.outdir: - args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name) + args.outdir = os.path.join( + '~/logdir', 'cifar_testbed', args.experiment_name) expdir = utils.full_path(args.outdir) os.makedirs(expdir, exist_ok=True) @@ -281,8 +295,10 @@ def main(): utils.create_logger(filepath=os.path.join(expdir, 'logs.log')) # log config for reference - logging.info(f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"') - logging.info(f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}') + logging.info( + f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"') + logging.info( + f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}') logging.info(f'half={args.half}, cutout={args.cutout}') logging.info(f'datadir="{datadir}"') logging.info(f'expdir="{expdir}"') @@ -299,14 +315,15 @@ def main(): # load data just before train start so any errors so far is not delayed train_dl, val_dl, test_dl = get_data(datadir=datadir, - train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size, - train_num_workers=args.loader_workers, test_num_workers=args.loader_workers, - cutout=args.cutout) + train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size, + train_num_workers=args.loader_workers, test_num_workers=args.loader_workers, + cutout=args.cutout) train_metrics = train(args.epochs, train_dl, val_dl, net, device, crit, optim, - sched, sched_on_epoch, args.half, False) + sched, sched_on_epoch, args.half, False) test_acc = test(net, test_dl, device, args.half) log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args) + if __name__ == '__main__': - main() \ No newline at end of file + main()