зеркало из https://github.com/microsoft/torchgeo.git
CDL/NLCD/SSL4EO: allow selection of classes (#1392)
* CDL/NLCD/SSL4EO: allow selection of classes * 0 is already 0 * Get NLCD tests to pass * Search recursively for NLCD files * Update CDL * Update SSL4EO-L Benchmark * Passing tests * Test SSL4EO-L Benchmark * Test CDL * Test NLCD * Mypy fix * Remove debugging code
This commit is contained in:
Родитель
3d436d3057
Коммит
9e57f27818
|
@ -14,7 +14,7 @@ module:
|
|||
datamodule:
|
||||
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
|
||||
root: "tests/data/ssl4eo_benchmark_landsat"
|
||||
input_sensor: "tm_toa"
|
||||
mask_product: "cdl"
|
||||
sensor: "tm_toa"
|
||||
product: "cdl"
|
||||
batch_size: 2
|
||||
num_workers: 0
|
||||
|
|
|
@ -14,7 +14,7 @@ module:
|
|||
datamodule:
|
||||
_target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
|
||||
root: "tests/data/ssl4eo_benchmark_landsat"
|
||||
input_sensor: "etm_sr"
|
||||
mask_product: "nlcd"
|
||||
sensor: "etm_sr"
|
||||
product: "nlcd"
|
||||
batch_size: 2
|
||||
num_workers: 0
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -51,6 +51,14 @@ class TestCDL:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_classes(self) -> None:
|
||||
root = os.path.join("tests", "data", "cdl")
|
||||
classes = list(CDL.cmap.keys())[:5]
|
||||
ds = CDL(root, years=[2021], classes=classes)
|
||||
sample = ds[ds.bounds]
|
||||
mask = sample["mask"]
|
||||
assert mask.max() < len(classes)
|
||||
|
||||
def test_and(self, dataset: CDL) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
@ -82,6 +90,13 @@ class TestCDL:
|
|||
):
|
||||
CDL(str(tmp_path), years=[1996])
|
||||
|
||||
def test_invalid_classes(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
CDL(classes=[-1])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
CDL(classes=[11])
|
||||
|
||||
def test_plot(self, dataset: CDL) -> None:
|
||||
query = dataset.bounds
|
||||
x = dataset[query]
|
||||
|
|
|
@ -52,6 +52,14 @@ class TestNLCD:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_classes(self) -> None:
|
||||
root = os.path.join("tests", "data", "nlcd")
|
||||
classes = list(NLCD.cmap.keys())[:5]
|
||||
ds = NLCD(root, years=[2019], classes=classes)
|
||||
sample = ds[ds.bounds]
|
||||
mask = sample["mask"]
|
||||
assert mask.max() < len(classes)
|
||||
|
||||
def test_and(self, dataset: NLCD) -> None:
|
||||
ds = dataset & dataset
|
||||
assert isinstance(ds, IntersectionDataset)
|
||||
|
@ -78,6 +86,13 @@ class TestNLCD:
|
|||
):
|
||||
NLCD(str(tmp_path), years=[1996])
|
||||
|
||||
def test_invalid_classes(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
NLCD(classes=[-1])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
NLCD(classes=[11])
|
||||
|
||||
def test_plot(self, dataset: NLCD) -> None:
|
||||
query = dataset.bounds
|
||||
x = dataset[query]
|
||||
|
|
|
@ -16,7 +16,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import SSL4EOLBenchmark
|
||||
from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -42,7 +42,7 @@ class TestSSL4EOLBenchmark:
|
|||
url = os.path.join("tests", "data", "ssl4eo_benchmark_landsat", "{}.tar.gz")
|
||||
monkeypatch.setattr(SSL4EOLBenchmark, "url", url)
|
||||
|
||||
input_sensor, mask_product, split = request.param
|
||||
sensor, product, split = request.param
|
||||
monkeypatch.setattr(
|
||||
SSL4EOLBenchmark, "split_percentages", [1 / 3, 1 / 3, 1 / 3]
|
||||
)
|
||||
|
@ -75,8 +75,8 @@ class TestSSL4EOLBenchmark:
|
|||
transforms = nn.Identity()
|
||||
return SSL4EOLBenchmark(
|
||||
root=root,
|
||||
input_sensor=input_sensor,
|
||||
mask_product=mask_product,
|
||||
sensor=sensor,
|
||||
product=product,
|
||||
split=split,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
|
@ -89,17 +89,33 @@ class TestSSL4EOLBenchmark:
|
|||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
@pytest.mark.parametrize("product,base_class", [("nlcd", NLCD), ("cdl", CDL)])
|
||||
def test_classes(self, product: str, base_class: RasterDataset) -> None:
|
||||
root = os.path.join("tests", "data", "ssl4eo_benchmark_landsat")
|
||||
classes = list(base_class.cmap.keys())[:5]
|
||||
ds = SSL4EOLBenchmark(root, product=product, classes=classes)
|
||||
sample = ds[0]
|
||||
mask = sample["mask"]
|
||||
assert mask.max() < len(classes)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
SSL4EOLBenchmark(split="foo")
|
||||
|
||||
def test_invalid_input_sensor(self) -> None:
|
||||
def test_invalid_sensor(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
SSL4EOLBenchmark(input_sensor="foo")
|
||||
SSL4EOLBenchmark(sensor="foo")
|
||||
|
||||
def test_invalid_mask_product(self) -> None:
|
||||
def test_invalid_product(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
SSL4EOLBenchmark(mask_product="foo")
|
||||
SSL4EOLBenchmark(product="foo")
|
||||
|
||||
def test_invalid_classes(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
SSL4EOLBenchmark(classes=[-1])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
SSL4EOLBenchmark(classes=[11])
|
||||
|
||||
def test_add(self, dataset: SSL4EOLBenchmark) -> None:
|
||||
ds = dataset + dataset
|
||||
|
@ -108,8 +124,8 @@ class TestSSL4EOLBenchmark:
|
|||
def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None:
|
||||
SSL4EOLBenchmark(
|
||||
root=dataset.root,
|
||||
input_sensor=dataset.input_sensor,
|
||||
mask_product=dataset.mask_product,
|
||||
sensor=dataset.sensor,
|
||||
product=dataset.product,
|
||||
download=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -8,8 +8,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.colors import ListedColormap
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -67,277 +66,140 @@ class CDL(RasterDataset):
|
|||
}
|
||||
|
||||
cmap = {
|
||||
0: (0, 0, 0, 0),
|
||||
0: (0, 0, 0, 255),
|
||||
1: (255, 211, 0, 255),
|
||||
2: (255, 38, 38, 255),
|
||||
3: (0, 168, 228, 255),
|
||||
4: (255, 158, 11, 255),
|
||||
5: (38, 112, 0, 255),
|
||||
2: (255, 37, 37, 255),
|
||||
3: (0, 168, 226, 255),
|
||||
4: (255, 158, 9, 255),
|
||||
5: (37, 111, 0, 255),
|
||||
6: (255, 255, 0, 255),
|
||||
7: (112, 165, 0, 255),
|
||||
8: (0, 175, 75, 255),
|
||||
9: (221, 165, 11, 255),
|
||||
10: (221, 165, 11, 255),
|
||||
11: (126, 211, 255, 255),
|
||||
12: (226, 0, 124, 255),
|
||||
13: (137, 98, 84, 255),
|
||||
14: (216, 181, 107, 255),
|
||||
15: (165, 112, 0, 255),
|
||||
16: (214, 158, 188, 255),
|
||||
17: (112, 112, 0, 255),
|
||||
18: (172, 0, 124, 255),
|
||||
19: (160, 89, 137, 255),
|
||||
20: (112, 0, 73, 255),
|
||||
21: (214, 158, 188, 255),
|
||||
22: (209, 255, 0, 255),
|
||||
23: (126, 153, 255, 255),
|
||||
24: (214, 214, 0, 255),
|
||||
25: (209, 255, 0, 255),
|
||||
26: (0, 175, 75, 255),
|
||||
27: (255, 165, 226, 255),
|
||||
28: (165, 242, 140, 255),
|
||||
29: (0, 175, 75, 255),
|
||||
30: (214, 158, 188, 255),
|
||||
31: (168, 0, 228, 255),
|
||||
32: (165, 0, 0, 255),
|
||||
33: (112, 38, 0, 255),
|
||||
34: (0, 175, 75, 255),
|
||||
35: (177, 126, 255, 255),
|
||||
36: (112, 38, 0, 255),
|
||||
37: (255, 102, 102, 255),
|
||||
38: (255, 102, 102, 255),
|
||||
39: (255, 204, 102, 255),
|
||||
40: (255, 102, 102, 255),
|
||||
41: (0, 175, 75, 255),
|
||||
42: (0, 221, 175, 255),
|
||||
43: (84, 255, 0, 255),
|
||||
44: (242, 163, 119, 255),
|
||||
45: (255, 102, 102, 255),
|
||||
46: (0, 175, 75, 255),
|
||||
47: (126, 211, 255, 255),
|
||||
48: (232, 191, 255, 255),
|
||||
49: (175, 255, 221, 255),
|
||||
50: (0, 175, 75, 255),
|
||||
51: (191, 191, 119, 255),
|
||||
52: (147, 204, 147, 255),
|
||||
53: (198, 214, 158, 255),
|
||||
54: (204, 191, 163, 255),
|
||||
55: (255, 0, 255, 255),
|
||||
56: (255, 142, 170, 255),
|
||||
57: (186, 0, 79, 255),
|
||||
58: (112, 68, 137, 255),
|
||||
59: (0, 119, 119, 255),
|
||||
60: (177, 154, 112, 255),
|
||||
61: (255, 255, 126, 255),
|
||||
62: (181, 112, 91, 255),
|
||||
63: (0, 165, 130, 255),
|
||||
64: (233, 214, 175, 255),
|
||||
65: (177, 154, 112, 255),
|
||||
66: (242, 242, 242, 255),
|
||||
67: (154, 154, 154, 255),
|
||||
68: (75, 112, 163, 255),
|
||||
69: (126, 177, 177, 255),
|
||||
70: (232, 255, 191, 255),
|
||||
71: (0, 255, 255, 255),
|
||||
72: (75, 112, 163, 255),
|
||||
73: (211, 226, 249, 255),
|
||||
74: (154, 154, 154, 255),
|
||||
75: (154, 154, 154, 255),
|
||||
76: (154, 154, 154, 255),
|
||||
77: (154, 154, 154, 255),
|
||||
78: (204, 191, 163, 255),
|
||||
79: (147, 204, 147, 255),
|
||||
80: (147, 204, 147, 255),
|
||||
81: (147, 204, 147, 255),
|
||||
82: (198, 214, 158, 255),
|
||||
83: (232, 255, 191, 255),
|
||||
84: (126, 177, 177, 255),
|
||||
85: (126, 177, 177, 255),
|
||||
86: (0, 255, 140, 255),
|
||||
87: (214, 158, 188, 255),
|
||||
88: (255, 102, 102, 255),
|
||||
89: (255, 102, 102, 255),
|
||||
90: (255, 102, 102, 255),
|
||||
91: (255, 102, 102, 255),
|
||||
92: (255, 142, 170, 255),
|
||||
93: (51, 73, 51, 255),
|
||||
94: (228, 112, 38, 255),
|
||||
95: (255, 102, 102, 255),
|
||||
96: (255, 102, 102, 255),
|
||||
97: (102, 153, 76, 255),
|
||||
98: (255, 102, 102, 255),
|
||||
99: (177, 154, 112, 255),
|
||||
100: (255, 142, 170, 255),
|
||||
101: (255, 102, 102, 255),
|
||||
102: (255, 142, 170, 255),
|
||||
103: (255, 102, 102, 255),
|
||||
104: (255, 102, 102, 255),
|
||||
105: (255, 142, 170, 255),
|
||||
106: (0, 175, 75, 255),
|
||||
107: (255, 211, 0, 255),
|
||||
108: (255, 211, 0, 255),
|
||||
109: (255, 102, 102, 255),
|
||||
110: (255, 210, 0, 255),
|
||||
111: (255, 102, 102, 255),
|
||||
112: (137, 98, 84, 255),
|
||||
113: (255, 102, 102, 255),
|
||||
114: (255, 38, 38, 255),
|
||||
115: (226, 0, 124, 255),
|
||||
116: (255, 158, 11, 255),
|
||||
117: (255, 158, 11, 255),
|
||||
118: (165, 112, 0, 255),
|
||||
119: (255, 211, 0, 255),
|
||||
120: (165, 112, 0, 255),
|
||||
121: (38, 112, 0, 255),
|
||||
122: (38, 112, 0, 255),
|
||||
123: (255, 211, 0, 255),
|
||||
124: (0, 0, 153, 255),
|
||||
125: (255, 102, 102, 255),
|
||||
126: (255, 102, 102, 255),
|
||||
127: (255, 102, 102, 255),
|
||||
128: (255, 102, 102, 255),
|
||||
129: (255, 102, 102, 255),
|
||||
130: (255, 102, 102, 255),
|
||||
131: (255, 102, 102, 255),
|
||||
132: (255, 102, 102, 255),
|
||||
133: (38, 112, 0, 255),
|
||||
}
|
||||
|
||||
ordinal_label_map = {
|
||||
0: 0,
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 3,
|
||||
4: 4,
|
||||
5: 5,
|
||||
6: 6,
|
||||
10: 7,
|
||||
11: 8,
|
||||
12: 9,
|
||||
13: 10,
|
||||
14: 11,
|
||||
21: 12,
|
||||
22: 13,
|
||||
23: 14,
|
||||
24: 15,
|
||||
25: 16,
|
||||
26: 17,
|
||||
27: 18,
|
||||
28: 19,
|
||||
29: 20,
|
||||
30: 21,
|
||||
31: 22,
|
||||
32: 23,
|
||||
33: 24,
|
||||
34: 25,
|
||||
35: 26,
|
||||
36: 27,
|
||||
37: 28,
|
||||
38: 29,
|
||||
39: 30,
|
||||
41: 31,
|
||||
42: 32,
|
||||
43: 33,
|
||||
44: 34,
|
||||
45: 35,
|
||||
46: 36,
|
||||
47: 37,
|
||||
48: 38,
|
||||
49: 39,
|
||||
50: 40,
|
||||
51: 41,
|
||||
52: 42,
|
||||
53: 43,
|
||||
54: 44,
|
||||
55: 45,
|
||||
56: 46,
|
||||
57: 47,
|
||||
58: 48,
|
||||
59: 49,
|
||||
60: 50,
|
||||
61: 51,
|
||||
63: 52,
|
||||
64: 53,
|
||||
65: 54,
|
||||
66: 55,
|
||||
67: 56,
|
||||
68: 57,
|
||||
69: 58,
|
||||
70: 59,
|
||||
71: 60,
|
||||
72: 61,
|
||||
74: 62,
|
||||
75: 63,
|
||||
76: 64,
|
||||
77: 65,
|
||||
81: 66,
|
||||
82: 67,
|
||||
83: 68,
|
||||
87: 69,
|
||||
88: 70,
|
||||
92: 71,
|
||||
111: 72,
|
||||
112: 73,
|
||||
121: 74,
|
||||
122: 75,
|
||||
123: 76,
|
||||
124: 77,
|
||||
131: 78,
|
||||
141: 79,
|
||||
142: 80,
|
||||
143: 81,
|
||||
152: 82,
|
||||
176: 83,
|
||||
190: 84,
|
||||
195: 85,
|
||||
204: 86,
|
||||
205: 87,
|
||||
206: 88,
|
||||
207: 89,
|
||||
208: 90,
|
||||
209: 91,
|
||||
210: 92,
|
||||
211: 93,
|
||||
212: 94,
|
||||
213: 95,
|
||||
214: 96,
|
||||
215: 97,
|
||||
216: 98,
|
||||
217: 99,
|
||||
218: 100,
|
||||
219: 101,
|
||||
220: 102,
|
||||
221: 103,
|
||||
222: 104,
|
||||
223: 105,
|
||||
224: 106,
|
||||
225: 107,
|
||||
226: 108,
|
||||
227: 109,
|
||||
228: 110,
|
||||
229: 111,
|
||||
230: 112,
|
||||
231: 113,
|
||||
232: 114,
|
||||
233: 115,
|
||||
234: 116,
|
||||
235: 117,
|
||||
236: 118,
|
||||
237: 119,
|
||||
238: 120,
|
||||
239: 121,
|
||||
240: 122,
|
||||
241: 123,
|
||||
242: 124,
|
||||
243: 125,
|
||||
244: 126,
|
||||
245: 127,
|
||||
246: 128,
|
||||
247: 129,
|
||||
248: 130,
|
||||
249: 131,
|
||||
250: 132,
|
||||
254: 133,
|
||||
10: (111, 166, 0, 255),
|
||||
11: (0, 175, 73, 255),
|
||||
12: (222, 166, 9, 255),
|
||||
13: (222, 166, 9, 255),
|
||||
14: (124, 211, 255, 255),
|
||||
21: (226, 0, 124, 255),
|
||||
22: (137, 96, 83, 255),
|
||||
23: (217, 181, 107, 255),
|
||||
24: (166, 111, 0, 255),
|
||||
25: (213, 158, 188, 255),
|
||||
26: (111, 111, 0, 255),
|
||||
27: (171, 0, 124, 255),
|
||||
28: (160, 88, 137, 255),
|
||||
29: (111, 0, 73, 255),
|
||||
30: (213, 158, 188, 255),
|
||||
31: (209, 255, 0, 255),
|
||||
32: (124, 153, 255, 255),
|
||||
33: (213, 213, 0, 255),
|
||||
34: (209, 255, 0, 255),
|
||||
35: (0, 175, 73, 255),
|
||||
36: (255, 166, 226, 255),
|
||||
37: (166, 241, 139, 255),
|
||||
38: (0, 175, 73, 255),
|
||||
39: (213, 158, 188, 255),
|
||||
41: (168, 0, 226, 255),
|
||||
42: (166, 0, 0, 255),
|
||||
43: (111, 37, 0, 255),
|
||||
44: (0, 175, 73, 255),
|
||||
45: (175, 124, 255, 255),
|
||||
46: (111, 37, 0, 255),
|
||||
47: (255, 102, 102, 255),
|
||||
48: (255, 102, 102, 255),
|
||||
49: (255, 204, 102, 255),
|
||||
50: (255, 102, 102, 255),
|
||||
51: (0, 175, 73, 255),
|
||||
52: (0, 222, 175, 255),
|
||||
53: (83, 255, 0, 255),
|
||||
54: (241, 162, 120, 255),
|
||||
55: (255, 102, 102, 255),
|
||||
56: (0, 175, 73, 255),
|
||||
57: (124, 211, 255, 255),
|
||||
58: (232, 190, 255, 255),
|
||||
59: (175, 255, 222, 255),
|
||||
60: (0, 175, 73, 255),
|
||||
61: (190, 190, 120, 255),
|
||||
63: (147, 204, 147, 255),
|
||||
64: (198, 213, 158, 255),
|
||||
65: (204, 190, 162, 255),
|
||||
66: (255, 0, 255, 255),
|
||||
67: (255, 143, 171, 255),
|
||||
68: (185, 0, 79, 255),
|
||||
69: (111, 69, 137, 255),
|
||||
70: (0, 120, 120, 255),
|
||||
71: (175, 153, 111, 255),
|
||||
72: (255, 255, 124, 255),
|
||||
74: (181, 111, 92, 255),
|
||||
75: (0, 166, 130, 255),
|
||||
76: (232, 213, 175, 255),
|
||||
77: (175, 153, 111, 255),
|
||||
81: (241, 241, 241, 255),
|
||||
82: (153, 153, 153, 255),
|
||||
83: (73, 111, 162, 255),
|
||||
87: (124, 175, 175, 255),
|
||||
88: (232, 255, 190, 255),
|
||||
92: (0, 255, 255, 255),
|
||||
111: (73, 111, 162, 255),
|
||||
112: (211, 226, 249, 255),
|
||||
121: (153, 153, 153, 255),
|
||||
122: (153, 153, 153, 255),
|
||||
123: (153, 153, 153, 255),
|
||||
124: (153, 153, 153, 255),
|
||||
131: (204, 190, 162, 255),
|
||||
141: (147, 204, 147, 255),
|
||||
142: (147, 204, 147, 255),
|
||||
143: (147, 204, 147, 255),
|
||||
152: (198, 213, 158, 255),
|
||||
176: (232, 255, 190, 255),
|
||||
190: (124, 175, 175, 255),
|
||||
195: (124, 175, 175, 255),
|
||||
204: (0, 255, 139, 255),
|
||||
205: (213, 158, 188, 255),
|
||||
206: (255, 102, 102, 255),
|
||||
207: (255, 102, 102, 255),
|
||||
208: (255, 102, 102, 255),
|
||||
209: (255, 102, 102, 255),
|
||||
210: (255, 143, 171, 255),
|
||||
211: (51, 73, 51, 255),
|
||||
212: (226, 111, 37, 255),
|
||||
213: (255, 102, 102, 255),
|
||||
214: (255, 102, 102, 255),
|
||||
215: (102, 153, 77, 255),
|
||||
216: (255, 102, 102, 255),
|
||||
217: (175, 153, 111, 255),
|
||||
218: (255, 143, 171, 255),
|
||||
219: (255, 102, 102, 255),
|
||||
220: (255, 143, 171, 255),
|
||||
221: (255, 102, 102, 255),
|
||||
222: (255, 102, 102, 255),
|
||||
223: (255, 143, 171, 255),
|
||||
224: (0, 175, 73, 255),
|
||||
225: (255, 211, 0, 255),
|
||||
226: (255, 211, 0, 255),
|
||||
227: (255, 102, 102, 255),
|
||||
228: (255, 211, 0, 255),
|
||||
229: (255, 102, 102, 255),
|
||||
230: (137, 96, 83, 255),
|
||||
231: (255, 102, 102, 255),
|
||||
232: (255, 37, 37, 255),
|
||||
233: (226, 0, 124, 255),
|
||||
234: (255, 158, 9, 255),
|
||||
235: (255, 158, 9, 255),
|
||||
236: (166, 111, 0, 255),
|
||||
237: (255, 211, 0, 255),
|
||||
238: (166, 111, 0, 255),
|
||||
239: (37, 111, 0, 255),
|
||||
240: (37, 111, 0, 255),
|
||||
241: (255, 211, 0, 255),
|
||||
242: (0, 0, 153, 255),
|
||||
243: (255, 102, 102, 255),
|
||||
244: (255, 102, 102, 255),
|
||||
245: (255, 102, 102, 255),
|
||||
246: (255, 102, 102, 255),
|
||||
247: (255, 102, 102, 255),
|
||||
248: (255, 102, 102, 255),
|
||||
249: (255, 102, 102, 255),
|
||||
250: (255, 102, 102, 255),
|
||||
254: (37, 111, 0, 255),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -346,6 +208,7 @@ class CDL(RasterDataset):
|
|||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
years: list[int] = [2022],
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
|
@ -360,6 +223,8 @@ class CDL(RasterDataset):
|
|||
res: resolution of the dataset in units of CRS
|
||||
(defaults to the resolution of the first file found)
|
||||
years: list of years for which to use cdl layer
|
||||
classes: list of classes to include, the rest will be mapped to 0
|
||||
(defaults to all classes)
|
||||
transforms: a function/transform that takes an input sample
|
||||
and returns a transformed version
|
||||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
@ -367,25 +232,39 @@ class CDL(RasterDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``years`` or ``classes`` are invalid
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
|
||||
.. versionadded:: 0.5
|
||||
The *years* parameter.
|
||||
The *years* and *classes* parameters.
|
||||
"""
|
||||
assert set(years).issubset(self.md5s.keys()), (
|
||||
assert set(years) <= self.md5s.keys(), (
|
||||
"CDL data product only exists for the following years: "
|
||||
f"{list(self.md5s.keys())}."
|
||||
)
|
||||
self.years = years
|
||||
assert (
|
||||
set(classes) <= self.cmap.keys()
|
||||
), f"Only the following classes are valid: {list(self.cmap.keys())}."
|
||||
assert 0 in classes, "Classes must include the background class: 0"
|
||||
|
||||
self.root = root
|
||||
self.years = years
|
||||
self.classes = classes
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype)
|
||||
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
|
||||
|
||||
self._verify()
|
||||
|
||||
super().__init__(root, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
# Map chosen classes to ordinal numbers, all others mapped to background class
|
||||
for v, k in enumerate(self.classes):
|
||||
self.ordinal_map[k] = v
|
||||
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
||||
"""Retrieve mask and metadata indexed by query.
|
||||
|
||||
|
@ -399,13 +278,7 @@ class CDL(RasterDataset):
|
|||
IndexError: if query is not found in the index
|
||||
"""
|
||||
sample = super().__getitem__(query)
|
||||
|
||||
mask = sample["mask"]
|
||||
for k, v in self.ordinal_label_map.items():
|
||||
mask[mask == k] = v
|
||||
|
||||
sample["mask"] = mask
|
||||
|
||||
sample["mask"] = self.ordinal_map[sample["mask"]]
|
||||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
|
@ -489,33 +362,26 @@ class CDL(RasterDataset):
|
|||
Method now takes a sample dict, not a Tensor. Additionally, possible to
|
||||
show subplot titles and/or use a custom suptitle.
|
||||
"""
|
||||
mask = sample["mask"].squeeze().numpy()
|
||||
mask = sample["mask"].squeeze()
|
||||
ncols = 1
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
pred = sample["prediction"].squeeze().numpy()
|
||||
pred = sample["prediction"].squeeze()
|
||||
ncols = 2
|
||||
|
||||
kwargs = {
|
||||
"cmap": ListedColormap(np.array(list(self.cmap.values())) / 255),
|
||||
"vmin": 0,
|
||||
"vmax": len(self.cmap) - 1,
|
||||
"interpolation": "none",
|
||||
}
|
||||
|
||||
fig, axs = plt.subplots(
|
||||
nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False
|
||||
)
|
||||
|
||||
axs[0, 0].imshow(mask, **kwargs)
|
||||
axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none")
|
||||
axs[0, 0].axis("off")
|
||||
|
||||
if show_titles:
|
||||
axs[0, 0].set_title("Mask")
|
||||
|
||||
if showing_predictions:
|
||||
axs[0, 1].imshow(pred, **kwargs)
|
||||
axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none")
|
||||
axs[0, 1].axis("off")
|
||||
if show_titles:
|
||||
axs[0, 1].set_title("Prediction")
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
|
||||
"""NLCD dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.colors import ListedColormap
|
||||
import torch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -83,44 +83,24 @@ class NLCD(RasterDataset):
|
|||
2019: "82851c3f8105763b01c83b4a9e6f3961",
|
||||
}
|
||||
|
||||
ordinal_label_map = {
|
||||
0: 0,
|
||||
11: 1,
|
||||
12: 2,
|
||||
21: 3,
|
||||
22: 4,
|
||||
23: 5,
|
||||
24: 6,
|
||||
31: 7,
|
||||
41: 8,
|
||||
42: 9,
|
||||
43: 10,
|
||||
52: 11,
|
||||
71: 12,
|
||||
81: 13,
|
||||
82: 14,
|
||||
90: 15,
|
||||
95: 16,
|
||||
}
|
||||
|
||||
cmap = {
|
||||
0: (0, 0, 0, 255),
|
||||
1: (70, 107, 159, 255),
|
||||
2: (209, 222, 248, 255),
|
||||
3: (222, 197, 197, 255),
|
||||
4: (217, 146, 130, 255),
|
||||
5: (235, 0, 0, 255),
|
||||
6: (171, 0, 0, 255),
|
||||
7: (179, 172, 159, 255),
|
||||
8: (104, 171, 95, 255),
|
||||
9: (28, 95, 44, 255),
|
||||
10: (181, 197, 143, 255),
|
||||
11: (204, 184, 121, 255),
|
||||
12: (223, 223, 194, 255),
|
||||
13: (220, 217, 57, 255),
|
||||
14: (171, 108, 40, 255),
|
||||
15: (184, 217, 235, 255),
|
||||
16: (108, 159, 184, 255),
|
||||
0: (0, 0, 0, 0),
|
||||
11: (70, 107, 159, 255),
|
||||
12: (209, 222, 248, 255),
|
||||
21: (222, 197, 197, 255),
|
||||
22: (217, 146, 130, 255),
|
||||
23: (235, 0, 0, 255),
|
||||
24: (171, 0, 0, 255),
|
||||
31: (179, 172, 159, 255),
|
||||
41: (104, 171, 95, 255),
|
||||
42: (28, 95, 44, 255),
|
||||
43: (181, 197, 143, 255),
|
||||
52: (204, 184, 121, 255),
|
||||
71: (223, 223, 194, 255),
|
||||
81: (220, 217, 57, 255),
|
||||
82: (171, 108, 40, 255),
|
||||
90: (184, 217, 235, 255),
|
||||
95: (108, 159, 184, 255),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
@ -129,6 +109,7 @@ class NLCD(RasterDataset):
|
|||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
years: list[int] = [2019],
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
|
@ -143,6 +124,8 @@ class NLCD(RasterDataset):
|
|||
res: resolution of the dataset in units of CRS
|
||||
(defaults to the resolution of the first file found)
|
||||
years: list of years for which to use nlcd layer
|
||||
classes: list of classes to include, the rest will be mapped to 0
|
||||
(defaults to all classes)
|
||||
transforms: a function/transform that takes an input sample
|
||||
and returns a transformed version
|
||||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
@ -150,23 +133,36 @@ class NLCD(RasterDataset):
|
|||
checksum: if True, check the MD5 after downloading files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``years`` or ``classes`` are invalid
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
AssertionError: if ``year`` is invalid
|
||||
"""
|
||||
assert set(years).issubset(self.md5s.keys()), (
|
||||
assert set(years) <= self.md5s.keys(), (
|
||||
"NLCD data product only exists for the following years: "
|
||||
f"{list(self.md5s.keys())}."
|
||||
)
|
||||
self.years = years
|
||||
assert (
|
||||
set(classes) <= self.cmap.keys()
|
||||
), f"Only the following classes are valid: {list(self.cmap.keys())}."
|
||||
assert 0 in classes, "Classes must include the background class: 0"
|
||||
|
||||
self.root = root
|
||||
self.years = years
|
||||
self.classes = classes
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype)
|
||||
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
|
||||
|
||||
self._verify()
|
||||
|
||||
super().__init__(root, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
# Map chosen classes to ordinal numbers, all others mapped to background class
|
||||
for v, k in enumerate(self.classes):
|
||||
self.ordinal_map[k] = v
|
||||
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
|
||||
"""Retrieve mask and metadata indexed by query.
|
||||
|
||||
|
@ -180,13 +176,7 @@ class NLCD(RasterDataset):
|
|||
IndexError: if query is not found in the index
|
||||
"""
|
||||
sample = super().__getitem__(query)
|
||||
|
||||
mask = sample["mask"]
|
||||
for k, v in self.ordinal_label_map.items():
|
||||
mask[mask == k] = v
|
||||
|
||||
sample["mask"] = mask
|
||||
|
||||
sample["mask"] = self.ordinal_map[sample["mask"]]
|
||||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
|
@ -199,9 +189,8 @@ class NLCD(RasterDataset):
|
|||
exists = []
|
||||
for year in self.years:
|
||||
filename_year = self.filename_glob.replace("*", str(year))
|
||||
dirname_year = filename_year.split(".")[0]
|
||||
pathname = os.path.join(self.root, dirname_year, filename_year)
|
||||
if os.path.exists(pathname):
|
||||
pathname = os.path.join(self.root, "**", filename_year)
|
||||
if glob.glob(pathname, recursive=True):
|
||||
exists.append(True)
|
||||
else:
|
||||
exists.append(False)
|
||||
|
@ -212,10 +201,9 @@ class NLCD(RasterDataset):
|
|||
# Check if the zip files have already been downloaded
|
||||
exists = []
|
||||
for year in self.years:
|
||||
pathname = os.path.join(
|
||||
self.root, self.zipfile_glob.replace("*", str(year))
|
||||
)
|
||||
if os.path.exists(pathname):
|
||||
zipfile_year = self.zipfile_glob.replace("*", str(year))
|
||||
pathname = os.path.join(self.root, "**", zipfile_year)
|
||||
if glob.glob(pathname, recursive=True):
|
||||
exists.append(True)
|
||||
self._extract()
|
||||
else:
|
||||
|
@ -249,8 +237,8 @@ class NLCD(RasterDataset):
|
|||
"""Extract the dataset."""
|
||||
for year in self.years:
|
||||
zipfile_name = self.zipfile_glob.replace("*", str(year))
|
||||
pathname = os.path.join(self.root, zipfile_name)
|
||||
extract_archive(pathname, self.root)
|
||||
pathname = os.path.join(self.root, "**", zipfile_name)
|
||||
extract_archive(glob.glob(pathname, recursive=True)[0], self.root)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
@ -268,33 +256,26 @@ class NLCD(RasterDataset):
|
|||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
mask = sample["mask"].squeeze().numpy()
|
||||
mask = sample["mask"].squeeze()
|
||||
ncols = 1
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
pred = sample["prediction"].squeeze().numpy()
|
||||
pred = sample["prediction"].squeeze()
|
||||
ncols = 2
|
||||
|
||||
kwargs = {
|
||||
"cmap": ListedColormap(np.array(list(self.cmap.values())) / 255),
|
||||
"vmin": 0,
|
||||
"vmax": len(self.cmap) - 1,
|
||||
"interpolation": "none",
|
||||
}
|
||||
|
||||
fig, axs = plt.subplots(
|
||||
nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False
|
||||
)
|
||||
|
||||
axs[0, 0].imshow(mask, **kwargs)
|
||||
axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none")
|
||||
axs[0, 0].axis("off")
|
||||
|
||||
if show_titles:
|
||||
axs[0, 0].set_title("Mask")
|
||||
|
||||
if showing_predictions:
|
||||
axs[0, 1].imshow(pred, **kwargs)
|
||||
axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none")
|
||||
axs[0, 1].axis("off")
|
||||
if show_titles:
|
||||
axs[0, 1].set_title("Prediction")
|
||||
|
|
|
@ -11,7 +11,6 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from torch import Tensor
|
||||
|
||||
from .cdl import CDL
|
||||
|
@ -45,8 +44,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
|
||||
url = "https://huggingface.co/datasets/torchgeo/{}/resolve/main/{}.tar.gz"
|
||||
|
||||
valid_input_sensors = ["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"]
|
||||
valid_mask_products = ["cdl", "nlcd"]
|
||||
valid_sensors = ["tm_toa", "etm_toa", "etm_sr", "oli_tirs_toa", "oli_sr"]
|
||||
valid_products = ["cdl", "nlcd"]
|
||||
valid_splits = ["train", "val", "test"]
|
||||
|
||||
image_root = "ssl4eo_l_{}_benchmark"
|
||||
|
@ -98,16 +97,15 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
|
||||
split_percentages = [0.7, 0.15, 0.15]
|
||||
|
||||
ordinal_label_map = {"nlcd": NLCD.ordinal_label_map, "cdl": CDL.ordinal_label_map}
|
||||
|
||||
cmaps = {"nlcd": NLCD.cmap, "cdl": CDL.cmap}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
input_sensor: str = "oli_sr",
|
||||
mask_product: str = "cdl",
|
||||
sensor: str = "oli_sr",
|
||||
product: str = "cdl",
|
||||
split: str = "train",
|
||||
classes: Optional[list[int]] = None,
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -116,9 +114,11 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
input_sensor: one of ['etm_toa', 'etm_sr', 'oli_tirs_toa, 'oli_sr']
|
||||
mask_product: mask target one of ['cdl', 'nlcd']
|
||||
sensor: one of ['etm_toa', 'etm_sr', 'oli_tirs_toa, 'oli_sr']
|
||||
product: mask target, one of ['cdl', 'nlcd']
|
||||
split: dataset split, one of ['train', 'val', 'test']
|
||||
classes: list of classes to include, the rest will be mapped to 0
|
||||
(defaults to all classes for the chosen product)
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
|
@ -126,29 +126,39 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if any arguments are invalid
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
assert (
|
||||
input_sensor in self.valid_input_sensors
|
||||
), f"Only supports one of {self.valid_input_sensors}, but found {input_sensor}."
|
||||
self.input_sensor = input_sensor
|
||||
sensor in self.valid_sensors
|
||||
), f"Only supports one of {self.valid_sensors}, but found {sensor}."
|
||||
self.sensor = sensor
|
||||
assert (
|
||||
mask_product in self.valid_mask_products
|
||||
), f"Only supports one of {self.valid_mask_products}, but found {mask_product}."
|
||||
self.mask_product = mask_product
|
||||
product in self.valid_products
|
||||
), f"Only supports one of {self.valid_products}, but found {product}."
|
||||
self.product = product
|
||||
assert (
|
||||
split in self.valid_splits
|
||||
), f"Only supports one of {self.valid_splits}, but found {split}."
|
||||
self.split = split
|
||||
|
||||
self.cmap = self.cmaps[product]
|
||||
if classes is None:
|
||||
classes = list(self.cmap.keys())
|
||||
|
||||
assert (
|
||||
set(classes) <= self.cmap.keys()
|
||||
), f"Only the following classes are valid: {list(self.cmap.keys())}."
|
||||
assert 0 in classes, "Classes must include the background class: 0"
|
||||
|
||||
self.root = root
|
||||
self.classes = classes
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.img_dir_name = self.image_root.format(self.input_sensor)
|
||||
self.mask_dir_name = self.mask_dir_dict[self.input_sensor].format(
|
||||
self.mask_product
|
||||
)
|
||||
self.cmap = self.cmaps[self.mask_product]
|
||||
self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=torch.long)
|
||||
self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8)
|
||||
self.img_dir_name = self.image_root.format(self.sensor)
|
||||
self.mask_dir_name = self.mask_dir_dict[self.sensor].format(self.product)
|
||||
|
||||
self._verify()
|
||||
|
||||
|
@ -169,6 +179,11 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
|
||||
self.sample_collection = [self.sample_collection[idx] for idx in split_indices]
|
||||
|
||||
# Map chosen classes to ordinal numbers, all others mapped to background class
|
||||
for v, k in enumerate(self.classes):
|
||||
self.ordinal_map[k] = v
|
||||
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
|
@ -183,7 +198,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
self.root,
|
||||
self.mask_dir_name,
|
||||
"**",
|
||||
f"{self.mask_product}_{self.year_dict[self.input_sensor]}.tif",
|
||||
f"{self.product}_{self.year_dict[self.sensor]}.tif",
|
||||
)
|
||||
exists.append(bool(glob.glob(mask_pathname, recursive=True)))
|
||||
|
||||
|
@ -219,13 +234,13 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
download_url(
|
||||
self.url.format(self.img_dir_name, self.img_dir_name),
|
||||
self.root,
|
||||
md5=self.img_md5s[self.input_sensor] if self.checksum else None,
|
||||
md5=self.img_md5s[self.sensor] if self.checksum else None,
|
||||
)
|
||||
# download mask
|
||||
download_url(
|
||||
self.url.format(self.mask_dir_name, self.mask_dir_name),
|
||||
self.root,
|
||||
md5=self.mask_md5s[self.input_sensor.split("_")[0]][self.mask_product]
|
||||
md5=self.mask_md5s[self.sensor.split("_")[0]][self.product]
|
||||
if self.checksum
|
||||
else None,
|
||||
)
|
||||
|
@ -277,8 +292,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
sample_collection: list[tuple[str, str]] = []
|
||||
for img_path in img_paths:
|
||||
mask_path = img_path.replace(self.img_dir_name, self.mask_dir_name).replace(
|
||||
"all_bands.tif",
|
||||
f"{self.mask_product}_{self.year_dict[self.input_sensor]}.tif",
|
||||
"all_bands.tif", f"{self.product}_{self.year_dict[self.sensor]}.tif"
|
||||
)
|
||||
sample_collection.append((img_path, mask_path))
|
||||
return sample_collection
|
||||
|
@ -293,8 +307,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
image
|
||||
"""
|
||||
with rasterio.open(path) as src:
|
||||
image = src.read().astype(np.float32)
|
||||
return torch.from_numpy(image)
|
||||
image = torch.from_numpy(src.read()).float()
|
||||
return image
|
||||
|
||||
def _load_mask(self, path: str) -> Tensor:
|
||||
"""Load the mask.
|
||||
|
@ -306,12 +320,8 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
mask
|
||||
"""
|
||||
with rasterio.open(path) as src:
|
||||
mask = src.read()
|
||||
|
||||
for k, v in self.ordinal_label_map[self.mask_product].items():
|
||||
mask[mask == k] = v
|
||||
|
||||
return torch.from_numpy(mask).long()
|
||||
mask = torch.from_numpy(src.read()).long()
|
||||
return self.ordinal_map[mask]
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
@ -330,34 +340,27 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
ncols = 2
|
||||
image = sample["image"][self.rgb_indices[self.input_sensor]].permute(1, 2, 0)
|
||||
image = image.numpy() / 255
|
||||
image = sample["image"][self.rgb_indices[self.sensor]].permute(1, 2, 0)
|
||||
image = image / 255
|
||||
|
||||
mask = sample["mask"].squeeze(0)
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
prediction_mask = sample["prediction"].squeeze(0).numpy()
|
||||
pred = sample["prediction"].squeeze(0)
|
||||
ncols = 3
|
||||
|
||||
kwargs = {
|
||||
"cmap": ListedColormap(np.array(list(self.cmap.values())) / 255),
|
||||
"vmin": 0,
|
||||
"vmax": len(self.cmap) - 1,
|
||||
"interpolation": "none",
|
||||
}
|
||||
|
||||
fig, ax = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4))
|
||||
ax[0].imshow(image)
|
||||
ax[0].axis("off")
|
||||
ax[1].imshow(mask, **kwargs)
|
||||
ax[1].imshow(self.ordinal_cmap[mask], interpolation="none")
|
||||
ax[1].axis("off")
|
||||
if show_titles:
|
||||
ax[0].set_title("Image")
|
||||
ax[1].set_title("Mask")
|
||||
|
||||
if showing_predictions:
|
||||
ax[2].imshow(prediction_mask, **kwargs)
|
||||
ax[2].imshow(self.ordinal_cmap[pred], interpolation="none")
|
||||
if show_titles:
|
||||
ax[2].set_title("Prediction")
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче