All-No_pareto functional again

This commit is contained in:
Shital Shah 2023-01-21 02:41:49 -08:00
Родитель fd1fb0ab44
Коммит 096b564983
24 изменённых файлов: 280 добавлений и 511 удалений

Просмотреть файл

@ -214,7 +214,7 @@ class ApexUtils:
else: else:
return val return val
def _get_optim(self, multi_optim:MultiOptim)->Optimizer: def _get_one_optim(self, multi_optim:MultiOptim)->Optimizer:
assert len(multi_optim)==1, \ assert len(multi_optim)==1, \
'Mixed precision is only supported for one optimizer' \ 'Mixed precision is only supported for one optimizer' \
f' but {len(multi_optim)} optimizers were supplied' f' but {len(multi_optim)} optimizers were supplied'
@ -234,7 +234,10 @@ class ApexUtils:
def step(self, multi_optim:MultiOptim)->None: def step(self, multi_optim:MultiOptim)->None:
if self.is_mixed(): if self.is_mixed():
self._scaler.step(self._get_optim(multi_optim)) # pyright: ignore[reportOptionalMemberAccess] # self._scaler.unscale_ will be called automatically if it isn't called yet from grad clipping
# https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.step
for optim_shed in multi_optim:
self._scaler.step(optim_shed.optim) # pyright: ignore[reportOptionalMemberAccess]
self._scaler.update() # pyright: ignore[reportOptionalMemberAccess] self._scaler.update() # pyright: ignore[reportOptionalMemberAccess]
else: else:
multi_optim.step() multi_optim.step()
@ -249,12 +252,13 @@ class ApexUtils:
model = model.to(self.device) model = model.to(self.device)
# scale LR # scale LR
optim = self._get_optim(multi_optim)
if self.is_dist() and self._scale_lr: if self.is_dist() and self._scale_lr:
lr = ml_utils.get_optim_lr(optim) for optim_shed in multi_optim:
scaled_lr = lr * self.world_size / float(batch_size) optim = optim_shed.optim
ml_utils.set_optim_lr(optim, scaled_lr) lr = ml_utils.get_optim_lr(optim)
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr}) scaled_lr = lr * self.world_size / float(batch_size)
ml_utils.set_optim_lr(optim, scaled_lr)
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
if self.is_dist(): if self.is_dist():
model = DistributedDataParallel(model, device_ids=[self._gpu], output_device=self._gpu) model = DistributedDataParallel(model, device_ids=[self._gpu], output_device=self._gpu)
@ -264,8 +268,8 @@ class ApexUtils:
def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None: def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None:
if clip > 0.0: if clip > 0.0:
if self.is_mixed(): if self.is_mixed():
optim = self._get_optim(multi_optim) # https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers
self._scaler.unscale_(optim) # pyright: ignore[reportOptionalMemberAccess] self._scaler.unscale_(multi_optim[0].optim) # pyright: ignore[reportOptionalMemberAccess]
nn.utils.clip_grad_norm_(model.parameters(), clip) nn.utils.clip_grad_norm_(model.parameters(), clip)
else: else:
nn.utils.clip_grad_norm_(model.parameters(), clip) nn.utils.clip_grad_norm_(model.parameters(), clip)

Просмотреть файл

@ -20,8 +20,8 @@ from archai.supergraph.nas.model import Model
from archai.supergraph.utils import ml_utils from archai.supergraph.utils import ml_utils
from archai.supergraph.utils.checkpoint import CheckPoint from archai.supergraph.utils.checkpoint import CheckPoint
from archai.supergraph.datasets import data from archai.supergraph.datasets import data
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.algos.darts.bilevel_optimizer import BilevelOptimizer from archai.supergraph.algos.darts.bilevel_optimizer import BilevelOptimizer
class BilevelArchTrainer(ArchTrainer): class BilevelArchTrainer(ArchTrainer):

Просмотреть файл

@ -12,8 +12,8 @@ from torch.optim.optimizer import Optimizer
from archai.common.config import Config from archai.common.config import Config
from archai.common import utils from archai.common import utils
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.common.utils import zip_eq from archai.common.utils import zip_eq
from archai.supergraph.utils import ml_utils from archai.supergraph.utils import ml_utils

Просмотреть файл

@ -12,8 +12,8 @@ from torch.optim.optimizer import Optimizer
from archai.common.config import Config from archai.common.config import Config
from archai.common import utils from archai.common import utils
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.common.utils import zip_eq from archai.common.utils import zip_eq
from archai.supergraph.utils import ml_utils from archai.supergraph.utils import ml_utils

Просмотреть файл

@ -19,8 +19,8 @@ from archai.common import utils
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.supergraph.utils import ml_utils from archai.supergraph.utils import ml_utils
from archai.supergraph.utils.checkpoint import CheckPoint from archai.supergraph.utils.checkpoint import CheckPoint
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.utils.multi_optim import MultiOptim, OptimSched from archai.supergraph.utils.multi_optim import MultiOptim, OptimSched
class DidartsArchTrainer(ArchTrainer): class DidartsArchTrainer(ArchTrainer):

Просмотреть файл

@ -10,8 +10,8 @@ from torch import nn
import numpy as np import numpy as np
from archai.common.common import get_conf from archai.common.common import get_conf
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.datasets.data import get_data from archai.supergraph.datasets.data import get_data
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.supergraph.nas.cell import Cell from archai.supergraph.nas.cell import Cell

Просмотреть файл

@ -14,8 +14,8 @@ import os
from archai.common.common import get_conf from archai.common.common import get_conf
from archai.common.common import get_expdir from archai.common.common import get_expdir
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.datasets.data import get_data from archai.supergraph.datasets.data import get_data
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.supergraph.nas.cell import Cell from archai.supergraph.nas.cell import Cell

Просмотреть файл

@ -20,8 +20,8 @@ from archai.common import utils
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.supergraph.utils import ml_utils from archai.supergraph.utils import ml_utils
from archai.supergraph.utils.checkpoint import CheckPoint from archai.supergraph.utils.checkpoint import CheckPoint
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.common.common import get_conf from archai.common.common import get_conf
from archai.supergraph.algos.gumbelsoftmax.gs_op import GsOp from archai.supergraph.algos.gumbelsoftmax.gs_op import GsOp

Просмотреть файл

@ -11,8 +11,8 @@ import os
from archai.common.common import get_conf from archai.common.common import get_conf
from archai.common.common import get_expdir from archai.common.common import get_expdir
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.datasets.data import get_data from archai.supergraph.datasets.data import get_data
from archai.supergraph.nas.model import Model from archai.supergraph.nas.model import Model
from archai.supergraph.nas.cell import Cell from archai.supergraph.nas.cell import Cell

Просмотреть файл

@ -17,8 +17,8 @@ from archai.supergraph.utils import ml_utils
from archai.supergraph.utils.trainer import Trainer from archai.supergraph.utils.trainer import Trainer
from archai.common.config import Config from archai.common.config import Config
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.datasets import data from archai.supergraph.datasets import data
from archai.supergraph.nas.model_desc import ModelDesc from archai.supergraph.nas.model_desc import ModelDesc
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder from archai.supergraph.nas.model_desc_builder import ModelDescBuilder

Просмотреть файл

@ -11,8 +11,8 @@ from overrides import overrides
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.common.config import Config from archai.common.config import Config
from archai.supergraph.nas.model_desc_builder import ModelDescBuilder from archai.supergraph.nas.model_desc_builder import ModelDescBuilder

Просмотреть файл

@ -77,12 +77,9 @@ class PetridishOp(Op):
'avg_pool_3x3', 'avg_pool_3x3',
'skip_connect', # identity 'skip_connect', # identity
'sep_conv_3x3', 'sep_conv_3x3',
#'sep_conv_5x5', 'sep_conv_5x5',
'dil_conv_3x3', 'dil_conv_3x3',
#'dil_conv_5x5', 'dil_conv_5x5',
'mbconv_r3',
'mbconv_r2',
'mbconv_r1',
'none' # this must be at the end so top1 doesn't chose it 'none' # this must be at the end so top1 doesn't chose it
] ]
@ -204,3 +201,4 @@ class PetridishOp(Op):
# we store alphas in list so Pytorch don't register them # we store alphas in list so Pytorch don't register them
self._alphas = list(self.arch_params().paramlist_by_kind('alphas')) self._alphas = list(self.arch_params().paramlist_by_kind('alphas'))
assert len(self._alphas)==1 assert len(self._alphas)==1

Просмотреть файл

@ -18,8 +18,8 @@ import matplotlib.pyplot as plt
from archai.supergraph.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \ from archai.supergraph.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
EdgeDesc, TensorShape, TensorShapes, NodeDesc, ModelDesc EdgeDesc, TensorShape, TensorShapes, NodeDesc, ModelDesc
from archai.supergraph.utils.metrics import Metrics from archai.supergraph.utils.metrics import Metrics
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.common import utils from archai.common import utils
class JobStage(Enum): class JobStage(Enum):

Просмотреть файл

@ -25,8 +25,8 @@ from torch.utils.data.dataloader import DataLoader
import yaml import yaml
from archai.common import common from archai.common import common
from archai.common.logger import Logger
logger = Logger(source=__name__), from archai.common.common import logger
from archai.common.common import CommonState from archai.common.common import CommonState
from archai.supergraph.utils.checkpoint import CheckPoint from archai.supergraph.utils.checkpoint import CheckPoint
from archai.common.config import Config from archai.common.config import Config

Просмотреть файл

@ -21,8 +21,8 @@ from archai.supergraph.nas.model import Model
from archai.supergraph.nas.model_desc import CellType from archai.supergraph.nas.model_desc import CellType
from archai.supergraph.utils import ml_utils from archai.supergraph.utils import ml_utils
from archai.supergraph.utils.checkpoint import CheckPoint from archai.supergraph.utils.checkpoint import CheckPoint
from archai.common.logger import Logger
logger = Logger(source=__name__) from archai.common.common import logger
from archai.supergraph.datasets import data from archai.supergraph.datasets import data
from archai.common.common import get_conf from archai.common.common import get_conf
from archai.supergraph.algos.xnas.xnas_op import XnasOp from archai.supergraph.algos.xnas.xnas_op import XnasOp

Просмотреть файл

@ -11,9 +11,9 @@ from torchvision.transforms import transforms
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from torchvision.datasets.utils import check_integrity, download_url from torchvision.datasets.utils import check_integrity, download_url
from archai.common.utils import download_and_extract_tar, extract_tar from archai.common.utils import download_and_extract_tar, extract_tar
from archai.common.logger import Logger
logger = Logger(source=__name__)
from archai.common.common import logger

Просмотреть файл

@ -4,19 +4,18 @@ from torch import nn
from torch.nn import DataParallel from torch.nn import DataParallel
# from torchvision import models # from torchvision import models
from archai.supergraph.nas.models.resnet import ResNet from .pyramidnet import PyramidNet
from archai.supergraph.nas.models.pyramidnet import PyramidNet from .shakeshake.shake_resnet import ShakeResNet
from archai.supergraph.nas.models.shakeshake.shake_resnet import ShakeResNet from .wideresnet import WideResNet
from archai.supergraph.nas.models.wideresnet import WideResNet from .shakeshake.shake_resnext import ShakeResNeXt
from archai.supergraph.nas.models.shakeshake.shake_resnext import ShakeResNeXt
from archai.supergraph.nas.models.mobilenetv2 import * from .mobilenetv2 import *
from archai.supergraph.nas.models.resnet_cifar10 import * from .resnet import *
from archai.supergraph.nas.models.vgg import * from .vgg import *
from archai.supergraph.nas.models.densenet import * from .densenet import *
from archai.supergraph.nas.models.resnet_orig import * from .resnet_orig import *
from archai.supergraph.nas.models.googlenet import * from .googlenet import *
from archai.supergraph.nas.models.inception import * from .inception import *
def get_model(conf, num_class=10): def get_model(conf, num_class=10):

Просмотреть файл

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import math import math
from archai.supergraph.nas.models.shakedrop import ShakeDrop from .shakedrop import ShakeDrop
def conv3x3(in_planes, out_planes, stride=1): def conv3x3(in_planes, out_planes, stride=1):

Просмотреть файл

@ -1,31 +1,44 @@
# Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py import torch
import torch.nn as nn import torch.nn as nn
import math import os
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
def conv3x3(in_planes, out_planes, stride=1): def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"3x3 convolution with padding" """3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False) padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride) self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = norm_layer(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
def forward(self, x): def forward(self, x):
residual = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -35,9 +48,9 @@ class BasicBlock(nn.Module):
out = self.bn2(out) out = self.bn2(out)
if self.downsample is not None: if self.downsample is not None:
residual = self.downsample(x) identity = self.downsample(x)
out += residual out += identity
out = self.relu(out) out = self.relu(out)
return out return out
@ -46,22 +59,25 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
if norm_layer is None:
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) norm_layer = nn.BatchNorm2d
self.bn1 = nn.BatchNorm2d(planes) width = int(planes * (base_width / 64.)) * groups
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) # Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.bn2 = nn.BatchNorm2d(planes) self.conv1 = conv1x1(inplanes, width)
self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) self.bn1 = norm_layer(width)
self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
def forward(self, x): def forward(self, x):
residual = x identity = x
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
@ -73,108 +89,199 @@ class Bottleneck(nn.Module):
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out) out = self.relu(out)
return out return out
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, dataset, depth, n_classes, bottleneck=False):
def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.dataset = dataset if norm_layer is None:
if self.dataset.startswith('cifar'): norm_layer = nn.BatchNorm2d
self.inplanes = 16 self._norm_layer = norm_layer
#logger.info(bottleneck)
if bottleneck == True:
n = int((depth - 2) / 9)
block = Bottleneck
else:
n = int((depth - 2) / 6)
block = BasicBlock
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.inplanes = 64
self.bn1 = nn.BatchNorm2d(self.inplanes) self.dilation = 1
self.relu = nn.ReLU(inplace=True) if replace_stride_with_dilation is None:
self.layer1 = self._make_layer(block, 16, n) # each element in the tuple indicates if we should replace
self.layer2 = self._make_layer(block, 32, n, stride=2) # the 2x2 stride with a dilated convolution instead
self.layer3 = self._make_layer(block, 64, n, stride=2) replace_stride_with_dilation = [False, False, False]
# self.avgpool = nn.AvgPool2d(8) if len(replace_stride_with_dilation) != 3:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) raise ValueError("replace_stride_with_dilation should be None "
self.fc = nn.Linear(64 * block.expansion, n_classes) "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
elif dataset == 'imagenet': ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} ## END
assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)'
self.inplanes = 64 self.bn1 = norm_layer(self.inplanes)
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(block, 64, layers[0])
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) dilate=replace_stride_with_dilation[0])
self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
# self.avgpool = nn.AvgPool2d(7) dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * blocks[depth].expansion, n_classes) self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1)
m.weight.data.fill_(1) nn.init.constant_(m.bias, 0)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1): # Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, conv1x1(self.inplanes, planes * block.expansion, stride),
kernel_size=1, stride=stride, bias=False), norm_layer(planes * block.expansion),
nn.BatchNorm2d(planes * block.expansion),
) )
layers = [] layers = []
layers.append(block(self.inplanes, planes, stride, downsample)) layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for i in range(1, blocks): for _ in range(1, blocks):
layers.append(block(self.inplanes, planes)) layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
if self.dataset == 'cifar10' or self.dataset == 'cifar100': x = self.conv1(x)
x = self.conv1(x) x = self.bn1(x)
x = self.bn1(x) x = self.relu(x)
x = self.relu(x) x = self.maxpool(x)
x = self.layer1(x) x = self.layer1(x)
x = self.layer2(x) x = self.layer2(x)
x = self.layer3(x) x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x) x = self.avgpool(x)
x = x.view(x.size(0), -1) x = x.reshape(x.size(0), -1)
x = self.fc(x) x = self.fc(x)
elif self.dataset == 'imagenet':
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x return x
def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
script_dir = os.path.dirname(__file__)
state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
**kwargs)
def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
**kwargs)
def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
**kwargs)
def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
**kwargs)
def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNeXt-50 32x4d model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, device, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, device, **kwargs)

Просмотреть файл

@ -1,287 +0,0 @@
import torch
import torch.nn as nn
import os
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
## END
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.size(0), -1)
x = self.fc(x)
return x
def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
script_dir = os.path.dirname(__file__)
state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
**kwargs)
def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
**kwargs)
def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
**kwargs)
def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
**kwargs)
def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNeXt-50 32x4d model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, device, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, device, **kwargs)

Просмотреть файл

@ -5,7 +5,7 @@ import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from archai.supergraph.nas.models.shakeshake.shakeshake import ShakeShake, Shortcut from .shakeshake import ShakeShake, Shortcut
class ShakeBlock(nn.Module): class ShakeBlock(nn.Module):

Просмотреть файл

@ -5,7 +5,7 @@ import math
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from archai.supergraph.nas.models.shakeshake.shakeshake import ShakeShake, Shortcut from .shakeshake import ShakeShake, Shortcut
class ShakeBottleNeck(nn.Module): class ShakeBottleNeck(nn.Module):

Просмотреть файл

@ -95,9 +95,9 @@ class Evaluater(EnforceOverrides):
# TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead # TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead
if dataset_name.startswith('cifar'): if dataset_name.startswith('cifar'):
if function_name.startswith('res'): # support resnext as well if function_name.startswith('res'): # support resnext as well
module_name = 'archai.cifar10_models.resnet' module_name = 'archai.supergraph.models.resnet'
elif function_name.startswith('dense'): elif function_name.startswith('dense'):
module_name = 'archai.cifar10_models.densenet' module_name = 'archai.supergraph.models.densenet'
elif dataset_name.startswith('imagenet') or dataset_name.startswith('sport8'): elif dataset_name.startswith('imagenet') or dataset_name.startswith('sport8'):
module_name = 'torchvision.models' module_name = 'torchvision.models'
if not module_name: if not module_name:

Просмотреть файл

@ -3,8 +3,6 @@ __include__: 'darts.yaml' # defaults are loaded from this file
common: common:
#yaml_log: False #yaml_log: False
apex: apex:
enabled: False # global switch to disable everything apex
distributed_enabled: False # enable/disable distributed mode
ray: ray:
enabled: True # initialize ray. Note: ray cannot be used if apex distributed is enabled enabled: True # initialize ray. Note: ray cannot be used if apex distributed is enabled
local_mode: False # if True then ray runs in serial mode local_mode: False # if True then ray runs in serial mode
@ -12,50 +10,31 @@ common:
nas: nas:
eval: eval:
final_desc_foldername: '$expdir/model_desc_gallery' # final_desc_foldername: '$expdir/model_desc_gallery' #
source_desc_foldername: '$expdir/model_desc_gallery'
model_desc: model_desc:
n_reductions: 2 # number of reductions to be applied n_reductions: 2 # number of reductions to be applied
n_cells: 10 # number of max cells, for pareto frontier, we use cell_count_scale to multiply cells and limit by n_cells n_cells: 20 # number of max cells, for pareto frontier, we use cell_count_scale to multiply cells and limit by n_cells
aux_weight: 0.0 # weight for loss from auxiliary towers in test time arch aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
num_edges_to_sample: 2 # number of edges each node will take inputs from num_edges_to_sample: 2 # number of edges each node will take inputs from
aux_tower_stride: 3
model_stems: model_stems:
ops: ['stem_conv3x3_s2', 'stem_conv3x3_s2'] init_node_ch: 36 # num of input/output channels for nodes in 1st cell
init_node_ch: 32 # num of input/output channels for nodes in 1st cell
stem_multiplier: 1 # output channels multiplier for the stem
cell: cell:
n_nodes: 5 # number of nodes in a cell if template desc is not provided n_nodes: 5 # number of nodes in a cell if template desc is not provided
cell_post_op: 'proj_channels' cell_post_op: 'proj_channels'
petridish: petridish:
cell_count_scale: 1.0 # for eval first multiply number of cells used in search by this factor, limit to n_cells cell_count_scale: 1.0 # for eval first multiply number of cells used in search by this factor, limit to n_cells
trainer: trainer:
aux_weight: 0.0 epochs: 600
epochs: 1500
batch_chunks: 1
validation:
batch_chunks: 1
optimizer:
lr: 0.033
loader:
cutout: 6 # cutout length, use cutout augmentation when > 0
load_train: True # load train split of dataset
train_batch: 32
test_batch: 32
img_size: 16
aug: 'autoaug_cifar10'
# dataset:
# max_batches: 32
search: search:
final_desc_foldername: '$expdir/model_desc_gallery' # the gallery of models that eval will train from scratch final_desc_foldername: '$expdir/model_desc_gallery' # the gallery of models that eval will train from scratch
petridish: petridish:
convex_hull_eps: 0.025 # tolerance convex_hull_eps: 0.025 # tolerance
max_madd: 20000000 # if any parent model reaches this many multiply-additions then the search is terminated or it reaches maximum number of parent pool size max_madd: 200000000 # if any parent model reaches this many multiply-additions then the search is terminated or it reaches maximum number of parent pool size
max_hull_points: 100 # if the pool of parent models reaches this size then search is terminated or if it reaches max multiply-adds max_hull_points: 100 # if the pool of parent models reaches this size then search is terminated or if it reaches max multiply-adds
checkpoints_foldername: '$expdir/petridish_search_checkpoints' checkpoints_foldername: '$expdir/petridish_search_checkpoints'
search_iters: 4
pareto: pareto:
max_cells: 10 max_cells: 8
max_reductions: 2 max_reductions: 3
max_nodes: 3 max_nodes: 3
enabled: True # if false then there will only be one seed model. if true a number of seed models with different number of cells, reductions and nodes will be used to initialize the search. this provides more coverage of the frontier. enabled: True # if false then there will only be one seed model. if true a number of seed models with different number of cells, reductions and nodes will be used to initialize the search. this provides more coverage of the frontier.
model_desc: model_desc:
@ -63,52 +42,21 @@ nas:
n_reductions: 1 n_reductions: 1
num_edges_to_sample: 2 # number of edges each node will take inputs from num_edges_to_sample: 2 # number of edges each node will take inputs from
cell: cell:
n_nodes: 1 # also used as min nodes to get combinations for seeding pareto n_nodes: 1
cell_post_op: 'proj_channels' cell_post_op: 'proj_channels'
model_stems:
ops: ['stem_conv3x3_s2', 'stem_conv3x3_s2']
stem_multiplier: 1 # output channels multiplier for the stem
init_node_ch: 32 # num of input/output channels for nodes in 1st cell
seed_train: seed_train:
trainer: trainer:
epochs: 80 # number of epochs model will be trained before search epochs: 80 # number of epochs model will be trained before search
optimizer:
lr: 0.033
batch_chunks: 1
validation:
batch_chunks: 1
loader: loader:
cutout: 6 train_batch: 128
train_batch: 32
test_batch: 32
img_size: 16
aug: ''
# dataset:
# max_batches: 32
post_train: post_train:
trainer: trainer:
epochs: 80 # number of epochs model will be trained after search epochs: 80 # number of epochs model will be trained after search
optimizer:
lr: 0.033
loader: loader:
train_batch: 32 train_batch: 96
cutout: 6
test_batch: 32
img_size: 16
aug: ''
# dataset:
# max_batches: 32
trainer: trainer:
l1_alphas: 0.001 # as per paper l1_alphas: 0.001 # as per paper
epochs: 80 # number of epochs model will be trained during search epochs: 80 # number of epochs model will be trained during search
optimizer:
lr: 0.033
loader: loader:
train_batch: 32 train_batch: 96
val_ratio: 0.2 #split portion for train set, 0 to 1 val_ratio: 0.2 #split portion for test set, 0 to 1
cutout: 6
test_batch: 32
img_size: 16
aug: ''
# dataset:
# max_batches: 32