From 32aa3492ca737fa0f90ea7bfe885753796512d4e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 10 Jul 2024 17:39:31 +0200 Subject: [PATCH] CV4A Kenya Crop Type: radiant mlhub -> source cooperative (#2090) --- tests/data/cv4a_kenya_crop_type/FieldIds.csv | 5 + tests/data/cv4a_kenya_crop_type/data.py | 51 ++++ .../data/0/0_field_id.tif | Bin 0 -> 150 bytes .../cv4a_kenya_crop_type/data/0/0_label.tif | Bin 0 -> 126 bytes .../data/0/20190606/0_B01_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B02_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B03_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B04_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B05_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B06_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B07_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B08_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B09_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B11_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B12_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_B8A_20190606.tif | Bin 0 -> 150 bytes .../data/0/20190606/0_CLD_20190606.tif | Bin 0 -> 150 bytes .../ref_african_crops_kenya_02_labels.tar.gz | Bin 603 -> 0 bytes .../ref_african_crops_kenya_02_source.tar.gz | Bin 638 -> 0 bytes tests/datasets/test_cv4a_kenya_crop_type.py | 85 ++---- torchgeo/datasets/cv4a_kenya_crop_type.py | 247 +++++------------- 21 files changed, 134 insertions(+), 254 deletions(-) create mode 100644 tests/data/cv4a_kenya_crop_type/FieldIds.csv create mode 100755 tests/data/cv4a_kenya_crop_type/data.py create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/0_label.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif create mode 100644 tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif delete mode 100644 tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz delete mode 100644 tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz diff --git a/tests/data/cv4a_kenya_crop_type/FieldIds.csv b/tests/data/cv4a_kenya_crop_type/FieldIds.csv new file mode 100644 index 000000000..04ff33b25 --- /dev/null +++ b/tests/data/cv4a_kenya_crop_type/FieldIds.csv @@ -0,0 +1,5 @@ +train,test +1,2 +3,4 +5 +6 diff --git a/tests/data/cv4a_kenya_crop_type/data.py b/tests/data/cv4a_kenya_crop_type/data.py new file mode 100755 index 000000000..e55ffa45b --- /dev/null +++ b/tests/data/cv4a_kenya_crop_type/data.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +from PIL import Image + +DTYPE = np.float32 +SIZE = 2 + +np.random.seed(0) + +all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', +) + +for tile in range(1): + directory = os.path.join('data', str(tile)) + os.makedirs(directory, exist_ok=True) + + arr = np.random.randint(np.iinfo(np.int32).max, size=(SIZE, SIZE), dtype=np.int32) + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_field_id.tif')) + + arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8) + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_label.tif')) + + for date in ['20190606']: + directory = os.path.join(directory, date) + os.makedirs(directory, exist_ok=True) + + for band in all_bands: + arr = np.random.rand(SIZE, SIZE).astype(DTYPE) * np.finfo(DTYPE).max + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_{band}_{date}.tif')) diff --git a/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif b/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif new file mode 100644 index 0000000000000000000000000000000000000000..f72a67720917d26126c56811efe1b39489778170 GIT binary patch literal 150 zcmebD)MDUZU|`^4U|?inU<9(5fS3`9&BVwI7FPg@Geg-Rb!}UMsVsSjXCuLl4w7YaQ9 literal 0 HcmV?d00001 diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif new file mode 100644 index 0000000000000000000000000000000000000000..ea661cbc40e8dd2c86bef11f807269d4f930e1c6 GIT binary patch literal 150 zcmebD)MDUZU|`^4U|?inU<9(5fS3`9&BVwI7FPg@Geg-Rb!fIsb>DRy>QBy+uAi2(tPTJ%Nea>c literal 0 HcmV?d00001 diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif new file mode 100644 index 0000000000000000000000000000000000000000..1e3f7ce38b80ede98a370844a4122004aa4ab9c4 GIT binary patch literal 150 zcmebD)MDUZU|`^4U|?inU<9(5fS3`9&BVwI7FPg@Geg-Rb!tbCb>&#aq*FAiAvz!rQ>j5cX326WT literal 0 HcmV?d00001 diff --git a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz b/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz deleted file mode 100644 index 1c642bf9c73bc230e09d947694c95b88b5201058..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 603 zcmV-h0;K&PiwFP|z0_a;1MQZ-Zxb;T$9+jvQd5aC01HTw*pMmydCpBOY@K0XJ00mo zIyJeVyK5^XCZsBz7#LVs5i9M#z+b^%z}!E9=SzlDM0<%6f=b^nRpDC9EiM{nYO{){XP2dv@F3Two~>M}lo2+e&Xl&>_5OKt=w zZL|2)^&I9!l(pNWcpBwJd>m!PpY!U*uPIrr{$Y1G7!1<|zJs~+w-WT1LMzmN70j-G zFNyp6VUa~i8WwS0gvoxs-OZ2c=Wqi1`f~lHv{U+94(Vb?~CF(!U-!78g)Esr^KRo)T^oJWk;rYJ`_V%8T7T7z4cunGgSGJ1ZnD{QM zS&6S#{JY*-X}>^suM)ps@gGcbTzgoFzpVH_Pd3RG*dNfokD8B8mXOmy@%ho$Y!}?!Z?9fIKTHqO;c@z&cvnumB$inJ;|HR96;2%ibJf35uKwM1 pf%U%vu>Ke0y4OE~AP9mW2!bF8f*=TjAPDk*@*B~w9%=w6005g4F4+J8 diff --git a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz b/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz deleted file mode 100644 index f5e0e289137c0f08dff05399b56428412cbdb85b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 638 zcmV-^0)hP>iwFR9xYS?(1MQhjYZE~j$Dd6~O|f7-s27ERH$k13o!R!(7rhj`>A~Bw zgakr0kN!$0+W4ktyHce_;F&n89xD9ejKkDJ$jNy+Ed|AG<4@I^lrB2_|e2`lV5|CPU~ zzg3om7D`IgKLpFI|DM=)gt!6!dHwhI?@`Ca zYeZ3-BA5B2mba?B9M!J1uh#PG(Tb0s+{cYt`^{Q@s~S@tw`=WBYx(!L>$KtGH+OuU zwyEplp6fW7u2ak9Z(M#lZPA*`KScHGINM(AdKVqHC}-P||K(SRR?2H=QRHLfZ&*Y> zHlohyw7hq#{`>p8Yn_v6*?-kXqIWNgk9R)jJ1*V%^6dHU;pkvAJ{rBI=)w~(iDmnr zu>Dsv|26y%0sQ9y|C|1Aas8Lp`)?Hd4*~q=0sou+Tc-aesr--Oe+b||5BOjA|Iw4j z-1*->|B+f@{T~GQ&jbE9=KsxI_WED?`QIA&9|HK#1O7Msmx?|A>)G=^?EixR|9Qaw z=Kf!D_y5)VKaKZ)L4f}};D6J9?)-0L_5NSN{}8}`9`L{Ezh?WN%>2jm{~*AB9`L`p z|L3m%trD~KzlHxHfd4$;e{=qix%0nLSpNqE{_}wUP5%wof330p4+#9{aoPO`00000 Y00000000000KmV;PjU1X7ywWJ0Qzxb;s5{u diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index ad0e26ed0..34f67036d 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -18,44 +16,23 @@ from torchgeo.datasets import ( DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestCV4AKenyaCropType: @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - source_md5 = '7f4dcb3f33743dddd73f453176308bfb' - labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9' - monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5) - monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5) - monkeypatch.setattr( - CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00'] - ) + def dataset( + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> CV4AKenyaCropType: + url = os.path.join('tests', 'data', 'cv4a_kenya_crop_type') + monkeypatch.setattr(CV4AKenyaCropType, 'url', url) + monkeypatch.setattr(CV4AKenyaCropType, 'tiles', list(map(str, range(1)))) monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606']) + monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2) + monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2) root = str(tmp_path) transforms = nn.Identity() - return CV4AKenyaCropType( - root, - transforms=transforms, - download=True, - api_key='', - checksum=True, - verbose=True, - ) + return CV4AKenyaCropType(root, transforms=transforms, download=True) def test_getitem(self, dataset: CV4AKenyaCropType) -> None: x = dataset[0] @@ -66,60 +43,34 @@ class TestCV4AKenyaCropType: assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: CV4AKenyaCropType) -> None: - assert len(dataset) == 345 + assert len(dataset) == 1 def test_add(self, dataset: CV4AKenyaCropType) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 690 - - def test_get_splits(self, dataset: CV4AKenyaCropType) -> None: - train_field_ids, test_field_ids = dataset.get_splits() - assert isinstance(train_field_ids, list) - assert isinstance(test_field_ids, list) - assert len(train_field_ids) == 18 - assert len(test_field_ids) == 9 - assert 336 in train_field_ids - assert 336 not in test_field_ids - assert 4793 in test_field_ids - assert 4793 not in train_field_ids + assert len(ds) == 2 def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None: - CV4AKenyaCropType(root=dataset.root, download=True, api_key='') + CV4AKenyaCropType(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): CV4AKenyaCropType(str(tmp_path)) - def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None: - with pytest.raises(AssertionError): - dataset._load_label_tile('foo') - - with pytest.raises(AssertionError): - dataset._load_all_image_tiles('foo', ('B01', 'B02')) - - with pytest.raises(AssertionError): - dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02')) - def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type] - - with pytest.raises(ValueError, match='is an invalid band name.'): CV4AKenyaCropType(bands=('foo', 'bar')) def test_plot(self, dataset: CV4AKenyaCropType) -> None: - dataset.plot(dataset[0], time_step=0, suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, time_step=0, suptitle='Test') + plt.close() sample['prediction'] = sample['mask'].clone() dataset.plot(sample, time_step=0, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None: dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01'])) - with pytest.raises( - RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' - ): - dataset.plot(dataset[0], time_step=0, suptitle='Single Band') + match = 'Dataset does not contain some of the RGB bands' + with pytest.raises(RGBBandsMissingError, match=match): + dataset.plot(dataset[0]) diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index a532c1539..feeb6ff0e 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -3,9 +3,8 @@ """CV4A Kenya Crop Type dataset.""" -import csv import os -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache import matplotlib.pyplot as plt @@ -17,16 +16,23 @@ from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which -# TODO: read geospatial information from stac.json files class CV4AKenyaCropType(NonGeoDataset): - """CV4A Kenya Crop Type dataset. + """CV4A Kenya Crop Type Competition dataset. - Used in a competition in the Computer NonGeo for Agriculture (CV4A) workshop in - ICLR 2020. See `this website `__ - for dataset details. + The `CV4A Kenya Crop Type Competition + `__ + dataset was produced as part of the Crop Type Detection competition at the + Computer Vision for Agriculture (CV4A) Workshop at the ICLR 2020 conference. + The objective of the competition was to create a machine learning model to + classify fields by crop type from images collected during the growing season + by the Sentinel-2 satellites. + + See the `dataset documentation + `__ + for details. Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time. @@ -54,29 +60,12 @@ class CV4AKenyaCropType(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. """ - collection_ids = [ - 'ref_african_crops_kenya_02_labels', - 'ref_african_crops_kenya_02_source', - ] - image_meta = { - 'filename': 'ref_african_crops_kenya_02_source.tar.gz', - 'md5': '9c2004782f6dc83abb1bf45ba4d0da46', - } - target_meta = { - 'filename': 'ref_african_crops_kenya_02_labels.tar.gz', - 'md5': '93949abd0ae82ba564f5a933cefd8215', - } - - tile_names = [ - 'ref_african_crops_kenya_02_tile_00', - 'ref_african_crops_kenya_02_tile_01', - 'ref_african_crops_kenya_02_tile_02', - 'ref_african_crops_kenya_02_tile_03', - ] + url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge' + tiles = list(map(str, range(4))) dates = [ '20190606', '20190701', @@ -92,7 +81,7 @@ class CV4AKenyaCropType(NonGeoDataset): '20191004', '20191103', ] - band_names = ( + all_bands = ( 'B01', 'B02', 'B03', @@ -107,7 +96,6 @@ class CV4AKenyaCropType(NonGeoDataset): 'B12', 'CLD', ) - rgb_bands = ['B04', 'B03', 'B02'] # Same for all tiles @@ -119,12 +107,9 @@ class CV4AKenyaCropType(NonGeoDataset): root: str = 'data', chip_size: int = 256, stride: int = 128, - bands: tuple[str, ...] = band_names, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new CV4A Kenya Crop Type Dataset instance. @@ -137,32 +122,25 @@ class CV4AKenyaCropType(NonGeoDataset): transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: + AssertionError: If *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) + assert set(bands) <= set(self.all_bands) self.root = root self.chip_size = chip_size self.stride = stride self.bands = bands self.transforms = transforms - self.checksum = checksum - self.verbose = verbose + self.download = download - if download: - self._download(api_key) - - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() # Calculate the indices that we will use over all tiles self.chips_metadata = [] - for tile_index in range(len(self.tile_names)): + for tile_index in range(len(self.tiles)): for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ self.tile_height - self.chip_size ]: @@ -181,10 +159,10 @@ class CV4AKenyaCropType(NonGeoDataset): data, labels, field ids, and metadata at that index """ tile_index, y, x = self.chips_metadata[index] - tile_name = self.tile_names[tile_index] + tile = self.tiles[tile_index] - img = self._load_all_image_tiles(tile_name, self.bands) - labels, field_ids = self._load_label_tile(tile_name) + img = self._load_all_image_tiles(tile) + labels, field_ids = self._load_label_tile(tile) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] labels = labels[y : y + self.chip_size, x : x + self.chip_size] @@ -213,193 +191,94 @@ class CV4AKenyaCropType(NonGeoDataset): return len(self.chips_metadata) @lru_cache(maxsize=128) - def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]: + def _load_label_tile(self, tile: str) -> tuple[Tensor, Tensor]: """Load a single _tile_ of labels and field_ids. Args: - tile_name: name of tile to load + tile: name of tile to load Returns: tuple of labels and field ids - - Raises: - AssertionError: if ``tile_name`` is invalid """ - assert tile_name in self.tile_names + directory = os.path.join(self.root, 'data', tile) - if self.verbose: - print(f'Loading labels/field_ids for {tile_name}') - - directory = os.path.join( - self.root, 'ref_african_crops_kenya_02_labels', tile_name + '_label' - ) - - with Image.open(os.path.join(directory, 'labels.tif')) as img: + with Image.open(os.path.join(directory, f'{tile}_label.tif')) as img: array: np.typing.NDArray[np.int_] = np.array(img) labels = torch.from_numpy(array) - with Image.open(os.path.join(directory, 'field_ids.tif')) as img: + with Image.open(os.path.join(directory, f'{tile}_field_id.tif')) as img: array = np.array(img) field_ids = torch.from_numpy(array) - return (labels, field_ids) - - def _validate_bands(self, bands: tuple[str, ...]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - AssertionError: if ``bands`` is not a tuple - ValueError: if an invalid band name is provided - """ - assert isinstance(bands, tuple), 'The list of bands must be a tuple' - for band in bands: - if band not in self.band_names: - raise ValueError(f"'{band}' is an invalid band name.") + return labels, field_ids @lru_cache(maxsize=128) - def _load_all_image_tiles( - self, tile_name: str, bands: tuple[str, ...] = band_names - ) -> Tensor: + def _load_all_image_tiles(self, tile: str) -> Tensor: """Load all the imagery (across time) for a single _tile_. Optionally allows for subsetting of the bands that are loaded. Args: - tile_name: name of tile to load - bands: tuple of bands to load + tile: name of tile to load Returns: imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of - points in time, 3035 is the tile height, and 2016 is the tile width - - Raises: - AssertionError: if ``tile_name`` is invalid + points in time, 3035 is the tile height, and 2016 is the tile width """ - assert tile_name in self.tile_names - - if self.verbose: - print(f'Loading all imagery for {tile_name}') - img = torch.zeros( len(self.dates), - len(bands), + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32, ) for date_index, date in enumerate(self.dates): - img[date_index] = self._load_single_image_tile(tile_name, date, self.bands) + img[date_index] = self._load_single_image_tile(tile, date) return img @lru_cache(maxsize=128) - def _load_single_image_tile( - self, tile_name: str, date: str, bands: tuple[str, ...] - ) -> Tensor: + def _load_single_image_tile(self, tile: str, date: str) -> Tensor: """Load the imagery for a single tile for a single date. - Optionally allows for subsetting of the bands that are loaded. - Args: - tile_name: name of tile to load + tile: name of tile to load date: date of tile to load - bands: bands to load Returns: array containing a single image tile - - Raises: - AssertionError: if ``tile_name`` or ``date`` is invalid """ - assert tile_name in self.tile_names - assert date in self.dates - - if self.verbose: - print(f'Loading imagery for {tile_name} at {date}') - + directory = os.path.join(self.root, 'data', tile, date) img = torch.zeros( - len(bands), self.tile_height, self.tile_width, dtype=torch.float32 + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32 ) for band_index, band_name in enumerate(self.bands): - filepath = os.path.join( - self.root, - 'ref_african_crops_kenya_02_source', - f'{tile_name}_{date}', - f'{band_name}.tif', - ) + filepath = os.path.join(directory, f'{tile}_{band_name}_{date}.tif') with Image.open(filepath) as band_img: array: np.typing.NDArray[np.int_] = np.array(band_img) img[band_index] = torch.from_numpy(array) return img - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta['filename']), - self.image_meta['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta['filename']), - self.target_meta['md5'] if self.checksum else None, - ) - - return images and targets - - def get_splits(self) -> tuple[list[int], list[int]]: - """Get the field_ids for the train/test splits from the dataset directory. - - Returns: - list of training field_ids and list of testing field_ids - """ - train_field_ids = [] - test_field_ids = [] - splits_fn = os.path.join( - self.root, - 'ref_african_crops_kenya_02_labels', - '_common', - 'field_train_test_ids.csv', - ) - - with open(splits_fn, newline='') as f: - reader = csv.reader(f) - - # Skip header row - next(reader) - - for row in reader: - train_field_ids.append(int(row[0])) - if row[1]: - test_field_ids.append(int(row[1])) - - return train_field_ids, test_field_ids - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(os.path.join(self.root, 'FieldIds.csv')): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) - image_archive_path = os.path.join(self.root, self.image_meta['filename']) - target_archive_path = os.path.join(self.root, self.target_meta['filename']) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -439,13 +318,7 @@ class CV4AKenyaCropType(NonGeoDataset): image, mask = sample['image'], sample['mask'] - assert time_step <= image.shape[0] - 1, ( - 'The specified time step' - f' does not exist, image only contains {image.shape[0]} time' - ' instances.' - ) - - image = image[time_step, rgb_indices, :, :] + image = image[time_step, rgb_indices] fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5))