зеркало из https://github.com/microsoft/archai.git
All-No_pareto functional again
This commit is contained in:
Родитель
fd1fb0ab44
Коммит
096b564983
|
@ -214,7 +214,7 @@ class ApexUtils:
|
||||||
else:
|
else:
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def _get_optim(self, multi_optim:MultiOptim)->Optimizer:
|
def _get_one_optim(self, multi_optim:MultiOptim)->Optimizer:
|
||||||
assert len(multi_optim)==1, \
|
assert len(multi_optim)==1, \
|
||||||
'Mixed precision is only supported for one optimizer' \
|
'Mixed precision is only supported for one optimizer' \
|
||||||
f' but {len(multi_optim)} optimizers were supplied'
|
f' but {len(multi_optim)} optimizers were supplied'
|
||||||
|
@ -234,7 +234,10 @@ class ApexUtils:
|
||||||
|
|
||||||
def step(self, multi_optim:MultiOptim)->None:
|
def step(self, multi_optim:MultiOptim)->None:
|
||||||
if self.is_mixed():
|
if self.is_mixed():
|
||||||
self._scaler.step(self._get_optim(multi_optim)) # pyright: ignore[reportOptionalMemberAccess]
|
# self._scaler.unscale_ will be called automatically if it isn't called yet from grad clipping
|
||||||
|
# https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.step
|
||||||
|
for optim_shed in multi_optim:
|
||||||
|
self._scaler.step(optim_shed.optim) # pyright: ignore[reportOptionalMemberAccess]
|
||||||
self._scaler.update() # pyright: ignore[reportOptionalMemberAccess]
|
self._scaler.update() # pyright: ignore[reportOptionalMemberAccess]
|
||||||
else:
|
else:
|
||||||
multi_optim.step()
|
multi_optim.step()
|
||||||
|
@ -249,12 +252,13 @@ class ApexUtils:
|
||||||
model = model.to(self.device)
|
model = model.to(self.device)
|
||||||
|
|
||||||
# scale LR
|
# scale LR
|
||||||
optim = self._get_optim(multi_optim)
|
|
||||||
if self.is_dist() and self._scale_lr:
|
if self.is_dist() and self._scale_lr:
|
||||||
lr = ml_utils.get_optim_lr(optim)
|
for optim_shed in multi_optim:
|
||||||
scaled_lr = lr * self.world_size / float(batch_size)
|
optim = optim_shed.optim
|
||||||
ml_utils.set_optim_lr(optim, scaled_lr)
|
lr = ml_utils.get_optim_lr(optim)
|
||||||
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
scaled_lr = lr * self.world_size / float(batch_size)
|
||||||
|
ml_utils.set_optim_lr(optim, scaled_lr)
|
||||||
|
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
||||||
|
|
||||||
if self.is_dist():
|
if self.is_dist():
|
||||||
model = DistributedDataParallel(model, device_ids=[self._gpu], output_device=self._gpu)
|
model = DistributedDataParallel(model, device_ids=[self._gpu], output_device=self._gpu)
|
||||||
|
@ -264,8 +268,8 @@ class ApexUtils:
|
||||||
def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None:
|
def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None:
|
||||||
if clip > 0.0:
|
if clip > 0.0:
|
||||||
if self.is_mixed():
|
if self.is_mixed():
|
||||||
optim = self._get_optim(multi_optim)
|
# https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers
|
||||||
self._scaler.unscale_(optim) # pyright: ignore[reportOptionalMemberAccess]
|
self._scaler.unscale_(multi_optim[0].optim) # pyright: ignore[reportOptionalMemberAccess]
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
||||||
else:
|
else:
|
||||||
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
nn.utils.clip_grad_norm_(model.parameters(), clip)
|
||||||
|
|
|
@ -20,8 +20,8 @@ from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.utils import ml_utils
|
from archai.supergraph.utils import ml_utils
|
||||||
from archai.supergraph.utils.checkpoint import CheckPoint
|
from archai.supergraph.utils.checkpoint import CheckPoint
|
||||||
from archai.supergraph.datasets import data
|
from archai.supergraph.datasets import data
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.algos.darts.bilevel_optimizer import BilevelOptimizer
|
from archai.supergraph.algos.darts.bilevel_optimizer import BilevelOptimizer
|
||||||
|
|
||||||
class BilevelArchTrainer(ArchTrainer):
|
class BilevelArchTrainer(ArchTrainer):
|
||||||
|
|
|
@ -12,8 +12,8 @@ from torch.optim.optimizer import Optimizer
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.common import utils
|
from archai.common import utils
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.common.utils import zip_eq
|
from archai.common.utils import zip_eq
|
||||||
from archai.supergraph.utils import ml_utils
|
from archai.supergraph.utils import ml_utils
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ from torch.optim.optimizer import Optimizer
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.common import utils
|
from archai.common import utils
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.common.utils import zip_eq
|
from archai.common.utils import zip_eq
|
||||||
from archai.supergraph.utils import ml_utils
|
from archai.supergraph.utils import ml_utils
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,8 @@ from archai.common import utils
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.utils import ml_utils
|
from archai.supergraph.utils import ml_utils
|
||||||
from archai.supergraph.utils.checkpoint import CheckPoint
|
from archai.supergraph.utils.checkpoint import CheckPoint
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.utils.multi_optim import MultiOptim, OptimSched
|
from archai.supergraph.utils.multi_optim import MultiOptim, OptimSched
|
||||||
|
|
||||||
class DidartsArchTrainer(ArchTrainer):
|
class DidartsArchTrainer(ArchTrainer):
|
||||||
|
|
|
@ -10,8 +10,8 @@ from torch import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from archai.common.common import get_conf
|
from archai.common.common import get_conf
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.datasets.data import get_data
|
from archai.supergraph.datasets.data import get_data
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.nas.cell import Cell
|
from archai.supergraph.nas.cell import Cell
|
||||||
|
|
|
@ -14,8 +14,8 @@ import os
|
||||||
|
|
||||||
from archai.common.common import get_conf
|
from archai.common.common import get_conf
|
||||||
from archai.common.common import get_expdir
|
from archai.common.common import get_expdir
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.datasets.data import get_data
|
from archai.supergraph.datasets.data import get_data
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.nas.cell import Cell
|
from archai.supergraph.nas.cell import Cell
|
||||||
|
|
|
@ -20,8 +20,8 @@ from archai.common import utils
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.utils import ml_utils
|
from archai.supergraph.utils import ml_utils
|
||||||
from archai.supergraph.utils.checkpoint import CheckPoint
|
from archai.supergraph.utils.checkpoint import CheckPoint
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.common.common import get_conf
|
from archai.common.common import get_conf
|
||||||
from archai.supergraph.algos.gumbelsoftmax.gs_op import GsOp
|
from archai.supergraph.algos.gumbelsoftmax.gs_op import GsOp
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,8 @@ import os
|
||||||
|
|
||||||
from archai.common.common import get_conf
|
from archai.common.common import get_conf
|
||||||
from archai.common.common import get_expdir
|
from archai.common.common import get_expdir
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.datasets.data import get_data
|
from archai.supergraph.datasets.data import get_data
|
||||||
from archai.supergraph.nas.model import Model
|
from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.nas.cell import Cell
|
from archai.supergraph.nas.cell import Cell
|
||||||
|
|
|
@ -17,8 +17,8 @@ from archai.supergraph.utils import ml_utils
|
||||||
|
|
||||||
from archai.supergraph.utils.trainer import Trainer
|
from archai.supergraph.utils.trainer import Trainer
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.datasets import data
|
from archai.supergraph.datasets import data
|
||||||
from archai.supergraph.nas.model_desc import ModelDesc
|
from archai.supergraph.nas.model_desc import ModelDesc
|
||||||
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
|
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
|
||||||
|
|
|
@ -11,8 +11,8 @@ from overrides import overrides
|
||||||
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
|
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
|
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder
|
||||||
|
|
|
@ -77,12 +77,9 @@ class PetridishOp(Op):
|
||||||
'avg_pool_3x3',
|
'avg_pool_3x3',
|
||||||
'skip_connect', # identity
|
'skip_connect', # identity
|
||||||
'sep_conv_3x3',
|
'sep_conv_3x3',
|
||||||
#'sep_conv_5x5',
|
'sep_conv_5x5',
|
||||||
'dil_conv_3x3',
|
'dil_conv_3x3',
|
||||||
#'dil_conv_5x5',
|
'dil_conv_5x5',
|
||||||
'mbconv_r3',
|
|
||||||
'mbconv_r2',
|
|
||||||
'mbconv_r1',
|
|
||||||
'none' # this must be at the end so top1 doesn't chose it
|
'none' # this must be at the end so top1 doesn't chose it
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -204,3 +201,4 @@ class PetridishOp(Op):
|
||||||
# we store alphas in list so Pytorch don't register them
|
# we store alphas in list so Pytorch don't register them
|
||||||
self._alphas = list(self.arch_params().paramlist_by_kind('alphas'))
|
self._alphas = list(self.arch_params().paramlist_by_kind('alphas'))
|
||||||
assert len(self._alphas)==1
|
assert len(self._alphas)==1
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@ import matplotlib.pyplot as plt
|
||||||
from archai.supergraph.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
|
from archai.supergraph.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
|
||||||
EdgeDesc, TensorShape, TensorShapes, NodeDesc, ModelDesc
|
EdgeDesc, TensorShape, TensorShapes, NodeDesc, ModelDesc
|
||||||
from archai.supergraph.utils.metrics import Metrics
|
from archai.supergraph.utils.metrics import Metrics
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.common import utils
|
from archai.common import utils
|
||||||
|
|
||||||
class JobStage(Enum):
|
class JobStage(Enum):
|
||||||
|
|
|
@ -25,8 +25,8 @@ from torch.utils.data.dataloader import DataLoader
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from archai.common import common
|
from archai.common import common
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__),
|
from archai.common.common import logger
|
||||||
from archai.common.common import CommonState
|
from archai.common.common import CommonState
|
||||||
from archai.supergraph.utils.checkpoint import CheckPoint
|
from archai.supergraph.utils.checkpoint import CheckPoint
|
||||||
from archai.common.config import Config
|
from archai.common.config import Config
|
||||||
|
|
|
@ -21,8 +21,8 @@ from archai.supergraph.nas.model import Model
|
||||||
from archai.supergraph.nas.model_desc import CellType
|
from archai.supergraph.nas.model_desc import CellType
|
||||||
from archai.supergraph.utils import ml_utils
|
from archai.supergraph.utils import ml_utils
|
||||||
from archai.supergraph.utils.checkpoint import CheckPoint
|
from archai.supergraph.utils.checkpoint import CheckPoint
|
||||||
from archai.common.logger import Logger
|
|
||||||
logger = Logger(source=__name__)
|
from archai.common.common import logger
|
||||||
from archai.supergraph.datasets import data
|
from archai.supergraph.datasets import data
|
||||||
from archai.common.common import get_conf
|
from archai.common.common import get_conf
|
||||||
from archai.supergraph.algos.xnas.xnas_op import XnasOp
|
from archai.supergraph.algos.xnas.xnas_op import XnasOp
|
||||||
|
|
|
@ -11,9 +11,9 @@ from torchvision.transforms import transforms
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
from torchvision.datasets.utils import check_integrity, download_url
|
from torchvision.datasets.utils import check_integrity, download_url
|
||||||
from archai.common.utils import download_and_extract_tar, extract_tar
|
from archai.common.utils import download_and_extract_tar, extract_tar
|
||||||
from archai.common.logger import Logger
|
|
||||||
|
|
||||||
logger = Logger(source=__name__)
|
|
||||||
|
from archai.common.common import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,19 +4,18 @@ from torch import nn
|
||||||
from torch.nn import DataParallel
|
from torch.nn import DataParallel
|
||||||
# from torchvision import models
|
# from torchvision import models
|
||||||
|
|
||||||
from archai.supergraph.nas.models.resnet import ResNet
|
from .pyramidnet import PyramidNet
|
||||||
from archai.supergraph.nas.models.pyramidnet import PyramidNet
|
from .shakeshake.shake_resnet import ShakeResNet
|
||||||
from archai.supergraph.nas.models.shakeshake.shake_resnet import ShakeResNet
|
from .wideresnet import WideResNet
|
||||||
from archai.supergraph.nas.models.wideresnet import WideResNet
|
from .shakeshake.shake_resnext import ShakeResNeXt
|
||||||
from archai.supergraph.nas.models.shakeshake.shake_resnext import ShakeResNeXt
|
|
||||||
|
|
||||||
from archai.supergraph.nas.models.mobilenetv2 import *
|
from .mobilenetv2 import *
|
||||||
from archai.supergraph.nas.models.resnet_cifar10 import *
|
from .resnet import *
|
||||||
from archai.supergraph.nas.models.vgg import *
|
from .vgg import *
|
||||||
from archai.supergraph.nas.models.densenet import *
|
from .densenet import *
|
||||||
from archai.supergraph.nas.models.resnet_orig import *
|
from .resnet_orig import *
|
||||||
from archai.supergraph.nas.models.googlenet import *
|
from .googlenet import *
|
||||||
from archai.supergraph.nas.models.inception import *
|
from .inception import *
|
||||||
|
|
||||||
|
|
||||||
def get_model(conf, num_class=10):
|
def get_model(conf, num_class=10):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from archai.supergraph.nas.models.shakedrop import ShakeDrop
|
from .shakedrop import ShakeDrop
|
||||||
|
|
||||||
|
|
||||||
def conv3x3(in_planes, out_planes, stride=1):
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
|
|
@ -1,31 +1,44 @@
|
||||||
# Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import math
|
import os
|
||||||
|
|
||||||
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
|
||||||
|
|
||||||
def conv3x3(in_planes, out_planes, stride=1):
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
"3x3 convolution with padding"
|
"""3x3 convolution with padding"""
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
padding=1, bias=False)
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
|
norm_layer = nn.BatchNorm2d
|
||||||
|
if groups != 1 or base_width != 64:
|
||||||
|
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||||
|
if dilation > 1:
|
||||||
|
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
self.bn1 = norm_layer(planes)
|
||||||
self.conv2 = conv3x3(planes, planes)
|
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.bn2 = norm_layer(planes)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
identity = x
|
||||||
|
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
out = self.bn1(out)
|
out = self.bn1(out)
|
||||||
|
@ -35,9 +48,9 @@ class BasicBlock(nn.Module):
|
||||||
out = self.bn2(out)
|
out = self.bn2(out)
|
||||||
|
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
residual = self.downsample(x)
|
identity = self.downsample(x)
|
||||||
|
|
||||||
out += residual
|
out += identity
|
||||||
out = self.relu(out)
|
out = self.relu(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
@ -46,22 +59,25 @@ class BasicBlock(nn.Module):
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(nn.Module):
|
||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||||
|
base_width=64, dilation=1, norm_layer=None):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
|
if norm_layer is None:
|
||||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
norm_layer = nn.BatchNorm2d
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
width = int(planes * (base_width / 64.)) * groups
|
||||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
self.conv1 = conv1x1(inplanes, width)
|
||||||
self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
|
self.bn1 = norm_layer(width)
|
||||||
self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
|
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||||
|
self.bn2 = norm_layer(width)
|
||||||
|
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||||
|
self.bn3 = norm_layer(planes * self.expansion)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
identity = x
|
||||||
|
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
out = self.bn1(out)
|
out = self.bn1(out)
|
||||||
|
@ -73,108 +89,199 @@ class Bottleneck(nn.Module):
|
||||||
|
|
||||||
out = self.conv3(out)
|
out = self.conv3(out)
|
||||||
out = self.bn3(out)
|
out = self.bn3(out)
|
||||||
if self.downsample is not None:
|
|
||||||
residual = self.downsample(x)
|
|
||||||
|
|
||||||
out += residual
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
out = self.relu(out)
|
out = self.relu(out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
class ResNet(nn.Module):
|
||||||
def __init__(self, dataset, depth, n_classes, bottleneck=False):
|
|
||||||
|
def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
|
||||||
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
|
norm_layer=None):
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
self.dataset = dataset
|
if norm_layer is None:
|
||||||
if self.dataset.startswith('cifar'):
|
norm_layer = nn.BatchNorm2d
|
||||||
self.inplanes = 16
|
self._norm_layer = norm_layer
|
||||||
#logger.info(bottleneck)
|
|
||||||
if bottleneck == True:
|
|
||||||
n = int((depth - 2) / 9)
|
|
||||||
block = Bottleneck
|
|
||||||
else:
|
|
||||||
n = int((depth - 2) / 6)
|
|
||||||
block = BasicBlock
|
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
self.inplanes = 64
|
||||||
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
self.dilation = 1
|
||||||
self.relu = nn.ReLU(inplace=True)
|
if replace_stride_with_dilation is None:
|
||||||
self.layer1 = self._make_layer(block, 16, n)
|
# each element in the tuple indicates if we should replace
|
||||||
self.layer2 = self._make_layer(block, 32, n, stride=2)
|
# the 2x2 stride with a dilated convolution instead
|
||||||
self.layer3 = self._make_layer(block, 64, n, stride=2)
|
replace_stride_with_dilation = [False, False, False]
|
||||||
# self.avgpool = nn.AvgPool2d(8)
|
if len(replace_stride_with_dilation) != 3:
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
raise ValueError("replace_stride_with_dilation should be None "
|
||||||
self.fc = nn.Linear(64 * block.expansion, n_classes)
|
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||||
|
self.groups = groups
|
||||||
|
self.base_width = width_per_group
|
||||||
|
|
||||||
elif dataset == 'imagenet':
|
## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
|
||||||
blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
|
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
|
## END
|
||||||
assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)'
|
|
||||||
|
|
||||||
self.inplanes = 64
|
self.bn1 = norm_layer(self.inplanes)
|
||||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
self.bn1 = nn.BatchNorm2d(64)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
||||||
self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0])
|
dilate=replace_stride_with_dilation[0])
|
||||||
self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2)
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||||
self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2)
|
dilate=replace_stride_with_dilation[1])
|
||||||
self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2)
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||||
# self.avgpool = nn.AvgPool2d(7)
|
dilate=replace_stride_with_dilation[2])
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
self.fc = nn.Linear(512 * blocks[depth].expansion, n_classes)
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
nn.init.constant_(m.weight, 1)
|
||||||
m.weight.data.fill_(1)
|
nn.init.constant_(m.bias, 0)
|
||||||
m.bias.data.zero_()
|
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, stride=1):
|
# Zero-initialize the last BN in each residual branch,
|
||||||
|
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||||
|
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||||
|
if zero_init_residual:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, Bottleneck):
|
||||||
|
nn.init.constant_(m.bn3.weight, 0)
|
||||||
|
elif isinstance(m, BasicBlock):
|
||||||
|
nn.init.constant_(m.bn2.weight, 0)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
norm_layer = self._norm_layer
|
||||||
downsample = None
|
downsample = None
|
||||||
|
previous_dilation = self.dilation
|
||||||
|
if dilate:
|
||||||
|
self.dilation *= stride
|
||||||
|
stride = 1
|
||||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
downsample = nn.Sequential(
|
downsample = nn.Sequential(
|
||||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||||
kernel_size=1, stride=stride, bias=False),
|
norm_layer(planes * block.expansion),
|
||||||
nn.BatchNorm2d(planes * block.expansion),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||||
|
self.base_width, previous_dilation, norm_layer))
|
||||||
self.inplanes = planes * block.expansion
|
self.inplanes = planes * block.expansion
|
||||||
for i in range(1, blocks):
|
for _ in range(1, blocks):
|
||||||
layers.append(block(self.inplanes, planes))
|
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||||
|
base_width=self.base_width, dilation=self.dilation,
|
||||||
|
norm_layer=norm_layer))
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.dataset == 'cifar10' or self.dataset == 'cifar100':
|
x = self.conv1(x)
|
||||||
x = self.conv1(x)
|
x = self.bn1(x)
|
||||||
x = self.bn1(x)
|
x = self.relu(x)
|
||||||
x = self.relu(x)
|
x = self.maxpool(x)
|
||||||
|
|
||||||
x = self.layer1(x)
|
x = self.layer1(x)
|
||||||
x = self.layer2(x)
|
x = self.layer2(x)
|
||||||
x = self.layer3(x)
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
|
||||||
x = self.avgpool(x)
|
x = self.avgpool(x)
|
||||||
x = x.view(x.size(0), -1)
|
x = x.reshape(x.size(0), -1)
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
|
|
||||||
elif self.dataset == 'imagenet':
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.maxpool(x)
|
|
||||||
|
|
||||||
x = self.layer1(x)
|
|
||||||
x = self.layer2(x)
|
|
||||||
x = self.layer3(x)
|
|
||||||
x = self.layer4(x)
|
|
||||||
|
|
||||||
x = self.avgpool(x)
|
|
||||||
x = x.view(x.size(0), -1)
|
|
||||||
x = self.fc(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
|
||||||
|
model = ResNet(block, layers, **kwargs)
|
||||||
|
if pretrained:
|
||||||
|
script_dir = os.path.dirname(__file__)
|
||||||
|
state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNet-18 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNet-34 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNet-50 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNet-101 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNet-152 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNeXt-50 32x4d model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 4
|
||||||
|
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||||
|
pretrained, progress, device, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
|
||||||
|
"""Constructs a ResNeXt-101 32x8d model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
|
progress (bool): If True, displays a progress bar of the download to stderr
|
||||||
|
"""
|
||||||
|
kwargs['groups'] = 32
|
||||||
|
kwargs['width_per_group'] = 8
|
||||||
|
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||||
|
pretrained, progress, device, **kwargs)
|
||||||
|
|
|
@ -1,287 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import os
|
|
||||||
|
|
||||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
|
||||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
|
|
||||||
|
|
||||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
|
||||||
"""3x3 convolution with padding"""
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
|
||||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
|
||||||
|
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
|
||||||
"""1x1 convolution"""
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock(nn.Module):
|
|
||||||
expansion = 1
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
|
||||||
base_width=64, dilation=1, norm_layer=None):
|
|
||||||
super(BasicBlock, self).__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
if groups != 1 or base_width != 64:
|
|
||||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
|
||||||
if dilation > 1:
|
|
||||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
|
||||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
||||||
self.bn1 = norm_layer(planes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.conv2 = conv3x3(planes, planes)
|
|
||||||
self.bn2 = norm_layer(planes)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
|
||||||
expansion = 4
|
|
||||||
|
|
||||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
|
||||||
base_width=64, dilation=1, norm_layer=None):
|
|
||||||
super(Bottleneck, self).__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
width = int(planes * (base_width / 64.)) * groups
|
|
||||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
|
||||||
self.conv1 = conv1x1(inplanes, width)
|
|
||||||
self.bn1 = norm_layer(width)
|
|
||||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
|
||||||
self.bn2 = norm_layer(width)
|
|
||||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
|
||||||
self.bn3 = norm_layer(planes * self.expansion)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.downsample = downsample
|
|
||||||
self.stride = stride
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
identity = x
|
|
||||||
|
|
||||||
out = self.conv1(x)
|
|
||||||
out = self.bn1(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv2(out)
|
|
||||||
out = self.bn2(out)
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
|
||||||
out = self.bn3(out)
|
|
||||||
|
|
||||||
if self.downsample is not None:
|
|
||||||
identity = self.downsample(x)
|
|
||||||
|
|
||||||
out += identity
|
|
||||||
out = self.relu(out)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class ResNet(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
|
|
||||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
|
||||||
norm_layer=None):
|
|
||||||
super(ResNet, self).__init__()
|
|
||||||
if norm_layer is None:
|
|
||||||
norm_layer = nn.BatchNorm2d
|
|
||||||
self._norm_layer = norm_layer
|
|
||||||
|
|
||||||
self.inplanes = 64
|
|
||||||
self.dilation = 1
|
|
||||||
if replace_stride_with_dilation is None:
|
|
||||||
# each element in the tuple indicates if we should replace
|
|
||||||
# the 2x2 stride with a dilated convolution instead
|
|
||||||
replace_stride_with_dilation = [False, False, False]
|
|
||||||
if len(replace_stride_with_dilation) != 3:
|
|
||||||
raise ValueError("replace_stride_with_dilation should be None "
|
|
||||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
|
||||||
self.groups = groups
|
|
||||||
self.base_width = width_per_group
|
|
||||||
|
|
||||||
## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
|
|
||||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
|
||||||
## END
|
|
||||||
|
|
||||||
self.bn1 = norm_layer(self.inplanes)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
||||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
||||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[0])
|
|
||||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[1])
|
|
||||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[2])
|
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
||||||
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
||||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
|
||||||
nn.init.constant_(m.weight, 1)
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
|
|
||||||
# Zero-initialize the last BN in each residual branch,
|
|
||||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
|
||||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
|
||||||
if zero_init_residual:
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, Bottleneck):
|
|
||||||
nn.init.constant_(m.bn3.weight, 0)
|
|
||||||
elif isinstance(m, BasicBlock):
|
|
||||||
nn.init.constant_(m.bn2.weight, 0)
|
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
|
||||||
norm_layer = self._norm_layer
|
|
||||||
downsample = None
|
|
||||||
previous_dilation = self.dilation
|
|
||||||
if dilate:
|
|
||||||
self.dilation *= stride
|
|
||||||
stride = 1
|
|
||||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
||||||
downsample = nn.Sequential(
|
|
||||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
|
||||||
norm_layer(planes * block.expansion),
|
|
||||||
)
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
|
||||||
self.base_width, previous_dilation, norm_layer))
|
|
||||||
self.inplanes = planes * block.expansion
|
|
||||||
for _ in range(1, blocks):
|
|
||||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
|
||||||
base_width=self.base_width, dilation=self.dilation,
|
|
||||||
norm_layer=norm_layer))
|
|
||||||
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.maxpool(x)
|
|
||||||
|
|
||||||
x = self.layer1(x)
|
|
||||||
x = self.layer2(x)
|
|
||||||
x = self.layer3(x)
|
|
||||||
x = self.layer4(x)
|
|
||||||
|
|
||||||
x = self.avgpool(x)
|
|
||||||
x = x.reshape(x.size(0), -1)
|
|
||||||
x = self.fc(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
|
|
||||||
model = ResNet(block, layers, **kwargs)
|
|
||||||
if pretrained:
|
|
||||||
script_dir = os.path.dirname(__file__)
|
|
||||||
state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNet-18 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNet-34 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNet-50 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNet-101 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNet-152 model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNeXt-50 32x4d model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
kwargs['groups'] = 32
|
|
||||||
kwargs['width_per_group'] = 4
|
|
||||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
|
||||||
pretrained, progress, device, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
|
|
||||||
"""Constructs a ResNeXt-101 32x8d model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
|
||||||
progress (bool): If True, displays a progress bar of the download to stderr
|
|
||||||
"""
|
|
||||||
kwargs['groups'] = 32
|
|
||||||
kwargs['width_per_group'] = 8
|
|
||||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
|
||||||
pretrained, progress, device, **kwargs)
|
|
|
@ -5,7 +5,7 @@ import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from archai.supergraph.nas.models.shakeshake.shakeshake import ShakeShake, Shortcut
|
from .shakeshake import ShakeShake, Shortcut
|
||||||
|
|
||||||
|
|
||||||
class ShakeBlock(nn.Module):
|
class ShakeBlock(nn.Module):
|
||||||
|
|
|
@ -5,7 +5,7 @@ import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from archai.supergraph.nas.models.shakeshake.shakeshake import ShakeShake, Shortcut
|
from .shakeshake import ShakeShake, Shortcut
|
||||||
|
|
||||||
|
|
||||||
class ShakeBottleNeck(nn.Module):
|
class ShakeBottleNeck(nn.Module):
|
||||||
|
|
|
@ -95,9 +95,9 @@ class Evaluater(EnforceOverrides):
|
||||||
# TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead
|
# TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead
|
||||||
if dataset_name.startswith('cifar'):
|
if dataset_name.startswith('cifar'):
|
||||||
if function_name.startswith('res'): # support resnext as well
|
if function_name.startswith('res'): # support resnext as well
|
||||||
module_name = 'archai.cifar10_models.resnet'
|
module_name = 'archai.supergraph.models.resnet'
|
||||||
elif function_name.startswith('dense'):
|
elif function_name.startswith('dense'):
|
||||||
module_name = 'archai.cifar10_models.densenet'
|
module_name = 'archai.supergraph.models.densenet'
|
||||||
elif dataset_name.startswith('imagenet') or dataset_name.startswith('sport8'):
|
elif dataset_name.startswith('imagenet') or dataset_name.startswith('sport8'):
|
||||||
module_name = 'torchvision.models'
|
module_name = 'torchvision.models'
|
||||||
if not module_name:
|
if not module_name:
|
||||||
|
|
|
@ -3,8 +3,6 @@ __include__: 'darts.yaml' # defaults are loaded from this file
|
||||||
common:
|
common:
|
||||||
#yaml_log: False
|
#yaml_log: False
|
||||||
apex:
|
apex:
|
||||||
enabled: False # global switch to disable everything apex
|
|
||||||
distributed_enabled: False # enable/disable distributed mode
|
|
||||||
ray:
|
ray:
|
||||||
enabled: True # initialize ray. Note: ray cannot be used if apex distributed is enabled
|
enabled: True # initialize ray. Note: ray cannot be used if apex distributed is enabled
|
||||||
local_mode: False # if True then ray runs in serial mode
|
local_mode: False # if True then ray runs in serial mode
|
||||||
|
@ -12,50 +10,31 @@ common:
|
||||||
nas:
|
nas:
|
||||||
eval:
|
eval:
|
||||||
final_desc_foldername: '$expdir/model_desc_gallery' #
|
final_desc_foldername: '$expdir/model_desc_gallery' #
|
||||||
source_desc_foldername: '$expdir/model_desc_gallery'
|
|
||||||
model_desc:
|
model_desc:
|
||||||
n_reductions: 2 # number of reductions to be applied
|
n_reductions: 2 # number of reductions to be applied
|
||||||
n_cells: 10 # number of max cells, for pareto frontier, we use cell_count_scale to multiply cells and limit by n_cells
|
n_cells: 20 # number of max cells, for pareto frontier, we use cell_count_scale to multiply cells and limit by n_cells
|
||||||
aux_weight: 0.0 # weight for loss from auxiliary towers in test time arch
|
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
||||||
num_edges_to_sample: 2 # number of edges each node will take inputs from
|
num_edges_to_sample: 2 # number of edges each node will take inputs from
|
||||||
aux_tower_stride: 3
|
|
||||||
model_stems:
|
model_stems:
|
||||||
ops: ['stem_conv3x3_s2', 'stem_conv3x3_s2']
|
init_node_ch: 36 # num of input/output channels for nodes in 1st cell
|
||||||
init_node_ch: 32 # num of input/output channels for nodes in 1st cell
|
|
||||||
stem_multiplier: 1 # output channels multiplier for the stem
|
|
||||||
cell:
|
cell:
|
||||||
n_nodes: 5 # number of nodes in a cell if template desc is not provided
|
n_nodes: 5 # number of nodes in a cell if template desc is not provided
|
||||||
cell_post_op: 'proj_channels'
|
cell_post_op: 'proj_channels'
|
||||||
petridish:
|
petridish:
|
||||||
cell_count_scale: 1.0 # for eval first multiply number of cells used in search by this factor, limit to n_cells
|
cell_count_scale: 1.0 # for eval first multiply number of cells used in search by this factor, limit to n_cells
|
||||||
trainer:
|
trainer:
|
||||||
aux_weight: 0.0
|
epochs: 600
|
||||||
epochs: 1500
|
|
||||||
batch_chunks: 1
|
|
||||||
validation:
|
|
||||||
batch_chunks: 1
|
|
||||||
optimizer:
|
|
||||||
lr: 0.033
|
|
||||||
loader:
|
|
||||||
cutout: 6 # cutout length, use cutout augmentation when > 0
|
|
||||||
load_train: True # load train split of dataset
|
|
||||||
train_batch: 32
|
|
||||||
test_batch: 32
|
|
||||||
img_size: 16
|
|
||||||
aug: 'autoaug_cifar10'
|
|
||||||
# dataset:
|
|
||||||
# max_batches: 32
|
|
||||||
|
|
||||||
search:
|
search:
|
||||||
final_desc_foldername: '$expdir/model_desc_gallery' # the gallery of models that eval will train from scratch
|
final_desc_foldername: '$expdir/model_desc_gallery' # the gallery of models that eval will train from scratch
|
||||||
petridish:
|
petridish:
|
||||||
convex_hull_eps: 0.025 # tolerance
|
convex_hull_eps: 0.025 # tolerance
|
||||||
max_madd: 20000000 # if any parent model reaches this many multiply-additions then the search is terminated or it reaches maximum number of parent pool size
|
max_madd: 200000000 # if any parent model reaches this many multiply-additions then the search is terminated or it reaches maximum number of parent pool size
|
||||||
max_hull_points: 100 # if the pool of parent models reaches this size then search is terminated or if it reaches max multiply-adds
|
max_hull_points: 100 # if the pool of parent models reaches this size then search is terminated or if it reaches max multiply-adds
|
||||||
checkpoints_foldername: '$expdir/petridish_search_checkpoints'
|
checkpoints_foldername: '$expdir/petridish_search_checkpoints'
|
||||||
|
search_iters: 4
|
||||||
pareto:
|
pareto:
|
||||||
max_cells: 10
|
max_cells: 8
|
||||||
max_reductions: 2
|
max_reductions: 3
|
||||||
max_nodes: 3
|
max_nodes: 3
|
||||||
enabled: True # if false then there will only be one seed model. if true a number of seed models with different number of cells, reductions and nodes will be used to initialize the search. this provides more coverage of the frontier.
|
enabled: True # if false then there will only be one seed model. if true a number of seed models with different number of cells, reductions and nodes will be used to initialize the search. this provides more coverage of the frontier.
|
||||||
model_desc:
|
model_desc:
|
||||||
|
@ -63,52 +42,21 @@ nas:
|
||||||
n_reductions: 1
|
n_reductions: 1
|
||||||
num_edges_to_sample: 2 # number of edges each node will take inputs from
|
num_edges_to_sample: 2 # number of edges each node will take inputs from
|
||||||
cell:
|
cell:
|
||||||
n_nodes: 1 # also used as min nodes to get combinations for seeding pareto
|
n_nodes: 1
|
||||||
cell_post_op: 'proj_channels'
|
cell_post_op: 'proj_channels'
|
||||||
model_stems:
|
|
||||||
ops: ['stem_conv3x3_s2', 'stem_conv3x3_s2']
|
|
||||||
stem_multiplier: 1 # output channels multiplier for the stem
|
|
||||||
init_node_ch: 32 # num of input/output channels for nodes in 1st cell
|
|
||||||
seed_train:
|
seed_train:
|
||||||
trainer:
|
trainer:
|
||||||
epochs: 80 # number of epochs model will be trained before search
|
epochs: 80 # number of epochs model will be trained before search
|
||||||
optimizer:
|
|
||||||
lr: 0.033
|
|
||||||
batch_chunks: 1
|
|
||||||
validation:
|
|
||||||
batch_chunks: 1
|
|
||||||
loader:
|
loader:
|
||||||
cutout: 6
|
train_batch: 128
|
||||||
train_batch: 32
|
|
||||||
test_batch: 32
|
|
||||||
img_size: 16
|
|
||||||
aug: ''
|
|
||||||
# dataset:
|
|
||||||
# max_batches: 32
|
|
||||||
post_train:
|
post_train:
|
||||||
trainer:
|
trainer:
|
||||||
epochs: 80 # number of epochs model will be trained after search
|
epochs: 80 # number of epochs model will be trained after search
|
||||||
optimizer:
|
|
||||||
lr: 0.033
|
|
||||||
loader:
|
loader:
|
||||||
train_batch: 32
|
train_batch: 96
|
||||||
cutout: 6
|
|
||||||
test_batch: 32
|
|
||||||
img_size: 16
|
|
||||||
aug: ''
|
|
||||||
# dataset:
|
|
||||||
# max_batches: 32
|
|
||||||
trainer:
|
trainer:
|
||||||
l1_alphas: 0.001 # as per paper
|
l1_alphas: 0.001 # as per paper
|
||||||
epochs: 80 # number of epochs model will be trained during search
|
epochs: 80 # number of epochs model will be trained during search
|
||||||
optimizer:
|
|
||||||
lr: 0.033
|
|
||||||
loader:
|
loader:
|
||||||
train_batch: 32
|
train_batch: 96
|
||||||
val_ratio: 0.2 #split portion for train set, 0 to 1
|
val_ratio: 0.2 #split portion for test set, 0 to 1
|
||||||
cutout: 6
|
|
||||||
test_batch: 32
|
|
||||||
img_size: 16
|
|
||||||
aug: ''
|
|
||||||
# dataset:
|
|
||||||
# max_batches: 32
|
|
Загрузка…
Ссылка в новой задаче