This commit is contained in:
Adam J. Stewart 2021-06-11 22:48:10 +00:00
Родитель eacb5685f3
Коммит 1fbe5a04f6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
1 изменённых файлов: 65 добавлений и 1 удалений

Просмотреть файл

@ -1,5 +1,5 @@
import abc
from typing import Any, Dict
from typing import Any, Dict, Iterable
from torch.utils.data import Dataset
@ -43,6 +43,17 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
"""
pass
def __add__(self, other: "GeoDataset") -> "ZipDataset": # type: ignore[override]
"""Merge two GeoDatasets.
Parameters:
other: another dataset
Returns:
a single dataset
"""
return ZipDataset([self, other])
def __str__(self) -> str:
"""Return the informal string representation of the object.
@ -92,3 +103,56 @@ class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
{self.__class__.__name__} Dataset
type: VisionDataset
size: {len(self)}"""
class ZipDataset(GeoDataset):
"""Dataset for merging two or more GeoDatasets.
For example, this allows you to combine an image source like Landsat8 with a target
label like CDL.
"""
def __init__(self, datasets: Iterable[GeoDataset]) -> None:
"""Initialize a new Dataset instance.
Parameters:
datasets: list of datasets to merge
"""
for ds in datasets:
assert isinstance(ds, GeoDataset), "ZipDataset only supports GeoDatasets"
self.datasets = datasets
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
index: index to return
Returns:
data and labels at that index
"""
sample = {}
for ds in self.datasets:
sample.update(ds[index])
return sample
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
length of the dataset
"""
# TODO: figure out how to handle this
pass
def __str__(self) -> str:
"""Return the informal string representation of the object.
Returns:
informal string representation
"""
return f"""\
{self.__class__.__name__} Dataset
type: ZipDataset
size: {len(self)}"""