This commit is contained in:
Adam J. Stewart 2021-07-16 20:20:13 +00:00
Родитель 53fb4b28b9
Коммит d7f2df061f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
5 изменённых файлов: 54 добавлений и 2 удалений

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,40 @@
import os
from pathlib import Path
import pytest
import torch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, Landsat8, ZipDataset
from torchgeo.transforms import Identity
class TestLandsat8:
@pytest.fixture
def dataset(self) -> Landsat8:
root = os.path.join("tests", "data")
bands = ["B4", "B3", "B2"]
transforms = Identity()
return Landsat8(root, bands=bands, transforms=transforms)
def test_getitem(self, dataset: Landsat8) -> None:
print(dataset.bounds.intersects(dataset.bounds))
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
def test_add(self, dataset: Landsat8) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Landsat data was found in "):
Landsat8(str(tmp_path))
def test_invalid_query(self, dataset: Landsat8) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* is not within bounds of the index:"
):
dataset[query]

Просмотреть файл

@ -58,6 +58,9 @@ class Landsat(GeoDataset, abc.ABC):
bands: bands to return (defaults to all bands)
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
Raises:
FileNotFoundError: if no files are found in ``root``
"""
self.root = root
self.crs = crs
@ -66,7 +69,8 @@ class Landsat(GeoDataset, abc.ABC):
# Create an R-tree to index the dataset
self.index = Index(interleaved=False, properties=Property(dimension=3))
fileglob = os.path.join(root, self.base_folder, f"**_{self.bands[0]}.TIF")
path = os.path.join(root, self.base_folder)
fileglob = os.path.join(path, f"**_{self.bands[0]}.TIF")
for i, filename in enumerate(glob.iglob(fileglob)):
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes
# https://www.usgs.gov/faqs/what-naming-convention-landsat-collection-2-level-1-and-level-2-scenes
@ -78,6 +82,9 @@ class Landsat(GeoDataset, abc.ABC):
coords = (minx, maxx, miny, maxy, timestamp, timestamp)
self.index.insert(i, coords, filename)
if "filename" not in locals():
raise FileNotFoundError(f"No Landsat data was found in '{path}'")
def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
@ -113,11 +120,16 @@ class Landsat(GeoDataset, abc.ABC):
data_list.append(image)
image = np.concatenate(data_list) # type: ignore[no-untyped-call]
image = image.astype(np.int32)
return {
sample = {
"image": torch.tensor(image), # type: ignore[attr-defined]
"crs": self.crs,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
class Landsat8(Landsat):
"""Landsat 8-9 Operational Land Imager (OLI) and Thermal Infrared Sensor (TIRS)."""