зеркало из https://github.com/microsoft/archai.git
distributed and mixed precision enablement
This commit is contained in:
Родитель
cec9c16780
Коммит
ee99587b14
|
@ -15,13 +15,13 @@ from archai.nas.arch_trainer import ArchTrainer
|
|||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.common import logger
|
||||
from archai.common.common import logger, get_device
|
||||
from .bilevel_optimizer import BilevelOptimizer
|
||||
|
||||
class BilevelArchTrainer(ArchTrainer):
|
||||
def __init__(self, conf_train: Config, model: Model, device,
|
||||
def __init__(self, conf_train: Config, model: Model,
|
||||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, device, checkpoint)
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._conf_w_optim = conf_train['optimizer']
|
||||
self._conf_w_lossfn = conf_train['lossfn']
|
||||
|
@ -36,7 +36,7 @@ class BilevelArchTrainer(ArchTrainer):
|
|||
assert val_dl is not None
|
||||
w_momentum = self._conf_w_optim['momentum']
|
||||
w_decay = self._conf_w_optim['decay']
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.device)
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
|
||||
|
||||
self._bilevel_optim = BilevelOptimizer(self._conf_alpha_optim, w_momentum,
|
||||
w_decay, self.model, lossfn)
|
||||
|
@ -71,8 +71,7 @@ class BilevelArchTrainer(ArchTrainer):
|
|||
self._valid_iter = iter(self._val_dl)
|
||||
x_val, y_val = next(self._valid_iter)
|
||||
|
||||
x_val, y_val = x_val.to(self.device), y_val.to(
|
||||
self.device, non_blocking=True)
|
||||
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
|
||||
|
||||
# update alphas
|
||||
self._bilevel_optim.step(x, y, x_val, y_val, super().get_optimizer())
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.optim.optimizer import Optimizer
|
|||
from archai.common.config import Config
|
||||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.common import logger
|
||||
from archai.common.common import logger, get_device
|
||||
|
||||
class BilevelOptimizer:
|
||||
def __init__(self, conf_alpha_optim:Config, w_momentum: float, w_decay: float,
|
||||
|
@ -22,7 +22,7 @@ class BilevelOptimizer:
|
|||
# create a copy of model which we will use
|
||||
# to compute grads for alphas without disturbing
|
||||
# original weights
|
||||
self._vmodel = copy.deepcopy(model)
|
||||
self._vmodel = copy.deepcopy(model).to(get_device())
|
||||
# this is the optimizer to optimize alphas parameter
|
||||
self._alpha_optim = ml_utils.create_optimizer(conf_alpha_optim, model.alphas())
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ from archai.common.common import logger
|
|||
|
||||
|
||||
class GsArchTrainer(ArchTrainer):
|
||||
def __init__(self, conf_train: Config, model: Model, device,
|
||||
def __init__(self, conf_train: Config, model: Model,
|
||||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, device, checkpoint)
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._conf_w_optim = conf_train['optimizer']
|
||||
# self._conf_w_lossfn = conf_train['lossfn']
|
||||
|
|
|
@ -15,13 +15,13 @@ from archai.nas.arch_trainer import ArchTrainer
|
|||
from archai.common import utils, ml_utils
|
||||
from archai.nas.model import Model
|
||||
from archai.common.checkpoint import CheckPoint
|
||||
from archai.common.common import logger
|
||||
from archai.common.common import logger, get_device
|
||||
|
||||
|
||||
class XnasArchTrainer(ArchTrainer):
|
||||
def __init__(self, conf_train: Config, model: Model, device,
|
||||
def __init__(self, conf_train: Config, model: Model,
|
||||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, device, checkpoint)
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._conf_w_optim = conf_train['optimizer']
|
||||
self._conf_w_lossfn = conf_train['lossfn']
|
||||
|
@ -39,7 +39,7 @@ class XnasArchTrainer(ArchTrainer):
|
|||
# optimizers, schedulers needs to be recreated for each fit call
|
||||
# as they have state
|
||||
assert val_dl is not None
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.device)
|
||||
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(get_device())
|
||||
|
||||
self._xnas_optim = _XnasOptimizer(self._conf_alpha_optim, self.model, lossfn)
|
||||
|
||||
|
@ -74,8 +74,7 @@ class XnasArchTrainer(ArchTrainer):
|
|||
self._valid_iter = iter(self._val_dl)
|
||||
x_val, y_val = next(self._valid_iter)
|
||||
|
||||
x_val, y_val = x_val.to(self.device), y_val.to(
|
||||
self.device, non_blocking=True)
|
||||
x_val, y_val = x_val.to(get_device()), y_val.to(get_device(), non_blocking=True)
|
||||
|
||||
# update alphas
|
||||
self._xnas_optim.step(x, y, x_val, y_val)
|
||||
|
|
|
@ -1,47 +1,182 @@
|
|||
from typing import Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch import Tensor, nn
|
||||
from torch import nn
|
||||
from torch.backends import cudnn
|
||||
import torch.distributed as dist
|
||||
|
||||
from .common import logger
|
||||
from archai.common.config import Config
|
||||
|
||||
class Amp:
|
||||
from archai.common import ml_utils, utils
|
||||
from archai.common.ordereddict_logger import OrderedDictLogger
|
||||
|
||||
class ApexUtils:
|
||||
_warning_shown = False
|
||||
def __init__(self, use_amp:bool)->None:
|
||||
self._use_amp = use_amp
|
||||
self._amp = None
|
||||
|
||||
if self._use_amp:
|
||||
try:
|
||||
from apex import amp
|
||||
self._amp = amp
|
||||
logger.warn({'apex': True})
|
||||
except ModuleNotFoundError:
|
||||
if not Amp._warning_shown:
|
||||
logger.warn({'apex': False})
|
||||
Amp._warning_shown = True
|
||||
self._amp = None
|
||||
else:
|
||||
pass # do not disable if already enabled as other callers may be using it
|
||||
def __init__(self, distdir:Optional[str], apex_config:Config)->None:
|
||||
logger = self._create_init_logger(distdir)
|
||||
|
||||
def available(self)->bool:
|
||||
# region conf vars
|
||||
self._enabled = apex_config['enabled'] # global switch to disable anything apex
|
||||
self._opt_level = apex_config['opt_level'] # optimization level for mixed precision
|
||||
self._bn_fp32 = apex_config['bn_fp32'] # keep BN in fp32
|
||||
self._loss_scale = apex_config['loss_scale'] # loss scaling mode for mixed prec
|
||||
self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model
|
||||
self._distributed = apex_config['distributed'] # enable/disable distributed mode
|
||||
self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode
|
||||
self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting
|
||||
conf_gpu_ids = apex_config['gpus']
|
||||
# endregion
|
||||
|
||||
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
|
||||
self._amp, self._ddp = None, None
|
||||
self._gpu = self.gpu_ids[0] if len(self.gpu_ids) else 0 # which GPU to use, we will use only 1 GPU
|
||||
self._world_size = 1 # total number of processes in distributed run
|
||||
self.local_rank = 0
|
||||
self.global_rank = 0
|
||||
|
||||
logger.info({'apex_config': apex_config.to_dict()})
|
||||
logger.info({'torch.distributed_is_available': dist.is_available()})
|
||||
if dist.is_available():
|
||||
logger.info({'gloo_available': dist.is_gloo_available(),
|
||||
'mpi_available': dist.is_mpi_available(),
|
||||
'nccl_available': dist.is_nccl_available()})
|
||||
|
||||
if self._enabled:
|
||||
# init enable mixed precision
|
||||
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
|
||||
from apex import amp
|
||||
self._amp = amp
|
||||
|
||||
# enable distributed processing
|
||||
if self._distributed:
|
||||
assert dist.is_available() # distributed module is available
|
||||
assert dist.is_nccl_available()
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
assert dist.is_initialized()
|
||||
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self._world_size = int(os.environ['WORLD_SIZE'])
|
||||
assert dist.get_world_size() == self._world_size
|
||||
else:
|
||||
raise RuntimeError('WORLD_SIZE must be set by distributed launcher when distributed mode is enabled')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--local-rank', type=int, help='local-rank must be supplied by torch distributed launcher!')
|
||||
args, extra_args = parser.parse_known_args()
|
||||
|
||||
self.local_rank = args.local_rank
|
||||
self.global_rank = dist.get_rank()
|
||||
|
||||
from apex import parallel
|
||||
self._ddp = parallel
|
||||
assert self.local_rank < torch.cuda.device_count()
|
||||
self._gpu = self.local_rank # reset to default assignment for rank
|
||||
# remap if GPU IDs are specified
|
||||
if len(self.gpu_ids):
|
||||
assert len(self.gpu_ids) > self.local_rank
|
||||
self._gpu = self.gpu_ids[self.local_rank]
|
||||
|
||||
assert self._world_size >= 1
|
||||
assert not self._min_world_size or self._world_size >= self._min_world_size
|
||||
assert self.local_rank >= 0 and self.local_rank < self._world_size
|
||||
assert self.global_rank >= 0 and self.global_rank < self._world_size
|
||||
|
||||
assert self._gpu < torch.cuda.device_count()
|
||||
torch.cuda.set_device(self._gpu)
|
||||
self.device = torch.device(f'cuda:{self._gpu}')
|
||||
|
||||
logger.info({'amp_available': self._amp is not None, 'distributed_available': self._distributed is not None})
|
||||
logger.info({'distributed': self._distributed, 'world_size': self._world_size,
|
||||
'gpu': self._gpu, 'gpu_ids':self.gpu_ids, 'local_rank': self.local_rank})
|
||||
|
||||
logger.info({'dist_initialized': dist.is_initialized() if dist.is_available() else False})
|
||||
|
||||
logger.close()
|
||||
|
||||
|
||||
def _create_init_logger(self, distdir:Optional[str])->OrderedDictLogger:
|
||||
# create PID specific logger to support many distributed processes
|
||||
init_log_filepath, yaml_log_filepath = None, None
|
||||
if distdir:
|
||||
init_log_filepath = os.path.join(utils.full_path(distdir),
|
||||
'apex_' + str(os.getpid()) + '.log')
|
||||
yaml_log_filepath = os.path.join(utils.full_path(distdir),
|
||||
'apex_' + str(os.getpid()) + '.yaml')
|
||||
|
||||
sys_logger = utils.create_logger(filepath=init_log_filepath)
|
||||
if not init_log_filepath:
|
||||
sys_logger.warn('logdir not specified, no logs will be created or any models saved')
|
||||
|
||||
logger = OrderedDictLogger(yaml_log_filepath, sys_logger)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def set_replica_logger(self, logger:OrderedDictLogger)->None:
|
||||
# To avoid circular dependency we don't reference logger in common.
|
||||
# Furthermore, each replica has its own logger but sharing same exp directory.
|
||||
# We can't create replica specific logger at time of init so this is set later.
|
||||
self.logger = logger
|
||||
|
||||
def amp_available(self)->bool:
|
||||
return self._amp is not None
|
||||
def dist_available(self)->bool:
|
||||
return self._ddp is not None
|
||||
def is_master(self)->bool:
|
||||
return self.global_rank == 0
|
||||
|
||||
def backward(self, loss:Tensor, optim:Optimizer)->None:
|
||||
def sync_dist(self)->None:
|
||||
if self._distributed:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def reduce_tensor(self, tensor:torch.Tensor):
|
||||
if self._distributed:
|
||||
rt = tensor.clone()
|
||||
torch.dist.all_reduce(rt, op=torch.dist.reduce_op.SUM)
|
||||
rt /= self._world_size
|
||||
return rt
|
||||
else:
|
||||
return tensor.data
|
||||
|
||||
def backward(self, loss:torch.Tensor, optim:Optimizer)->None:
|
||||
if self._amp:
|
||||
with self._amp.scale_loss(loss, optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def to_amp(self, model:nn.Module, optim:Optimizer, opt_level="O2",
|
||||
keep_batchnorm_fp32=True, loss_scale="dynamic")\
|
||||
def to_amp(self, model:nn.Module, optim:Optimizer, batch_size:int)\
|
||||
->Tuple[nn.Module, Optimizer]:
|
||||
# conver BNs to sync BNs in distributed mode
|
||||
if self._ddp and self._sync_bn:
|
||||
model = self._ddp.convert_syncbn_model(model)
|
||||
self.logger.info({'BNs_converted': True})
|
||||
|
||||
model = model.to(self.device)
|
||||
|
||||
if self._amp:
|
||||
# scale LR
|
||||
if self._scale_lr:
|
||||
lr = ml_utils.get_optim_lr(optim)
|
||||
scaled_lr = lr * self._world_size / float(batch_size)
|
||||
ml_utils.set_optim_lr(optim, scaled_lr)
|
||||
self.logger.info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
|
||||
|
||||
model, optim = self._amp.initialize(
|
||||
model, optim, opt_level=opt_level,
|
||||
keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale=loss_scale
|
||||
model, optim, opt_level=self._opt_level,
|
||||
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
|
||||
)
|
||||
|
||||
if self._ddp:
|
||||
# By default, apex.parallel.DistributedDataParallel overlaps communication with
|
||||
# computation in the backward pass.
|
||||
# delay_allreduce delays all communication to the end of the backward pass.
|
||||
model = self._ddp.DistributedDataParallel(model, delay_allreduce=True)
|
||||
|
||||
return model, optim
|
||||
|
||||
def clip_grad(self, clip:float, model:nn.Module, optim:Optimizer)->None:
|
||||
|
@ -59,12 +194,10 @@ class Amp:
|
|||
|
||||
def load_state_dict(self, state_dict):
|
||||
if self._amp:
|
||||
if state_dict is None:
|
||||
raise RuntimeError('checkpoint state_dict is None but Nvidia apex (amp) is enabled')
|
||||
self._amp.load_state_dict()
|
||||
else:
|
||||
if state_dict is not None:
|
||||
raise RuntimeError('checkpoint state_dict is not None but Nvidia apex (amp) is not enabled')
|
||||
raise RuntimeError('checkpoint state_dict is not None but Nvidia apex (amp) is not ')
|
||||
else:
|
||||
pass
|
||||
|
||||
|
|
|
@ -4,17 +4,20 @@ import os
|
|||
from typing import List, Iterable, Union, Optional, Tuple
|
||||
import atexit
|
||||
import subprocess
|
||||
|
||||
import datetime
|
||||
import yaml
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.backends.cudnn as cudnn
|
||||
import psutil
|
||||
|
||||
from .config import Config
|
||||
from .stopwatch import StopWatch
|
||||
from . import utils
|
||||
from .ordereddict_logger import OrderedDictLogger
|
||||
from .apex_utils import ApexUtils
|
||||
|
||||
class SummaryWriterDummy:
|
||||
def __init__(self, log_dir):
|
||||
|
@ -28,11 +31,21 @@ class SummaryWriterDummy:
|
|||
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
|
||||
logger = OrderedDictLogger(None, None)
|
||||
_tb_writer: SummaryWriterAny = None
|
||||
_apex_utils = None
|
||||
_atexit_reg = False # is hook for atexit registered?
|
||||
|
||||
def get_conf()->Config:
|
||||
return Config.get()
|
||||
|
||||
def get_device():
|
||||
global _apex_utils
|
||||
return _apex_utils.device
|
||||
|
||||
def get_apex_utils()->ApexUtils:
|
||||
global _apex_utils
|
||||
assert _apex_utils
|
||||
return _apex_utils
|
||||
|
||||
def get_conf_common()->Config:
|
||||
return get_conf()['common']
|
||||
|
||||
|
@ -85,30 +98,51 @@ def _setup_pt(param_args: list)->Tuple[str,str, list]:
|
|||
|
||||
# initializes random number gen, debugging etc
|
||||
def common_init(config_filepath: Optional[str]=None,
|
||||
param_args: list = [],
|
||||
log_level=logging.INFO, is_master=True, use_args=True) \
|
||||
-> Config:
|
||||
param_args: list = [], log_level=logging.INFO, use_args=True)->Config:
|
||||
|
||||
pt_data_dir, pt_output_dir, param_args = _setup_pt(param_args)
|
||||
# get cloud dirs if any
|
||||
pt_data_dir, pt_output_dir, param_overrides = _setup_pt(param_args)
|
||||
|
||||
# init config
|
||||
conf = Config(config_filepath=config_filepath,
|
||||
param_args=param_args,
|
||||
param_args=param_overrides,
|
||||
use_args=use_args)
|
||||
Config.set(conf)
|
||||
|
||||
sw = StopWatch()
|
||||
StopWatch.set(sw)
|
||||
|
||||
expdir = _setup_dirs()
|
||||
_setup_logger()
|
||||
# create experiment dir
|
||||
_setup_dirs()
|
||||
|
||||
# validate and log dirs
|
||||
expdir = get_expdir()
|
||||
assert not pt_output_dir or not expdir.startswith(utils.full_path('~/logdir'))
|
||||
logger.info({'expdir': expdir,
|
||||
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
|
||||
|
||||
# set up amp, apex, mixed-prec, distributed training stubs
|
||||
_setup_apex()
|
||||
# create global logger
|
||||
_setup_logger()
|
||||
# init GPU settings
|
||||
_setup_gpus()
|
||||
# create info file for current system
|
||||
_create_sysinfo(conf)
|
||||
|
||||
if expdir:
|
||||
# setup tensorboard
|
||||
global _tb_writer
|
||||
_tb_writer = _create_tb_writer(get_apex_utils().is_master())
|
||||
|
||||
# create hooks to execute code when script exits
|
||||
global _atexit_reg
|
||||
if not _atexit_reg:
|
||||
atexit.register(on_app_exit)
|
||||
_atexit_reg = True
|
||||
|
||||
return conf
|
||||
|
||||
def _create_sysinfo(conf:Config)->None:
|
||||
expdir = get_expdir()
|
||||
|
||||
if expdir and not utils.is_debugging():
|
||||
# copy net config to experiment folder for reference
|
||||
with open(expdir_abspath('config_used.yaml'), 'w') as f:
|
||||
yaml.dump(conf.to_dict(), f)
|
||||
|
@ -118,16 +152,6 @@ def common_init(config_filepath: Optional[str]=None,
|
|||
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
shell=True)
|
||||
|
||||
global _tb_writer
|
||||
_tb_writer = _create_tb_writer(is_master)
|
||||
|
||||
global _atexit_reg
|
||||
if not _atexit_reg:
|
||||
atexit.register(on_app_exit)
|
||||
_atexit_reg = True
|
||||
|
||||
return conf
|
||||
|
||||
def expdir_abspath(path:str, create=False)->str:
|
||||
"""Returns full path for given relative path within experiment directory."""
|
||||
return utils.full_path(os.path.join('$expdir',path), create=create)
|
||||
|
@ -161,45 +185,74 @@ def _setup_dirs()->Optional[str]:
|
|||
logdir = utils.full_path(logdir)
|
||||
expdir = os.path.join(logdir, experiment_name)
|
||||
os.makedirs(expdir, exist_ok=True)
|
||||
|
||||
# directory for non-master replica logs
|
||||
distdir = os.path.join(expdir, 'dist')
|
||||
os.makedirs(distdir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError('The logdir setting must be specified for the output directory in yaml')
|
||||
|
||||
# update conf so everyone gets expanded full paths from here on
|
||||
conf_common['logdir'], conf_data['dataroot'], conf_common['expdir'] = \
|
||||
logdir, dataroot, expdir
|
||||
|
||||
# set environment variable so it can be referenced in paths used in config
|
||||
os.environ['expdir'] = expdir
|
||||
os.environ['logdir'] = conf_common['logdir'] = logdir
|
||||
os.environ['dataroot'] = conf_data['dataroot'] = dataroot
|
||||
os.environ['expdir'] = conf_common['expdir'] = expdir
|
||||
os.environ['distdir'] = conf_common['distdir'] = distdir
|
||||
|
||||
return expdir
|
||||
|
||||
def _setup_logger():
|
||||
global logger
|
||||
logger.close() # close any previous instances
|
||||
|
||||
experiment_name = get_experiment_name()
|
||||
conf_common = get_conf_common()
|
||||
expdir = conf_common['expdir']
|
||||
distdir = conf_common['distdir']
|
||||
global_rank = get_apex_utils().global_rank
|
||||
|
||||
# file where logger would log messages
|
||||
sys_log_filepath = expdir_abspath('logs.log')
|
||||
sys_logger = utils.setup_logging(filepath=sys_log_filepath, name=experiment_name)
|
||||
if get_apex_utils().is_master():
|
||||
sys_log_filepath = utils.full_path(os.path.join(expdir, 'logs.log'))
|
||||
logs_yaml_filepath = utils.full_path(os.path.join(expdir, 'logs.yaml'))
|
||||
experiment_name = get_experiment_name()
|
||||
enable_stdout = True
|
||||
else:
|
||||
sys_log_filepath = utils.full_path(os.path.join(distdir, f'logs_{global_rank}.log'))
|
||||
logs_yaml_filepath = utils.full_path(os.path.join(distdir, f'logs_{global_rank}.yaml'))
|
||||
experiment_name = get_experiment_name() + '_' + str(global_rank)
|
||||
enable_stdout = False
|
||||
print('No stdout logging for replica {global_rank}')
|
||||
|
||||
sys_logger = utils.create_logger(filepath=sys_log_filepath,
|
||||
name=experiment_name,
|
||||
enable_stdout=enable_stdout)
|
||||
if not sys_log_filepath:
|
||||
sys_logger.warn(
|
||||
'logdir not specified, no logs will be created or any models saved')
|
||||
|
||||
# We need to create ApexUtils before we have logger. Now that we have logger
|
||||
# lets give it to ApexUtils
|
||||
get_apex_utils().set_replica_logger(logger)
|
||||
|
||||
# reset to new file path
|
||||
logs_yaml_filepath = expdir_abspath('logs.yaml')
|
||||
logger.reset(logs_yaml_filepath, sys_logger)
|
||||
logger.info({
|
||||
'datetime:': datetime.datetime.now(),
|
||||
'command_line': ' '.join(sys.argv[1:]),
|
||||
'logger_global_rank': global_rank,
|
||||
'logger_enable_stdout': enable_stdout,
|
||||
'sys_log_filepath': sys_log_filepath
|
||||
})
|
||||
|
||||
def _setup_apex():
|
||||
conf_common = get_conf_common()
|
||||
distdir = conf_common['distdir']
|
||||
|
||||
global _apex_utils
|
||||
_apex_utils = ApexUtils(distdir, conf_common['apex'])
|
||||
|
||||
def _setup_gpus():
|
||||
conf_common = get_conf_common()
|
||||
|
||||
if conf_common['gpus'] is not None:
|
||||
csv = str(conf_common['gpus'])
|
||||
#os.environ['CUDA_VISIBLE_DEVICES'] = str(conf_common['gpus'])
|
||||
torch.cuda.set_device(int(csv.split(',')[0]))
|
||||
logger.info({'gpu_ids': conf_common['gpus']})
|
||||
# alternative: torch.cuda.set_device(config.gpus[0])
|
||||
|
||||
utils.setup_cuda(conf_common['seed'])
|
||||
|
||||
if conf_common['detect_anomaly']:
|
||||
|
@ -207,12 +260,17 @@ def _setup_gpus():
|
|||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
logger.info({'gpu_names': utils.cuda_device_names(),
|
||||
'gpu_count': torch.cuda.device_count(),
|
||||
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
|
||||
'cudnn.enabled': cudnn.enabled,
|
||||
'cudnn.benchmark': cudnn.benchmark,
|
||||
'cudnn.deterministic': cudnn.deterministic
|
||||
'cudnn.deterministic': cudnn.deterministic,
|
||||
'cudnn.version': cudnn.version()
|
||||
})
|
||||
logger.info({'memory': str(psutil.virtual_memory())})
|
||||
logger.info({'CPUs': str(psutil.cpu_count())})
|
||||
|
||||
|
||||
# gpu_usage = os.popen(
|
||||
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
|
||||
|
|
|
@ -55,7 +55,10 @@ def create_optimizer(conf_opt:Config, params)->Optimizer:
|
|||
def get_optim_lr(optimizer:Optimizer)->float:
|
||||
for param_group in optimizer.param_groups:
|
||||
return param_group['lr']
|
||||
raise RuntimeError('optimizer did not had any param_group named lr!')
|
||||
|
||||
def set_optim_lr(optimizer:Optimizer, lr:float)->None:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def ensure_pytorch_ver(min_ver:str, error_msg:str)->bool:
|
||||
tv = torch.__version__.split('.')
|
||||
|
|
|
@ -10,6 +10,8 @@ import pathlib
|
|||
|
||||
TItems = Union[Mapping, str]
|
||||
|
||||
# do not reference common or otherwise we will have circular deps
|
||||
|
||||
def _fmt(val:Any)->str:
|
||||
if isinstance(val, float):
|
||||
return f'{val:.4g}'
|
||||
|
|
|
@ -9,20 +9,19 @@ from overrides import EnforceOverrides
|
|||
from .metrics import Metrics
|
||||
from .config import Config
|
||||
from . import utils, ml_utils
|
||||
from .common import logger
|
||||
from .common import logger, get_device
|
||||
|
||||
class Tester(EnforceOverrides):
|
||||
"""Evaluate model on given data
|
||||
"""
|
||||
|
||||
def __init__(self, conf_eval:Config, model:nn.Module, device)->None:
|
||||
def __init__(self, conf_eval:Config, model:nn.Module)->None:
|
||||
self._title = conf_eval['title']
|
||||
self._logger_freq = conf_eval['logger_freq']
|
||||
conf_lossfn = conf_eval['lossfn']
|
||||
|
||||
self.model = model
|
||||
self.device = device
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(device)
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(get_device())
|
||||
self._metrics = None
|
||||
|
||||
def test(self, test_dl: DataLoader)->Metrics:
|
||||
|
@ -45,7 +44,7 @@ class Tester(EnforceOverrides):
|
|||
|
||||
with torch.no_grad(), logger.pushd('steps'):
|
||||
for step, (x, y) in enumerate(test_dl):
|
||||
x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
|
||||
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
|
||||
|
||||
assert not self.model.training # derived class might alter the mode
|
||||
logger.pushd(step)
|
||||
|
|
|
@ -11,13 +11,12 @@ from .metrics import Metrics
|
|||
from .tester import Tester
|
||||
from .config import Config
|
||||
from . import utils, ml_utils
|
||||
from ..common.common import logger
|
||||
from ..common.common import logger, get_device, get_apex_utils
|
||||
from ..common.checkpoint import CheckPoint
|
||||
from .apex_utils import Amp
|
||||
|
||||
|
||||
class Trainer(EnforceOverrides):
|
||||
def __init__(self, conf_train:Config, model:nn.Module, device,
|
||||
def __init__(self, conf_train:Config, model:nn.Module,
|
||||
checkpoint:Optional[CheckPoint])->None:
|
||||
# region config vars
|
||||
conf_lossfn = conf_train['lossfn']
|
||||
|
@ -36,12 +35,11 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._checkpoint = checkpoint
|
||||
self.model = model
|
||||
self.device = device
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn).to(device)
|
||||
self._tester = Tester(conf_validation, model, device) \
|
||||
|
||||
self._lossfn = ml_utils.get_lossfn(conf_lossfn)
|
||||
self._tester = Tester(conf_validation, model) \
|
||||
if conf_validation else None
|
||||
self._metrics:Optional[Metrics] = None
|
||||
self._amp = Amp(self._apex)
|
||||
|
||||
self._droppath_module = self._get_droppath_module()
|
||||
if self._droppath_module is None and self._drop_path_prob > 0.0:
|
||||
|
@ -60,18 +58,22 @@ class Trainer(EnforceOverrides):
|
|||
# create scheduler for optim before applying amp
|
||||
self._sched, self._sched_on_epoch = self._create_scheduler(optim, len(train_dl))
|
||||
# before checkpoint restore, convert to amp
|
||||
# TODO: see if original model gets lost after to_amp?
|
||||
self.model, self._optim = self._amp.to_amp(self.model, optim)
|
||||
self.model, self._optim = get_apex_utils().to_amp(self.model, optim,
|
||||
batch_size=train_dl.batch_size)
|
||||
|
||||
self._lossfn = self._lossfn.to(get_device())
|
||||
|
||||
self.pre_fit(train_dl, val_dl)
|
||||
|
||||
# we need to restore checkpoint after all objects are created because
|
||||
# restoring checkpoint requires load_state_dict calls on these objects
|
||||
self._start_epoch = 0
|
||||
# do we have a checkpoint
|
||||
checkpoint_avail = self._checkpoint is not None
|
||||
checkpoint_val = checkpoint_avail and 'trainer' in self._checkpoint
|
||||
resumed = False
|
||||
if checkpoint_val:
|
||||
# restore checkpoint
|
||||
resumed = True
|
||||
self.restore_checkpoint()
|
||||
elif checkpoint_avail: # TODO: bad checkpoint?
|
||||
|
@ -152,8 +154,8 @@ class Trainer(EnforceOverrides):
|
|||
self._metrics.post_epoch(val_metrics, lr=self._optim.param_groups[0]['lr'])
|
||||
|
||||
# checkpoint if enabled with given freq or if this is the last epoch
|
||||
if self._checkpoint is not None and self._checkpoint.freq > 0 and \
|
||||
(self._metrics.epochs() % self._checkpoint.freq == 0 or \
|
||||
if self._checkpoint is not None and get_apex_utils().is_master() and \
|
||||
self._checkpoint.freq > 0 and (self._metrics.epochs() % self._checkpoint.freq == 0 or \
|
||||
self._metrics.epochs() >= self._epochs):
|
||||
self._checkpoint.new()
|
||||
self.update_checkpoint(self._checkpoint)
|
||||
|
@ -174,7 +176,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
self._metrics.load_state_dict(state['metrics'])
|
||||
assert self._metrics.epochs() == last_epoch+1
|
||||
self._amp.load_state_dict(state['amp'])
|
||||
get_apex_utils().load_state_dict(state['amp'])
|
||||
self.model.load_state_dict(state['model'])
|
||||
self._optim.load_state_dict(state['optim'])
|
||||
if self._sched:
|
||||
|
@ -192,7 +194,7 @@ class Trainer(EnforceOverrides):
|
|||
'model': self.model.state_dict(),
|
||||
'optim': self._optim.state_dict(),
|
||||
'sched': self._sched.state_dict() if self._sched else None,
|
||||
'amp': self._amp.state_dict()
|
||||
'amp': get_apex_utils().state_dict()
|
||||
}
|
||||
self._checkpoint['trainer'] = state
|
||||
|
||||
|
@ -202,7 +204,7 @@ class Trainer(EnforceOverrides):
|
|||
|
||||
logger.pushd('steps')
|
||||
for step, (x, y) in enumerate(train_dl):
|
||||
x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
|
||||
x, y = x.to(get_device(), non_blocking=True), y.to(get_device(), non_blocking=True)
|
||||
|
||||
logger.pushd(step)
|
||||
assert self.model.training # derived class might alter the mode
|
||||
|
@ -220,10 +222,10 @@ class Trainer(EnforceOverrides):
|
|||
loss = self.compute_loss(self._lossfn, x, y, logits,
|
||||
self._aux_weight, aux_logits)
|
||||
|
||||
self._amp.backward(loss, self._optim)
|
||||
get_apex_utils().backward(loss, self._optim)
|
||||
|
||||
# TODO: original darts clips alphas as well but pt.darts doesn't
|
||||
self._amp.clip_grad(self._grad_clip, self.model, self._optim)
|
||||
get_apex_utils().clip_grad(self._grad_clip, self.model, self._optim)
|
||||
|
||||
self._optim.step()
|
||||
if self._sched and not self._sched_on_epoch:
|
||||
|
|
|
@ -118,8 +118,10 @@ def zero_file(filepath)->None:
|
|||
"""Creates or truncates existing file"""
|
||||
open(filepath, 'w').close()
|
||||
|
||||
def setup_logging(filepath:Optional[str]=None,
|
||||
name:Optional[str]=None, level=logging.INFO)->logging.Logger:
|
||||
def create_logger(filepath:Optional[str]=None,
|
||||
name:Optional[str]=None,
|
||||
level=logging.INFO,
|
||||
enable_stdout=True)->logging.Logger:
|
||||
logger = logging.getLogger()
|
||||
|
||||
# close current handlers
|
||||
|
@ -128,10 +130,13 @@ def setup_logging(filepath:Optional[str]=None,
|
|||
logger.removeHandler(handler)
|
||||
|
||||
logger.setLevel(level)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
ch.setFormatter(logging.Formatter('%(asctime)s %(message)s', '%H:%M'))
|
||||
logger.addHandler(ch)
|
||||
|
||||
if enable_stdout:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
ch.setFormatter(logging.Formatter('%(asctime)s %(message)s', '%H:%M'))
|
||||
logger.addHandler(ch)
|
||||
|
||||
logger.propagate = False # otherwise root logger prints things again
|
||||
|
||||
if filepath:
|
||||
|
@ -209,7 +214,8 @@ def setup_cuda(seed):
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
#torch.cuda.manual_seed_all(seed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.benchmark = True # set to false if deterministic
|
||||
torch.set_printoptions(precision=10)
|
||||
#cudnn.deterministic = False
|
||||
# torch.cuda.empty_cache()
|
||||
# torch.cuda.synchronize()
|
||||
|
@ -221,3 +227,4 @@ def exec_shell_command(command:str, print_command=True)->None:
|
|||
if print_command:
|
||||
print(command)
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
|
||||
|
|
|
@ -339,7 +339,7 @@ def _eval_tta(conf, augment, reporter):
|
|||
|
||||
loaders = []
|
||||
for _ in range(augment['num_policy']):
|
||||
tl, validloader, tl2, _ = get_dataloaders(augment['dataroot'], ds_name,
|
||||
tl, validloader, tl2 = get_dataloaders(augment['dataroot'], ds_name,
|
||||
, aug, cutout,
|
||||
load_train=True, load_test=True,
|
||||
val_ratio=val_ratio, val_fold=val_fold, n_workers=n_workers)
|
||||
|
|
|
@ -142,7 +142,7 @@ def train_and_eval(conf, val_ratio, val_fold, save_path, only_eval,
|
|||
reporter = lambda **kwargs: 0
|
||||
|
||||
# get dataloaders with transformations and splits applied
|
||||
train_dl, valid_dl, test_dl, trainsampler = get_dataloaders(ds_name,
|
||||
train_dl, valid_dl, test_dl = get_dataloaders(ds_name,
|
||||
batch_size, dataroot, aug, cutout,
|
||||
load_train=True, load_test=True, val_ratio=val_ratio, val_fold=val_fold,
|
||||
horovod=horovod, n_workers=n_workers, max_batches=max_batches)
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import List, Tuple, Union, Optional
|
|||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
@ -21,8 +21,7 @@ from ..common.common import utils
|
|||
from archai.datasets.dataset_provider import DatasetProvider, get_provider_type
|
||||
from ..common.config import Config
|
||||
from .limit_dataset import LimitDataset, DatasetLike
|
||||
|
||||
|
||||
from .distributed_stratified_sampler import DistributedStratifiedSampler
|
||||
|
||||
def get_data(conf_loader:Config)\
|
||||
-> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
|
||||
|
@ -35,7 +34,6 @@ def get_data(conf_loader:Config)\
|
|||
cutout = conf_loader['cutout']
|
||||
val_ratio = conf_loader['val_ratio']
|
||||
val_fold = conf_loader['val_fold']
|
||||
horovod = conf_loader['horovod']
|
||||
load_train = conf_loader['load_train']
|
||||
train_batch = conf_loader['train_batch']
|
||||
train_workers = conf_loader['train_workers']
|
||||
|
@ -46,11 +44,11 @@ def get_data(conf_loader:Config)\
|
|||
|
||||
ds_provider = create_dataset_provider(conf_data)
|
||||
|
||||
train_dl, val_dl, test_dl, *_ = get_dataloaders(ds_provider,
|
||||
train_dl, val_dl, test_dl = get_dataloaders(ds_provider,
|
||||
load_train=load_train, train_batch_size=train_batch,
|
||||
load_test=load_test, test_batch_size=test_batch,
|
||||
aug=aug, cutout=cutout, val_ratio=val_ratio, val_fold=val_fold,
|
||||
train_workers=train_workers, test_workers=test_workers, horovod=horovod,
|
||||
train_workers=train_workers, test_workers=test_workers,
|
||||
max_batches=max_batches)
|
||||
|
||||
assert train_dl is not None
|
||||
|
@ -71,27 +69,24 @@ def get_dataloaders(ds_provider:DatasetProvider,
|
|||
load_test:bool, test_batch_size:int,
|
||||
aug, cutout:int, val_ratio:float, val_fold=0,
|
||||
train_workers:Optional[int]=None, test_workers:Optional[int]=None,
|
||||
horovod=False, target_lb=-1, max_batches:int=-1) \
|
||||
-> Tuple[Optional[DataLoader], Optional[DataLoader],
|
||||
Optional[DataLoader], Optional[Sampler]]:
|
||||
target_lb=-1, max_batches:int=-1) \
|
||||
-> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
|
||||
|
||||
# if debugging in vscode, workers > 0 gets termination
|
||||
if utils.is_debugging():
|
||||
train_workers = test_workers = 0
|
||||
logger.warn({'debugger': True})
|
||||
if train_workers is None:
|
||||
train_workers = torch.cuda.device_count() * 4
|
||||
train_workers = 4
|
||||
if test_workers is None:
|
||||
test_workers = torch.cuda.device_count() * 4
|
||||
test_workers = 4
|
||||
logger.info({'train_workers': train_workers, 'test_workers':test_workers})
|
||||
|
||||
transform_train, transform_test = ds_provider.get_transforms()
|
||||
add_named_augs(transform_train, aug, cutout)
|
||||
|
||||
trainset, testset = _get_datasets(ds_provider,
|
||||
load_train, load_test, transform_train, transform_test,
|
||||
train_max_size=max_batches*train_batch_size,
|
||||
test_max_size=max_batches*test_batch_size)
|
||||
load_train, load_test, transform_train, transform_test)
|
||||
|
||||
# TODO: below will never get executed, set_preaug does not exist in PyTorch
|
||||
# if total_aug is not None and augs is not None:
|
||||
|
@ -100,37 +95,52 @@ def get_dataloaders(ds_provider:DatasetProvider,
|
|||
|
||||
trainloader, validloader, testloader, train_sampler = None, None, None, None
|
||||
|
||||
max_train_fold = max_batches*train_batch_size if max_batches else None
|
||||
max_test_fold = max_batches*test_batch_size if max_batches else None
|
||||
logger.info({'val_ratio': val_ratio, 'max_batches': max_batches,
|
||||
'max_train_fold': max_train_fold, 'max_test_fold': max_test_fold})
|
||||
|
||||
if trainset:
|
||||
# sample validation set from trainset if cv_ratio > 0
|
||||
train_sampler, valid_sampler = _get_train_sampler(val_ratio, val_fold,
|
||||
trainset, horovod, target_lb)
|
||||
train_sampler, valid_sampler = _get_sampler(trainset, val_ratio=val_ratio,
|
||||
shuffle=True,
|
||||
max_items=max_train_fold)
|
||||
|
||||
# shuffle is performed by sampler at each epoch
|
||||
trainloader = DataLoader(trainset,
|
||||
batch_size=train_batch_size, shuffle=True if train_sampler is None else False,
|
||||
num_workers=train_workers, pin_memory=True,
|
||||
batch_size=train_batch_size, shuffle=False,
|
||||
num_workers=round((1-val_ratio)*train_workers),
|
||||
pin_memory=True,
|
||||
sampler=train_sampler, drop_last=False) # TODO: original paper has this True
|
||||
if train_sampler is not None:
|
||||
|
||||
if val_ratio > 0.0:
|
||||
validloader = DataLoader(trainset,
|
||||
batch_size=train_batch_size, shuffle=False,
|
||||
num_workers=train_workers, pin_memory=True, #TODO: set n_workers per ratio?
|
||||
num_workers=round(val_ratio*train_workers), # if val_ratio = 0.5, then both sets re same
|
||||
pin_memory=True, #TODO: set n_workers per ratio?
|
||||
sampler=valid_sampler, drop_last=False)
|
||||
# else validloader is left as None
|
||||
if testset:
|
||||
test_sampler, _ = _get_sampler(testset, val_ratio=0.0,
|
||||
shuffle=False,
|
||||
max_items=max_test_fold)
|
||||
testloader = DataLoader(testset,
|
||||
batch_size=test_batch_size, shuffle=False,
|
||||
num_workers=test_workers, pin_memory=True,
|
||||
sampler=None, drop_last=False
|
||||
num_workers=test_workers,
|
||||
pin_memory=True,
|
||||
sampler=test_sampler, drop_last=False
|
||||
)
|
||||
|
||||
assert val_ratio > 0.0 or validloader is None
|
||||
|
||||
logger.info({
|
||||
'train_batch_size': train_batch_size, 'test_batch_size': test_batch_size,
|
||||
'train_batches': len(trainloader) if trainloader is not None else None,
|
||||
'val_batches': len(validloader) if validloader is not None else None,
|
||||
'test_batches': len(testloader) if testloader is not None else None
|
||||
})
|
||||
|
||||
# we have to return train_sampler because of horovod
|
||||
return trainloader, validloader, testloader, train_sampler
|
||||
return trainloader, validloader, testloader
|
||||
|
||||
|
||||
class SubsetSampler(Sampler):
|
||||
|
@ -150,75 +160,26 @@ class SubsetSampler(Sampler):
|
|||
return len(self.indices)
|
||||
|
||||
def _get_datasets(ds_provider:DatasetProvider, load_train:bool, load_test:bool,
|
||||
transform_train, transform_test, train_max_size:int, test_max_size:int)\
|
||||
transform_train, transform_test)\
|
||||
->Tuple[DatasetLike, DatasetLike]:
|
||||
trainset, testset = ds_provider.get_datasets(load_train, load_test,
|
||||
transform_train, transform_test)
|
||||
|
||||
if train_max_size > 0:
|
||||
logger.warn({'train_max_size': train_max_size})
|
||||
trainset = LimitDataset(trainset, train_max_size)
|
||||
if test_max_size > 0:
|
||||
logger.warn({'test_max_size': test_max_size})
|
||||
testset = LimitDataset(testset, test_max_size)
|
||||
|
||||
return trainset, testset
|
||||
|
||||
# target_lb allows to filter dataset for a specific class, not used
|
||||
def _get_train_sampler(val_ratio:float, val_fold:int, trainset, horovod,
|
||||
target_lb:int=-1)->Tuple[Optional[Sampler], Sampler]:
|
||||
"""Splits train set into train, validation sets, stratified rand sampling.
|
||||
|
||||
Arguments:
|
||||
val_ratio {float} -- % of data to put in valid set
|
||||
val_fold {int} -- Total of 5 folds are created, val_fold specifies which
|
||||
one to use
|
||||
target_lb {int} -- If >= 0 then trainset is filtered for only that
|
||||
target class ID
|
||||
"""
|
||||
assert val_fold >= 0
|
||||
|
||||
train_sampler, valid_sampler = None, None
|
||||
logger.info({'val_ratio': val_ratio})
|
||||
if val_ratio > 0.0: # if val_ratio is not specified then sampler is empty
|
||||
"""stratified shuffle val_ratio will yield return total of n_splits,
|
||||
each val_ratio containing tuple of train and valid set with valid set
|
||||
size portion = val_ratio, while samples for each class having same
|
||||
proportions as original dataset"""
|
||||
def _get_sampler(dataset:Dataset, val_ratio:Optional[float], shuffle:bool,
|
||||
max_items:Optional[int])->Tuple[Sampler, Optional[Sampler]]:
|
||||
# we cannot not shuffle just for train or just val because of in distributed mode both must come from same shrad
|
||||
train_sampler = DistributedStratifiedSampler(dataset,
|
||||
val_ratio=val_ratio, is_val=False, shuffle=shuffle,
|
||||
max_items=max_items)
|
||||
valid_sampler = DistributedStratifiedSampler(dataset,
|
||||
val_ratio=val_ratio, is_val=True, shuffle=shuffle,
|
||||
max_items=max_items) \
|
||||
if val_ratio is not None else None
|
||||
|
||||
|
||||
# TODO: random_state should be None so np.random is used
|
||||
# TODO: keep hardcoded n_splits=5?
|
||||
sss = StratifiedShuffleSplit(n_splits=5, test_size=val_ratio,
|
||||
random_state=0)
|
||||
sss = sss.split(list(range(len(trainset))), trainset.targets)
|
||||
|
||||
# we have 5 plits, but will select only one of them by val_fold
|
||||
for _ in range(val_fold + 1):
|
||||
train_idx, valid_idx = next(sss)
|
||||
|
||||
if target_lb >= 0:
|
||||
train_idx = [i for i in train_idx if trainset.targets[i] == target_lb]
|
||||
valid_idx = [i for i in valid_idx if trainset.targets[i] == target_lb]
|
||||
|
||||
# NOTE: we apply random sampler for validation set as well because
|
||||
# this set is used for training alphas for darts
|
||||
train_sampler = SubsetRandomSampler(train_idx)
|
||||
valid_sampler = SubsetRandomSampler(valid_idx)
|
||||
|
||||
if horovod: # train sampler for horovod
|
||||
import horovod.torch as hvd
|
||||
train_sampler = DistributedSampler(
|
||||
train_sampler, num_replicas=hvd.size(), rank=hvd.rank())
|
||||
else:
|
||||
# this means no sampling, validation set would be empty
|
||||
valid_sampler = SubsetSampler([])
|
||||
|
||||
if horovod: # train sampler for horovod
|
||||
import horovod.torch as hvd
|
||||
train_sampler = DistributedSampler(
|
||||
valid_sampler, num_replicas=hvd.size(), rank=hvd.rank())
|
||||
# else train_sampler is None
|
||||
return train_sampler, valid_sampler
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
|
||||
|
||||
class DistributedStratifiedSampler(Sampler):
|
||||
def __init__(self, dataset:Dataset, num_replicas:Optional[int]=None,
|
||||
rank:Optional[int]=None, shuffle=True,
|
||||
val_ratio=0.0, is_val=False, auto_epoch=True,
|
||||
max_items:Optional[int]=None):
|
||||
"""Performs stratified sampling of dataset for each replica in the distributed as well as non-distributed setting. If validation split is needed then yet another stratified sampling within replica's split is performed to further obtain the train/validation splits.
|
||||
|
||||
This sampler works in distributed as well as non-distributed setting with no panelty in either mode and is replacement for built-in torch.util.data.DistributedSampler. In distributed setting, many instances of the same code runs as process known as replicas. Each replica has sequential number assigned by the launcher, starting from 0 to uniquely identify it. This is known as global rank or simply rank. The number of replicas is known as the world size. For non-distributed setting, world_size=1 and rank=0.
|
||||
|
||||
To perform stratified sampling we need labels. This sampler assumes that labels for each datapoint is available in dataset.targets property which should be array like containing as many values as length of the dataset. This is availalble already for many popular datasets such as cifar and, with newer PyTorch versions, ImageFolder as well as DatasetFolder. If you are using custom dataset, you can usually create this property with one line of code such as `dataset.targets = [yi for _, yi in dataset]`.
|
||||
|
||||
Generally, to do distributed sampling, each replica must shuffle with same seed as all other replicas with every epoch and then chose some subset of dataset for itself. Traditionally, we use epoch number as seed for shuffling for each replica. However, this then requires that training code calls sampler.set_epoch(epoch) to set seed at every epoch. This breaks many training code which doesn't have access to the sampler. However, this is unnecessory as well. Sampler knows when each new iteration is requested, so it can automatically increment epoch and make call to set_epoch by itself. This is what this sampler does. This makes things transparent to users in distributed or non-distributed setting. If you don't want this behaviour then pass auto_epoch=False.
|
||||
|
||||
Arguments:
|
||||
dataset -- PyTorch dataset like object
|
||||
|
||||
Keyword Arguments:
|
||||
num_replicas -- Total number of replicas running in distributed setting, if None then auto-detect, 1 for non distributed setting (default: {None})
|
||||
rank -- Global rank of this replica, if None then auto-detect, 0 for non distributed setting (default: {None})
|
||||
shuffle {bool} -- If True then suffle at every epoch (default: {True})
|
||||
val_ratio {float} -- If you want to create validation split then set to > 0 (default: {0.0})
|
||||
is_val {bool} -- If True then validation split is returned set to val_ratio otherwise main split is returned (default: {False})
|
||||
auto_epoch {bool} -- if True then automatically count epoch for each new iteration eliminating the need to call set_epoch() in distributed setting (default: {True})
|
||||
max_items -- if not None then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset)
|
||||
"""
|
||||
|
||||
|
||||
# cifar10 amd DatasetFolder has this attribute, for others it may be easy to add from outside
|
||||
assert hasattr(dataset, 'targets') and dataset.targets is not None, 'dataset needs to have targets attribute to work with this sampler'
|
||||
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
num_replicas = 1
|
||||
else:
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
rank = 0
|
||||
else:
|
||||
rank = dist.get_rank()
|
||||
|
||||
assert num_replicas >= 1
|
||||
assert rank >= 0 and rank < num_replicas
|
||||
assert val_ratio < 1.0 and val_ratio >= 0.0
|
||||
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = -1
|
||||
self.auto_epoch = auto_epoch
|
||||
self.shuffle = shuffle
|
||||
self.data_len = len(self.dataset)
|
||||
self.max_items = max_items
|
||||
assert self.data_len == len(dataset.targets)
|
||||
self.val_ratio = val_ratio
|
||||
self.is_val = is_val
|
||||
|
||||
# computing duplications we needs
|
||||
self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas))
|
||||
self.total_size = self.replica_len_full * self.num_replicas
|
||||
assert self.total_size >= self.data_len
|
||||
|
||||
if self.max_items:
|
||||
self.replica_len = min(self.replica_len_full, self.max_items)
|
||||
|
||||
self.main_split_len = int(math.floor(self.replica_len*(1-val_ratio)))
|
||||
self.val_split_len = self.replica_len - self.main_split_len
|
||||
self._len = self.val_split_len if self.is_val else self.main_split_len
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
# every time new iterator is requested, we increase epoch count
|
||||
# assuming all replicas do same thing, this avoids need to call set_epoch
|
||||
if self.auto_epoch:
|
||||
self.epoch += 1
|
||||
|
||||
# get shuffled indices, dataset is extended if needed to divide equally
|
||||
# between replicas
|
||||
indices, targets = self._indices()
|
||||
|
||||
# get the fold which we will assign to current replica
|
||||
indices, targets = self._replica_fold(indices, targets)
|
||||
|
||||
indices, targets = self._limit(indices, targets, self.max_items)
|
||||
|
||||
# split current replica's fold between train and val
|
||||
# return indices depending on if we are val or train split
|
||||
indices, _ = self._split(indices, targets, self.val_split_len, self.is_val)
|
||||
assert len(indices) == self._len
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def _replica_fold(self, indices:np.ndarray, targets:np.ndarray)\
|
||||
->Tuple[np.ndarray, np.ndarray]:
|
||||
|
||||
if self.num_replicas > 1:
|
||||
replica_fold_idxs = None
|
||||
rfolder = StratifiedKFold(n_splits=self.num_replicas, shuffle=False)
|
||||
folds = rfolder.split(indices, targets)
|
||||
for _ in range(self.rank + 1):
|
||||
other_fold_idxs, replica_fold_idxs = next(folds)
|
||||
|
||||
assert replica_fold_idxs is not None and \
|
||||
len(replica_fold_idxs)==self.replica_len_full
|
||||
|
||||
return indices[replica_fold_idxs], targets[replica_fold_idxs]
|
||||
else:
|
||||
assert self.num_replicas == 1
|
||||
return indices, targets
|
||||
|
||||
|
||||
def _indices(self)->Tuple[np.ndarray, np.ndarray]:
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(self.data_len, generator=g).numpy()
|
||||
else:
|
||||
indices = np.arange(self.data_len)
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
# this is neccesory because we have __len__ which must return same
|
||||
# number consistently
|
||||
if self.total_size > self.data_len:
|
||||
indices = np.append(indices, indices[:(self.total_size - self.data_len)])
|
||||
else:
|
||||
assert self.total_size == self.data_len, 'total_size cannot be less than dataset size!'
|
||||
|
||||
targets = np.array(list(self.dataset.targets[i] for i in indices))
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
return indices, targets
|
||||
|
||||
def _limit(self, indices:np.ndarray, targets:np.ndarray, max_items:Optional[int])\
|
||||
->Tuple[np.ndarray, np.ndarray]:
|
||||
if max_items:
|
||||
return self._split(indices, targets, max_items, True)
|
||||
return indices, targets
|
||||
|
||||
def _split(self, indices:np.ndarray, targets:np.ndarray, test_size:int,
|
||||
return_test_split:bool)->Tuple[np.ndarray, np.ndarray]:
|
||||
if test_size:
|
||||
assert isinstance(test_size, int) # othewise next call assumes ratio instead of count
|
||||
vfolder = StratifiedShuffleSplit(n_splits=1,
|
||||
test_size=test_size,
|
||||
random_state=self.epoch)
|
||||
vfolder = vfolder.split(indices, targets)
|
||||
train_idx, valid_idx = next(vfolder)
|
||||
|
||||
idxs = valid_idx if return_test_split else train_idx
|
||||
return indices[idxs], targets[idxs]
|
||||
else:
|
||||
return indices, targets
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
|
@ -58,7 +58,8 @@ class ImagenetProvider(DatasetProvider):
|
|||
transform_train, transform_test = None, None
|
||||
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224, scale=(0.08, 1.0),
|
||||
transforms.RandomResizedCrop(224,
|
||||
scale=(0.08, 1.0), # TODO: these two params are normally not specified
|
||||
interpolation=Image.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(
|
||||
|
|
|
@ -20,9 +20,9 @@ from ..common.checkpoint import CheckPoint
|
|||
TArchTrainer = Optional[Type['ArchTrainer']]
|
||||
|
||||
class ArchTrainer(Trainer, EnforceOverrides):
|
||||
def __init__(self, conf_train: Config, model: Model, device,
|
||||
def __init__(self, conf_train: Config, model: Model,
|
||||
checkpoint:Optional[CheckPoint]) -> None:
|
||||
super().__init__(conf_train, model, device, checkpoint)
|
||||
super().__init__(conf_train, model, checkpoint)
|
||||
|
||||
self._l1_alphas = conf_train['l1_alphas']
|
||||
self._plotsdir = conf_train['plotsdir']
|
||||
|
|
|
@ -26,19 +26,17 @@ def eval_arch(conf_eval:Config, cell_builder:Optional[CellBuilder]):
|
|||
conf_train = conf_eval['trainer']
|
||||
# endregion
|
||||
|
||||
device = torch.device(conf_eval['device'])
|
||||
|
||||
if cell_builder:
|
||||
cell_builder.register_ops()
|
||||
|
||||
model = create_model(conf_eval, device)
|
||||
model = create_model(conf_eval)
|
||||
|
||||
# get data
|
||||
train_dl, _, test_dl = data.get_data(conf_loader)
|
||||
assert train_dl is not None and test_dl is not None
|
||||
|
||||
checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
|
||||
trainer = Trainer(conf_train, model, device, checkpoint)
|
||||
trainer = Trainer(conf_train, model, checkpoint)
|
||||
train_metrics = trainer.fit(train_dl, test_dl)
|
||||
train_metrics.save(metric_filename)
|
||||
|
||||
|
@ -65,7 +63,7 @@ def _default_module_name(dataset_name:str, function_name:str)->str:
|
|||
raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet')
|
||||
return module_name
|
||||
|
||||
def create_model(conf_eval:Config, device)->nn.Module:
|
||||
def create_model(conf_eval:Config)->nn.Module:
|
||||
# region conf vars
|
||||
dataset_name = conf_eval['loader']['dataset']['name']
|
||||
final_desc_filename = conf_eval['final_desc_filename']
|
||||
|
@ -86,7 +84,6 @@ def create_model(conf_eval:Config, device)->nn.Module:
|
|||
module = importlib.import_module(module_name) if module_name else sys.modules[__name__]
|
||||
function = getattr(module, function_name)
|
||||
model = function()
|
||||
model = nas_utils.to_device(model, device)
|
||||
|
||||
logger.info({'model_factory':True,
|
||||
'module_name': module_name,
|
||||
|
@ -97,8 +94,7 @@ def create_model(conf_eval:Config, device)->nn.Module:
|
|||
template_model_desc = ModelDesc.load(final_desc_filename)
|
||||
|
||||
model = nas_utils.model_from_conf(full_desc_filename,
|
||||
conf_model_desc, device,
|
||||
affine=True, droppath=True,
|
||||
conf_model_desc, affine=True, droppath=True,
|
||||
template_model_desc=template_model_desc)
|
||||
|
||||
logger.info({'model_factory':False,
|
||||
|
|
|
@ -38,25 +38,17 @@ def create_checkpoint(conf_checkpoint:Config, resume:bool)->Optional[CheckPoint]
|
|||
return checkpoint
|
||||
|
||||
def model_from_conf(full_desc_filename:str, conf_model_desc: Config,
|
||||
device, affine:bool, droppath:bool,
|
||||
template_model_desc:ModelDesc)->Model:
|
||||
affine:bool, droppath:bool, template_model_desc:ModelDesc)->Model:
|
||||
"""Creates model given desc config and template"""
|
||||
# create model
|
||||
model_desc = create_macro_desc(conf_model_desc, template_model_desc)
|
||||
# save model that we would eval for reference
|
||||
model_desc.save(full_desc_filename)
|
||||
|
||||
return model_from_desc(model_desc, device, droppath=droppath, affine=affine)
|
||||
return model_from_desc(model_desc, droppath=droppath, affine=affine)
|
||||
|
||||
def model_from_desc(model_desc, device, droppath:bool, affine:bool)->Model:
|
||||
def model_from_desc(model_desc, droppath:bool, affine:bool)->Model:
|
||||
model = Model(model_desc, droppath=droppath, affine=affine)
|
||||
return to_device(model, device) # type: ignore
|
||||
|
||||
def to_device(model:nn.Module, device)->nn.Module:
|
||||
# TODO: enable DataParallel
|
||||
# if data_parallel:
|
||||
# model = nn.DataParallel(model).to(device)
|
||||
# else:
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -110,7 +110,6 @@ class Search:
|
|||
pareto_summary_filename = conf_pareto['summary_filename']
|
||||
# endregion
|
||||
|
||||
self.device = torch.device(conf_search['device'])
|
||||
self.cell_builder = cell_builder
|
||||
self.trainer_class = trainer_class
|
||||
self._data_cache = {}
|
||||
|
@ -317,14 +316,15 @@ class Search:
|
|||
# nothing to pretrain, save time
|
||||
metrics_stats = MetricsStats(model_desc, None, None)
|
||||
else:
|
||||
model = nas_utils.model_from_desc(model_desc, self.device,
|
||||
droppath=drop_path_prob>0.0, affine=True)
|
||||
model = nas_utils.model_from_desc(model_desc,
|
||||
droppath=drop_path_prob>0.0,
|
||||
affine=True)
|
||||
|
||||
# get data
|
||||
train_dl, val_dl = self.get_data(conf_loader)
|
||||
assert train_dl is not None
|
||||
|
||||
trainer = Trainer(conf_trainer, model, self.device, checkpoint=None)
|
||||
trainer = Trainer(conf_trainer, model, checkpoint=None)
|
||||
train_metrics = trainer.fit(train_dl, val_dl)
|
||||
|
||||
metrics_stats = Search._create_metrics_stats(model, train_metrics)
|
||||
|
@ -346,16 +346,16 @@ class Search:
|
|||
nas_utils.build_cell(model_desc, self.cell_builder, search_iter)
|
||||
|
||||
if self.trainer_class:
|
||||
model = nas_utils.model_from_desc(model_desc, self.device,
|
||||
droppath=False, affine=False)
|
||||
model = nas_utils.model_from_desc(model_desc,
|
||||
droppath=False,
|
||||
affine=False)
|
||||
|
||||
# get data
|
||||
train_dl, val_dl = self.get_data(self.conf_loader)
|
||||
assert train_dl is not None
|
||||
|
||||
# search arch
|
||||
arch_trainer = self.trainer_class(self.conf_train, model, self.device,
|
||||
checkpoint=None)
|
||||
arch_trainer = self.trainer_class(self.conf_train, model, checkpoint=None)
|
||||
train_metrics = arch_trainer.fit(train_dl, val_dl)
|
||||
|
||||
metrics_stats = Search._create_metrics_stats(model, train_metrics)
|
||||
|
|
|
@ -11,7 +11,7 @@ from .wideresnet import WideResNet
|
|||
from .shakeshake.shake_resnext import ShakeResNeXt
|
||||
|
||||
|
||||
def get_model(conf, num_class=10, data_parallel=True):
|
||||
def get_model(conf, num_class=10):
|
||||
name = conf['type']
|
||||
|
||||
if name == 'resnet50':
|
||||
|
@ -40,15 +40,6 @@ def get_model(conf, num_class=10, data_parallel=True):
|
|||
else:
|
||||
raise NameError('no model named, %s' % name)
|
||||
|
||||
if data_parallel:
|
||||
model = model.cuda()
|
||||
model = DataParallel(model)
|
||||
else:
|
||||
import horovod.torch as hvd
|
||||
device = torch.device('cuda', hvd.local_rank())
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def num_class(dataset):
|
||||
return {
|
||||
|
|
|
@ -7,8 +7,6 @@ common:
|
|||
seed: 2.0
|
||||
tb_enable: False # if True then TensorBoard logging is enabled (may impact perf)
|
||||
tb_dir: '$expdir/tb' # path where tensorboard logs would be stored
|
||||
horovod: False
|
||||
device: 'cuda'
|
||||
checkpoint:
|
||||
filename: '$expdir/checkpoint.pth'
|
||||
freq: 10
|
||||
|
@ -19,7 +17,17 @@ common:
|
|||
# otherwise it should something like host:6379. Make sure to run on head node:
|
||||
# "ray start --head --redis-port=6379"
|
||||
redis: null
|
||||
gpus: 0 # use GPU IDs specified here (comma separated), if null then use all GPUs
|
||||
apex:
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
enabled: False # global switch to disable anything apex
|
||||
opt_level: 'O1' # optimization level for mixed precision
|
||||
bn_fp32: True # keep BN in fp32
|
||||
loss_scale: None # loss scaling mode for mixed prec
|
||||
sync_bn: False # should be replace BNs with sync BNs for distributed model
|
||||
distributed: False # enable/disable distributed mode
|
||||
scale_lr: True # enable/disable distributed mode
|
||||
min_world_size: 0 # allows to confirm we are indeed in distributed setting
|
||||
|
||||
|
||||
smoke_test: False
|
||||
only_eval: False
|
||||
|
@ -39,7 +47,6 @@ nas:
|
|||
|
||||
metric_filename: '$expdir/eval_train_metrics.yaml'
|
||||
model_filename: '$expdir/model.pt' # file to which trained model will be saved
|
||||
device: '_copy: common/device'
|
||||
data_parallel: False
|
||||
checkpoint:
|
||||
_copy: 'common/checkpoint'
|
||||
|
@ -74,7 +81,6 @@ nas:
|
|||
val_ratio: 0.0 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
horovod: '_copy: common/horovod'
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
|
@ -114,7 +120,6 @@ nas:
|
|||
full_desc_filename: '$expdir/full_model_desc.yaml' # arch before it was finalized
|
||||
final_desc_filename: '$expdir/final_model_desc.yaml' # final arch is saved in this file
|
||||
metrics_dir: '$expdir/models/{reductions}/{cells}/{nodes}/{search_iter}' # where metrics and model stats would be saved from each pareto iteration
|
||||
device: '_copy: common/device'
|
||||
seed_train:
|
||||
trainer:
|
||||
_copy: 'nas/eval/trainer'
|
||||
|
@ -176,7 +181,6 @@ nas:
|
|||
val_ratio: 0.5 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
horovod: '_copy: common/horovod'
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
trainer:
|
||||
|
@ -233,7 +237,6 @@ autoaug:
|
|||
val_ratio: 0.4 #split portion for test set, 0 to 1
|
||||
val_fold: 0 #Fold number to use (0 to 4)
|
||||
cv_num: 5 # total number of folds available
|
||||
horovod: '_copy: common/horovod'
|
||||
dataset:
|
||||
_copy: '/dataset'
|
||||
optimizer:
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, g
|
|||
from torch_testbed.dataloader_dali import cifar10_dataloaders
|
||||
|
||||
|
||||
utils.setup_logging()
|
||||
utils.create_logger()
|
||||
utils.setup_cuda(42)
|
||||
|
||||
batch_size = 512
|
||||
|
|
|
@ -14,12 +14,12 @@ conf_model_desc = conf_eval['model_desc']
|
|||
|
||||
conf_model_desc['n_cells'] = 14
|
||||
template_model_desc = ModelDesc.load('$expdir/final_model_desc.yaml')
|
||||
model_desc = create_macro_desc(conf_model_desc, True, template_model_desc)
|
||||
model_desc = create_macro_desc(conf_model_desc, template_model_desc)
|
||||
|
||||
mb = PetridishCellBuilder()
|
||||
mb.register_ops()
|
||||
model = Model(model_desc, droppath=False, affine=False)
|
||||
#model.cuda()
|
||||
|
||||
summary(model, [64, 3, 32, 32])
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
from torch_testbed import utils, cifar10_models
|
||||
from torch_testbed.timing import MeasureTime, print_all_timings, print_timing, get_timing
|
||||
|
||||
utils.setup_logging()
|
||||
utils.create_logger()
|
||||
utils.setup_cuda(42)
|
||||
|
||||
batch_size = 512
|
||||
|
|
|
@ -5,7 +5,7 @@ from archai import cifar10_models
|
|||
|
||||
from archai.common.trainer import Trainer
|
||||
from archai.common.config import Config
|
||||
from archai.common.common import logger, common_init
|
||||
from archai.common.common import logger, common_init, get_device
|
||||
from archai.datasets import data
|
||||
|
||||
def train_test(conf_eval:Config):
|
||||
|
@ -23,15 +23,14 @@ def train_test(conf_eval:Config):
|
|||
conf_trainer['grad_clip'] = 0.0
|
||||
conf_trainer['aux_weight'] = 0.0
|
||||
|
||||
device = torch.device(conf_eval['device'])
|
||||
Net = cifar10_models.resnet34
|
||||
model = Net().to(device)
|
||||
model = Net().to(get_device())
|
||||
|
||||
# get data
|
||||
train_dl, _, test_dl = data.get_data(conf_loader)
|
||||
assert train_dl is not None and test_dl is not None
|
||||
|
||||
trainer = Trainer(conf_trainer, model, device, None)
|
||||
trainer = Trainer(conf_trainer, model, None)
|
||||
trainer.fit(train_dl, test_dl)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
from archai.datasets.distributed_stratified_sampler import DistributedStratifiedSampler
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
from collections import Counter
|
||||
import random
|
||||
|
||||
class ListDataset(Dataset):
|
||||
def __init__(self, x, y, transform=None):
|
||||
self.x = x
|
||||
self.targets = self.y = np.array(y)
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.x[index], self.y[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
def _dist_no_val(rep_count:int, data_len=1000, labels_len=2, val_ratio=0.0):
|
||||
x = np.random.randint(-data_len, data_len, data_len)
|
||||
labels = np.array(range(labels_len))
|
||||
y = np.repeat(labels, math.ceil(float(data_len)/labels_len))[:data_len]
|
||||
np.random.shuffle(y)
|
||||
dataset = ListDataset(x, y)
|
||||
|
||||
train_samplers, val_samplers = [], []
|
||||
for i in range(rep_count):
|
||||
train_samplers.append(DistributedStratifiedSampler(dataset,
|
||||
num_replicas=rep_count,
|
||||
rank=i,
|
||||
val_ratio=val_ratio,
|
||||
is_val=False))
|
||||
val_samplers.append(DistributedStratifiedSampler(dataset,
|
||||
num_replicas=rep_count,
|
||||
rank=i,
|
||||
val_ratio=val_ratio,
|
||||
is_val=True))
|
||||
tl = [list(iter(s)) for s in train_samplers]
|
||||
vl = [list(iter(s)) for s in val_samplers]
|
||||
|
||||
l = [tli+vli for tli, vli in zip(tl,vl)] # combile train val
|
||||
all_len = sum((len(li) for li in l))
|
||||
u = set(i for li in l for i in li)
|
||||
|
||||
# verify stratification
|
||||
for vli, tli in zip(vl, tl):
|
||||
vlic = Counter(dataset.targets[vli])
|
||||
assert len(vlic.keys()) == labels_len
|
||||
assert max(vlic.values())-min(vlic.values()) <=2
|
||||
tlic = Counter(dataset.targets[tli])
|
||||
assert len(tlic.keys()) == labels_len
|
||||
assert max(tlic.values())-min(tlic.values()) <=2
|
||||
|
||||
# below means all indices are equally divided between shards
|
||||
assert len(set((len(li) for li in l)))==1 # all shards equal
|
||||
assert all((len(li)>=len(dataset)/rep_count for li in l))
|
||||
assert all((len(li)<=len(dataset)/rep_count+1 for li in l))
|
||||
assert min(u)==0
|
||||
assert max(u)==len(x)-1
|
||||
assert len(u)==len(x)
|
||||
assert all((float(len(vli))/(len(vli)+len(tli))>=val_ratio for vli, tli in zip(vl, tl)))
|
||||
assert all(((len(vli)-1.0)/(len(vli)+len(tli))<=val_ratio for vli, tli in zip(vl, tl)))
|
||||
assert all((len(set(vli).union(tli))==len(vli+tli) for vli, tli in zip(vl, tl)))
|
||||
assert all_len <= math.ceil(len(x)/rep_count)*rep_count
|
||||
|
||||
def test_combinations():
|
||||
st = time.time()
|
||||
labels_len = 2
|
||||
combs = 0
|
||||
random.seed(0)
|
||||
for data_len in (100, 1001, 17777):
|
||||
max_rep = int(math.sqrt(data_len)*3)
|
||||
for rep_count in range(1, max_rep, max(1, max_rep//17)):
|
||||
for val_num in range(0, random.randint(0,5)):
|
||||
combs += 1
|
||||
val_ratio = val_num/11.0 # good to have prime numbers
|
||||
if math.floor(val_ratio*data_len/rep_count) >= labels_len:
|
||||
_dist_no_val(rep_count=rep_count, val_ratio=val_ratio, data_len=data_len, labels_len=labels_len)
|
||||
elapsed = time.time()-st
|
||||
print('elapsed', elapsed, 'combs', combs)
|
||||
|
||||
_dist_no_val(1, 100, val_ratio=0.1)
|
||||
test_combinations()
|
Загрузка…
Ссылка в новой задаче