зеркало из https://github.com/microsoft/torchgeo.git
CV4A Kenya Crop Type: radiant mlhub -> source cooperative (#2090)
This commit is contained in:
Родитель
83cad6017c
Коммит
32aa3492ca
|
@ -0,0 +1,5 @@
|
|||
train,test
|
||||
1,2
|
||||
3,4
|
||||
5
|
||||
6
|
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
DTYPE = np.float32
|
||||
SIZE = 2
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
all_bands = (
|
||||
'B01',
|
||||
'B02',
|
||||
'B03',
|
||||
'B04',
|
||||
'B05',
|
||||
'B06',
|
||||
'B07',
|
||||
'B08',
|
||||
'B8A',
|
||||
'B09',
|
||||
'B11',
|
||||
'B12',
|
||||
'CLD',
|
||||
)
|
||||
|
||||
for tile in range(1):
|
||||
directory = os.path.join('data', str(tile))
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
arr = np.random.randint(np.iinfo(np.int32).max, size=(SIZE, SIZE), dtype=np.int32)
|
||||
img = Image.fromarray(arr)
|
||||
img.save(os.path.join(directory, f'{tile}_field_id.tif'))
|
||||
|
||||
arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8)
|
||||
img = Image.fromarray(arr)
|
||||
img.save(os.path.join(directory, f'{tile}_label.tif'))
|
||||
|
||||
for date in ['20190606']:
|
||||
directory = os.path.join(directory, date)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
for band in all_bands:
|
||||
arr = np.random.rand(SIZE, SIZE).astype(DTYPE) * np.finfo(DTYPE).max
|
||||
img = Image.fromarray(arr)
|
||||
img.save(os.path.join(directory, f'{tile}_{band}_{date}.tif'))
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -1,9 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -18,44 +16,23 @@ from torchgeo.datasets import (
|
|||
DatasetNotFoundError,
|
||||
RGBBandsMissingError,
|
||||
)
|
||||
|
||||
|
||||
class Collection:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(
|
||||
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
|
||||
)
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(dataset_id: str, **kwargs: str) -> Collection:
|
||||
return Collection()
|
||||
from torchgeo.datasets.utils import Executable
|
||||
|
||||
|
||||
class TestCV4AKenyaCropType:
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
|
||||
source_md5 = '7f4dcb3f33743dddd73f453176308bfb'
|
||||
labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9'
|
||||
monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5)
|
||||
monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5)
|
||||
monkeypatch.setattr(
|
||||
CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00']
|
||||
)
|
||||
def dataset(
|
||||
self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> CV4AKenyaCropType:
|
||||
url = os.path.join('tests', 'data', 'cv4a_kenya_crop_type')
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'url', url)
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'tiles', list(map(str, range(1))))
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606'])
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2)
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity()
|
||||
return CV4AKenyaCropType(
|
||||
root,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key='',
|
||||
checksum=True,
|
||||
verbose=True,
|
||||
)
|
||||
return CV4AKenyaCropType(root, transforms=transforms, download=True)
|
||||
|
||||
def test_getitem(self, dataset: CV4AKenyaCropType) -> None:
|
||||
x = dataset[0]
|
||||
|
@ -66,60 +43,34 @@ class TestCV4AKenyaCropType:
|
|||
assert isinstance(x['y'], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: CV4AKenyaCropType) -> None:
|
||||
assert len(dataset) == 345
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add(self, dataset: CV4AKenyaCropType) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 690
|
||||
|
||||
def test_get_splits(self, dataset: CV4AKenyaCropType) -> None:
|
||||
train_field_ids, test_field_ids = dataset.get_splits()
|
||||
assert isinstance(train_field_ids, list)
|
||||
assert isinstance(test_field_ids, list)
|
||||
assert len(train_field_ids) == 18
|
||||
assert len(test_field_ids) == 9
|
||||
assert 336 in train_field_ids
|
||||
assert 336 not in test_field_ids
|
||||
assert 4793 in test_field_ids
|
||||
assert 4793 not in train_field_ids
|
||||
assert len(ds) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None:
|
||||
CV4AKenyaCropType(root=dataset.root, download=True, api_key='')
|
||||
CV4AKenyaCropType(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CV4AKenyaCropType(str(tmp_path))
|
||||
|
||||
def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._load_label_tile('foo')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._load_all_image_tiles('foo', ('B01', 'B02'))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02'))
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ValueError, match='is an invalid band name.'):
|
||||
CV4AKenyaCropType(bands=('foo', 'bar'))
|
||||
|
||||
def test_plot(self, dataset: CV4AKenyaCropType) -> None:
|
||||
dataset.plot(dataset[0], time_step=0, suptitle='Test')
|
||||
plt.close()
|
||||
|
||||
sample = dataset[0]
|
||||
dataset.plot(sample, time_step=0, suptitle='Test')
|
||||
plt.close()
|
||||
sample['prediction'] = sample['mask'].clone()
|
||||
dataset.plot(sample, time_step=0, suptitle='Pred')
|
||||
plt.close()
|
||||
|
||||
def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None:
|
||||
dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01']))
|
||||
with pytest.raises(
|
||||
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
|
||||
):
|
||||
dataset.plot(dataset[0], time_step=0, suptitle='Single Band')
|
||||
match = 'Dataset does not contain some of the RGB bands'
|
||||
with pytest.raises(RGBBandsMissingError, match=match):
|
||||
dataset.plot(dataset[0])
|
||||
|
|
|
@ -3,9 +3,8 @@
|
|||
|
||||
"""CV4A Kenya Crop Type dataset."""
|
||||
|
||||
import csv
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import lru_cache
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -17,16 +16,23 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import which
|
||||
|
||||
|
||||
# TODO: read geospatial information from stac.json files
|
||||
class CV4AKenyaCropType(NonGeoDataset):
|
||||
"""CV4A Kenya Crop Type dataset.
|
||||
"""CV4A Kenya Crop Type Competition dataset.
|
||||
|
||||
Used in a competition in the Computer NonGeo for Agriculture (CV4A) workshop in
|
||||
ICLR 2020. See `this website <https://mlhub.earth/10.34911/rdnt.dw605x>`__
|
||||
for dataset details.
|
||||
The `CV4A Kenya Crop Type Competition
|
||||
<https://beta.source.coop/repositories/radiantearth/african-crops-kenya-02/>`__
|
||||
dataset was produced as part of the Crop Type Detection competition at the
|
||||
Computer Vision for Agriculture (CV4A) Workshop at the ICLR 2020 conference.
|
||||
The objective of the competition was to create a machine learning model to
|
||||
classify fields by crop type from images collected during the growing season
|
||||
by the Sentinel-2 satellites.
|
||||
|
||||
See the `dataset documentation
|
||||
<https://data.source.coop/radiantearth/african-crops-kenya-02/Documentation.pdf>`__
|
||||
for details.
|
||||
|
||||
Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time.
|
||||
|
||||
|
@ -54,29 +60,12 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
|
||||
dataset from Source Cooperative.
|
||||
"""
|
||||
|
||||
collection_ids = [
|
||||
'ref_african_crops_kenya_02_labels',
|
||||
'ref_african_crops_kenya_02_source',
|
||||
]
|
||||
image_meta = {
|
||||
'filename': 'ref_african_crops_kenya_02_source.tar.gz',
|
||||
'md5': '9c2004782f6dc83abb1bf45ba4d0da46',
|
||||
}
|
||||
target_meta = {
|
||||
'filename': 'ref_african_crops_kenya_02_labels.tar.gz',
|
||||
'md5': '93949abd0ae82ba564f5a933cefd8215',
|
||||
}
|
||||
|
||||
tile_names = [
|
||||
'ref_african_crops_kenya_02_tile_00',
|
||||
'ref_african_crops_kenya_02_tile_01',
|
||||
'ref_african_crops_kenya_02_tile_02',
|
||||
'ref_african_crops_kenya_02_tile_03',
|
||||
]
|
||||
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
|
||||
tiles = list(map(str, range(4)))
|
||||
dates = [
|
||||
'20190606',
|
||||
'20190701',
|
||||
|
@ -92,7 +81,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
'20191004',
|
||||
'20191103',
|
||||
]
|
||||
band_names = (
|
||||
all_bands = (
|
||||
'B01',
|
||||
'B02',
|
||||
'B03',
|
||||
|
@ -107,7 +96,6 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
'B12',
|
||||
'CLD',
|
||||
)
|
||||
|
||||
rgb_bands = ['B04', 'B03', 'B02']
|
||||
|
||||
# Same for all tiles
|
||||
|
@ -119,12 +107,9 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
root: str = 'data',
|
||||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: tuple[str, ...] = band_names,
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new CV4A Kenya Crop Type Dataset instance.
|
||||
|
||||
|
@ -137,32 +122,25 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
verbose: if True, print messages when new tiles are loaded
|
||||
|
||||
Raises:
|
||||
AssertionError: If *bands* are invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self._validate_bands(bands)
|
||||
assert set(bands) <= set(self.all_bands)
|
||||
|
||||
self.root = root
|
||||
self.chip_size = chip_size
|
||||
self.stride = stride
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
self.verbose = verbose
|
||||
self.download = download
|
||||
|
||||
if download:
|
||||
self._download(api_key)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise DatasetNotFoundError(self)
|
||||
self._verify()
|
||||
|
||||
# Calculate the indices that we will use over all tiles
|
||||
self.chips_metadata = []
|
||||
for tile_index in range(len(self.tile_names)):
|
||||
for tile_index in range(len(self.tiles)):
|
||||
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
|
||||
self.tile_height - self.chip_size
|
||||
]:
|
||||
|
@ -181,10 +159,10 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
data, labels, field ids, and metadata at that index
|
||||
"""
|
||||
tile_index, y, x = self.chips_metadata[index]
|
||||
tile_name = self.tile_names[tile_index]
|
||||
tile = self.tiles[tile_index]
|
||||
|
||||
img = self._load_all_image_tiles(tile_name, self.bands)
|
||||
labels, field_ids = self._load_label_tile(tile_name)
|
||||
img = self._load_all_image_tiles(tile)
|
||||
labels, field_ids = self._load_label_tile(tile)
|
||||
|
||||
img = img[:, :, y : y + self.chip_size, x : x + self.chip_size]
|
||||
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
|
||||
|
@ -213,193 +191,94 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
return len(self.chips_metadata)
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]:
|
||||
def _load_label_tile(self, tile: str) -> tuple[Tensor, Tensor]:
|
||||
"""Load a single _tile_ of labels and field_ids.
|
||||
|
||||
Args:
|
||||
tile_name: name of tile to load
|
||||
tile: name of tile to load
|
||||
|
||||
Returns:
|
||||
tuple of labels and field ids
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``tile_name`` is invalid
|
||||
"""
|
||||
assert tile_name in self.tile_names
|
||||
directory = os.path.join(self.root, 'data', tile)
|
||||
|
||||
if self.verbose:
|
||||
print(f'Loading labels/field_ids for {tile_name}')
|
||||
|
||||
directory = os.path.join(
|
||||
self.root, 'ref_african_crops_kenya_02_labels', tile_name + '_label'
|
||||
)
|
||||
|
||||
with Image.open(os.path.join(directory, 'labels.tif')) as img:
|
||||
with Image.open(os.path.join(directory, f'{tile}_label.tif')) as img:
|
||||
array: np.typing.NDArray[np.int_] = np.array(img)
|
||||
labels = torch.from_numpy(array)
|
||||
|
||||
with Image.open(os.path.join(directory, 'field_ids.tif')) as img:
|
||||
with Image.open(os.path.join(directory, f'{tile}_field_id.tif')) as img:
|
||||
array = np.array(img)
|
||||
field_ids = torch.from_numpy(array)
|
||||
|
||||
return (labels, field_ids)
|
||||
|
||||
def _validate_bands(self, bands: tuple[str, ...]) -> None:
|
||||
"""Validate list of bands.
|
||||
|
||||
Args:
|
||||
bands: user-provided tuple of bands to load
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``bands`` is not a tuple
|
||||
ValueError: if an invalid band name is provided
|
||||
"""
|
||||
assert isinstance(bands, tuple), 'The list of bands must be a tuple'
|
||||
for band in bands:
|
||||
if band not in self.band_names:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
return labels, field_ids
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_all_image_tiles(
|
||||
self, tile_name: str, bands: tuple[str, ...] = band_names
|
||||
) -> Tensor:
|
||||
def _load_all_image_tiles(self, tile: str) -> Tensor:
|
||||
"""Load all the imagery (across time) for a single _tile_.
|
||||
|
||||
Optionally allows for subsetting of the bands that are loaded.
|
||||
|
||||
Args:
|
||||
tile_name: name of tile to load
|
||||
bands: tuple of bands to load
|
||||
tile: name of tile to load
|
||||
|
||||
Returns:
|
||||
imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of
|
||||
points in time, 3035 is the tile height, and 2016 is the tile width
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``tile_name`` is invalid
|
||||
points in time, 3035 is the tile height, and 2016 is the tile width
|
||||
"""
|
||||
assert tile_name in self.tile_names
|
||||
|
||||
if self.verbose:
|
||||
print(f'Loading all imagery for {tile_name}')
|
||||
|
||||
img = torch.zeros(
|
||||
len(self.dates),
|
||||
len(bands),
|
||||
len(self.bands),
|
||||
self.tile_height,
|
||||
self.tile_width,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
for date_index, date in enumerate(self.dates):
|
||||
img[date_index] = self._load_single_image_tile(tile_name, date, self.bands)
|
||||
img[date_index] = self._load_single_image_tile(tile, date)
|
||||
|
||||
return img
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_single_image_tile(
|
||||
self, tile_name: str, date: str, bands: tuple[str, ...]
|
||||
) -> Tensor:
|
||||
def _load_single_image_tile(self, tile: str, date: str) -> Tensor:
|
||||
"""Load the imagery for a single tile for a single date.
|
||||
|
||||
Optionally allows for subsetting of the bands that are loaded.
|
||||
|
||||
Args:
|
||||
tile_name: name of tile to load
|
||||
tile: name of tile to load
|
||||
date: date of tile to load
|
||||
bands: bands to load
|
||||
|
||||
Returns:
|
||||
array containing a single image tile
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``tile_name`` or ``date`` is invalid
|
||||
"""
|
||||
assert tile_name in self.tile_names
|
||||
assert date in self.dates
|
||||
|
||||
if self.verbose:
|
||||
print(f'Loading imagery for {tile_name} at {date}')
|
||||
|
||||
directory = os.path.join(self.root, 'data', tile, date)
|
||||
img = torch.zeros(
|
||||
len(bands), self.tile_height, self.tile_width, dtype=torch.float32
|
||||
len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32
|
||||
)
|
||||
for band_index, band_name in enumerate(self.bands):
|
||||
filepath = os.path.join(
|
||||
self.root,
|
||||
'ref_african_crops_kenya_02_source',
|
||||
f'{tile_name}_{date}',
|
||||
f'{band_name}.tif',
|
||||
)
|
||||
filepath = os.path.join(directory, f'{tile}_{band_name}_{date}.tif')
|
||||
with Image.open(filepath) as band_img:
|
||||
array: np.typing.NDArray[np.int_] = np.array(band_img)
|
||||
img[band_index] = torch.from_numpy(array)
|
||||
|
||||
return img
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
||||
Returns:
|
||||
True if dataset files are found and/or MD5s match, else False
|
||||
"""
|
||||
images: bool = check_integrity(
|
||||
os.path.join(self.root, self.image_meta['filename']),
|
||||
self.image_meta['md5'] if self.checksum else None,
|
||||
)
|
||||
|
||||
targets: bool = check_integrity(
|
||||
os.path.join(self.root, self.target_meta['filename']),
|
||||
self.target_meta['md5'] if self.checksum else None,
|
||||
)
|
||||
|
||||
return images and targets
|
||||
|
||||
def get_splits(self) -> tuple[list[int], list[int]]:
|
||||
"""Get the field_ids for the train/test splits from the dataset directory.
|
||||
|
||||
Returns:
|
||||
list of training field_ids and list of testing field_ids
|
||||
"""
|
||||
train_field_ids = []
|
||||
test_field_ids = []
|
||||
splits_fn = os.path.join(
|
||||
self.root,
|
||||
'ref_african_crops_kenya_02_labels',
|
||||
'_common',
|
||||
'field_train_test_ids.csv',
|
||||
)
|
||||
|
||||
with open(splits_fn, newline='') as f:
|
||||
reader = csv.reader(f)
|
||||
|
||||
# Skip header row
|
||||
next(reader)
|
||||
|
||||
for row in reader:
|
||||
train_field_ids.append(int(row[0]))
|
||||
if row[1]:
|
||||
test_field_ids.append(int(row[1]))
|
||||
|
||||
return train_field_ids, test_field_ids
|
||||
|
||||
def _download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
if os.path.exists(os.path.join(self.root, 'FieldIds.csv')):
|
||||
return
|
||||
|
||||
for collection_id in self.collection_ids:
|
||||
download_radiant_mlhub_collection(collection_id, self.root, api_key)
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
image_archive_path = os.path.join(self.root, self.image_meta['filename'])
|
||||
target_archive_path = os.path.join(self.root, self.target_meta['filename'])
|
||||
for fn in [image_archive_path, target_archive_path]:
|
||||
extract_archive(fn, self.root)
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
azcopy = which('azcopy')
|
||||
azcopy('sync', self.url, self.root, '--recursive=true')
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
@ -439,13 +318,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
|
||||
image, mask = sample['image'], sample['mask']
|
||||
|
||||
assert time_step <= image.shape[0] - 1, (
|
||||
'The specified time step'
|
||||
f' does not exist, image only contains {image.shape[0]} time'
|
||||
' instances.'
|
||||
)
|
||||
|
||||
image = image[time_step, rgb_indices, :, :]
|
||||
image = image[time_step, rgb_indices]
|
||||
|
||||
fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5))
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче