CV4A Kenya Crop Type: radiant mlhub -> source cooperative (#2090)

This commit is contained in:
Adam J. Stewart 2024-07-10 17:39:31 +02:00 коммит произвёл GitHub
Родитель 83cad6017c
Коммит 32aa3492ca
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
21 изменённых файлов: 134 добавлений и 254 удалений

Просмотреть файл

@ -0,0 +1,5 @@
train,test
1,2
3,4
5
6
1 train,test
2 1,2
3 3,4
4 5
5 6

Просмотреть файл

@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import numpy as np
from PIL import Image
DTYPE = np.float32
SIZE = 2
np.random.seed(0)
all_bands = (
'B01',
'B02',
'B03',
'B04',
'B05',
'B06',
'B07',
'B08',
'B8A',
'B09',
'B11',
'B12',
'CLD',
)
for tile in range(1):
directory = os.path.join('data', str(tile))
os.makedirs(directory, exist_ok=True)
arr = np.random.randint(np.iinfo(np.int32).max, size=(SIZE, SIZE), dtype=np.int32)
img = Image.fromarray(arr)
img.save(os.path.join(directory, f'{tile}_field_id.tif'))
arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8)
img = Image.fromarray(arr)
img.save(os.path.join(directory, f'{tile}_label.tif'))
for date in ['20190606']:
directory = os.path.join(directory, date)
os.makedirs(directory, exist_ok=True)
for band in all_bands:
arr = np.random.rand(SIZE, SIZE).astype(DTYPE) * np.finfo(DTYPE).max
img = Image.fromarray(arr)
img.save(os.path.join(directory, f'{tile}_{band}_{date}.tif'))

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/0_label.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
@ -18,44 +16,23 @@ from torchgeo.datasets import (
DatasetNotFoundError,
RGBBandsMissingError,
)
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
)
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()
from torchgeo.datasets.utils import Executable
class TestCV4AKenyaCropType:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
source_md5 = '7f4dcb3f33743dddd73f453176308bfb'
labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9'
monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5)
monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5)
monkeypatch.setattr(
CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00']
)
def dataset(
self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
) -> CV4AKenyaCropType:
url = os.path.join('tests', 'data', 'cv4a_kenya_crop_type')
monkeypatch.setattr(CV4AKenyaCropType, 'url', url)
monkeypatch.setattr(CV4AKenyaCropType, 'tiles', list(map(str, range(1))))
monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606'])
monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2)
monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2)
root = str(tmp_path)
transforms = nn.Identity()
return CV4AKenyaCropType(
root,
transforms=transforms,
download=True,
api_key='',
checksum=True,
verbose=True,
)
return CV4AKenyaCropType(root, transforms=transforms, download=True)
def test_getitem(self, dataset: CV4AKenyaCropType) -> None:
x = dataset[0]
@ -66,60 +43,34 @@ class TestCV4AKenyaCropType:
assert isinstance(x['y'], torch.Tensor)
def test_len(self, dataset: CV4AKenyaCropType) -> None:
assert len(dataset) == 345
assert len(dataset) == 1
def test_add(self, dataset: CV4AKenyaCropType) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 690
def test_get_splits(self, dataset: CV4AKenyaCropType) -> None:
train_field_ids, test_field_ids = dataset.get_splits()
assert isinstance(train_field_ids, list)
assert isinstance(test_field_ids, list)
assert len(train_field_ids) == 18
assert len(test_field_ids) == 9
assert 336 in train_field_ids
assert 336 not in test_field_ids
assert 4793 in test_field_ids
assert 4793 not in train_field_ids
assert len(ds) == 2
def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None:
CV4AKenyaCropType(root=dataset.root, download=True, api_key='')
CV4AKenyaCropType(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CV4AKenyaCropType(str(tmp_path))
def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None:
with pytest.raises(AssertionError):
dataset._load_label_tile('foo')
with pytest.raises(AssertionError):
dataset._load_all_image_tiles('foo', ('B01', 'B02'))
with pytest.raises(AssertionError):
dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02'))
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type]
with pytest.raises(ValueError, match='is an invalid band name.'):
CV4AKenyaCropType(bands=('foo', 'bar'))
def test_plot(self, dataset: CV4AKenyaCropType) -> None:
dataset.plot(dataset[0], time_step=0, suptitle='Test')
plt.close()
sample = dataset[0]
dataset.plot(sample, time_step=0, suptitle='Test')
plt.close()
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample, time_step=0, suptitle='Pred')
plt.close()
def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None:
dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01']))
with pytest.raises(
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
dataset.plot(dataset[0], time_step=0, suptitle='Single Band')
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
dataset.plot(dataset[0])

Просмотреть файл

@ -3,9 +3,8 @@
"""CV4A Kenya Crop Type dataset."""
import csv
import os
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import lru_cache
import matplotlib.pyplot as plt
@ -17,16 +16,23 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import which
# TODO: read geospatial information from stac.json files
class CV4AKenyaCropType(NonGeoDataset):
"""CV4A Kenya Crop Type dataset.
"""CV4A Kenya Crop Type Competition dataset.
Used in a competition in the Computer NonGeo for Agriculture (CV4A) workshop in
ICLR 2020. See `this website <https://mlhub.earth/10.34911/rdnt.dw605x>`__
for dataset details.
The `CV4A Kenya Crop Type Competition
<https://beta.source.coop/repositories/radiantearth/african-crops-kenya-02/>`__
dataset was produced as part of the Crop Type Detection competition at the
Computer Vision for Agriculture (CV4A) Workshop at the ICLR 2020 conference.
The objective of the competition was to create a machine learning model to
classify fields by crop type from images collected during the growing season
by the Sentinel-2 satellites.
See the `dataset documentation
<https://data.source.coop/radiantearth/african-crops-kenya-02/Documentation.pdf>`__
for details.
Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time.
@ -54,29 +60,12 @@ class CV4AKenyaCropType(NonGeoDataset):
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
dataset from Source Cooperative.
"""
collection_ids = [
'ref_african_crops_kenya_02_labels',
'ref_african_crops_kenya_02_source',
]
image_meta = {
'filename': 'ref_african_crops_kenya_02_source.tar.gz',
'md5': '9c2004782f6dc83abb1bf45ba4d0da46',
}
target_meta = {
'filename': 'ref_african_crops_kenya_02_labels.tar.gz',
'md5': '93949abd0ae82ba564f5a933cefd8215',
}
tile_names = [
'ref_african_crops_kenya_02_tile_00',
'ref_african_crops_kenya_02_tile_01',
'ref_african_crops_kenya_02_tile_02',
'ref_african_crops_kenya_02_tile_03',
]
url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge'
tiles = list(map(str, range(4)))
dates = [
'20190606',
'20190701',
@ -92,7 +81,7 @@ class CV4AKenyaCropType(NonGeoDataset):
'20191004',
'20191103',
]
band_names = (
all_bands = (
'B01',
'B02',
'B03',
@ -107,7 +96,6 @@ class CV4AKenyaCropType(NonGeoDataset):
'B12',
'CLD',
)
rgb_bands = ['B04', 'B03', 'B02']
# Same for all tiles
@ -119,12 +107,9 @@ class CV4AKenyaCropType(NonGeoDataset):
root: str = 'data',
chip_size: int = 256,
stride: int = 128,
bands: tuple[str, ...] = band_names,
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
api_key: str | None = None,
checksum: bool = False,
verbose: bool = False,
) -> None:
"""Initialize a new CV4A Kenya Crop Type Dataset instance.
@ -137,32 +122,25 @@ class CV4AKenyaCropType(NonGeoDataset):
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
verbose: if True, print messages when new tiles are loaded
Raises:
AssertionError: If *bands* are invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
assert set(bands) <= set(self.all_bands)
self.root = root
self.chip_size = chip_size
self.stride = stride
self.bands = bands
self.transforms = transforms
self.checksum = checksum
self.verbose = verbose
self.download = download
if download:
self._download(api_key)
if not self._check_integrity():
raise DatasetNotFoundError(self)
self._verify()
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
for tile_index in range(len(self.tile_names)):
for tile_index in range(len(self.tiles)):
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
self.tile_height - self.chip_size
]:
@ -181,10 +159,10 @@ class CV4AKenyaCropType(NonGeoDataset):
data, labels, field ids, and metadata at that index
"""
tile_index, y, x = self.chips_metadata[index]
tile_name = self.tile_names[tile_index]
tile = self.tiles[tile_index]
img = self._load_all_image_tiles(tile_name, self.bands)
labels, field_ids = self._load_label_tile(tile_name)
img = self._load_all_image_tiles(tile)
labels, field_ids = self._load_label_tile(tile)
img = img[:, :, y : y + self.chip_size, x : x + self.chip_size]
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
@ -213,193 +191,94 @@ class CV4AKenyaCropType(NonGeoDataset):
return len(self.chips_metadata)
@lru_cache(maxsize=128)
def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]:
def _load_label_tile(self, tile: str) -> tuple[Tensor, Tensor]:
"""Load a single _tile_ of labels and field_ids.
Args:
tile_name: name of tile to load
tile: name of tile to load
Returns:
tuple of labels and field ids
Raises:
AssertionError: if ``tile_name`` is invalid
"""
assert tile_name in self.tile_names
directory = os.path.join(self.root, 'data', tile)
if self.verbose:
print(f'Loading labels/field_ids for {tile_name}')
directory = os.path.join(
self.root, 'ref_african_crops_kenya_02_labels', tile_name + '_label'
)
with Image.open(os.path.join(directory, 'labels.tif')) as img:
with Image.open(os.path.join(directory, f'{tile}_label.tif')) as img:
array: np.typing.NDArray[np.int_] = np.array(img)
labels = torch.from_numpy(array)
with Image.open(os.path.join(directory, 'field_ids.tif')) as img:
with Image.open(os.path.join(directory, f'{tile}_field_id.tif')) as img:
array = np.array(img)
field_ids = torch.from_numpy(array)
return (labels, field_ids)
def _validate_bands(self, bands: tuple[str, ...]) -> None:
"""Validate list of bands.
Args:
bands: user-provided tuple of bands to load
Raises:
AssertionError: if ``bands`` is not a tuple
ValueError: if an invalid band name is provided
"""
assert isinstance(bands, tuple), 'The list of bands must be a tuple'
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")
return labels, field_ids
@lru_cache(maxsize=128)
def _load_all_image_tiles(
self, tile_name: str, bands: tuple[str, ...] = band_names
) -> Tensor:
def _load_all_image_tiles(self, tile: str) -> Tensor:
"""Load all the imagery (across time) for a single _tile_.
Optionally allows for subsetting of the bands that are loaded.
Args:
tile_name: name of tile to load
bands: tuple of bands to load
tile: name of tile to load
Returns:
imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of
points in time, 3035 is the tile height, and 2016 is the tile width
Raises:
AssertionError: if ``tile_name`` is invalid
points in time, 3035 is the tile height, and 2016 is the tile width
"""
assert tile_name in self.tile_names
if self.verbose:
print(f'Loading all imagery for {tile_name}')
img = torch.zeros(
len(self.dates),
len(bands),
len(self.bands),
self.tile_height,
self.tile_width,
dtype=torch.float32,
)
for date_index, date in enumerate(self.dates):
img[date_index] = self._load_single_image_tile(tile_name, date, self.bands)
img[date_index] = self._load_single_image_tile(tile, date)
return img
@lru_cache(maxsize=128)
def _load_single_image_tile(
self, tile_name: str, date: str, bands: tuple[str, ...]
) -> Tensor:
def _load_single_image_tile(self, tile: str, date: str) -> Tensor:
"""Load the imagery for a single tile for a single date.
Optionally allows for subsetting of the bands that are loaded.
Args:
tile_name: name of tile to load
tile: name of tile to load
date: date of tile to load
bands: bands to load
Returns:
array containing a single image tile
Raises:
AssertionError: if ``tile_name`` or ``date`` is invalid
"""
assert tile_name in self.tile_names
assert date in self.dates
if self.verbose:
print(f'Loading imagery for {tile_name} at {date}')
directory = os.path.join(self.root, 'data', tile, date)
img = torch.zeros(
len(bands), self.tile_height, self.tile_width, dtype=torch.float32
len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32
)
for band_index, band_name in enumerate(self.bands):
filepath = os.path.join(
self.root,
'ref_african_crops_kenya_02_source',
f'{tile_name}_{date}',
f'{band_name}.tif',
)
filepath = os.path.join(directory, f'{tile}_{band_name}_{date}.tif')
with Image.open(filepath) as band_img:
array: np.typing.NDArray[np.int_] = np.array(band_img)
img[band_index] = torch.from_numpy(array)
return img
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Returns:
True if dataset files are found and/or MD5s match, else False
"""
images: bool = check_integrity(
os.path.join(self.root, self.image_meta['filename']),
self.image_meta['md5'] if self.checksum else None,
)
targets: bool = check_integrity(
os.path.join(self.root, self.target_meta['filename']),
self.target_meta['md5'] if self.checksum else None,
)
return images and targets
def get_splits(self) -> tuple[list[int], list[int]]:
"""Get the field_ids for the train/test splits from the dataset directory.
Returns:
list of training field_ids and list of testing field_ids
"""
train_field_ids = []
test_field_ids = []
splits_fn = os.path.join(
self.root,
'ref_african_crops_kenya_02_labels',
'_common',
'field_train_test_ids.csv',
)
with open(splits_fn, newline='') as f:
reader = csv.reader(f)
# Skip header row
next(reader)
for row in reader:
train_field_ids.append(int(row[0]))
if row[1]:
test_field_ids.append(int(row[1]))
return train_field_ids, test_field_ids
def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
"""
if self._check_integrity():
print('Files already downloaded and verified')
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the files already exist
if os.path.exists(os.path.join(self.root, 'FieldIds.csv')):
return
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)
# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)
image_archive_path = os.path.join(self.root, self.image_meta['filename'])
target_archive_path = os.path.join(self.root, self.target_meta['filename'])
for fn in [image_archive_path, target_archive_path]:
extract_archive(fn, self.root)
# Download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset."""
os.makedirs(self.root, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', self.url, self.root, '--recursive=true')
def plot(
self,
@ -439,13 +318,7 @@ class CV4AKenyaCropType(NonGeoDataset):
image, mask = sample['image'], sample['mask']
assert time_step <= image.shape[0] - 1, (
'The specified time step'
f' does not exist, image only contains {image.shape[0]} time'
' instances.'
)
image = image[time_step, rgb_indices, :, :]
image = image[time_step, rgb_indices]
fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5))