зеркало из https://github.com/microsoft/archai.git
sync plain pytorch script, add grad clip, cosine sched
This commit is contained in:
Родитель
7396e4f56c
Коммит
6f4f3c5342
|
@ -28,7 +28,7 @@ nas:
|
|||
aux_weight: 0.0
|
||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
epochs: 100
|
||||
epochs: 108
|
||||
optimizer:
|
||||
type: 'sgd'
|
||||
lr: 0.025 # init learning rate
|
||||
|
|
|
@ -23,13 +23,13 @@ from archai.algos.nasbench101.nasbench101_dataset import Nasbench101Dataset
|
|||
|
||||
|
||||
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
|
||||
sched, sched_on_epoch, half, quiet, grad_clip:float) -> List[Mapping]:
|
||||
train_acc, test_acc = 0.0, 0.0
|
||||
metrics = []
|
||||
for epoch in range(epochs):
|
||||
lr = optim.param_groups[0]['lr']
|
||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)
|
||||
sched, sched_on_epoch, half, grad_clip)
|
||||
|
||||
val_acc = test(net, val_dal, device,
|
||||
half) if val_dal is not None else math.nan
|
||||
|
@ -41,7 +41,7 @@ def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
|||
return metrics
|
||||
|
||||
|
||||
def optim_sched(net):
|
||||
def optim_sched_orig(net, epochs):
|
||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
|
@ -55,6 +55,19 @@ def optim_sched(net):
|
|||
|
||||
return optim, sched, sched_on_epoch
|
||||
|
||||
def optim_sched_cosine(net, epochs):
|
||||
lr, momentum, weight_decay = 0.025, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
|
||||
sched_on_epoch = True
|
||||
|
||||
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||
|
||||
return optim, sched, sched_on_epoch
|
||||
|
||||
|
||||
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||
|
@ -102,7 +115,7 @@ def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
|||
|
||||
|
||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half) -> Tuple[float, float]:
|
||||
sched, sched_on_epoch, half, grad_clip:float) -> Tuple[float, float]:
|
||||
correct, total, loss_total = 0, 0, 0.0
|
||||
net.train()
|
||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||
|
@ -113,7 +126,7 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
|||
inputs = inputs.half()
|
||||
|
||||
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
||||
inputs, targets)
|
||||
inputs, targets, grad_clip)
|
||||
loss_total += loss
|
||||
|
||||
_, predicted = outputs.max(1)
|
||||
|
@ -126,12 +139,13 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
|||
|
||||
def train_step(net: nn.Module,
|
||||
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
||||
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
||||
inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
|
||||
outputs = net(inputs)
|
||||
|
||||
loss = crit(outputs, targets)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||
|
||||
optim.step()
|
||||
if sched and not sched_on_epoch:
|
||||
|
@ -267,6 +281,7 @@ def main():
|
|||
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
||||
nargs='?', const=True, default=False)
|
||||
parser.add_argument('--cutout', type=int, default=0)
|
||||
parser.add_argument('--grad-clip', type=float, default=5.0)
|
||||
|
||||
parser.add_argument('--datadir', default='',
|
||||
help='where to find dataset files, default is ~/torchvision_data_dir')
|
||||
|
@ -321,11 +336,14 @@ def main():
|
|||
|
||||
for model_id in [4, 400, 4000, 40000, 400000]:
|
||||
perf_data = nsds[model_id]
|
||||
epochs = perf_data['epochs']
|
||||
|
||||
net = create_model(nsds, model_id, device, args.half)
|
||||
crit = create_crit(device, args.half)
|
||||
optim, sched, sched_on_epoch = optim_sched(net)
|
||||
train_metrics = train(perf_data['epochs'], train_dl, val_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, args.half, False)
|
||||
optim, sched, sched_on_epoch = optim_sched_cosine(net, epochs)
|
||||
|
||||
train_metrics = train(epochs, train_dl, val_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, args.half, False, grad_clip=args.grad_clip)
|
||||
test_acc = test(net, test_dl, device, args.half)
|
||||
log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data)
|
||||
|
||||
|
|
|
@ -22,13 +22,13 @@ from archai import cifar10_models
|
|||
|
||||
|
||||
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
|
||||
sched, sched_on_epoch, half, quiet, grad_clip:float) -> List[Mapping]:
|
||||
train_acc, test_acc = 0.0, 0.0
|
||||
metrics = []
|
||||
for epoch in range(epochs):
|
||||
lr = optim.param_groups[0]['lr']
|
||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half)
|
||||
sched, sched_on_epoch, half, grad_clip)
|
||||
|
||||
val_acc = test(net, val_dal, device,
|
||||
half) if val_dal is not None else math.nan
|
||||
|
@ -40,7 +40,7 @@ def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
|||
return metrics
|
||||
|
||||
|
||||
def optim_sched(net):
|
||||
def optim_sched_orig(net, epochs):
|
||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
|
@ -54,6 +54,19 @@ def optim_sched(net):
|
|||
|
||||
return optim, sched, sched_on_epoch
|
||||
|
||||
def optim_sched_cosine(net, epochs):
|
||||
lr, momentum, weight_decay = 0.025, 0.9, 1.0e-4
|
||||
optim = torch.optim.SGD(net.parameters(),
|
||||
lr, momentum=momentum, weight_decay=weight_decay)
|
||||
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
|
||||
sched_on_epoch = True
|
||||
|
||||
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||
|
||||
return optim, sched, sched_on_epoch
|
||||
|
||||
|
||||
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||
|
@ -101,7 +114,7 @@ def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
|||
|
||||
|
||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||
sched, sched_on_epoch, half) -> Tuple[float, float]:
|
||||
sched, sched_on_epoch, half, grad_clip:float) -> Tuple[float, float]:
|
||||
correct, total, loss_total = 0, 0, 0.0
|
||||
net.train()
|
||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||
|
@ -112,7 +125,7 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
|||
inputs = inputs.half()
|
||||
|
||||
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
||||
inputs, targets)
|
||||
inputs, targets, grad_clip)
|
||||
loss_total += loss
|
||||
|
||||
_, predicted = outputs.max(1)
|
||||
|
@ -125,12 +138,13 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
|||
|
||||
def train_step(net: nn.Module,
|
||||
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
||||
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
||||
inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
|
||||
outputs = net(inputs)
|
||||
|
||||
loss = crit(outputs, targets)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||
|
||||
optim.step()
|
||||
if sched and not sched_on_epoch:
|
||||
|
@ -266,6 +280,7 @@ def main():
|
|||
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
||||
nargs='?', const=True, default=False)
|
||||
parser.add_argument('--cutout', type=int, default=0)
|
||||
parser.add_argument('--grad-clip', type=float, default=5.0)
|
||||
|
||||
parser.add_argument('--datadir', default='',
|
||||
help='where to find dataset files, default is ~/torchvision_data_dir')
|
||||
|
@ -309,18 +324,22 @@ def main():
|
|||
else:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
net = create_model(args.model_name, device, args.half)
|
||||
crit = create_crit(device, args.half)
|
||||
optim, sched, sched_on_epoch = optim_sched(net)
|
||||
|
||||
# load data just before train start so any errors so far is not delayed
|
||||
train_dl, val_dl, test_dl = get_data(datadir=datadir,
|
||||
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
|
||||
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
|
||||
cutout=args.cutout)
|
||||
|
||||
train_metrics = train(args.epochs, train_dl, val_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, args.half, False)
|
||||
epochs = args.epochs
|
||||
|
||||
|
||||
net = create_model(args.model_name, device, args.half)
|
||||
crit = create_crit(device, args.half)
|
||||
optim, sched, sched_on_epoch = optim_sched_orig(net, epochs)
|
||||
|
||||
|
||||
train_metrics = train(epochs, train_dl, val_dl, net, device, crit, optim,
|
||||
sched, sched_on_epoch, args.half, False, grad_clip=args.grad_clip)
|
||||
test_acc = test(net, test_dl, device, args.half)
|
||||
log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче