This commit is contained in:
Adam J. Stewart 2021-06-11 19:41:12 +00:00
Родитель 9e58dc6a63
Коммит 39ea1be875
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
6 изменённых файлов: 88 добавлений и 57 удалений

Просмотреть файл

@ -1,6 +1,6 @@
from .cowc import COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .geo import GeoDataset
from .geo import GeoDataset, VisionDataset
from .landcoverai import LandCoverAI
from .nwpu import VHR10
@ -11,4 +11,5 @@ __all__ = (
"GeoDataset",
"LandCoverAI",
"VHR10",
"VisionDataset",
)

Просмотреть файл

@ -22,15 +22,16 @@ import bz2
import csv
import os
import tarfile
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
check_integrity,
download_url,
)
from .geo import VisionDataset
class _COWC(VisionDataset, abc.ABC):
"""Abstract base class for all COWC datasets."""
@ -69,9 +70,7 @@ class _COWC(VisionDataset, abc.ABC):
self,
root: str = "data",
split: str = "train",
transform: Optional[Callable[[Image.Image], Any]] = None,
target_transform: Optional[Callable[[int], Any]] = None,
transforms: Optional[Callable[[Image.Image, int], Tuple[Any, Any]]] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
) -> None:
"""Initialize a new COWC dataset instance.
@ -79,10 +78,6 @@ class _COWC(VisionDataset, abc.ABC):
Parameters:
root: root directory where dataset can be found
split: one of "train" or "test"
transform: a function/transform that takes in a PIL image and returns a
transformed version
target_transform: a function/transform that takes in the target and
transforms it
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
@ -94,7 +89,8 @@ class _COWC(VisionDataset, abc.ABC):
"""
assert split in ["train", "test"]
super().__init__(root, transforms, transform, target_transform)
self.root = root
self.transforms = transforms
if download:
self.download()
@ -116,7 +112,7 @@ class _COWC(VisionDataset, abc.ABC):
self.images.append(row[0])
self.targets.append(row[1])
def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
@ -125,13 +121,15 @@ class _COWC(VisionDataset, abc.ABC):
Returns:
data and label at that index
"""
image = self._load_image(index)
target = int(self.targets[index])
sample = {
"image": self._load_image(index),
"label": int(self.targets[index]),
}
if self.transforms is not None:
image, target = self.transforms(image, target)
sample = self.transforms(sample)
return image, target
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.

Просмотреть файл

@ -104,9 +104,7 @@ class CV4AKenyaCropType(GeoDataset):
chip_size: int = 256,
stride: int = 128,
bands: Tuple[str, ...] = band_names,
transforms: Optional[
Callable[[np.ndarray, np.ndarray], Tuple[Any, Any]]
] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
api_key: Optional[str] = None,
verbose: bool = False,
@ -182,16 +180,18 @@ class CV4AKenyaCropType(GeoDataset):
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
if self.transforms is not None:
img, labels = self.transforms(img, labels)
return {
"img": img,
"labels": labels,
sample = {
"image": img,
"mask": labels,
"field_ids": field_ids,
"metadata": (tile_index, y, x),
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of chips in the dataset.

Просмотреть файл

@ -53,3 +53,42 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
{self.__class__.__name__} Dataset
type: GeoDataset
size: {len(self)}"""
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
"""Abstract base class for datasets lacking geospatial information.
This base class is designed for datasets with pre-defined image chips.
"""
@abc.abstractmethod
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
index: index to return
Returns:
data and labels at that index
"""
pass
@abc.abstractmethod
def __len__(self) -> int:
"""Return the length of the dataset.
Returns:
length of the dataset
"""
pass
def __str__(self) -> str:
"""Return the informal string representation of the object.
Returns:
informal string representation
"""
return f"""\
{self.__class__.__name__} Dataset
type: VisionDataset
size: {len(self)}"""

Просмотреть файл

@ -1,10 +1,10 @@
import os
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Dict, Optional
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
from .geo import VisionDataset
from .utils import working_dir
@ -57,9 +57,7 @@ class LandCoverAI(VisionDataset):
self,
root: str = "data",
split: str = "train",
transform: Optional[Callable[[Image.Image], Any]] = None,
target_transform: Optional[Callable[[Image.Image], Any]] = None,
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
) -> None:
"""Initialize a new LandCover.ai dataset instance.
@ -67,10 +65,6 @@ class LandCoverAI(VisionDataset):
Parameters:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
transform: a function/transform that takes in a PIL image and returns a
transformed version
target_transform: a function/transform that takes in the target and
transforms it
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
@ -82,7 +76,8 @@ class LandCoverAI(VisionDataset):
"""
assert split in ["train", "val", "test"]
super().__init__(root, transforms, transform, target_transform)
self.root = root
self.transforms = transforms
if download:
self.download()
@ -96,7 +91,7 @@ class LandCoverAI(VisionDataset):
with open(os.path.join(self.root, self.base_folder, split + ".txt")) as f:
self.ids = f.readlines()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
@ -106,13 +101,15 @@ class LandCoverAI(VisionDataset):
data and label at that index
"""
id_ = self.ids[index].rstrip()
image = self._load_image(id_)
target = self._load_target(id_)
sample = {
"image": self._load_image(id_),
"mask": self._load_target(id_),
}
if self.transforms is not None:
image, target = self.transforms(image, target)
sample = self.transforms(sample)
return image, target
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.

Просмотреть файл

@ -1,14 +1,15 @@
import os
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
check_integrity,
download_file_from_google_drive,
download_url,
)
from .geo import VisionDataset
class VHR10(VisionDataset):
"""Northwestern Polytechnical University (NWPU) very-high-resolution ten-class
@ -74,11 +75,7 @@ class VHR10(VisionDataset):
self,
root: str = "data",
split: str = "positive",
transform: Optional[Callable[[Image.Image], Any]] = None,
target_transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
transforms: Optional[
Callable[[Image.Image, Dict[str, Any]], Tuple[Any, Any]]
] = None,
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
) -> None:
"""Initialize a new VHR-10 dataset instance.
@ -86,10 +83,6 @@ class VHR10(VisionDataset):
Parameters:
root: root directory where dataset can be found
split: one of "postive" or "negative"
transform: a function/transform that takes in a PIL image and returns a
transformed version
target_transform: a function/transform that takes in the target and
transforms it
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
@ -101,8 +94,9 @@ class VHR10(VisionDataset):
"""
assert split in ["positive", "negative"]
super().__init__(root, transforms, transform, target_transform)
self.root = root
self.split = split
self.transforms = transforms
if download:
self.download()
@ -126,7 +120,7 @@ class VHR10(VisionDataset):
)
)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
def __getitem__(self, index: int) -> Dict[str, Any]:
"""Return an index within the dataset.
Parameters:
@ -136,13 +130,15 @@ class VHR10(VisionDataset):
data and label at that index
"""
id_ = index % len(self) + 1
image = self._load_image(id_)
target = self._load_target(id_)
sample = {
"image": self._load_image(id_),
"label": self._load_target(id_),
}
if self.transforms is not None:
image, target = self.transforms(image, target)
sample = self.transforms(sample)
return image, target
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.