зеркало из https://github.com/microsoft/torchgeo.git
Address review comments
This commit is contained in:
Родитель
28026fdfe4
Коммит
7b06347424
|
@ -133,6 +133,3 @@ dmypy.json
|
|||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# Visual studio code
|
||||
.vscode/
|
|
@ -5,7 +5,7 @@ mypy
|
|||
opencv-python
|
||||
Pillow
|
||||
pycocotools
|
||||
radiant-mlhub
|
||||
radiant-mlhub>=0.2.1
|
||||
rarfile
|
||||
torch
|
||||
torchvision
|
||||
|
|
|
@ -31,7 +31,7 @@ packages = find:
|
|||
|
||||
[options.extras_require]
|
||||
cv4akenyacroptype =
|
||||
radiant-mlhub
|
||||
radiant-mlhub>=0.2.1
|
||||
landcoverai =
|
||||
opencv-python
|
||||
vhr10 =
|
||||
|
|
|
@ -8,7 +8,7 @@ spack:
|
|||
- py-mypy
|
||||
- py-pillow-simd
|
||||
- py-pycocotools
|
||||
- py-radiant-mlhub
|
||||
- "py-radiant-mlhub@0.2.1:"
|
||||
- py-rarfile
|
||||
- py-torch
|
||||
- py-torchvision
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import csv
|
||||
from functools import lru_cache
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
@ -5,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import check_integrity
|
||||
from torchvision.datasets.utils import check_integrity, extract_archive
|
||||
|
||||
|
||||
class CV4AKenyaCropType(VisionDataset):
|
||||
|
@ -18,17 +19,18 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time.
|
||||
|
||||
Each tile has:
|
||||
|
||||
* 13 multi-band observations throughout the growing season. Each observation
|
||||
includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer.
|
||||
The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A,
|
||||
B09, B11, B12] (refer to Sentinel-2 documentation for more information about
|
||||
the bands). The cloud probability layer is a product of the
|
||||
Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated
|
||||
cloud probability (0-100%) per pixel. All of the bands are mapped to a common
|
||||
10 m spatial resolution grid.
|
||||
includes 12 bands from Sentinel-2 L2A product, and a cloud probability layer.
|
||||
The twelve bands are [B01, B02, B03, B04, B05, B06, B07, B08, B8A,
|
||||
B09, B11, B12] (refer to Sentinel-2 documentation for more information about
|
||||
the bands). The cloud probability layer is a product of the
|
||||
Sentinel-2 atmospheric correction algorithm (Sen2Cor) and provides an estimated
|
||||
cloud probability (0-100%) per pixel. All of the bands are mapped to a common
|
||||
10 m spatial resolution grid.
|
||||
* A raster layer indicating the crop ID for the fields in the training set.
|
||||
* A raster layer indicating field IDs for the fields (both training and test sets).
|
||||
Fields with a crop ID 0 are the test fields.
|
||||
Fields with a crop ID 0 are the test fields.
|
||||
|
||||
There are 3,286 fields in the train set and 1,402 fields in the test set.
|
||||
|
||||
|
@ -96,10 +98,10 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
root: str = "data",
|
||||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: Optional[Tuple[str, ...]] = None,
|
||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
target_transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
|
||||
bands: Tuple[str, ...] = band_names,
|
||||
transform: Optional[Callable[[np.ndarray], Any]] = None,
|
||||
target_transform: Optional[Callable[[np.ndarray], Any]] = None,
|
||||
transforms: Optional[Callable[[np.ndarray, np.ndarray], Any]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
|
@ -108,11 +110,11 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
|
||||
Parameters:
|
||||
root: root directory where dataset can be found
|
||||
chip_size (int): size of chips
|
||||
stride (int): spacing between chips, if less than chip_size, then there
|
||||
chip_size: size of chips
|
||||
stride: spacing between chips, if less than chip_size, then there
|
||||
will be overlap between chips
|
||||
bands (tuple): the subset of bands to load
|
||||
transform: a function/transform that takes in a PIL image and returns a
|
||||
bands: the subset of bands to load
|
||||
transform: a function/transform that takes in a numpy array and returns a
|
||||
transformed version
|
||||
target_transform: a function/transform that takes in the target and
|
||||
transforms it
|
||||
|
@ -120,7 +122,14 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
verbose: if True, print messages when new tiles are loaded
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download=True but api_key=None, or download=False but
|
||||
dataset is missing or checksum fails
|
||||
"""
|
||||
self._validate_bands(bands)
|
||||
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.verbose = verbose
|
||||
|
||||
|
@ -140,7 +149,7 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
)
|
||||
|
||||
# Calculate the indices that we will use over all tiles
|
||||
self.bands = self._validate_bands(bands)
|
||||
self.bands = bands
|
||||
self.chip_size = chip_size
|
||||
self.chips_metadata = []
|
||||
for tile_index in range(len(self.tile_names)):
|
||||
|
@ -159,10 +168,8 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
index: index to return
|
||||
|
||||
Returns:
|
||||
data, label, and field ids at that index
|
||||
data, labels, field ids, and metadata at that index
|
||||
"""
|
||||
assert index < len(self)
|
||||
|
||||
tile_index, y, x = self.chips_metadata[index]
|
||||
tile_name = self.tile_names[tile_index]
|
||||
|
||||
|
@ -173,6 +180,9 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
|
||||
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
|
||||
|
||||
if self.transforms is not None:
|
||||
img, labels = self.transforms(img, labels)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"labels": labels,
|
||||
|
@ -189,12 +199,22 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
return len(self.chips_metadata)
|
||||
|
||||
@lru_cache
|
||||
def _load_label_tile(self, tile_name_: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Loads a single _tile_ of labels and field_ids"""
|
||||
assert tile_name_ in self.tile_names
|
||||
def _load_label_tile(self, tile_name: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Loads a single _tile_ of labels and field_ids.
|
||||
|
||||
Parameters:
|
||||
tile_name: name of tile to load
|
||||
|
||||
Returns:
|
||||
tuple of labels and field ids
|
||||
|
||||
Raises:
|
||||
AssertionError: if tile_name is invalid
|
||||
"""
|
||||
assert tile_name in self.tile_names
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading labels/field_ids for {tile_name_}")
|
||||
print(f"Loading labels/field_ids for {tile_name}")
|
||||
|
||||
labels = np.array(
|
||||
Image.open(
|
||||
|
@ -202,7 +222,7 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_labels",
|
||||
tile_name_ + "_label",
|
||||
tile_name + "_label",
|
||||
"labels.tif",
|
||||
)
|
||||
)
|
||||
|
@ -214,7 +234,7 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_labels",
|
||||
tile_name_ + "_label",
|
||||
tile_name + "_label",
|
||||
"field_ids.tif",
|
||||
)
|
||||
)
|
||||
|
@ -222,34 +242,44 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
|
||||
return (labels, field_ids)
|
||||
|
||||
def _validate_bands(self, bands: Optional[Tuple[str, ...]]) -> Tuple[str, ...]:
|
||||
"""Routine for validating a list of bands / filling in a default value"""
|
||||
def _validate_bands(self, bands: Tuple[str, ...]) -> None:
|
||||
"""Validate list of bands.
|
||||
|
||||
if bands is None:
|
||||
return self.band_names
|
||||
else:
|
||||
assert isinstance(bands, tuple), "The list of bands must be a tuple"
|
||||
for band in bands:
|
||||
if band not in self.band_names:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
return bands
|
||||
Parameters:
|
||||
bands: user-provided tuple of bands to load
|
||||
|
||||
Raises:
|
||||
AssertionError: if bands is not a tuple
|
||||
ValueError: if an invalid band name is provided
|
||||
"""
|
||||
|
||||
assert isinstance(bands, tuple), "The list of bands must be a tuple"
|
||||
for band in bands:
|
||||
if band not in self.band_names:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
|
||||
@lru_cache
|
||||
def _load_all_image_tiles(
|
||||
self, tile_name_: str, bands: Optional[Tuple[str, ...]] = None
|
||||
self, tile_name: str, bands: Tuple[str, ...] = band_names
|
||||
) -> np.ndarray:
|
||||
"""Load all the imagery (across time) for a single _tile_. Optionally allows
|
||||
for subsetting of the bands that are loaded.
|
||||
|
||||
Parameters:
|
||||
tile_name: name of tile to load
|
||||
bands: tuple of bands to load
|
||||
|
||||
Returns
|
||||
imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of
|
||||
points in time, 3035 is the tile height, and 2016 is the tile width.
|
||||
points in time, 3035 is the tile height, and 2016 is the tile width
|
||||
|
||||
Raises:
|
||||
AssertionError: if tile_name is invalid
|
||||
"""
|
||||
assert tile_name_ in self.tile_names
|
||||
bands = self._validate_bands(bands)
|
||||
assert tile_name in self.tile_names
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading all imagery for {tile_name_}")
|
||||
print(f"Loading all imagery for {tile_name}")
|
||||
|
||||
img = np.zeros(
|
||||
(len(self.dates), len(bands), self.tile_height, self.tile_width),
|
||||
|
@ -257,32 +287,43 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
)
|
||||
|
||||
for date_index, date in enumerate(self.dates):
|
||||
img[date_index] = self._load_single_image_tile(tile_name_, date, bands)
|
||||
img[date_index] = self._load_single_image_tile(tile_name, date, self.bands)
|
||||
|
||||
return img
|
||||
|
||||
@lru_cache
|
||||
def _load_single_image_tile(
|
||||
self, tile_name_: str, date_: str, bands: Optional[Tuple[str, ...]] = None
|
||||
self, tile_name: str, date: str, bands: Tuple[str, ...]
|
||||
) -> np.ndarray:
|
||||
"""Loads the imagery for a single tile for a single date. Optionally allows
|
||||
for subsetting of the bands that are loaded."""
|
||||
assert tile_name_ in self.tile_names
|
||||
assert date_ in self.dates
|
||||
bands = self._validate_bands(bands)
|
||||
"""Load the imagery for a single tile for a single date. Optionally allows
|
||||
for subsetting of the bands that are loaded.
|
||||
|
||||
Parameters:
|
||||
tile_name: name of tile to load
|
||||
date: date of tile to load
|
||||
bands: bands to load
|
||||
|
||||
Returns:
|
||||
array containing a single image tile
|
||||
|
||||
Raises:
|
||||
AssertionError: if tile_name or date is invalid
|
||||
"""
|
||||
assert tile_name in self.tile_names
|
||||
assert date in self.dates
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading imagery for {tile_name_} at {date_}")
|
||||
print(f"Loading imagery for {tile_name} at {date}")
|
||||
|
||||
img = np.zeros(
|
||||
(len(bands), self.tile_height, self.tile_width), dtype=np.float32
|
||||
)
|
||||
for band_index, band_name in enumerate(bands):
|
||||
for band_index, band_name in enumerate(self.bands):
|
||||
img_fn = os.path.join(
|
||||
self.root,
|
||||
self.base_folder,
|
||||
"ref_african_crops_kenya_02_source",
|
||||
f"{tile_name_}_{date_}",
|
||||
f"{tile_name}_{date}",
|
||||
f"{band_name}.tif",
|
||||
)
|
||||
band_img = np.array(Image.open(img_fn))
|
||||
|
@ -325,13 +366,16 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
"field_train_test_ids.csv",
|
||||
)
|
||||
|
||||
with open(splits_fn, "r") as f:
|
||||
lines = f.read().strip().split("\n")
|
||||
for line in lines[1:]: # we skip the first line as it is a header
|
||||
parts = line.split(",")
|
||||
train_field_ids.append(int(parts[0]))
|
||||
if parts[1] != "":
|
||||
test_field_ids.append(int(parts[1]))
|
||||
with open(splits_fn, newline="") as f:
|
||||
reader = csv.reader(f)
|
||||
|
||||
# Skip header row
|
||||
next(reader)
|
||||
|
||||
for row in reader:
|
||||
train_field_ids.append(int(row[0]))
|
||||
if row[1]:
|
||||
test_field_ids.append(int(row[1]))
|
||||
|
||||
return train_field_ids, test_field_ids
|
||||
|
||||
|
@ -340,28 +384,28 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
|
||||
Parameters:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
# Download from MLHub and check integrity
|
||||
import radiant_mlhub # To download from MLHub, could use `requests` instead
|
||||
# Must be installed to download from MLHub
|
||||
import radiant_mlhub
|
||||
|
||||
dataset = radiant_mlhub.Dataset.fetch(
|
||||
"ref_african_crops_kenya_02", api_key=api_key
|
||||
)
|
||||
dataset.download(
|
||||
output_dir=os.path.join(self.root, self.base_folder), api_key=api_key
|
||||
) # NOTE: Will not work with library versions < 0.2.1
|
||||
)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError("Dataset files not found or corrupted.")
|
||||
|
||||
# Extract archives
|
||||
import tarfile # To extract .tar.gz archives
|
||||
|
||||
image_archive_path = os.path.join(
|
||||
self.root, self.base_folder, self.image_meta["filename"]
|
||||
)
|
||||
|
@ -369,5 +413,4 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
self.root, self.base_folder, self.target_meta["filename"]
|
||||
)
|
||||
for fn in [image_archive_path, target_archive_path]:
|
||||
with tarfile.open(fn) as tfile:
|
||||
tfile.extractall(path=os.path.join(self.root, self.base_folder))
|
||||
extract_archive(fn, os.path.join(self.root, self.base_folder))
|
||||
|
|
Загрузка…
Ссылка в новой задаче