diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index b5a33461a..0bdf71dd2 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -118,6 +118,11 @@ GID-15 (Gaofen Image Dataset)
.. autoclass:: GID15
+IDTReeS
+^^^^^^^
+
+.. autoclass:: IDTReeS
+
LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/environment.yml b/environment.yml
index aef403329..17cea448a 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,6 +1,7 @@
name: torchgeo
channels:
- conda-forge
+ - open3d-admin
dependencies:
- cudatoolkit
- einops
@@ -23,11 +24,14 @@ dependencies:
- isort[colors]>=5.8
- jupyterlab
- kornia>=0.5.4
+ - laspy>=2.0.0
- mypy>=0.900
- nbmake>=0.1
- nbsphinx>=0.8.5
- omegaconf>=2.1
+ - open3d>=0.11.2
- opencv-python
+ - pandas>=0.19.1
- pillow>=2.9
- pydocstyle[toml]>=6.1
- pytest>=6
diff --git a/setup.cfg b/setup.cfg
index f665c337f..ebcf9b3d2 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -65,7 +65,14 @@ include = torchgeo*
# Optional dataset requirements
datasets =
h5py
+ # loading .las point clouds (idtrees) laspy 2+ required for Python 3.6+ support
+ laspy>=2.0.0
+ # open3d will add add support for python 3.9 in pypi in v0.14. v0.11.2 last version for tests to pass
+ # https://github.com/isl-org/Open3D/issues/1550
+ open3d>=0.11.2;python_version<'3.9'
opencv-python
+ # pandas 0.19.1+ required for python 3.6 support
+ pandas>=0.19.1
pycocotools
# radiant-mlhub 0.2.1+ required for api_key bugfix:
# https://github.com/radiantearth/radiant-mlhub/pull/48
diff --git a/tests/data/README.md b/tests/data/README.md
index 35cca49b9..3884cbf38 100644
--- a/tests/data/README.md
+++ b/tests/data/README.md
@@ -91,3 +91,28 @@ masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8)
f.create_dataset("images", data=images)
f.create_dataset("masks", data=masks)
f.close()
+```
+
+### LAS Point Cloud files
+
+```python
+import laspy
+
+num_points = 4
+
+las = laspy.read("0.las")
+las.points = las.points[:num_points]
+
+points = np.random.randint(low=0, high=100, size=(num_points,), dtype=las.x.dtype)
+las.x = points
+las.y = points
+las.z = points
+
+if hasattr(las, "red"):
+ colors = np.random.randint(low=0, high=10, size=(num_points,), dtype=las.red.dtype)
+ las.red = colors
+ las.green = colors
+ las.blue = colors
+
+las.write("0.las")
+```
diff --git a/tests/data/idtrees/IDTREES_competition_test_v2.zip b/tests/data/idtrees/IDTREES_competition_test_v2.zip
new file mode 100644
index 000000000..26870fef1
Binary files /dev/null and b/tests/data/idtrees/IDTREES_competition_test_v2.zip differ
diff --git a/tests/data/idtrees/IDTREES_competition_train_v2.zip b/tests/data/idtrees/IDTREES_competition_train_v2.zip
new file mode 100644
index 000000000..1385991f4
Binary files /dev/null and b/tests/data/idtrees/IDTREES_competition_train_v2.zip differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif b/tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif
new file mode 100644
index 000000000..22369e56f
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/CHM/MLBS_4.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif b/tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif
new file mode 100644
index 000000000..21dc8b0cd
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/CHM/OSBS_11.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif b/tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif
new file mode 100644
index 000000000..8e3f03ef0
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/CHM/TALL_1.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif b/tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif
new file mode 100644
index 000000000..5a2dd4408
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/HSI/MLBS_4.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif b/tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif
new file mode 100644
index 000000000..39f61a017
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/HSI/OSBS_11.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif b/tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif
new file mode 100644
index 000000000..e03b7bf76
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/HSI/TALL_1.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las b/tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las
new file mode 100644
index 000000000..f185afc11
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/LAS/MLBS_4.las differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las b/tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las
new file mode 100644
index 000000000..19f0159ca
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/LAS/OSBS_11.las differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las b/tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las
new file mode 100644
index 000000000..723efb46d
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/LAS/TALL_1.las differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif b/tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif
new file mode 100644
index 000000000..d6b3d6dcb
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/RGB/MLBS_4.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif b/tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif
new file mode 100644
index 000000000..7525657c0
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/RGB/OSBS_11.tif differ
diff --git a/tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif b/tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif
new file mode 100644
index 000000000..d5692e445
Binary files /dev/null and b/tests/data/idtrees/task1/RemoteSensing/RGB/TALL_1.tif differ
diff --git a/tests/data/idtrees/task2/ITC/test_MLBS.dbf b/tests/data/idtrees/task2/ITC/test_MLBS.dbf
new file mode 100644
index 000000000..e7d06021b
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_MLBS.dbf differ
diff --git a/tests/data/idtrees/task2/ITC/test_MLBS.shp b/tests/data/idtrees/task2/ITC/test_MLBS.shp
new file mode 100644
index 000000000..45d90f102
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_MLBS.shp differ
diff --git a/tests/data/idtrees/task2/ITC/test_MLBS.shx b/tests/data/idtrees/task2/ITC/test_MLBS.shx
new file mode 100644
index 000000000..f9e0f0ae9
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_MLBS.shx differ
diff --git a/tests/data/idtrees/task2/ITC/test_OSBS.dbf b/tests/data/idtrees/task2/ITC/test_OSBS.dbf
new file mode 100644
index 000000000..1f64986d8
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_OSBS.dbf differ
diff --git a/tests/data/idtrees/task2/ITC/test_OSBS.shp b/tests/data/idtrees/task2/ITC/test_OSBS.shp
new file mode 100644
index 000000000..fb4ce2f42
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_OSBS.shp differ
diff --git a/tests/data/idtrees/task2/ITC/test_OSBS.shx b/tests/data/idtrees/task2/ITC/test_OSBS.shx
new file mode 100644
index 000000000..afa9cfce3
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_OSBS.shx differ
diff --git a/tests/data/idtrees/task2/ITC/test_TALL.dbf b/tests/data/idtrees/task2/ITC/test_TALL.dbf
new file mode 100644
index 000000000..c8e498acb
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_TALL.dbf differ
diff --git a/tests/data/idtrees/task2/ITC/test_TALL.shp b/tests/data/idtrees/task2/ITC/test_TALL.shp
new file mode 100644
index 000000000..573fc771e
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_TALL.shp differ
diff --git a/tests/data/idtrees/task2/ITC/test_TALL.shx b/tests/data/idtrees/task2/ITC/test_TALL.shx
new file mode 100644
index 000000000..5dccd1df8
Binary files /dev/null and b/tests/data/idtrees/task2/ITC/test_TALL.shx differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif b/tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif
new file mode 100644
index 000000000..5c9f2b201
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/CHM/MLBS_1.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif b/tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif
new file mode 100644
index 000000000..6af65d610
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/CHM/OSBS_15.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif b/tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif
new file mode 100644
index 000000000..5f973638c
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/CHM/TALL_2.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif b/tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif
new file mode 100644
index 000000000..2976621cc
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/HSI/MLBS_1.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif b/tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif
new file mode 100644
index 000000000..592c7f7fa
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/HSI/OSBS_15.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif b/tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif
new file mode 100644
index 000000000..db0f88bd5
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/HSI/TALL_2.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las b/tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las
new file mode 100644
index 000000000..8ba3b83b9
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/LAS/MLBS_1.las differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las b/tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las
new file mode 100644
index 000000000..9edbe69c2
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/LAS/OSBS_15.las differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las b/tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las
new file mode 100644
index 000000000..f26b1000a
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/LAS/TALL_2.las differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif b/tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif
new file mode 100644
index 000000000..f3ff31fb2
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/RGB/MLBS_1.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif b/tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif
new file mode 100644
index 000000000..1635be545
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/RGB/OSBS_15.tif differ
diff --git a/tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif b/tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif
new file mode 100644
index 000000000..9a07eaff9
Binary files /dev/null and b/tests/data/idtrees/task2/RemoteSensing/RGB/TALL_2.tif differ
diff --git a/tests/data/idtrees/train/ITC/train_MLBS.dbf b/tests/data/idtrees/train/ITC/train_MLBS.dbf
new file mode 100644
index 000000000..d88f8c20c
Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_MLBS.dbf differ
diff --git a/tests/data/idtrees/train/ITC/train_MLBS.shp b/tests/data/idtrees/train/ITC/train_MLBS.shp
new file mode 100644
index 000000000..070be5da1
Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_MLBS.shp differ
diff --git a/tests/data/idtrees/train/ITC/train_MLBS.shx b/tests/data/idtrees/train/ITC/train_MLBS.shx
new file mode 100644
index 000000000..7157402b3
Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_MLBS.shx differ
diff --git a/tests/data/idtrees/train/ITC/train_OSBS.dbf b/tests/data/idtrees/train/ITC/train_OSBS.dbf
new file mode 100644
index 000000000..ad8161f5c
Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_OSBS.dbf differ
diff --git a/tests/data/idtrees/train/ITC/train_OSBS.shp b/tests/data/idtrees/train/ITC/train_OSBS.shp
new file mode 100644
index 000000000..9ee65c3e2
Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_OSBS.shp differ
diff --git a/tests/data/idtrees/train/ITC/train_OSBS.shx b/tests/data/idtrees/train/ITC/train_OSBS.shx
new file mode 100644
index 000000000..c514c8544
Binary files /dev/null and b/tests/data/idtrees/train/ITC/train_OSBS.shx differ
diff --git a/tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif b/tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif
new file mode 100644
index 000000000..bfa936c9e
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/CHM/MLBS_1.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif
new file mode 100644
index 000000000..a072fc25e
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_1.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif
new file mode 100644
index 000000000..191d9a16a
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/CHM/OSBS_39.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif b/tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif
new file mode 100644
index 000000000..0fad804da
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/HSI/MLBS_1.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif
new file mode 100644
index 000000000..6b78adbbd
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_1.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif
new file mode 100644
index 000000000..03ea53f1b
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/HSI/OSBS_39.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las b/tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las
new file mode 100644
index 000000000..4960fe63c
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/LAS/MLBS_1.las differ
diff --git a/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las
new file mode 100644
index 000000000..494a1ad82
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_1.las differ
diff --git a/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las
new file mode 100644
index 000000000..760476898
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/LAS/OSBS_39.las differ
diff --git a/tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif b/tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif
new file mode 100644
index 000000000..064644fc9
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/RGB/MLBS_1.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif
new file mode 100644
index 000000000..b499dc59e
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_1.tif differ
diff --git a/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif
new file mode 100644
index 000000000..f4eb86a0b
Binary files /dev/null and b/tests/data/idtrees/train/RemoteSensing/RGB/OSBS_39.tif differ
diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py
new file mode 100644
index 000000000..9f39e47e6
--- /dev/null
+++ b/tests/datasets/test_idtrees.py
@@ -0,0 +1,162 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import builtins
+import glob
+import os
+import shutil
+import sys
+from pathlib import Path
+from typing import Any, Generator
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+from _pytest.monkeypatch import MonkeyPatch
+
+import torchgeo.datasets.utils
+from torchgeo.datasets import IDTReeS
+
+
+def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
+ shutil.copy(url, root)
+
+
+class TestIDTReeS:
+ @pytest.fixture(params=zip(["train", "test", "test"], ["task1", "task1", "task2"]))
+ def dataset(
+ self,
+ monkeypatch: Generator[MonkeyPatch, None, None],
+ tmp_path: Path,
+ request: SubRequest,
+ ) -> IDTReeS:
+ pytest.importorskip("pandas")
+ pytest.importorskip("laspy")
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ torchgeo.datasets.idtrees, "download_url", download_url
+ )
+ data_dir = os.path.join("tests", "data", "idtrees")
+ metadata = {
+ "train": {
+ "url": os.path.join(data_dir, "IDTREES_competition_train_v2.zip"),
+ "md5": "5ddfa76240b4bb6b4a7861d1d31c299c",
+ "filename": "IDTREES_competition_train_v2.zip",
+ },
+ "test": {
+ "url": os.path.join(data_dir, "IDTREES_competition_test_v2.zip"),
+ "md5": "b108931c84a70f2a38a8234290131c9b",
+ "filename": "IDTREES_competition_test_v2.zip",
+ },
+ }
+ split, task = request.param
+ monkeypatch.setattr(IDTReeS, "metadata", metadata) # type: ignore[attr-defined]
+ root = str(tmp_path)
+ transforms = nn.Identity() # type: ignore[attr-defined]
+ return IDTReeS(root, split, task, transforms, download=True, checksum=True)
+
+ @pytest.fixture(params=["pandas", "laspy", "open3d"])
+ def mock_missing_module(
+ self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
+ ) -> str:
+ import_orig = builtins.__import__
+ package = str(request.param)
+
+ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
+ if name == package:
+ raise ImportError()
+ return import_orig(name, *args, **kwargs)
+
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ builtins, "__import__", mocked_import
+ )
+ return package
+
+ def test_getitem(self, dataset: IDTReeS) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["chm"], torch.Tensor)
+ assert isinstance(x["hsi"], torch.Tensor)
+ assert isinstance(x["las"], torch.Tensor)
+ assert x["image"].shape == (3, 200, 200)
+ assert x["chm"].shape == (1, 200, 200)
+ assert x["hsi"].shape == (369, 200, 200)
+ assert x["las"].ndim == 2
+ assert x["las"].shape[0] == 3
+
+ if "label" in x:
+ assert isinstance(x["label"], torch.Tensor)
+ if "boxes" in x:
+ assert isinstance(x["boxes"], torch.Tensor)
+ if x["boxes"].ndim != 1:
+ assert x["boxes"].ndim == 2
+ assert x["boxes"].shape[-1] == 4
+
+ def test_len(self, dataset: IDTReeS) -> None:
+ assert len(dataset) == 3
+
+ def test_already_downloaded(self, dataset: IDTReeS) -> None:
+ IDTReeS(root=dataset.root, download=True)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ err = "Dataset not found in `root` directory and `download=False`, "
+ "either specify a different `root` directory or use `download=True` "
+ "to automaticaly download the dataset."
+ with pytest.raises(RuntimeError, match=err):
+ IDTReeS(str(tmp_path))
+
+ def test_not_extracted(self, tmp_path: Path) -> None:
+ pathname = os.path.join("tests", "data", "idtrees", "*.zip")
+ root = str(tmp_path)
+ for zipfile in glob.iglob(pathname):
+ shutil.copy(zipfile, root)
+ IDTReeS(root)
+
+ def test_mock_missing_module(
+ self, dataset: IDTReeS, mock_missing_module: str
+ ) -> None:
+ package = mock_missing_module
+
+ if package in ["pandas", "laspy"]:
+ with pytest.raises(
+ ImportError,
+ match=f"{package} is not installed and is required to use this dataset",
+ ):
+ IDTReeS(dataset.root, download=True, checksum=True)
+ else:
+ with pytest.raises(
+ ImportError,
+ match=f"{package} is not installed and is required to use this dataset",
+ ):
+ dataset.plot_las(0)
+
+ def test_plot(self, dataset: IDTReeS) -> None:
+ x = dataset[0].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+
+ if "boxes" in x:
+ x["prediction_boxes"] = x["boxes"]
+ dataset.plot(x, show_titles=True)
+ plt.close()
+ if "label" in x:
+ x["prediction_label"] = x["label"]
+ dataset.plot(x, show_titles=False)
+ plt.close()
+
+ @pytest.mark.skipif(
+ sys.platform in ["darwin", "win32"],
+ reason="segmentation fault on macOS and windows",
+ )
+ def test_plot_las(self, dataset: IDTReeS) -> None:
+ pytest.importorskip("open3d")
+ vis = dataset.plot_las(index=0, colormap="BrBG")
+ vis.close()
+ vis = dataset.plot_las(index=0, colormap=None)
+ vis.close()
+ vis = dataset.plot_las(index=1, colormap=None)
+ vis.close()
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index ccd5cefc2..4e29a516e 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -37,6 +37,7 @@ from .geo import (
VisionDataset,
)
from .gid15 import GID15
+from .idtrees import IDTReeS
from .landcoverai import LandCoverAI, LandCoverAIDataModule
from .landsat import (
Landsat,
@@ -115,6 +116,7 @@ __all__ = (
"EuroSAT",
"EuroSATDataModule",
"GID15",
+ "IDTReeS",
"LandCoverAI",
"LandCoverAIDataModule",
"LEVIRCDPlus",
diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py
new file mode 100644
index 000000000..aff6ac82d
--- /dev/null
+++ b/torchgeo/datasets/idtrees.py
@@ -0,0 +1,553 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""IDTReeS dataset."""
+
+import glob
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import fiona
+import matplotlib.pyplot as plt
+import numpy as np
+import rasterio
+import torch
+from rasterio.enums import Resampling
+from torch import Tensor
+from torchvision.utils import draw_bounding_boxes
+
+from .geo import VisionDataset
+from .utils import download_url, extract_archive
+
+
+class IDTReeS(VisionDataset):
+ """IDTReeS dataset.
+
+ The `IDTReeS `_
+ dataset is a dataset for tree crown detection.
+
+ Dataset features:
+
+ * RGB Image, Canopy Height Model (CHM), Hyperspectral Image (HSI), LiDAR Point Cloud
+ * Remote sensing and field data generated by the
+ `National Ecological Observatory Network (NEON) `_
+ * 0.1 - 1m resolution imagery
+ * Task 1 - object detection (tree crown delination)
+ * Task 2 - object classification (species classification)
+ * Train set contains 85 images
+ * Test set (task 1) contains 153 images
+ * Test set (task 2) contains 353 images and tree crown polygons
+
+ Dataset format:
+
+ * optical - three-channel RGB 200x200 geotiff
+ * canopy height model - one-channel 20x20 geotiff
+ * hyperspectral - 369-channel 20x20 geotiff
+ * point cloud - Nx3 LAS file (.las), some files contain RGB colors per point
+ * shapely files (.shp) containing polygons
+ * csv file containing species labels and other metadata for each polygon
+
+ Dataset classes:
+
+ 0. ACPE
+ 1. ACRU
+ 2. ACSA3
+ 3. AMLA
+ 4. BETUL
+ 5. CAGL8
+ 6. CATO6
+ 7. FAGR
+ 8. GOLA
+ 9. LITU
+ 10. LYLU3
+ 11. MAGNO
+ 12. NYBI
+ 13. NYSY
+ 14. OXYDE
+ 15. PEPA37
+ 16. PIEL
+ 17. PIPA2
+ 18. PINUS
+ 19. PITA
+ 20. PRSE2
+ 21. QUAL
+ 22. QUCO2
+ 23. QUGE2
+ 24. QUHE2
+ 25. QULA2
+ 26. QULA3
+ 27. QUMO4
+ 28. QUNI
+ 29. QURU
+ 30. QUERC
+ 31. ROPS
+ 32. TSCA
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://doi.org/10.7717/peerj.5843
+
+ .. versionadded:: 0.2
+ """
+
+ classes = {
+ "ACPE": "Acer pensylvanicum L.",
+ "ACRU": "Acer rubrum L.",
+ "ACSA3": "Acer saccharum Marshall",
+ "AMLA": "Amelanchier laevis Wiegand",
+ "BETUL": "Betula sp.",
+ "CAGL8": "Carya glabra (Mill.) Sweet",
+ "CATO6": "Carya tomentosa (Lam.) Nutt.",
+ "FAGR": "Fagus grandifolia Ehrh.",
+ "GOLA": "Gordonia lasianthus (L.) Ellis",
+ "LITU": "Liriodendron tulipifera L.",
+ "LYLU3": "Lyonia lucida (Lam.) K. Koch",
+ "MAGNO": "Magnolia sp.",
+ "NYBI": "Nyssa biflora Walter",
+ "NYSY": "Nyssa sylvatica Marshall",
+ "OXYDE": "Oxydendrum sp.",
+ "PEPA37": "Persea palustris (Raf.) Sarg.",
+ "PIEL": "Pinus elliottii Engelm.",
+ "PIPA2": "Pinus palustris Mill.",
+ "PINUS": "Pinus sp.",
+ "PITA": "Pinus taeda L.",
+ "PRSE2": "Prunus serotina Ehrh.",
+ "QUAL": "Quercus alba L.",
+ "QUCO2": "Quercus coccinea",
+ "QUGE2": "Quercus geminata Small",
+ "QUHE2": "Quercus hemisphaerica W. Bartram ex Willd.",
+ "QULA2": "Quercus laevis Walter",
+ "QULA3": "Quercus laurifolia Michx.",
+ "QUMO4": "Quercus montana Willd.",
+ "QUNI": "Quercus nigra L.",
+ "QURU": "Quercus rubra L.",
+ "QUERC": "Quercus sp.",
+ "ROPS": "Robinia pseudoacacia L.",
+ "TSCA": "Tsuga canadensis (L.) Carriere",
+ }
+ metadata = {
+ "train": {
+ "url": "https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1", # noqa: E501
+ "md5": "5ddfa76240b4bb6b4a7861d1d31c299c",
+ "filename": "IDTREES_competition_train_v2.zip",
+ },
+ "test": {
+ "url": "https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1", # noqa: E501
+ "md5": "b108931c84a70f2a38a8234290131c9b",
+ "filename": "IDTREES_competition_test_v2.zip",
+ },
+ }
+ directories = {"train": ["train"], "test": ["task1", "task2"]}
+ image_size = (200, 200)
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "train",
+ task: str = "task1",
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new IDTReeS dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ split: one of "train" or "test"
+ task: 'task1' for detection, 'task2' for detection + classification
+ (only relevant for split='test')
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ ImportError: if laspy or pandas are are not installed
+ """
+ assert split in ["train", "test"]
+ assert task in ["task1", "task2"]
+ self.root = root
+ self.split = split
+ self.task = task
+ self.transforms = transforms
+ self.download = download
+ self.checksum = checksum
+ self.class2idx = {c: i for i, c in enumerate(self.classes)}
+ self.idx2class = {i: c for i, c in enumerate(self.classes)}
+ self.num_classes = len(self.classes)
+ self._verify()
+
+ try:
+ import pandas as pd # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "pandas is not installed and is required to use this dataset"
+ )
+ try:
+ import laspy # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "laspy is not installed and is required to use this dataset"
+ )
+
+ self.images, self.geometries, self.labels = self._load(root)
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ path = self.images[index]
+ image = self._load_image(path).to(torch.uint8) # type:ignore[attr-defined]
+ hsi = self._load_image(path.replace("RGB", "HSI"))
+ chm = self._load_image(path.replace("RGB", "CHM"))
+ las = self._load_las(path.replace("RGB", "LAS").replace(".tif", ".las"))
+ sample = {"image": image, "hsi": hsi, "chm": chm, "las": las}
+
+ if self.split == "test":
+ if self.task == "task2":
+ sample["boxes"] = self._load_boxes(path)
+ else:
+ sample["boxes"] = self._load_boxes(path)
+ sample["label"] = self._load_target(path)
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.images)
+
+ def _load_image(self, path: str) -> Tensor:
+ """Load a tiff file.
+
+ Args:
+ path: path to .tif file
+
+ Returns:
+ the image
+ """
+ with rasterio.open(path) as f:
+ array = f.read(out_shape=self.image_size, resampling=Resampling.bilinear)
+ tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
+ return tensor
+
+ def _load_las(self, path: str) -> Tensor:
+ """Load a single point cloud.
+
+ Args:
+ path: path to .las file
+
+ Returns:
+ the point cloud
+ """
+ import laspy
+
+ las = laspy.read(path)
+ array = np.stack([las.x, las.y, las.z], axis=0)
+ tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
+ return tensor
+
+ def _load_boxes(self, path: str) -> Tensor:
+ """Load object bounding boxes.
+
+ Args:
+ path: path to .tif file
+
+ Returns:
+ the bounding boxes
+ """
+ base_path = os.path.basename(path)
+
+ # Find object ids and geometries
+ if self.split == "train":
+ indices = self.labels["rsFile"] == base_path
+ ids = self.labels[indices]["id"].tolist()
+ geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
+ # Test set - Task 2 has no mapping csv. Mapping is inside of geometry
+ else:
+ ids = [
+ k
+ for k, v in self.geometries.items()
+ if v["properties"]["plotID"] == base_path
+ ]
+ geoms = [self.geometries[i]["geometry"]["coordinates"][0][:4] for i in ids]
+
+ # Convert to pixel coords
+ boxes = []
+ with rasterio.open(path) as f:
+ for geom in geoms:
+ coords = [f.index(x, y) for x, y in geom]
+ xmin = min([coord[0] for coord in coords])
+ xmax = max([coord[0] for coord in coords])
+ ymin = min([coord[1] for coord in coords])
+ ymax = max([coord[1] for coord in coords])
+ boxes.append([xmin, ymin, xmax, ymax])
+
+ tensor: Tensor = torch.tensor(boxes) # type: ignore[attr-defined]
+ return tensor
+
+ def _load_target(self, path: str) -> Tensor:
+ """Load target label for a single sample.
+
+ Args:
+ path: path to image
+
+ Returns:
+ the label
+ """
+ # Find indices for objects in the image
+ base_path = os.path.basename(path)
+ indices = self.labels["rsFile"] == base_path
+
+ # Load object labels
+ classes = self.labels[indices]["taxonID"].tolist()
+ labels = [self.class2idx[c] for c in classes]
+ tensor: Tensor = torch.tensor(labels) # type: ignore[attr-defined]
+ return tensor
+
+ def _load(self, root: str) -> Tuple[List[str], Dict[int, Dict[str, Any]], Any]:
+ """Load files, geometries, and labels.
+
+ Args:
+ root: root directory
+
+ Returns:
+ the image path, geometries, and labels
+ """
+ import pandas as pd
+
+ if self.split == "train":
+ directory = os.path.join(root, self.directories[self.split][0])
+ labels: pd.DataFrame = self._load_labels(directory)
+ geoms = self._load_geometries(directory)
+ else:
+ directory = os.path.join(root, self.task)
+ if self.task == "task1":
+ geoms = None # type: ignore[assignment]
+ labels = None
+ else:
+ geoms = self._load_geometries(directory)
+ labels = None
+
+ images = glob.glob(os.path.join(directory, "RemoteSensing", "RGB", "*.tif"))
+
+ return images, geoms, labels
+
+ def _load_labels(self, directory: str) -> Any:
+ """Load the csv files containing the labels.
+
+ Args:
+ directory: directory containing csv files
+
+ Returns:
+ a pandas DataFrame containing the labels for each image
+ """
+ import pandas as pd
+
+ path_mapping = os.path.join(directory, "Field", "itc_rsFile.csv")
+ path_labels = os.path.join(directory, "Field", "train_data.csv")
+ df_mapping = pd.read_csv(path_mapping)
+ df_labels = pd.read_csv(path_labels)
+ df_mapping = df_mapping.set_index("indvdID", drop=True)
+ df_labels = df_labels.set_index("indvdID", drop=True)
+ df = df_labels.join(df_mapping, on="indvdID")
+ df = df.drop_duplicates()
+ df.reset_index()
+ return df
+
+ def _load_geometries(self, directory: str) -> Dict[int, Dict[str, Any]]:
+ """Load the shape files containing the geometries.
+
+ Args:
+ directory: directory containing .shp files
+
+ Returns:
+ a dict containing the geometries for each object
+ """
+ filepaths = glob.glob(os.path.join(directory, "ITC", "*.shp"))
+
+ features: Dict[int, Dict[str, Any]] = {}
+ for path in filepaths:
+ with fiona.open(path) as src:
+ for i, feature in enumerate(src):
+ if self.split == "train":
+ features[feature["properties"]["id"]] = feature
+ # Test set task 2 has no id
+ else:
+ features[i] = feature
+ return features
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset.
+
+ Raises:
+ RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ """
+ url = self.metadata[self.split]["url"]
+ md5 = self.metadata[self.split]["md5"]
+ filename = self.metadata[self.split]["filename"]
+ directories = self.directories[self.split]
+
+ # Check if the files already exist
+ exists = [
+ os.path.exists(os.path.join(self.root, directory))
+ for directory in directories
+ ]
+ if all(exists):
+ return
+
+ # Check if zip file already exists (if so then extract)
+ filepath = os.path.join(self.root, filename)
+ if os.path.exists(filepath):
+ extract_archive(filepath)
+ return
+
+ # Check if the user requested to download the dataset
+ if not self.download:
+ raise RuntimeError(
+ "Dataset not found in `root` directory and `download=False`, "
+ "either specify a different `root` directory or use `download=True` "
+ "to automaticaly download the dataset."
+ )
+
+ # Download and extract the dataset
+ download_url(
+ url, self.root, filename=filename, md5=md5 if self.checksum else None
+ )
+ filepath = os.path.join(self.root, filename)
+ extract_archive(filepath)
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ hsi_indices: Tuple[int, int, int] = (0, 1, 2),
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+ hsi_indices: tuple of indices to create HSI false color image
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ assert len(hsi_indices) == 3
+
+ def normalize(x: Tensor) -> Tensor:
+ return (x - x.min()) / (x.max() - x.min())
+
+ ncols = 3
+
+ hsi = normalize(sample["hsi"][hsi_indices, :, :]).permute((1, 2, 0)).numpy()
+ chm = normalize(sample["chm"]).permute((1, 2, 0)).numpy()
+
+ if "boxes" in sample:
+ labels = (
+ [self.idx2class[int(i)] for i in sample["label"]]
+ if "label" in sample
+ else None
+ )
+ image = draw_bounding_boxes(
+ image=sample["image"], boxes=sample["boxes"], labels=labels
+ )
+ image = image.permute((1, 2, 0)).numpy()
+ else:
+ image = sample["image"].permute((1, 2, 0)).numpy()
+
+ if "prediction_boxes" in sample:
+ ncols += 1
+ labels = (
+ [self.idx2class[int(i)] for i in sample["prediction_label"]]
+ if "prediction_label" in sample
+ else None
+ )
+ preds = draw_bounding_boxes(
+ image=sample["image"], boxes=sample["prediction_boxes"], labels=labels
+ )
+ preds = preds.permute((1, 2, 0)).numpy()
+
+ fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
+ axs[0].imshow(image)
+ axs[0].axis("off")
+ axs[1].imshow(hsi)
+ axs[1].axis("off")
+ axs[2].imshow(chm)
+ axs[2].axis("off")
+ if ncols > 3:
+ axs[3].imshow(preds)
+ axs[3].axis("off")
+
+ if show_titles:
+ axs[0].set_title("Ground Truth")
+ axs[1].set_title("Hyperspectral False Color Image")
+ axs[2].set_title("Canopy Height Model")
+ if ncols > 3:
+ axs[3].set_title("Predictions")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+
+ return fig
+
+ def plot_las(self, index: int, colormap: Optional[str] = None) -> Any:
+ """Plot a sample point cloud at the index.
+
+ Args:
+ index: index to plot
+ colormap: a valid matplotlib colormap
+
+ Returns:
+ a open3d.visualizer.Visualizer object. Use
+ Visualizer.run() to display
+
+ Raises:
+ ImportError: if open3d is not installed
+ """
+ try:
+ import open3d # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "open3d is not installed and is required to use this dataset"
+ )
+ import laspy
+
+ path = self.images[index]
+ path = path.replace("RGB", "LAS").replace(".tif", ".las")
+ las = laspy.read(path)
+ points = np.stack([las.x, las.y, las.z], axis=0).transpose((1, 0))
+
+ if colormap:
+ cm = plt.cm.get_cmap(colormap)
+ norm = plt.Normalize()
+ colors = cm(norm(points[:, 2]))[:, :3]
+ else:
+ # Some point cloud files have no color->points mapping
+ if hasattr(las, "red"):
+ colors = np.stack([las.red, las.green, las.blue], axis=0)
+ colors = colors.transpose((1, 0)) / 65535
+ # Default to no colormap if no colors exist in las file
+ else:
+ colors = np.zeros_like(points)
+
+ pcd = open3d.geometry.PointCloud()
+ pcd.points = open3d.utility.Vector3dVector(points)
+ pcd.colors = open3d.utility.Vector3dVector(colors)
+ vis = open3d.visualization.Visualizer()
+ vis.create_window()
+ vis.add_geometry(pcd)
+ return vis