Remove unnecessary code for loading ImageNet
This commit is contained in:
Родитель
6082675145
Коммит
adee718328
|
@ -15,8 +15,6 @@ import torch
|
|||
# make sure your val directory is preprocessed to look like the train directory, e.g. by running this script
|
||||
# https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
|
||||
IMAGENET_LOC_ENV = "IMAGENET_DIR"
|
||||
IMAGENET_ON_PHILLY_DIR = "IMAGENET_DIR_PHILLY"
|
||||
IMAGENET_ON_AZURE_ENV = "IMAGENET_DIR_AZURE"
|
||||
|
||||
# list of all datasets
|
||||
DATASETS = ["imagenet", "imagenet32", "cifar10"]
|
||||
|
@ -25,10 +23,7 @@ DATASETS = ["imagenet", "imagenet32", "cifar10"]
|
|||
def get_dataset(dataset: str, split: str) -> Dataset:
|
||||
"""Return the dataset as a PyTorch Dataset object"""
|
||||
if dataset == "imagenet":
|
||||
if "PT_DATA_DIR" in os.environ or IMAGENET_ON_AZURE_ENV in os.environ: #running on Philly
|
||||
return _imagenet_remote(split)
|
||||
else:
|
||||
return _imagenet(split)
|
||||
return _imagenet(split)
|
||||
|
||||
elif dataset == "imagenet32":
|
||||
return _imagenet32(split)
|
||||
|
@ -84,32 +79,6 @@ def _cifar10(split: str) -> Dataset:
|
|||
else:
|
||||
raise Exception("Unknown split name.")
|
||||
|
||||
def _imagenet_remote(split: str) -> Dataset:
|
||||
if IMAGENET_ON_AZURE_ENV in os.environ:
|
||||
dir = os.environ[IMAGENET_ON_AZURE_ENV]
|
||||
else:
|
||||
dir = os.environ[IMAGENET_ON_PHILLY_DIR]
|
||||
|
||||
trainpath = os.path.join(dir, 'train.zip')
|
||||
train_map = os.path.join(dir, 'train_map.txt')
|
||||
valpath = os.path.join(dir, 'val.zip')
|
||||
val_map = os.path.join(dir, 'val_map.txt')
|
||||
|
||||
if split == "train":
|
||||
return ZipData(trainpath, train_map,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
]))
|
||||
elif split == "test":
|
||||
return ZipData(valpath, val_map,
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
]))
|
||||
|
||||
|
||||
def _imagenet(split: str) -> Dataset:
|
||||
if not IMAGENET_LOC_ENV in os.environ:
|
||||
|
|
Загрузка…
Ссылка в новой задаче