зеркало из https://github.com/microsoft/archai.git
123 строки
4.9 KiB
Python
123 строки
4.9 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
|
|
from archai.common.config import Config
|
|
import numpy as np
|
|
import math
|
|
import time
|
|
from collections import Counter
|
|
import random
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from archai.datasets.distributed_stratified_sampler import DistributedStratifiedSampler
|
|
from archai.datasets import data
|
|
from archai.common import common
|
|
|
|
class ListDataset(Dataset):
|
|
def __init__(self, x, y, transform=None):
|
|
self.x = x
|
|
self.targets = self.y = np.array(y)
|
|
self.transform = transform
|
|
|
|
def __getitem__(self, index):
|
|
return self.x[index], self.y[index]
|
|
|
|
def __len__(self):
|
|
return len(self.x)
|
|
|
|
def _dist_no_val(rep_count:int, data_len=1000, labels_len=2, val_ratio=0.0):
|
|
x = np.random.randint(-data_len, data_len, data_len)
|
|
labels = np.array(range(labels_len))
|
|
y = np.repeat(labels, math.ceil(float(data_len)/labels_len))[:data_len]
|
|
np.random.shuffle(y)
|
|
dataset = ListDataset(x, y)
|
|
|
|
train_samplers, val_samplers = [], []
|
|
for i in range(rep_count):
|
|
train_samplers.append(DistributedStratifiedSampler(dataset,
|
|
num_replicas=rep_count,
|
|
rank=i,
|
|
val_ratio=val_ratio,
|
|
is_val=False))
|
|
val_samplers.append(DistributedStratifiedSampler(dataset,
|
|
num_replicas=rep_count,
|
|
rank=i,
|
|
val_ratio=val_ratio,
|
|
is_val=True))
|
|
tl = [list(iter(s)) for s in train_samplers]
|
|
vl = [list(iter(s)) for s in val_samplers]
|
|
|
|
l = [tli+vli for tli, vli in zip(tl,vl)] # combile train val
|
|
all_len = sum((len(li) for li in l))
|
|
u = set(i for li in l for i in li)
|
|
|
|
# verify stratification
|
|
for vli, tli in zip(vl, tl):
|
|
vlic = Counter(dataset.targets[vli])
|
|
assert len(vlic.keys()) == labels_len
|
|
assert max(vlic.values())-min(vlic.values()) <=2
|
|
tlic = Counter(dataset.targets[tli])
|
|
assert len(tlic.keys()) == labels_len
|
|
assert max(tlic.values())-min(tlic.values()) <=2
|
|
|
|
# below means all indices are equally divided between shards
|
|
assert len(set((len(li) for li in l)))==1 # all shards equal
|
|
assert all((len(li)>=len(dataset)/rep_count for li in l))
|
|
assert all((len(li)<=len(dataset)/rep_count+1 for li in l))
|
|
assert min(u)==0
|
|
assert max(u)==len(x)-1
|
|
assert len(u)==len(x)
|
|
assert all((float(len(vli))/(len(vli)+len(tli))>=val_ratio for vli, tli in zip(vl, tl)))
|
|
assert all(((len(vli)-1.0)/(len(vli)+len(tli))<=val_ratio for vli, tli in zip(vl, tl)))
|
|
assert all((len(set(vli).union(tli))==len(vli+tli) for vli, tli in zip(vl, tl)))
|
|
assert all_len <= math.ceil(len(x)/rep_count)*rep_count
|
|
|
|
def test_combinations():
|
|
st = time.time()
|
|
labels_len = 2
|
|
combs = 0
|
|
random.seed(0)
|
|
for data_len in (100, 1001, 17777):
|
|
max_rep = int(math.sqrt(data_len)*3)
|
|
for rep_count in range(1, max_rep, max(1, max_rep//17)):
|
|
for val_num in range(0, random.randint(0,5)):
|
|
combs += 1
|
|
val_ratio = val_num/11.0 # good to have prime numbers
|
|
if math.floor(val_ratio*data_len/rep_count) >= labels_len:
|
|
_dist_no_val(rep_count=rep_count, val_ratio=val_ratio, data_len=data_len, labels_len=labels_len)
|
|
elapsed = time.time()-st
|
|
print('elapsed', elapsed, 'combs', combs)
|
|
|
|
def imagenet_test():
|
|
conf = Config('confs/algos/darts.yaml;confs/datasets/imagenet.yaml',)
|
|
conf_loader = conf['nas']['eval']['loader']
|
|
data_loaders = data.get_data(conf_loader)
|
|
|
|
|
|
def exclusion_test(data_len=32, labels_len=2, val_ratio=0.5):
|
|
x = np.array(range(data_len))
|
|
labels = np.array(range(labels_len))
|
|
y = np.repeat(labels, math.ceil(float(data_len)/labels_len))[:data_len]
|
|
np.random.shuffle(y)
|
|
dataset = ListDataset(x, y)
|
|
|
|
train_sampler = DistributedStratifiedSampler(dataset,
|
|
val_ratio=val_ratio, is_val=False, shuffle=True,
|
|
max_items=-1, world_size=1, rank=0)
|
|
|
|
valid_sampler = DistributedStratifiedSampler(dataset,
|
|
val_ratio=val_ratio, is_val=True, shuffle=True,
|
|
max_items=-1, world_size=1, rank=0)
|
|
tidx = list(train_sampler)
|
|
vidx = list(valid_sampler)
|
|
|
|
assert len(tidx) == len(vidx) == 16
|
|
assert all(ti not in vidx for ti in tidx)
|
|
# print(len(tidx), tidx)
|
|
# print(len(vidx), vidx)
|
|
|
|
exclusion_test()
|
|
_dist_no_val(1, 100, val_ratio=0.1)
|
|
test_combinations() |