зеркало из https://github.com/microsoft/archai.git
create nasbench101 model by index, var exp
This commit is contained in:
Родитель
5e70aa8a51
Коммит
7396e4f56c
|
@ -96,10 +96,13 @@ import logging
|
|||
import numpy as np
|
||||
from numpy.lib.function_base import average
|
||||
|
||||
from torch import nn
|
||||
|
||||
from archai.common import utils
|
||||
from . import config
|
||||
from . import model_metrics_pb2
|
||||
from . import model_spec as _model_spec
|
||||
from . import model_builder
|
||||
|
||||
# Bring ModelSpec to top-level for convenience. See lib/model_spec.py.
|
||||
ModelSpec = _model_spec.ModelSpec
|
||||
|
@ -304,6 +307,13 @@ class Nasbench101Dataset(object):
|
|||
|
||||
return data
|
||||
|
||||
def create_model(self, index:int, device=None,
|
||||
stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, num_labels=10)->nn.Module:
|
||||
data = self[index]
|
||||
adj, ops = data['module_adjacency'], data['module_operations']
|
||||
return model_builder.build(adj, ops, device=device,
|
||||
stem_out_channels=stem_out_channels, num_stacks=num_stacks,
|
||||
num_modules_per_stack=num_modules_per_stack, num_labels=num_labels)
|
||||
|
||||
def is_valid(self, desc_matrix:List[List[int]], vertex_ops:List[str]):
|
||||
"""Checks the validity of the model_spec.
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import argparse
|
||||
from typing import List, Mapping, Tuple, Any
|
||||
import math
|
||||
from typing import List, Mapping, Optional, Tuple, Any
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
@ -17,91 +19,47 @@ import yaml
|
|||
|
||||
from archai.common import utils
|
||||
from archai import cifar10_models
|
||||
from archai.datasets.list_dataset import ListDataset
|
||||
from archai.algos.nasbench101.nasbench101_dataset import Nasbench101Dataset
|
||||
|
||||
|
||||
def train_test(datadir:str, expdir:str,
|
||||
exp_name:str, exp_desc:str, epochs:int, model_name:str,
|
||||
train_batch_size:int, loader_workers:int, seed, half:bool,
|
||||
test_batch_size:int, cutout:int)->Tuple[List[Mapping], int]:
|
||||
|
||||
# dirs
|
||||
utils.setup_cuda(seed)
|
||||
device = torch.device('cuda')
|
||||
|
||||
datadir = utils.full_path(datadir)
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
|
||||
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
|
||||
|
||||
# log config for reference
|
||||
logging.info(f'exp_name="{exp_name}", exp_desc="{exp_desc}"')
|
||||
logging.info(f'model_name="{model_name}", seed={seed}, epochs={epochs}')
|
||||
logging.info(f'half={half}, cutout={cutout}, train_batch_size={train_batch_size}')
|
||||
logging.info(f'datadir="{datadir}"')
|
||||
logging.info(f'expdir="{expdir}"')
|
||||
|
||||
model_class = getattr(cifar10_models, model_name)
|
||||
net = model_class()
|
||||
logging.info(f'param_size_m={param_size(net):.1e}')
|
||||
net = net.to(device)
|
||||
|
||||
crit = torch.nn.CrossEntropyLoss().to(device)
|
||||
|
||||
optim, sched, sched_on_epoch, batch_size = optim_sched(net)
|
||||
logging.info(f'train_batch_size={train_batch_size}, batch_size={batch_size}')
|
||||
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||
|
||||
# load data just before train start so any errors so far is not delayed
|
||||
train_dl, test_dl = cifar10_dataloaders(datadir,
|
||||
train_batch_size=batch_size, test_batch_size=test_batch_size,
|
||||
train_num_workers=loader_workers, test_num_workers=loader_workers,
|
||||
cutout=cutout)
|
||||
|
||||
metrics = train(epochs, train_dl, test_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, False)
|
||||
|
||||
with open(os.path.join(expdir, 'metrics.yaml'), 'w') as f:
|
||||
yaml.dump(metrics, f)
|
||||
|
||||
return metrics, train_batch_size
|
||||
|
||||
|
||||
def train(epochs, train_dl, test_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, quiet)->List[Mapping]:
|
||||
if half:
|
||||
net.half()
|
||||
crit.half()
|
||||
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
|
||||
train_acc, test_acc = 0.0, 0.0
|
||||
metrics = []
|
||||
for epoch in range(epochs):
|
||||
lr = optim.param_groups[0]['lr']
|
||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)
|
||||
test_acc = test(net, test_dl, device, half)
|
||||
metrics.append({'test_top1':test_acc, 'train_top1':train_acc, 'lr':lr,
|
||||
sched, sched_on_epoch, half)
|
||||
|
||||
val_acc = test(net, val_dal, device,
|
||||
half) if val_dal is not None else math.nan
|
||||
metrics.append({'val_top1': val_acc, 'train_top1': train_acc, 'lr': lr,
|
||||
'epoch': epoch, 'train_loss': loss})
|
||||
if not quiet:
|
||||
logging.info(f'train_epoch={epoch}, test_top1={test_acc},'
|
||||
logging.info(f'train_epoch={epoch}, val_top1={val_acc},'
|
||||
f' train_top1={train_acc}, lr={lr:.4g}')
|
||||
return metrics
|
||||
|
||||
|
||||
def optim_sched(net):
|
||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||
|
||||
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[100, 150, 200, 400, 600]) # resnet original paper
|
||||
sched_on_epoch = True
|
||||
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[100, 150, 200, 400, 600]) # resnet original paper
|
||||
sched_on_epoch = True
|
||||
|
||||
return optim, sched, sched_on_epoch, 128
|
||||
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||
|
||||
return optim, sched, sched_on_epoch
|
||||
|
||||
|
||||
def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1)\
|
||||
->Tuple[DataLoader, DataLoader]:
|
||||
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||
val_percent=10.0)\
|
||||
-> Tuple[DataLoader, Optional[DataLoader], DataLoader]:
|
||||
if utils.is_debugging():
|
||||
train_num_workers = test_num_workers = 0
|
||||
logging.info('debugger=true, num_workers=0')
|
||||
|
@ -112,23 +70,40 @@ def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096,
|
|||
|
||||
train_transform = cifar10_transform(aug=True, cutout=cutout)
|
||||
trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
|
||||
download=True, transform=train_transform)
|
||||
download=True, transform=train_transform)
|
||||
|
||||
val_len = int(len(trainset) * val_percent / 100.0)
|
||||
train_len = len(trainset) - val_len
|
||||
|
||||
valset = None
|
||||
if val_len:
|
||||
trainset, valset = torch.utils.data.random_split(
|
||||
trainset, [train_len, val_len])
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,
|
||||
shuffle=True, num_workers=train_num_workers, pin_memory=True)
|
||||
shuffle=True, num_workers=train_num_workers, pin_memory=True)
|
||||
|
||||
if valset is not None:
|
||||
val_dl = torch.utils.data.DataLoader(valset, batch_size=test_batch_size,
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
else:
|
||||
val_dl = None
|
||||
|
||||
test_transform = cifar10_transform(aug=False, cutout=0)
|
||||
testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
|
||||
download=True, transform=test_transform)
|
||||
download=True, transform=test_transform)
|
||||
test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
|
||||
return train_dl, test_dl
|
||||
logging.info(
|
||||
f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}')
|
||||
|
||||
return train_dl, val_dl, test_dl
|
||||
|
||||
|
||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)->Tuple[float, float]:
|
||||
sched, sched_on_epoch, half) -> Tuple[float, float]:
|
||||
correct, total, loss_total = 0, 0, 0.0
|
||||
ds = ListDataset(train_dl)
|
||||
net.train()
|
||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||
inputs = inputs.to(device, non_blocking=True)
|
||||
|
@ -149,9 +124,9 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
|||
return 100.0*correct/total, loss_total
|
||||
|
||||
|
||||
def train_step(net:torch.nn.Module,
|
||||
crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool,
|
||||
inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]:
|
||||
def train_step(net: nn.Module,
|
||||
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
||||
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
||||
outputs = net(inputs)
|
||||
|
||||
loss = crit(outputs, targets)
|
||||
|
@ -163,7 +138,8 @@ def train_step(net:torch.nn.Module,
|
|||
sched.step()
|
||||
return outputs, loss.item()
|
||||
|
||||
def test(net, test_dl, device, half)->float:
|
||||
|
||||
def test(net, test_dl, device, half) -> float:
|
||||
correct, total = 0, 0
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
|
@ -181,12 +157,13 @@ def test(net, test_dl, device, half)->float:
|
|||
return 100.0*correct/total
|
||||
|
||||
|
||||
def param_size(model:torch.nn.Module)->int:
|
||||
def param_size(model: torch.nn.Module) -> int:
|
||||
"""count all parameters excluding auxiliary"""
|
||||
return sum(v.numel() for name, v in model.named_parameters() \
|
||||
if "auxiliary" not in name)
|
||||
return sum(v.numel() for name, v in model.named_parameters()
|
||||
if "auxiliary" not in name)
|
||||
|
||||
def cifar10_transform(aug:bool, cutout=0):
|
||||
|
||||
def cifar10_transform(aug: bool, cutout=0):
|
||||
MEAN = [0.49139968, 0.48215827, 0.44653124]
|
||||
STD = [0.24703233, 0.24348505, 0.26158768]
|
||||
|
||||
|
@ -202,7 +179,7 @@ def cifar10_transform(aug:bool, cutout=0):
|
|||
]
|
||||
transf = aug_transf + transf
|
||||
|
||||
if cutout > 0: # must be after normalization
|
||||
if cutout > 0: # must be after normalization
|
||||
transf += [CutoutDefault(cutout)]
|
||||
|
||||
return transforms.Compose(transf)
|
||||
|
@ -212,6 +189,7 @@ class CutoutDefault:
|
|||
"""
|
||||
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
||||
"""
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
|
@ -233,29 +211,70 @@ class CutoutDefault:
|
|||
return img
|
||||
|
||||
|
||||
def log_metrics(expdir: str, filename: str, metrics, test_acc: float, args, perf_data:dict) -> None:
|
||||
print('filename:', f'test_acc: {test_acc}', metrics[-1])
|
||||
results = [
|
||||
('test_acc', test_acc),
|
||||
('nasbenc101_test_acc', perf_data['avg_final_test_accuracy']),
|
||||
('val_acc', metrics[-1]['val_top1']),
|
||||
('epochs', args.epochs),
|
||||
('train_batch_size', args.train_batch_size),
|
||||
('test_batch_size', args.test_batch_size),
|
||||
('model_name', args.model_name),
|
||||
('exp_name', args.experiment_name),
|
||||
('exp_desc', args.experiment_description),
|
||||
('seed', args.seed),
|
||||
('devices', utils.cuda_device_names()),
|
||||
('half', args.half),
|
||||
('cutout', args.cutout),
|
||||
('train_acc', metrics[-1]['train_top1']),
|
||||
('loader_workers', args.loader_workers),
|
||||
('date', str(time.time())),
|
||||
]
|
||||
utils.append_csv_file(os.path.join(expdir, f'{filename}.tsv'), results)
|
||||
with open(os.path.join(expdir, f'{filename}_metrics.yaml'), 'w') as f:
|
||||
yaml.dump(metrics, f)
|
||||
with open(os.path.join(expdir, f'{filename}_nasbench101.yaml'), 'w') as f:
|
||||
yaml.dump(perf_data, f)
|
||||
|
||||
def create_crit(device, half):
|
||||
crit = nn.CrossEntropyLoss().to(device)
|
||||
if half:
|
||||
crit.half()
|
||||
return crit
|
||||
|
||||
def create_model(nsds, index, device, half) -> nn.Module:
|
||||
net = nsds.create_model(index)
|
||||
logging.info(f'param_size_m={param_size(net):.1e}')
|
||||
net = net.to(device)
|
||||
if half:
|
||||
net.half()
|
||||
return net
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Pytorch cifar training')
|
||||
parser.add_argument('--experiment-name', '-n', default='train_pytorch')
|
||||
parser.add_argument('--experiment-description', '-d',
|
||||
default='Train cifar usin pure PyTorch code')
|
||||
parser.add_argument('--epochs', '-e', type=int, default=300)
|
||||
parser.add_argument('--epochs', '-e', type=int, default=1)
|
||||
parser.add_argument('--model-name', '-m', default='resnet34')
|
||||
parser.add_argument('--train-batch', '-b', type=int, default=-1)
|
||||
parser.add_argument('--test-batch', type=int, default=4096)
|
||||
parser.add_argument('--device', default='',
|
||||
help='"cuda" or "cpu" or "" in which case use cuda if available')
|
||||
parser.add_argument('--train-batch-size', '-b', type=int, default=128)
|
||||
parser.add_argument('--test-batch-size', type=int, default=4096)
|
||||
parser.add_argument('--seed', '-s', type=float, default=42)
|
||||
parser.add_argument('--half', type=lambda x:x.lower()=='true',
|
||||
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
||||
nargs='?', const=True, default=False)
|
||||
parser.add_argument('--cutout', type=int, default=0)
|
||||
parser.add_argument('--loader', default='torch', help='torch or dali')
|
||||
|
||||
parser.add_argument('--datadir', default='',
|
||||
help='where to find dataset files, default is ~/torchvision_data_dir')
|
||||
parser.add_argument('--outdir', default='',
|
||||
help='where to put results, default is ~/logdir')
|
||||
|
||||
parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)')
|
||||
parser.add_argument('--optim-sched', '-os', default='darts',
|
||||
help='Optimizer and scheduler provider')
|
||||
parser.add_argument('--loader-workers', type=int, default=-1,
|
||||
help='number of thread/workers for data loader (-1 means auto)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -264,41 +283,51 @@ def main():
|
|||
if not args.outdir:
|
||||
args.outdir = os.environ.get('PT_OUTPUT_DIR', '')
|
||||
if not args.outdir:
|
||||
args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name)
|
||||
args.outdir = os.path.join(
|
||||
'~/logdir', 'cifar_testbed', args.experiment_name)
|
||||
|
||||
expdir = utils.full_path(args.outdir)
|
||||
os.makedirs(expdir, exist_ok=True)
|
||||
|
||||
metrics, train_batch_size = train_test(datadir=args.datadir, expdir=expdir,
|
||||
exp_name=args.experiment_name,
|
||||
exp_desc=args.experiment_description,
|
||||
epochs=args.epochs, model_name=args.model_name,
|
||||
train_batch_size=args.train_batch, loader_workers=args.loader_workers,
|
||||
seed=args.seed, half=args.half, test_batch_size=args.test_batch,
|
||||
cutout=args.cutout)
|
||||
utils.setup_cuda(args.seed)
|
||||
datadir = utils.full_path(args.datadir)
|
||||
os.makedirs(datadir, exist_ok=True)
|
||||
|
||||
print(metrics[-1])
|
||||
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
|
||||
|
||||
results = [
|
||||
('test_acc', metrics[-1]['test_top1']),
|
||||
('epochs', args.epochs),
|
||||
('train_batch_size', train_batch_size),
|
||||
('test_batch_size', args.test_batch),
|
||||
('model_name', args.model_name),
|
||||
('exp_name', args.experiment_name),
|
||||
('exp_desc', args.experiment_description),
|
||||
('seed', args.seed),
|
||||
('devices', utils.cuda_device_names()),
|
||||
('half', args.half),
|
||||
('cutout', args.cutout),
|
||||
('optim_sched', args.optim_sched),
|
||||
('train_acc', metrics[-1]['train_top1']),
|
||||
('loader', args.loader),
|
||||
('loader_workers', args.loader_workers),
|
||||
('date', str(time.time())),
|
||||
]
|
||||
# log config for reference
|
||||
logging.info(
|
||||
f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"')
|
||||
logging.info(
|
||||
f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}')
|
||||
logging.info(f'half={args.half}, cutout={args.cutout}')
|
||||
logging.info(f'datadir="{datadir}"')
|
||||
logging.info(f'expdir="{expdir}"')
|
||||
logging.info(f'train_batch_size={args.train_batch_size}')
|
||||
|
||||
utils.append_csv_file(os.path.join(expdir, 'results.tsv'), results)
|
||||
if args.device:
|
||||
device = torch.device(args.device)
|
||||
else:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
nsds = Nasbench101Dataset(
|
||||
os.path.join(args.datadir, 'nasbench_ds', 'nasbench_only108.tfrecord.pkl'))
|
||||
|
||||
# load data just before train start so any errors so far is not delayed
|
||||
train_dl, val_dl, test_dl = get_data(datadir=datadir,
|
||||
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
|
||||
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
|
||||
cutout=args.cutout)
|
||||
|
||||
for model_id in [4, 400, 4000, 40000, 400000]:
|
||||
perf_data = nsds[model_id]
|
||||
net = create_model(nsds, model_id, device, args.half)
|
||||
crit = create_crit(device, args.half)
|
||||
optim, sched, sched_on_epoch = optim_sched(net)
|
||||
train_metrics = train(perf_data['epochs'], train_dl, val_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, args.half, False)
|
||||
test_acc = test(net, test_dl, device, args.half)
|
||||
log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -27,6 +27,9 @@ def main():
|
|||
best = nsds[len(nsds)-1]
|
||||
print('best', best)
|
||||
|
||||
# create model by index
|
||||
model = nsds.create_model(42)
|
||||
print(model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -22,22 +22,24 @@ from archai import cifar10_models
|
|||
|
||||
|
||||
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, quiet)->List[Mapping]:
|
||||
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
|
||||
train_acc, test_acc = 0.0, 0.0
|
||||
metrics = []
|
||||
for epoch in range(epochs):
|
||||
lr = optim.param_groups[0]['lr']
|
||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)
|
||||
sched, sched_on_epoch, half)
|
||||
|
||||
val_acc = test(net, val_dal, device, half) if val_dal is not None else math.nan
|
||||
metrics.append({'val_top1':val_acc, 'train_top1':train_acc, 'lr':lr,
|
||||
val_acc = test(net, val_dal, device,
|
||||
half) if val_dal is not None else math.nan
|
||||
metrics.append({'val_top1': val_acc, 'train_top1': train_acc, 'lr': lr,
|
||||
'epoch': epoch, 'train_loss': loss})
|
||||
if not quiet:
|
||||
logging.info(f'train_epoch={epoch}, val_top1={val_acc},'
|
||||
f' train_top1={train_acc}, lr={lr:.4g}')
|
||||
return metrics
|
||||
|
||||
|
||||
def optim_sched(net):
|
||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
|
@ -45,7 +47,7 @@ def optim_sched(net):
|
|||
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||
|
||||
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
|
||||
milestones=[100, 150, 200, 400, 600]) # resnet original paper
|
||||
milestones=[100, 150, 200, 400, 600]) # resnet original paper
|
||||
sched_on_epoch = True
|
||||
|
||||
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||
|
@ -53,10 +55,10 @@ def optim_sched(net):
|
|||
return optim, sched, sched_on_epoch
|
||||
|
||||
|
||||
def get_data(datadir:str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||
val_percent=10.0)\
|
||||
->Tuple[DataLoader, Optional[DataLoader], DataLoader]:
|
||||
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||
val_percent=10.0)\
|
||||
-> Tuple[DataLoader, Optional[DataLoader], DataLoader]:
|
||||
if utils.is_debugging():
|
||||
train_num_workers = test_num_workers = 0
|
||||
logging.info('debugger=true, num_workers=0')
|
||||
|
@ -67,37 +69,39 @@ def get_data(datadir:str, train_batch_size=128, test_batch_size=4096,
|
|||
|
||||
train_transform = cifar10_transform(aug=True, cutout=cutout)
|
||||
trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
|
||||
download=True, transform=train_transform)
|
||||
download=True, transform=train_transform)
|
||||
|
||||
val_len = int(len(trainset) * val_percent / 100.0)
|
||||
train_len = len(trainset) - val_len
|
||||
|
||||
valset = None
|
||||
if val_len:
|
||||
trainset, valset = torch.utils.data.random_split(trainset, [train_len, val_len])
|
||||
trainset, valset = torch.utils.data.random_split(
|
||||
trainset, [train_len, val_len])
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,
|
||||
shuffle=True, num_workers=train_num_workers, pin_memory=True)
|
||||
shuffle=True, num_workers=train_num_workers, pin_memory=True)
|
||||
|
||||
if valset is not None:
|
||||
val_dl = torch.utils.data.DataLoader(valset, batch_size=test_batch_size,
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
else:
|
||||
val_dl = None
|
||||
|
||||
test_transform = cifar10_transform(aug=False, cutout=0)
|
||||
testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
|
||||
download=True, transform=test_transform)
|
||||
download=True, transform=test_transform)
|
||||
test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
shuffle=False, num_workers=test_num_workers, pin_memory=True)
|
||||
|
||||
logging.info(f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}')
|
||||
logging.info(
|
||||
f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}')
|
||||
|
||||
return train_dl, val_dl, test_dl
|
||||
|
||||
|
||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)->Tuple[float, float]:
|
||||
sched, sched_on_epoch, half) -> Tuple[float, float]:
|
||||
correct, total, loss_total = 0, 0, 0.0
|
||||
net.train()
|
||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||
|
@ -119,9 +123,9 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
|||
return 100.0*correct/total, loss_total
|
||||
|
||||
|
||||
def train_step(net:nn.Module,
|
||||
crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool,
|
||||
inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]:
|
||||
def train_step(net: nn.Module,
|
||||
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
||||
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
||||
outputs = net(inputs)
|
||||
|
||||
loss = crit(outputs, targets)
|
||||
|
@ -133,7 +137,8 @@ def train_step(net:nn.Module,
|
|||
sched.step()
|
||||
return outputs, loss.item()
|
||||
|
||||
def test(net, test_dl, device, half)->float:
|
||||
|
||||
def test(net, test_dl, device, half) -> float:
|
||||
correct, total = 0, 0
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
|
@ -151,12 +156,13 @@ def test(net, test_dl, device, half)->float:
|
|||
return 100.0*correct/total
|
||||
|
||||
|
||||
def param_size(model:torch.nn.Module)->int:
|
||||
def param_size(model: torch.nn.Module) -> int:
|
||||
"""count all parameters excluding auxiliary"""
|
||||
return sum(v.numel() for name, v in model.named_parameters() \
|
||||
if "auxiliary" not in name)
|
||||
return sum(v.numel() for name, v in model.named_parameters()
|
||||
if "auxiliary" not in name)
|
||||
|
||||
def cifar10_transform(aug:bool, cutout=0):
|
||||
|
||||
def cifar10_transform(aug: bool, cutout=0):
|
||||
MEAN = [0.49139968, 0.48215827, 0.44653124]
|
||||
STD = [0.24703233, 0.24348505, 0.26158768]
|
||||
|
||||
|
@ -172,7 +178,7 @@ def cifar10_transform(aug:bool, cutout=0):
|
|||
]
|
||||
transf = aug_transf + transf
|
||||
|
||||
if cutout > 0: # must be after normalization
|
||||
if cutout > 0: # must be after normalization
|
||||
transf += [CutoutDefault(cutout)]
|
||||
|
||||
return transforms.Compose(transf)
|
||||
|
@ -182,6 +188,7 @@ class CutoutDefault:
|
|||
"""
|
||||
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
||||
"""
|
||||
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
|
@ -202,7 +209,8 @@ class CutoutDefault:
|
|||
img *= mask
|
||||
return img
|
||||
|
||||
def log_metrics(expdir:str, filename:str, metrics, test_acc:float, args)->None:
|
||||
|
||||
def log_metrics(expdir: str, filename: str, metrics, test_acc: float, args) -> None:
|
||||
print('filename:', f'test_acc: {test_acc}', metrics[-1])
|
||||
results = [
|
||||
('test_acc', test_acc),
|
||||
|
@ -225,13 +233,15 @@ def log_metrics(expdir:str, filename:str, metrics, test_acc:float, args)->None:
|
|||
with open(os.path.join(expdir, f'{filename}.yaml'), 'w') as f:
|
||||
yaml.dump(metrics, f)
|
||||
|
||||
|
||||
def create_crit(device, half):
|
||||
crit = nn.CrossEntropyLoss().to(device)
|
||||
if half:
|
||||
crit.half()
|
||||
return crit
|
||||
|
||||
def create_model(model_name, device, half)->nn.Module:
|
||||
|
||||
def create_model(model_name, device, half) -> nn.Module:
|
||||
model_class = getattr(cifar10_models, model_name)
|
||||
net = model_class()
|
||||
logging.info(f'param_size_m={param_size(net):.1e}')
|
||||
|
@ -240,6 +250,7 @@ def create_model(model_name, device, half)->nn.Module:
|
|||
net.half()
|
||||
return net
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Pytorch cifar training')
|
||||
parser.add_argument('--experiment-name', '-n', default='train_pytorch')
|
||||
|
@ -247,11 +258,12 @@ def main():
|
|||
default='Train cifar usin pure PyTorch code')
|
||||
parser.add_argument('--epochs', '-e', type=int, default=1)
|
||||
parser.add_argument('--model-name', '-m', default='resnet34')
|
||||
parser.add_argument('--device', default='', help='"cuda" or "cpu" or "" in which case use cuda if available')
|
||||
parser.add_argument('--device', default='',
|
||||
help='"cuda" or "cpu" or "" in which case use cuda if available')
|
||||
parser.add_argument('--train-batch-size', '-b', type=int, default=128)
|
||||
parser.add_argument('--test-batch-size', type=int, default=4096)
|
||||
parser.add_argument('--seed', '-s', type=float, default=42)
|
||||
parser.add_argument('--half', type=lambda x:x.lower()=='true',
|
||||
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
||||
nargs='?', const=True, default=False)
|
||||
parser.add_argument('--cutout', type=int, default=0)
|
||||
|
||||
|
@ -260,7 +272,8 @@ def main():
|
|||
parser.add_argument('--outdir', default='',
|
||||
help='where to put results, default is ~/logdir')
|
||||
|
||||
parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)')
|
||||
parser.add_argument('--loader-workers', type=int, default=-1,
|
||||
help='number of thread/workers for data loader (-1 means auto)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -269,7 +282,8 @@ def main():
|
|||
if not args.outdir:
|
||||
args.outdir = os.environ.get('PT_OUTPUT_DIR', '')
|
||||
if not args.outdir:
|
||||
args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name)
|
||||
args.outdir = os.path.join(
|
||||
'~/logdir', 'cifar_testbed', args.experiment_name)
|
||||
|
||||
expdir = utils.full_path(args.outdir)
|
||||
os.makedirs(expdir, exist_ok=True)
|
||||
|
@ -281,8 +295,10 @@ def main():
|
|||
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
|
||||
|
||||
# log config for reference
|
||||
logging.info(f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"')
|
||||
logging.info(f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}')
|
||||
logging.info(
|
||||
f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"')
|
||||
logging.info(
|
||||
f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}')
|
||||
logging.info(f'half={args.half}, cutout={args.cutout}')
|
||||
logging.info(f'datadir="{datadir}"')
|
||||
logging.info(f'expdir="{expdir}"')
|
||||
|
@ -299,14 +315,15 @@ def main():
|
|||
|
||||
# load data just before train start so any errors so far is not delayed
|
||||
train_dl, val_dl, test_dl = get_data(datadir=datadir,
|
||||
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
|
||||
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
|
||||
cutout=args.cutout)
|
||||
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
|
||||
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
|
||||
cutout=args.cutout)
|
||||
|
||||
train_metrics = train(args.epochs, train_dl, val_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, args.half, False)
|
||||
sched, sched_on_epoch, args.half, False)
|
||||
test_acc = test(net, test_dl, device, args.half)
|
||||
log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче