зеркало из https://github.com/microsoft/torchgeo.git
Finished draft version of the CV4A Kenya Crop Type Dataset
This commit is contained in:
Родитель
fb34d1906e
Коммит
4e4750df32
|
@ -1,5 +1,5 @@
|
|||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .landcoverai import LandCoverAI
|
||||
from .nwpu import VHR10
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
|
||||
__all__ = ("LandCoverAI", "VHR10", "CV4AKenyaCropType")
|
||||
__all__ = ("CV4AKenyaCropType", "LandCoverAI", "VHR10")
|
||||
|
|
|
@ -1,25 +1,34 @@
|
|||
import os
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, List
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
||||
|
||||
from .utils import working_dir
|
||||
from torchvision.datasets.utils import check_integrity
|
||||
|
||||
|
||||
class CV4AKenyaCropType(VisionDataset):
|
||||
"""CV4A Kenya Crop Type dataset.
|
||||
|
||||
Used in a competition in the Computer Vision for Agriculture (CV4A) workshop in ICLR 2020.
|
||||
See the competition website <https://zindi.africa/competitions/iclr-workshop-challenge-2-radiant-earth-computer-vision-for-crop-recognition>.
|
||||
Used in a competition in the Computer Vision for Agriculture (CV4A) workshop in
|
||||
ICLR 2020. See the competition website <https://zindi.africa/competitions/iclr-workshop-challenge-2-radiant-earth-computer-vision-for-crop-recognition>.
|
||||
|
||||
Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time.
|
||||
|
||||
Each tile has:
|
||||
* 13 multi-band observations throughout the growing season. Each observation includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer. The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12] (refer to Sentinel-2 documentation for more information about the bands). The cloud probability layer is a product of the Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated cloud probability (0-100%) per pixel. All of the bands are mapped to a common 10 m spatial resolution grid.
|
||||
* 13 multi-band observations throughout the growing season. Each observation
|
||||
includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer.
|
||||
The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A,
|
||||
B09, B11, B12] (refer to Sentinel-2 documentation for more information about
|
||||
the bands). The cloud probability layer is a product of the
|
||||
Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated
|
||||
cloud probability (0-100%) per pixel. All of the bands are mapped to a common
|
||||
10 m spatial resolution grid.
|
||||
* A raster layer indicating the crop ID for the fields in the training set.
|
||||
* A raster layer indicating field IDs for the fields (both training and test sets). Fields with a crop ID 0 are the test fields.
|
||||
* A raster layer indicating field IDs for the fields (both training and test sets).
|
||||
Fields with a crop ID 0 are the test fields.
|
||||
|
||||
There are 3,286 fields in the train set and 1,402 fields in the test set.
|
||||
|
||||
|
@ -41,21 +50,68 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
"md5": "93949abd0ae82ba564f5a933cefd8215",
|
||||
}
|
||||
|
||||
tile_names = [
|
||||
"ref_african_crops_kenya_02_tile_00",
|
||||
"ref_african_crops_kenya_02_tile_01",
|
||||
"ref_african_crops_kenya_02_tile_02",
|
||||
"ref_african_crops_kenya_02_tile_03",
|
||||
]
|
||||
dates = [
|
||||
"20190606",
|
||||
"20190701",
|
||||
"20190706",
|
||||
"20190711",
|
||||
"20190721",
|
||||
"20190805",
|
||||
"20190815",
|
||||
"20190825",
|
||||
"20190909",
|
||||
"20190919",
|
||||
"20190924",
|
||||
"20191004",
|
||||
"20191103",
|
||||
]
|
||||
band_names = (
|
||||
"B01",
|
||||
"B02",
|
||||
"B03",
|
||||
"B04",
|
||||
"B05",
|
||||
"B06",
|
||||
"B07",
|
||||
"B08",
|
||||
"B8A",
|
||||
"B09",
|
||||
"B11",
|
||||
"B12",
|
||||
"CLD",
|
||||
)
|
||||
|
||||
# Same for all tiles
|
||||
tile_height = 3035
|
||||
tile_width = 2016
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: Optional[Tuple[str]] = None,
|
||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
target_transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None
|
||||
api_key: Optional[str] = None,
|
||||
verbose: bool = False
|
||||
) -> None:
|
||||
"""Initialize a new CV4A Kenya Crop Type dataset instance.
|
||||
"""Initialize a new CV4A Kenya Crop Type Dataset instance.
|
||||
|
||||
Parameters:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train" or "test"
|
||||
chip_size (int): size of chips
|
||||
stride (int): spacing between chips, if less than chip_size, then there
|
||||
will be overlap between chips
|
||||
bands (tuple): the subset of bands to load
|
||||
transform: a function/transform that takes in a PIL image and returns a
|
||||
transformed version
|
||||
target_transform: a function/transform that takes in the target and
|
||||
|
@ -65,13 +121,15 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
download: if True, download dataset and store it in the root directory
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
"""
|
||||
assert split in ["train", "test"]
|
||||
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.verbose = verbose
|
||||
|
||||
if download:
|
||||
if api_key is None:
|
||||
raise RuntimeError("You must pass an MLHub API key if download=True. See https://www.mlhub.earth/ to register for API access.")
|
||||
raise RuntimeError(
|
||||
"You must pass an MLHub API key if download=True. "
|
||||
+ "See https://www.mlhub.earth/ to register for API access."
|
||||
)
|
||||
else:
|
||||
self.download(api_key)
|
||||
|
||||
|
@ -81,37 +139,155 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
+ "You can use download=True to download it"
|
||||
)
|
||||
|
||||
# Calculate the indices that we will use over all tiles
|
||||
self.bands = self._validate_bands(bands)
|
||||
self.chip_size = chip_size
|
||||
self.chips_metadata = []
|
||||
for tile_index in range(len(self.tile_names)):
|
||||
for y in list(range(0, self.tile_height - self.chip_size, stride)) + [
|
||||
self.tile_height - self.chip_size
|
||||
]:
|
||||
for x in list(range(0, self.tile_width - self.chip_size, stride)) + [
|
||||
self.tile_width - self.chip_size
|
||||
]:
|
||||
self.chips_metadata.append((tile_index, y, x))
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
def __getitem__(self, index: int) -> Dict:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
TODO: See the following link for example loading code from competition organizers: https://github.com/radiantearth/mlhub-tutorials/blob/main/notebooks/2020%20CV4A%20Crop%20Type%20Challenge/cv4a-crop-challenge-load-data.ipynb
|
||||
|
||||
Parameters:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
data, label, and field ids at that index
|
||||
"""
|
||||
raise NotImplementedError("") # TODO: Implement after discussion about how to handle tile datasets
|
||||
assert index < len(self)
|
||||
|
||||
tile_index, y, x = self.chips_metadata[index]
|
||||
tile_name = self.tile_names[tile_index]
|
||||
|
||||
img = self._load_all_image_tiles(tile_name, self.bands)
|
||||
labels, field_ids = self._load_label_tile(tile_name)
|
||||
|
||||
img = img[:, :, y : y + self.chip_size, x : x + self.chip_size]
|
||||
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
|
||||
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"labels": labels,
|
||||
"field_ids": field_ids,
|
||||
"metadata": (tile_index, y, x)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
"""Return the number of chips in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
raise NotImplementedError("") # TODO: Implement after discussion about how to handle tile datasets
|
||||
return len(self.chips_metadata)
|
||||
|
||||
def _load_image(self, id_: str) -> Image.Image:
|
||||
"""
|
||||
"""
|
||||
raise NotImplementedError("") # TODO: Implement after discussion about how to handle tile datasets
|
||||
@lru_cache
|
||||
def _load_label_tile(self, tile_name_: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
""" Loads a single _tile_ of labels and field_ids"""
|
||||
assert tile_name_ in self.tile_names
|
||||
|
||||
def _load_target(self, id_: str) -> Image.Image:
|
||||
if self.verbose:
|
||||
print(f"Loading labels/field_ids for {tile_name_}")
|
||||
|
||||
labels = np.array(
|
||||
Image.open(
|
||||
os.path.join(
|
||||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_labels",
|
||||
tile_name_ + "_label",
|
||||
"labels.tif",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
field_ids = np.array(
|
||||
Image.open(
|
||||
os.path.join(
|
||||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_labels",
|
||||
tile_name_ + "_label",
|
||||
"field_ids.tif",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return (labels, field_ids)
|
||||
|
||||
def _validate_bands(self, bands: Optional[Tuple[str]]) -> Tuple[str]:
|
||||
""" Routine for validating a list of bands / filling in a default value """
|
||||
|
||||
if bands is None:
|
||||
return self.band_names
|
||||
else:
|
||||
for band in bands:
|
||||
if band not in self.band_names:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
return bands
|
||||
|
||||
@lru_cache
|
||||
def _load_all_image_tiles(
|
||||
self, tile_name_: str, bands: Optional[Tuple[str]] = None
|
||||
) -> np.ndarray:
|
||||
""" Load all the imagery (across time) for a single _tile_. Optionally allows
|
||||
for subsetting of the bands that are loaded.
|
||||
|
||||
Returns
|
||||
imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of
|
||||
points in time, 3035 is the tile height, and 2016 is the tile width.
|
||||
"""
|
||||
"""
|
||||
raise NotImplementedError("") # TODO: Implement after discussion about how to handle tile datasets
|
||||
assert tile_name_ in self.tile_names
|
||||
bands = self._validate_bands(bands)
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading all imagery for {tile_name_}")
|
||||
|
||||
img = np.zeros(
|
||||
(len(self.dates), len(bands), self.tile_height, self.tile_width),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
for date_index, date in enumerate(self.dates):
|
||||
img[date_index] = self._load_single_image_tile(tile_name_, date, bands)
|
||||
|
||||
return img
|
||||
|
||||
@lru_cache
|
||||
def _load_single_image_tile(
|
||||
self, tile_name_: str, date_: str, bands: Optional[Tuple[str]] = None
|
||||
) -> np.ndarray:
|
||||
""" Loads the imagery for a single tile for a single date. Optionally allows
|
||||
for subsetting of the bands that are loaded."""
|
||||
assert tile_name_ in self.tile_names
|
||||
assert date_ in self.dates
|
||||
bands = self._validate_bands(bands)
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading imagery for {tile_name_} at {date_}")
|
||||
|
||||
img = np.zeros(
|
||||
(len(bands), self.tile_height, self.tile_width), dtype=np.float32
|
||||
)
|
||||
for band_index, band_name in enumerate(bands):
|
||||
img_fn = os.path.join(
|
||||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_source",
|
||||
f"{tile_name_}_{date_}",
|
||||
f"{band_name}.tif",
|
||||
)
|
||||
band_img = np.array(Image.open(img_fn))
|
||||
img[band_index] = band_img
|
||||
|
||||
return img
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
@ -131,7 +307,34 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
|
||||
return images and targets
|
||||
|
||||
def download(self, api_key) -> None:
|
||||
def get_splits(self) -> Tuple[List[int], List[int]]:
|
||||
""" Gets the field_ids for the train/test splits from the dataset directory
|
||||
|
||||
Returns:
|
||||
list of training field_ids and list of testing field_ids
|
||||
"""
|
||||
|
||||
train_field_ids = []
|
||||
test_field_ids = []
|
||||
splits_fn = os.path.join(
|
||||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_labels",
|
||||
"_common",
|
||||
"field_train_test_ids.csv"
|
||||
)
|
||||
|
||||
with open(splits_fn, "r") as f:
|
||||
lines = f.read().strip().split("\n")
|
||||
for line in lines[1:]: # we skip the first line as it is a header
|
||||
parts = line.split(",")
|
||||
train_field_ids.append(int(parts[0]))
|
||||
if parts[1] != "":
|
||||
test_field_ids.append(int(parts[1]))
|
||||
|
||||
return train_field_ids, test_field_ids
|
||||
|
||||
def download(self, api_key: str) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Parameters:
|
||||
|
@ -142,25 +345,28 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
|
||||
# Download from MLHub and check integrity
|
||||
import radiant_mlhub # To download from MLHub, could probably use `requests` instead
|
||||
import radiant_mlhub # To download from MLHub, could use `requests` instead
|
||||
|
||||
dataset = radiant_mlhub.Dataset.fetch('ref_african_crops_kenya_02', api_key=api_key)
|
||||
dataset = radiant_mlhub.Dataset.fetch(
|
||||
"ref_african_crops_kenya_02", api_key=api_key
|
||||
)
|
||||
dataset.download(
|
||||
output_dir=os.path.join(self.root, self.base_folder),
|
||||
api_key=api_key
|
||||
) # NOTE: Will not work with library versions < 0.2.1
|
||||
output_dir=os.path.join(self.root, self.base_folder), api_key=api_key
|
||||
) # NOTE: Will not work with library versions < 0.2.1
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError("Dataset files not found or corrupted.")
|
||||
|
||||
|
||||
# Extract archives
|
||||
import tarfile # To extract .tar.gz archives
|
||||
import tarfile # To extract .tar.gz archives
|
||||
|
||||
image_archive_path = os.path.join(self.root, self.base_folder, self.image_meta["filename"])
|
||||
target_archive_path = os.path.join(self.root, self.base_folder, self.target_meta["filename"])
|
||||
image_archive_path = os.path.join(
|
||||
self.root, self.base_folder, self.image_meta["filename"]
|
||||
)
|
||||
target_archive_path = os.path.join(
|
||||
self.root, self.base_folder, self.target_meta["filename"]
|
||||
)
|
||||
for fn in [image_archive_path, target_archive_path]:
|
||||
with tarfile.open(fn) as tfile:
|
||||
tfile.extractall(path=os.path.join(self.root, self.base_folder))
|
||||
tfile.extractall(path=os.path.join(self.root, self.base_folder))
|
||||
|
|
Загрузка…
Ссылка в новой задаче