separate out losses in module, cuda synchronize, proper warmup handling, default amp at O2, add BN weight decay, nvidia benchmark config,

This commit is contained in:
Shital Shah 2020-04-22 06:35:43 -07:00
Родитель 1cdcb8d7ba
Коммит 63b3ec140d
9 изменённых файлов: 208 добавлений и 109 удалений

Просмотреть файл

@ -138,7 +138,7 @@ class ApexUtils:
def is_master(self)->bool:
return self.global_rank == 0
def sync_dist(self)->None:
def sync_devices(self)->None:
if self._distributed:
torch.cuda.synchronize()
@ -149,7 +149,7 @@ class ApexUtils:
rt /= self._world_size
return rt
else:
return tensor.data
return tensor
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
if self._amp:

Просмотреть файл

@ -129,6 +129,8 @@ class Config(UserDict):
else:
return 1 # path not found, ignore this
def get_val(self, key, default_val):
return super().get(key, default_val)
@staticmethod
def set(instance:'Config')->None:

107
archai/common/ml_losses.py Normal file
Просмотреть файл

@ -0,0 +1,107 @@
import torch
import torch.backends.cudnn as cudnn
from torch import nn
from torch.optim import lr_scheduler, SGD, Adam
from warmup_scheduler import GradualWarmupScheduler
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.nn.modules.loss import _WeightedLoss, _Loss
import torch.nn.functional as F
# TODO: replace this with SmoothCrossEntropyLoss class
# def cross_entropy_smooth(input: torch.Tensor, target, size_average=True, label_smoothing=0.1):
# y = torch.eye(10).to(input.device)
# lb_oh = y[target]
# target = lb_oh * (1 - label_smoothing) + 0.5 * label_smoothing
# logsoftmax = nn.LogSoftmax()
# if size_average:
# return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
# else:
# return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))
class SmoothCrossEntropyLoss(_WeightedLoss):
def __init__(self, weight=None, reduction='mean', smoothing=0.0):
super().__init__(weight=weight, reduction=reduction)
self.smoothing = smoothing
self.weight = weight
self.reduction = reduction
@staticmethod
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
assert 0 <= smoothing < 1
with torch.no_grad():
# For label smoothing, we replace 1-hot vector with 0.9-hot vector instead.
# Create empty vector of same size as targets, fill them up with smoothing/(n-1)
# then replace element where 1 supposed to go and put there 1-smoothing instead
targets = torch.empty(size=(targets.size(0), n_classes), device=targets.device) \
.fill_(smoothing /(n_classes-1)) \
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
return targets
def forward(self, inputs, targets):
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
self.smoothing)
lsm = F.log_softmax(inputs, -1)
if self.weight is not None: # to support weighted targets
lsm = lsm * self.weight.unsqueeze(0)
loss = -(targets * lsm).sum(-1)
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss
# Credits: https://github.com/NVIDIA/DeepLearningExamples/blob/342d2e7649b9a478f35ea45a069a4c7e6b1497b8/PyTorch/Classification/ConvNets/main.py#L350
class NLLMultiLabelSmooth(nn.Module):
"""According to NVidia code, this should be used with mixup?"""
def __init__(self, smoothing = 0.0):
super(NLLMultiLabelSmooth, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
if self.training:
x = x.float()
target = target.float()
logprobs = torch.nn.functional.log_softmax(x, dim = -1)
nll_loss = -logprobs * target
nll_loss = nll_loss.sum(-1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
else:
return torch.nn.functional.cross_entropy(x, target)
class LabelSmoothing(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.0):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()

Просмотреть файл

@ -1,56 +1,63 @@
from typing import Iterable, Type, MutableMapping, Mapping, Any, Optional, Tuple, List, Sequence
import numpy as np
import logging
import csv
from collections import OrderedDict
import sys
import os
import pathlib
import time
import math
import torch
import torch.backends.cudnn as cudnn
from torch import nn
from torch import isnan, nn
from torch.optim import lr_scheduler, SGD, Adam
from warmup_scheduler import GradualWarmupScheduler
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.nn.modules.loss import _WeightedLoss, _Loss
import torch.nn.functional as F
from torchvision.datasets.utils import check_integrity, download_url
from torch.utils.model_zoo import tqdm
import yaml
import runstats
import statopt
from .config import Config
from .cocob import CocobBackprop
from torch.utils.data.dataloader import DataLoader
from .ml_losses import SmoothCrossEntropyLoss
from .common import logger
def create_optimizer(conf_opt:Config, params)->Optimizer:
if conf_opt['type'] == 'sgd':
optim_type = conf_opt['type']
lr = conf_opt.get_val('lr', math.nan)
decay = conf_opt.get_val('decay', math.nan)
decay_bn = conf_opt.get_val('decay_bn', math.nan) # some optim may not support weight decay
logger.info({'optim_type': optim_type, 'lr':lr, 'decay':decay, 'decay_bn': decay_bn})
if not isnan(decay_bn):
bn_params = [v for n, v in params if 'bn' in n]
rest_params = [v for n, v in params if not 'bn' in n]
params = [{
'params': bn_params,
'weight_decay': 0
}, {
'params': rest_params,
'weight_decay': decay
}]
if optim_type == 'sgd':
return SGD(
params,
lr=conf_opt['lr'],
lr=lr,
momentum=conf_opt['momentum'],
weight_decay=conf_opt['decay'],
weight_decay=decay,
nesterov=conf_opt['nesterov']
)
elif conf_opt['type'] == 'adam':
elif optim_type == 'adam':
return Adam(params,
lr=conf_opt['lr'],
lr=lr,
betas=conf_opt['betas'],
weight_decay=conf_opt['decay'])
elif conf_opt['type'] == 'cocob':
weight_decay=decay)
elif optim_type == 'cocob':
return CocobBackprop(params,
alpha=conf_opt['alpha'])
elif conf_opt['type'] == 'salsa':
elif optim_type == 'salsa':
return statopt.SALSA(params,
alpha=conf_opt['alpha'])
else:
raise ValueError('invalid optimizer type=%s' % conf_opt['type'])
raise ValueError('invalid optimizer type=%s' % optim_type)
def get_optim_lr(optimizer:Optimizer)->float:
for param_group in optimizer.param_groups:
@ -80,22 +87,25 @@ def create_lr_scheduler(conf_lrs:Config, epochs:int, optimizer:Optimizer,
# epoch_or_step - apply every epoch or every step
scheduler, epoch_or_step = None, True # by default sched step on epoch
# TODO: adjust max epochs for warmup?
# if conf_lrs.get('warmup', None):
# epochs -= conf_lrs['warmup']['epochs']
conf_warmup = conf_lrs.get_val('warmup', None)
warmup_epochs = 0
if conf_warmup is not None and 'epochs' in conf_warmup:
warmup_epochs = conf_warmup['epochs']
if conf_lrs is not None:
lr_scheduler_type = conf_lrs['type'] # TODO: default should be none?
if lr_scheduler_type == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs,
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
T_max=epochs-warmup_epochs,
eta_min=conf_lrs['min_lr'])
elif lr_scheduler_type == 'multi_step':
scheduler = lr_scheduler.MultiStepLR(optimizer,
milestones=conf_lrs['milestones'],
gamma=conf_lrs['gamma'])
elif lr_scheduler_type == 'pyramid':
scheduler = _adjust_learning_rate_pyramid(optimizer, epochs,
scheduler = _adjust_learning_rate_pyramid(optimizer,
epochs-warmup_epochs,
get_optim_lr(optimizer))
elif lr_scheduler_type == 'step':
decay_period = conf_lrs['decay_period']
@ -107,7 +117,8 @@ def create_lr_scheduler(conf_lrs:Config, epochs:int, optimizer:Optimizer,
max_lr = conf_lrs['max_lr']
epoch_or_step = False
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
epochs=epochs, steps_per_epoch=steps_per_epoch,
epochs=epochs-warmup_epochs,
steps_per_epoch=steps_per_epoch,
) # TODO: other params
elif not lr_scheduler_type:
scheduler = None
@ -115,16 +126,17 @@ def create_lr_scheduler(conf_lrs:Config, epochs:int, optimizer:Optimizer,
raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type)
# select warmup for LR schedule
if conf_lrs.get('warmup', None):
if conf_lrs.get_val('warmup', None):
scheduler = GradualWarmupScheduler(
optimizer,
multiplier=conf_lrs['warmup']['multiplier'],
multiplier=conf_lrs['warmup'].get_val('multiplier', 1.0),
total_epoch=conf_lrs['warmup']['epochs'],
after_scheduler=scheduler
)
return scheduler, epoch_or_step
def _adjust_learning_rate_pyramid(optimizer, max_epoch:int, base_lr:float):
def _internal_adjust_learning_rate_pyramid(epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
@ -134,54 +146,6 @@ def _adjust_learning_rate_pyramid(optimizer, max_epoch:int, base_lr:float):
return lr_scheduler.LambdaLR(optimizer, _internal_adjust_learning_rate_pyramid)
# TODO: replace this with SmoothCrossEntropyLoss class
# def cross_entropy_smooth(input: torch.Tensor, target, size_average=True, label_smoothing=0.1):
# y = torch.eye(10).to(input.device)
# lb_oh = y[target]
# target = lb_oh * (1 - label_smoothing) + 0.5 * label_smoothing
# logsoftmax = nn.LogSoftmax()
# if size_average:
# return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
# else:
# return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))
class SmoothCrossEntropyLoss(_WeightedLoss):
def __init__(self, weight=None, reduction='mean', smoothing=0.0):
super().__init__(weight=weight, reduction=reduction)
self.smoothing = smoothing
self.weight = weight
self.reduction = reduction
@staticmethod
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
assert 0 <= smoothing < 1
with torch.no_grad():
# For label smoothing, we replace 1-hot vector with 0.9-hot vector instead.
# Create empty vector of same size as targets, fill them up with smoothing/(n-1)
# then replace element where 1 supposed to go and put there 1-smoothing instead
targets = torch.empty(size=(targets.size(0), n_classes), device=targets.device) \
.fill_(smoothing /(n_classes-1)) \
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
return targets
def forward(self, inputs, targets):
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
self.smoothing)
lsm = F.log_softmax(inputs, -1)
if self.weight is not None: # to support weighted targets
lsm = lsm * self.weight.unsqueeze(0)
loss = -(targets * lsm).sum(-1)
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss
def get_lossfn(conf_lossfn:Config)->_Loss:
type = conf_lossfn['type']
if type == 'CrossEntropyLoss':

Просмотреть файл

@ -228,12 +228,17 @@ class Trainer(EnforceOverrides):
get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim)
self._optim.step()
get_apex_utils().sync_devices()
if self._sched and not self._sched_on_epoch:
self._sched.step()
self.post_step(x, y, logits, loss, steps)
logger.popd()
# end of step
if self._sched and self._sched_on_epoch:
self._sched.step()
logger.popd()

Просмотреть файл

@ -77,7 +77,7 @@ def get_dataloaders(ds_provider:DatasetProvider,
train_workers = test_workers = 0
logger.warn({'debugger': True})
if train_workers is None:
train_workers = 4
train_workers = 4 # following NVidia DeepLearningExamples
if test_workers is None:
test_workers = 4
logger.info({'train_workers': train_workers, 'test_workers':test_workers})

Просмотреть файл

@ -20,15 +20,14 @@ common:
apex:
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
enabled: False # global switch to disable anything apex
opt_level: 'O1' # optimization level for mixed precision
opt_level: 'O2' # optimization level for mixed precision
bn_fp32: True # keep BN in fp32
loss_scale: None # loss scaling mode for mixed prec
loss_scale: "dynamic" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
sync_bn: False # should be replace BNs with sync BNs for distributed model
distributed: False # enable/disable distributed mode
scale_lr: True # enable/disable distributed mode
min_world_size: 0 # allows to confirm we are indeed in distributed setting
smoke_test: False
only_eval: False
resume: True
@ -74,10 +73,10 @@ nas:
cutout: 16 # cutout length, use cutout augmentation when > 0
load_train: True # load train split of dataset
train_batch: 96
train_workers: null # if null then gpu_count*4
train_workers: null # if null then 4
load_test: True # load test split of dataset
test_batch: 1024
test_workers: null # if null then gpu_count*4
test_workers: null # if null then 4
val_ratio: 0.0 #split portion for test set, 0 to 1
val_fold: 0 #Fold number to use (0 to 4)
cv_num: 5 # total number of folds available
@ -100,10 +99,13 @@ nas:
decay: 3.0e-4 # pytorch default is 0.0
momentum: 0.9 # pytorch default is 0.0
nesterov: False # pytorch default is False
warmup: null
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
lr_schedule:
type: 'cosine'
min_lr: 0.001 # min learning rate to se bet in eta_min param of scheduler
warmup: null # increases LR for 0 to current in specified epochs and then hands over to main scheduler
# multiplier: 1 # end warmup at this multiple of LR
# epochs: 1
validation:
title: 'eval_test'
logger_freq: 0
@ -202,15 +204,17 @@ nas:
decay: 3.0e-4
momentum: 0.9 # pytorch default is 0
nesterov: False
warmup: null
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
alpha_optimizer:
type: 'adam'
lr: 3.0e-4
decay: 1.0e-3
betas: [0.5, 0.999]
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
lr_schedule:
type: 'cosine'
min_lr: 0.001 # min learning rate, this will be used in eta_min param of scheduler
warmup: null
validation:
title: 'search_val'
logger_freq: 0
@ -246,11 +250,9 @@ autoaug:
momentum: 0.9 # pytorch default is 0.0
nesterov: False # pytorch default is False
clip: 5.0 # grads above this value is clipped # TODO: Why is this also in trainer?
warmup:
null
# multiplier: 2
# epochs: 3
decay_bn: .NaN # if NaN then same as decay otherwise apply different decay to BN layers
#betas: [0.9, 0.999] # PyTorch default betas for Adam
lr_schedule:
type: 'cosine'
min_lr: 0.0 # min learning rate, this will be used in eta_min param of scheduler
warmup: null

Просмотреть файл

@ -1,6 +1,8 @@
common:
seed: 0.0
apex:
loss_scale: "128" # loss scaling mode for mixed prec, must be string reprenting floar ot "dynamic"
dataset_eval:
name: 'imagenet'
@ -22,6 +24,8 @@ nas:
model_post_op: 'pool_avg2d7x7'
dataset:
_copy: '/dataset_eval'
# darts setup
loader:
batch: 128
dataset:
@ -39,4 +43,31 @@ nas:
lr_schedule:
type: 'step'
decay_period: 1 # epochs between two learning rate decays
gamma: 0.97 # learning rate decay
gamma: 0.97 # learning rate decay
# NVidia benchmark setup DGX1_RN50_AMP_90E.sh
# Enable amp and distributed 8 GPUs in apex section
# loader:
# batch: 256
# train_workers: 5
# test_workers: 5
# dataset:
# _copy: '/dataset_eval'
# trainer:
# aux_weight: 0.0 # weight for loss from auxiliary towers in test time arch
# drop_path_prob: 0.0 # probability that given edge will be dropped
# epochs: 90
# lossfn: # TODO: this is perhaps reversed for test/train?
# type: 'CrossEntropyLabelSmooth'
# smoothing: 0.1 # label smoothing
# optimizer:
# lr: 2.048 # init learning rate
# decay: 3.05e-5
# decay_bn: 0.0 # if NaN then same as decay otherwise apply different decay to BN layers
# momentum: 0.875 # pytorch default is 0.0
# lr_schedule:
# type: 'cosine'
# min_lr: 0.0 # min learning rate to se bet in eta_min param of scheduler
# warmup: # increases LR for 0 to current in specified epochs and then hands over to main scheduler
# multiplier: 1
# epochs: 8

Просмотреть файл

@ -1,22 +1,10 @@
import yaml
y = """
d: &d
f: 21
g: 31
d1:
f: 21
g: 31
c:
d: *d
a: .NaN
"""
d=yaml.load(y, Loader=yaml.Loader)
print(d)
print(d['d']==d['c']['d'])
print(d['d1']==d['c']['d'])
print(type( d['a']))