create nasbench101 model by index, var exp

This commit is contained in:
Shital Shah 2021-01-16 00:41:47 -08:00 коммит произвёл Gustavo Rosa
Родитель 5e70aa8a51
Коммит 7396e4f56c
4 изменённых файлов: 220 добавлений и 161 удалений

Просмотреть файл

@ -96,10 +96,13 @@ import logging
import numpy as np
from numpy.lib.function_base import average
from torch import nn
from archai.common import utils
from . import config
from . import model_metrics_pb2
from . import model_spec as _model_spec
from . import model_builder
# Bring ModelSpec to top-level for convenience. See lib/model_spec.py.
ModelSpec = _model_spec.ModelSpec
@ -304,6 +307,13 @@ class Nasbench101Dataset(object):
return data
def create_model(self, index:int, device=None,
stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, num_labels=10)->nn.Module:
data = self[index]
adj, ops = data['module_adjacency'], data['module_operations']
return model_builder.build(adj, ops, device=device,
stem_out_channels=stem_out_channels, num_stacks=num_stacks,
num_modules_per_stack=num_modules_per_stack, num_labels=num_labels)
def is_valid(self, desc_matrix:List[List[int]], vertex_ops:List[str]):
"""Checks the validity of the model_spec.

Просмотреть файл

@ -1,11 +1,13 @@
import argparse
from typing import List, Mapping, Tuple, Any
import math
from typing import List, Mapping, Optional, Tuple, Any
import os
import logging
import numpy as np
import time
import torch
from torch import nn
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.nn.modules.loss import _Loss
@ -17,91 +19,47 @@ import yaml
from archai.common import utils
from archai import cifar10_models
from archai.datasets.list_dataset import ListDataset
from archai.algos.nasbench101.nasbench101_dataset import Nasbench101Dataset
def train_test(datadir:str, expdir:str,
exp_name:str, exp_desc:str, epochs:int, model_name:str,
train_batch_size:int, loader_workers:int, seed, half:bool,
test_batch_size:int, cutout:int)->Tuple[List[Mapping], int]:
# dirs
utils.setup_cuda(seed)
device = torch.device('cuda')
datadir = utils.full_path(datadir)
os.makedirs(datadir, exist_ok=True)
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
# log config for reference
logging.info(f'exp_name="{exp_name}", exp_desc="{exp_desc}"')
logging.info(f'model_name="{model_name}", seed={seed}, epochs={epochs}')
logging.info(f'half={half}, cutout={cutout}, train_batch_size={train_batch_size}')
logging.info(f'datadir="{datadir}"')
logging.info(f'expdir="{expdir}"')
model_class = getattr(cifar10_models, model_name)
net = model_class()
logging.info(f'param_size_m={param_size(net):.1e}')
net = net.to(device)
crit = torch.nn.CrossEntropyLoss().to(device)
optim, sched, sched_on_epoch, batch_size = optim_sched(net)
logging.info(f'train_batch_size={train_batch_size}, batch_size={batch_size}')
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
# load data just before train start so any errors so far is not delayed
train_dl, test_dl = cifar10_dataloaders(datadir,
train_batch_size=batch_size, test_batch_size=test_batch_size,
train_num_workers=loader_workers, test_num_workers=loader_workers,
cutout=cutout)
metrics = train(epochs, train_dl, test_dl, net, device, crit, optim,
sched, sched_on_epoch, half, False)
with open(os.path.join(expdir, 'metrics.yaml'), 'w') as f:
yaml.dump(metrics, f)
return metrics, train_batch_size
def train(epochs, train_dl, test_dl, net, device, crit, optim,
sched, sched_on_epoch, half, quiet)->List[Mapping]:
if half:
net.half()
crit.half()
def train(epochs, train_dl, val_dal, net, device, crit, optim,
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
train_acc, test_acc = 0.0, 0.0
metrics = []
for epoch in range(epochs):
lr = optim.param_groups[0]['lr']
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half)
test_acc = test(net, test_dl, device, half)
metrics.append({'test_top1':test_acc, 'train_top1':train_acc, 'lr':lr,
sched, sched_on_epoch, half)
val_acc = test(net, val_dal, device,
half) if val_dal is not None else math.nan
metrics.append({'val_top1': val_acc, 'train_top1': train_acc, 'lr': lr,
'epoch': epoch, 'train_loss': loss})
if not quiet:
logging.info(f'train_epoch={epoch}, test_top1={test_acc},'
logging.info(f'train_epoch={epoch}, val_top1={val_acc},'
f' train_top1={train_acc}, lr={lr:.4g}')
return metrics
def optim_sched(net):
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
lr, momentum=momentum, weight_decay=weight_decay)
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
lr, momentum=momentum, weight_decay=weight_decay)
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
milestones=[100, 150, 200, 400, 600]) # resnet original paper
sched_on_epoch = True
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
milestones=[100, 150, 200, 400, 600]) # resnet original paper
sched_on_epoch = True
return optim, sched, sched_on_epoch, 128
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
return optim, sched, sched_on_epoch
def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096,
cutout=0, train_num_workers=-1, test_num_workers=-1)\
->Tuple[DataLoader, DataLoader]:
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
cutout=0, train_num_workers=-1, test_num_workers=-1,
val_percent=10.0)\
-> Tuple[DataLoader, Optional[DataLoader], DataLoader]:
if utils.is_debugging():
train_num_workers = test_num_workers = 0
logging.info('debugger=true, num_workers=0')
@ -112,23 +70,40 @@ def cifar10_dataloaders(datadir:str, train_batch_size=128, test_batch_size=4096,
train_transform = cifar10_transform(aug=True, cutout=cutout)
trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
download=True, transform=train_transform)
download=True, transform=train_transform)
val_len = int(len(trainset) * val_percent / 100.0)
train_len = len(trainset) - val_len
valset = None
if val_len:
trainset, valset = torch.utils.data.random_split(
trainset, [train_len, val_len])
train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,
shuffle=True, num_workers=train_num_workers, pin_memory=True)
shuffle=True, num_workers=train_num_workers, pin_memory=True)
if valset is not None:
val_dl = torch.utils.data.DataLoader(valset, batch_size=test_batch_size,
shuffle=False, num_workers=test_num_workers, pin_memory=True)
else:
val_dl = None
test_transform = cifar10_transform(aug=False, cutout=0)
testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
download=True, transform=test_transform)
download=True, transform=test_transform)
test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
shuffle=False, num_workers=test_num_workers, pin_memory=True)
shuffle=False, num_workers=test_num_workers, pin_memory=True)
return train_dl, test_dl
logging.info(
f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}')
return train_dl, val_dl, test_dl
def train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half)->Tuple[float, float]:
sched, sched_on_epoch, half) -> Tuple[float, float]:
correct, total, loss_total = 0, 0, 0.0
ds = ListDataset(train_dl)
net.train()
for batch_idx, (inputs, targets) in enumerate(train_dl):
inputs = inputs.to(device, non_blocking=True)
@ -149,9 +124,9 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
return 100.0*correct/total, loss_total
def train_step(net:torch.nn.Module,
crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool,
inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]:
def train_step(net: nn.Module,
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
outputs = net(inputs)
loss = crit(outputs, targets)
@ -163,7 +138,8 @@ def train_step(net:torch.nn.Module,
sched.step()
return outputs, loss.item()
def test(net, test_dl, device, half)->float:
def test(net, test_dl, device, half) -> float:
correct, total = 0, 0
net.eval()
with torch.no_grad():
@ -181,12 +157,13 @@ def test(net, test_dl, device, half)->float:
return 100.0*correct/total
def param_size(model:torch.nn.Module)->int:
def param_size(model: torch.nn.Module) -> int:
"""count all parameters excluding auxiliary"""
return sum(v.numel() for name, v in model.named_parameters() \
if "auxiliary" not in name)
return sum(v.numel() for name, v in model.named_parameters()
if "auxiliary" not in name)
def cifar10_transform(aug:bool, cutout=0):
def cifar10_transform(aug: bool, cutout=0):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
@ -202,7 +179,7 @@ def cifar10_transform(aug:bool, cutout=0):
]
transf = aug_transf + transf
if cutout > 0: # must be after normalization
if cutout > 0: # must be after normalization
transf += [CutoutDefault(cutout)]
return transforms.Compose(transf)
@ -212,6 +189,7 @@ class CutoutDefault:
"""
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
"""
def __init__(self, length):
self.length = length
@ -233,29 +211,70 @@ class CutoutDefault:
return img
def log_metrics(expdir: str, filename: str, metrics, test_acc: float, args, perf_data:dict) -> None:
print('filename:', f'test_acc: {test_acc}', metrics[-1])
results = [
('test_acc', test_acc),
('nasbenc101_test_acc', perf_data['avg_final_test_accuracy']),
('val_acc', metrics[-1]['val_top1']),
('epochs', args.epochs),
('train_batch_size', args.train_batch_size),
('test_batch_size', args.test_batch_size),
('model_name', args.model_name),
('exp_name', args.experiment_name),
('exp_desc', args.experiment_description),
('seed', args.seed),
('devices', utils.cuda_device_names()),
('half', args.half),
('cutout', args.cutout),
('train_acc', metrics[-1]['train_top1']),
('loader_workers', args.loader_workers),
('date', str(time.time())),
]
utils.append_csv_file(os.path.join(expdir, f'{filename}.tsv'), results)
with open(os.path.join(expdir, f'{filename}_metrics.yaml'), 'w') as f:
yaml.dump(metrics, f)
with open(os.path.join(expdir, f'{filename}_nasbench101.yaml'), 'w') as f:
yaml.dump(perf_data, f)
def create_crit(device, half):
crit = nn.CrossEntropyLoss().to(device)
if half:
crit.half()
return crit
def create_model(nsds, index, device, half) -> nn.Module:
net = nsds.create_model(index)
logging.info(f'param_size_m={param_size(net):.1e}')
net = net.to(device)
if half:
net.half()
return net
def main():
parser = argparse.ArgumentParser(description='Pytorch cifar training')
parser.add_argument('--experiment-name', '-n', default='train_pytorch')
parser.add_argument('--experiment-description', '-d',
default='Train cifar usin pure PyTorch code')
parser.add_argument('--epochs', '-e', type=int, default=300)
parser.add_argument('--epochs', '-e', type=int, default=1)
parser.add_argument('--model-name', '-m', default='resnet34')
parser.add_argument('--train-batch', '-b', type=int, default=-1)
parser.add_argument('--test-batch', type=int, default=4096)
parser.add_argument('--device', default='',
help='"cuda" or "cpu" or "" in which case use cuda if available')
parser.add_argument('--train-batch-size', '-b', type=int, default=128)
parser.add_argument('--test-batch-size', type=int, default=4096)
parser.add_argument('--seed', '-s', type=float, default=42)
parser.add_argument('--half', type=lambda x:x.lower()=='true',
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
nargs='?', const=True, default=False)
parser.add_argument('--cutout', type=int, default=0)
parser.add_argument('--loader', default='torch', help='torch or dali')
parser.add_argument('--datadir', default='',
help='where to find dataset files, default is ~/torchvision_data_dir')
parser.add_argument('--outdir', default='',
help='where to put results, default is ~/logdir')
parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)')
parser.add_argument('--optim-sched', '-os', default='darts',
help='Optimizer and scheduler provider')
parser.add_argument('--loader-workers', type=int, default=-1,
help='number of thread/workers for data loader (-1 means auto)')
args = parser.parse_args()
@ -264,41 +283,51 @@ def main():
if not args.outdir:
args.outdir = os.environ.get('PT_OUTPUT_DIR', '')
if not args.outdir:
args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name)
args.outdir = os.path.join(
'~/logdir', 'cifar_testbed', args.experiment_name)
expdir = utils.full_path(args.outdir)
os.makedirs(expdir, exist_ok=True)
metrics, train_batch_size = train_test(datadir=args.datadir, expdir=expdir,
exp_name=args.experiment_name,
exp_desc=args.experiment_description,
epochs=args.epochs, model_name=args.model_name,
train_batch_size=args.train_batch, loader_workers=args.loader_workers,
seed=args.seed, half=args.half, test_batch_size=args.test_batch,
cutout=args.cutout)
utils.setup_cuda(args.seed)
datadir = utils.full_path(args.datadir)
os.makedirs(datadir, exist_ok=True)
print(metrics[-1])
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
results = [
('test_acc', metrics[-1]['test_top1']),
('epochs', args.epochs),
('train_batch_size', train_batch_size),
('test_batch_size', args.test_batch),
('model_name', args.model_name),
('exp_name', args.experiment_name),
('exp_desc', args.experiment_description),
('seed', args.seed),
('devices', utils.cuda_device_names()),
('half', args.half),
('cutout', args.cutout),
('optim_sched', args.optim_sched),
('train_acc', metrics[-1]['train_top1']),
('loader', args.loader),
('loader_workers', args.loader_workers),
('date', str(time.time())),
]
# log config for reference
logging.info(
f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"')
logging.info(
f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}')
logging.info(f'half={args.half}, cutout={args.cutout}')
logging.info(f'datadir="{datadir}"')
logging.info(f'expdir="{expdir}"')
logging.info(f'train_batch_size={args.train_batch_size}')
utils.append_csv_file(os.path.join(expdir, 'results.tsv'), results)
if args.device:
device = torch.device(args.device)
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
nsds = Nasbench101Dataset(
os.path.join(args.datadir, 'nasbench_ds', 'nasbench_only108.tfrecord.pkl'))
# load data just before train start so any errors so far is not delayed
train_dl, val_dl, test_dl = get_data(datadir=datadir,
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
cutout=args.cutout)
for model_id in [4, 400, 4000, 40000, 400000]:
perf_data = nsds[model_id]
net = create_model(nsds, model_id, device, args.half)
crit = create_crit(device, args.half)
optim, sched, sched_on_epoch = optim_sched(net)
train_metrics = train(perf_data['epochs'], train_dl, val_dl, net, device, crit, optim,
sched, sched_on_epoch, args.half, False)
test_acc = test(net, test_dl, device, args.half)
log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data)
if __name__ == '__main__':
main()
main()

Просмотреть файл

@ -27,6 +27,9 @@ def main():
best = nsds[len(nsds)-1]
print('best', best)
# create model by index
model = nsds.create_model(42)
print(model)
if __name__ == '__main__':
main()

Просмотреть файл

@ -22,22 +22,24 @@ from archai import cifar10_models
def train(epochs, train_dl, val_dal, net, device, crit, optim,
sched, sched_on_epoch, half, quiet)->List[Mapping]:
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
train_acc, test_acc = 0.0, 0.0
metrics = []
for epoch in range(epochs):
lr = optim.param_groups[0]['lr']
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half)
sched, sched_on_epoch, half)
val_acc = test(net, val_dal, device, half) if val_dal is not None else math.nan
metrics.append({'val_top1':val_acc, 'train_top1':train_acc, 'lr':lr,
val_acc = test(net, val_dal, device,
half) if val_dal is not None else math.nan
metrics.append({'val_top1': val_acc, 'train_top1': train_acc, 'lr': lr,
'epoch': epoch, 'train_loss': loss})
if not quiet:
logging.info(f'train_epoch={epoch}, val_top1={val_acc},'
f' train_top1={train_acc}, lr={lr:.4g}')
return metrics
def optim_sched(net):
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
@ -45,7 +47,7 @@ def optim_sched(net):
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
sched = torch.optim.lr_scheduler.MultiStepLR(optim,
milestones=[100, 150, 200, 400, 600]) # resnet original paper
milestones=[100, 150, 200, 400, 600]) # resnet original paper
sched_on_epoch = True
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
@ -53,10 +55,10 @@ def optim_sched(net):
return optim, sched, sched_on_epoch
def get_data(datadir:str, train_batch_size=128, test_batch_size=4096,
cutout=0, train_num_workers=-1, test_num_workers=-1,
val_percent=10.0)\
->Tuple[DataLoader, Optional[DataLoader], DataLoader]:
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
cutout=0, train_num_workers=-1, test_num_workers=-1,
val_percent=10.0)\
-> Tuple[DataLoader, Optional[DataLoader], DataLoader]:
if utils.is_debugging():
train_num_workers = test_num_workers = 0
logging.info('debugger=true, num_workers=0')
@ -67,37 +69,39 @@ def get_data(datadir:str, train_batch_size=128, test_batch_size=4096,
train_transform = cifar10_transform(aug=True, cutout=cutout)
trainset = torchvision.datasets.CIFAR10(root=datadir, train=True,
download=True, transform=train_transform)
download=True, transform=train_transform)
val_len = int(len(trainset) * val_percent / 100.0)
train_len = len(trainset) - val_len
valset = None
if val_len:
trainset, valset = torch.utils.data.random_split(trainset, [train_len, val_len])
trainset, valset = torch.utils.data.random_split(
trainset, [train_len, val_len])
train_dl = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,
shuffle=True, num_workers=train_num_workers, pin_memory=True)
shuffle=True, num_workers=train_num_workers, pin_memory=True)
if valset is not None:
val_dl = torch.utils.data.DataLoader(valset, batch_size=test_batch_size,
shuffle=False, num_workers=test_num_workers, pin_memory=True)
shuffle=False, num_workers=test_num_workers, pin_memory=True)
else:
val_dl = None
test_transform = cifar10_transform(aug=False, cutout=0)
testset = torchvision.datasets.CIFAR10(root=datadir, train=False,
download=True, transform=test_transform)
download=True, transform=test_transform)
test_dl = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,
shuffle=False, num_workers=test_num_workers, pin_memory=True)
shuffle=False, num_workers=test_num_workers, pin_memory=True)
logging.info(f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}')
logging.info(
f'train_len={train_len}, val_len={val_len}, test_len={len(testset)}')
return train_dl, val_dl, test_dl
def train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half)->Tuple[float, float]:
sched, sched_on_epoch, half) -> Tuple[float, float]:
correct, total, loss_total = 0, 0, 0.0
net.train()
for batch_idx, (inputs, targets) in enumerate(train_dl):
@ -119,9 +123,9 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
return 100.0*correct/total, loss_total
def train_step(net:nn.Module,
crit:_Loss, optim:Optimizer, sched:_LRScheduler, sched_on_epoch:bool,
inputs:torch.Tensor, targets:torch.Tensor)->Tuple[torch.Tensor, float]:
def train_step(net: nn.Module,
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
outputs = net(inputs)
loss = crit(outputs, targets)
@ -133,7 +137,8 @@ def train_step(net:nn.Module,
sched.step()
return outputs, loss.item()
def test(net, test_dl, device, half)->float:
def test(net, test_dl, device, half) -> float:
correct, total = 0, 0
net.eval()
with torch.no_grad():
@ -151,12 +156,13 @@ def test(net, test_dl, device, half)->float:
return 100.0*correct/total
def param_size(model:torch.nn.Module)->int:
def param_size(model: torch.nn.Module) -> int:
"""count all parameters excluding auxiliary"""
return sum(v.numel() for name, v in model.named_parameters() \
if "auxiliary" not in name)
return sum(v.numel() for name, v in model.named_parameters()
if "auxiliary" not in name)
def cifar10_transform(aug:bool, cutout=0):
def cifar10_transform(aug: bool, cutout=0):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
@ -172,7 +178,7 @@ def cifar10_transform(aug:bool, cutout=0):
]
transf = aug_transf + transf
if cutout > 0: # must be after normalization
if cutout > 0: # must be after normalization
transf += [CutoutDefault(cutout)]
return transforms.Compose(transf)
@ -182,6 +188,7 @@ class CutoutDefault:
"""
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
"""
def __init__(self, length):
self.length = length
@ -202,7 +209,8 @@ class CutoutDefault:
img *= mask
return img
def log_metrics(expdir:str, filename:str, metrics, test_acc:float, args)->None:
def log_metrics(expdir: str, filename: str, metrics, test_acc: float, args) -> None:
print('filename:', f'test_acc: {test_acc}', metrics[-1])
results = [
('test_acc', test_acc),
@ -225,13 +233,15 @@ def log_metrics(expdir:str, filename:str, metrics, test_acc:float, args)->None:
with open(os.path.join(expdir, f'{filename}.yaml'), 'w') as f:
yaml.dump(metrics, f)
def create_crit(device, half):
crit = nn.CrossEntropyLoss().to(device)
if half:
crit.half()
return crit
def create_model(model_name, device, half)->nn.Module:
def create_model(model_name, device, half) -> nn.Module:
model_class = getattr(cifar10_models, model_name)
net = model_class()
logging.info(f'param_size_m={param_size(net):.1e}')
@ -240,6 +250,7 @@ def create_model(model_name, device, half)->nn.Module:
net.half()
return net
def main():
parser = argparse.ArgumentParser(description='Pytorch cifar training')
parser.add_argument('--experiment-name', '-n', default='train_pytorch')
@ -247,11 +258,12 @@ def main():
default='Train cifar usin pure PyTorch code')
parser.add_argument('--epochs', '-e', type=int, default=1)
parser.add_argument('--model-name', '-m', default='resnet34')
parser.add_argument('--device', default='', help='"cuda" or "cpu" or "" in which case use cuda if available')
parser.add_argument('--device', default='',
help='"cuda" or "cpu" or "" in which case use cuda if available')
parser.add_argument('--train-batch-size', '-b', type=int, default=128)
parser.add_argument('--test-batch-size', type=int, default=4096)
parser.add_argument('--seed', '-s', type=float, default=42)
parser.add_argument('--half', type=lambda x:x.lower()=='true',
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
nargs='?', const=True, default=False)
parser.add_argument('--cutout', type=int, default=0)
@ -260,7 +272,8 @@ def main():
parser.add_argument('--outdir', default='',
help='where to put results, default is ~/logdir')
parser.add_argument('--loader-workers', type=int, default=-1, help='number of thread/workers for data loader (-1 means auto)')
parser.add_argument('--loader-workers', type=int, default=-1,
help='number of thread/workers for data loader (-1 means auto)')
args = parser.parse_args()
@ -269,7 +282,8 @@ def main():
if not args.outdir:
args.outdir = os.environ.get('PT_OUTPUT_DIR', '')
if not args.outdir:
args.outdir = os.path.join('~/logdir', 'cifar_testbed', args.experiment_name)
args.outdir = os.path.join(
'~/logdir', 'cifar_testbed', args.experiment_name)
expdir = utils.full_path(args.outdir)
os.makedirs(expdir, exist_ok=True)
@ -281,8 +295,10 @@ def main():
utils.create_logger(filepath=os.path.join(expdir, 'logs.log'))
# log config for reference
logging.info(f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"')
logging.info(f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}')
logging.info(
f'exp_name="{args.experiment_name}", exp_desc="{args.experiment_description}"')
logging.info(
f'model_name="{args.model_name}", seed={args.seed}, epochs={args.epochs}')
logging.info(f'half={args.half}, cutout={args.cutout}')
logging.info(f'datadir="{datadir}"')
logging.info(f'expdir="{expdir}"')
@ -299,14 +315,15 @@ def main():
# load data just before train start so any errors so far is not delayed
train_dl, val_dl, test_dl = get_data(datadir=datadir,
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
cutout=args.cutout)
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
cutout=args.cutout)
train_metrics = train(args.epochs, train_dl, val_dl, net, device, crit, optim,
sched, sched_on_epoch, args.half, False)
sched, sched_on_epoch, args.half, False)
test_acc = test(net, test_dl, device, args.half)
log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args)
if __name__ == '__main__':
main()
main()