зеркало из https://github.com/microsoft/torchgeo.git
Rwanda Field Boundary: radiant mlhub -> source cooperative (#2118)
This commit is contained in:
Родитель
9df08d0ff5
Коммит
61635cd084
|
@ -3,99 +3,46 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.transform import Affine
|
||||
|
||||
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
|
||||
all_bands = ('B01', 'B02', 'B03', 'B04')
|
||||
|
||||
SIZE = 32
|
||||
NUM_SAMPLES = 5
|
||||
DTYPE = np.uint16
|
||||
NUM_SAMPLES = 1
|
||||
np.random.seed(0)
|
||||
|
||||
profile = {
|
||||
'driver': 'GTiff',
|
||||
'dtype': DTYPE,
|
||||
'width': SIZE,
|
||||
'height': SIZE,
|
||||
'count': 1,
|
||||
'crs': CRS.from_epsg(3857),
|
||||
'transform': Affine(
|
||||
4.77731426716, 0.0, 3374518.037700199, 0.0, -4.77731426716, -168438.54642526805
|
||||
),
|
||||
}
|
||||
Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE)
|
||||
|
||||
def create_mask(fn: str) -> None:
|
||||
profile = {
|
||||
'driver': 'GTiff',
|
||||
'dtype': 'uint8',
|
||||
'nodata': 0.0,
|
||||
'width': SIZE,
|
||||
'height': SIZE,
|
||||
'count': 1,
|
||||
'crs': 'epsg:3857',
|
||||
'compress': 'lzw',
|
||||
'predictor': 2,
|
||||
'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
|
||||
'blockysize': 32,
|
||||
'tiled': False,
|
||||
'interleave': 'band',
|
||||
}
|
||||
with rasterio.open(fn, 'w', **profile) as f:
|
||||
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1)
|
||||
for sample in range(NUM_SAMPLES):
|
||||
for split in ['train', 'test']:
|
||||
for date in dates:
|
||||
path = os.path.join('source', split, date)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
for band in all_bands:
|
||||
file = os.path.join(path, f'{sample:02}_{band}.tif')
|
||||
with rasterio.open(file, 'w', **profile) as src:
|
||||
src.write(Z, 1)
|
||||
|
||||
|
||||
def create_img(fn: str) -> None:
|
||||
profile = {
|
||||
'driver': 'GTiff',
|
||||
'dtype': 'uint16',
|
||||
'nodata': 0.0,
|
||||
'width': SIZE,
|
||||
'height': SIZE,
|
||||
'count': 1,
|
||||
'crs': 'epsg:3857',
|
||||
'compress': 'lzw',
|
||||
'predictor': 2,
|
||||
'blockysize': 16,
|
||||
'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
|
||||
'tiled': False,
|
||||
'interleave': 'band',
|
||||
}
|
||||
with rasterio.open(fn, 'w', **profile) as f:
|
||||
f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Train and test images
|
||||
for split in ('train', 'test'):
|
||||
for i in range(NUM_SAMPLES):
|
||||
for date in dates:
|
||||
directory = os.path.join(
|
||||
f'nasa_rwanda_field_boundary_competition_source_{split}',
|
||||
f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501
|
||||
)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
for band in all_bands:
|
||||
create_img(os.path.join(directory, f'{band}.tif'))
|
||||
|
||||
# Create collections.json, this isn't used by the dataset but is checked to
|
||||
# exist
|
||||
with open(
|
||||
f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json',
|
||||
'w',
|
||||
) as f:
|
||||
f.write('Not used')
|
||||
|
||||
# Train labels
|
||||
for i in range(NUM_SAMPLES):
|
||||
directory = os.path.join(
|
||||
'nasa_rwanda_field_boundary_competition_labels_train',
|
||||
f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}',
|
||||
)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
create_mask(os.path.join(directory, 'raster_labels.tif'))
|
||||
|
||||
# Create directories and compute checksums
|
||||
for filename in [
|
||||
'nasa_rwanda_field_boundary_competition_source_train',
|
||||
'nasa_rwanda_field_boundary_competition_source_test',
|
||||
'nasa_rwanda_field_boundary_competition_labels_train',
|
||||
]:
|
||||
shutil.make_archive(filename, 'gztar', '.', filename)
|
||||
# Compute checksums
|
||||
with open(f'{filename}.tar.gz', 'rb') as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f'{filename}: {md5}')
|
||||
path = os.path.join('labels', 'train')
|
||||
os.makedirs(path, exist_ok=True)
|
||||
file = os.path.join(path, f'{sample:02}.tif')
|
||||
with rasterio.open(file, 'w', **profile) as src:
|
||||
src.write(Z, 1)
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -1,9 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -19,45 +17,26 @@ from torchgeo.datasets import (
|
|||
RGBBandsMissingError,
|
||||
RwandaFieldBoundary,
|
||||
)
|
||||
|
||||
|
||||
class Collection:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz')
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(dataset_id: str, **kwargs: str) -> Collection:
|
||||
return Collection()
|
||||
from torchgeo.datasets.utils import Executable
|
||||
|
||||
|
||||
class TestRwandaFieldBoundary:
|
||||
@pytest.fixture(params=['train', 'test'])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
self,
|
||||
azcopy: Executable,
|
||||
monkeypatch: MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> RwandaFieldBoundary:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
|
||||
monkeypatch.setattr(
|
||||
RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5}
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
RwandaFieldBoundary,
|
||||
'md5s',
|
||||
{
|
||||
'train_images': 'af9395e2e49deefebb35fa65fa378ba3',
|
||||
'test_images': 'd104bb82323a39e7c3b3b7dd0156f550',
|
||||
'train_labels': '6cceaf16a141cf73179253a783e7d51b',
|
||||
},
|
||||
)
|
||||
url = os.path.join('tests', 'data', 'rwanda_field_boundary')
|
||||
monkeypatch.setattr(RwandaFieldBoundary, 'url', url)
|
||||
monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1})
|
||||
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return RwandaFieldBoundary(
|
||||
root, split, transforms=transforms, api_key='', download=True, checksum=True
|
||||
)
|
||||
return RwandaFieldBoundary(root, split, transforms=transforms, download=True)
|
||||
|
||||
def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
|
||||
x = dataset[0]
|
||||
|
@ -69,23 +48,12 @@ class TestRwandaFieldBoundary:
|
|||
assert 'mask' not in x
|
||||
|
||||
def test_len(self, dataset: RwandaFieldBoundary) -> None:
|
||||
assert len(dataset) == 5
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add(self, dataset: RwandaFieldBoundary) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 10
|
||||
|
||||
def test_needs_extraction(self, tmp_path: Path) -> None:
|
||||
root = str(tmp_path)
|
||||
for fn in [
|
||||
'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
|
||||
'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
|
||||
'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
|
||||
]:
|
||||
url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn)
|
||||
shutil.copy(url, root)
|
||||
RwandaFieldBoundary(root, checksum=False)
|
||||
assert len(ds) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None:
|
||||
RwandaFieldBoundary(root=dataset.root)
|
||||
|
@ -94,35 +62,8 @@ class TestRwandaFieldBoundary:
|
|||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
RwandaFieldBoundary(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
for fn in [
|
||||
'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
|
||||
'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
|
||||
'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
|
||||
]:
|
||||
with open(os.path.join(tmp_path, fn), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
RwandaFieldBoundary(root=str(tmp_path), checksum=True)
|
||||
|
||||
def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
|
||||
monkeypatch.setattr(
|
||||
RwandaFieldBoundary,
|
||||
'md5s',
|
||||
{'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'},
|
||||
)
|
||||
root = str(tmp_path)
|
||||
with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'):
|
||||
RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True)
|
||||
|
||||
def test_no_api_key(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match='Must provide an API key to download'):
|
||||
RwandaFieldBoundary(str(tmp_path), api_key=None, download=True)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(ValueError, match='is an invalid band name.'):
|
||||
with pytest.raises(AssertionError):
|
||||
RwandaFieldBoundary(bands=('foo', 'bar'))
|
||||
|
||||
def test_plot(self, dataset: RwandaFieldBoundary) -> None:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""Rwanda Field Boundary Competition dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
|
@ -16,11 +17,11 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import which
|
||||
|
||||
|
||||
class RwandaFieldBoundary(NonGeoDataset):
|
||||
r"""Rwanda Field Boundary Competition dataset.
|
||||
"""Rwanda Field Boundary Competition dataset.
|
||||
|
||||
This dataset contains field boundaries for smallholder farms in eastern Rwanda.
|
||||
The Nasa Harvest program funded a team of annotators from TaQadam to label Planet
|
||||
|
@ -46,40 +47,20 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
|
||||
dataset from Source Cooperative.
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
dataset_id = 'nasa_rwanda_field_boundary_competition'
|
||||
collection_ids = [
|
||||
'nasa_rwanda_field_boundary_competition_source_train',
|
||||
'nasa_rwanda_field_boundary_competition_labels_train',
|
||||
'nasa_rwanda_field_boundary_competition_source_test',
|
||||
]
|
||||
number_of_patches_per_split = {'train': 57, 'test': 13}
|
||||
|
||||
filenames = {
|
||||
'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz',
|
||||
'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz',
|
||||
'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz',
|
||||
}
|
||||
md5s = {
|
||||
'train_images': '1f9ec08038218e67e11f82a86849b333',
|
||||
'test_images': '17bb0e56eedde2e7a43c57aa908dc125',
|
||||
'train_labels': '10e4eb761523c57b6d3bdf9394004f5f',
|
||||
}
|
||||
url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition'
|
||||
|
||||
splits = {'train': 57, 'test': 13}
|
||||
dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12')
|
||||
|
||||
all_bands = ('B01', 'B02', 'B03', 'B04')
|
||||
rgb_bands = ('B03', 'B02', 'B01')
|
||||
|
||||
classes = ['No field-boundary', 'Field-boundary']
|
||||
|
||||
splits = ['train', 'test']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
|
@ -87,8 +68,6 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
bands: Sequence[str] = all_bands,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new RwandaFieldBoundary instance.
|
||||
|
||||
|
@ -99,49 +78,29 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: If *split* or *bands* are invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self._validate_bands(bands)
|
||||
assert split in self.splits
|
||||
if download and api_key is None:
|
||||
raise RuntimeError('Must provide an API key to download the dataset')
|
||||
assert set(bands) <= set(self.all_bands)
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
self.split = split
|
||||
self.download = download
|
||||
self.api_key = api_key
|
||||
self.checksum = checksum
|
||||
|
||||
self._verify()
|
||||
|
||||
self.image_filenames: list[list[list[str]]] = []
|
||||
self.mask_filenames: list[str] = []
|
||||
for i in range(self.number_of_patches_per_split[split]):
|
||||
dates = []
|
||||
for date in self.dates:
|
||||
patch = []
|
||||
for band in self.bands:
|
||||
fn = os.path.join(
|
||||
self.root,
|
||||
f'nasa_rwanda_field_boundary_competition_source_{split}',
|
||||
f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501
|
||||
f'{band}.tif',
|
||||
)
|
||||
patch.append(fn)
|
||||
dates.append(patch)
|
||||
self.image_filenames.append(dates)
|
||||
self.mask_filenames.append(
|
||||
os.path.join(
|
||||
self.root,
|
||||
f'nasa_rwanda_field_boundary_competition_labels_{split}',
|
||||
f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}',
|
||||
'raster_labels.tif',
|
||||
)
|
||||
)
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of chips in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return self.splits[self.split]
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
@ -150,83 +109,34 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
index: index to return
|
||||
|
||||
Returns:
|
||||
a dict containing image, mask, transform, crs, and metadata at index.
|
||||
a dict containing image and mask at index.
|
||||
"""
|
||||
img_fns = self.image_filenames[index]
|
||||
mask_fn = self.mask_filenames[index]
|
||||
|
||||
imgs = []
|
||||
for date_fns in img_fns:
|
||||
bands = []
|
||||
for band_fn in date_fns:
|
||||
with rasterio.open(band_fn) as f:
|
||||
bands.append(f.read(1).astype(np.int32))
|
||||
imgs.append(bands)
|
||||
img = torch.from_numpy(np.array(imgs))
|
||||
|
||||
sample = {'image': img}
|
||||
images = []
|
||||
for date in self.dates:
|
||||
patches = []
|
||||
for band in self.bands:
|
||||
path = os.path.join(self.root, 'source', self.split, date)
|
||||
with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src:
|
||||
patches.append(src.read(1).astype(np.float32))
|
||||
images.append(patches)
|
||||
sample = {'image': torch.from_numpy(np.array(images))}
|
||||
|
||||
if self.split == 'train':
|
||||
with rasterio.open(mask_fn) as f:
|
||||
mask = f.read(1)
|
||||
mask = torch.from_numpy(mask)
|
||||
sample['mask'] = mask
|
||||
path = os.path.join(self.root, 'labels', self.split)
|
||||
with rasterio.open(os.path.join(path, f'{index:02}.tif')) as src:
|
||||
sample['mask'] = torch.from_numpy(src.read(1).astype(np.int64))
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of chips in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.image_filenames)
|
||||
|
||||
def _validate_bands(self, bands: Sequence[str]) -> None:
|
||||
"""Validate list of bands.
|
||||
|
||||
Args:
|
||||
bands: user-provided sequence of bands to load
|
||||
|
||||
Raises:
|
||||
ValueError: if an invalid band name is provided
|
||||
"""
|
||||
for band in bands:
|
||||
if band not in self.all_bands:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the subdirectories already exist and have the correct number of files
|
||||
checks = []
|
||||
for split, num_patches in self.number_of_patches_per_split.items():
|
||||
path = os.path.join(
|
||||
self.root, f'nasa_rwanda_field_boundary_competition_source_{split}'
|
||||
)
|
||||
if os.path.exists(path):
|
||||
num_files = len(os.listdir(path))
|
||||
# 6 dates + 1 collection.json file
|
||||
checks.append(num_files == (num_patches * 6) + 1)
|
||||
else:
|
||||
checks.append(False)
|
||||
|
||||
if all(checks):
|
||||
return
|
||||
|
||||
# Check if tar file already exists (if so then extract)
|
||||
have_all_files = True
|
||||
for group in ['train_images', 'train_labels', 'test_images']:
|
||||
filepath = os.path.join(self.root, self.filenames[group])
|
||||
if os.path.exists(filepath):
|
||||
if self.checksum and not check_integrity(filepath, self.md5s[group]):
|
||||
raise RuntimeError('Dataset found, but corrupted.')
|
||||
extract_archive(filepath)
|
||||
else:
|
||||
have_all_files = False
|
||||
if have_all_files:
|
||||
path = os.path.join(self.root, 'source', self.split, '*', '*.tif')
|
||||
expected = len(self.dates) * self.splits[self.split] * len(self.all_bands)
|
||||
if len(glob.glob(path)) == expected:
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
|
@ -237,15 +147,10 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it."""
|
||||
for collection_id in self.collection_ids:
|
||||
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
|
||||
|
||||
for group in ['train_images', 'train_labels', 'test_images']:
|
||||
filepath = os.path.join(self.root, self.filenames[group])
|
||||
if self.checksum and not check_integrity(filepath, self.md5s[group]):
|
||||
raise RuntimeError('Dataset not found or corrupted.')
|
||||
extract_archive(filepath, self.root)
|
||||
"""Download the dataset."""
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
azcopy = which('azcopy')
|
||||
azcopy('sync', self.url, self.root, '--recursive=true')
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
Загрузка…
Ссылка в новой задаче