зеркало из https://github.com/microsoft/torchgeo.git
Add IDTReeS dataset (#201)
* add IDTReeS dataset * dataset loads data now * add optional laspy and pandas dependencies * fixed docs failing * format * refactor verify and resample chm/hsi to 200x200 * add open3d optional dep * overhaul * temporarily remove open3d install bc their pypi is broken * mypy fixes * fixes per suggestions * general cleanup * test passing * add min version for laspy and pandas * add open3d dependency * add open3d to mypy tests * add hard install for python 3.9 open3d to actions * attempt #2 * I think I got it now * updated tests.yaml * make open3d dep require python<3.9 * open3d has issues with macos python 3.6 * same for 3.7 * skip open3d plot test for macos * formatting * skip open3d plot test for windows * update per suggestions * update test data readme for las files * updated per suggestions * more changes per suggestions * last change per suggestion * Grammar fix in pandas dep requirement comment Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
fcbd1ab6c3
Коммит
0434f3c1ce
|
@ -118,6 +118,11 @@ GID-15 (Gaofen Image Dataset)
|
|||
|
||||
.. autoclass:: GID15
|
||||
|
||||
IDTReeS
|
||||
^^^^^^^
|
||||
|
||||
.. autoclass:: IDTReeS
|
||||
|
||||
LandCover.ai (Land Cover from Aerial Imagery)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
name: torchgeo
|
||||
channels:
|
||||
- conda-forge
|
||||
- open3d-admin
|
||||
dependencies:
|
||||
- cudatoolkit
|
||||
- einops
|
||||
|
@ -23,11 +24,14 @@ dependencies:
|
|||
- isort[colors]>=5.8
|
||||
- jupyterlab
|
||||
- kornia>=0.5.4
|
||||
- laspy>=2.0.0
|
||||
- mypy>=0.900
|
||||
- nbmake>=0.1
|
||||
- nbsphinx>=0.8.5
|
||||
- omegaconf>=2.1
|
||||
- open3d>=0.11.2
|
||||
- opencv-python
|
||||
- pandas>=0.19.1
|
||||
- pillow>=2.9
|
||||
- pydocstyle[toml]>=6.1
|
||||
- pytest>=6
|
||||
|
|
|
@ -65,7 +65,14 @@ include = torchgeo*
|
|||
# Optional dataset requirements
|
||||
datasets =
|
||||
h5py
|
||||
# loading .las point clouds (idtrees) laspy 2+ required for Python 3.6+ support
|
||||
laspy>=2.0.0
|
||||
# open3d will add add support for python 3.9 in pypi in v0.14. v0.11.2 last version for tests to pass
|
||||
# https://github.com/isl-org/Open3D/issues/1550
|
||||
open3d>=0.11.2;python_version<'3.9'
|
||||
opencv-python
|
||||
# pandas 0.19.1+ required for python 3.6 support
|
||||
pandas>=0.19.1
|
||||
pycocotools
|
||||
# radiant-mlhub 0.2.1+ required for api_key bugfix:
|
||||
# https://github.com/radiantearth/radiant-mlhub/pull/48
|
||||
|
|
|
@ -91,3 +91,28 @@ masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8)
|
|||
f.create_dataset("images", data=images)
|
||||
f.create_dataset("masks", data=masks)
|
||||
f.close()
|
||||
```
|
||||
|
||||
### LAS Point Cloud files
|
||||
|
||||
```python
|
||||
import laspy
|
||||
|
||||
num_points = 4
|
||||
|
||||
las = laspy.read("0.las")
|
||||
las.points = las.points[:num_points]
|
||||
|
||||
points = np.random.randint(low=0, high=100, size=(num_points,), dtype=las.x.dtype)
|
||||
las.x = points
|
||||
las.y = points
|
||||
las.z = points
|
||||
|
||||
if hasattr(las, "red"):
|
||||
colors = np.random.randint(low=0, high=10, size=(num_points,), dtype=las.red.dtype)
|
||||
las.red = colors
|
||||
las.green = colors
|
||||
las.blue = colors
|
||||
|
||||
las.write("0.las")
|
||||
```
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,162 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import IDTReeS
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestIDTReeS:
|
||||
@pytest.fixture(params=zip(["train", "test", "test"], ["task1", "task1", "task2"]))
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> IDTReeS:
|
||||
pytest.importorskip("pandas")
|
||||
pytest.importorskip("laspy")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.idtrees, "download_url", download_url
|
||||
)
|
||||
data_dir = os.path.join("tests", "data", "idtrees")
|
||||
metadata = {
|
||||
"train": {
|
||||
"url": os.path.join(data_dir, "IDTREES_competition_train_v2.zip"),
|
||||
"md5": "5ddfa76240b4bb6b4a7861d1d31c299c",
|
||||
"filename": "IDTREES_competition_train_v2.zip",
|
||||
},
|
||||
"test": {
|
||||
"url": os.path.join(data_dir, "IDTREES_competition_test_v2.zip"),
|
||||
"md5": "b108931c84a70f2a38a8234290131c9b",
|
||||
"filename": "IDTREES_competition_test_v2.zip",
|
||||
},
|
||||
}
|
||||
split, task = request.param
|
||||
monkeypatch.setattr(IDTReeS, "metadata", metadata) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture(params=["pandas", "laspy", "open3d"])
|
||||
def mock_missing_module(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
|
||||
) -> str:
|
||||
import_orig = builtins.__import__
|
||||
package = str(request.param)
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == package:
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
builtins, "__import__", mocked_import
|
||||
)
|
||||
return package
|
||||
|
||||
def test_getitem(self, dataset: IDTReeS) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["chm"], torch.Tensor)
|
||||
assert isinstance(x["hsi"], torch.Tensor)
|
||||
assert isinstance(x["las"], torch.Tensor)
|
||||
assert x["image"].shape == (3, 200, 200)
|
||||
assert x["chm"].shape == (1, 200, 200)
|
||||
assert x["hsi"].shape == (369, 200, 200)
|
||||
assert x["las"].ndim == 2
|
||||
assert x["las"].shape[0] == 3
|
||||
|
||||
if "label" in x:
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
if "boxes" in x:
|
||||
assert isinstance(x["boxes"], torch.Tensor)
|
||||
if x["boxes"].ndim != 1:
|
||||
assert x["boxes"].ndim == 2
|
||||
assert x["boxes"].shape[-1] == 4
|
||||
|
||||
def test_len(self, dataset: IDTReeS) -> None:
|
||||
assert len(dataset) == 3
|
||||
|
||||
def test_already_downloaded(self, dataset: IDTReeS) -> None:
|
||||
IDTReeS(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
IDTReeS(str(tmp_path))
|
||||
|
||||
def test_not_extracted(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join("tests", "data", "idtrees", "*.zip")
|
||||
root = str(tmp_path)
|
||||
for zipfile in glob.iglob(pathname):
|
||||
shutil.copy(zipfile, root)
|
||||
IDTReeS(root)
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: IDTReeS, mock_missing_module: str
|
||||
) -> None:
|
||||
package = mock_missing_module
|
||||
|
||||
if package in ["pandas", "laspy"]:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match=f"{package} is not installed and is required to use this dataset",
|
||||
):
|
||||
IDTReeS(dataset.root, download=True, checksum=True)
|
||||
else:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match=f"{package} is not installed and is required to use this dataset",
|
||||
):
|
||||
dataset.plot_las(0)
|
||||
|
||||
def test_plot(self, dataset: IDTReeS) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
if "boxes" in x:
|
||||
x["prediction_boxes"] = x["boxes"]
|
||||
dataset.plot(x, show_titles=True)
|
||||
plt.close()
|
||||
if "label" in x:
|
||||
x["prediction_label"] = x["label"]
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform in ["darwin", "win32"],
|
||||
reason="segmentation fault on macOS and windows",
|
||||
)
|
||||
def test_plot_las(self, dataset: IDTReeS) -> None:
|
||||
pytest.importorskip("open3d")
|
||||
vis = dataset.plot_las(index=0, colormap="BrBG")
|
||||
vis.close()
|
||||
vis = dataset.plot_las(index=0, colormap=None)
|
||||
vis.close()
|
||||
vis = dataset.plot_las(index=1, colormap=None)
|
||||
vis.close()
|
|
@ -37,6 +37,7 @@ from .geo import (
|
|||
VisionDataset,
|
||||
)
|
||||
from .gid15 import GID15
|
||||
from .idtrees import IDTReeS
|
||||
from .landcoverai import LandCoverAI, LandCoverAIDataModule
|
||||
from .landsat import (
|
||||
Landsat,
|
||||
|
@ -115,6 +116,7 @@ __all__ = (
|
|||
"EuroSAT",
|
||||
"EuroSATDataModule",
|
||||
"GID15",
|
||||
"IDTReeS",
|
||||
"LandCoverAI",
|
||||
"LandCoverAIDataModule",
|
||||
"LEVIRCDPlus",
|
||||
|
|
|
@ -0,0 +1,553 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""IDTReeS dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
from torchvision.utils import draw_bounding_boxes
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_url, extract_archive
|
||||
|
||||
|
||||
class IDTReeS(VisionDataset):
|
||||
"""IDTReeS dataset.
|
||||
|
||||
The `IDTReeS <https://idtrees.org/competition/>`_
|
||||
dataset is a dataset for tree crown detection.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* RGB Image, Canopy Height Model (CHM), Hyperspectral Image (HSI), LiDAR Point Cloud
|
||||
* Remote sensing and field data generated by the
|
||||
`National Ecological Observatory Network (NEON) <https://data.neonscience.org/>`_
|
||||
* 0.1 - 1m resolution imagery
|
||||
* Task 1 - object detection (tree crown delination)
|
||||
* Task 2 - object classification (species classification)
|
||||
* Train set contains 85 images
|
||||
* Test set (task 1) contains 153 images
|
||||
* Test set (task 2) contains 353 images and tree crown polygons
|
||||
|
||||
Dataset format:
|
||||
|
||||
* optical - three-channel RGB 200x200 geotiff
|
||||
* canopy height model - one-channel 20x20 geotiff
|
||||
* hyperspectral - 369-channel 20x20 geotiff
|
||||
* point cloud - Nx3 LAS file (.las), some files contain RGB colors per point
|
||||
* shapely files (.shp) containing polygons
|
||||
* csv file containing species labels and other metadata for each polygon
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. ACPE
|
||||
1. ACRU
|
||||
2. ACSA3
|
||||
3. AMLA
|
||||
4. BETUL
|
||||
5. CAGL8
|
||||
6. CATO6
|
||||
7. FAGR
|
||||
8. GOLA
|
||||
9. LITU
|
||||
10. LYLU3
|
||||
11. MAGNO
|
||||
12. NYBI
|
||||
13. NYSY
|
||||
14. OXYDE
|
||||
15. PEPA37
|
||||
16. PIEL
|
||||
17. PIPA2
|
||||
18. PINUS
|
||||
19. PITA
|
||||
20. PRSE2
|
||||
21. QUAL
|
||||
22. QUCO2
|
||||
23. QUGE2
|
||||
24. QUHE2
|
||||
25. QULA2
|
||||
26. QULA3
|
||||
27. QUMO4
|
||||
28. QUNI
|
||||
29. QURU
|
||||
30. QUERC
|
||||
31. ROPS
|
||||
32. TSCA
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://doi.org/10.7717/peerj.5843
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
classes = {
|
||||
"ACPE": "Acer pensylvanicum L.",
|
||||
"ACRU": "Acer rubrum L.",
|
||||
"ACSA3": "Acer saccharum Marshall",
|
||||
"AMLA": "Amelanchier laevis Wiegand",
|
||||
"BETUL": "Betula sp.",
|
||||
"CAGL8": "Carya glabra (Mill.) Sweet",
|
||||
"CATO6": "Carya tomentosa (Lam.) Nutt.",
|
||||
"FAGR": "Fagus grandifolia Ehrh.",
|
||||
"GOLA": "Gordonia lasianthus (L.) Ellis",
|
||||
"LITU": "Liriodendron tulipifera L.",
|
||||
"LYLU3": "Lyonia lucida (Lam.) K. Koch",
|
||||
"MAGNO": "Magnolia sp.",
|
||||
"NYBI": "Nyssa biflora Walter",
|
||||
"NYSY": "Nyssa sylvatica Marshall",
|
||||
"OXYDE": "Oxydendrum sp.",
|
||||
"PEPA37": "Persea palustris (Raf.) Sarg.",
|
||||
"PIEL": "Pinus elliottii Engelm.",
|
||||
"PIPA2": "Pinus palustris Mill.",
|
||||
"PINUS": "Pinus sp.",
|
||||
"PITA": "Pinus taeda L.",
|
||||
"PRSE2": "Prunus serotina Ehrh.",
|
||||
"QUAL": "Quercus alba L.",
|
||||
"QUCO2": "Quercus coccinea",
|
||||
"QUGE2": "Quercus geminata Small",
|
||||
"QUHE2": "Quercus hemisphaerica W. Bartram ex Willd.",
|
||||
"QULA2": "Quercus laevis Walter",
|
||||
"QULA3": "Quercus laurifolia Michx.",
|
||||
"QUMO4": "Quercus montana Willd.",
|
||||
"QUNI": "Quercus nigra L.",
|
||||
"QURU": "Quercus rubra L.",
|
||||
"QUERC": "Quercus sp.",
|
||||
"ROPS": "Robinia pseudoacacia L.",
|
||||
"TSCA": "Tsuga canadensis (L.) Carriere",
|
||||
}
|
||||
metadata = {
|
||||
"train": {
|
||||
"url": "https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1", # noqa: E501
|
||||
"md5": "5ddfa76240b4bb6b4a7861d1d31c299c",
|
||||
"filename": "IDTREES_competition_train_v2.zip",
|
||||
},
|
||||
"test": {
|
||||
"url": "https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1", # noqa: E501
|
||||
"md5": "b108931c84a70f2a38a8234290131c9b",
|
||||
"filename": "IDTREES_competition_test_v2.zip",
|
||||
},
|
||||
}
|
||||
directories = {"train": ["train"], "test": ["task1", "task2"]}
|
||||
image_size = (200, 200)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
task: str = "task1",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new IDTReeS dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train" or "test"
|
||||
task: 'task1' for detection, 'task2' for detection + classification
|
||||
(only relevant for split='test')
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
ImportError: if laspy or pandas are are not installed
|
||||
"""
|
||||
assert split in ["train", "test"]
|
||||
assert task in ["task1", "task2"]
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.task = task
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.class2idx = {c: i for i, c in enumerate(self.classes)}
|
||||
self.idx2class = {i: c for i, c in enumerate(self.classes)}
|
||||
self.num_classes = len(self.classes)
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import pandas as pd # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas is not installed and is required to use this dataset"
|
||||
)
|
||||
try:
|
||||
import laspy # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"laspy is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
self.images, self.geometries, self.labels = self._load(root)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
path = self.images[index]
|
||||
image = self._load_image(path).to(torch.uint8) # type:ignore[attr-defined]
|
||||
hsi = self._load_image(path.replace("RGB", "HSI"))
|
||||
chm = self._load_image(path.replace("RGB", "CHM"))
|
||||
las = self._load_las(path.replace("RGB", "LAS").replace(".tif", ".las"))
|
||||
sample = {"image": image, "hsi": hsi, "chm": chm, "las": las}
|
||||
|
||||
if self.split == "test":
|
||||
if self.task == "task2":
|
||||
sample["boxes"] = self._load_boxes(path)
|
||||
else:
|
||||
sample["boxes"] = self._load_boxes(path)
|
||||
sample["label"] = self._load_target(path)
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.images)
|
||||
|
||||
def _load_image(self, path: str) -> Tensor:
|
||||
"""Load a tiff file.
|
||||
|
||||
Args:
|
||||
path: path to .tif file
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
with rasterio.open(path) as f:
|
||||
array = f.read(out_shape=self.image_size, resampling=Resampling.bilinear)
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load_las(self, path: str) -> Tensor:
|
||||
"""Load a single point cloud.
|
||||
|
||||
Args:
|
||||
path: path to .las file
|
||||
|
||||
Returns:
|
||||
the point cloud
|
||||
"""
|
||||
import laspy
|
||||
|
||||
las = laspy.read(path)
|
||||
array = np.stack([las.x, las.y, las.z], axis=0)
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load_boxes(self, path: str) -> Tensor:
|
||||
"""Load object bounding boxes.
|
||||
|
||||
Args:
|
||||
path: path to .tif file
|
||||
|
||||
Returns:
|
||||
the bounding boxes
|
||||
"""
|
||||
base_path = os.path.basename(path)
|
||||
|
||||
# Find object ids and geometries
|
||||
if self.split == "train":
|
||||
indices = self.labels["rsFile"] == base_path
|
||||
ids = self.labels[indices]["id"].tolist()
|
||||
geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
|
||||
# Test set - Task 2 has no mapping csv. Mapping is inside of geometry
|
||||
else:
|
||||
ids = [
|
||||
k
|
||||
for k, v in self.geometries.items()
|
||||
if v["properties"]["plotID"] == base_path
|
||||
]
|
||||
geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
|
||||
|
||||
# Convert to pixel coords
|
||||
boxes = []
|
||||
with rasterio.open(path) as f:
|
||||
for geom in geoms:
|
||||
coords = [f.index(x, y) for x, y in geom]
|
||||
xmin = min([coord[0] for coord in coords])
|
||||
xmax = max([coord[0] for coord in coords])
|
||||
ymin = min([coord[1] for coord in coords])
|
||||
ymax = max([coord[1] for coord in coords])
|
||||
boxes.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
tensor: Tensor = torch.tensor(boxes) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load_target(self, path: str) -> Tensor:
|
||||
"""Load target label for a single sample.
|
||||
|
||||
Args:
|
||||
path: path to image
|
||||
|
||||
Returns:
|
||||
the label
|
||||
"""
|
||||
# Find indices for objects in the image
|
||||
base_path = os.path.basename(path)
|
||||
indices = self.labels["rsFile"] == base_path
|
||||
|
||||
# Load object labels
|
||||
classes = self.labels[indices]["taxonID"].tolist()
|
||||
labels = [self.class2idx[c] for c in classes]
|
||||
tensor: Tensor = torch.tensor(labels) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load(self, root: str) -> Tuple[List[str], Dict[int, Dict[str, Any]], Any]:
|
||||
"""Load files, geometries, and labels.
|
||||
|
||||
Args:
|
||||
root: root directory
|
||||
|
||||
Returns:
|
||||
the image path, geometries, and labels
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
if self.split == "train":
|
||||
directory = os.path.join(root, self.directories[self.split][0])
|
||||
labels: pd.DataFrame = self._load_labels(directory)
|
||||
geoms = self._load_geometries(directory)
|
||||
else:
|
||||
directory = os.path.join(root, self.task)
|
||||
if self.task == "task1":
|
||||
geoms = None # type: ignore[assignment]
|
||||
labels = None
|
||||
else:
|
||||
geoms = self._load_geometries(directory)
|
||||
labels = None
|
||||
|
||||
images = glob.glob(os.path.join(directory, "RemoteSensing", "RGB", "*.tif"))
|
||||
|
||||
return images, geoms, labels
|
||||
|
||||
def _load_labels(self, directory: str) -> Any:
|
||||
"""Load the csv files containing the labels.
|
||||
|
||||
Args:
|
||||
directory: directory containing csv files
|
||||
|
||||
Returns:
|
||||
a pandas DataFrame containing the labels for each image
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
path_mapping = os.path.join(directory, "Field", "itc_rsFile.csv")
|
||||
path_labels = os.path.join(directory, "Field", "train_data.csv")
|
||||
df_mapping = pd.read_csv(path_mapping)
|
||||
df_labels = pd.read_csv(path_labels)
|
||||
df_mapping = df_mapping.set_index("indvdID", drop=True)
|
||||
df_labels = df_labels.set_index("indvdID", drop=True)
|
||||
df = df_labels.join(df_mapping, on="indvdID")
|
||||
df = df.drop_duplicates()
|
||||
df.reset_index()
|
||||
return df
|
||||
|
||||
def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
|
||||
"""Load the shape files containing the geometries.
|
||||
|
||||
Args:
|
||||
directory: directory containing .shp files
|
||||
|
||||
Returns:
|
||||
a dict containing the geometries for each object
|
||||
"""
|
||||
filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))
|
||||
|
||||
features: Dict[int, Dict[str, Any]] = {}
|
||||
for path in filepaths:
|
||||
with fiona.open(path) as src:
|
||||
for i, feature in enumerate(src):
|
||||
if self.split == "train":
|
||||
features[feature["properties"]["id"]] = feature
|
||||
# Test set task 2 has no id
|
||||
else:
|
||||
features[i] = feature
|
||||
return features
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
url = self.metadata[self.split]["url"]
|
||||
md5 = self.metadata[self.split]["md5"]
|
||||
filename = self.metadata[self.split]["filename"]
|
||||
directories = self.directories[self.split]
|
||||
|
||||
# Check if the files already exist
|
||||
exists = [
|
||||
os.path.exists(os.path.join(self.root, directory))
|
||||
for directory in directories
|
||||
]
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if zip file already exists (if so then extract)
|
||||
filepath = os.path.join(self.root, filename)
|
||||
if os.path.exists(filepath):
|
||||
extract_archive(filepath)
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
)
|
||||
|
||||
# Download and extract the dataset
|
||||
download_url(
|
||||
url, self.root, filename=filename, md5=md5 if self.checksum else None
|
||||
)
|
||||
filepath = os.path.join(self.root, filename)
|
||||
extract_archive(filepath)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
hsi_indices: Tuple[int, int, int] = (0, 1, 2),
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
hsi_indices: tuple of indices to create HSI false color image
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
assert len(hsi_indices) == 3
|
||||
|
||||
def normalize(x: Tensor) -> Tensor:
|
||||
return (x - x.min()) / (x.max() - x.min())
|
||||
|
||||
ncols = 3
|
||||
|
||||
hsi = normalize(sample["hsi"][hsi_indices, :, :]).permute((1, 2, 0)).numpy()
|
||||
chm = normalize(sample["chm"]).permute((1, 2, 0)).numpy()
|
||||
|
||||
if "boxes" in sample:
|
||||
labels = (
|
||||
[self.idx2class[int(i)] for i in sample["label"]]
|
||||
if "label" in sample
|
||||
else None
|
||||
)
|
||||
image = draw_bounding_boxes(
|
||||
image=sample["image"], boxes=sample["boxes"], labels=labels
|
||||
)
|
||||
image = image.permute((1, 2, 0)).numpy()
|
||||
else:
|
||||
image = sample["image"].permute((1, 2, 0)).numpy()
|
||||
|
||||
if "prediction_boxes" in sample:
|
||||
ncols += 1
|
||||
labels = (
|
||||
[self.idx2class[int(i)] for i in sample["prediction_label"]]
|
||||
if "prediction_label" in sample
|
||||
else None
|
||||
)
|
||||
preds = draw_bounding_boxes(
|
||||
image=sample["image"], boxes=sample["prediction_boxes"], labels=labels
|
||||
)
|
||||
preds = preds.permute((1, 2, 0)).numpy()
|
||||
|
||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
axs[1].imshow(hsi)
|
||||
axs[1].axis("off")
|
||||
axs[2].imshow(chm)
|
||||
axs[2].axis("off")
|
||||
if ncols > 3:
|
||||
axs[3].imshow(preds)
|
||||
axs[3].axis("off")
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title("Ground Truth")
|
||||
axs[1].set_title("Hyperspectral False Color Image")
|
||||
axs[2].set_title("Canopy Height Model")
|
||||
if ncols > 3:
|
||||
axs[3].set_title("Predictions")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
||||
def plot_las(self, index: int, colormap: Optional[str] = None) -> Any:
|
||||
"""Plot a sample point cloud at the index.
|
||||
|
||||
Args:
|
||||
index: index to plot
|
||||
colormap: a valid matplotlib colormap
|
||||
|
||||
Returns:
|
||||
a open3d.visualizer.Visualizer object. Use
|
||||
Visualizer.run() to display
|
||||
|
||||
Raises:
|
||||
ImportError: if open3d is not installed
|
||||
"""
|
||||
try:
|
||||
import open3d # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"open3d is not installed and is required to use this dataset"
|
||||
)
|
||||
import laspy
|
||||
|
||||
path = self.images[index]
|
||||
path = path.replace("RGB", "LAS").replace(".tif", ".las")
|
||||
las = laspy.read(path)
|
||||
points = np.stack([las.x, las.y, las.z], axis=0).transpose((1, 0))
|
||||
|
||||
if colormap:
|
||||
cm = plt.cm.get_cmap(colormap)
|
||||
norm = plt.Normalize()
|
||||
colors = cm(norm(points[:, 2]))[:, :3]
|
||||
else:
|
||||
# Some point cloud files have no color->points mapping
|
||||
if hasattr(las, "red"):
|
||||
colors = np.stack([las.red, las.green, las.blue], axis=0)
|
||||
colors = colors.transpose((1, 0)) / 65535
|
||||
# Default to no colormap if no colors exist in las file
|
||||
else:
|
||||
colors = np.zeros_like(points)
|
||||
|
||||
pcd = open3d.geometry.PointCloud()
|
||||
pcd.points = open3d.utility.Vector3dVector(points)
|
||||
pcd.colors = open3d.utility.Vector3dVector(colors)
|
||||
vis = open3d.visualization.Visualizer()
|
||||
vis.create_window()
|
||||
vis.add_geometry(pcd)
|
||||
return vis
|
Загрузка…
Ссылка в новой задаче