From 7a44b9fe3785a957e5692bd605169ae9270c6c5d Mon Sep 17 00:00:00 2001 From: Shital Shah Date: Thu, 23 Apr 2020 15:54:14 -0700 Subject: [PATCH] Fix bugs in dist stratified sampler, apex_install --- archai/datasets/distributed_stratified_sampler.py | 10 +++++----- {tools => scripts}/apex_install.sh | 0 scripts/dist_main.sh | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) rename {tools => scripts}/apex_install.sh (100%) mode change 100755 => 100644 diff --git a/archai/datasets/distributed_stratified_sampler.py b/archai/datasets/distributed_stratified_sampler.py index 900a6857..2f88a8cf 100644 --- a/archai/datasets/distributed_stratified_sampler.py +++ b/archai/datasets/distributed_stratified_sampler.py @@ -32,7 +32,7 @@ class DistributedStratifiedSampler(Sampler): val_ratio {float} -- If you want to create validation split then set to > 0 (default: {0.0}) is_val {bool} -- If True then validation split is returned set to val_ratio otherwise main split is returned (default: {False}) auto_epoch {bool} -- if True then automatically count epoch for each new iteration eliminating the need to call set_epoch() in distributed setting (default: {True}) - max_items -- if not None then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset) + max_items -- if >= 0 then dataset will be trimmed to these many items for each replica (useful to test on smaller dataset) """ @@ -61,17 +61,17 @@ class DistributedStratifiedSampler(Sampler): self.auto_epoch = auto_epoch self.shuffle = shuffle self.data_len = len(self.dataset) - self.max_items = max_items + self.max_items = max_items if max_items is not None and max_items >= 0 else None assert self.data_len == len(dataset.targets) self.val_ratio = val_ratio self.is_val = is_val # computing duplications we needs - self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas)) + self.replica_len = self.replica_len_full = int(math.ceil(float(self.data_len)/self.num_replicas)) self.total_size = self.replica_len_full * self.num_replicas assert self.total_size >= self.data_len - if self.max_items: + if self.max_items is not None: self.replica_len = min(self.replica_len_full, self.max_items) self.main_split_len = int(math.floor(self.replica_len*(1-val_ratio))) @@ -144,7 +144,7 @@ class DistributedStratifiedSampler(Sampler): def _limit(self, indices:np.ndarray, targets:np.ndarray, max_items:Optional[int])\ ->Tuple[np.ndarray, np.ndarray]: - if max_items: + if max_items is not None: return self._split(indices, targets, max_items, True) return indices, targets diff --git a/tools/apex_install.sh b/scripts/apex_install.sh old mode 100755 new mode 100644 similarity index 100% rename from tools/apex_install.sh rename to scripts/apex_install.sh diff --git a/scripts/dist_main.sh b/scripts/dist_main.sh index 723a55ff..25beb266 100644 --- a/scripts/dist_main.sh +++ b/scripts/dist_main.sh @@ -2,6 +2,8 @@ #fail if any errors set -e +bash "${BASH_SOURCE%/*}/apex_install.sh" + nvidia-smi --list-gpus gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) @@ -13,4 +15,4 @@ echo "*****************************************" set -e -o xtrace -python -m torch.distributed.launch --nproc_per_node=$gpu_count scripts/main.py $* \ No newline at end of file +python -m torch.distributed.launch --nproc_per_node=$gpu_count scripts/main.py --nas.eval.trainer.apex.enabled True $* \ No newline at end of file