зеркало из https://github.com/microsoft/torchgeo.git
Add GeoDataset base class
This commit is contained in:
Родитель
36611a7c6a
Коммит
9e58dc6a63
|
@ -1,6 +1,14 @@
|
|||
from .cowc import COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .geo import GeoDataset
|
||||
from .landcoverai import LandCoverAI
|
||||
from .nwpu import VHR10
|
||||
|
||||
__all__ = ("COWCCounting", "COWCDetection", "CV4AKenyaCropType", "LandCoverAI", "VHR10")
|
||||
__all__ = (
|
||||
"COWCCounting",
|
||||
"COWCDetection",
|
||||
"CV4AKenyaCropType",
|
||||
"GeoDataset",
|
||||
"LandCoverAI",
|
||||
"VHR10",
|
||||
)
|
||||
|
|
|
@ -5,11 +5,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import check_integrity, extract_archive
|
||||
|
||||
from .geo import GeoDataset
|
||||
|
||||
class CV4AKenyaCropType(VisionDataset):
|
||||
|
||||
class CV4AKenyaCropType(GeoDataset):
|
||||
"""CV4A Kenya Crop Type dataset.
|
||||
|
||||
Used in a competition in the Computer Vision for Agriculture (CV4A) workshop in
|
||||
|
@ -103,8 +104,6 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: Tuple[str, ...] = band_names,
|
||||
transform: Optional[Callable[[np.ndarray], Any]] = None,
|
||||
target_transform: Optional[Callable[[np.ndarray], Any]] = None,
|
||||
transforms: Optional[
|
||||
Callable[[np.ndarray, np.ndarray], Tuple[Any, Any]]
|
||||
] = None,
|
||||
|
@ -120,10 +119,6 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
stride: spacing between chips, if less than chip_size, then there
|
||||
will be overlap between chips
|
||||
bands: the subset of bands to load
|
||||
transform: a function/transform that takes in a numpy array and returns a
|
||||
transformed version
|
||||
target_transform: a function/transform that takes in the target and
|
||||
transforms it
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
|
@ -136,7 +131,8 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
"""
|
||||
self._validate_bands(bands)
|
||||
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.verbose = verbose
|
||||
|
||||
if download:
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import abc
|
||||
from typing import Any, Dict
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||
"""Abstract base class for datasets containing geospatial information.
|
||||
|
||||
Geospatial information includes things like:
|
||||
|
||||
* latitude, longitude
|
||||
* time
|
||||
* coordinate reference systems (CRS)
|
||||
|
||||
These kind of datasets are special because they can be combined. For example:
|
||||
|
||||
* Combine Landsat8 and CDL to train a model for crop classification
|
||||
* Combine Sentinel2 and Chesapeake to train a model for land cover mapping
|
||||
|
||||
This isn't true for VisionDataset, where the lack of geospatial information
|
||||
prohibits swapping image sources or target labels.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Parameters:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and labels at that index
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
Returns:
|
||||
informal string representation
|
||||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: GeoDataset
|
||||
size: {len(self)}"""
|
Загрузка…
Ссылка в новой задаче