Rwanda Field Boundary: radiant mlhub -> source cooperative (#2118)

This commit is contained in:
Adam J. Stewart 2024-07-10 17:40:23 +02:00 коммит произвёл GitHub
Родитель 9df08d0ff5
Коммит 61635cd084
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
55 изменённых файлов: 82 добавлений и 289 удалений

113
tests/data/rwanda_field_boundary/data.py Normal file → Executable file
Просмотреть файл

@ -3,99 +3,46 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import hashlib
import os
import shutil
import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
all_bands = ('B01', 'B02', 'B03', 'B04')
SIZE = 32
NUM_SAMPLES = 5
DTYPE = np.uint16
NUM_SAMPLES = 1
np.random.seed(0)
profile = {
'driver': 'GTiff',
'dtype': DTYPE,
'width': SIZE,
'height': SIZE,
'count': 1,
'crs': CRS.from_epsg(3857),
'transform': Affine(
4.77731426716, 0.0, 3374518.037700199, 0.0, -4.77731426716, -168438.54642526805
),
}
Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE)
def create_mask(fn: str) -> None:
profile = {
'driver': 'GTiff',
'dtype': 'uint8',
'nodata': 0.0,
'width': SIZE,
'height': SIZE,
'count': 1,
'crs': 'epsg:3857',
'compress': 'lzw',
'predictor': 2,
'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
'blockysize': 32,
'tiled': False,
'interleave': 'band',
}
with rasterio.open(fn, 'w', **profile) as f:
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1)
for sample in range(NUM_SAMPLES):
for split in ['train', 'test']:
for date in dates:
path = os.path.join('source', split, date)
os.makedirs(path, exist_ok=True)
for band in all_bands:
file = os.path.join(path, f'{sample:02}_{band}.tif')
with rasterio.open(file, 'w', **profile) as src:
src.write(Z, 1)
def create_img(fn: str) -> None:
profile = {
'driver': 'GTiff',
'dtype': 'uint16',
'nodata': 0.0,
'width': SIZE,
'height': SIZE,
'count': 1,
'crs': 'epsg:3857',
'compress': 'lzw',
'predictor': 2,
'blockysize': 16,
'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
'tiled': False,
'interleave': 'band',
}
with rasterio.open(fn, 'w', **profile) as f:
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1)
if __name__ == '__main__':
# Train and test images
for split in ('train', 'test'):
for i in range(NUM_SAMPLES):
for date in dates:
directory = os.path.join(
f'nasa_rwanda_field_boundary_competition_source_{split}',
f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501
)
os.makedirs(directory, exist_ok=True)
for band in all_bands:
create_img(os.path.join(directory, f'{band}.tif'))
# Create collections.json, this isn't used by the dataset but is checked to
# exist
with open(
f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json',
'w',
) as f:
f.write('Not used')
# Train labels
for i in range(NUM_SAMPLES):
directory = os.path.join(
'nasa_rwanda_field_boundary_competition_labels_train',
f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}',
)
os.makedirs(directory, exist_ok=True)
create_mask(os.path.join(directory, 'raster_labels.tif'))
# Create directories and compute checksums
for filename in [
'nasa_rwanda_field_boundary_competition_source_train',
'nasa_rwanda_field_boundary_competition_source_test',
'nasa_rwanda_field_boundary_competition_labels_train',
]:
shutil.make_archive(filename, 'gztar', '.', filename)
# Compute checksums
with open(f'{filename}.tar.gz', 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f'{filename}: {md5}')
path = os.path.join('labels', 'train')
os.makedirs(path, exist_ok=True)
file = os.path.join(path, f'{sample:02}.tif')
with rasterio.open(file, 'w', **profile) as src:
src.write(Z, 1)

Двоичные данные
tests/data/rwanda_field_boundary/labels/train/00.tif Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
@ -19,45 +17,26 @@ from torchgeo.datasets import (
RGBBandsMissingError,
RwandaFieldBoundary,
)
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz')
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()
from torchgeo.datasets.utils import Executable
class TestRwandaFieldBoundary:
@pytest.fixture(params=['train', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
self,
azcopy: Executable,
monkeypatch: MonkeyPatch,
tmp_path: Path,
request: SubRequest,
) -> RwandaFieldBoundary:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
monkeypatch.setattr(
RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5}
)
monkeypatch.setattr(
RwandaFieldBoundary,
'md5s',
{
'train_images': 'af9395e2e49deefebb35fa65fa378ba3',
'test_images': 'd104bb82323a39e7c3b3b7dd0156f550',
'train_labels': '6cceaf16a141cf73179253a783e7d51b',
},
)
url = os.path.join('tests', 'data', 'rwanda_field_boundary')
monkeypatch.setattr(RwandaFieldBoundary, 'url', url)
monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1})
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return RwandaFieldBoundary(
root, split, transforms=transforms, api_key='', download=True, checksum=True
)
return RwandaFieldBoundary(root, split, transforms=transforms, download=True)
def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
x = dataset[0]
@ -69,23 +48,12 @@ class TestRwandaFieldBoundary:
assert 'mask' not in x
def test_len(self, dataset: RwandaFieldBoundary) -> None:
assert len(dataset) == 5
assert len(dataset) == 1
def test_add(self, dataset: RwandaFieldBoundary) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 10
def test_needs_extraction(self, tmp_path: Path) -> None:
root = str(tmp_path)
for fn in [
'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
]:
url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn)
shutil.copy(url, root)
RwandaFieldBoundary(root, checksum=False)
assert len(ds) == 2
def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None:
RwandaFieldBoundary(root=dataset.root)
@ -94,35 +62,8 @@ class TestRwandaFieldBoundary:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
RwandaFieldBoundary(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:
for fn in [
'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
]:
with open(os.path.join(tmp_path, fn), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
RwandaFieldBoundary(root=str(tmp_path), checksum=True)
def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
monkeypatch.setattr(
RwandaFieldBoundary,
'md5s',
{'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'},
)
root = str(tmp_path)
with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'):
RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True)
def test_no_api_key(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match='Must provide an API key to download'):
RwandaFieldBoundary(str(tmp_path), api_key=None, download=True)
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError, match='is an invalid band name.'):
with pytest.raises(AssertionError):
RwandaFieldBoundary(bands=('foo', 'bar'))
def test_plot(self, dataset: RwandaFieldBoundary) -> None:

Просмотреть файл

@ -3,6 +3,7 @@
"""Rwanda Field Boundary Competition dataset."""
import glob
import os
from collections.abc import Callable, Sequence
@ -16,11 +17,11 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import which
class RwandaFieldBoundary(NonGeoDataset):
r"""Rwanda Field Boundary Competition dataset.
"""Rwanda Field Boundary Competition dataset.
This dataset contains field boundaries for smallholder farms in eastern Rwanda.
The Nasa Harvest program funded a team of annotators from TaQadam to label Planet
@ -46,40 +47,20 @@ class RwandaFieldBoundary(NonGeoDataset):
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
dataset from Source Cooperative.
.. versionadded:: 0.5
"""
dataset_id = 'nasa_rwanda_field_boundary_competition'
collection_ids = [
'nasa_rwanda_field_boundary_competition_source_train',
'nasa_rwanda_field_boundary_competition_labels_train',
'nasa_rwanda_field_boundary_competition_source_test',
]
number_of_patches_per_split = {'train': 57, 'test': 13}
filenames = {
'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
}
md5s = {
'train_images': '1f9ec08038218e67e11f82a86849b333',
'test_images': '17bb0e56eedde2e7a43c57aa908dc125',
'train_labels': '10e4eb761523c57b6d3bdf9394004f5f',
}
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
splits = {'train': 57, 'test': 13}
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
all_bands = ('B01', 'B02', 'B03', 'B04')
rgb_bands = ('B03', 'B02', 'B01')
classes = ['No field-boundary', 'Field-boundary']
splits = ['train', 'test']
def __init__(
self,
root: str = 'data',
@ -87,8 +68,6 @@ class RwandaFieldBoundary(NonGeoDataset):
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new RwandaFieldBoundary instance.
@ -99,49 +78,29 @@ class RwandaFieldBoundary(NonGeoDataset):
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: If *split* or *bands* are invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
assert split in self.splits
if download and api_key is None:
raise RuntimeError('Must provide an API key to download the dataset')
assert set(bands) <= set(self.all_bands)
self.root = root
self.split = split
self.bands = bands
self.transforms = transforms
self.split = split
self.download = download
self.api_key = api_key
self.checksum = checksum
self._verify()
self.image_filenames: list[list[list[str]]] = []
self.mask_filenames: list[str] = []
for i in range(self.number_of_patches_per_split[split]):
dates = []
for date in self.dates:
patch = []
for band in self.bands:
fn = os.path.join(
self.root,
f'nasa_rwanda_field_boundary_competition_source_{split}',
f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501
f'{band}.tif',
)
patch.append(fn)
dates.append(patch)
self.image_filenames.append(dates)
self.mask_filenames.append(
os.path.join(
self.root,
f'nasa_rwanda_field_boundary_competition_labels_{split}',
f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}',
'raster_labels.tif',
)
)
def __len__(self) -> int:
"""Return the number of chips in the dataset.
Returns:
length of the dataset
"""
return self.splits[self.split]
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
@ -150,83 +109,34 @@ class RwandaFieldBoundary(NonGeoDataset):
index: index to return
Returns:
a dict containing image, mask, transform, crs, and metadata at index.
a dict containing image and mask at index.
"""
img_fns = self.image_filenames[index]
mask_fn = self.mask_filenames[index]
imgs = []
for date_fns in img_fns:
bands = []
for band_fn in date_fns:
with rasterio.open(band_fn) as f:
bands.append(f.read(1).astype(np.int32))
imgs.append(bands)
img = torch.from_numpy(np.array(imgs))
sample = {'image': img}
images = []
for date in self.dates:
patches = []
for band in self.bands:
path = os.path.join(self.root, 'source', self.split, date)
with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src:
patches.append(src.read(1).astype(np.float32))
images.append(patches)
sample = {'image': torch.from_numpy(np.array(images))}
if self.split == 'train':
with rasterio.open(mask_fn) as f:
mask = f.read(1)
mask = torch.from_numpy(mask)
sample['mask'] = mask
path = os.path.join(self.root, 'labels', self.split)
with rasterio.open(os.path.join(path, f'{index:02}.tif')) as src:
sample['mask'] = torch.from_numpy(src.read(1).astype(np.int64))
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of chips in the dataset.
Returns:
length of the dataset
"""
return len(self.image_filenames)
def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.
Args:
bands: user-provided sequence of bands to load
Raises:
ValueError: if an invalid band name is provided
"""
for band in bands:
if band not in self.all_bands:
raise ValueError(f"'{band}' is an invalid band name.")
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the subdirectories already exist and have the correct number of files
checks = []
for split, num_patches in self.number_of_patches_per_split.items():
path = os.path.join(
self.root, f'nasa_rwanda_field_boundary_competition_source_{split}'
)
if os.path.exists(path):
num_files = len(os.listdir(path))
# 6 dates + 1 collection.json file
checks.append(num_files == (num_patches * 6) + 1)
else:
checks.append(False)
if all(checks):
return
# Check if tar file already exists (if so then extract)
have_all_files = True
for group in ['train_images', 'train_labels', 'test_images']:
filepath = os.path.join(self.root, self.filenames[group])
if os.path.exists(filepath):
if self.checksum and not check_integrity(filepath, self.md5s[group]):
raise RuntimeError('Dataset found, but corrupted.')
extract_archive(filepath)
else:
have_all_files = False
if have_all_files:
path = os.path.join(self.root, 'source', self.split, '*', '*.tif')
expected = len(self.dates) * self.splits[self.split] * len(self.all_bands)
if len(glob.glob(path)) == expected:
return
# Check if the user requested to download the dataset
@ -237,15 +147,10 @@ class RwandaFieldBoundary(NonGeoDataset):
self._download()
def _download(self) -> None:
"""Download the dataset and extract it."""
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
for group in ['train_images', 'train_labels', 'test_images']:
filepath = os.path.join(self.root, self.filenames[group])
if self.checksum and not check_integrity(filepath, self.md5s[group]):
raise RuntimeError('Dataset not found or corrupted.')
extract_archive(filepath, self.root)
"""Download the dataset."""
os.makedirs(self.root, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', self.url, self.root, '--recursive=true')
def plot(
self,