зеркало из https://github.com/microsoft/torchgeo.git
Add ZipDataset class
This commit is contained in:
Родитель
eacb5685f3
Коммит
1fbe5a04f6
|
@ -1,5 +1,5 @@
|
|||
import abc
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
@ -43,6 +43,17 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def __add__(self, other: "GeoDataset") -> "ZipDataset": # type: ignore[override]
|
||||
"""Merge two GeoDatasets.
|
||||
|
||||
Parameters:
|
||||
other: another dataset
|
||||
|
||||
Returns:
|
||||
a single dataset
|
||||
"""
|
||||
return ZipDataset([self, other])
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
|
@ -92,3 +103,56 @@ class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
{self.__class__.__name__} Dataset
|
||||
type: VisionDataset
|
||||
size: {len(self)}"""
|
||||
|
||||
|
||||
class ZipDataset(GeoDataset):
|
||||
"""Dataset for merging two or more GeoDatasets.
|
||||
|
||||
For example, this allows you to combine an image source like Landsat8 with a target
|
||||
label like CDL.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets: Iterable[GeoDataset]) -> None:
|
||||
"""Initialize a new Dataset instance.
|
||||
|
||||
Parameters:
|
||||
datasets: list of datasets to merge
|
||||
"""
|
||||
for ds in datasets:
|
||||
assert isinstance(ds, GeoDataset), "ZipDataset only supports GeoDatasets"
|
||||
|
||||
self.datasets = datasets
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Parameters:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and labels at that index
|
||||
"""
|
||||
sample = {}
|
||||
for ds in self.datasets:
|
||||
sample.update(ds[index])
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
# TODO: figure out how to handle this
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
Returns:
|
||||
informal string representation
|
||||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: ZipDataset
|
||||
size: {len(self)}"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче