* add IDTReeS dataset

* dataset loads data now

* add optional laspy and pandas dependencies

* fixed docs failing

* format

* refactor verify and resample chm/hsi to 200x200

* add open3d optional dep

* overhaul

* temporarily remove open3d install bc their pypi is broken

* mypy fixes

* fixes per suggestions

* general cleanup

* test passing

* add min version for laspy and pandas

* add open3d dependency

* add open3d to mypy tests

* add hard install for python 3.9 open3d to actions

* attempt #2

* I think I got it now

* updated tests.yaml

* make open3d dep require python<3.9

* open3d has issues with macos python 3.6

* same for 3.7

* skip open3d plot test for macos

* formatting

* skip open3d plot test for windows

* update per suggestions

* update test data readme for las files

* updated per suggestions

* more changes per suggestions

* last change per suggestion

* Grammar fix in pandas dep requirement comment

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
isaac 2021-12-05 16:38:50 -06:00 коммит произвёл GitHub
Родитель fcbd1ab6c3
Коммит 0434f3c1ce
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
60 изменённых файлов: 758 добавлений и 0 удалений

Просмотреть файл

@ -118,6 +118,11 @@ GID-15 (Gaofen Image Dataset)
.. autoclass:: GID15
IDTReeS
^^^^^^^
.. autoclass:: IDTReeS
LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Просмотреть файл

@ -1,6 +1,7 @@
name: torchgeo
channels:
- conda-forge
- open3d-admin
dependencies:
- cudatoolkit
- einops
@ -23,11 +24,14 @@ dependencies:
- isort[colors]>=5.8
- jupyterlab
- kornia>=0.5.4
- laspy>=2.0.0
- mypy>=0.900
- nbmake>=0.1
- nbsphinx>=0.8.5
- omegaconf>=2.1
- open3d>=0.11.2
- opencv-python
- pandas>=0.19.1
- pillow>=2.9
- pydocstyle[toml]>=6.1
- pytest>=6

Просмотреть файл

@ -65,7 +65,14 @@ include = torchgeo*
# Optional dataset requirements
datasets =
h5py
# loading .las point clouds (idtrees) laspy 2+ required for Python 3.6+ support
laspy>=2.0.0
# open3d will add add support for python 3.9 in pypi in v0.14. v0.11.2 last version for tests to pass
# https://github.com/isl-org/Open3D/issues/1550
open3d>=0.11.2;python_version<'3.9'
opencv-python
# pandas 0.19.1+ required for python 3.6 support
pandas>=0.19.1
pycocotools
# radiant-mlhub 0.2.1+ required for api_key bugfix:
# https://github.com/radiantearth/radiant-mlhub/pull/48

Просмотреть файл

@ -91,3 +91,28 @@ masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8)
f.create_dataset("images", data=images)
f.create_dataset("masks", data=masks)
f.close()
```
### LAS Point Cloud files
```python
import laspy
num_points = 4
las = laspy.read("0.las")
las.points = las.points[:num_points]
points = np.random.randint(low=0, high=100, size=(num_points,), dtype=las.x.dtype)
las.x = points
las.y = points
las.z = points
if hasattr(las, "red"):
colors = np.random.randint(low=0, high=10, size=(num_points,), dtype=las.red.dtype)
las.red = colors
las.green = colors
las.blue = colors
las.write("0.las")
```

Двоичные данные
tests/data/idtrees/IDTREES_competition_test_v2.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/IDTREES_competition_train_v2.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_MLBS.dbf Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_MLBS.shp Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_MLBS.shx Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_OSBS.dbf Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_OSBS.shp Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_OSBS.shx Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_TALL.dbf Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_TALL.shp Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/ITC/test_TALL.shx Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/ITC/train_MLBS.dbf Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/ITC/train_MLBS.shp Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/ITC/train_MLBS.shx Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/ITC/train_OSBS.dbf Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/ITC/train_OSBS.shp Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/ITC/train_OSBS.shx Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,162 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import glob
import os
import shutil
import sys
from pathlib import Path
from typing import Any, Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import IDTReeS
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestIDTReeS:
@pytest.fixture(params=zip(["train", "test", "test"], ["task1", "task1", "task2"]))
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> IDTReeS:
pytest.importorskip("pandas")
pytest.importorskip("laspy")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.idtrees, "download_url", download_url
)
data_dir = os.path.join("tests", "data", "idtrees")
metadata = {
"train": {
"url": os.path.join(data_dir, "IDTREES_competition_train_v2.zip"),
"md5": "5ddfa76240b4bb6b4a7861d1d31c299c",
"filename": "IDTREES_competition_train_v2.zip",
},
"test": {
"url": os.path.join(data_dir, "IDTREES_competition_test_v2.zip"),
"md5": "b108931c84a70f2a38a8234290131c9b",
"filename": "IDTREES_competition_test_v2.zip",
},
}
split, task = request.param
monkeypatch.setattr(IDTReeS, "metadata", metadata) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
@pytest.fixture(params=["pandas", "laspy", "open3d"])
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
) -> str:
import_orig = builtins.__import__
package = str(request.param)
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)
return package
def test_getitem(self, dataset: IDTReeS) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["chm"], torch.Tensor)
assert isinstance(x["hsi"], torch.Tensor)
assert isinstance(x["las"], torch.Tensor)
assert x["image"].shape == (3, 200, 200)
assert x["chm"].shape == (1, 200, 200)
assert x["hsi"].shape == (369, 200, 200)
assert x["las"].ndim == 2
assert x["las"].shape[0] == 3
if "label" in x:
assert isinstance(x["label"], torch.Tensor)
if "boxes" in x:
assert isinstance(x["boxes"], torch.Tensor)
if x["boxes"].ndim != 1:
assert x["boxes"].ndim == 2
assert x["boxes"].shape[-1] == 4
def test_len(self, dataset: IDTReeS) -> None:
assert len(dataset) == 3
def test_already_downloaded(self, dataset: IDTReeS) -> None:
IDTReeS(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
IDTReeS(str(tmp_path))
def test_not_extracted(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "idtrees", "*.zip")
root = str(tmp_path)
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
IDTReeS(root)
def test_mock_missing_module(
self, dataset: IDTReeS, mock_missing_module: str
) -> None:
package = mock_missing_module
if package in ["pandas", "laspy"]:
with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
):
IDTReeS(dataset.root, download=True, checksum=True)
else:
with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
):
dataset.plot_las(0)
def test_plot(self, dataset: IDTReeS) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
if "boxes" in x:
x["prediction_boxes"] = x["boxes"]
dataset.plot(x, show_titles=True)
plt.close()
if "label" in x:
x["prediction_label"] = x["label"]
dataset.plot(x, show_titles=False)
plt.close()
@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="segmentation fault on macOS and windows",
)
def test_plot_las(self, dataset: IDTReeS) -> None:
pytest.importorskip("open3d")
vis = dataset.plot_las(index=0, colormap="BrBG")
vis.close()
vis = dataset.plot_las(index=0, colormap=None)
vis.close()
vis = dataset.plot_las(index=1, colormap=None)
vis.close()

Просмотреть файл

@ -37,6 +37,7 @@ from .geo import (
VisionDataset,
)
from .gid15 import GID15
from .idtrees import IDTReeS
from .landcoverai import LandCoverAI, LandCoverAIDataModule
from .landsat import (
Landsat,
@ -115,6 +116,7 @@ __all__ = (
"EuroSAT",
"EuroSATDataModule",
"GID15",
"IDTReeS",
"LandCoverAI",
"LandCoverAIDataModule",
"LEVIRCDPlus",

Просмотреть файл

@ -0,0 +1,553 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""IDTReeS dataset."""
import glob
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import fiona
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from rasterio.enums import Resampling
from torch import Tensor
from torchvision.utils import draw_bounding_boxes
from .geo import VisionDataset
from .utils import download_url, extract_archive
class IDTReeS(VisionDataset):
"""IDTReeS dataset.
The `IDTReeS <https://idtrees.org/competition/>`_
dataset is a dataset for tree crown detection.
Dataset features:
* RGB Image, Canopy Height Model (CHM), Hyperspectral Image (HSI), LiDAR Point Cloud
* Remote sensing and field data generated by the
`National Ecological Observatory Network (NEON) <https://data.neonscience.org/>`_
* 0.1 - 1m resolution imagery
* Task 1 - object detection (tree crown delination)
* Task 2 - object classification (species classification)
* Train set contains 85 images
* Test set (task 1) contains 153 images
* Test set (task 2) contains 353 images and tree crown polygons
Dataset format:
* optical - three-channel RGB 200x200 geotiff
* canopy height model - one-channel 20x20 geotiff
* hyperspectral - 369-channel 20x20 geotiff
* point cloud - Nx3 LAS file (.las), some files contain RGB colors per point
* shapely files (.shp) containing polygons
* csv file containing species labels and other metadata for each polygon
Dataset classes:
0. ACPE
1. ACRU
2. ACSA3
3. AMLA
4. BETUL
5. CAGL8
6. CATO6
7. FAGR
8. GOLA
9. LITU
10. LYLU3
11. MAGNO
12. NYBI
13. NYSY
14. OXYDE
15. PEPA37
16. PIEL
17. PIPA2
18. PINUS
19. PITA
20. PRSE2
21. QUAL
22. QUCO2
23. QUGE2
24. QUHE2
25. QULA2
26. QULA3
27. QUMO4
28. QUNI
29. QURU
30. QUERC
31. ROPS
32. TSCA
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.7717/peerj.5843
.. versionadded:: 0.2
"""
classes = {
"ACPE": "Acer pensylvanicum L.",
"ACRU": "Acer rubrum L.",
"ACSA3": "Acer saccharum Marshall",
"AMLA": "Amelanchier laevis Wiegand",
"BETUL": "Betula sp.",
"CAGL8": "Carya glabra (Mill.) Sweet",
"CATO6": "Carya tomentosa (Lam.) Nutt.",
"FAGR": "Fagus grandifolia Ehrh.",
"GOLA": "Gordonia lasianthus (L.) Ellis",
"LITU": "Liriodendron tulipifera L.",
"LYLU3": "Lyonia lucida (Lam.) K. Koch",
"MAGNO": "Magnolia sp.",
"NYBI": "Nyssa biflora Walter",
"NYSY": "Nyssa sylvatica Marshall",
"OXYDE": "Oxydendrum sp.",
"PEPA37": "Persea palustris (Raf.) Sarg.",
"PIEL": "Pinus elliottii Engelm.",
"PIPA2": "Pinus palustris Mill.",
"PINUS": "Pinus sp.",
"PITA": "Pinus taeda L.",
"PRSE2": "Prunus serotina Ehrh.",
"QUAL": "Quercus alba L.",
"QUCO2": "Quercus coccinea",
"QUGE2": "Quercus geminata Small",
"QUHE2": "Quercus hemisphaerica W. Bartram ex Willd.",
"QULA2": "Quercus laevis Walter",
"QULA3": "Quercus laurifolia Michx.",
"QUMO4": "Quercus montana Willd.",
"QUNI": "Quercus nigra L.",
"QURU": "Quercus rubra L.",
"QUERC": "Quercus sp.",
"ROPS": "Robinia pseudoacacia L.",
"TSCA": "Tsuga canadensis (L.) Carriere",
}
metadata = {
"train": {
"url": "https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1", # noqa: E501
"md5": "5ddfa76240b4bb6b4a7861d1d31c299c",
"filename": "IDTREES_competition_train_v2.zip",
},
"test": {
"url": "https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1", # noqa: E501
"md5": "b108931c84a70f2a38a8234290131c9b",
"filename": "IDTREES_competition_test_v2.zip",
},
}
directories = {"train": ["train"], "test": ["task1", "task2"]}
image_size = (200, 200)
def __init__(
self,
root: str = "data",
split: str = "train",
task: str = "task1",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new IDTReeS dataset instance.
Args:
root: root directory where dataset can be found
split: one of "train" or "test"
task: 'task1' for detection, 'task2' for detection + classification
(only relevant for split='test')
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
ImportError: if laspy or pandas are are not installed
"""
assert split in ["train", "test"]
assert task in ["task1", "task2"]
self.root = root
self.split = split
self.task = task
self.transforms = transforms
self.download = download
self.checksum = checksum
self.class2idx = {c: i for i, c in enumerate(self.classes)}
self.idx2class = {i: c for i, c in enumerate(self.classes)}
self.num_classes = len(self.classes)
self._verify()
try:
import pandas as pd # noqa: F401
except ImportError:
raise ImportError(
"pandas is not installed and is required to use this dataset"
)
try:
import laspy # noqa: F401
except ImportError:
raise ImportError(
"laspy is not installed and is required to use this dataset"
)
self.images, self.geometries, self.labels = self._load(root)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
path = self.images[index]
image = self._load_image(path).to(torch.uint8) # type:ignore[attr-defined]
hsi = self._load_image(path.replace("RGB", "HSI"))
chm = self._load_image(path.replace("RGB", "CHM"))
las = self._load_las(path.replace("RGB", "LAS").replace(".tif", ".las"))
sample = {"image": image, "hsi": hsi, "chm": chm, "las": las}
if self.split == "test":
if self.task == "task2":
sample["boxes"] = self._load_boxes(path)
else:
sample["boxes"] = self._load_boxes(path)
sample["label"] = self._load_target(path)
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.images)
def _load_image(self, path: str) -> Tensor:
"""Load a tiff file.
Args:
path: path to .tif file
Returns:
the image
"""
with rasterio.open(path) as f:
array = f.read(out_shape=self.image_size, resampling=Resampling.bilinear)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _load_las(self, path: str) -> Tensor:
"""Load a single point cloud.
Args:
path: path to .las file
Returns:
the point cloud
"""
import laspy
las = laspy.read(path)
array = np.stack([las.x, las.y, las.z], axis=0)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _load_boxes(self, path: str) -> Tensor:
"""Load object bounding boxes.
Args:
path: path to .tif file
Returns:
the bounding boxes
"""
base_path = os.path.basename(path)
# Find object ids and geometries
if self.split == "train":
indices = self.labels["rsFile"] == base_path
ids = self.labels[indices]["id"].tolist()
geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
# Test set - Task 2 has no mapping csv. Mapping is inside of geometry
else:
ids = [
k
for k, v in self.geometries.items()
if v["properties"]["plotID"] == base_path
]
geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
# Convert to pixel coords
boxes = []
with rasterio.open(path) as f:
for geom in geoms:
coords = [f.index(x, y) for x, y in geom]
xmin = min([coord[0] for coord in coords])
xmax = max([coord[0] for coord in coords])
ymin = min([coord[1] for coord in coords])
ymax = max([coord[1] for coord in coords])
boxes.append([xmin, ymin, xmax, ymax])
tensor: Tensor = torch.tensor(boxes) # type: ignore[attr-defined]
return tensor
def _load_target(self, path: str) -> Tensor:
"""Load target label for a single sample.
Args:
path: path to image
Returns:
the label
"""
# Find indices for objects in the image
base_path = os.path.basename(path)
indices = self.labels["rsFile"] == base_path
# Load object labels
classes = self.labels[indices]["taxonID"].tolist()
labels = [self.class2idx[c] for c in classes]
tensor: Tensor = torch.tensor(labels) # type: ignore[attr-defined]
return tensor
def _load(self, root: str) -> Tuple[List[str], Dict[int, Dict[str, Any]], Any]:
"""Load files, geometries, and labels.
Args:
root: root directory
Returns:
the image path, geometries, and labels
"""
import pandas as pd
if self.split == "train":
directory = os.path.join(root, self.directories[self.split][0])
labels: pd.DataFrame = self._load_labels(directory)
geoms = self._load_geometries(directory)
else:
directory = os.path.join(root, self.task)
if self.task == "task1":
geoms = None # type: ignore[assignment]
labels = None
else:
geoms = self._load_geometries(directory)
labels = None
images = glob.glob(os.path.join(directory, "RemoteSensing", "RGB", "*.tif"))
return images, geoms, labels
def _load_labels(self, directory: str) -> Any:
"""Load the csv files containing the labels.
Args:
directory: directory containing csv files
Returns:
a pandas DataFrame containing the labels for each image
"""
import pandas as pd
path_mapping = os.path.join(directory, "Field", "itc_rsFile.csv")
path_labels = os.path.join(directory, "Field", "train_data.csv")
df_mapping = pd.read_csv(path_mapping)
df_labels = pd.read_csv(path_labels)
df_mapping = df_mapping.set_index("indvdID", drop=True)
df_labels = df_labels.set_index("indvdID", drop=True)
df = df_labels.join(df_mapping, on="indvdID")
df = df.drop_duplicates()
df.reset_index()
return df
def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
"""Load the shape files containing the geometries.
Args:
directory: directory containing .shp files
Returns:
a dict containing the geometries for each object
"""
filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))
features: Dict[int, Dict[str, Any]] = {}
for path in filepaths:
with fiona.open(path) as src:
for i, feature in enumerate(src):
if self.split == "train":
features[feature["properties"]["id"]] = feature
# Test set task 2 has no id
else:
features[i] = feature
return features
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
url = self.metadata[self.split]["url"]
md5 = self.metadata[self.split]["md5"]
filename = self.metadata[self.split]["filename"]
directories = self.directories[self.split]
# Check if the files already exist
exists = [
os.path.exists(os.path.join(self.root, directory))
for directory in directories
]
if all(exists):
return
# Check if zip file already exists (if so then extract)
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
extract_archive(filepath)
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download and extract the dataset
download_url(
url, self.root, filename=filename, md5=md5 if self.checksum else None
)
filepath = os.path.join(self.root, filename)
extract_archive(filepath)
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
hsi_indices: Tuple[int, int, int] = (0, 1, 2),
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
hsi_indices: tuple of indices to create HSI false color image
Returns:
a matplotlib Figure with the rendered sample
"""
assert len(hsi_indices) == 3
def normalize(x: Tensor) -> Tensor:
return (x - x.min()) / (x.max() - x.min())
ncols = 3
hsi = normalize(sample["hsi"][hsi_indices, :, :]).permute((1, 2, 0)).numpy()
chm = normalize(sample["chm"]).permute((1, 2, 0)).numpy()
if "boxes" in sample:
labels = (
[self.idx2class[int(i)] for i in sample["label"]]
if "label" in sample
else None
)
image = draw_bounding_boxes(
image=sample["image"], boxes=sample["boxes"], labels=labels
)
image = image.permute((1, 2, 0)).numpy()
else:
image = sample["image"].permute((1, 2, 0)).numpy()
if "prediction_boxes" in sample:
ncols += 1
labels = (
[self.idx2class[int(i)] for i in sample["prediction_label"]]
if "prediction_label" in sample
else None
)
preds = draw_bounding_boxes(
image=sample["image"], boxes=sample["prediction_boxes"], labels=labels
)
preds = preds.permute((1, 2, 0)).numpy()
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
axs[0].imshow(image)
axs[0].axis("off")
axs[1].imshow(hsi)
axs[1].axis("off")
axs[2].imshow(chm)
axs[2].axis("off")
if ncols > 3:
axs[3].imshow(preds)
axs[3].axis("off")
if show_titles:
axs[0].set_title("Ground Truth")
axs[1].set_title("Hyperspectral False Color Image")
axs[2].set_title("Canopy Height Model")
if ncols > 3:
axs[3].set_title("Predictions")
if suptitle is not None:
plt.suptitle(suptitle)
return fig
def plot_las(self, index: int, colormap: Optional[str] = None) -> Any:
"""Plot a sample point cloud at the index.
Args:
index: index to plot
colormap: a valid matplotlib colormap
Returns:
a open3d.visualizer.Visualizer object. Use
Visualizer.run() to display
Raises:
ImportError: if open3d is not installed
"""
try:
import open3d # noqa: F401
except ImportError:
raise ImportError(
"open3d is not installed and is required to use this dataset"
)
import laspy
path = self.images[index]
path = path.replace("RGB", "LAS").replace(".tif", ".las")
las = laspy.read(path)
points = np.stack([las.x, las.y, las.z], axis=0).transpose((1, 0))
if colormap:
cm = plt.cm.get_cmap(colormap)
norm = plt.Normalize()
colors = cm(norm(points[:, 2]))[:, :3]
else:
# Some point cloud files have no color->points mapping
if hasattr(las, "red"):
colors = np.stack([las.red, las.green, las.blue], axis=0)
colors = colors.transpose((1, 0)) / 65535
# Default to no colormap if no colors exist in las file
else:
colors = np.zeros_like(points)
pcd = open3d.geometry.PointCloud()
pcd.points = open3d.utility.Vector3dVector(points)
pcd.colors = open3d.utility.Vector3dVector(colors)
vis = open3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(pcd)
return vis