diff --git a/archai/cifar10_models/README.md b/archai/cifar10_models/README.md index f3365605..5a285e61 100644 --- a/archai/cifar10_models/README.md +++ b/archai/cifar10_models/README.md @@ -1,3 +1,3 @@ # Credits -Models in this folder are as-is from https://github.com/huyvnphan/PyTorch-CIFAR10. \ No newline at end of file +Models in this folder are as-is from https://github.com/huyvnphan/PyTorch-CIFAR10 and from Yerlan Idelbayev. \ No newline at end of file diff --git a/scripts/plain_models/cifar_resnet/resnet.py b/archai/cifar10_models/resnet_paper.py similarity index 100% rename from scripts/plain_models/cifar_resnet/resnet.py rename to archai/cifar10_models/resnet_paper.py diff --git a/archai/common/apex_utils.py b/archai/common/apex_utils.py index 69c02412..c8f88b8e 100644 --- a/archai/common/apex_utils.py +++ b/archai/common/apex_utils.py @@ -116,7 +116,7 @@ class ApexUtils: def _setup_gpus(self, seed:float, detect_anomaly:bool): - utils.setup_cuda(seed, self.local_rank) + utils.setup_cuda(seed, local_rank=self.local_rank) torch.autograd.set_detect_anomaly(detect_anomaly) self._log_info({'set_detect_anomaly': detect_anomaly, diff --git a/archai/common/timing.py b/archai/common/timing.py index be018d1f..9fb46d1a 100644 --- a/archai/common/timing.py +++ b/archai/common/timing.py @@ -84,6 +84,13 @@ def MeasureTime(f_py=None, no_print=True, disable_gc=False, name:Optional[str]=N return _decorator(f_py) if callable(f_py) else _decorator class MeasureBlockTime: + """ + Example: + with MeasureBlockTime('my_func') as t: + my_func() + print(t.elapsed) + + """ def __init__(self, name:str, no_print=True, disable_gc=False): self.name = name self.no_print = no_print diff --git a/archai/common/utils.py b/archai/common/utils.py index c8b3d914..472137e0 100644 --- a/archai/common/utils.py +++ b/archai/common/utils.py @@ -228,7 +228,7 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) -def setup_cuda(seed:Union[float, int], local_rank:int): +def setup_cuda(seed:Union[float, int], local_rank:int=0): seed = int(seed) + local_rank # setup cuda cudnn.enabled = True diff --git a/scripts/perf/mixed_ops.py b/scripts/perf/mixed_ops.py index 31b9c528..47956340 100644 --- a/scripts/perf/mixed_ops.py +++ b/scripts/perf/mixed_ops.py @@ -9,7 +9,7 @@ import torch.backends.cudnn as cudnn import numpy as np -utils.setup_cuda(2, 0) +utils.setup_cuda(2, local_rank=0) device = torch.device('cuda') diff --git a/scripts/perf/model_dl_test.py b/scripts/perf/model_dl_test.py index bb48b0c2..36ae49e7 100644 --- a/scripts/perf/model_dl_test.py +++ b/scripts/perf/model_dl_test.py @@ -13,7 +13,7 @@ from torch_testbed.dataloader_dali import cifar10_dataloaders utils.create_logger() -utils.setup_cuda(42, 0) +utils.setup_cuda(42, local_rank=0) batch_size = 512 half = True diff --git a/scripts/perf/model_test.py b/scripts/perf/model_test.py index 8894d3c3..1200ba1f 100644 --- a/scripts/perf/model_test.py +++ b/scripts/perf/model_test.py @@ -8,7 +8,7 @@ from torch_testbed import utils, cifar10_models from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, get_timing utils.create_logger() -utils.setup_cuda(42, 0) +utils.setup_cuda(42, local_rank=0) batch_size = 512 half = True diff --git a/scripts/plain_models/train_resnet.py b/scripts/plain_models/cifar_resnet/train_archai.py similarity index 100% rename from scripts/plain_models/train_resnet.py rename to scripts/plain_models/cifar_resnet/train_archai.py diff --git a/scripts/plain_models/cifar_resnet/train_pytorch.py b/scripts/plain_models/cifar_resnet/train_pytorch.py new file mode 100644 index 00000000..477d9b67 --- /dev/null +++ b/scripts/plain_models/cifar_resnet/train_pytorch.py @@ -0,0 +1,302 @@ +import argparse +from typing import List, Mapping, Tuple, Any +import os +import logging +import numpy as np +import time + +import torch +from torch.optim.optimizer import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.nn.modules.loss import _Loss +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +import yaml + +from archai.common import utils +from archai import cifar10_models + + +def train_test(datadir:str, expdir:str, + exp_name:str, exp_desc:str, epochs:int, model_name:str, + train_batch_size:int, loader_workers:int, seed, half:bool, + test_batch_size:int, cutout:int)->Tuple[List[Mapping], int]: + + # dirs + utils.setup_cuda(seed) + device = torch.device('cuda') + + datadir = utils.full_path(datadir) + os.makedirs(datadir, exist_ok=True) + + utils.create_logger(filepath=os.path.join(expdir, 'logs.log')) + + # log config for reference + logging.info(f'exp_name="{exp_name}", exp_desc="{exp_desc}"') + logging.info(f'model_name="{model_name}", seed={seed}, epochs={epochs}') + logging.info(f'half={half}, cutout={cutout}, train_batch_size={train_batch_size}') + logging.info(f'datadir="{datadir}"') + logging.info(f'expdir="{expdir}"') + + model_class = getattr(cifar10_models, model_name) + net = model_class() + logging.info(f'param_size_m={param_size(net):.1e}') + net = net.to(device) + + crit = torch.nn.CrossEntropyLoss().to(device) + + optim, sched, sched_on_epoch, batch_size = optim_sched(net) + logging.info(f'train_batch_size={train_batch_size}, batch_size={batch_size}') + logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}') + + # load data just before train start so any errors so far is not delayed + train_dl, test_dl = cifar10_dataloaders(datadir, + train_batch_size=batch_size, test_batch_size=test_batch_size, + train_num_workers=loader_workers, test_num_workers=loader_workers, + cutout=cutout) + + metrics = train(epochs, train_dl, test_dl, net, device, crit, optim, + sched, sched_on_epoch, half, False) + + with open(os.path.join(expdir, 'metrics.yaml'), 'w') as f: + yaml.dump(metrics, f) + + return metrics, train_batch_size + + +def train(epochs, train_dl, test_dl, net, device, crit, optim, + sched, sched_on_epoch, half, quiet)->List[Mapping]: + if half: + net.half() + crit.half() + train_acc, test_acc = 0.0, 0.0 + metrics = [] + for epoch in range(epochs): + lr = optim.param_groups[0]['lr'] + train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim, + sched, sched_on_epoch, half) + test_acc = test(net, test_dl, device, half) + metrics.append({'test_top1':test_acc, 'train_top1':train_acc, 'lr':lr, + 'epoch': epoch, 'train_loss': loss}) + if not quiet: + logging.info(f'train_epoch={epoch}, test_top1={test_acc},' + f' train_top1={train_acc}, lr={lr:.4g}') + return metrics + +def optim_sched(net): + lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4 + optim = torch.optim.SGD(net.parameters(), + lr, momentum=momentum, weight_decay=weight_decay) + logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}') + + sched = torch.optim.lr_scheduler.MultiStepLR(optim, + milestones=[100, 150, 200, 400, 600]) # resnet original paper + sched_on_epoch = True + + return optim, sched, sched_on_epoch, 128 + + +def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096, + cutout=0, train_num_workers=-1, test_num_workers=-1)\ + ->Tuple[DataLoader, DataLoader]: + if utils.is_debugging(): + train_num_workers = test_num_workers = 0 + logging.info('debugger=true, num_workers=0') + if train_num_workers <= -1: + train_num_workers = torch.cuda.device_count()*4 + if test_num_workers <= -1: + test_num_workers = torch.cuda.device_count()*4 + + train_transform = cifar10_transform(aug=True, cutout=cutout) + trainset = torchvision.datasets.CIFAR10(root=datadir, train=True, + download=True, transform=train_transform) + train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, + shuffle=True, num_workers=train_num_workers, pin_memory=True) + + test_transform = cifar10_transform(aug=False, cutout=0) + testset = torchvision.datasets.CIFAR10(root=datadir, train=False, + download=True, transform=test_transform) + test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, + shuffle=False, num_workers=test_num_workers, pin_memory=True) + + return train_dl, test_dl + + +def train_epoch(epoch, net, train_dl, device, crit, optim, + sched, sched_on_epoch, half)->Tuple[float, float]: + correct, total, loss_total = 0, 0, 0.0 + net.train() + for batch_idx, (inputs, targets) in enumerate(train_dl): + inputs = inputs.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if half: + inputs = inputs.half() + + outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch, + inputs, targets) + loss_total += loss + + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + if sched and sched_on_epoch: + sched.step() + return 100.0*correct/total, loss_total + + +def train_step(net:torch.nn.Module, + crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool, + inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]: + outputs = net(inputs) + + loss = crit(outputs, targets) + optim.zero_grad() + loss.backward() + + optim.step() + if sched and not sched_on_epoch: + sched.step() + return outputs, loss.item() + +def test(net, test_dl, device, half)->float: + correct, total = 0, 0 + net.eval() + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(test_dl): + inputs = inputs.to(device, non_blocking=False) + targets = targets.to(device) + + if half: + inputs = inputs.half() + + outputs = net(inputs) + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + return 100.0*correct/total + + +def param_size(model:torch.nn.Module)->int: + """count all parameters excluding auxiliary""" + return sum(v.numel() for name, v in model.named_parameters() \ + if "auxiliary" not in name) + +def cifar10_transform(aug:bool, cutout=0): + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + + transf = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + if aug: + aug_transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + transf = aug_transf + transf + + if cutout > 0: # must be after normalization + transf += [CutoutDefault(cutout)] + + return transforms.Compose(transf) + + +class CutoutDefault: + """ + Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py + """ + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + + +def main(): + parser = argparse.ArgumentParser(description='Pytorch cifar training') + parser.add_argument('--experiment-name', '-n', default='train_pytorch') + parser.add_argument('--experiment-description', '-d', + default='Train cifar usin pure PyTorch code') + parser.add_argument('--epochs', '-e', type=int, default=300) + parser.add_argument('--model-name', '-m', default='resnet34') + parser.add_argument('--train-batch', '-b', type=int, default=-1) + parser.add_argument('--test-batch', type=int, default=4096) + parser.add_argument('--seed', '-s', type=float, default=42) + parser.add_argument('--half', type=lambda x:x.lower()=='true', + nargs='?', const=True, default=False) + parser.add_argument('--cutout', type=int, default=0) + parser.add_argument('--loader', default='torch', help='torch or dali') + + parser.add_argument('--datadir', default='', + help='where to find dataset files, default is ~/torchvision_data_dir') + parser.add_argument('--outdir', default='', + help='where to put results, default is ~/logdir') + + parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)') + parser.add_argument('--optim-sched', '-os', default='darts', + help='Optimizer and scheduler provider') + + args = parser.parse_args() + + if not args.datadir: + args.datadir = os.environ.get('PT_DATA_DIR', '') or '~/dataroot' + if not args.outdir: + args.outdir = os.environ.get('PT_OUTPUT_DIR', '') + if not args.outdir: + args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name) + + expdir = utils.full_path(args.outdir) + os.makedirs(expdir, exist_ok=True) + + metrics, train_batch_size = train_test(datadir=args.datadir, expdir=expdir, + exp_name=args.experiment_name, + exp_desc=args.experiment_description, + epochs=args.epochs, model_name=args.model_name, + train_batch_size=args.train_batch, loader_workers=args.loader_workers, + seed=args.seed, half=args.half, test_batch_size=args.test_batch, + cutout=args.cutout) + + print(metrics[-1]) + + results = [ + ('test_acc', metrics[-1]['test_top1']), + ('epochs', args.epochs), + ('train_batch_size', train_batch_size), + ('test_batch_size', args.test_batch), + ('model_name', args.model_name), + ('exp_name', args.experiment_name), + ('exp_desc', args.experiment_description), + ('seed', args.seed), + ('devices', utils.cuda_device_names()), + ('half', args.half), + ('cutout', args.cutout), + ('optim_sched', args.optim_sched), + ('train_acc', metrics[-1]['train_top1']), + ('loader', args.loader), + ('loader_workers', args.loader_workers), + ('date', str(time.time())), + ] + + utils.append_csv_file(os.path.join(expdir, 'results.tsv'), results) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/plain_models/cifar_resnet/trainer.py b/scripts/plain_models/cifar_resnet/train_pytorch1.py similarity index 73% rename from scripts/plain_models/cifar_resnet/trainer.py rename to scripts/plain_models/cifar_resnet/train_pytorch1.py index 0e055f11..60f7b406 100644 --- a/scripts/plain_models/cifar_resnet/trainer.py +++ b/scripts/plain_models/cifar_resnet/train_pytorch1.py @@ -11,55 +11,48 @@ import torch.optim import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets -import resnet -model_names = sorted(name for name in resnet.__dict__ - if name.islower() and not name.startswith("__") - and name.startswith("resnet") - and callable(resnet.__dict__[name])) - -print(model_names) - -parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32', - choices=model_names, - help='model architecture: ' + ' | '.join(model_names) + - ' (default: resnet32)') -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -parser.add_argument('--epochs', default=200, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('-b', '--batch-size', default=128, type=int, - metavar='N', help='mini-batch size (default: 128)') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -parser.add_argument('--print-freq', '-p', default=50, type=int, - metavar='N', help='print frequency (default: 50)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') -parser.add_argument('--half', dest='half', action='store_true', - help='use half-precision(16-bit) ') -parser.add_argument('--save-dir', dest='save_dir', - help='The directory used to save the trained models', - default='save_temp', type=str) -parser.add_argument('--save-every', dest='save_every', - help='Saves checkpoints at every specified number of epochs', - type=int, default=10) -best_prec1 = 0 def main(): - global args, best_prec1 + parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') + parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32', + choices=model_names, + help='model architecture: ' + ' | '.join(model_names) + + ' (default: resnet32)') + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') + parser.add_argument('--epochs', default=200, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') + parser.add_argument('-b', '--batch-size', default=128, type=int, + metavar='N', help='mini-batch size (default: 128)') + parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') + parser.add_argument('--print-freq', '-p', default=50, type=int, + metavar='N', help='print frequency (default: 50)') + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') + parser.add_argument('--half', dest='half', action='store_true', + help='use half-precision(16-bit) ') + parser.add_argument('--save-dir', dest='save_dir', + help='The directory used to save the trained models', + default='save_temp', type=str) + parser.add_argument('--save-every', dest='save_every', + help='Saves checkpoints at every specified number of epochs', + type=int, default=10) + best_prec1 = 0 + + args = parser.parse_args()