Remove unnecessary code for loading ImageNet

This commit is contained in:
hadisalman 2020-05-13 01:46:37 +00:00
Родитель 6082675145
Коммит adee718328
1 изменённых файлов: 1 добавлений и 32 удалений

Просмотреть файл

@ -15,8 +15,6 @@ import torch
# make sure your val directory is preprocessed to look like the train directory, e.g. by running this script
# https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
IMAGENET_LOC_ENV = "IMAGENET_DIR"
IMAGENET_ON_PHILLY_DIR = "IMAGENET_DIR_PHILLY"
IMAGENET_ON_AZURE_ENV = "IMAGENET_DIR_AZURE"
# list of all datasets
DATASETS = ["imagenet", "imagenet32", "cifar10"]
@ -25,10 +23,7 @@ DATASETS = ["imagenet", "imagenet32", "cifar10"]
def get_dataset(dataset: str, split: str) -> Dataset:
"""Return the dataset as a PyTorch Dataset object"""
if dataset == "imagenet":
if "PT_DATA_DIR" in os.environ or IMAGENET_ON_AZURE_ENV in os.environ: #running on Philly
return _imagenet_remote(split)
else:
return _imagenet(split)
return _imagenet(split)
elif dataset == "imagenet32":
return _imagenet32(split)
@ -84,32 +79,6 @@ def _cifar10(split: str) -> Dataset:
else:
raise Exception("Unknown split name.")
def _imagenet_remote(split: str) -> Dataset:
if IMAGENET_ON_AZURE_ENV in os.environ:
dir = os.environ[IMAGENET_ON_AZURE_ENV]
else:
dir = os.environ[IMAGENET_ON_PHILLY_DIR]
trainpath = os.path.join(dir, 'train.zip')
train_map = os.path.join(dir, 'train_map.txt')
valpath = os.path.join(dir, 'val.zip')
val_map = os.path.join(dir, 'val_map.txt')
if split == "train":
return ZipData(trainpath, train_map,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]))
elif split == "test":
return ZipData(valpath, val_map,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
]))
def _imagenet(split: str) -> Dataset:
if not IMAGENET_LOC_ENV in os.environ: