зеркало из https://github.com/microsoft/torchgeo.git
Add VisionDataset base class
This commit is contained in:
Родитель
9e58dc6a63
Коммит
39ea1be875
|
@ -1,6 +1,6 @@
|
||||||
from .cowc import COWCCounting, COWCDetection
|
from .cowc import COWCCounting, COWCDetection
|
||||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||||
from .geo import GeoDataset
|
from .geo import GeoDataset, VisionDataset
|
||||||
from .landcoverai import LandCoverAI
|
from .landcoverai import LandCoverAI
|
||||||
from .nwpu import VHR10
|
from .nwpu import VHR10
|
||||||
|
|
||||||
|
@ -11,4 +11,5 @@ __all__ = (
|
||||||
"GeoDataset",
|
"GeoDataset",
|
||||||
"LandCoverAI",
|
"LandCoverAI",
|
||||||
"VHR10",
|
"VHR10",
|
||||||
|
"VisionDataset",
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,15 +22,16 @@ import bz2
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.datasets import VisionDataset
|
|
||||||
from torchvision.datasets.utils import (
|
from torchvision.datasets.utils import (
|
||||||
check_integrity,
|
check_integrity,
|
||||||
download_url,
|
download_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .geo import VisionDataset
|
||||||
|
|
||||||
|
|
||||||
class _COWC(VisionDataset, abc.ABC):
|
class _COWC(VisionDataset, abc.ABC):
|
||||||
"""Abstract base class for all COWC datasets."""
|
"""Abstract base class for all COWC datasets."""
|
||||||
|
@ -69,9 +70,7 @@ class _COWC(VisionDataset, abc.ABC):
|
||||||
self,
|
self,
|
||||||
root: str = "data",
|
root: str = "data",
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||||
target_transform: Optional[Callable[[int], Any]] = None,
|
|
||||||
transforms: Optional[Callable[[Image.Image, int], Tuple[Any, Any]]] = None,
|
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a new COWC dataset instance.
|
"""Initialize a new COWC dataset instance.
|
||||||
|
@ -79,10 +78,6 @@ class _COWC(VisionDataset, abc.ABC):
|
||||||
Parameters:
|
Parameters:
|
||||||
root: root directory where dataset can be found
|
root: root directory where dataset can be found
|
||||||
split: one of "train" or "test"
|
split: one of "train" or "test"
|
||||||
transform: a function/transform that takes in a PIL image and returns a
|
|
||||||
transformed version
|
|
||||||
target_transform: a function/transform that takes in the target and
|
|
||||||
transforms it
|
|
||||||
transforms: a function/transform that takes input sample and its target as
|
transforms: a function/transform that takes input sample and its target as
|
||||||
entry and returns a transformed version
|
entry and returns a transformed version
|
||||||
download: if True, download dataset and store it in the root directory
|
download: if True, download dataset and store it in the root directory
|
||||||
|
@ -94,7 +89,8 @@ class _COWC(VisionDataset, abc.ABC):
|
||||||
"""
|
"""
|
||||||
assert split in ["train", "test"]
|
assert split in ["train", "test"]
|
||||||
|
|
||||||
super().__init__(root, transforms, transform, target_transform)
|
self.root = root
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
if download:
|
if download:
|
||||||
self.download()
|
self.download()
|
||||||
|
@ -116,7 +112,7 @@ class _COWC(VisionDataset, abc.ABC):
|
||||||
self.images.append(row[0])
|
self.images.append(row[0])
|
||||||
self.targets.append(row[1])
|
self.targets.append(row[1])
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
"""Return an index within the dataset.
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -125,13 +121,15 @@ class _COWC(VisionDataset, abc.ABC):
|
||||||
Returns:
|
Returns:
|
||||||
data and label at that index
|
data and label at that index
|
||||||
"""
|
"""
|
||||||
image = self._load_image(index)
|
sample = {
|
||||||
target = int(self.targets[index])
|
"image": self._load_image(index),
|
||||||
|
"label": int(self.targets[index]),
|
||||||
|
}
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
image, target = self.transforms(image, target)
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
return image, target
|
return sample
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of data points in the dataset.
|
"""Return the number of data points in the dataset.
|
||||||
|
|
|
@ -104,9 +104,7 @@ class CV4AKenyaCropType(GeoDataset):
|
||||||
chip_size: int = 256,
|
chip_size: int = 256,
|
||||||
stride: int = 128,
|
stride: int = 128,
|
||||||
bands: Tuple[str, ...] = band_names,
|
bands: Tuple[str, ...] = band_names,
|
||||||
transforms: Optional[
|
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||||
Callable[[np.ndarray, np.ndarray], Tuple[Any, Any]]
|
|
||||||
] = None,
|
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
@ -182,16 +180,18 @@ class CV4AKenyaCropType(GeoDataset):
|
||||||
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
|
labels = labels[y : y + self.chip_size, x : x + self.chip_size]
|
||||||
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
|
field_ids = field_ids[y : y + self.chip_size, x : x + self.chip_size]
|
||||||
|
|
||||||
if self.transforms is not None:
|
sample = {
|
||||||
img, labels = self.transforms(img, labels)
|
"image": img,
|
||||||
|
"mask": labels,
|
||||||
return {
|
|
||||||
"img": img,
|
|
||||||
"labels": labels,
|
|
||||||
"field_ids": field_ids,
|
"field_ids": field_ids,
|
||||||
"metadata": (tile_index, y, x),
|
"metadata": (tile_index, y, x),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of chips in the dataset.
|
"""Return the number of chips in the dataset.
|
||||||
|
|
||||||
|
|
|
@ -53,3 +53,42 @@ class GeoDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||||
{self.__class__.__name__} Dataset
|
{self.__class__.__name__} Dataset
|
||||||
type: GeoDataset
|
type: GeoDataset
|
||||||
size: {len(self)}"""
|
size: {len(self)}"""
|
||||||
|
|
||||||
|
|
||||||
|
class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
|
||||||
|
"""Abstract base class for datasets lacking geospatial information.
|
||||||
|
|
||||||
|
This base class is designed for datasets with pre-defined image chips.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data and labels at that index
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the length of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
length of the dataset
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return the informal string representation of the object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
informal string representation
|
||||||
|
"""
|
||||||
|
return f"""\
|
||||||
|
{self.__class__.__name__} Dataset
|
||||||
|
type: VisionDataset
|
||||||
|
size: {len(self)}"""
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.datasets import VisionDataset
|
|
||||||
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
from torchvision.datasets.utils import check_integrity, download_and_extract_archive
|
||||||
|
|
||||||
|
from .geo import VisionDataset
|
||||||
from .utils import working_dir
|
from .utils import working_dir
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,9 +57,7 @@ class LandCoverAI(VisionDataset):
|
||||||
self,
|
self,
|
||||||
root: str = "data",
|
root: str = "data",
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||||
target_transform: Optional[Callable[[Image.Image], Any]] = None,
|
|
||||||
transforms: Optional[Callable[[Image.Image, Image.Image], Any]] = None,
|
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a new LandCover.ai dataset instance.
|
"""Initialize a new LandCover.ai dataset instance.
|
||||||
|
@ -67,10 +65,6 @@ class LandCoverAI(VisionDataset):
|
||||||
Parameters:
|
Parameters:
|
||||||
root: root directory where dataset can be found
|
root: root directory where dataset can be found
|
||||||
split: one of "train", "val", or "test"
|
split: one of "train", "val", or "test"
|
||||||
transform: a function/transform that takes in a PIL image and returns a
|
|
||||||
transformed version
|
|
||||||
target_transform: a function/transform that takes in the target and
|
|
||||||
transforms it
|
|
||||||
transforms: a function/transform that takes input sample and its target as
|
transforms: a function/transform that takes input sample and its target as
|
||||||
entry and returns a transformed version
|
entry and returns a transformed version
|
||||||
download: if True, download dataset and store it in the root directory
|
download: if True, download dataset and store it in the root directory
|
||||||
|
@ -82,7 +76,8 @@ class LandCoverAI(VisionDataset):
|
||||||
"""
|
"""
|
||||||
assert split in ["train", "val", "test"]
|
assert split in ["train", "val", "test"]
|
||||||
|
|
||||||
super().__init__(root, transforms, transform, target_transform)
|
self.root = root
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
if download:
|
if download:
|
||||||
self.download()
|
self.download()
|
||||||
|
@ -96,7 +91,7 @@ class LandCoverAI(VisionDataset):
|
||||||
with open(os.path.join(self.root, self.base_folder, split + ".txt")) as f:
|
with open(os.path.join(self.root, self.base_folder, split + ".txt")) as f:
|
||||||
self.ids = f.readlines()
|
self.ids = f.readlines()
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
"""Return an index within the dataset.
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -106,13 +101,15 @@ class LandCoverAI(VisionDataset):
|
||||||
data and label at that index
|
data and label at that index
|
||||||
"""
|
"""
|
||||||
id_ = self.ids[index].rstrip()
|
id_ = self.ids[index].rstrip()
|
||||||
image = self._load_image(id_)
|
sample = {
|
||||||
target = self._load_target(id_)
|
"image": self._load_image(id_),
|
||||||
|
"mask": self._load_target(id_),
|
||||||
|
}
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
image, target = self.transforms(image, target)
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
return image, target
|
return sample
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of data points in the dataset.
|
"""Return the number of data points in the dataset.
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.datasets import VisionDataset
|
|
||||||
from torchvision.datasets.utils import (
|
from torchvision.datasets.utils import (
|
||||||
check_integrity,
|
check_integrity,
|
||||||
download_file_from_google_drive,
|
download_file_from_google_drive,
|
||||||
download_url,
|
download_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .geo import VisionDataset
|
||||||
|
|
||||||
|
|
||||||
class VHR10(VisionDataset):
|
class VHR10(VisionDataset):
|
||||||
"""Northwestern Polytechnical University (NWPU) very-high-resolution ten-class
|
"""Northwestern Polytechnical University (NWPU) very-high-resolution ten-class
|
||||||
|
@ -74,11 +75,7 @@ class VHR10(VisionDataset):
|
||||||
self,
|
self,
|
||||||
root: str = "data",
|
root: str = "data",
|
||||||
split: str = "positive",
|
split: str = "positive",
|
||||||
transform: Optional[Callable[[Image.Image], Any]] = None,
|
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||||
target_transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
|
||||||
transforms: Optional[
|
|
||||||
Callable[[Image.Image, Dict[str, Any]], Tuple[Any, Any]]
|
|
||||||
] = None,
|
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a new VHR-10 dataset instance.
|
"""Initialize a new VHR-10 dataset instance.
|
||||||
|
@ -86,10 +83,6 @@ class VHR10(VisionDataset):
|
||||||
Parameters:
|
Parameters:
|
||||||
root: root directory where dataset can be found
|
root: root directory where dataset can be found
|
||||||
split: one of "postive" or "negative"
|
split: one of "postive" or "negative"
|
||||||
transform: a function/transform that takes in a PIL image and returns a
|
|
||||||
transformed version
|
|
||||||
target_transform: a function/transform that takes in the target and
|
|
||||||
transforms it
|
|
||||||
transforms: a function/transform that takes input sample and its target as
|
transforms: a function/transform that takes input sample and its target as
|
||||||
entry and returns a transformed version
|
entry and returns a transformed version
|
||||||
download: if True, download dataset and store it in the root directory
|
download: if True, download dataset and store it in the root directory
|
||||||
|
@ -101,8 +94,9 @@ class VHR10(VisionDataset):
|
||||||
"""
|
"""
|
||||||
assert split in ["positive", "negative"]
|
assert split in ["positive", "negative"]
|
||||||
|
|
||||||
super().__init__(root, transforms, transform, target_transform)
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
if download:
|
if download:
|
||||||
self.download()
|
self.download()
|
||||||
|
@ -126,7 +120,7 @@ class VHR10(VisionDataset):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
"""Return an index within the dataset.
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -136,13 +130,15 @@ class VHR10(VisionDataset):
|
||||||
data and label at that index
|
data and label at that index
|
||||||
"""
|
"""
|
||||||
id_ = index % len(self) + 1
|
id_ = index % len(self) + 1
|
||||||
image = self._load_image(id_)
|
sample = {
|
||||||
target = self._load_target(id_)
|
"image": self._load_image(id_),
|
||||||
|
"label": self._load_target(id_),
|
||||||
|
}
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
image, target = self.transforms(image, target)
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
return image, target
|
return sample
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return the number of data points in the dataset.
|
"""Return the number of data points in the dataset.
|
||||||
|
|
Загрузка…
Ссылка в новой задаче