sync plain pytorch script, add grad clip, cosine sched

This commit is contained in:
Shital Shah 2021-01-16 01:27:18 -08:00 коммит произвёл Gustavo Rosa
Родитель 7396e4f56c
Коммит 6f4f3c5342
3 изменённых файлов: 59 добавлений и 22 удалений

Просмотреть файл

@ -28,7 +28,7 @@ nas:
aux_weight: 0.0
drop_path_prob: 0.0 # probability that given edge will be dropped
grad_clip: 5.0 # grads above this value is clipped
epochs: 100
epochs: 108
optimizer:
type: 'sgd'
lr: 0.025 # init learning rate

Просмотреть файл

@ -23,13 +23,13 @@ from archai.algos.nasbench101.nasbench101_dataset import Nasbench101Dataset
def train(epochs, train_dl, val_dal, net, device, crit, optim,
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
sched, sched_on_epoch, half, quiet, grad_clip:float) -> List[Mapping]:
train_acc, test_acc = 0.0, 0.0
metrics = []
for epoch in range(epochs):
lr = optim.param_groups[0]['lr']
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half)
sched, sched_on_epoch, half, grad_clip)
val_acc = test(net, val_dal, device,
half) if val_dal is not None else math.nan
@ -41,7 +41,7 @@ def train(epochs, train_dl, val_dal, net, device, crit, optim,
return metrics
def optim_sched(net):
def optim_sched_orig(net, epochs):
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
lr, momentum=momentum, weight_decay=weight_decay)
@ -55,6 +55,19 @@ def optim_sched(net):
return optim, sched, sched_on_epoch
def optim_sched_cosine(net, epochs):
lr, momentum, weight_decay = 0.025, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
lr, momentum=momentum, weight_decay=weight_decay)
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
sched_on_epoch = True
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
return optim, sched, sched_on_epoch
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
cutout=0, train_num_workers=-1, test_num_workers=-1,
@ -102,7 +115,7 @@ def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
def train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half) -> Tuple[float, float]:
sched, sched_on_epoch, half, grad_clip:float) -> Tuple[float, float]:
correct, total, loss_total = 0, 0, 0.0
net.train()
for batch_idx, (inputs, targets) in enumerate(train_dl):
@ -113,7 +126,7 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
inputs = inputs.half()
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
inputs, targets)
inputs, targets, grad_clip)
loss_total += loss
_, predicted = outputs.max(1)
@ -126,12 +139,13 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
def train_step(net: nn.Module,
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
outputs = net(inputs)
loss = crit(outputs, targets)
optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
optim.step()
if sched and not sched_on_epoch:
@ -267,6 +281,7 @@ def main():
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
nargs='?', const=True, default=False)
parser.add_argument('--cutout', type=int, default=0)
parser.add_argument('--grad-clip', type=float, default=5.0)
parser.add_argument('--datadir', default='',
help='where to find dataset files, default is ~/torchvision_data_dir')
@ -321,11 +336,14 @@ def main():
for model_id in [4, 400, 4000, 40000, 400000]:
perf_data = nsds[model_id]
epochs = perf_data['epochs']
net = create_model(nsds, model_id, device, args.half)
crit = create_crit(device, args.half)
optim, sched, sched_on_epoch = optim_sched(net)
train_metrics = train(perf_data['epochs'], train_dl, val_dl, net, device, crit, optim,
sched, sched_on_epoch, args.half, False)
optim, sched, sched_on_epoch = optim_sched_cosine(net, epochs)
train_metrics = train(epochs, train_dl, val_dl, net, device, crit, optim,
sched, sched_on_epoch, args.half, False, grad_clip=args.grad_clip)
test_acc = test(net, test_dl, device, args.half)
log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data)

Просмотреть файл

@ -22,13 +22,13 @@ from archai import cifar10_models
def train(epochs, train_dl, val_dal, net, device, crit, optim,
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
sched, sched_on_epoch, half, quiet, grad_clip:float) -> List[Mapping]:
train_acc, test_acc = 0.0, 0.0
metrics = []
for epoch in range(epochs):
lr = optim.param_groups[0]['lr']
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half)
sched, sched_on_epoch, half, grad_clip)
val_acc = test(net, val_dal, device,
half) if val_dal is not None else math.nan
@ -40,7 +40,7 @@ def train(epochs, train_dl, val_dal, net, device, crit, optim,
return metrics
def optim_sched(net):
def optim_sched_orig(net, epochs):
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
lr, momentum=momentum, weight_decay=weight_decay)
@ -54,6 +54,19 @@ def optim_sched(net):
return optim, sched, sched_on_epoch
def optim_sched_cosine(net, epochs):
lr, momentum, weight_decay = 0.025, 0.9, 1.0e-4
optim = torch.optim.SGD(net.parameters(),
lr, momentum=momentum, weight_decay=weight_decay)
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
sched_on_epoch = True
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
return optim, sched, sched_on_epoch
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
cutout=0, train_num_workers=-1, test_num_workers=-1,
@ -101,7 +114,7 @@ def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
def train_epoch(epoch, net, train_dl, device, crit, optim,
sched, sched_on_epoch, half) -> Tuple[float, float]:
sched, sched_on_epoch, half, grad_clip:float) -> Tuple[float, float]:
correct, total, loss_total = 0, 0, 0.0
net.train()
for batch_idx, (inputs, targets) in enumerate(train_dl):
@ -112,7 +125,7 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
inputs = inputs.half()
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
inputs, targets)
inputs, targets, grad_clip)
loss_total += loss
_, predicted = outputs.max(1)
@ -125,12 +138,13 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
def train_step(net: nn.Module,
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
outputs = net(inputs)
loss = crit(outputs, targets)
optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
optim.step()
if sched and not sched_on_epoch:
@ -266,6 +280,7 @@ def main():
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
nargs='?', const=True, default=False)
parser.add_argument('--cutout', type=int, default=0)
parser.add_argument('--grad-clip', type=float, default=5.0)
parser.add_argument('--datadir', default='',
help='where to find dataset files, default is ~/torchvision_data_dir')
@ -309,18 +324,22 @@ def main():
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = create_model(args.model_name, device, args.half)
crit = create_crit(device, args.half)
optim, sched, sched_on_epoch = optim_sched(net)
# load data just before train start so any errors so far is not delayed
train_dl, val_dl, test_dl = get_data(datadir=datadir,
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
cutout=args.cutout)
train_metrics = train(args.epochs, train_dl, val_dl, net, device, crit, optim,
sched, sched_on_epoch, args.half, False)
epochs = args.epochs
net = create_model(args.model_name, device, args.half)
crit = create_crit(device, args.half)
optim, sched, sched_on_epoch = optim_sched_orig(net, epochs)
train_metrics = train(epochs, train_dl, val_dl, net, device, crit, optim,
sched, sched_on_epoch, args.half, False, grad_clip=args.grad_clip)
test_acc = test(net, test_dl, device, args.half)
log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args)