зеркало из https://github.com/microsoft/esvit.git
202 строки
7.8 KiB
Python
202 строки
7.8 KiB
Python
|
"""
|
||
|
Inspired and partially adopted from
|
||
|
https://github.com/zhmiao/OpenLongTailRecognition-OLTR/blob/master/data/ClassAwareSampler.py
|
||
|
"""
|
||
|
import math
|
||
|
import random
|
||
|
import logging
|
||
|
import torch.distributed as dist
|
||
|
from torch.utils.data.sampler import Sampler
|
||
|
|
||
|
|
||
|
class CycleIter:
|
||
|
|
||
|
def __init__(self, data, shuffle=True):
|
||
|
self.data_list = list(data)
|
||
|
self.length = len(self.data_list)
|
||
|
self.i = self.length - 1
|
||
|
self.shuffle = shuffle
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def __next__(self):
|
||
|
self.i += 1
|
||
|
|
||
|
if self.i == self.length:
|
||
|
self.i = 0
|
||
|
if self.shuffle:
|
||
|
random.shuffle(self.data_list)
|
||
|
|
||
|
return self.data_list[self.i]
|
||
|
|
||
|
|
||
|
class ClassAwareTargetSizeSampler(Sampler):
|
||
|
def __init__(self, dataset, target_size, num_replicas=None, rank=None, shuffle=True):
|
||
|
super().__init__(dataset)
|
||
|
if num_replicas is None:
|
||
|
if not dist.is_available():
|
||
|
raise RuntimeError("Requires distributed package to be available")
|
||
|
num_replicas = dist.get_world_size()
|
||
|
if rank is None:
|
||
|
if not dist.is_available():
|
||
|
raise RuntimeError("Requires distributed package to be available")
|
||
|
rank = dist.get_rank()
|
||
|
|
||
|
self.dataset = dataset
|
||
|
assert len(dataset) >= target_size, "Dataset size %d is smaller than target size %d." % (
|
||
|
len(dataset), target_size)
|
||
|
|
||
|
if hasattr(self.dataset, "num_classes"):
|
||
|
self.num_classes = self.dataset.num_classes()
|
||
|
else:
|
||
|
raise RuntimeError("Dataset does not provide the info of class number!")
|
||
|
|
||
|
if hasattr(self.dataset, "get_class_boundaries"):
|
||
|
self.class_boundaries = self.dataset.get_class_boundaries()
|
||
|
else:
|
||
|
raise RuntimeError("Dataset does not provide the info of class boundary!")
|
||
|
assert len(self.class_boundaries) == self.num_classes, "Number of classes and boundaries not consistent!"
|
||
|
|
||
|
if target_size > 0:
|
||
|
# Try to evenly draw the same number of data samples from each class
|
||
|
self.num_samples_cls = int(math.ceil(target_size * 1.0 / self.num_classes))
|
||
|
class_sizes = [(high - low + 1) for low, high in self.class_boundaries]
|
||
|
self.target_size = sum([min(self.num_samples_cls, s) for s in class_sizes])
|
||
|
|
||
|
# Increase the number of data samples drawn from each class until we reach the target size
|
||
|
while self.target_size < target_size:
|
||
|
self.num_samples_cls += 1
|
||
|
self.target_size = sum([min(self.num_samples_cls, s) for s in class_sizes])
|
||
|
else:
|
||
|
# Sample all available data
|
||
|
self.target_size = len(dataset)
|
||
|
self.num_samples_cls = max([(high - low + 1) for low, high in self.class_boundaries])
|
||
|
|
||
|
self.epoch = 0
|
||
|
self.rank = rank
|
||
|
self.num_replicas = num_replicas
|
||
|
self.num_samples = int(math.ceil(self.target_size * 1.0 / self.num_replicas))
|
||
|
self.total_size = self.num_samples * self.num_replicas
|
||
|
self.shuffle = shuffle
|
||
|
|
||
|
random.seed(self.epoch)
|
||
|
self.indices = []
|
||
|
for j in range(self.num_classes):
|
||
|
lower_bound, upper_bound = self.class_boundaries[j]
|
||
|
indices_ = list(range(lower_bound, upper_bound))
|
||
|
if self.shuffle:
|
||
|
random.shuffle(indices_)
|
||
|
self.indices += indices_[:min(self.num_samples_cls, len(indices_))]
|
||
|
|
||
|
# add extra samples to make it evenly divisible
|
||
|
self.indices += self.indices[:(self.total_size - len(self.indices))]
|
||
|
assert len(self.indices) == self.total_size
|
||
|
logging.info("[ClassAwareTargetSizeSampler] sampled data size: %d" % self.total_size)
|
||
|
|
||
|
def __iter__(self):
|
||
|
# deterministically shuffle based on epoch
|
||
|
random.seed(self.epoch)
|
||
|
random.shuffle(self.indices)
|
||
|
|
||
|
# subsample
|
||
|
indices = self.indices[self.rank:self.total_size:self.num_replicas]
|
||
|
assert len(indices) == self.num_samples
|
||
|
|
||
|
return iter(indices)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.num_samples
|
||
|
|
||
|
def set_epoch(self, epoch):
|
||
|
self.epoch = epoch
|
||
|
|
||
|
|
||
|
class ClassAwareDistributedSampler(Sampler):
|
||
|
|
||
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_samples_cls=1000):
|
||
|
super().__init__(dataset)
|
||
|
if num_replicas is None:
|
||
|
if not dist.is_available():
|
||
|
raise RuntimeError("Requires distributed package to be available")
|
||
|
num_replicas = dist.get_world_size()
|
||
|
if rank is None:
|
||
|
if not dist.is_available():
|
||
|
raise RuntimeError("Requires distributed package to be available")
|
||
|
rank = dist.get_rank()
|
||
|
|
||
|
self.dataset = dataset
|
||
|
if hasattr(self.dataset, "num_classes"):
|
||
|
self.num_classes = self.dataset.num_classes()
|
||
|
else:
|
||
|
raise RuntimeError("Dataset does not provide the info of class number!")
|
||
|
if hasattr(self.dataset, "get_class_boundaries"):
|
||
|
self.class_boundaries = self.dataset.get_class_boundaries()
|
||
|
else:
|
||
|
raise RuntimeError("Dataset does not provide the info of class boundary!")
|
||
|
|
||
|
self.num_replicas = num_replicas
|
||
|
self.rank = rank
|
||
|
self.epoch = 0
|
||
|
self.num_samples_cls = num_samples_cls
|
||
|
self.num_samples = int(math.ceil(self.num_samples_cls * self.num_classes * 1.0 / self.num_replicas))
|
||
|
self.total_size = self.num_samples * self.num_replicas
|
||
|
self.shuffle = shuffle
|
||
|
|
||
|
def __iter__(self):
|
||
|
# deterministically shuffle based on epoch
|
||
|
random.seed(self.epoch)
|
||
|
|
||
|
class_iter = CycleIter(range(self.num_classes), shuffle=self.shuffle)
|
||
|
data_iter = [CycleIter(range(bound[0], bound[1]), shuffle=self.shuffle) for bound in self.class_boundaries]
|
||
|
|
||
|
indices = []
|
||
|
for i in range(self.num_samples_cls):
|
||
|
for j in range(self.num_classes):
|
||
|
indices.append(next(data_iter[next(class_iter)]))
|
||
|
|
||
|
# add extra samples to make it evenly divisible
|
||
|
indices += indices[:(self.total_size - len(indices))]
|
||
|
assert len(indices) == self.total_size
|
||
|
|
||
|
# subsample
|
||
|
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||
|
assert len(indices) == self.num_samples
|
||
|
|
||
|
return iter(indices)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.num_samples
|
||
|
|
||
|
def set_epoch(self, epoch):
|
||
|
self.epoch = epoch
|
||
|
|
||
|
|
||
|
class ClassAwareAverageSampler(ClassAwareDistributedSampler):
|
||
|
|
||
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
||
|
if hasattr(dataset, "get_class_boundaries"):
|
||
|
class_boundaries = dataset.get_class_boundaries()
|
||
|
class_sizes = [bound[1] - bound[0] for bound in class_boundaries]
|
||
|
average_size = int(sum(class_sizes) * 1.0 / len(class_sizes))
|
||
|
logging.info("[ClassAwareAverageSampler] Average class size: %d", average_size)
|
||
|
else:
|
||
|
raise RuntimeError("Dataset does not provide the info of class boundary!")
|
||
|
|
||
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, num_samples_cls=average_size)
|
||
|
|
||
|
|
||
|
class ClassAwareMedianSampler(ClassAwareDistributedSampler):
|
||
|
|
||
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
||
|
if hasattr(dataset, "get_class_boundaries"):
|
||
|
class_boundaries = dataset.get_class_boundaries()
|
||
|
class_sizes = sorted([bound[1] - bound[0] for bound in class_boundaries])
|
||
|
median_size = class_sizes[len(class_sizes) // 2]
|
||
|
logging.info("[ClassAwareMedianSampler] Median class size: %d", median_size)
|
||
|
else:
|
||
|
raise RuntimeError("Dataset does not provide the info of class boundary!")
|
||
|
|
||
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, num_samples_cls=median_size)
|
||
|
|