From f2668bcf136a5fc5d5b95ed7f1155082428594dd Mon Sep 17 00:00:00 2001 From: Debadeepta Dey Date: Thu, 2 Dec 2021 12:26:01 -0800 Subject: [PATCH] Added ninapro dataset provider. --- .vscode/launch.json | 2 +- archai/datasets/__init__.py | 3 +- archai/datasets/providers/ninapro_provider.py | 90 +++++++++++++++++++ confs/datasets/ninapro.yaml | 8 ++ 4 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 archai/datasets/providers/ninapro_provider.py create mode 100644 confs/datasets/ninapro.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 052b2cd9..c4a4ad4c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -220,7 +220,7 @@ "request": "launch", "program": "${cwd}/scripts/main.py", "console": "integratedTerminal", - "args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "sphericalcifar100"] + "args": ["--full", "--algos", "darts_space_constant_random_archs", "--datasets", "ninapro"] }, { "name": "Proxynas-Darts-Space-Full", diff --git a/archai/datasets/__init__.py b/archai/datasets/__init__.py index 317fc0ca..8087999a 100644 --- a/archai/datasets/__init__.py +++ b/archai/datasets/__init__.py @@ -14,4 +14,5 @@ from .providers.flower102_provider import Flower102Provider from .providers.imagenet16120_provider import ImageNet16120Provider from .providers.synthetic_cifar10_provider import SyntheticCifar10Provider from .providers.intel_image_provider import IntelImageProvider -from .providers.spherical_cifar100_provider import SphericalCifar100Provider \ No newline at end of file +from .providers.spherical_cifar100_provider import SphericalCifar100Provider +from .providers.ninapro_provider import NinaproProvider \ No newline at end of file diff --git a/archai/datasets/providers/ninapro_provider.py b/archai/datasets/providers/ninapro_provider.py new file mode 100644 index 00000000..b4a980c0 --- /dev/null +++ b/archai/datasets/providers/ninapro_provider.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import List, Tuple, Union, Optional +import os +import gzip +import pickle +import numpy as np + +from overrides import overrides, EnforceOverrides +import torch +import torch.utils.data as data_utils +from torch.utils.data.dataset import Dataset + +import torchvision +from torchvision.transforms import transforms + +from archai.datasets.dataset_provider import DatasetProvider, register_dataset_provider, TrainTestDatasets +from archai.common.config import Config +from archai.common import utils + + +def load_ninapro_data(path, train=True): + ''' Modified from + https://github.com/rtu715/NAS-Bench-360/blob/ba7ff6bd0762073d1ce49207b95245c5c742b567/backbone/data_utils/load_data.py#L396 ''' + + trainset = load_ninapro(path, 'train') + valset = load_ninapro(path, 'val') + testset = load_ninapro(path, 'test') + + if train: + return trainset, valset, testset + else: + targets = torch.cat((trainset.targets, valset.targets)) + trainset = data_utils.ConcatDataset([trainset, valset]) + trainset.targets = targets # for compatibility with stratified sampler + + return trainset, None, testset + + +def load_ninapro(path, whichset): + ''' Modified from + https://github.com/rtu715/NAS-Bench-360/blob/ba7ff6bd0762073d1ce49207b95245c5c742b567/backbone/data_utils/load_data.py#L396 ''' + + data_str = 'ninapro_' + whichset + '.npy' + label_str = 'label_' + whichset + '.npy' + + data = np.load(os.path.join(path, data_str), + encoding="bytes", allow_pickle=True) + labels = np.load(os.path.join(path, label_str), encoding="bytes", allow_pickle=True) + + data = np.transpose(data, (0, 2, 1)) + data = data[:, None, :, :] + data = torch.from_numpy(data.astype(np.float32)) + labels = torch.from_numpy(labels.astype(np.int64)) + + all_data = data_utils.TensorDataset(data, labels) + all_data.targets = labels # for compatibility with stratified data sampler + return all_data + + + +class NinaproProvider(DatasetProvider): + def __init__(self, conf_dataset:Config): + super().__init__(conf_dataset) + self._dataroot = utils.full_path(conf_dataset['dataroot']) + + @overrides + def get_datasets(self, load_train:bool, load_test:bool, + transform_train, transform_test)->TrainTestDatasets: + trainset, testset = None, None + + path_to_data = os.path.join(self._dataroot, 'ninapro') + + # load the dataset but without any validation split + trainset, _, testset = load_ninapro_data(path_to_data, train=False) + + return trainset, testset + + @overrides + def get_transforms(self)->tuple: + # return empty transforms since we have preprocessed data + train_transf = [] + test_transf = [] + + train_transform = transforms.Compose(train_transf) + test_transform = transforms.Compose(test_transf) + return train_transform, test_transform + +register_dataset_provider('ninapro', NinaproProvider) \ No newline at end of file diff --git a/confs/datasets/ninapro.yaml b/confs/datasets/ninapro.yaml new file mode 100644 index 00000000..e9771f9b --- /dev/null +++ b/confs/datasets/ninapro.yaml @@ -0,0 +1,8 @@ +__include__: './dataroot.yaml' # default dataset settings are for cifar + +dataset: + name: 'ninapro' + n_classes: 18 + channels: 1 # number of channels in image + max_batches: -1 # if >= 0 then only these many batches are generated (useful for debugging) + storage_name: 'ninapro' # name of folder or tar file to copy from cloud storage \ No newline at end of file