Fix bugs in dist stratified sampler, apex_install

This commit is contained in:
Shital Shah 2020-04-23 15:54:14 -07:00
Родитель 5a4ef14edb
Коммит 7a44b9fe37
3 изменённых файлов: 8 добавлений и 6 удалений

Просмотреть файл

@ -32,7 +32,7 @@ class DistributedStratifiedSampler(Sampler):
val_ratio {float} -- If you want to create validation split then set to > 0 (default: {0.0})
is_val {bool} -- If True then validation split is returned set to val_ratio otherwise main split is returned (default: {False})
auto_epoch {bool} -- if True then automatically count epoch for each new iteration eliminating the need to call set_epoch() in distributed setting (default: {True})
max_items -- if not None then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset)
max_items -- if >= 0 then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset)
"""
@ -61,17 +61,17 @@ class DistributedStratifiedSampler(Sampler):
self.auto_epoch = auto_epoch
self.shuffle = shuffle
self.data_len = len(self.dataset)
self.max_items = max_items
self.max_items = max_items if max_items is not None and max_items >= 0 else None
assert self.data_len == len(dataset.targets)
self.val_ratio = val_ratio
self.is_val = is_val
# computing duplications we needs
self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas))
self.replica_len = self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas))
self.total_size = self.replica_len_full * self.num_replicas
assert self.total_size >= self.data_len
if self.max_items:
if self.max_items is not None:
self.replica_len = min(self.replica_len_full, self.max_items)
self.main_split_len = int(math.floor(self.replica_len*(1-val_ratio)))
@ -144,7 +144,7 @@ class DistributedStratifiedSampler(Sampler):
def _limit(self, indices:np.ndarray, targets:np.ndarray, max_items:Optional[int])\
->Tuple[np.ndarray, np.ndarray]:
if max_items:
if max_items is not None:
return self._split(indices, targets, max_items, True)
return indices, targets

0
tools/apex_install.sh → scripts/apex_install.sh Executable file → Normal file
Просмотреть файл

Просмотреть файл

@ -2,6 +2,8 @@
#fail if any errors
set -e
bash "${BASH_SOURCE%/*}/apex_install.sh"
nvidia-smi --list-gpus
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
@ -13,4 +15,4 @@ echo "*****************************************"
set -e -o xtrace
python -m torch.distributed.launch --nproc_per_node=$gpu_count scripts/main.py $*
python -m torch.distributed.launch --nproc_per_node=$gpu_count scripts/main.py --nas.eval.trainer.apex.enabled True $*