This commit is contained in:
Adam J. Stewart 2021-06-11 17:16:34 +00:00
Родитель 36611a7c6a
Коммит 9e58dc6a63
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
3 изменённых файлов: 69 добавлений и 10 удалений

Просмотреть файл

@ -1,6 +1,14 @@
from .cowc import COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .geo import GeoDataset
from .landcoverai import LandCoverAI
from .nwpu import VHR10
__all__ = ("COWCCounting", "COWCDetection", "CV4AKenyaCropType", "LandCoverAI", "VHR10")
__all__ = (
"COWCCounting",
"COWCDetection",
"CV4AKenyaCropType",
"GeoDataset",
"LandCoverAI",
"VHR10",
)

Просмотреть файл

@ -5,11 +5,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, extract_archive
from .geo import GeoDataset
class CV4AKenyaCropType(VisionDataset):
class CV4AKenyaCropType(GeoDataset):
"""CV4A Kenya Crop Type dataset.
Used in a competition in the Computer Vision for Agriculture (CV4A) workshop in
@ -103,8 +104,6 @@ class CV4AKenyaCropType(VisionDataset):
chip_size: int = 256,
stride: int = 128,
bands: Tuple[str, ...] = band_names,
transform: Optional[Callable[[np.ndarray], Any]] = None,
target_transform: Optional[Callable[[np.ndarray], Any]] = None,
transforms: Optional[
Callable[[np.ndarray, np.ndarray], Tuple[Any, Any]]
] = None,
@ -120,10 +119,6 @@ class CV4AKenyaCropType(VisionDataset):
stride: spacing between chips, if less than chip_size, then there
will be overlap between chips
bands: the subset of bands to load
transform: a function/transform that takes in a numpy array and returns a
transformed version
target_transform: a function/transform that takes in the target and
transforms it
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
@ -136,7 +131,8 @@ class CV4AKenyaCropType(VisionDataset):
"""
self._validate_bands(bands)
super().__init__(root, transforms, transform, target_transform)
self.root = root
self.transforms = transforms
self.verbose = verbose
if download:

55
torchgeo/datasets/geo.py Normal file
Просмотреть файл

@ -0,0 +1,55 @@
import abc
from typing import Any, Dict
from torch.utils.data import Dataset
class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""Abstract base class for datasets containing geospatial information.
Geospatial information includes things like:
* latitude, longitude
* time
* coordinate reference systems (CRS)
These kind of datasets are special because they can be combined. For example:
* Combine Landsat8 and CDL to train a model for crop classification
* Combine Sentinel2 and Chesapeake to train a model for land cover mapping
This isn't true for VisionDataset, where the lack of geospatial information
prohibits swapping image sources or target labels.
"""
@abc.abstractmethod
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
index: index to return
Returns:
data and labels at that index
"""
pass
@abc.abstractmethod
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
length of the dataset
"""
pass
def __str__(self) -> str:
"""Return the informal string representation of the object.
Returns:
informal string representation
"""
return f"""\
{self.__class__.__name__} Dataset
type: GeoDataset
size: {len(self)}"""