зеркало из https://github.com/microsoft/archai.git
sync plain pytorch script, add grad clip, cosine sched
This commit is contained in:
Родитель
7396e4f56c
Коммит
6f4f3c5342
|
@ -28,7 +28,7 @@ nas:
|
||||||
aux_weight: 0.0
|
aux_weight: 0.0
|
||||||
drop_path_prob: 0.0 # probability that given edge will be dropped
|
drop_path_prob: 0.0 # probability that given edge will be dropped
|
||||||
grad_clip: 5.0 # grads above this value is clipped
|
grad_clip: 5.0 # grads above this value is clipped
|
||||||
epochs: 100
|
epochs: 108
|
||||||
optimizer:
|
optimizer:
|
||||||
type: 'sgd'
|
type: 'sgd'
|
||||||
lr: 0.025 # init learning rate
|
lr: 0.025 # init learning rate
|
||||||
|
|
|
@ -23,13 +23,13 @@ from archai.algos.nasbench101.nasbench101_dataset import Nasbench101Dataset
|
||||||
|
|
||||||
|
|
||||||
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||||
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
|
sched, sched_on_epoch, half, quiet, grad_clip:float) -> List[Mapping]:
|
||||||
train_acc, test_acc = 0.0, 0.0
|
train_acc, test_acc = 0.0, 0.0
|
||||||
metrics = []
|
metrics = []
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
lr = optim.param_groups[0]['lr']
|
lr = optim.param_groups[0]['lr']
|
||||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
sched, sched_on_epoch, half)
|
sched, sched_on_epoch, half, grad_clip)
|
||||||
|
|
||||||
val_acc = test(net, val_dal, device,
|
val_acc = test(net, val_dal, device,
|
||||||
half) if val_dal is not None else math.nan
|
half) if val_dal is not None else math.nan
|
||||||
|
@ -41,7 +41,7 @@ def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def optim_sched(net):
|
def optim_sched_orig(net, epochs):
|
||||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||||
optim = torch.optim.SGD(net.parameters(),
|
optim = torch.optim.SGD(net.parameters(),
|
||||||
lr, momentum=momentum, weight_decay=weight_decay)
|
lr, momentum=momentum, weight_decay=weight_decay)
|
||||||
|
@ -55,6 +55,19 @@ def optim_sched(net):
|
||||||
|
|
||||||
return optim, sched, sched_on_epoch
|
return optim, sched, sched_on_epoch
|
||||||
|
|
||||||
|
def optim_sched_cosine(net, epochs):
|
||||||
|
lr, momentum, weight_decay = 0.025, 0.9, 1.0e-4
|
||||||
|
optim = torch.optim.SGD(net.parameters(),
|
||||||
|
lr, momentum=momentum, weight_decay=weight_decay)
|
||||||
|
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||||
|
|
||||||
|
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
|
||||||
|
sched_on_epoch = True
|
||||||
|
|
||||||
|
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||||
|
|
||||||
|
return optim, sched, sched_on_epoch
|
||||||
|
|
||||||
|
|
||||||
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||||
|
@ -102,7 +115,7 @@ def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
sched, sched_on_epoch, half) -> Tuple[float, float]:
|
sched, sched_on_epoch, half, grad_clip:float) -> Tuple[float, float]:
|
||||||
correct, total, loss_total = 0, 0, 0.0
|
correct, total, loss_total = 0, 0, 0.0
|
||||||
net.train()
|
net.train()
|
||||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||||
|
@ -113,7 +126,7 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
inputs = inputs.half()
|
inputs = inputs.half()
|
||||||
|
|
||||||
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
||||||
inputs, targets)
|
inputs, targets, grad_clip)
|
||||||
loss_total += loss
|
loss_total += loss
|
||||||
|
|
||||||
_, predicted = outputs.max(1)
|
_, predicted = outputs.max(1)
|
||||||
|
@ -126,12 +139,13 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
|
|
||||||
def train_step(net: nn.Module,
|
def train_step(net: nn.Module,
|
||||||
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
||||||
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
|
||||||
outputs = net(inputs)
|
outputs = net(inputs)
|
||||||
|
|
||||||
loss = crit(outputs, targets)
|
loss = crit(outputs, targets)
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||||
|
|
||||||
optim.step()
|
optim.step()
|
||||||
if sched and not sched_on_epoch:
|
if sched and not sched_on_epoch:
|
||||||
|
@ -267,6 +281,7 @@ def main():
|
||||||
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
||||||
nargs='?', const=True, default=False)
|
nargs='?', const=True, default=False)
|
||||||
parser.add_argument('--cutout', type=int, default=0)
|
parser.add_argument('--cutout', type=int, default=0)
|
||||||
|
parser.add_argument('--grad-clip', type=float, default=5.0)
|
||||||
|
|
||||||
parser.add_argument('--datadir', default='',
|
parser.add_argument('--datadir', default='',
|
||||||
help='where to find dataset files, default is ~/torchvision_data_dir')
|
help='where to find dataset files, default is ~/torchvision_data_dir')
|
||||||
|
@ -321,11 +336,14 @@ def main():
|
||||||
|
|
||||||
for model_id in [4, 400, 4000, 40000, 400000]:
|
for model_id in [4, 400, 4000, 40000, 400000]:
|
||||||
perf_data = nsds[model_id]
|
perf_data = nsds[model_id]
|
||||||
|
epochs = perf_data['epochs']
|
||||||
|
|
||||||
net = create_model(nsds, model_id, device, args.half)
|
net = create_model(nsds, model_id, device, args.half)
|
||||||
crit = create_crit(device, args.half)
|
crit = create_crit(device, args.half)
|
||||||
optim, sched, sched_on_epoch = optim_sched(net)
|
optim, sched, sched_on_epoch = optim_sched_cosine(net, epochs)
|
||||||
train_metrics = train(perf_data['epochs'], train_dl, val_dl, net, device, crit, optim,
|
|
||||||
sched, sched_on_epoch, args.half, False)
|
train_metrics = train(epochs, train_dl, val_dl, net, device, crit, optim,
|
||||||
|
sched, sched_on_epoch, args.half, False, grad_clip=args.grad_clip)
|
||||||
test_acc = test(net, test_dl, device, args.half)
|
test_acc = test(net, test_dl, device, args.half)
|
||||||
log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data)
|
log_metrics(expdir, f'metrics_{model_id}', train_metrics, test_acc, args, perf_data)
|
||||||
|
|
||||||
|
|
|
@ -22,13 +22,13 @@ from archai import cifar10_models
|
||||||
|
|
||||||
|
|
||||||
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||||
sched, sched_on_epoch, half, quiet) -> List[Mapping]:
|
sched, sched_on_epoch, half, quiet, grad_clip:float) -> List[Mapping]:
|
||||||
train_acc, test_acc = 0.0, 0.0
|
train_acc, test_acc = 0.0, 0.0
|
||||||
metrics = []
|
metrics = []
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
lr = optim.param_groups[0]['lr']
|
lr = optim.param_groups[0]['lr']
|
||||||
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
train_acc, loss = train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
sched, sched_on_epoch, half)
|
sched, sched_on_epoch, half, grad_clip)
|
||||||
|
|
||||||
val_acc = test(net, val_dal, device,
|
val_acc = test(net, val_dal, device,
|
||||||
half) if val_dal is not None else math.nan
|
half) if val_dal is not None else math.nan
|
||||||
|
@ -40,7 +40,7 @@ def train(epochs, train_dl, val_dal, net, device, crit, optim,
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def optim_sched(net):
|
def optim_sched_orig(net, epochs):
|
||||||
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
lr, momentum, weight_decay = 0.1, 0.9, 1.0e-4
|
||||||
optim = torch.optim.SGD(net.parameters(),
|
optim = torch.optim.SGD(net.parameters(),
|
||||||
lr, momentum=momentum, weight_decay=weight_decay)
|
lr, momentum=momentum, weight_decay=weight_decay)
|
||||||
|
@ -54,6 +54,19 @@ def optim_sched(net):
|
||||||
|
|
||||||
return optim, sched, sched_on_epoch
|
return optim, sched, sched_on_epoch
|
||||||
|
|
||||||
|
def optim_sched_cosine(net, epochs):
|
||||||
|
lr, momentum, weight_decay = 0.025, 0.9, 1.0e-4
|
||||||
|
optim = torch.optim.SGD(net.parameters(),
|
||||||
|
lr, momentum=momentum, weight_decay=weight_decay)
|
||||||
|
logging.info(f'lr={lr}, momentum={momentum}, weight_decay={weight_decay}')
|
||||||
|
|
||||||
|
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
|
||||||
|
sched_on_epoch = True
|
||||||
|
|
||||||
|
logging.info(f'sched_on_epoch={sched_on_epoch}, sched={str(sched)}')
|
||||||
|
|
||||||
|
return optim, sched, sched_on_epoch
|
||||||
|
|
||||||
|
|
||||||
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||||
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
cutout=0, train_num_workers=-1, test_num_workers=-1,
|
||||||
|
@ -101,7 +114,7 @@ def get_data(datadir: str, train_batch_size=128, test_batch_size=4096,
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
sched, sched_on_epoch, half) -> Tuple[float, float]:
|
sched, sched_on_epoch, half, grad_clip:float) -> Tuple[float, float]:
|
||||||
correct, total, loss_total = 0, 0, 0.0
|
correct, total, loss_total = 0, 0, 0.0
|
||||||
net.train()
|
net.train()
|
||||||
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
for batch_idx, (inputs, targets) in enumerate(train_dl):
|
||||||
|
@ -112,7 +125,7 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
inputs = inputs.half()
|
inputs = inputs.half()
|
||||||
|
|
||||||
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
outputs, loss = train_step(net, crit, optim, sched, sched_on_epoch,
|
||||||
inputs, targets)
|
inputs, targets, grad_clip)
|
||||||
loss_total += loss
|
loss_total += loss
|
||||||
|
|
||||||
_, predicted = outputs.max(1)
|
_, predicted = outputs.max(1)
|
||||||
|
@ -125,12 +138,13 @@ def train_epoch(epoch, net, train_dl, device, crit, optim,
|
||||||
|
|
||||||
def train_step(net: nn.Module,
|
def train_step(net: nn.Module,
|
||||||
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
crit: _Loss, optim: Optimizer, sched: _LRScheduler, sched_on_epoch: bool,
|
||||||
inputs: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
inputs: torch.Tensor, targets: torch.Tensor, grad_clip:float) -> Tuple[torch.Tensor, float]:
|
||||||
outputs = net(inputs)
|
outputs = net(inputs)
|
||||||
|
|
||||||
loss = crit(outputs, targets)
|
loss = crit(outputs, targets)
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||||
|
|
||||||
optim.step()
|
optim.step()
|
||||||
if sched and not sched_on_epoch:
|
if sched and not sched_on_epoch:
|
||||||
|
@ -266,6 +280,7 @@ def main():
|
||||||
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
parser.add_argument('--half', type=lambda x: x.lower() == 'true',
|
||||||
nargs='?', const=True, default=False)
|
nargs='?', const=True, default=False)
|
||||||
parser.add_argument('--cutout', type=int, default=0)
|
parser.add_argument('--cutout', type=int, default=0)
|
||||||
|
parser.add_argument('--grad-clip', type=float, default=5.0)
|
||||||
|
|
||||||
parser.add_argument('--datadir', default='',
|
parser.add_argument('--datadir', default='',
|
||||||
help='where to find dataset files, default is ~/torchvision_data_dir')
|
help='where to find dataset files, default is ~/torchvision_data_dir')
|
||||||
|
@ -309,18 +324,22 @@ def main():
|
||||||
else:
|
else:
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
net = create_model(args.model_name, device, args.half)
|
|
||||||
crit = create_crit(device, args.half)
|
|
||||||
optim, sched, sched_on_epoch = optim_sched(net)
|
|
||||||
|
|
||||||
# load data just before train start so any errors so far is not delayed
|
# load data just before train start so any errors so far is not delayed
|
||||||
train_dl, val_dl, test_dl = get_data(datadir=datadir,
|
train_dl, val_dl, test_dl = get_data(datadir=datadir,
|
||||||
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
|
train_batch_size=args.train_batch_size, test_batch_size=args.test_batch_size,
|
||||||
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
|
train_num_workers=args.loader_workers, test_num_workers=args.loader_workers,
|
||||||
cutout=args.cutout)
|
cutout=args.cutout)
|
||||||
|
|
||||||
train_metrics = train(args.epochs, train_dl, val_dl, net, device, crit, optim,
|
epochs = args.epochs
|
||||||
sched, sched_on_epoch, args.half, False)
|
|
||||||
|
|
||||||
|
net = create_model(args.model_name, device, args.half)
|
||||||
|
crit = create_crit(device, args.half)
|
||||||
|
optim, sched, sched_on_epoch = optim_sched_orig(net, epochs)
|
||||||
|
|
||||||
|
|
||||||
|
train_metrics = train(epochs, train_dl, val_dl, net, device, crit, optim,
|
||||||
|
sched, sched_on_epoch, args.half, False, grad_clip=args.grad_clip)
|
||||||
test_acc = test(net, test_dl, device, args.half)
|
test_acc = test(net, test_dl, device, args.half)
|
||||||
log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args)
|
log_metrics(expdir, 'train_metrics', train_metrics, test_acc, args)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче