зеркало из https://github.com/microsoft/archai.git
move other resnet models to resnet_paper, make local_rank param default, pure PyTorch code for resnet for comparison,
This commit is contained in:
Родитель
6151424d7c
Коммит
d4092dae83
|
@ -1,3 +1,3 @@
|
|||
# Credits
|
||||
|
||||
Models in this folder are as-is from https://github.com/huyvnphan/PyTorch-CIFAR10.
|
||||
Models in this folder are as-is from https://github.com/huyvnphan/PyTorch-CIFAR10 and from Yerlan Idelbayev.
|
|
@ -116,7 +116,7 @@ class ApexUtils:
|
|||
|
||||
|
||||
def _setup_gpus(self, seed:float, detect_anomaly:bool):
|
||||
utils.setup_cuda(seed, self.local_rank)
|
||||
utils.setup_cuda(seed, local_rank=self.local_rank)
|
||||
|
||||
torch.autograd.set_detect_anomaly(detect_anomaly)
|
||||
self._log_info({'set_detect_anomaly': detect_anomaly,
|
||||
|
|
|
@ -84,6 +84,13 @@ def MeasureTime(f_py=None, no_print=True, disable_gc=False, name:Optional[str]=N
|
|||
return _decorator(f_py) if callable(f_py) else _decorator
|
||||
|
||||
class MeasureBlockTime:
|
||||
"""
|
||||
Example:
|
||||
with MeasureBlockTime('my_func') as t:
|
||||
my_func()
|
||||
print(t.elapsed)
|
||||
|
||||
"""
|
||||
def __init__(self, name:str, no_print=True, disable_gc=False):
|
||||
self.name = name
|
||||
self.no_print = no_print
|
||||
|
|
|
@ -228,7 +228,7 @@ def download_and_extract_tar(url, download_root, extract_root=None, filename=Non
|
|||
|
||||
extract_tar(os.path.join(download_root, filename), extract_root, **kwargs)
|
||||
|
||||
def setup_cuda(seed:Union[float, int], local_rank:int):
|
||||
def setup_cuda(seed:Union[float, int], local_rank:int=0):
|
||||
seed = int(seed) + local_rank
|
||||
# setup cuda
|
||||
cudnn.enabled = True
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch.backends.cudnn as cudnn
|
|||
import numpy as np
|
||||
|
||||
|
||||
utils.setup_cuda(2, 0)
|
||||
utils.setup_cuda(2, local_rank=0)
|
||||
|
||||
device = torch.device('cuda')
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch_testbed.dataloader_dali import cifar10_dataloaders
|
|||
|
||||
|
||||
utils.create_logger()
|
||||
utils.setup_cuda(42, 0)
|
||||
utils.setup_cuda(42, local_rank=0)
|
||||
|
||||
batch_size = 512
|
||||
half = True
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch_testbed import utils, cifar10_models
|
|||
from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, get_timing
|
||||
|
||||
utils.create_logger()
|
||||
utils.setup_cuda(42, 0)
|
||||
utils.setup_cuda(42, local_rank=0)
|
||||
|
||||
batch_size = 512
|
||||
half = True
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
import argparse
|
||||
from typing import List, Mapping, Tuple, Any
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import yaml
|
||||
|
||||
from archai.common import utils
|
||||
from archai import cifar10_models
|
||||
|
||||
|
||||
def train_test(datadir:str, expdir:str,
|
||||
exp_name:str, exp_desc:str, epochs:int, model_name:str,
|
||||
train_batch_size:int, loader_workers:int, seed, half:bool,
|
||||
test_batch_size:int, cutout:int)->Tuple[List[Mapping], int]:
|
||||
|
||||
# dirs
|
||||
utils.setup_cuda(seed)
|
||||
device = torch.device('cuda')
|
||||
|
||||
datadir = utils.full_path(datadir)
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
|
||||
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
|
||||
|
||||
# log config for reference
|
||||
logging.info(f'exp_name="{exp_name}", exp_desc="{exp_desc}"')
|
||||
logging.info(f'model_name="{model_name}", seed={seed}, epochs={epochs}')
|
||||
logging.info(f'half={half}, cutout={cutout}, train_batch_size={train_batch_size}')
|
||||
logging.info(f'datadir="{datadir}"')
|
||||
logging.info(f'expdir="{expdir}"')
|
||||
|
||||
model_class = getattr(cifar10_models, model_name)
|
||||
net = model_class()
|
||||
logging.info(f'param_size_m={param_size(net):.1e}')
|
||||
net = net.to(device)
|
||||
|
||||
crit = torch.nn.CrossEntropyLoss().to(device)
|
||||
|
||||
optim, sched, sched_on_epoch, batch_size = optim_sched(net)
|
||||
logging.info(f'train_batch_size={train_batch_size}, batch_size={batch_size}')
|
||||
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||
|
||||
# load data just before train start so any errors so far is not delayed
|
||||
train_dl, test_dl = cifar10_dataloaders(datadir,
|
||||
train_batch_size=batch_size, test_batch_size=test_batch_size,
|
||||
train_num_workers=loader_workers, test_num_workers=loader_workers,
|
||||
cutout=cutout)
|
||||
|
||||
metrics = train(epochs, train_dl, test_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, False)
|
||||
|
||||
with open(os.path.join(expdir, 'metrics.yaml'), 'w') as f:
|
||||
yaml.dump(metrics, f)
|
||||
|
||||
return metrics, train_batch_size
|
||||
|
||||
|
||||
def train(epochs, train_dl, test_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, quiet)->List[Mapping]:
|
||||
if half:
|
||||
net.half()
|
||||
crit.half()
|
||||
train_acc, test_acc = 0.0, 0.0
|
||||
metrics = []
|
||||
for epoch in range(epochs):
|
||||
lr = optim.param_groups[0]['lr']
|
||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)
|
||||
test_acc = test(net, test_dl, device, half)
|
||||
metrics.append({'test_top1':test_acc, 'train_top1':train_acc, 'lr':lr,
|
||||
'epoch': epoch, 'train_loss': loss})
|
||||
if not quiet:
|
||||
logging.info(f'train_epoch={epoch}, test_top1={test_acc},'
|
||||
f' train_top1={train_acc}, lr={lr:.4g}')
|
||||
return metrics
|
||||
|
||||
def optim_sched(net):
|
||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||
|
||||
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[100, 150, 200, 400, 600]) # resnet original paper
|
||||
sched_on_epoch = True
|
||||
|
||||
return optim, sched, sched_on_epoch, 128
|
||||
|
||||
|
||||
def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1)\
|
||||
->Tuple[DataLoader, DataLoader]:
|
||||
if utils.is_debugging():
|
||||
train_num_workers = test_num_workers = 0
|
||||
logging.info('debugger=true, num_workers=0')
|
||||
if train_num_workers <= -1:
|
||||
train_num_workers = torch.cuda.device_count()*4
|
||||
if test_num_workers <= -1:
|
||||
test_num_workers = torch.cuda.device_count()*4
|
||||
|
||||
train_transform = cifar10_transform(aug=True, cutout=cutout)
|
||||
trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
|
||||
download=True, transform=train_transform)
|
||||
train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,
|
||||
shuffle=True, num_workers=train_num_workers, pin_memory=True)
|
||||
|
||||
test_transform = cifar10_transform(aug=False, cutout=0)
|
||||
testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
|
||||
download=True, transform=test_transform)
|
||||
test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
|
||||
return train_dl, test_dl
|
||||
|
||||
|
||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)->Tuple[float, float]:
|
||||
correct, total, loss_total = 0, 0, 0.0
|
||||
net.train()
|
||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||
inputs = inputs.to(device, non_blocking=True)
|
||||
targets = targets.to(device, non_blocking=True)
|
||||
|
||||
if half:
|
||||
inputs = inputs.half()
|
||||
|
||||
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
||||
inputs, targets)
|
||||
loss_total += loss
|
||||
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if sched and sched_on_epoch:
|
||||
sched.step()
|
||||
return 100.0*correct/total, loss_total
|
||||
|
||||
|
||||
def train_step(net:torch.nn.Module,
|
||||
crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool,
|
||||
inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]:
|
||||
outputs = net(inputs)
|
||||
|
||||
loss = crit(outputs, targets)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
optim.step()
|
||||
if sched and not sched_on_epoch:
|
||||
sched.step()
|
||||
return outputs, loss.item()
|
||||
|
||||
def test(net, test_dl, device, half)->float:
|
||||
correct, total = 0, 0
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
for batch_idx, (inputs, targets) in enumerate(test_dl):
|
||||
inputs = inputs.to(device, non_blocking=False)
|
||||
targets = targets.to(device)
|
||||
|
||||
if half:
|
||||
inputs = inputs.half()
|
||||
|
||||
outputs = net(inputs)
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
return 100.0*correct/total
|
||||
|
||||
|
||||
def param_size(model:torch.nn.Module)->int:
|
||||
"""count all parameters excluding auxiliary"""
|
||||
return sum(v.numel() for name, v in model.named_parameters() \
|
||||
if "auxiliary" not in name)
|
||||
|
||||
def cifar10_transform(aug:bool, cutout=0):
|
||||
MEAN = [0.49139968, 0.48215827, 0.44653124]
|
||||
STD = [0.24703233, 0.24348505, 0.26158768]
|
||||
|
||||
transf = [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(MEAN, STD)
|
||||
]
|
||||
|
||||
if aug:
|
||||
aug_transf = [
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
transf = aug_transf + transf
|
||||
|
||||
if cutout > 0: # must be after normalization
|
||||
transf += [CutoutDefault(cutout)]
|
||||
|
||||
return transforms.Compose(transf)
|
||||
|
||||
|
||||
class CutoutDefault:
|
||||
"""
|
||||
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
||||
"""
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Pytorch cifar training')
|
||||
parser.add_argument('--experiment-name', '-n', default='train_pytorch')
|
||||
parser.add_argument('--experiment-description', '-d',
|
||||
default='Train cifar usin pure PyTorch code')
|
||||
parser.add_argument('--epochs', '-e', type=int, default=300)
|
||||
parser.add_argument('--model-name', '-m', default='resnet34')
|
||||
parser.add_argument('--train-batch', '-b', type=int, default=-1)
|
||||
parser.add_argument('--test-batch', type=int, default=4096)
|
||||
parser.add_argument('--seed', '-s', type=float, default=42)
|
||||
parser.add_argument('--half', type=lambda x:x.lower()=='true',
|
||||
nargs='?', const=True, default=False)
|
||||
parser.add_argument('--cutout', type=int, default=0)
|
||||
parser.add_argument('--loader', default='torch', help='torch or dali')
|
||||
|
||||
parser.add_argument('--datadir', default='',
|
||||
help='where to find dataset files, default is ~/torchvision_data_dir')
|
||||
parser.add_argument('--outdir', default='',
|
||||
help='where to put results, default is ~/logdir')
|
||||
|
||||
parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)')
|
||||
parser.add_argument('--optim-sched', '-os', default='darts',
|
||||
help='Optimizer and scheduler provider')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.datadir:
|
||||
args.datadir = os.environ.get('PT_DATA_DIR', '') or '~/dataroot'
|
||||
if not args.outdir:
|
||||
args.outdir = os.environ.get('PT_OUTPUT_DIR', '')
|
||||
if not args.outdir:
|
||||
args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name)
|
||||
|
||||
expdir = utils.full_path(args.outdir)
|
||||
os.makedirs(expdir, exist_ok=True)
|
||||
|
||||
metrics, train_batch_size = train_test(datadir=args.datadir, expdir=expdir,
|
||||
exp_name=args.experiment_name,
|
||||
exp_desc=args.experiment_description,
|
||||
epochs=args.epochs, model_name=args.model_name,
|
||||
train_batch_size=args.train_batch, loader_workers=args.loader_workers,
|
||||
seed=args.seed, half=args.half, test_batch_size=args.test_batch,
|
||||
cutout=args.cutout)
|
||||
|
||||
print(metrics[-1])
|
||||
|
||||
results = [
|
||||
('test_acc', metrics[-1]['test_top1']),
|
||||
('epochs', args.epochs),
|
||||
('train_batch_size', train_batch_size),
|
||||
('test_batch_size', args.test_batch),
|
||||
('model_name', args.model_name),
|
||||
('exp_name', args.experiment_name),
|
||||
('exp_desc', args.experiment_description),
|
||||
('seed', args.seed),
|
||||
('devices', utils.cuda_device_names()),
|
||||
('half', args.half),
|
||||
('cutout', args.cutout),
|
||||
('optim_sched', args.optim_sched),
|
||||
('train_acc', metrics[-1]['train_top1']),
|
||||
('loader', args.loader),
|
||||
('loader_workers', args.loader_workers),
|
||||
('date', str(time.time())),
|
||||
]
|
||||
|
||||
utils.append_csv_file(os.path.join(expdir, 'results.tsv'), results)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -11,55 +11,48 @@ import torch.optim
|
|||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import resnet
|
||||
|
||||
model_names = sorted(name for name in resnet.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and name.startswith("resnet")
|
||||
and callable(resnet.__dict__[name]))
|
||||
|
||||
print(model_names)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
|
||||
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
|
||||
choices=model_names,
|
||||
help='model architecture: ' + ' | '.join(model_names) +
|
||||
' (default: resnet32)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=128, type=int,
|
||||
metavar='N', help='mini-batch size (default: 128)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)')
|
||||
parser.add_argument('--print-freq', '-p', default=50, type=int,
|
||||
metavar='N', help='print frequency (default: 50)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--half', dest='half', action='store_true',
|
||||
help='use half-precision(16-bit) ')
|
||||
parser.add_argument('--save-dir', dest='save_dir',
|
||||
help='The directory used to save the trained models',
|
||||
default='save_temp', type=str)
|
||||
parser.add_argument('--save-every', dest='save_every',
|
||||
help='Saves checkpoints at every specified number of epochs',
|
||||
type=int, default=10)
|
||||
best_prec1 = 0
|
||||
|
||||
|
||||
def main():
|
||||
global args, best_prec1
|
||||
parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
|
||||
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
|
||||
choices=model_names,
|
||||
help='model architecture: ' + ' | '.join(model_names) +
|
||||
' (default: resnet32)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=128, type=int,
|
||||
metavar='N', help='mini-batch size (default: 128)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)')
|
||||
parser.add_argument('--print-freq', '-p', default=50, type=int,
|
||||
metavar='N', help='print frequency (default: 50)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--half', dest='half', action='store_true',
|
||||
help='use half-precision(16-bit) ')
|
||||
parser.add_argument('--save-dir', dest='save_dir',
|
||||
help='The directory used to save the trained models',
|
||||
default='save_temp', type=str)
|
||||
parser.add_argument('--save-every', dest='save_every',
|
||||
help='Saves checkpoints at every specified number of epochs',
|
||||
type=int, default=10)
|
||||
best_prec1 = 0
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Загрузка…
Ссылка в новой задаче