diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index eefb9849f..cd6085fbf 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -399,6 +399,11 @@ Rwanda Field Boundary .. autoclass:: RwandaFieldBoundary +SatlasPretrain +^^^^^^^^^^^^^^ + +.. autoclass:: SatlasPretrain + Seasonal Contrast ^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index 560869403..3efd05f92 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -11,7 +11,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI `DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB -`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared +`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared `ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI `FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB @@ -38,6 +38,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB `RESISC45`_,C,Google Earth,-,"31,500",45,256x256,0.2--30,RGB `Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR +`SatlasPretrain`_,"C, R, S, I, OD","NAIP, Landsat, Sentinel",ESA AND CC0-1.0 AND ODbL-1.0 AND CC-BY-4.0,302M,137,512,0.6--30,"SAR, MSI" `Seasonal Contrast`_,T,Sentinel-2,"CC-BY-4.0",100K--1M,-,264x264,10,MSI `SeasoNet`_,S,Sentinel-2,"CC-BY-4.0","1,759,830",33,120x120,10,MSI `SEN12MS`_,S,"Sentinel-1/2, MODIS","CC-BY-4.0","180,662",33,256x256,10,"SAR, MSI" diff --git a/tests/data/satlas/data.py b/tests/data/satlas/data.py new file mode 100755 index 000000000..1661f6e42 --- /dev/null +++ b/tests/data/satlas/data.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import os +import shutil + +from PIL import Image + +SIZE = 32 +landsat_size = { + 'b1': SIZE // 2, + 'b2': SIZE // 2, + 'b3': SIZE // 2, + 'b4': SIZE // 2, + 'b5': SIZE // 2, + 'b6': SIZE // 2, + 'b7': SIZE // 2, + 'b8': SIZE, + 'b9': SIZE // 2, + 'b10': SIZE // 2, + 'b11': SIZE // 4, + 'b12': SIZE // 4, +} + +index = [[7149, 3246], [1234, 5678]] +good_images = [ + [7149, 3246, '2022-03'], + [1234, 5678, '2022-03'], + [7149, 3246, 'm_3808245_se_17_1_20110801'], + [1234, 5678, 'm_3808245_se_17_1_20110801'], + [7149, 3246, '2022-01'], + [1234, 5678, '2022-01'], + [7149, 3246, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'], + [1234, 5678, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'], +] +times = { + '2022-03': '2022-03-01T00:00:00+00:00', + 'm_3808245_se_17_1_20110801': '2011-08-01T12:00:00+00:00', + '2022-01': '2022-01-01T00:00:00+00:00', + 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': '2022-03-09T06:02:35+00:00', +} + +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] +filenames: FILENAME_HIERARCHY = { + 'landsat': {'2022-03': list(f'b{i}' for i in range(1, 12))}, + 'naip': {'m_3808245_se_17_1_20110801': ['tci', 'ir']}, + 'sentinel1': {'2022-01': ['vh', 'vv']}, + 'sentinel2': { + 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': [ + 'tci', + 'b05', + 'b06', + 'b07', + 'b08', + 'b11', + 'b12', + ] + }, +} + + +def create_files(path: str) -> None: + os.makedirs(path, exist_ok=True) + for col, row in index: + band = os.path.basename(path) + mode = 'RGB' if band == 'tci' else 'L' + size = SIZE + if 'landsat' in path: + size = landsat_size[band] + img = Image.new(mode, (size, size)) + img.save(os.path.join(path, f'{col}_{row}.png')) + + +def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: + if isinstance(hierarchy, dict): + # Recursive case + for key, value in hierarchy.items(): + path = os.path.join(directory, key) + create_directory(path, value) + else: + # Base case + for value in hierarchy: + path = os.path.join(directory, value) + create_files(path) + + +if __name__ == '__main__': + create_directory('.', filenames) + + col, row = index[0] + path = os.path.join('static', f'{col}_{row}') + os.makedirs(path, exist_ok=True) + img = Image.new('L', (SIZE, SIZE)) + img.save(os.path.join(path, 'land_cover.png')) + + os.makedirs('metadata', exist_ok=True) + with open(os.path.join('metadata', 'train_lowres.json'), 'w') as f: + json.dump(index, f) + + with open(os.path.join('metadata', 'good_images_lowres_all.json'), 'w') as f: + json.dump(good_images, f) + + with open(os.path.join('metadata', 'image_times.json'), 'w') as f: + json.dump(times, f) + + for path in os.listdir('.'): + if os.path.isdir(path): + shutil.make_archive(path, 'tar', '.', path) diff --git a/tests/data/satlas/landsat.tar b/tests/data/satlas/landsat.tar new file mode 100644 index 000000000..f21ba5980 Binary files /dev/null and b/tests/data/satlas/landsat.tar differ diff --git a/tests/data/satlas/landsat/2022-03/b1/1234_5678.png b/tests/data/satlas/landsat/2022-03/b1/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b1/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b1/7149_3246.png b/tests/data/satlas/landsat/2022-03/b1/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b1/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b10/1234_5678.png b/tests/data/satlas/landsat/2022-03/b10/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b10/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b10/7149_3246.png b/tests/data/satlas/landsat/2022-03/b10/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b10/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b11/1234_5678.png b/tests/data/satlas/landsat/2022-03/b11/1234_5678.png new file mode 100644 index 000000000..a7ff273b8 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b11/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b11/7149_3246.png b/tests/data/satlas/landsat/2022-03/b11/7149_3246.png new file mode 100644 index 000000000..a7ff273b8 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b11/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b2/1234_5678.png b/tests/data/satlas/landsat/2022-03/b2/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b2/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b2/7149_3246.png b/tests/data/satlas/landsat/2022-03/b2/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b2/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b3/1234_5678.png b/tests/data/satlas/landsat/2022-03/b3/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b3/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b3/7149_3246.png b/tests/data/satlas/landsat/2022-03/b3/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b3/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b4/1234_5678.png b/tests/data/satlas/landsat/2022-03/b4/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b4/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b4/7149_3246.png b/tests/data/satlas/landsat/2022-03/b4/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b4/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b5/1234_5678.png b/tests/data/satlas/landsat/2022-03/b5/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b5/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b5/7149_3246.png b/tests/data/satlas/landsat/2022-03/b5/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b5/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b6/1234_5678.png b/tests/data/satlas/landsat/2022-03/b6/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b6/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b6/7149_3246.png b/tests/data/satlas/landsat/2022-03/b6/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b6/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b7/1234_5678.png b/tests/data/satlas/landsat/2022-03/b7/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b7/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b7/7149_3246.png b/tests/data/satlas/landsat/2022-03/b7/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b7/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b8/1234_5678.png b/tests/data/satlas/landsat/2022-03/b8/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b8/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b8/7149_3246.png b/tests/data/satlas/landsat/2022-03/b8/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b8/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b9/1234_5678.png b/tests/data/satlas/landsat/2022-03/b9/1234_5678.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b9/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b9/7149_3246.png b/tests/data/satlas/landsat/2022-03/b9/7149_3246.png new file mode 100644 index 000000000..1c5ee4e26 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b9/7149_3246.png differ diff --git a/tests/data/satlas/metadata.tar b/tests/data/satlas/metadata.tar new file mode 100644 index 000000000..da55bab05 Binary files /dev/null and b/tests/data/satlas/metadata.tar differ diff --git a/tests/data/satlas/metadata/good_images_lowres_all.json b/tests/data/satlas/metadata/good_images_lowres_all.json new file mode 100644 index 000000000..32f868783 --- /dev/null +++ b/tests/data/satlas/metadata/good_images_lowres_all.json @@ -0,0 +1 @@ +[[7149, 3246, "2022-03"], [1234, 5678, "2022-03"], [7149, 3246, "m_3808245_se_17_1_20110801"], [1234, 5678, "m_3808245_se_17_1_20110801"], [7149, 3246, "2022-01"], [1234, 5678, "2022-01"], [7149, 3246, "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235"], [1234, 5678, "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235"]] \ No newline at end of file diff --git a/tests/data/satlas/metadata/image_times.json b/tests/data/satlas/metadata/image_times.json new file mode 100644 index 000000000..9028902e0 --- /dev/null +++ b/tests/data/satlas/metadata/image_times.json @@ -0,0 +1 @@ +{"2022-03": "2022-03-01T00:00:00+00:00", "m_3808245_se_17_1_20110801": "2011-08-01T12:00:00+00:00", "2022-01": "2022-01-01T00:00:00+00:00", "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235": "2022-03-09T06:02:35+00:00"} \ No newline at end of file diff --git a/tests/data/satlas/metadata/train_lowres.json b/tests/data/satlas/metadata/train_lowres.json new file mode 100644 index 000000000..af40dbffd --- /dev/null +++ b/tests/data/satlas/metadata/train_lowres.json @@ -0,0 +1 @@ +[[7149, 3246], [1234, 5678]] \ No newline at end of file diff --git a/tests/data/satlas/naip.tar b/tests/data/satlas/naip.tar new file mode 100644 index 000000000..c77db0c06 Binary files /dev/null and b/tests/data/satlas/naip.tar differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png new file mode 100644 index 000000000..1655bc2ca Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png new file mode 100644 index 000000000..1655bc2ca Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png differ diff --git a/tests/data/satlas/sentinel1.tar b/tests/data/satlas/sentinel1.tar new file mode 100644 index 000000000..755851304 Binary files /dev/null and b/tests/data/satlas/sentinel1.tar differ diff --git a/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png b/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png differ diff --git a/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png b/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png differ diff --git a/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png b/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png differ diff --git a/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png b/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2.tar b/tests/data/satlas/sentinel2.tar new file mode 100644 index 000000000..aa3122a90 Binary files /dev/null and b/tests/data/satlas/sentinel2.tar differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png new file mode 100644 index 000000000..1655bc2ca Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png new file mode 100644 index 000000000..1655bc2ca Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png differ diff --git a/tests/data/satlas/static.tar b/tests/data/satlas/static.tar new file mode 100644 index 000000000..decccd5ca Binary files /dev/null and b/tests/data/satlas/static.tar differ diff --git a/tests/data/satlas/static/7149_3246/land_cover.png b/tests/data/satlas/static/7149_3246/land_cover.png new file mode 100644 index 000000000..c1620c855 Binary files /dev/null and b/tests/data/satlas/static/7149_3246/land_cover.png differ diff --git a/tests/datasets/test_satlas.py b/tests/datasets/test_satlas.py new file mode 100644 index 000000000..7c10f55bd --- /dev/null +++ b/tests/datasets/test_satlas.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, SatlasPretrain +from torchgeo.datasets.utils import Executable + + +class TestSatlasPretrain: + @pytest.fixture + def dataset( + self, aws: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> SatlasPretrain: + url = os.path.join('tests', 'data', 'satlas', '') + monkeypatch.setattr(SatlasPretrain, 'url', url) + images = ('landsat', 'naip', 'sentinel1', 'sentinel2') + products = (*images, 'static', 'metadata') + tarballs = {product: (f'{product}.tar',) for product in products} + monkeypatch.setattr(SatlasPretrain, 'tarballs', tarballs) + transforms = nn.Identity() + return SatlasPretrain( + tmp_path, images=images, transforms=transforms, download=True + ) + + @pytest.mark.parametrize('index', [0, 1]) + def test_getitem(self, dataset: SatlasPretrain, index: int) -> None: + x = dataset[index] + assert isinstance(x, dict) + for image in dataset.images: + assert isinstance(x[f'image_{image}'], Tensor) + assert isinstance(x[f'time_{image}'], Tensor) + for label in dataset.labels: + assert isinstance(x[f'mask_{label}'], Tensor) + + def test_len(self, dataset: SatlasPretrain) -> None: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: SatlasPretrain) -> None: + shutil.rmtree(os.path.join(dataset.root, 'landsat')) + SatlasPretrain(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + SatlasPretrain(tmp_path) + + def test_plot(self, dataset: SatlasPretrain) -> None: + x = dataset[0] + x['prediction_land_cover'] = x['mask_land_cover'] + dataset.plot(x, suptitle='Test') + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 663b08e7c..981cf2b26 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -99,6 +99,7 @@ from .quakeset import QuakeSet from .reforestree import ReforesTree from .resisc45 import RESISC45 from .rwanda_field_boundary import RwandaFieldBoundary +from .satlas import SatlasPretrain from .seasonet import SeasoNet from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS @@ -244,6 +245,7 @@ __all__ = ( 'RESISC45', 'ReforesTree', 'RwandaFieldBoundary', + 'SatlasPretrain', 'SeasonalContrastS2', 'SeasoNet', 'SEN12MS', diff --git a/torchgeo/datasets/satlas.py b/torchgeo/datasets/satlas.py new file mode 100644 index 000000000..f52a7fef9 --- /dev/null +++ b/torchgeo/datasets/satlas.py @@ -0,0 +1,770 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SatlasPretrain dataset.""" + +import os +from collections.abc import Callable, Iterable +from typing import ClassVar, TypedDict + +import numpy as np +import pandas as pd +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, check_integrity, extract_archive, which + + +class _Task(TypedDict, total=False): + BackgroundInvalid: bool + categories: list[str] + colors: list[list[int]] + type: str + + +# https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py +TASKS: dict[str, _Task] = { + 'polyline_bin_segment': { + 'type': 'bin_segment', + 'categories': [ + 'airport_runway', + 'airport_taxiway', + 'raceway', + 'road', + 'railway', + 'river', + ], + 'colors': [ + [255, 255, 255], # (white) airport_runway + [192, 192, 192], # (light grey) airport_taxiway + [160, 82, 45], # (sienna) raceway + [255, 255, 255], # (white) road + [144, 238, 144], # (light green) railway + [0, 0, 255], # (blue) river + ], + }, + 'bin_segment': { + 'type': 'bin_segment', + 'categories': [ + 'aquafarm', + 'lock', + 'dam', + 'solar_farm', + 'power_plant', + 'gas_station', + 'park', + 'parking_garage', + 'parking_lot', + 'landfill', + 'quarry', + 'stadium', + 'airport', + 'airport_runway', + 'airport_taxiway', + 'airport_apron', + 'airport_hangar', + 'airstrip', + 'airport_terminal', + 'ski_resort', + 'theme_park', + 'storage_tank', + 'silo', + 'track', + 'raceway', + 'wastewater_plant', + 'road', + 'railway', + 'river', + 'water_park', + 'pier', + 'water_tower', + 'street_lamp', + 'traffic_signals', + 'power_tower', + 'power_substation', + 'building', + 'bridge', + 'road_motorway', + 'road_trunk', + 'road_primary', + 'road_secondary', + 'road_tertiary', + 'road_residential', + 'road_service', + 'road_track', + 'road_pedestrian', + ], + 'colors': [ + [32, 178, 170], # (light sea green) aquafarm + [0, 255, 255], # (cyan) lock + [173, 216, 230], # (light blue) dam + [255, 0, 255], # (magenta) solar farm + [255, 165, 0], # (orange) power plant + [128, 128, 0], # (olive) gas station + [0, 255, 0], # (green) park + [47, 79, 79], # (dark slate gray) parking garage + [128, 0, 0], # (maroon) parking lot + [165, 42, 42], # (brown) landfill + [128, 128, 128], # (grey) quarry + [255, 215, 0], # (gold) stadium + [255, 105, 180], # (pink) airport + [255, 255, 255], # (white) airport_runway + [192, 192, 192], # (light grey) airport_taxiway + [128, 0, 128], # (purple) airport_apron + [0, 128, 0], # (dark green) airport_hangar + [248, 248, 255], # (ghost white) airstrip + [240, 230, 140], # (khaki) airport_terminal + [192, 192, 192], # (silver) ski_resort + [0, 96, 0], # (dark green) theme_park + [95, 158, 160], # (cadet blue) storage_tank + [205, 133, 63], # (peru) silo + [154, 205, 50], # (yellow green) track + [160, 82, 45], # (sienna) raceway + [218, 112, 214], # (orchid) wastewater_plant + [255, 255, 255], # (white) road + [144, 238, 144], # (light green) railway + [0, 0, 255], # (blue) river + [255, 240, 245], # (lavender blush) water_park + [65, 105, 225], # (royal blue) pier + [238, 130, 238], # (violet) water_tower + [75, 0, 130], # (indigo) street_lamp + [233, 150, 122], # (dark salmon) traffic_signals + [255, 255, 0], # (yellow) power_tower + [255, 255, 0], # (yellow) power_substation + [255, 0, 0], # (red) building + [64, 64, 64], # (dark grey) bridge + [255, 255, 255], # (white) road_motorway + [255, 255, 255], # (white) road_trunk + [255, 255, 255], # (white) road_primary + [255, 255, 255], # (white) road_secondary + [255, 255, 255], # (white) road_tertiary + [255, 255, 255], # (white) road_residential + [255, 255, 255], # (white) road_service + [255, 255, 255], # (white) road_track + [255, 255, 255], # (white) road_pedestrian + ], + }, + 'land_cover': { + 'type': 'segment', + 'BackgroundInvalid': True, + 'categories': [ + 'background', + 'water', + 'developed', + 'tree', + 'shrub', + 'grass', + 'crop', + 'bare', + 'snow', + 'wetland', + 'mangroves', + 'moss', + ], + 'colors': [ + [0, 0, 0], # unknown + [0, 0, 255], # (blue) water + [255, 0, 0], # (red) developed + [0, 192, 0], # (dark green) tree + [200, 170, 120], # (brown) shrub + [0, 255, 0], # (green) grass + [255, 255, 0], # (yellow) crop + [128, 128, 128], # (grey) bare + [255, 255, 255], # (white) snow + [0, 255, 255], # (cyan) wetland + [255, 0, 255], # (pink) mangroves + [128, 0, 128], # (purple) moss + ], + }, + 'tree_cover': {'type': 'regress', 'BackgroundInvalid': True}, + 'crop_type': { + 'type': 'segment', + 'BackgroundInvalid': True, + 'categories': [ + 'invalid', + 'rice', + 'grape', + 'corn', + 'sugarcane', + 'tea', + 'hop', + 'wheat', + 'soy', + 'barley', + 'oats', + 'rye', + 'cassava', + 'potato', + 'sunflower', + 'asparagus', + 'coffee', + ], + 'colors': [ + [0, 0, 0], # unknown + [0, 0, 255], # (blue) rice + [255, 0, 0], # (red) grape + [255, 255, 0], # (yellow) corn + [0, 255, 0], # (green) sugarcane + [128, 0, 128], # (purple) tea + [255, 0, 255], # (pink) hop + [0, 128, 0], # (dark green) wheat + [255, 255, 255], # (white) soy + [128, 128, 128], # (grey) barley + [165, 42, 42], # (brown) oats + [0, 255, 255], # (cyan) rye + [128, 0, 0], # (maroon) cassava + [173, 216, 230], # (light blue) potato + [128, 128, 0], # (olive) sunflower + [0, 128, 0], # (dark green) asparagus + [92, 64, 51], # (dark brown) coffee + ], + }, + 'point': { + 'type': 'detect', + 'categories': [ + 'background', + 'wind_turbine', + 'lighthouse', + 'mineshaft', + 'aerialway_pylon', + 'helipad', + 'fountain', + 'toll_booth', + 'chimney', + 'communications_tower', + 'flagpole', + 'petroleum_well', + 'water_tower', + 'offshore_wind_turbine', + 'offshore_platform', + 'power_tower', + ], + 'colors': [ + [0, 0, 0], + [0, 255, 255], # (cyan) wind_turbine + [0, 255, 0], # (green) lighthouse + [255, 255, 0], # (yellow) mineshaft + [0, 0, 255], # (blue) pylon + [173, 216, 230], # (light blue) helipad + [128, 0, 128], # (purple) fountain + [255, 255, 255], # (white) toll_booth + [0, 128, 0], # (dark green) chimney + [128, 128, 128], # (grey) communications_tower + [165, 42, 42], # (brown) flagpole + [128, 0, 0], # (maroon) petroleum_well + [255, 165, 0], # (orange) water_tower + [255, 255, 0], # (yellow) offshore_wind_turbine + [255, 0, 0], # (red) offshore_platform + [255, 0, 255], # (magenta) power_tower + ], + }, + 'rooftop_solar_panel': { + 'type': 'detect', + 'categories': ['background', 'rooftop_solar_panel'], + 'colors': [ + [0, 0, 0], + [255, 255, 0], # (yellow) rooftop_solar_panel + ], + }, + 'building': { + 'type': 'instance', + 'categories': ['background', 'ms_building'], + 'colors': [ + [0, 0, 0], + [255, 255, 0], # (yellow) building + ], + }, + 'polygon': { + 'type': 'instance', + 'categories': [ + 'background', + 'aquafarm', + 'lock', + 'dam', + 'solar_farm', + 'power_plant', + 'gas_station', + 'park', + 'parking_garage', + 'parking_lot', + 'landfill', + 'quarry', + 'stadium', + 'airport', + 'airport_apron', + 'airport_hangar', + 'airport_terminal', + 'ski_resort', + 'theme_park', + 'storage_tank', + 'silo', + 'track', + 'wastewater_plant', + 'power_substation', + 'pier', + 'crop', + 'water_park', + ], + 'colors': [ + [0, 0, 0], + [255, 255, 0], # (yellow) aquafarm + [0, 255, 255], # (cyan) lock + [0, 255, 0], # (green) dam + [0, 0, 255], # (blue) solar_farm + [255, 0, 0], # (red) power_plant + [128, 0, 128], # (purple) gas_station + [255, 255, 255], # (white) park + [0, 128, 0], # (dark green) parking_garage + [128, 128, 128], # (grey) parking_lot + [165, 42, 42], # (brown) landfill + [128, 0, 0], # (maroon) quarry + [255, 165, 0], # (orange) stadium + [255, 105, 180], # (pink) airport + [192, 192, 192], # (silver) airport_apron + [173, 216, 230], # (light blue) airport_hangar + [32, 178, 170], # (light sea green) airport_terminal + [255, 0, 255], # (magenta) ski_resort + [128, 128, 0], # (olive) theme_park + [47, 79, 79], # (dark slate gray) storage_tank + [255, 215, 0], # (gold) silo + [192, 192, 192], # (light grey) track + [240, 230, 140], # (khaki) wastewater_plant + [154, 205, 50], # (yellow green) power_substation + [255, 165, 0], # (orange) pier + [0, 192, 0], # (middle green) crop + [0, 192, 0], # (middle green) water_park + ], + }, + 'wildfire': { + 'type': 'bin_segment', + 'categories': ['fire_retardant', 'burned'], + 'colors': [ + [255, 0, 0], # (red) fire retardant + [128, 128, 128], # (grey) burned area + ], + }, + 'smoke': {'type': 'classification', 'categories': ['no', 'partial', 'yes']}, + 'snow': {'type': 'classification', 'categories': ['no', 'partial', 'yes']}, + 'dem': {'type': 'regress', 'BackgroundInvalid': True}, + 'airplane': { + 'type': 'detect', + 'categories': ['background', 'airplane'], + 'colors': [ + [0, 0, 0], # (black) background + [255, 0, 0], # (red) airplane + ], + }, + 'vessel': { + 'type': 'detect', + 'categories': ['background', 'vessel'], + 'colors': [ + [0, 0, 0], # (black) background + [255, 0, 0], # (red) vessel + ], + }, + 'water_event': { + 'type': 'segment', + 'BackgroundInvalid': True, + 'categories': ['invalid', 'background', 'water_event'], + 'colors': [ + [0, 0, 0], # (black) invalid + [0, 255, 0], # (green) background + [0, 0, 255], # (blue) water_event + ], + }, + 'park_sport': { + 'type': 'classification', + 'categories': [ + 'american_football', + 'badminton', + 'baseball', + 'basketball', + 'cricket', + 'rugby', + 'soccer', + 'tennis', + 'volleyball', + ], + }, + 'park_type': { + 'type': 'classification', + 'categories': ['park', 'pitch', 'golf_course', 'cemetery'], + }, + 'power_plant_type': { + 'type': 'classification', + 'categories': ['oil', 'nuclear', 'coal', 'gas'], + }, + 'quarry_resource': { + 'type': 'classification', + 'categories': ['sand', 'gravel', 'clay', 'coal', 'peat'], + }, + 'track_sport': { + 'type': 'classification', + 'categories': ['running', 'cycling', 'horse'], + }, + 'road_type': { + 'type': 'classification', + 'categories': [ + 'motorway', + 'trunk', + 'primary', + 'secondary', + 'tertiary', + 'residential', + 'service', + 'track', + 'pedestrian', + ], + }, + 'cloud': { + 'type': 'bin_segment', + 'categories': ['background', 'cloud', 'shadow'], + 'colors': [ + [0, 255, 0], # (green) not clouds or shadows + [255, 255, 255], # (white) clouds + [128, 128, 128], # (grey) shadows + ], + 'BackgroundInvalid': True, + }, + 'flood': { + 'type': 'bin_segment', + 'categories': ['background', 'water'], + 'colors': [ + [0, 255, 0], # (green) background + [0, 0, 255], # (blue) water + ], + 'BackgroundInvalid': True, + }, +} + + +class SatlasPretrain(NonGeoDataset): + """SatlasPretrain dataset. + + `SatlasPretrain `_ is a large-scale pre-training + dataset for tasks that involve understanding satellite images. Regularly-updated + satellite data is publicly available for much of the Earth through sources such as + Sentinel-2 and NAIP, and can inform numerous applications from tackling illegal + deforestation to monitoring marine infrastructure. However, developing automatic + computer vision systems to parse these images requires a huge amount of manual + labeling of training data. By combining over 30 TB of satellite images with 137 + label categories, SatlasPretrain serves as an effective pre-training dataset that + greatly reduces the effort needed to develop robust models for downstream satellite + image applications. + + Reference implementation: + + * https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.48550/arXiv.2211.15660 + + .. versionadded:: 0.7 + + .. note:: + This dataset requires the following additional library to be installed: + + * `AWS CLI `_: to download the dataset from AWS. + """ + + # https://github.com/allenai/satlas/blob/main/satlaspretrain_urls.txt + url = 's3://ai2-public-datasets/satlas/' + tarballs: ClassVar[dict[str, tuple[str, ...]]] = { + 'landsat': ('satlas-dataset-v1-landsat.tar',), + 'naip': ( + 'satlas-dataset-v1-naip-2011.tar', + 'satlas-dataset-v1-naip-2012.tar', + 'satlas-dataset-v1-naip-2013.tar', + 'satlas-dataset-v1-naip-2014.tar', + 'satlas-dataset-v1-naip-2015.tar', + 'satlas-dataset-v1-naip-2016.tar', + 'satlas-dataset-v1-naip-2017.tar', + 'satlas-dataset-v1-naip-2018.tar', + 'satlas-dataset-v1-naip-2019.tar', + 'satlas-dataset-v1-naip-2020.tar', + ), + 'sentinel1': ('satlas-dataset-v1-sentinel1-new.tar',), + 'sentinel2': ( + 'satlas-dataset-v1-sentinel2-a.tar', + 'satlas-dataset-v1-sentinel2-b.tar', + ), + 'static': ('satlas-dataset-v1-labels-static.tar',), + 'dynamic': ('satlas-dataset-v1-labels-dynamic.tar',), + 'metadata': ('satlas-dataset-v1-metadata.tar',), + } + md5s: ClassVar[dict[str, tuple[str, ...]]] = { + 'landsat': ('89ea5e8974826c071908392827780a06',), + 'naip': ( + '523736842994861054f04b97c4d90bfb', + '636b9a3b08be0e40d098cb7b5e655b57', + '69e2b1052b1d2d465322a24cf7207a16', + '38999aea424d403ad60e1398443636aa', + '97f4855072a8a406a4bfbe94c5f7311c', + '9ba3c626b23e6d26749a323eaedc7c0a', + 'e4aba3d198dedfe1524a9338e85794aa', + '74191a36d841b0b9b5d5cbae9a92ad71', + '55b110cc6f734bf88793306d49f1c415', + '97fc8414334987c59593d574f112a77e', + ), + 'sentinel1': ('3d88a0a10df6ab0aa50db2ba4c475048',), + 'sentinel2': ( + '7e1c6a1e322807fb11df8c0c062545ca', + '6636b8ecf2fff1d6723ecfef55a4876d', + ), + 'static': ('4e38c2573bc78cf1f0d7267e432cb42c',), + 'dynamic': ('4503ae687948e7d2cb7ade0083f77a8a',), + 'metadata': ('6b9ac5a4f9a1ee88a271d28f12854607',), + } + + # NOTE: 'tci' is RGB (b04-b02), not BGR (b02-b04) + bands: ClassVar[dict[str, tuple[str, ...]]] = { + 'landsat': tuple(f'b{i}' for i in range(1, 12)), + 'naip': ('tci', 'ir'), + 'sentinel1': ('vh', 'vv'), + 'sentinel2': ('tci', 'b05', 'b06', 'b07', 'b08', 'b11', 'b12'), + } + + chip_size = 512 + + def __init__( + self, + root: Path = 'data', + split: str = 'train_lowres', + good_images: str = 'good_images_lowres_all', + image_times: str = 'image_times', + images: Iterable[str] = ('sentinel1', 'sentinel2', 'landsat'), + labels: Iterable[str] = ('land_cover',), + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new SatlasPretrain instance. + + Args: + root: Root directory where dataset can be found. + split: Metadata split to load. + good_images: Metadata mapping between col/row and directory. + image_times: Metadata mapping between directory and ISO time. + images: List of image products. + labels: List of label products. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + AssertionError: If *images* is invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert set(images) <= set(self.bands.keys()) + + self.root = root + self.images = images + self.labels = labels + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + # Read metadata files + self.split = pd.read_json( + os.path.join(root, 'metadata', f'{split}.json'), typ='frame' + ) + self.good_images = pd.read_json( + os.path.join(root, 'metadata', f'{good_images}.json'), typ='frame' + ) + self.image_times = pd.read_json( + os.path.join(root, 'metadata', f'{image_times}.json'), typ='series' + ) + + self.split.columns = ['col', 'row'] + self.good_images.columns = ['col', 'row', 'directory'] + self.good_images = self.good_images.groupby(['col', 'row']) + + def __len__(self) -> int: + """Return the number of locations in the dataset. + + Returns: + Length of the dataset + """ + return len(self.split) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + col, row = self.split.iloc[index] + directories = self.good_images.get_group((col, row))['directory'] + + sample: dict[str, Tensor] = {} + + for image in self.images: + self._load_image(sample, image, col, row, directories) + + for label in self.labels: + self._load_label(sample, label, col, row) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_image( + self, + sample: dict[str, Tensor], + image: str, + col: int, + row: int, + directories: pd.Series, + ) -> None: + """Load a single image. + + Args: + sample: Dataset sample to populate. + image: Image product. + col: Web Mercator column. + row: Web Mercator row. + directories: Directories that may contain the image. + """ + # Moved in PIL 9.1.0 + try: + resample = Image.Resampling.BILINEAR + except AttributeError: + resample = Image.BILINEAR # type: ignore[attr-defined] + + # Find directories that match image product + good_directories: list[str] = [] + for directory in directories: + path = os.path.join(self.root, image, directory) + if os.path.isdir(path): + good_directories.append(directory) + + # Choose a random timestamp + idx = torch.randint(len(good_directories), (1,)) + directory = good_directories[idx] + time = self.image_times[directory].timestamp() + sample[f'time_{image}'] = torch.tensor(time) + + # Load all bands + channels = [] + for band in self.bands[image]: + path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png') + with Image.open(path) as img: + img = img.resize((self.chip_size, self.chip_size), resample=resample) + array = np.atleast_3d(np.array(img, dtype=np.float32)) + channels.append(torch.tensor(array)) + raster = rearrange(torch.cat(channels, dim=-1), 'h w c -> c h w') + sample[f'image_{image}'] = raster + + def _load_label( + self, sample: dict[str, Tensor], label: str, col: int, row: int + ) -> None: + """Load a single label. + + Args: + sample: Dataset sample to populate. + label: Label product. + col: Web Mercator column. + row: Web Mercator row. + """ + path = os.path.join(self.root, 'static', f'{col}_{row}', f'{label}.png') + if os.path.isfile(path): + with Image.open(path) as img: + raster = torch.tensor(np.array(img, dtype=np.int64)) + else: + raster = torch.zeros(self.chip_size, self.chip_size, dtype=torch.long) + sample[f'mask_{label}'] = raster + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + products = [*self.images, 'metadata'] + if self.labels: + products.append('static') + + for product in products: + # Check if the extracted directory already exists + if os.path.isdir(os.path.join(self.root, product)): + continue + + tarballs = self.tarballs[product] + md5s = self.md5s[product] + for tarball, md5 in zip(tarballs, md5s): + path = os.path.join(self.root, tarball) + + # Check if the tarball has already been downloaded + if os.path.isfile(path): + extract_archive(path) + continue + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download and extract the tarball + aws = which('aws') + aws('s3', 'cp', self.url + tarball, self.root) + check_integrity(path, md5 if self.checksum else None) + extract_archive(path) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + show_titles: Flag indicating whether to show titles above each panel. + suptitle: Optional string to use as a suptitle. + + Returns: + A matplotlib Figure with the rendered sample. + """ + images = [] + titles = [] + for key, value in sample.items(): + match key.split('_', 1): + case ['image', 'landsat']: + images.append(rearrange(value[[3, 2, 1]], 'c h w -> h w c') / 255) + titles.append('Landsat 8/9') + case ['image', 'naip']: + images.append(rearrange(value[:3], 'c h w -> h w c') / 255) + titles.append('NAIP') + case ['image', 'sentinel1']: + images.extend([value[0] / 255, value[1] / 255]) + titles.extend(['Sentinel-1 VH', 'Sentinel-1 VV']) + case ['image', 'sentinel2']: + images.append(rearrange(value[:3], 'c h w -> h w c') / 255) + titles.append('Sentinel-2') + case ['mask' | 'prediction', label]: + cmap = torch.tensor(TASKS[label]['colors']) + images.append(cmap[value]) + titles.append(label.replace('_', ' ').capitalize()) + + fig, ax = plt.subplots(ncols=len(images), squeeze=False) + for i, (image, title) in enumerate(zip(images, titles)): + ax[0, i].imshow(image) + ax[0, i].axis('off') + + if show_titles: + ax[0, i].set_title(title) + + if suptitle is not None: + fig.suptitle(suptitle) + + return fig