зеркало из https://github.com/microsoft/archai.git
Fix bugs in dist stratified sampler, apex_install
This commit is contained in:
Родитель
5a4ef14edb
Коммит
7a44b9fe37
|
@ -32,7 +32,7 @@ class DistributedStratifiedSampler(Sampler):
|
||||||
val_ratio {float} -- If you want to create validation split then set to > 0 (default: {0.0})
|
val_ratio {float} -- If you want to create validation split then set to > 0 (default: {0.0})
|
||||||
is_val {bool} -- If True then validation split is returned set to val_ratio otherwise main split is returned (default: {False})
|
is_val {bool} -- If True then validation split is returned set to val_ratio otherwise main split is returned (default: {False})
|
||||||
auto_epoch {bool} -- if True then automatically count epoch for each new iteration eliminating the need to call set_epoch() in distributed setting (default: {True})
|
auto_epoch {bool} -- if True then automatically count epoch for each new iteration eliminating the need to call set_epoch() in distributed setting (default: {True})
|
||||||
max_items -- if not None then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset)
|
max_items -- if >= 0 then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,17 +61,17 @@ class DistributedStratifiedSampler(Sampler):
|
||||||
self.auto_epoch = auto_epoch
|
self.auto_epoch = auto_epoch
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.data_len = len(self.dataset)
|
self.data_len = len(self.dataset)
|
||||||
self.max_items = max_items
|
self.max_items = max_items if max_items is not None and max_items >= 0 else None
|
||||||
assert self.data_len == len(dataset.targets)
|
assert self.data_len == len(dataset.targets)
|
||||||
self.val_ratio = val_ratio
|
self.val_ratio = val_ratio
|
||||||
self.is_val = is_val
|
self.is_val = is_val
|
||||||
|
|
||||||
# computing duplications we needs
|
# computing duplications we needs
|
||||||
self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas))
|
self.replica_len = self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas))
|
||||||
self.total_size = self.replica_len_full * self.num_replicas
|
self.total_size = self.replica_len_full * self.num_replicas
|
||||||
assert self.total_size >= self.data_len
|
assert self.total_size >= self.data_len
|
||||||
|
|
||||||
if self.max_items:
|
if self.max_items is not None:
|
||||||
self.replica_len = min(self.replica_len_full, self.max_items)
|
self.replica_len = min(self.replica_len_full, self.max_items)
|
||||||
|
|
||||||
self.main_split_len = int(math.floor(self.replica_len*(1-val_ratio)))
|
self.main_split_len = int(math.floor(self.replica_len*(1-val_ratio)))
|
||||||
|
@ -144,7 +144,7 @@ class DistributedStratifiedSampler(Sampler):
|
||||||
|
|
||||||
def _limit(self, indices:np.ndarray, targets:np.ndarray, max_items:Optional[int])\
|
def _limit(self, indices:np.ndarray, targets:np.ndarray, max_items:Optional[int])\
|
||||||
->Tuple[np.ndarray, np.ndarray]:
|
->Tuple[np.ndarray, np.ndarray]:
|
||||||
if max_items:
|
if max_items is not None:
|
||||||
return self._split(indices, targets, max_items, True)
|
return self._split(indices, targets, max_items, True)
|
||||||
return indices, targets
|
return indices, targets
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
#fail if any errors
|
#fail if any errors
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
bash "${BASH_SOURCE%/*}/apex_install.sh"
|
||||||
|
|
||||||
nvidia-smi --list-gpus
|
nvidia-smi --list-gpus
|
||||||
|
|
||||||
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||||
|
@ -13,4 +15,4 @@ echo "*****************************************"
|
||||||
set -e -o xtrace
|
set -e -o xtrace
|
||||||
|
|
||||||
|
|
||||||
python -m torch.distributed.launch --nproc_per_node=$gpu_count scripts/main.py $*
|
python -m torch.distributed.launch --nproc_per_node=$gpu_count scripts/main.py --nas.eval.trainer.apex.enabled True $*
|
Загрузка…
Ссылка в новой задаче