зеркало из https://github.com/microsoft/archai.git
remove imports for warmup_scheduler, add apex support in loader, allow None val_ratio, num_replicas->world_size, partial test case for imagenet
This commit is contained in:
Родитель
980f5b7b18
Коммит
a6c5073520
|
@ -2,9 +2,6 @@ import torch
|
|||
import torch.backends.cudnn as cudnn
|
||||
from torch import nn
|
||||
from torch.optim import lr_scheduler, SGD, Adam
|
||||
from warmup_scheduler import GradualWarmupScheduler
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.nn.modules.loss import _WeightedLoss, _Loss
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ import torch
|
|||
import torch.backends.cudnn as cudnn
|
||||
from torch import nn
|
||||
from torch.optim import lr_scheduler, SGD, Adam
|
||||
from warmup_scheduler import GradualWarmupScheduler
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.nn.modules.loss import _WeightedLoss, _Loss
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
from .augmentation import add_named_augs
|
||||
from ..common.common import logger
|
||||
from ..common.common import utils
|
||||
from ..common import utils, apex_utils
|
||||
from archai.datasets.dataset_provider import DatasetProvider, get_provider_type
|
||||
from ..common.config import Config
|
||||
from .limit_dataset import LimitDataset, DatasetLike
|
||||
|
@ -25,6 +25,9 @@ from .distributed_stratified_sampler import DistributedStratifiedSampler
|
|||
|
||||
def get_data(conf_loader:Config)\
|
||||
-> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
|
||||
|
||||
logger.pushd('data')
|
||||
|
||||
# region conf vars
|
||||
# dataset
|
||||
conf_data = conf_loader['dataset']
|
||||
|
@ -40,18 +43,24 @@ def get_data(conf_loader:Config)\
|
|||
load_test = conf_loader['load_test']
|
||||
test_batch = conf_loader['test_batch']
|
||||
test_workers = conf_loader['test_workers']
|
||||
conf_apex = conf_loader['apex']
|
||||
# endregion
|
||||
|
||||
ds_provider = create_dataset_provider(conf_data)
|
||||
|
||||
apex = apex_utils.ApexUtils(conf_apex, logger)
|
||||
|
||||
train_dl, val_dl, test_dl = get_dataloaders(ds_provider,
|
||||
load_train=load_train, train_batch_size=train_batch,
|
||||
load_test=load_test, test_batch_size=test_batch,
|
||||
aug=aug, cutout=cutout, val_ratio=val_ratio, val_fold=val_fold,
|
||||
train_workers=train_workers, test_workers=test_workers,
|
||||
max_batches=max_batches)
|
||||
max_batches=max_batches, apex=apex)
|
||||
|
||||
assert train_dl is not None
|
||||
|
||||
logger.popd()
|
||||
|
||||
return train_dl, val_dl, test_dl
|
||||
|
||||
def create_dataset_provider(conf_data:Config)->DatasetProvider:
|
||||
|
@ -67,20 +76,25 @@ def create_dataset_provider(conf_data:Config)->DatasetProvider:
|
|||
def get_dataloaders(ds_provider:DatasetProvider,
|
||||
load_train:bool, train_batch_size:int,
|
||||
load_test:bool, test_batch_size:int,
|
||||
aug, cutout:int, val_ratio:float, val_fold=0,
|
||||
train_workers:Optional[int]=None, test_workers:Optional[int]=None,
|
||||
aug, cutout:int, val_ratio:float, apex:apex_utils.ApexUtils,
|
||||
val_fold=0, train_workers:Optional[int]=None, test_workers:Optional[int]=None,
|
||||
target_lb=-1, max_batches:int=-1) \
|
||||
-> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
|
||||
|
||||
# if debugging in vscode, workers > 0 gets termination
|
||||
default_workers = 4
|
||||
if utils.is_debugging():
|
||||
train_workers = test_workers = 0
|
||||
logger.warn({'debugger': True})
|
||||
if train_workers is None:
|
||||
train_workers = 4 # following NVidia DeepLearningExamples
|
||||
train_workers = default_workers # following NVidia DeepLearningExamples
|
||||
if test_workers is None:
|
||||
test_workers = 4
|
||||
logger.info({'train_workers': train_workers, 'test_workers':test_workers})
|
||||
test_workers = default_workers
|
||||
|
||||
train_workers = round((1-val_ratio)*train_workers)
|
||||
val_workers = round(val_ratio*train_workers)
|
||||
logger.info({'train_workers': train_workers, 'val_workers': val_workers,
|
||||
'test_workers':test_workers})
|
||||
|
||||
transform_train, transform_test = ds_provider.get_transforms()
|
||||
add_named_augs(transform_train, aug, cutout)
|
||||
|
@ -103,27 +117,40 @@ def get_dataloaders(ds_provider:DatasetProvider,
|
|||
if trainset:
|
||||
# sample validation set from trainset if cv_ratio > 0
|
||||
train_sampler, valid_sampler = _get_sampler(trainset, val_ratio=val_ratio,
|
||||
shuffle=True,
|
||||
shuffle=True, apex=apex,
|
||||
max_items=max_train_fold)
|
||||
logger.info({'train_sampler_world_size':train_sampler.world_size,
|
||||
'train_sampler_rank':train_sampler.rank,
|
||||
'train_sampler_len': len(train_sampler)})
|
||||
if valid_sampler:
|
||||
logger.info({'valid_sampler_world_size':valid_sampler.world_size,
|
||||
'valid_sampler_rank':valid_sampler.rank,
|
||||
'valid_sampler_len': len(valid_sampler)
|
||||
})
|
||||
|
||||
# shuffle is performed by sampler at each epoch
|
||||
trainloader = DataLoader(trainset,
|
||||
batch_size=train_batch_size, shuffle=False,
|
||||
num_workers=round((1-val_ratio)*train_workers),
|
||||
num_workers=train_workers,
|
||||
pin_memory=True,
|
||||
sampler=train_sampler, drop_last=False) # TODO: original paper has this True
|
||||
|
||||
if val_ratio > 0.0:
|
||||
validloader = DataLoader(trainset,
|
||||
batch_size=train_batch_size, shuffle=False,
|
||||
num_workers=round(val_ratio*train_workers), # if val_ratio = 0.5, then both sets re same
|
||||
pin_memory=True, #TODO: set n_workers per ratio?
|
||||
num_workers=val_workers,
|
||||
pin_memory=True,
|
||||
sampler=valid_sampler, drop_last=False)
|
||||
# else validloader is left as None
|
||||
if testset:
|
||||
test_sampler, _ = _get_sampler(testset, val_ratio=0.0,
|
||||
shuffle=False,
|
||||
test_sampler, test_val_sampler = _get_sampler(testset, val_ratio=None,
|
||||
shuffle=False, apex=apex,
|
||||
max_items=max_test_fold)
|
||||
logger.info({'test_sampler_world_size':test_sampler.world_size,
|
||||
'test_sampler_rank':test_sampler.rank,
|
||||
'test_sampler_len': len(test_sampler)})
|
||||
assert test_val_sampler is None
|
||||
|
||||
testloader = DataLoader(testset,
|
||||
batch_size=test_batch_size, shuffle=False,
|
||||
num_workers=test_workers,
|
||||
|
@ -169,17 +196,21 @@ def _get_datasets(ds_provider:DatasetProvider, load_train:bool, load_test:bool,
|
|||
|
||||
# target_lb allows to filter dataset for a specific class, not used
|
||||
def _get_sampler(dataset:Dataset, val_ratio:Optional[float], shuffle:bool,
|
||||
max_items:Optional[int])->Tuple[Sampler, Optional[Sampler]]:
|
||||
max_items:Optional[int], apex:apex_utils.ApexUtils)\
|
||||
->Tuple[DistributedStratifiedSampler, Optional[DistributedStratifiedSampler]]:
|
||||
|
||||
world_size, global_rank = apex.world_size, apex.global_rank
|
||||
|
||||
# we cannot not shuffle just for train or just val because of in distributed mode both must come from same shrad
|
||||
train_sampler = DistributedStratifiedSampler(dataset,
|
||||
val_ratio=val_ratio, is_val=False, shuffle=shuffle,
|
||||
max_items=max_items)
|
||||
max_items=max_items, world_size=world_size, rank=global_rank)
|
||||
|
||||
valid_sampler = DistributedStratifiedSampler(dataset,
|
||||
val_ratio=val_ratio, is_val=True, shuffle=shuffle,
|
||||
max_items=max_items) \
|
||||
max_items=max_items, world_size=world_size, rank=global_rank) \
|
||||
if val_ratio is not None else None
|
||||
|
||||
|
||||
return train_sampler, valid_sampler
|
||||
|
||||
|
||||
|
|
|
@ -10,9 +10,9 @@ import numpy as np
|
|||
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
|
||||
|
||||
class DistributedStratifiedSampler(Sampler):
|
||||
def __init__(self, dataset:Dataset, num_replicas:Optional[int]=None,
|
||||
def __init__(self, dataset:Dataset, world_size:Optional[int]=None,
|
||||
rank:Optional[int]=None, shuffle=True,
|
||||
val_ratio=0.0, is_val=False, auto_epoch=True,
|
||||
val_ratio:Optional[float]=0.0, is_val=False, auto_epoch=True,
|
||||
max_items:Optional[int]=None):
|
||||
"""Performs stratified sampling of dataset for each replica in the distributed as well as non-distributed setting. If validation split is needed then yet another stratified sampling within replica's split is performed to further obtain the train/validation splits.
|
||||
|
||||
|
@ -26,7 +26,7 @@ class DistributedStratifiedSampler(Sampler):
|
|||
dataset -- PyTorch dataset like object
|
||||
|
||||
Keyword Arguments:
|
||||
num_replicas -- Total number of replicas running in distributed setting, if None then auto-detect, 1 for non distributed setting (default: {None})
|
||||
world_size -- Total number of replicas running in distributed setting, if None then auto-detect, 1 for non distributed setting (default: {None})
|
||||
rank -- Global rank of this replica, if None then auto-detect, 0 for non distributed setting (default: {None})
|
||||
shuffle {bool} -- If True then suffle at every epoch (default: {True})
|
||||
val_ratio {float} -- If you want to create validation split then set to > 0 (default: {0.0})
|
||||
|
@ -39,23 +39,25 @@ class DistributedStratifiedSampler(Sampler):
|
|||
# cifar10 amd DatasetFolder has this attribute, for others it may be easy to add from outside
|
||||
assert hasattr(dataset, 'targets') and dataset.targets is not None, 'dataset needs to have targets attribute to work with this sampler'
|
||||
|
||||
if num_replicas is None:
|
||||
if world_size is None:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
num_replicas = dist.get_world_size()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
num_replicas = 1
|
||||
world_size = 1
|
||||
if rank is None:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
if val_ratio is None:
|
||||
val_ratio = 0.0
|
||||
|
||||
assert num_replicas >= 1
|
||||
assert rank >= 0 and rank < num_replicas
|
||||
assert world_size >= 1
|
||||
assert rank >= 0 and rank < world_size
|
||||
assert val_ratio < 1.0 and val_ratio >= 0.0
|
||||
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.epoch = -1
|
||||
self.auto_epoch = auto_epoch
|
||||
|
@ -67,8 +69,8 @@ class DistributedStratifiedSampler(Sampler):
|
|||
self.is_val = is_val
|
||||
|
||||
# computing duplications we needs
|
||||
self.replica_len = self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas))
|
||||
self.total_size = self.replica_len_full * self.num_replicas
|
||||
self.replica_len = self.replica_len_full = int(math.ceil(float(self.data_len)/self.world_size))
|
||||
self.total_size = self.replica_len_full * self.world_size
|
||||
assert self.total_size >= self.data_len
|
||||
|
||||
if self.max_items is not None:
|
||||
|
@ -104,9 +106,9 @@ class DistributedStratifiedSampler(Sampler):
|
|||
def _replica_fold(self, indices:np.ndarray, targets:np.ndarray)\
|
||||
->Tuple[np.ndarray, np.ndarray]:
|
||||
|
||||
if self.num_replicas > 1:
|
||||
if self.world_size > 1:
|
||||
replica_fold_idxs = None
|
||||
rfolder = StratifiedKFold(n_splits=self.num_replicas, shuffle=False)
|
||||
rfolder = StratifiedKFold(n_splits=self.world_size, shuffle=False)
|
||||
folds = rfolder.split(indices, targets)
|
||||
for _ in range(self.rank + 1):
|
||||
other_fold_idxs, replica_fold_idxs = next(folds)
|
||||
|
@ -116,7 +118,7 @@ class DistributedStratifiedSampler(Sampler):
|
|||
|
||||
return indices[replica_fold_idxs], targets[replica_fold_idxs]
|
||||
else:
|
||||
assert self.num_replicas == 1
|
||||
assert self.world_size == 1
|
||||
return indices, targets
|
||||
|
||||
|
||||
|
|
|
@ -294,7 +294,9 @@ class Search:
|
|||
])
|
||||
|
||||
def get_data(self, conf_loader:Config)->Tuple[Optional[DataLoader], Optional[DataLoader]]:
|
||||
# first get from cache
|
||||
train_ds, val_ds = self._data_cache.get(id(conf_loader), (None, None))
|
||||
# if not found in cache then create
|
||||
if train_ds is None:
|
||||
train_ds, val_ds, _ = data.get_data(conf_loader)
|
||||
self._data_cache[id(conf_loader)] = (train_ds, val_ds)
|
||||
|
|
|
@ -21,7 +21,7 @@ common:
|
|||
# "ray start --head --redis-port=6379"
|
||||
redis: null
|
||||
apex: # this is overriden in search and eval individually
|
||||
enabled: False # global switch to disable anything apex
|
||||
enabled: False # global switch to disable everything apex
|
||||
distributed_enabled: True # enable/disable distributed mode
|
||||
mixed_prec_enabled: True # switch to disable amp mixed precision
|
||||
gpus: '' # use GPU IDs specified here (comma separated), if '' then use all GPUs
|
||||
|
@ -75,6 +75,8 @@ nas:
|
|||
n_cells: 20 # number of cells
|
||||
aux_weight: 0.4 # weight for loss from auxiliary towers in test time arch
|
||||
loader:
|
||||
apex:
|
||||
_copy: '../../trainer/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 16 # cutout length, use cutout augmentation when > 0
|
||||
load_train: True # load train split of dataset
|
||||
|
@ -91,7 +93,6 @@ nas:
|
|||
trainer:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
enabled: False
|
||||
aux_weight: '_copy: /nas/eval/model_desc/aux_weight'
|
||||
drop_path_prob: 0.2 # probability that given edge will be dropped
|
||||
grad_clip: 5.0 # grads above this value is clipped
|
||||
|
@ -135,8 +136,6 @@ nas:
|
|||
seed_train:
|
||||
trainer:
|
||||
_copy: '/nas/eval/trainer'
|
||||
apex:
|
||||
enabled: False
|
||||
title: 'seed_train'
|
||||
epochs: 0 # number of epochs model will be trained before search
|
||||
aux_weight: 0.0
|
||||
|
@ -148,8 +147,6 @@ nas:
|
|||
post_train:
|
||||
trainer:
|
||||
_copy: '/nas/eval/trainer'
|
||||
apex:
|
||||
enabled: False
|
||||
title: 'post_train'
|
||||
epochs: 0 # number of epochs model will be trained after search
|
||||
aux_weight: 0.0
|
||||
|
@ -186,6 +183,8 @@ nas:
|
|||
n_cells: 8 # number of cells
|
||||
aux_weight: 0.0 # weight for loss from auxiliary towers in test time arch
|
||||
loader:
|
||||
apex:
|
||||
_copy: '../../trainer/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 0 # cutout length, use cutout augmentation when > 0
|
||||
load_train: True # load train split of dataset
|
||||
|
@ -246,6 +245,8 @@ autoaug:
|
|||
num_search: 200
|
||||
num_result_per_cv: 10 # after conducting N trials, we will chose the results of top num_result_per_cv
|
||||
loader:
|
||||
apex:
|
||||
_copy: '/common/apex'
|
||||
aug: '' # additional augmentations to use
|
||||
cutout: 16 # cutout length, use cutout augmentation when > 0
|
||||
epochs: 50
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
from archai.datasets.distributed_stratified_sampler import DistributedStratifiedSampler
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from archai.common.config import Config
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
from collections import Counter
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from archai.datasets.distributed_stratified_sampler import DistributedStratifiedSampler
|
||||
from archai.datasets import data
|
||||
from archai.common import common
|
||||
|
||||
class ListDataset(Dataset):
|
||||
def __init__(self, x, y, transform=None):
|
||||
self.x = x
|
||||
|
@ -82,5 +87,11 @@ def test_combinations():
|
|||
elapsed = time.time()-st
|
||||
print('elapsed', elapsed, 'combs', combs)
|
||||
|
||||
def imagenet_test():
|
||||
conf = Config('confs/algos/darts.yaml;confs/datasets/imagenet.yaml',)
|
||||
conf_loader = conf['nas']['eval']['loader']
|
||||
dl_train, *_ = data.get_data(conf_loader)
|
||||
|
||||
|
||||
_dist_no_val(1, 100, val_ratio=0.1)
|
||||
test_combinations()
|
Загрузка…
Ссылка в новой задаче