CDL/NLCD/SSL4EO: allow selection of classes (#1392)

* CDL/NLCD/SSL4EO: allow selection of classes

* 0 is already 0

* Get NLCD tests to pass

* Search recursively for NLCD files

* Update CDL

* Update SSL4EO-L Benchmark

* Passing tests

* Test SSL4EO-L Benchmark

* Test CDL

* Test NLCD

* Mypy fix

* Remove debugging code
This commit is contained in:
Adam J. Stewart 2023-06-04 10:21:05 -05:00 коммит произвёл GitHub
Родитель 3d436d3057
Коммит 9e57f27818
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
21 изменённых файлов: 316 добавлений и 420 удалений

Просмотреть файл

@ -14,7 +14,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
root: "tests/data/ssl4eo_benchmark_landsat"
input_sensor: "tm_toa"
mask_product: "cdl"
sensor: "tm_toa"
product: "cdl"
batch_size: 2
num_workers: 0

Просмотреть файл

@ -14,7 +14,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
root: "tests/data/ssl4eo_benchmark_landsat"
input_sensor: "etm_sr"
mask_product: "nlcd"
sensor: "etm_sr"
product: "nlcd"
batch_size: 2
num_workers: 0

Двоичные данные
tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cdl/2020_30m_cdls/2020_30m_cdls.tif.ovr Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cdl/2021_30m_cdls/2021_30m_cdls.tif.ovr Normal file

Двоичный файл не отображается.

0
tests/data/fire_risk/data.py Normal file → Executable file
Просмотреть файл

0
tests/data/nlcd/data.py Normal file → Executable file
Просмотреть файл

Просмотреть файл

0
tests/data/skippd/data.py Normal file → Executable file
Просмотреть файл

0
tests/data/spacenet/data.py Normal file → Executable file
Просмотреть файл

0
tests/data/ssl4eo_benchmark_landsat/data.py Normal file → Executable file
Просмотреть файл

0
tests/data/sustainbench_crop_yield/data.py Normal file → Executable file
Просмотреть файл

0
tests/data/vhr10/data.py Normal file → Executable file
Просмотреть файл

0
tests/data/western_usa_live_fuel_moisture/data.py Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -51,6 +51,14 @@ class TestCDL:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_classes(self) -> None:
root = os.path.join("tests", "data", "cdl")
classes = list(CDL.cmap.keys())[:5]
ds = CDL(root, years=[2021], classes=classes)
sample = ds[ds.bounds]
mask = sample["mask"]
assert mask.max() < len(classes)
def test_and(self, dataset: CDL) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
@ -82,6 +90,13 @@ class TestCDL:
):
CDL(str(tmp_path), years=[1996])
def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
CDL(classes=[-1])
with pytest.raises(AssertionError):
CDL(classes=[11])
def test_plot(self, dataset: CDL) -> None:
query = dataset.bounds
x = dataset[query]

Просмотреть файл

@ -52,6 +52,14 @@ class TestNLCD:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_classes(self) -> None:
root = os.path.join("tests", "data", "nlcd")
classes = list(NLCD.cmap.keys())[:5]
ds = NLCD(root, years=[2019], classes=classes)
sample = ds[ds.bounds]
mask = sample["mask"]
assert mask.max() < len(classes)
def test_and(self, dataset: NLCD) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)
@ -78,6 +86,13 @@ class TestNLCD:
):
NLCD(str(tmp_path), years=[1996])
def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
NLCD(classes=[-1])
with pytest.raises(AssertionError):
NLCD(classes=[11])
def test_plot(self, dataset: NLCD) -> None:
query = dataset.bounds
x = dataset[query]

Просмотреть файл

@ -16,7 +16,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import SSL4EOLBenchmark
from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -42,7 +42,7 @@ class TestSSL4EOLBenchmark:
url = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "{}.tar.gz")
monkeypatch.setattr(SSL4EOLBenchmark, "url", url)
input_sensor, mask_product, split = request.param
sensor, product, split = request.param
monkeypatch.setattr(
SSL4EOLBenchmark, "split_percentages", [1 / 3, 1 / 3, 1 / 3]
)
@ -75,8 +75,8 @@ class TestSSL4EOLBenchmark:
transforms = nn.Identity()
return SSL4EOLBenchmark(
root=root,
input_sensor=input_sensor,
mask_product=mask_product,
sensor=sensor,
product=product,
split=split,
transforms=transforms,
download=True,
@ -89,17 +89,33 @@ class TestSSL4EOLBenchmark:
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
@pytest.mark.parametrize("product,base_class", [("nlcd", NLCD), ("cdl", CDL)])
def test_classes(self, product: str, base_class: RasterDataset) -> None:
root = os.path.join("tests", "data", "ssl4eo_benchmark_landsat")
classes = list(base_class.cmap.keys())[:5]
ds = SSL4EOLBenchmark(root, product=product, classes=classes)
sample = ds[0]
mask = sample["mask"]
assert mask.max() < len(classes)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(split="foo")
def test_invalid_input_sensor(self) -> None:
def test_invalid_sensor(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(input_sensor="foo")
SSL4EOLBenchmark(sensor="foo")
def test_invalid_mask_product(self) -> None:
def test_invalid_product(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(mask_product="foo")
SSL4EOLBenchmark(product="foo")
def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
SSL4EOLBenchmark(classes=[-1])
with pytest.raises(AssertionError):
SSL4EOLBenchmark(classes=[11])
def test_add(self, dataset: SSL4EOLBenchmark) -> None:
ds = dataset + dataset
@ -108,8 +124,8 @@ class TestSSL4EOLBenchmark:
def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None:
SSL4EOLBenchmark(
root=dataset.root,
input_sensor=dataset.input_sensor,
mask_product=dataset.mask_product,
sensor=dataset.sensor,
product=dataset.product,
download=True,
)

Просмотреть файл

@ -8,8 +8,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
import torch
from rasterio.crs import CRS
from .geo import RasterDataset
@ -67,277 +66,140 @@ class CDL(RasterDataset):
}
cmap = {
0: (0, 0, 0, 0),
0: (0, 0, 0, 255),
1: (255, 211, 0, 255),
2: (255, 38, 38, 255),
3: (0, 168, 228, 255),
4: (255, 158, 11, 255),
5: (38, 112, 0, 255),
2: (255, 37, 37, 255),
3: (0, 168, 226, 255),
4: (255, 158, 9, 255),
5: (37, 111, 0, 255),
6: (255, 255, 0, 255),
7: (112, 165, 0, 255),
8: (0, 175, 75, 255),
9: (221, 165, 11, 255),
10: (221, 165, 11, 255),
11: (126, 211, 255, 255),
12: (226, 0, 124, 255),
13: (137, 98, 84, 255),
14: (216, 181, 107, 255),
15: (165, 112, 0, 255),
16: (214, 158, 188, 255),
17: (112, 112, 0, 255),
18: (172, 0, 124, 255),
19: (160, 89, 137, 255),
20: (112, 0, 73, 255),
21: (214, 158, 188, 255),
22: (209, 255, 0, 255),
23: (126, 153, 255, 255),
24: (214, 214, 0, 255),
25: (209, 255, 0, 255),
26: (0, 175, 75, 255),
27: (255, 165, 226, 255),
28: (165, 242, 140, 255),
29: (0, 175, 75, 255),
30: (214, 158, 188, 255),
31: (168, 0, 228, 255),
32: (165, 0, 0, 255),
33: (112, 38, 0, 255),
34: (0, 175, 75, 255),
35: (177, 126, 255, 255),
36: (112, 38, 0, 255),
37: (255, 102, 102, 255),
38: (255, 102, 102, 255),
39: (255, 204, 102, 255),
40: (255, 102, 102, 255),
41: (0, 175, 75, 255),
42: (0, 221, 175, 255),
43: (84, 255, 0, 255),
44: (242, 163, 119, 255),
45: (255, 102, 102, 255),
46: (0, 175, 75, 255),
47: (126, 211, 255, 255),
48: (232, 191, 255, 255),
49: (175, 255, 221, 255),
50: (0, 175, 75, 255),
51: (191, 191, 119, 255),
52: (147, 204, 147, 255),
53: (198, 214, 158, 255),
54: (204, 191, 163, 255),
55: (255, 0, 255, 255),
56: (255, 142, 170, 255),
57: (186, 0, 79, 255),
58: (112, 68, 137, 255),
59: (0, 119, 119, 255),
60: (177, 154, 112, 255),
61: (255, 255, 126, 255),
62: (181, 112, 91, 255),
63: (0, 165, 130, 255),
64: (233, 214, 175, 255),
65: (177, 154, 112, 255),
66: (242, 242, 242, 255),
67: (154, 154, 154, 255),
68: (75, 112, 163, 255),
69: (126, 177, 177, 255),
70: (232, 255, 191, 255),
71: (0, 255, 255, 255),
72: (75, 112, 163, 255),
73: (211, 226, 249, 255),
74: (154, 154, 154, 255),
75: (154, 154, 154, 255),
76: (154, 154, 154, 255),
77: (154, 154, 154, 255),
78: (204, 191, 163, 255),
79: (147, 204, 147, 255),
80: (147, 204, 147, 255),
81: (147, 204, 147, 255),
82: (198, 214, 158, 255),
83: (232, 255, 191, 255),
84: (126, 177, 177, 255),
85: (126, 177, 177, 255),
86: (0, 255, 140, 255),
87: (214, 158, 188, 255),
88: (255, 102, 102, 255),
89: (255, 102, 102, 255),
90: (255, 102, 102, 255),
91: (255, 102, 102, 255),
92: (255, 142, 170, 255),
93: (51, 73, 51, 255),
94: (228, 112, 38, 255),
95: (255, 102, 102, 255),
96: (255, 102, 102, 255),
97: (102, 153, 76, 255),
98: (255, 102, 102, 255),
99: (177, 154, 112, 255),
100: (255, 142, 170, 255),
101: (255, 102, 102, 255),
102: (255, 142, 170, 255),
103: (255, 102, 102, 255),
104: (255, 102, 102, 255),
105: (255, 142, 170, 255),
106: (0, 175, 75, 255),
107: (255, 211, 0, 255),
108: (255, 211, 0, 255),
109: (255, 102, 102, 255),
110: (255, 210, 0, 255),
111: (255, 102, 102, 255),
112: (137, 98, 84, 255),
113: (255, 102, 102, 255),
114: (255, 38, 38, 255),
115: (226, 0, 124, 255),
116: (255, 158, 11, 255),
117: (255, 158, 11, 255),
118: (165, 112, 0, 255),
119: (255, 211, 0, 255),
120: (165, 112, 0, 255),
121: (38, 112, 0, 255),
122: (38, 112, 0, 255),
123: (255, 211, 0, 255),
124: (0, 0, 153, 255),
125: (255, 102, 102, 255),
126: (255, 102, 102, 255),
127: (255, 102, 102, 255),
128: (255, 102, 102, 255),
129: (255, 102, 102, 255),
130: (255, 102, 102, 255),
131: (255, 102, 102, 255),
132: (255, 102, 102, 255),
133: (38, 112, 0, 255),
}
ordinal_label_map = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
10: 7,
11: 8,
12: 9,
13: 10,
14: 11,
21: 12,
22: 13,
23: 14,
24: 15,
25: 16,
26: 17,
27: 18,
28: 19,
29: 20,
30: 21,
31: 22,
32: 23,
33: 24,
34: 25,
35: 26,
36: 27,
37: 28,
38: 29,
39: 30,
41: 31,
42: 32,
43: 33,
44: 34,
45: 35,
46: 36,
47: 37,
48: 38,
49: 39,
50: 40,
51: 41,
52: 42,
53: 43,
54: 44,
55: 45,
56: 46,
57: 47,
58: 48,
59: 49,
60: 50,
61: 51,
63: 52,
64: 53,
65: 54,
66: 55,
67: 56,
68: 57,
69: 58,
70: 59,
71: 60,
72: 61,
74: 62,
75: 63,
76: 64,
77: 65,
81: 66,
82: 67,
83: 68,
87: 69,
88: 70,
92: 71,
111: 72,
112: 73,
121: 74,
122: 75,
123: 76,
124: 77,
131: 78,
141: 79,
142: 80,
143: 81,
152: 82,
176: 83,
190: 84,
195: 85,
204: 86,
205: 87,
206: 88,
207: 89,
208: 90,
209: 91,
210: 92,
211: 93,
212: 94,
213: 95,
214: 96,
215: 97,
216: 98,
217: 99,
218: 100,
219: 101,
220: 102,
221: 103,
222: 104,
223: 105,
224: 106,
225: 107,
226: 108,
227: 109,
228: 110,
229: 111,
230: 112,
231: 113,
232: 114,
233: 115,
234: 116,
235: 117,
236: 118,
237: 119,
238: 120,
239: 121,
240: 122,
241: 123,
242: 124,
243: 125,
244: 126,
245: 127,
246: 128,
247: 129,
248: 130,
249: 131,
250: 132,
254: 133,
10: (111, 166, 0, 255),
11: (0, 175, 73, 255),
12: (222, 166, 9, 255),
13: (222, 166, 9, 255),
14: (124, 211, 255, 255),
21: (226, 0, 124, 255),
22: (137, 96, 83, 255),
23: (217, 181, 107, 255),
24: (166, 111, 0, 255),
25: (213, 158, 188, 255),
26: (111, 111, 0, 255),
27: (171, 0, 124, 255),
28: (160, 88, 137, 255),
29: (111, 0, 73, 255),
30: (213, 158, 188, 255),
31: (209, 255, 0, 255),
32: (124, 153, 255, 255),
33: (213, 213, 0, 255),
34: (209, 255, 0, 255),
35: (0, 175, 73, 255),
36: (255, 166, 226, 255),
37: (166, 241, 139, 255),
38: (0, 175, 73, 255),
39: (213, 158, 188, 255),
41: (168, 0, 226, 255),
42: (166, 0, 0, 255),
43: (111, 37, 0, 255),
44: (0, 175, 73, 255),
45: (175, 124, 255, 255),
46: (111, 37, 0, 255),
47: (255, 102, 102, 255),
48: (255, 102, 102, 255),
49: (255, 204, 102, 255),
50: (255, 102, 102, 255),
51: (0, 175, 73, 255),
52: (0, 222, 175, 255),
53: (83, 255, 0, 255),
54: (241, 162, 120, 255),
55: (255, 102, 102, 255),
56: (0, 175, 73, 255),
57: (124, 211, 255, 255),
58: (232, 190, 255, 255),
59: (175, 255, 222, 255),
60: (0, 175, 73, 255),
61: (190, 190, 120, 255),
63: (147, 204, 147, 255),
64: (198, 213, 158, 255),
65: (204, 190, 162, 255),
66: (255, 0, 255, 255),
67: (255, 143, 171, 255),
68: (185, 0, 79, 255),
69: (111, 69, 137, 255),
70: (0, 120, 120, 255),
71: (175, 153, 111, 255),
72: (255, 255, 124, 255),
74: (181, 111, 92, 255),
75: (0, 166, 130, 255),
76: (232, 213, 175, 255),
77: (175, 153, 111, 255),
81: (241, 241, 241, 255),
82: (153, 153, 153, 255),
83: (73, 111, 162, 255),
87: (124, 175, 175, 255),
88: (232, 255, 190, 255),
92: (0, 255, 255, 255),
111: (73, 111, 162, 255),
112: (211, 226, 249, 255),
121: (153, 153, 153, 255),
122: (153, 153, 153, 255),
123: (153, 153, 153, 255),
124: (153, 153, 153, 255),
131: (204, 190, 162, 255),
141: (147, 204, 147, 255),
142: (147, 204, 147, 255),
143: (147, 204, 147, 255),
152: (198, 213, 158, 255),
176: (232, 255, 190, 255),
190: (124, 175, 175, 255),
195: (124, 175, 175, 255),
204: (0, 255, 139, 255),
205: (213, 158, 188, 255),
206: (255, 102, 102, 255),
207: (255, 102, 102, 255),
208: (255, 102, 102, 255),
209: (255, 102, 102, 255),
210: (255, 143, 171, 255),
211: (51, 73, 51, 255),
212: (226, 111, 37, 255),
213: (255, 102, 102, 255),
214: (255, 102, 102, 255),
215: (102, 153, 77, 255),
216: (255, 102, 102, 255),
217: (175, 153, 111, 255),
218: (255, 143, 171, 255),
219: (255, 102, 102, 255),
220: (255, 143, 171, 255),
221: (255, 102, 102, 255),
222: (255, 102, 102, 255),
223: (255, 143, 171, 255),
224: (0, 175, 73, 255),
225: (255, 211, 0, 255),
226: (255, 211, 0, 255),
227: (255, 102, 102, 255),
228: (255, 211, 0, 255),
229: (255, 102, 102, 255),
230: (137, 96, 83, 255),
231: (255, 102, 102, 255),
232: (255, 37, 37, 255),
233: (226, 0, 124, 255),
234: (255, 158, 9, 255),
235: (255, 158, 9, 255),
236: (166, 111, 0, 255),
237: (255, 211, 0, 255),
238: (166, 111, 0, 255),
239: (37, 111, 0, 255),
240: (37, 111, 0, 255),
241: (255, 211, 0, 255),
242: (0, 0, 153, 255),
243: (255, 102, 102, 255),
244: (255, 102, 102, 255),
245: (255, 102, 102, 255),
246: (255, 102, 102, 255),
247: (255, 102, 102, 255),
248: (255, 102, 102, 255),
249: (255, 102, 102, 255),
250: (255, 102, 102, 255),
254: (37, 111, 0, 255),
}
def __init__(
@ -346,6 +208,7 @@ class CDL(RasterDataset):
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2022],
classes: list[int] = list(cmap.keys()),
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
@ -360,6 +223,8 @@ class CDL(RasterDataset):
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
years: list of years for which to use cdl layer
classes: list of classes to include, the rest will be mapped to 0
(defaults to all classes)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
@ -367,25 +232,39 @@ class CDL(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
.. versionadded:: 0.5
The *years* parameter.
The *years* and *classes* parameters.
"""
assert set(years).issubset(self.md5s.keys()), (
assert set(years) <= self.md5s.keys(), (
"CDL data product only exists for the following years: "
f"{list(self.md5s.keys())}."
)
self.years = years
assert (
set(classes) <= self.cmap.keys()
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
self.root = root
self.years = years
self.classes = classes
self.download = download
self.checksum = checksum
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype)
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
self.ordinal_map[k] = v
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve mask and metadata indexed by query.
@ -399,13 +278,7 @@ class CDL(RasterDataset):
IndexError: if query is not found in the index
"""
sample = super().__getitem__(query)
mask = sample["mask"]
for k, v in self.ordinal_label_map.items():
mask[mask == k] = v
sample["mask"] = mask
sample["mask"] = self.ordinal_map[sample["mask"]]
return sample
def _verify(self) -> None:
@ -489,33 +362,26 @@ class CDL(RasterDataset):
Method now takes a sample dict, not a Tensor. Additionally, possible to
show subplot titles and/or use a custom suptitle.
"""
mask = sample["mask"].squeeze().numpy()
mask = sample["mask"].squeeze()
ncols = 1
showing_predictions = "prediction" in sample
if showing_predictions:
pred = sample["prediction"].squeeze().numpy()
pred = sample["prediction"].squeeze()
ncols = 2
kwargs = {
"cmap": ListedColormap(np.array(list(self.cmap.values())) / 255),
"vmin": 0,
"vmax": len(self.cmap) - 1,
"interpolation": "none",
}
fig, axs = plt.subplots(
nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False
)
axs[0, 0].imshow(mask, **kwargs)
axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none")
axs[0, 0].axis("off")
if show_titles:
axs[0, 0].set_title("Mask")
if showing_predictions:
axs[0, 1].imshow(pred, **kwargs)
axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none")
axs[0, 1].axis("off")
if show_titles:
axs[0, 1].set_title("Prediction")

Просмотреть файл

@ -3,12 +3,12 @@
"""NLCD dataset."""
import glob
import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
import torch
from rasterio.crs import CRS
from .geo import RasterDataset
@ -83,44 +83,24 @@ class NLCD(RasterDataset):
2019: "82851c3f8105763b01c83b4a9e6f3961",
}
ordinal_label_map = {
0: 0,
11: 1,
12: 2,
21: 3,
22: 4,
23: 5,
24: 6,
31: 7,
41: 8,
42: 9,
43: 10,
52: 11,
71: 12,
81: 13,
82: 14,
90: 15,
95: 16,
}
cmap = {
0: (0, 0, 0, 255),
1: (70, 107, 159, 255),
2: (209, 222, 248, 255),
3: (222, 197, 197, 255),
4: (217, 146, 130, 255),
5: (235, 0, 0, 255),
6: (171, 0, 0, 255),
7: (179, 172, 159, 255),
8: (104, 171, 95, 255),
9: (28, 95, 44, 255),
10: (181, 197, 143, 255),
11: (204, 184, 121, 255),
12: (223, 223, 194, 255),
13: (220, 217, 57, 255),
14: (171, 108, 40, 255),
15: (184, 217, 235, 255),
16: (108, 159, 184, 255),
0: (0, 0, 0, 0),
11: (70, 107, 159, 255),
12: (209, 222, 248, 255),
21: (222, 197, 197, 255),
22: (217, 146, 130, 255),
23: (235, 0, 0, 255),
24: (171, 0, 0, 255),
31: (179, 172, 159, 255),
41: (104, 171, 95, 255),
42: (28, 95, 44, 255),
43: (181, 197, 143, 255),
52: (204, 184, 121, 255),
71: (223, 223, 194, 255),
81: (220, 217, 57, 255),
82: (171, 108, 40, 255),
90: (184, 217, 235, 255),
95: (108, 159, 184, 255),
}
def __init__(
@ -129,6 +109,7 @@ class NLCD(RasterDataset):
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2019],
classes: list[int] = list(cmap.keys()),
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
@ -143,6 +124,8 @@ class NLCD(RasterDataset):
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
years: list of years for which to use nlcd layer
classes: list of classes to include, the rest will be mapped to 0
(defaults to all classes)
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
@ -150,23 +133,36 @@ class NLCD(RasterDataset):
checksum: if True, check the MD5 after downloading files (may be slow)
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``year`` is invalid
"""
assert set(years).issubset(self.md5s.keys()), (
assert set(years) <= self.md5s.keys(), (
"NLCD data product only exists for the following years: "
f"{list(self.md5s.keys())}."
)
self.years = years
assert (
set(classes) <= self.cmap.keys()
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
self.root = root
self.years = years
self.classes = classes
self.download = download
self.checksum = checksum
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype)
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
self.ordinal_map[k] = v
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve mask and metadata indexed by query.
@ -180,13 +176,7 @@ class NLCD(RasterDataset):
IndexError: if query is not found in the index
"""
sample = super().__getitem__(query)
mask = sample["mask"]
for k, v in self.ordinal_label_map.items():
mask[mask == k] = v
sample["mask"] = mask
sample["mask"] = self.ordinal_map[sample["mask"]]
return sample
def _verify(self) -> None:
@ -199,9 +189,8 @@ class NLCD(RasterDataset):
exists = []
for year in self.years:
filename_year = self.filename_glob.replace("*", str(year))
dirname_year = filename_year.split(".")[0]
pathname = os.path.join(self.root, dirname_year, filename_year)
if os.path.exists(pathname):
pathname = os.path.join(self.root, "**", filename_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
else:
exists.append(False)
@ -212,10 +201,9 @@ class NLCD(RasterDataset):
# Check if the zip files have already been downloaded
exists = []
for year in self.years:
pathname = os.path.join(
self.root, self.zipfile_glob.replace("*", str(year))
)
if os.path.exists(pathname):
zipfile_year = self.zipfile_glob.replace("*", str(year))
pathname = os.path.join(self.root, "**", zipfile_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
self._extract()
else:
@ -249,8 +237,8 @@ class NLCD(RasterDataset):
"""Extract the dataset."""
for year in self.years:
zipfile_name = self.zipfile_glob.replace("*", str(year))
pathname = os.path.join(self.root, zipfile_name)
extract_archive(pathname, self.root)
pathname = os.path.join(self.root, "**", zipfile_name)
extract_archive(glob.glob(pathname, recursive=True)[0], self.root)
def plot(
self,
@ -268,33 +256,26 @@ class NLCD(RasterDataset):
Returns:
a matplotlib Figure with the rendered sample
"""
mask = sample["mask"].squeeze().numpy()
mask = sample["mask"].squeeze()
ncols = 1
showing_predictions = "prediction" in sample
if showing_predictions:
pred = sample["prediction"].squeeze().numpy()
pred = sample["prediction"].squeeze()
ncols = 2
kwargs = {
"cmap": ListedColormap(np.array(list(self.cmap.values())) / 255),
"vmin": 0,
"vmax": len(self.cmap) - 1,
"interpolation": "none",
}
fig, axs = plt.subplots(
nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False
)
axs[0, 0].imshow(mask, **kwargs)
axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none")
axs[0, 0].axis("off")
if show_titles:
axs[0, 0].set_title("Mask")
if showing_predictions:
axs[0, 1].imshow(pred, **kwargs)
axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none")
axs[0, 1].axis("off")
if show_titles:
axs[0, 1].set_title("Prediction")

Просмотреть файл

@ -11,7 +11,6 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.colors import ListedColormap
from torch import Tensor
from .cdl import CDL
@ -45,8 +44,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
url = "https://huggingface.co/datasets/torchgeo/{}/resolve/main/{}.tar.gz"
valid_input_sensors = ["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"]
valid_mask_products = ["cdl", "nlcd"]
valid_sensors = ["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"]
valid_products = ["cdl", "nlcd"]
valid_splits = ["train", "val", "test"]
image_root = "ssl4eo_l_{}_benchmark"
@ -98,16 +97,15 @@ class SSL4EOLBenchmark(NonGeoDataset):
split_percentages = [0.7, 0.15, 0.15]
ordinal_label_map = {"nlcd": NLCD.ordinal_label_map, "cdl": CDL.ordinal_label_map}
cmaps = {"nlcd": NLCD.cmap, "cdl": CDL.cmap}
def __init__(
self,
root: str = "data",
input_sensor: str = "oli_sr",
mask_product: str = "cdl",
sensor: str = "oli_sr",
product: str = "cdl",
split: str = "train",
classes: Optional[list[int]] = None,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
@ -116,9 +114,11 @@ class SSL4EOLBenchmark(NonGeoDataset):
Args:
root: root directory where dataset can be found
input_sensor: one of ['etm_toa', 'etm_sr', 'oli_tirs_toa, 'oli_sr']
mask_product: mask target one of ['cdl', 'nlcd']
sensor: one of ['etm_toa', 'etm_sr', 'oli_tirs_toa, 'oli_sr']
product: mask target, one of ['cdl', 'nlcd']
split: dataset split, one of ['train', 'val', 'test']
classes: list of classes to include, the rest will be mapped to 0
(defaults to all classes for the chosen product)
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
@ -126,29 +126,39 @@ class SSL4EOLBenchmark(NonGeoDataset):
Raises:
AssertionError: if any arguments are invalid
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
assert (
input_sensor in self.valid_input_sensors
), f"Only supports one of {self.valid_input_sensors}, but found {input_sensor}."
self.input_sensor = input_sensor
sensor in self.valid_sensors
), f"Only supports one of {self.valid_sensors}, but found {sensor}."
self.sensor = sensor
assert (
mask_product in self.valid_mask_products
), f"Only supports one of {self.valid_mask_products}, but found {mask_product}."
self.mask_product = mask_product
product in self.valid_products
), f"Only supports one of {self.valid_products}, but found {product}."
self.product = product
assert (
split in self.valid_splits
), f"Only supports one of {self.valid_splits}, but found {split}."
self.split = split
self.cmap = self.cmaps[product]
if classes is None:
classes = list(self.cmap.keys())
assert (
set(classes) <= self.cmap.keys()
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
self.root = root
self.classes = classes
self.transforms = transforms
self.download = download
self.checksum = checksum
self.img_dir_name = self.image_root.format(self.input_sensor)
self.mask_dir_name = self.mask_dir_dict[self.input_sensor].format(
self.mask_product
)
self.cmap = self.cmaps[self.mask_product]
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=torch.long)
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
self.img_dir_name = self.image_root.format(self.sensor)
self.mask_dir_name = self.mask_dir_dict[self.sensor].format(self.product)
self._verify()
@ -169,6 +179,11 @@ class SSL4EOLBenchmark(NonGeoDataset):
self.sample_collection = [self.sample_collection[idx] for idx in split_indices]
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
self.ordinal_map[k] = v
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -183,7 +198,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
self.root,
self.mask_dir_name,
"**",
f"{self.mask_product}_{self.year_dict[self.input_sensor]}.tif",
f"{self.product}_{self.year_dict[self.sensor]}.tif",
)
exists.append(bool(glob.glob(mask_pathname, recursive=True)))
@ -219,13 +234,13 @@ class SSL4EOLBenchmark(NonGeoDataset):
download_url(
self.url.format(self.img_dir_name, self.img_dir_name),
self.root,
md5=self.img_md5s[self.input_sensor] if self.checksum else None,
md5=self.img_md5s[self.sensor] if self.checksum else None,
)
# download mask
download_url(
self.url.format(self.mask_dir_name, self.mask_dir_name),
self.root,
md5=self.mask_md5s[self.input_sensor.split("_")[0]][self.mask_product]
md5=self.mask_md5s[self.sensor.split("_")[0]][self.product]
if self.checksum
else None,
)
@ -277,8 +292,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
sample_collection: list[tuple[str, str]] = []
for img_path in img_paths:
mask_path = img_path.replace(self.img_dir_name, self.mask_dir_name).replace(
"all_bands.tif",
f"{self.mask_product}_{self.year_dict[self.input_sensor]}.tif",
"all_bands.tif", f"{self.product}_{self.year_dict[self.sensor]}.tif"
)
sample_collection.append((img_path, mask_path))
return sample_collection
@ -293,8 +307,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
image
"""
with rasterio.open(path) as src:
image = src.read().astype(np.float32)
return torch.from_numpy(image)
image = torch.from_numpy(src.read()).float()
return image
def _load_mask(self, path: str) -> Tensor:
"""Load the mask.
@ -306,12 +320,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
mask
"""
with rasterio.open(path) as src:
mask = src.read()
for k, v in self.ordinal_label_map[self.mask_product].items():
mask[mask == k] = v
return torch.from_numpy(mask).long()
mask = torch.from_numpy(src.read()).long()
return self.ordinal_map[mask]
def plot(
self,
@ -330,34 +340,27 @@ class SSL4EOLBenchmark(NonGeoDataset):
a matplotlib Figure with the rendered sample
"""
ncols = 2
image = sample["image"][self.rgb_indices[self.input_sensor]].permute(1, 2, 0)
image = image.numpy() / 255
image = sample["image"][self.rgb_indices[self.sensor]].permute(1, 2, 0)
image = image / 255
mask = sample["mask"].squeeze(0)
showing_predictions = "prediction" in sample
if showing_predictions:
prediction_mask = sample["prediction"].squeeze(0).numpy()
pred = sample["prediction"].squeeze(0)
ncols = 3
kwargs = {
"cmap": ListedColormap(np.array(list(self.cmap.values())) / 255),
"vmin": 0,
"vmax": len(self.cmap) - 1,
"interpolation": "none",
}
fig, ax = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4))
ax[0].imshow(image)
ax[0].axis("off")
ax[1].imshow(mask, **kwargs)
ax[1].imshow(self.ordinal_cmap[mask], interpolation="none")
ax[1].axis("off")
if show_titles:
ax[0].set_title("Image")
ax[1].set_title("Mask")
if showing_predictions:
ax[2].imshow(prediction_mask, **kwargs)
ax[2].imshow(self.ordinal_cmap[pred], interpolation="none")
if show_titles:
ax[2].set_title("Prediction")