This commit is contained in:
Adam J. Stewart 2021-06-07 16:10:09 +00:00
Родитель 28026fdfe4
Коммит 7b06347424
5 изменённых файлов: 113 добавлений и 73 удалений

3
.gitignore поставляемый
Просмотреть файл

@ -133,6 +133,3 @@ dmypy.json
# Pyre type checker
.pyre/
# Visual studio code
.vscode/

Просмотреть файл

@ -5,7 +5,7 @@ mypy
opencv-python
Pillow
pycocotools
radiant-mlhub
radiant-mlhub>=0.2.1
rarfile
torch
torchvision

Просмотреть файл

@ -31,7 +31,7 @@ packages = find:
[options.extras_require]
cv4akenyacroptype =
radiant-mlhub
radiant-mlhub>=0.2.1
landcoverai =
opencv-python
vhr10 =

Просмотреть файл

@ -8,7 +8,7 @@ spack:
- py-mypy
- py-pillow-simd
- py-pycocotools
- py-radiant-mlhub
- "py-radiant-mlhub@0.2.1:"
- py-rarfile
- py-torch
- py-torchvision

Просмотреть файл

@ -1,3 +1,4 @@
import csv
from functools import lru_cache
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -5,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity
from torchvision.datasets.utils import check_integrity, extract_archive
class CV4AKenyaCropType(VisionDataset):
@ -18,17 +19,18 @@ class CV4AKenyaCropType(VisionDataset):
Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time.
Each tile has:
* 13 multi-band observations throughout the growing season. Each observation
includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer.
The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A,
B09, B11, B12] (refer to Sentinel-2 documentation for more information about
the bands). The cloud probability layer is a product of the
Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated
cloud probability (0-100%) per pixel. All of the bands are mapped to a common
10 m spatial resolution grid.
includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer.
The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A,
B09, B11, B12] (refer to Sentinel-2 documentation for more information about
the bands). The cloud probability layer is a product of the
Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated
cloud probability (0-100%) per pixel. All of the bands are mapped to a common
10 m spatial resolution grid.
* A raster layer indicating the crop ID for the fields in the training set.
* A raster layer indicating field IDs for the fields (both training and test sets).
Fields with a crop ID 0 are the test fields.
Fields with a crop ID 0 are the test fields.
There are 3,286 fields in the train set and 1,402 fields in the test set.
@ -96,10 +98,10 @@ class CV4AKenyaCropType(VisionDataset):
root: str = "data",
chip_size: int = 256,
stride: int = 128,
bands: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable[[Image.Image], Any]] = None,
target_transform: Optional[Callable[[Image.Image], Any]] = None,
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
bands: Tuple[str, ...] = band_names,
transform: Optional[Callable[[np.ndarray], Any]] = None,
target_transform: Optional[Callable[[np.ndarray], Any]] = None,
transforms: Optional[Callable[[np.ndarray, np.ndarray], Any]] = None,
download: bool = False,
api_key: Optional[str] = None,
verbose: bool = False,
@ -108,11 +110,11 @@ class CV4AKenyaCropType(VisionDataset):
Parameters:
root: root directory where dataset can be found
chip_size (int): size of chips
stride (int): spacing between chips, if less than chip_size, then there
chip_size: size of chips
stride: spacing between chips, if less than chip_size, then there
will be overlap between chips
bands (tuple): the subset of bands to load
transform: a function/transform that takes in a PIL image and returns a
bands: the subset of bands to load
transform: a function/transform that takes in a numpy array and returns a
transformed version
target_transform: a function/transform that takes in the target and
transforms it
@ -120,7 +122,14 @@ class CV4AKenyaCropType(VisionDataset):
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
verbose: if True, print messages when new tiles are loaded
Raises:
RuntimeError: if download=True but api_key=None, or download=False but
dataset is missing or checksum fails
"""
self._validate_bands(bands)
super().__init__(root, transforms, transform, target_transform)
self.verbose = verbose
@ -140,7 +149,7 @@ class CV4AKenyaCropType(VisionDataset):
)
# Calculate the indices that we will use over all tiles
self.bands = self._validate_bands(bands)
self.bands = bands
self.chip_size = chip_size
self.chips_metadata = []
for tile_index in range(len(self.tile_names)):
@ -159,10 +168,8 @@ class CV4AKenyaCropType(VisionDataset):
index: index to return
Returns:
data, label, and field ids at that index
data, labels, field ids, and metadata at that index
"""
assert index < len(self)
tile_index, y, x = self.chips_metadata[index]
tile_name = self.tile_names[tile_index]
@ -173,6 +180,9 @@ class CV4AKenyaCropType(VisionDataset):
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
if self.transforms is not None:
img, labels = self.transforms(img, labels)
return {
"img": img,
"labels": labels,
@ -189,12 +199,22 @@ class CV4AKenyaCropType(VisionDataset):
return len(self.chips_metadata)
@lru_cache
def _load_label_tile(self, tile_name_: str) -> Tuple[np.ndarray, np.ndarray]:
"""Loads a single _tile_ of labels and field_ids"""
assert tile_name_ in self.tile_names
def _load_label_tile(self, tile_name: str) -> Tuple[np.ndarray, np.ndarray]:
"""Loads a single _tile_ of labels and field_ids.
Parameters:
tile_name: name of tile to load
Returns:
tuple of labels and field ids
Raises:
AssertionError: if tile_name is invalid
"""
assert tile_name in self.tile_names
if self.verbose:
print(f"Loading labels/field_ids for {tile_name_}")
print(f"Loading labels/field_ids for {tile_name}")
labels = np.array(
Image.open(
@ -202,7 +222,7 @@ class CV4AKenyaCropType(VisionDataset):
self.root,
self.base_folder,
"ref_african_crops_kenya_02_labels",
tile_name_ + "_label",
tile_name + "_label",
"labels.tif",
)
)
@ -214,7 +234,7 @@ class CV4AKenyaCropType(VisionDataset):
self.root,
self.base_folder,
"ref_african_crops_kenya_02_labels",
tile_name_ + "_label",
tile_name + "_label",
"field_ids.tif",
)
)
@ -222,34 +242,44 @@ class CV4AKenyaCropType(VisionDataset):
return (labels, field_ids)
def _validate_bands(self, bands: Optional[Tuple[str, ...]]) -> Tuple[str, ...]:
"""Routine for validating a list of bands / filling in a default value"""
def _validate_bands(self, bands: Tuple[str, ...]) -> None:
"""Validate list of bands.
if bands is None:
return self.band_names
else:
assert isinstance(bands, tuple), "The list of bands must be a tuple"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")
return bands
Parameters:
bands: user-provided tuple of bands to load
Raises:
AssertionError: if bands is not a tuple
ValueError: if an invalid band name is provided
"""
assert isinstance(bands, tuple), "The list of bands must be a tuple"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")
@lru_cache
def _load_all_image_tiles(
self, tile_name_: str, bands: Optional[Tuple[str, ...]] = None
self, tile_name: str, bands: Tuple[str, ...] = band_names
) -> np.ndarray:
"""Load all the imagery (across time) for a single _tile_. Optionally allows
for subsetting of the bands that are loaded.
Parameters:
tile_name: name of tile to load
bands: tuple of bands to load
Returns
imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of
points in time, 3035 is the tile height, and 2016 is the tile width.
points in time, 3035 is the tile height, and 2016 is the tile width
Raises:
AssertionError: if tile_name is invalid
"""
assert tile_name_ in self.tile_names
bands = self._validate_bands(bands)
assert tile_name in self.tile_names
if self.verbose:
print(f"Loading all imagery for {tile_name_}")
print(f"Loading all imagery for {tile_name}")
img = np.zeros(
(len(self.dates), len(bands), self.tile_height, self.tile_width),
@ -257,32 +287,43 @@ class CV4AKenyaCropType(VisionDataset):
)
for date_index, date in enumerate(self.dates):
img[date_index] = self._load_single_image_tile(tile_name_, date, bands)
img[date_index] = self._load_single_image_tile(tile_name, date, self.bands)
return img
@lru_cache
def _load_single_image_tile(
self, tile_name_: str, date_: str, bands: Optional[Tuple[str, ...]] = None
self, tile_name: str, date: str, bands: Tuple[str, ...]
) -> np.ndarray:
"""Loads the imagery for a single tile for a single date. Optionally allows
for subsetting of the bands that are loaded."""
assert tile_name_ in self.tile_names
assert date_ in self.dates
bands = self._validate_bands(bands)
"""Load the imagery for a single tile for a single date. Optionally allows
for subsetting of the bands that are loaded.
Parameters:
tile_name: name of tile to load
date: date of tile to load
bands: bands to load
Returns:
array containing a single image tile
Raises:
AssertionError: if tile_name or date is invalid
"""
assert tile_name in self.tile_names
assert date in self.dates
if self.verbose:
print(f"Loading imagery for {tile_name_} at {date_}")
print(f"Loading imagery for {tile_name} at {date}")
img = np.zeros(
(len(bands), self.tile_height, self.tile_width), dtype=np.float32
)
for band_index, band_name in enumerate(bands):
for band_index, band_name in enumerate(self.bands):
img_fn = os.path.join(
self.root,
self.base_folder,
"ref_african_crops_kenya_02_source",
f"{tile_name_}_{date_}",
f"{tile_name}_{date}",
f"{band_name}.tif",
)
band_img = np.array(Image.open(img_fn))
@ -325,13 +366,16 @@ class CV4AKenyaCropType(VisionDataset):
"field_train_test_ids.csv",
)
with open(splits_fn, "r") as f:
lines = f.read().strip().split("\n")
for line in lines[1:]: # we skip the first line as it is a header
parts = line.split(",")
train_field_ids.append(int(parts[0]))
if parts[1] != "":
test_field_ids.append(int(parts[1]))
with open(splits_fn, newline="") as f:
reader = csv.reader(f)
# Skip header row
next(reader)
for row in reader:
train_field_ids.append(int(row[0]))
if row[1]:
test_field_ids.append(int(row[1]))
return train_field_ids, test_field_ids
@ -340,28 +384,28 @@ class CV4AKenyaCropType(VisionDataset):
Parameters:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")
return
# Download from MLHub and check integrity
import radiant_mlhub # To download from MLHub, could use `requests` instead
# Must be installed to download from MLHub
import radiant_mlhub
dataset = radiant_mlhub.Dataset.fetch(
"ref_african_crops_kenya_02", api_key=api_key
)
dataset.download(
output_dir=os.path.join(self.root, self.base_folder), api_key=api_key
) # NOTE: Will not work with library versions < 0.2.1
)
if not self._check_integrity():
raise RuntimeError("Dataset files not found or corrupted.")
# Extract archives
import tarfile # To extract .tar.gz archives
image_archive_path = os.path.join(
self.root, self.base_folder, self.image_meta["filename"]
)
@ -369,5 +413,4 @@ class CV4AKenyaCropType(VisionDataset):
self.root, self.base_folder, self.target_meta["filename"]
)
for fn in [image_archive_path, target_archive_path]:
with tarfile.open(fn) as tfile:
tfile.extractall(path=os.path.join(self.root, self.base_folder))
extract_archive(fn, os.path.join(self.root, self.base_folder))