diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b5a33461a..0bdf71dd2 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -118,6 +118,11 @@ GID-15 (Gaofen Image Dataset) .. autoclass:: GID15 +IDTReeS +^^^^^^^ + +.. autoclass:: IDTReeS + LandCover.ai (Land Cover from Aerial Imagery) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/environment.yml b/environment.yml index aef403329..17cea448a 100644 --- a/environment.yml +++ b/environment.yml @@ -1,6 +1,7 @@ name: torchgeo channels: - conda-forge + - open3d-admin dependencies: - cudatoolkit - einops @@ -23,11 +24,14 @@ dependencies: - isort[colors]>=5.8 - jupyterlab - kornia>=0.5.4 + - laspy>=2.0.0 - mypy>=0.900 - nbmake>=0.1 - nbsphinx>=0.8.5 - omegaconf>=2.1 + - open3d>=0.11.2 - opencv-python + - pandas>=0.19.1 - pillow>=2.9 - pydocstyle[toml]>=6.1 - pytest>=6 diff --git a/setup.cfg b/setup.cfg index f665c337f..ebcf9b3d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,7 +65,14 @@ include = torchgeo* # Optional dataset requirements datasets = h5py + # loading .las point clouds (idtrees) laspy 2+ required for Python 3.6+ support + laspy>=2.0.0 + # open3d will add add support for python 3.9 in pypi in v0.14. v0.11.2 last version for tests to pass + # https://github.com/isl-org/Open3D/issues/1550 + open3d>=0.11.2;python_version<'3.9' opencv-python + # pandas 0.19.1+ required for python 3.6 support + pandas>=0.19.1 pycocotools # radiant-mlhub 0.2.1+ required for api_key bugfix: # https://github.com/radiantearth/radiant-mlhub/pull/48 diff --git a/tests/data/README.md b/tests/data/README.md index 35cca49b9..3884cbf38 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -91,3 +91,28 @@ masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8) f.create_dataset("images", data=images) f.create_dataset("masks", data=masks) f.close() +``` + +### LAS Point Cloud files + +```python +import laspy + +num_points = 4 + +las = laspy.read("0.las") +las.points = las.points[:num_points] + +points = np.random.randint(low=0, high=100, size=(num_points,), dtype=las.x.dtype) +las.x = points +las.y = points +las.z = points + +if hasattr(las, "red"): + colors = np.random.randint(low=0, high=10, size=(num_points,), dtype=las.red.dtype) + las.red = colors + las.green = colors + las.blue = colors + +las.write("0.las") +``` diff --git a/tests/data/idtrees/IDTREES_competition_test_v2.zip b/tests/data/idtrees/IDTREES_competition_test_v2.zip new file mode 100644 index 000000000..26870fef1 Binary files /dev/null and b/tests/data/idtrees/IDTREES_competition_test_v2.zip differ diff --git a/tests/data/idtrees/IDTREES_competition_train_v2.zip b/tests/data/idtrees/IDTREES_competition_train_v2.zip new file mode 100644 index 000000000..1385991f4 Binary files /dev/null and b/tests/data/idtrees/IDTREES_competition_train_v2.zip differ diff --git a/tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif b/tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif new file mode 100644 index 000000000..22369e56f Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif b/tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif new file mode 100644 index 000000000..21dc8b0cd Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif b/tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif new file mode 100644 index 000000000..8e3f03ef0 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif b/tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif new file mode 100644 index 000000000..5a2dd4408 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif b/tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif new file mode 100644 index 000000000..39f61a017 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif b/tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif new file mode 100644 index 000000000..e03b7bf76 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las b/tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las new file mode 100644 index 000000000..f185afc11 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las differ diff --git a/tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las b/tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las new file mode 100644 index 000000000..19f0159ca Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las differ diff --git a/tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las b/tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las new file mode 100644 index 000000000..723efb46d Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las differ diff --git a/tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif b/tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif new file mode 100644 index 000000000..d6b3d6dcb Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif b/tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif new file mode 100644 index 000000000..7525657c0 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif differ diff --git a/tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif b/tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif new file mode 100644 index 000000000..d5692e445 Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif differ diff --git a/tests/data/idtrees/task2/ITC/test_MLBS.dbf b/tests/data/idtrees/task2/ITC/test_MLBS.dbf new file mode 100644 index 000000000..e7d06021b Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_MLBS.dbf differ diff --git a/tests/data/idtrees/task2/ITC/test_MLBS.shp b/tests/data/idtrees/task2/ITC/test_MLBS.shp new file mode 100644 index 000000000..45d90f102 Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_MLBS.shp differ diff --git a/tests/data/idtrees/task2/ITC/test_MLBS.shx b/tests/data/idtrees/task2/ITC/test_MLBS.shx new file mode 100644 index 000000000..f9e0f0ae9 Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_MLBS.shx differ diff --git a/tests/data/idtrees/task2/ITC/test_OSBS.dbf b/tests/data/idtrees/task2/ITC/test_OSBS.dbf new file mode 100644 index 000000000..1f64986d8 Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_OSBS.dbf differ diff --git a/tests/data/idtrees/task2/ITC/test_OSBS.shp b/tests/data/idtrees/task2/ITC/test_OSBS.shp new file mode 100644 index 000000000..fb4ce2f42 Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_OSBS.shp differ diff --git a/tests/data/idtrees/task2/ITC/test_OSBS.shx b/tests/data/idtrees/task2/ITC/test_OSBS.shx new file mode 100644 index 000000000..afa9cfce3 Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_OSBS.shx differ diff --git a/tests/data/idtrees/task2/ITC/test_TALL.dbf b/tests/data/idtrees/task2/ITC/test_TALL.dbf new file mode 100644 index 000000000..c8e498acb Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_TALL.dbf differ diff --git a/tests/data/idtrees/task2/ITC/test_TALL.shp b/tests/data/idtrees/task2/ITC/test_TALL.shp new file mode 100644 index 000000000..573fc771e Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_TALL.shp differ diff --git a/tests/data/idtrees/task2/ITC/test_TALL.shx b/tests/data/idtrees/task2/ITC/test_TALL.shx new file mode 100644 index 000000000..5dccd1df8 Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_TALL.shx differ diff --git a/tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif b/tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif new file mode 100644 index 000000000..5c9f2b201 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif b/tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif new file mode 100644 index 000000000..6af65d610 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif b/tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif new file mode 100644 index 000000000..5f973638c Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif b/tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif new file mode 100644 index 000000000..2976621cc Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif b/tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif new file mode 100644 index 000000000..592c7f7fa Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif b/tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif new file mode 100644 index 000000000..db0f88bd5 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las b/tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las new file mode 100644 index 000000000..8ba3b83b9 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las differ diff --git a/tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las b/tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las new file mode 100644 index 000000000..9edbe69c2 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las differ diff --git a/tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las b/tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las new file mode 100644 index 000000000..f26b1000a Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las differ diff --git a/tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif b/tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif new file mode 100644 index 000000000..f3ff31fb2 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif b/tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif new file mode 100644 index 000000000..1635be545 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif differ diff --git a/tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif b/tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif new file mode 100644 index 000000000..9a07eaff9 Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif differ diff --git a/tests/data/idtrees/train/ITC/train_MLBS.dbf b/tests/data/idtrees/train/ITC/train_MLBS.dbf new file mode 100644 index 000000000..d88f8c20c Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_MLBS.dbf differ diff --git a/tests/data/idtrees/train/ITC/train_MLBS.shp b/tests/data/idtrees/train/ITC/train_MLBS.shp new file mode 100644 index 000000000..070be5da1 Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_MLBS.shp differ diff --git a/tests/data/idtrees/train/ITC/train_MLBS.shx b/tests/data/idtrees/train/ITC/train_MLBS.shx new file mode 100644 index 000000000..7157402b3 Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_MLBS.shx differ diff --git a/tests/data/idtrees/train/ITC/train_OSBS.dbf b/tests/data/idtrees/train/ITC/train_OSBS.dbf new file mode 100644 index 000000000..ad8161f5c Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_OSBS.dbf differ diff --git a/tests/data/idtrees/train/ITC/train_OSBS.shp b/tests/data/idtrees/train/ITC/train_OSBS.shp new file mode 100644 index 000000000..9ee65c3e2 Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_OSBS.shp differ diff --git a/tests/data/idtrees/train/ITC/train_OSBS.shx b/tests/data/idtrees/train/ITC/train_OSBS.shx new file mode 100644 index 000000000..c514c8544 Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_OSBS.shx differ diff --git a/tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif b/tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif new file mode 100644 index 000000000..bfa936c9e Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif new file mode 100644 index 000000000..a072fc25e Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif new file mode 100644 index 000000000..191d9a16a Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif b/tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif new file mode 100644 index 000000000..0fad804da Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif new file mode 100644 index 000000000..6b78adbbd Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif new file mode 100644 index 000000000..03ea53f1b Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las b/tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las new file mode 100644 index 000000000..4960fe63c Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las differ diff --git a/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las new file mode 100644 index 000000000..494a1ad82 Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las differ diff --git a/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las new file mode 100644 index 000000000..760476898 Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las differ diff --git a/tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif b/tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif new file mode 100644 index 000000000..064644fc9 Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif new file mode 100644 index 000000000..b499dc59e Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif differ diff --git a/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif new file mode 100644 index 000000000..f4eb86a0b Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif differ diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py new file mode 100644 index 000000000..9f39e47e6 --- /dev/null +++ b/tests/datasets/test_idtrees.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import builtins +import glob +import os +import shutil +import sys +from pathlib import Path +from typing import Any, Generator + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import IDTReeS + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestIDTReeS: + @pytest.fixture(params=zip(["train", "test", "test"], ["task1", "task1", "task2"])) + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + request: SubRequest, + ) -> IDTReeS: + pytest.importorskip("pandas") + pytest.importorskip("laspy") + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.idtrees, "download_url", download_url + ) + data_dir = os.path.join("tests", "data", "idtrees") + metadata = { + "train": { + "url": os.path.join(data_dir, "IDTREES_competition_train_v2.zip"), + "md5": "5ddfa76240b4bb6b4a7861d1d31c299c", + "filename": "IDTREES_competition_train_v2.zip", + }, + "test": { + "url": os.path.join(data_dir, "IDTREES_competition_test_v2.zip"), + "md5": "b108931c84a70f2a38a8234290131c9b", + "filename": "IDTREES_competition_test_v2.zip", + }, + } + split, task = request.param + monkeypatch.setattr(IDTReeS, "metadata", metadata) # type: ignore[attr-defined] + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + return IDTReeS(root, split, task, transforms, download=True, checksum=True) + + @pytest.fixture(params=["pandas", "laspy", "open3d"]) + def mock_missing_module( + self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest + ) -> str: + import_orig = builtins.__import__ + package = str(request.param) + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == package: + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr( # type: ignore[attr-defined] + builtins, "__import__", mocked_import + ) + return package + + def test_getitem(self, dataset: IDTReeS) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["chm"], torch.Tensor) + assert isinstance(x["hsi"], torch.Tensor) + assert isinstance(x["las"], torch.Tensor) + assert x["image"].shape == (3, 200, 200) + assert x["chm"].shape == (1, 200, 200) + assert x["hsi"].shape == (369, 200, 200) + assert x["las"].ndim == 2 + assert x["las"].shape[0] == 3 + + if "label" in x: + assert isinstance(x["label"], torch.Tensor) + if "boxes" in x: + assert isinstance(x["boxes"], torch.Tensor) + if x["boxes"].ndim != 1: + assert x["boxes"].ndim == 2 + assert x["boxes"].shape[-1] == 4 + + def test_len(self, dataset: IDTReeS) -> None: + assert len(dataset) == 3 + + def test_already_downloaded(self, dataset: IDTReeS) -> None: + IDTReeS(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + err = "Dataset not found in `root` directory and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + with pytest.raises(RuntimeError, match=err): + IDTReeS(str(tmp_path)) + + def test_not_extracted(self, tmp_path: Path) -> None: + pathname = os.path.join("tests", "data", "idtrees", "*.zip") + root = str(tmp_path) + for zipfile in glob.iglob(pathname): + shutil.copy(zipfile, root) + IDTReeS(root) + + def test_mock_missing_module( + self, dataset: IDTReeS, mock_missing_module: str + ) -> None: + package = mock_missing_module + + if package in ["pandas", "laspy"]: + with pytest.raises( + ImportError, + match=f"{package} is not installed and is required to use this dataset", + ): + IDTReeS(dataset.root, download=True, checksum=True) + else: + with pytest.raises( + ImportError, + match=f"{package} is not installed and is required to use this dataset", + ): + dataset.plot_las(0) + + def test_plot(self, dataset: IDTReeS) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + + if "boxes" in x: + x["prediction_boxes"] = x["boxes"] + dataset.plot(x, show_titles=True) + plt.close() + if "label" in x: + x["prediction_label"] = x["label"] + dataset.plot(x, show_titles=False) + plt.close() + + @pytest.mark.skipif( + sys.platform in ["darwin", "win32"], + reason="segmentation fault on macOS and windows", + ) + def test_plot_las(self, dataset: IDTReeS) -> None: + pytest.importorskip("open3d") + vis = dataset.plot_las(index=0, colormap="BrBG") + vis.close() + vis = dataset.plot_las(index=0, colormap=None) + vis.close() + vis = dataset.plot_las(index=1, colormap=None) + vis.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index ccd5cefc2..4e29a516e 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -37,6 +37,7 @@ from .geo import ( VisionDataset, ) from .gid15 import GID15 +from .idtrees import IDTReeS from .landcoverai import LandCoverAI, LandCoverAIDataModule from .landsat import ( Landsat, @@ -115,6 +116,7 @@ __all__ = ( "EuroSAT", "EuroSATDataModule", "GID15", + "IDTReeS", "LandCoverAI", "LandCoverAIDataModule", "LEVIRCDPlus", diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py new file mode 100644 index 000000000..aff6ac82d --- /dev/null +++ b/torchgeo/datasets/idtrees.py @@ -0,0 +1,553 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""IDTReeS dataset.""" + +import glob +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import fiona +import matplotlib.pyplot as plt +import numpy as np +import rasterio +import torch +from rasterio.enums import Resampling +from torch import Tensor +from torchvision.utils import draw_bounding_boxes + +from .geo import VisionDataset +from .utils import download_url, extract_archive + + +class IDTReeS(VisionDataset): + """IDTReeS dataset. + + The `IDTReeS `_ + dataset is a dataset for tree crown detection. + + Dataset features: + + * RGB Image, Canopy Height Model (CHM), Hyperspectral Image (HSI), LiDAR Point Cloud + * Remote sensing and field data generated by the + `National Ecological Observatory Network (NEON) `_ + * 0.1 - 1m resolution imagery + * Task 1 - object detection (tree crown delination) + * Task 2 - object classification (species classification) + * Train set contains 85 images + * Test set (task 1) contains 153 images + * Test set (task 2) contains 353 images and tree crown polygons + + Dataset format: + + * optical - three-channel RGB 200x200 geotiff + * canopy height model - one-channel 20x20 geotiff + * hyperspectral - 369-channel 20x20 geotiff + * point cloud - Nx3 LAS file (.las), some files contain RGB colors per point + * shapely files (.shp) containing polygons + * csv file containing species labels and other metadata for each polygon + + Dataset classes: + + 0. ACPE + 1. ACRU + 2. ACSA3 + 3. AMLA + 4. BETUL + 5. CAGL8 + 6. CATO6 + 7. FAGR + 8. GOLA + 9. LITU + 10. LYLU3 + 11. MAGNO + 12. NYBI + 13. NYSY + 14. OXYDE + 15. PEPA37 + 16. PIEL + 17. PIPA2 + 18. PINUS + 19. PITA + 20. PRSE2 + 21. QUAL + 22. QUCO2 + 23. QUGE2 + 24. QUHE2 + 25. QULA2 + 26. QULA3 + 27. QUMO4 + 28. QUNI + 29. QURU + 30. QUERC + 31. ROPS + 32. TSCA + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.7717/peerj.5843 + + .. versionadded:: 0.2 + """ + + classes = { + "ACPE": "Acer pensylvanicum L.", + "ACRU": "Acer rubrum L.", + "ACSA3": "Acer saccharum Marshall", + "AMLA": "Amelanchier laevis Wiegand", + "BETUL": "Betula sp.", + "CAGL8": "Carya glabra (Mill.) Sweet", + "CATO6": "Carya tomentosa (Lam.) Nutt.", + "FAGR": "Fagus grandifolia Ehrh.", + "GOLA": "Gordonia lasianthus (L.) Ellis", + "LITU": "Liriodendron tulipifera L.", + "LYLU3": "Lyonia lucida (Lam.) K. Koch", + "MAGNO": "Magnolia sp.", + "NYBI": "Nyssa biflora Walter", + "NYSY": "Nyssa sylvatica Marshall", + "OXYDE": "Oxydendrum sp.", + "PEPA37": "Persea palustris (Raf.) Sarg.", + "PIEL": "Pinus elliottii Engelm.", + "PIPA2": "Pinus palustris Mill.", + "PINUS": "Pinus sp.", + "PITA": "Pinus taeda L.", + "PRSE2": "Prunus serotina Ehrh.", + "QUAL": "Quercus alba L.", + "QUCO2": "Quercus coccinea", + "QUGE2": "Quercus geminata Small", + "QUHE2": "Quercus hemisphaerica W. Bartram ex Willd.", + "QULA2": "Quercus laevis Walter", + "QULA3": "Quercus laurifolia Michx.", + "QUMO4": "Quercus montana Willd.", + "QUNI": "Quercus nigra L.", + "QURU": "Quercus rubra L.", + "QUERC": "Quercus sp.", + "ROPS": "Robinia pseudoacacia L.", + "TSCA": "Tsuga canadensis (L.) Carriere", + } + metadata = { + "train": { + "url": "https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1", # noqa: E501 + "md5": "5ddfa76240b4bb6b4a7861d1d31c299c", + "filename": "IDTREES_competition_train_v2.zip", + }, + "test": { + "url": "https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1", # noqa: E501 + "md5": "b108931c84a70f2a38a8234290131c9b", + "filename": "IDTREES_competition_test_v2.zip", + }, + } + directories = {"train": ["train"], "test": ["task1", "task2"]} + image_size = (200, 200) + + def __init__( + self, + root: str = "data", + split: str = "train", + task: str = "task1", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new IDTReeS dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train" or "test" + task: 'task1' for detection, 'task2' for detection + classification + (only relevant for split='test') + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + ImportError: if laspy or pandas are are not installed + """ + assert split in ["train", "test"] + assert task in ["task1", "task2"] + self.root = root + self.split = split + self.task = task + self.transforms = transforms + self.download = download + self.checksum = checksum + self.class2idx = {c: i for i, c in enumerate(self.classes)} + self.idx2class = {i: c for i, c in enumerate(self.classes)} + self.num_classes = len(self.classes) + self._verify() + + try: + import pandas as pd # noqa: F401 + except ImportError: + raise ImportError( + "pandas is not installed and is required to use this dataset" + ) + try: + import laspy # noqa: F401 + except ImportError: + raise ImportError( + "laspy is not installed and is required to use this dataset" + ) + + self.images, self.geometries, self.labels = self._load(root) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + path = self.images[index] + image = self._load_image(path).to(torch.uint8) # type:ignore[attr-defined] + hsi = self._load_image(path.replace("RGB", "HSI")) + chm = self._load_image(path.replace("RGB", "CHM")) + las = self._load_las(path.replace("RGB", "LAS").replace(".tif", ".las")) + sample = {"image": image, "hsi": hsi, "chm": chm, "las": las} + + if self.split == "test": + if self.task == "task2": + sample["boxes"] = self._load_boxes(path) + else: + sample["boxes"] = self._load_boxes(path) + sample["label"] = self._load_target(path) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.images) + + def _load_image(self, path: str) -> Tensor: + """Load a tiff file. + + Args: + path: path to .tif file + + Returns: + the image + """ + with rasterio.open(path) as f: + array = f.read(out_shape=self.image_size, resampling=Resampling.bilinear) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + return tensor + + def _load_las(self, path: str) -> Tensor: + """Load a single point cloud. + + Args: + path: path to .las file + + Returns: + the point cloud + """ + import laspy + + las = laspy.read(path) + array = np.stack([las.x, las.y, las.z], axis=0) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + return tensor + + def _load_boxes(self, path: str) -> Tensor: + """Load object bounding boxes. + + Args: + path: path to .tif file + + Returns: + the bounding boxes + """ + base_path = os.path.basename(path) + + # Find object ids and geometries + if self.split == "train": + indices = self.labels["rsFile"] == base_path + ids = self.labels[indices]["id"].tolist() + geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids] + # Test set - Task 2 has no mapping csv. Mapping is inside of geometry + else: + ids = [ + k + for k, v in self.geometries.items() + if v["properties"]["plotID"] == base_path + ] + geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids] + + # Convert to pixel coords + boxes = [] + with rasterio.open(path) as f: + for geom in geoms: + coords = [f.index(x, y) for x, y in geom] + xmin = min([coord[0] for coord in coords]) + xmax = max([coord[0] for coord in coords]) + ymin = min([coord[1] for coord in coords]) + ymax = max([coord[1] for coord in coords]) + boxes.append([xmin, ymin, xmax, ymax]) + + tensor: Tensor = torch.tensor(boxes) # type: ignore[attr-defined] + return tensor + + def _load_target(self, path: str) -> Tensor: + """Load target label for a single sample. + + Args: + path: path to image + + Returns: + the label + """ + # Find indices for objects in the image + base_path = os.path.basename(path) + indices = self.labels["rsFile"] == base_path + + # Load object labels + classes = self.labels[indices]["taxonID"].tolist() + labels = [self.class2idx[c] for c in classes] + tensor: Tensor = torch.tensor(labels) # type: ignore[attr-defined] + return tensor + + def _load(self, root: str) -> Tuple[List[str], Dict[int, Dict[str, Any]], Any]: + """Load files, geometries, and labels. + + Args: + root: root directory + + Returns: + the image path, geometries, and labels + """ + import pandas as pd + + if self.split == "train": + directory = os.path.join(root, self.directories[self.split][0]) + labels: pd.DataFrame = self._load_labels(directory) + geoms = self._load_geometries(directory) + else: + directory = os.path.join(root, self.task) + if self.task == "task1": + geoms = None # type: ignore[assignment] + labels = None + else: + geoms = self._load_geometries(directory) + labels = None + + images = glob.glob(os.path.join(directory, "RemoteSensing", "RGB", "*.tif")) + + return images, geoms, labels + + def _load_labels(self, directory: str) -> Any: + """Load the csv files containing the labels. + + Args: + directory: directory containing csv files + + Returns: + a pandas DataFrame containing the labels for each image + """ + import pandas as pd + + path_mapping = os.path.join(directory, "Field", "itc_rsFile.csv") + path_labels = os.path.join(directory, "Field", "train_data.csv") + df_mapping = pd.read_csv(path_mapping) + df_labels = pd.read_csv(path_labels) + df_mapping = df_mapping.set_index("indvdID", drop=True) + df_labels = df_labels.set_index("indvdID", drop=True) + df = df_labels.join(df_mapping, on="indvdID") + df = df.drop_duplicates() + df.reset_index() + return df + + def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]: + """Load the shape files containing the geometries. + + Args: + directory: directory containing .shp files + + Returns: + a dict containing the geometries for each object + """ + filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp")) + + features: Dict[int, Dict[str, Any]] = {} + for path in filepaths: + with fiona.open(path) as src: + for i, feature in enumerate(src): + if self.split == "train": + features[feature["properties"]["id"]] = feature + # Test set task 2 has no id + else: + features[i] = feature + return features + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + url = self.metadata[self.split]["url"] + md5 = self.metadata[self.split]["md5"] + filename = self.metadata[self.split]["filename"] + directories = self.directories[self.split] + + # Check if the files already exist + exists = [ + os.path.exists(os.path.join(self.root, directory)) + for directory in directories + ] + if all(exists): + return + + # Check if zip file already exists (if so then extract) + filepath = os.path.join(self.root, filename) + if os.path.exists(filepath): + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + "Dataset not found in `root` directory and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + + # Download and extract the dataset + download_url( + url, self.root, filename=filename, md5=md5 if self.checksum else None + ) + filepath = os.path.join(self.root, filename) + extract_archive(filepath) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + hsi_indices: Tuple[int, int, int] = (0, 1, 2), + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + hsi_indices: tuple of indices to create HSI false color image + + Returns: + a matplotlib Figure with the rendered sample + """ + assert len(hsi_indices) == 3 + + def normalize(x: Tensor) -> Tensor: + return (x - x.min()) / (x.max() - x.min()) + + ncols = 3 + + hsi = normalize(sample["hsi"][hsi_indices, :, :]).permute((1, 2, 0)).numpy() + chm = normalize(sample["chm"]).permute((1, 2, 0)).numpy() + + if "boxes" in sample: + labels = ( + [self.idx2class[int(i)] for i in sample["label"]] + if "label" in sample + else None + ) + image = draw_bounding_boxes( + image=sample["image"], boxes=sample["boxes"], labels=labels + ) + image = image.permute((1, 2, 0)).numpy() + else: + image = sample["image"].permute((1, 2, 0)).numpy() + + if "prediction_boxes" in sample: + ncols += 1 + labels = ( + [self.idx2class[int(i)] for i in sample["prediction_label"]] + if "prediction_label" in sample + else None + ) + preds = draw_bounding_boxes( + image=sample["image"], boxes=sample["prediction_boxes"], labels=labels + ) + preds = preds.permute((1, 2, 0)).numpy() + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(hsi) + axs[1].axis("off") + axs[2].imshow(chm) + axs[2].axis("off") + if ncols > 3: + axs[3].imshow(preds) + axs[3].axis("off") + + if show_titles: + axs[0].set_title("Ground Truth") + axs[1].set_title("Hyperspectral False Color Image") + axs[2].set_title("Canopy Height Model") + if ncols > 3: + axs[3].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig + + def plot_las(self, index: int, colormap: Optional[str] = None) -> Any: + """Plot a sample point cloud at the index. + + Args: + index: index to plot + colormap: a valid matplotlib colormap + + Returns: + a open3d.visualizer.Visualizer object. Use + Visualizer.run() to display + + Raises: + ImportError: if open3d is not installed + """ + try: + import open3d # noqa: F401 + except ImportError: + raise ImportError( + "open3d is not installed and is required to use this dataset" + ) + import laspy + + path = self.images[index] + path = path.replace("RGB", "LAS").replace(".tif", ".las") + las = laspy.read(path) + points = np.stack([las.x, las.y, las.z], axis=0).transpose((1, 0)) + + if colormap: + cm = plt.cm.get_cmap(colormap) + norm = plt.Normalize() + colors = cm(norm(points[:, 2]))[:, :3] + else: + # Some point cloud files have no color->points mapping + if hasattr(las, "red"): + colors = np.stack([las.red, las.green, las.blue], axis=0) + colors = colors.transpose((1, 0)) / 65535 + # Default to no colormap if no colors exist in las file + else: + colors = np.zeros_like(points) + + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(points) + pcd.colors = open3d.utility.Vector3dVector(colors) + vis = open3d.visualization.Visualizer() + vis.create_window() + vis.add_geometry(pcd) + return vis