зеркало из https://github.com/microsoft/torchgeo.git
Add VisionDataset base class
This commit is contained in:
Родитель
9e58dc6a63
Коммит
39ea1be875
|
@ -1,6 +1,6 @@
|
|||
from .cowc import COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .geo import GeoDataset
|
||||
from .geo import GeoDataset, VisionDataset
|
||||
from .landcoverai import LandCoverAI
|
||||
from .nwpu import VHR10
|
||||
|
||||
|
@ -11,4 +11,5 @@ __all__ = (
|
|||
"GeoDataset",
|
||||
"LandCoverAI",
|
||||
"VHR10",
|
||||
"VisionDataset",
|
||||
)
|
||||
|
|
|
@ -22,15 +22,16 @@ import bz2
|
|||
import csv
|
||||
import os
|
||||
import tarfile
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import (
|
||||
check_integrity,
|
||||
download_url,
|
||||
)
|
||||
|
||||
from .geo import VisionDataset
|
||||
|
||||
|
||||
class _COWC(VisionDataset, abc.ABC):
|
||||
"""Abstract base class for all COWC datasets."""
|
||||
|
@ -69,9 +70,7 @@ class _COWC(VisionDataset, abc.ABC):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
target_transform: Optional[Callable[[int], Any]] = None,
|
||||
transforms: Optional[Callable[[Image.Image, int], Tuple[Any, Any]]] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new COWC dataset instance.
|
||||
|
@ -79,10 +78,6 @@ class _COWC(VisionDataset, abc.ABC):
|
|||
Parameters:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train" or "test"
|
||||
transform: a function/transform that takes in a PIL image and returns a
|
||||
transformed version
|
||||
target_transform: a function/transform that takes in the target and
|
||||
transforms it
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
|
@ -94,7 +89,8 @@ class _COWC(VisionDataset, abc.ABC):
|
|||
"""
|
||||
assert split in ["train", "test"]
|
||||
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
@ -116,7 +112,7 @@ class _COWC(VisionDataset, abc.ABC):
|
|||
self.images.append(row[0])
|
||||
self.targets.append(row[1])
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Parameters:
|
||||
|
@ -125,13 +121,15 @@ class _COWC(VisionDataset, abc.ABC):
|
|||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
image = self._load_image(index)
|
||||
target = int(self.targets[index])
|
||||
sample = {
|
||||
"image": self._load_image(index),
|
||||
"label": int(self.targets[index]),
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
image, target = self.transforms(image, target)
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return image, target
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
|
|
@ -104,9 +104,7 @@ class CV4AKenyaCropType(GeoDataset):
|
|||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: Tuple[str, ...] = band_names,
|
||||
transforms: Optional[
|
||||
Callable[[np.ndarray, np.ndarray], Tuple[Any, Any]]
|
||||
] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
|
@ -182,16 +180,18 @@ class CV4AKenyaCropType(GeoDataset):
|
|||
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
|
||||
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
|
||||
|
||||
if self.transforms is not None:
|
||||
img, labels = self.transforms(img, labels)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"labels": labels,
|
||||
sample = {
|
||||
"image": img,
|
||||
"mask": labels,
|
||||
"field_ids": field_ids,
|
||||
"metadata": (tile_index, y, x),
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of chips in the dataset.
|
||||
|
||||
|
|
|
@ -53,3 +53,42 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
|||
{self.__class__.__name__} Dataset
|
||||
type: GeoDataset
|
||||
size: {len(self)}"""
|
||||
|
||||
|
||||
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||
"""Abstract base class for datasets lacking geospatial information.
|
||||
|
||||
This base class is designed for datasets with pre-defined image chips.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Parameters:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and labels at that index
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the informal string representation of the object.
|
||||
|
||||
Returns:
|
||||
informal string representation
|
||||
"""
|
||||
return f"""\
|
||||
{self.__class__.__name__} Dataset
|
||||
type: VisionDataset
|
||||
size: {len(self)}"""
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import os
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import working_dir
|
||||
|
||||
|
||||
|
@ -57,9 +57,7 @@ class LandCoverAI(VisionDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
target_transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new LandCover.ai dataset instance.
|
||||
|
@ -67,10 +65,6 @@ class LandCoverAI(VisionDataset):
|
|||
Parameters:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", "val", or "test"
|
||||
transform: a function/transform that takes in a PIL image and returns a
|
||||
transformed version
|
||||
target_transform: a function/transform that takes in the target and
|
||||
transforms it
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
|
@ -82,7 +76,8 @@ class LandCoverAI(VisionDataset):
|
|||
"""
|
||||
assert split in ["train", "val", "test"]
|
||||
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
@ -96,7 +91,7 @@ class LandCoverAI(VisionDataset):
|
|||
with open(os.path.join(self.root, self.base_folder, split + ".txt")) as f:
|
||||
self.ids = f.readlines()
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Parameters:
|
||||
|
@ -106,13 +101,15 @@ class LandCoverAI(VisionDataset):
|
|||
data and label at that index
|
||||
"""
|
||||
id_ = self.ids[index].rstrip()
|
||||
image = self._load_image(id_)
|
||||
target = self._load_target(id_)
|
||||
sample = {
|
||||
"image": self._load_image(id_),
|
||||
"mask": self._load_target(id_),
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
image, target = self.transforms(image, target)
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return image, target
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import os
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.utils import (
|
||||
check_integrity,
|
||||
download_file_from_google_drive,
|
||||
download_url,
|
||||
)
|
||||
|
||||
from .geo import VisionDataset
|
||||
|
||||
|
||||
class VHR10(VisionDataset):
|
||||
"""Northwestern Polytechnical University (NWPU) very-high-resolution ten-class
|
||||
|
@ -74,11 +75,7 @@ class VHR10(VisionDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "positive",
|
||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
||||
target_transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
||||
transforms: Optional[
|
||||
Callable[[Image.Image, Dict[str, Any]], Tuple[Any, Any]]
|
||||
] = None,
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new VHR-10 dataset instance.
|
||||
|
@ -86,10 +83,6 @@ class VHR10(VisionDataset):
|
|||
Parameters:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "postive" or "negative"
|
||||
transform: a function/transform that takes in a PIL image and returns a
|
||||
transformed version
|
||||
target_transform: a function/transform that takes in the target and
|
||||
transforms it
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
|
@ -101,8 +94,9 @@ class VHR10(VisionDataset):
|
|||
"""
|
||||
assert split in ["positive", "negative"]
|
||||
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
@ -126,7 +120,7 @@ class VHR10(VisionDataset):
|
|||
)
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Parameters:
|
||||
|
@ -136,13 +130,15 @@ class VHR10(VisionDataset):
|
|||
data and label at that index
|
||||
"""
|
||||
id_ = index % len(self) + 1
|
||||
image = self._load_image(id_)
|
||||
target = self._load_target(id_)
|
||||
sample = {
|
||||
"image": self._load_image(id_),
|
||||
"label": self._load_target(id_),
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
image, target = self.transforms(image, target)
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return image, target
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
|
Загрузка…
Ссылка в новой задаче