diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index eefb9849f..cd6085fbf 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -399,6 +399,11 @@ Rwanda Field Boundary
.. autoclass:: RwandaFieldBoundary
+SatlasPretrain
+^^^^^^^^^^^^^^
+
+.. autoclass:: SatlasPretrain
+
Seasonal Contrast
^^^^^^^^^^^^^^^^^
diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv
index 560869403..3efd05f92 100644
--- a/docs/api/datasets/non_geo_datasets.csv
+++ b/docs/api/datasets/non_geo_datasets.csv
@@ -11,7 +11,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB
`DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB
-`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared
+`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared
`ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR
`EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
@@ -38,6 +38,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB
`RESISC45`_,C,Google Earth,-,"31,500",45,256x256,0.2--30,RGB
`Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR
+`SatlasPretrain`_,"C, R, S, I, OD","NAIP, Landsat, Sentinel",ESA AND CC0-1.0 AND ODbL-1.0 AND CC-BY-4.0,302M,137,512,0.6--30,"SAR, MSI"
`Seasonal Contrast`_,T,Sentinel-2,"CC-BY-4.0",100K--1M,-,264x264,10,MSI
`SeasoNet`_,S,Sentinel-2,"CC-BY-4.0","1,759,830",33,120x120,10,MSI
`SEN12MS`_,S,"Sentinel-1/2, MODIS","CC-BY-4.0","180,662",33,256x256,10,"SAR, MSI"
diff --git a/tests/data/satlas/data.py b/tests/data/satlas/data.py
new file mode 100755
index 000000000..1661f6e42
--- /dev/null
+++ b/tests/data/satlas/data.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import json
+import os
+import shutil
+
+from PIL import Image
+
+SIZE = 32
+landsat_size = {
+ 'b1': SIZE // 2,
+ 'b2': SIZE // 2,
+ 'b3': SIZE // 2,
+ 'b4': SIZE // 2,
+ 'b5': SIZE // 2,
+ 'b6': SIZE // 2,
+ 'b7': SIZE // 2,
+ 'b8': SIZE,
+ 'b9': SIZE // 2,
+ 'b10': SIZE // 2,
+ 'b11': SIZE // 4,
+ 'b12': SIZE // 4,
+}
+
+index = [[7149, 3246], [1234, 5678]]
+good_images = [
+ [7149, 3246, '2022-03'],
+ [1234, 5678, '2022-03'],
+ [7149, 3246, 'm_3808245_se_17_1_20110801'],
+ [1234, 5678, 'm_3808245_se_17_1_20110801'],
+ [7149, 3246, '2022-01'],
+ [1234, 5678, '2022-01'],
+ [7149, 3246, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'],
+ [1234, 5678, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'],
+]
+times = {
+ '2022-03': '2022-03-01T00:00:00+00:00',
+ 'm_3808245_se_17_1_20110801': '2011-08-01T12:00:00+00:00',
+ '2022-01': '2022-01-01T00:00:00+00:00',
+ 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': '2022-03-09T06:02:35+00:00',
+}
+
+FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str]
+filenames: FILENAME_HIERARCHY = {
+ 'landsat': {'2022-03': list(f'b{i}' for i in range(1, 12))},
+ 'naip': {'m_3808245_se_17_1_20110801': ['tci', 'ir']},
+ 'sentinel1': {'2022-01': ['vh', 'vv']},
+ 'sentinel2': {
+ 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': [
+ 'tci',
+ 'b05',
+ 'b06',
+ 'b07',
+ 'b08',
+ 'b11',
+ 'b12',
+ ]
+ },
+}
+
+
+def create_files(path: str) -> None:
+ os.makedirs(path, exist_ok=True)
+ for col, row in index:
+ band = os.path.basename(path)
+ mode = 'RGB' if band == 'tci' else 'L'
+ size = SIZE
+ if 'landsat' in path:
+ size = landsat_size[band]
+ img = Image.new(mode, (size, size))
+ img.save(os.path.join(path, f'{col}_{row}.png'))
+
+
+def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
+ if isinstance(hierarchy, dict):
+ # Recursive case
+ for key, value in hierarchy.items():
+ path = os.path.join(directory, key)
+ create_directory(path, value)
+ else:
+ # Base case
+ for value in hierarchy:
+ path = os.path.join(directory, value)
+ create_files(path)
+
+
+if __name__ == '__main__':
+ create_directory('.', filenames)
+
+ col, row = index[0]
+ path = os.path.join('static', f'{col}_{row}')
+ os.makedirs(path, exist_ok=True)
+ img = Image.new('L', (SIZE, SIZE))
+ img.save(os.path.join(path, 'land_cover.png'))
+
+ os.makedirs('metadata', exist_ok=True)
+ with open(os.path.join('metadata', 'train_lowres.json'), 'w') as f:
+ json.dump(index, f)
+
+ with open(os.path.join('metadata', 'good_images_lowres_all.json'), 'w') as f:
+ json.dump(good_images, f)
+
+ with open(os.path.join('metadata', 'image_times.json'), 'w') as f:
+ json.dump(times, f)
+
+ for path in os.listdir('.'):
+ if os.path.isdir(path):
+ shutil.make_archive(path, 'tar', '.', path)
diff --git a/tests/data/satlas/landsat.tar b/tests/data/satlas/landsat.tar
new file mode 100644
index 000000000..f21ba5980
Binary files /dev/null and b/tests/data/satlas/landsat.tar differ
diff --git a/tests/data/satlas/landsat/2022-03/b1/1234_5678.png b/tests/data/satlas/landsat/2022-03/b1/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b1/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b1/7149_3246.png b/tests/data/satlas/landsat/2022-03/b1/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b1/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b10/1234_5678.png b/tests/data/satlas/landsat/2022-03/b10/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b10/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b10/7149_3246.png b/tests/data/satlas/landsat/2022-03/b10/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b10/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b11/1234_5678.png b/tests/data/satlas/landsat/2022-03/b11/1234_5678.png
new file mode 100644
index 000000000..a7ff273b8
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b11/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b11/7149_3246.png b/tests/data/satlas/landsat/2022-03/b11/7149_3246.png
new file mode 100644
index 000000000..a7ff273b8
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b11/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b2/1234_5678.png b/tests/data/satlas/landsat/2022-03/b2/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b2/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b2/7149_3246.png b/tests/data/satlas/landsat/2022-03/b2/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b2/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b3/1234_5678.png b/tests/data/satlas/landsat/2022-03/b3/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b3/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b3/7149_3246.png b/tests/data/satlas/landsat/2022-03/b3/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b3/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b4/1234_5678.png b/tests/data/satlas/landsat/2022-03/b4/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b4/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b4/7149_3246.png b/tests/data/satlas/landsat/2022-03/b4/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b4/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b5/1234_5678.png b/tests/data/satlas/landsat/2022-03/b5/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b5/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b5/7149_3246.png b/tests/data/satlas/landsat/2022-03/b5/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b5/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b6/1234_5678.png b/tests/data/satlas/landsat/2022-03/b6/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b6/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b6/7149_3246.png b/tests/data/satlas/landsat/2022-03/b6/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b6/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b7/1234_5678.png b/tests/data/satlas/landsat/2022-03/b7/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b7/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b7/7149_3246.png b/tests/data/satlas/landsat/2022-03/b7/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b7/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b8/1234_5678.png b/tests/data/satlas/landsat/2022-03/b8/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b8/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b8/7149_3246.png b/tests/data/satlas/landsat/2022-03/b8/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b8/7149_3246.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b9/1234_5678.png b/tests/data/satlas/landsat/2022-03/b9/1234_5678.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b9/1234_5678.png differ
diff --git a/tests/data/satlas/landsat/2022-03/b9/7149_3246.png b/tests/data/satlas/landsat/2022-03/b9/7149_3246.png
new file mode 100644
index 000000000..1c5ee4e26
Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b9/7149_3246.png differ
diff --git a/tests/data/satlas/metadata.tar b/tests/data/satlas/metadata.tar
new file mode 100644
index 000000000..da55bab05
Binary files /dev/null and b/tests/data/satlas/metadata.tar differ
diff --git a/tests/data/satlas/metadata/good_images_lowres_all.json b/tests/data/satlas/metadata/good_images_lowres_all.json
new file mode 100644
index 000000000..32f868783
--- /dev/null
+++ b/tests/data/satlas/metadata/good_images_lowres_all.json
@@ -0,0 +1 @@
+[[7149, 3246, "2022-03"], [1234, 5678, "2022-03"], [7149, 3246, "m_3808245_se_17_1_20110801"], [1234, 5678, "m_3808245_se_17_1_20110801"], [7149, 3246, "2022-01"], [1234, 5678, "2022-01"], [7149, 3246, "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235"], [1234, 5678, "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235"]]
\ No newline at end of file
diff --git a/tests/data/satlas/metadata/image_times.json b/tests/data/satlas/metadata/image_times.json
new file mode 100644
index 000000000..9028902e0
--- /dev/null
+++ b/tests/data/satlas/metadata/image_times.json
@@ -0,0 +1 @@
+{"2022-03": "2022-03-01T00:00:00+00:00", "m_3808245_se_17_1_20110801": "2011-08-01T12:00:00+00:00", "2022-01": "2022-01-01T00:00:00+00:00", "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235": "2022-03-09T06:02:35+00:00"}
\ No newline at end of file
diff --git a/tests/data/satlas/metadata/train_lowres.json b/tests/data/satlas/metadata/train_lowres.json
new file mode 100644
index 000000000..af40dbffd
--- /dev/null
+++ b/tests/data/satlas/metadata/train_lowres.json
@@ -0,0 +1 @@
+[[7149, 3246], [1234, 5678]]
\ No newline at end of file
diff --git a/tests/data/satlas/naip.tar b/tests/data/satlas/naip.tar
new file mode 100644
index 000000000..c77db0c06
Binary files /dev/null and b/tests/data/satlas/naip.tar differ
diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png differ
diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png differ
diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png
new file mode 100644
index 000000000..1655bc2ca
Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png differ
diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png
new file mode 100644
index 000000000..1655bc2ca
Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel1.tar b/tests/data/satlas/sentinel1.tar
new file mode 100644
index 000000000..755851304
Binary files /dev/null and b/tests/data/satlas/sentinel1.tar differ
diff --git a/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png b/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png b/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png b/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png b/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2.tar b/tests/data/satlas/sentinel2.tar
new file mode 100644
index 000000000..aa3122a90
Binary files /dev/null and b/tests/data/satlas/sentinel2.tar differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png
new file mode 100644
index 000000000..1655bc2ca
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png differ
diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png
new file mode 100644
index 000000000..1655bc2ca
Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png differ
diff --git a/tests/data/satlas/static.tar b/tests/data/satlas/static.tar
new file mode 100644
index 000000000..decccd5ca
Binary files /dev/null and b/tests/data/satlas/static.tar differ
diff --git a/tests/data/satlas/static/7149_3246/land_cover.png b/tests/data/satlas/static/7149_3246/land_cover.png
new file mode 100644
index 000000000..c1620c855
Binary files /dev/null and b/tests/data/satlas/static/7149_3246/land_cover.png differ
diff --git a/tests/datasets/test_satlas.py b/tests/datasets/test_satlas.py
new file mode 100644
index 000000000..7c10f55bd
--- /dev/null
+++ b/tests/datasets/test_satlas.py
@@ -0,0 +1,59 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch.nn as nn
+from pytest import MonkeyPatch
+from torch import Tensor
+
+from torchgeo.datasets import DatasetNotFoundError, SatlasPretrain
+from torchgeo.datasets.utils import Executable
+
+
+class TestSatlasPretrain:
+ @pytest.fixture
+ def dataset(
+ self, aws: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
+ ) -> SatlasPretrain:
+ url = os.path.join('tests', 'data', 'satlas', '')
+ monkeypatch.setattr(SatlasPretrain, 'url', url)
+ images = ('landsat', 'naip', 'sentinel1', 'sentinel2')
+ products = (*images, 'static', 'metadata')
+ tarballs = {product: (f'{product}.tar',) for product in products}
+ monkeypatch.setattr(SatlasPretrain, 'tarballs', tarballs)
+ transforms = nn.Identity()
+ return SatlasPretrain(
+ tmp_path, images=images, transforms=transforms, download=True
+ )
+
+ @pytest.mark.parametrize('index', [0, 1])
+ def test_getitem(self, dataset: SatlasPretrain, index: int) -> None:
+ x = dataset[index]
+ assert isinstance(x, dict)
+ for image in dataset.images:
+ assert isinstance(x[f'image_{image}'], Tensor)
+ assert isinstance(x[f'time_{image}'], Tensor)
+ for label in dataset.labels:
+ assert isinstance(x[f'mask_{label}'], Tensor)
+
+ def test_len(self, dataset: SatlasPretrain) -> None:
+ assert len(dataset) == 2
+
+ def test_already_downloaded(self, dataset: SatlasPretrain) -> None:
+ shutil.rmtree(os.path.join(dataset.root, 'landsat'))
+ SatlasPretrain(root=dataset.root, download=True)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
+ SatlasPretrain(tmp_path)
+
+ def test_plot(self, dataset: SatlasPretrain) -> None:
+ x = dataset[0]
+ x['prediction_land_cover'] = x['mask_land_cover']
+ dataset.plot(x, suptitle='Test')
+ plt.close()
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 663b08e7c..981cf2b26 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -99,6 +99,7 @@ from .quakeset import QuakeSet
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rwanda_field_boundary import RwandaFieldBoundary
+from .satlas import SatlasPretrain
from .seasonet import SeasoNet
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
@@ -244,6 +245,7 @@ __all__ = (
'RESISC45',
'ReforesTree',
'RwandaFieldBoundary',
+ 'SatlasPretrain',
'SeasonalContrastS2',
'SeasoNet',
'SEN12MS',
diff --git a/torchgeo/datasets/satlas.py b/torchgeo/datasets/satlas.py
new file mode 100644
index 000000000..f52a7fef9
--- /dev/null
+++ b/torchgeo/datasets/satlas.py
@@ -0,0 +1,770 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""SatlasPretrain dataset."""
+
+import os
+from collections.abc import Callable, Iterable
+from typing import ClassVar, TypedDict
+
+import numpy as np
+import pandas as pd
+import torch
+from einops import rearrange
+from matplotlib import pyplot as plt
+from matplotlib.figure import Figure
+from PIL import Image
+from torch import Tensor
+
+from .errors import DatasetNotFoundError
+from .geo import NonGeoDataset
+from .utils import Path, check_integrity, extract_archive, which
+
+
+class _Task(TypedDict, total=False):
+ BackgroundInvalid: bool
+ categories: list[str]
+ colors: list[list[int]]
+ type: str
+
+
+# https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py
+TASKS: dict[str, _Task] = {
+ 'polyline_bin_segment': {
+ 'type': 'bin_segment',
+ 'categories': [
+ 'airport_runway',
+ 'airport_taxiway',
+ 'raceway',
+ 'road',
+ 'railway',
+ 'river',
+ ],
+ 'colors': [
+ [255, 255, 255], # (white) airport_runway
+ [192, 192, 192], # (light grey) airport_taxiway
+ [160, 82, 45], # (sienna) raceway
+ [255, 255, 255], # (white) road
+ [144, 238, 144], # (light green) railway
+ [0, 0, 255], # (blue) river
+ ],
+ },
+ 'bin_segment': {
+ 'type': 'bin_segment',
+ 'categories': [
+ 'aquafarm',
+ 'lock',
+ 'dam',
+ 'solar_farm',
+ 'power_plant',
+ 'gas_station',
+ 'park',
+ 'parking_garage',
+ 'parking_lot',
+ 'landfill',
+ 'quarry',
+ 'stadium',
+ 'airport',
+ 'airport_runway',
+ 'airport_taxiway',
+ 'airport_apron',
+ 'airport_hangar',
+ 'airstrip',
+ 'airport_terminal',
+ 'ski_resort',
+ 'theme_park',
+ 'storage_tank',
+ 'silo',
+ 'track',
+ 'raceway',
+ 'wastewater_plant',
+ 'road',
+ 'railway',
+ 'river',
+ 'water_park',
+ 'pier',
+ 'water_tower',
+ 'street_lamp',
+ 'traffic_signals',
+ 'power_tower',
+ 'power_substation',
+ 'building',
+ 'bridge',
+ 'road_motorway',
+ 'road_trunk',
+ 'road_primary',
+ 'road_secondary',
+ 'road_tertiary',
+ 'road_residential',
+ 'road_service',
+ 'road_track',
+ 'road_pedestrian',
+ ],
+ 'colors': [
+ [32, 178, 170], # (light sea green) aquafarm
+ [0, 255, 255], # (cyan) lock
+ [173, 216, 230], # (light blue) dam
+ [255, 0, 255], # (magenta) solar farm
+ [255, 165, 0], # (orange) power plant
+ [128, 128, 0], # (olive) gas station
+ [0, 255, 0], # (green) park
+ [47, 79, 79], # (dark slate gray) parking garage
+ [128, 0, 0], # (maroon) parking lot
+ [165, 42, 42], # (brown) landfill
+ [128, 128, 128], # (grey) quarry
+ [255, 215, 0], # (gold) stadium
+ [255, 105, 180], # (pink) airport
+ [255, 255, 255], # (white) airport_runway
+ [192, 192, 192], # (light grey) airport_taxiway
+ [128, 0, 128], # (purple) airport_apron
+ [0, 128, 0], # (dark green) airport_hangar
+ [248, 248, 255], # (ghost white) airstrip
+ [240, 230, 140], # (khaki) airport_terminal
+ [192, 192, 192], # (silver) ski_resort
+ [0, 96, 0], # (dark green) theme_park
+ [95, 158, 160], # (cadet blue) storage_tank
+ [205, 133, 63], # (peru) silo
+ [154, 205, 50], # (yellow green) track
+ [160, 82, 45], # (sienna) raceway
+ [218, 112, 214], # (orchid) wastewater_plant
+ [255, 255, 255], # (white) road
+ [144, 238, 144], # (light green) railway
+ [0, 0, 255], # (blue) river
+ [255, 240, 245], # (lavender blush) water_park
+ [65, 105, 225], # (royal blue) pier
+ [238, 130, 238], # (violet) water_tower
+ [75, 0, 130], # (indigo) street_lamp
+ [233, 150, 122], # (dark salmon) traffic_signals
+ [255, 255, 0], # (yellow) power_tower
+ [255, 255, 0], # (yellow) power_substation
+ [255, 0, 0], # (red) building
+ [64, 64, 64], # (dark grey) bridge
+ [255, 255, 255], # (white) road_motorway
+ [255, 255, 255], # (white) road_trunk
+ [255, 255, 255], # (white) road_primary
+ [255, 255, 255], # (white) road_secondary
+ [255, 255, 255], # (white) road_tertiary
+ [255, 255, 255], # (white) road_residential
+ [255, 255, 255], # (white) road_service
+ [255, 255, 255], # (white) road_track
+ [255, 255, 255], # (white) road_pedestrian
+ ],
+ },
+ 'land_cover': {
+ 'type': 'segment',
+ 'BackgroundInvalid': True,
+ 'categories': [
+ 'background',
+ 'water',
+ 'developed',
+ 'tree',
+ 'shrub',
+ 'grass',
+ 'crop',
+ 'bare',
+ 'snow',
+ 'wetland',
+ 'mangroves',
+ 'moss',
+ ],
+ 'colors': [
+ [0, 0, 0], # unknown
+ [0, 0, 255], # (blue) water
+ [255, 0, 0], # (red) developed
+ [0, 192, 0], # (dark green) tree
+ [200, 170, 120], # (brown) shrub
+ [0, 255, 0], # (green) grass
+ [255, 255, 0], # (yellow) crop
+ [128, 128, 128], # (grey) bare
+ [255, 255, 255], # (white) snow
+ [0, 255, 255], # (cyan) wetland
+ [255, 0, 255], # (pink) mangroves
+ [128, 0, 128], # (purple) moss
+ ],
+ },
+ 'tree_cover': {'type': 'regress', 'BackgroundInvalid': True},
+ 'crop_type': {
+ 'type': 'segment',
+ 'BackgroundInvalid': True,
+ 'categories': [
+ 'invalid',
+ 'rice',
+ 'grape',
+ 'corn',
+ 'sugarcane',
+ 'tea',
+ 'hop',
+ 'wheat',
+ 'soy',
+ 'barley',
+ 'oats',
+ 'rye',
+ 'cassava',
+ 'potato',
+ 'sunflower',
+ 'asparagus',
+ 'coffee',
+ ],
+ 'colors': [
+ [0, 0, 0], # unknown
+ [0, 0, 255], # (blue) rice
+ [255, 0, 0], # (red) grape
+ [255, 255, 0], # (yellow) corn
+ [0, 255, 0], # (green) sugarcane
+ [128, 0, 128], # (purple) tea
+ [255, 0, 255], # (pink) hop
+ [0, 128, 0], # (dark green) wheat
+ [255, 255, 255], # (white) soy
+ [128, 128, 128], # (grey) barley
+ [165, 42, 42], # (brown) oats
+ [0, 255, 255], # (cyan) rye
+ [128, 0, 0], # (maroon) cassava
+ [173, 216, 230], # (light blue) potato
+ [128, 128, 0], # (olive) sunflower
+ [0, 128, 0], # (dark green) asparagus
+ [92, 64, 51], # (dark brown) coffee
+ ],
+ },
+ 'point': {
+ 'type': 'detect',
+ 'categories': [
+ 'background',
+ 'wind_turbine',
+ 'lighthouse',
+ 'mineshaft',
+ 'aerialway_pylon',
+ 'helipad',
+ 'fountain',
+ 'toll_booth',
+ 'chimney',
+ 'communications_tower',
+ 'flagpole',
+ 'petroleum_well',
+ 'water_tower',
+ 'offshore_wind_turbine',
+ 'offshore_platform',
+ 'power_tower',
+ ],
+ 'colors': [
+ [0, 0, 0],
+ [0, 255, 255], # (cyan) wind_turbine
+ [0, 255, 0], # (green) lighthouse
+ [255, 255, 0], # (yellow) mineshaft
+ [0, 0, 255], # (blue) pylon
+ [173, 216, 230], # (light blue) helipad
+ [128, 0, 128], # (purple) fountain
+ [255, 255, 255], # (white) toll_booth
+ [0, 128, 0], # (dark green) chimney
+ [128, 128, 128], # (grey) communications_tower
+ [165, 42, 42], # (brown) flagpole
+ [128, 0, 0], # (maroon) petroleum_well
+ [255, 165, 0], # (orange) water_tower
+ [255, 255, 0], # (yellow) offshore_wind_turbine
+ [255, 0, 0], # (red) offshore_platform
+ [255, 0, 255], # (magenta) power_tower
+ ],
+ },
+ 'rooftop_solar_panel': {
+ 'type': 'detect',
+ 'categories': ['background', 'rooftop_solar_panel'],
+ 'colors': [
+ [0, 0, 0],
+ [255, 255, 0], # (yellow) rooftop_solar_panel
+ ],
+ },
+ 'building': {
+ 'type': 'instance',
+ 'categories': ['background', 'ms_building'],
+ 'colors': [
+ [0, 0, 0],
+ [255, 255, 0], # (yellow) building
+ ],
+ },
+ 'polygon': {
+ 'type': 'instance',
+ 'categories': [
+ 'background',
+ 'aquafarm',
+ 'lock',
+ 'dam',
+ 'solar_farm',
+ 'power_plant',
+ 'gas_station',
+ 'park',
+ 'parking_garage',
+ 'parking_lot',
+ 'landfill',
+ 'quarry',
+ 'stadium',
+ 'airport',
+ 'airport_apron',
+ 'airport_hangar',
+ 'airport_terminal',
+ 'ski_resort',
+ 'theme_park',
+ 'storage_tank',
+ 'silo',
+ 'track',
+ 'wastewater_plant',
+ 'power_substation',
+ 'pier',
+ 'crop',
+ 'water_park',
+ ],
+ 'colors': [
+ [0, 0, 0],
+ [255, 255, 0], # (yellow) aquafarm
+ [0, 255, 255], # (cyan) lock
+ [0, 255, 0], # (green) dam
+ [0, 0, 255], # (blue) solar_farm
+ [255, 0, 0], # (red) power_plant
+ [128, 0, 128], # (purple) gas_station
+ [255, 255, 255], # (white) park
+ [0, 128, 0], # (dark green) parking_garage
+ [128, 128, 128], # (grey) parking_lot
+ [165, 42, 42], # (brown) landfill
+ [128, 0, 0], # (maroon) quarry
+ [255, 165, 0], # (orange) stadium
+ [255, 105, 180], # (pink) airport
+ [192, 192, 192], # (silver) airport_apron
+ [173, 216, 230], # (light blue) airport_hangar
+ [32, 178, 170], # (light sea green) airport_terminal
+ [255, 0, 255], # (magenta) ski_resort
+ [128, 128, 0], # (olive) theme_park
+ [47, 79, 79], # (dark slate gray) storage_tank
+ [255, 215, 0], # (gold) silo
+ [192, 192, 192], # (light grey) track
+ [240, 230, 140], # (khaki) wastewater_plant
+ [154, 205, 50], # (yellow green) power_substation
+ [255, 165, 0], # (orange) pier
+ [0, 192, 0], # (middle green) crop
+ [0, 192, 0], # (middle green) water_park
+ ],
+ },
+ 'wildfire': {
+ 'type': 'bin_segment',
+ 'categories': ['fire_retardant', 'burned'],
+ 'colors': [
+ [255, 0, 0], # (red) fire retardant
+ [128, 128, 128], # (grey) burned area
+ ],
+ },
+ 'smoke': {'type': 'classification', 'categories': ['no', 'partial', 'yes']},
+ 'snow': {'type': 'classification', 'categories': ['no', 'partial', 'yes']},
+ 'dem': {'type': 'regress', 'BackgroundInvalid': True},
+ 'airplane': {
+ 'type': 'detect',
+ 'categories': ['background', 'airplane'],
+ 'colors': [
+ [0, 0, 0], # (black) background
+ [255, 0, 0], # (red) airplane
+ ],
+ },
+ 'vessel': {
+ 'type': 'detect',
+ 'categories': ['background', 'vessel'],
+ 'colors': [
+ [0, 0, 0], # (black) background
+ [255, 0, 0], # (red) vessel
+ ],
+ },
+ 'water_event': {
+ 'type': 'segment',
+ 'BackgroundInvalid': True,
+ 'categories': ['invalid', 'background', 'water_event'],
+ 'colors': [
+ [0, 0, 0], # (black) invalid
+ [0, 255, 0], # (green) background
+ [0, 0, 255], # (blue) water_event
+ ],
+ },
+ 'park_sport': {
+ 'type': 'classification',
+ 'categories': [
+ 'american_football',
+ 'badminton',
+ 'baseball',
+ 'basketball',
+ 'cricket',
+ 'rugby',
+ 'soccer',
+ 'tennis',
+ 'volleyball',
+ ],
+ },
+ 'park_type': {
+ 'type': 'classification',
+ 'categories': ['park', 'pitch', 'golf_course', 'cemetery'],
+ },
+ 'power_plant_type': {
+ 'type': 'classification',
+ 'categories': ['oil', 'nuclear', 'coal', 'gas'],
+ },
+ 'quarry_resource': {
+ 'type': 'classification',
+ 'categories': ['sand', 'gravel', 'clay', 'coal', 'peat'],
+ },
+ 'track_sport': {
+ 'type': 'classification',
+ 'categories': ['running', 'cycling', 'horse'],
+ },
+ 'road_type': {
+ 'type': 'classification',
+ 'categories': [
+ 'motorway',
+ 'trunk',
+ 'primary',
+ 'secondary',
+ 'tertiary',
+ 'residential',
+ 'service',
+ 'track',
+ 'pedestrian',
+ ],
+ },
+ 'cloud': {
+ 'type': 'bin_segment',
+ 'categories': ['background', 'cloud', 'shadow'],
+ 'colors': [
+ [0, 255, 0], # (green) not clouds or shadows
+ [255, 255, 255], # (white) clouds
+ [128, 128, 128], # (grey) shadows
+ ],
+ 'BackgroundInvalid': True,
+ },
+ 'flood': {
+ 'type': 'bin_segment',
+ 'categories': ['background', 'water'],
+ 'colors': [
+ [0, 255, 0], # (green) background
+ [0, 0, 255], # (blue) water
+ ],
+ 'BackgroundInvalid': True,
+ },
+}
+
+
+class SatlasPretrain(NonGeoDataset):
+ """SatlasPretrain dataset.
+
+ `SatlasPretrain `_ is a large-scale pre-training
+ dataset for tasks that involve understanding satellite images. Regularly-updated
+ satellite data is publicly available for much of the Earth through sources such as
+ Sentinel-2 and NAIP, and can inform numerous applications from tackling illegal
+ deforestation to monitoring marine infrastructure. However, developing automatic
+ computer vision systems to parse these images requires a huge amount of manual
+ labeling of training data. By combining over 30 TB of satellite images with 137
+ label categories, SatlasPretrain serves as an effective pre-training dataset that
+ greatly reduces the effort needed to develop robust models for downstream satellite
+ image applications.
+
+ Reference implementation:
+
+ * https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://doi.org/10.48550/arXiv.2211.15660
+
+ .. versionadded:: 0.7
+
+ .. note::
+ This dataset requires the following additional library to be installed:
+
+ * `AWS CLI `_: to download the dataset from AWS.
+ """
+
+ # https://github.com/allenai/satlas/blob/main/satlaspretrain_urls.txt
+ url = 's3://ai2-public-datasets/satlas/'
+ tarballs: ClassVar[dict[str, tuple[str, ...]]] = {
+ 'landsat': ('satlas-dataset-v1-landsat.tar',),
+ 'naip': (
+ 'satlas-dataset-v1-naip-2011.tar',
+ 'satlas-dataset-v1-naip-2012.tar',
+ 'satlas-dataset-v1-naip-2013.tar',
+ 'satlas-dataset-v1-naip-2014.tar',
+ 'satlas-dataset-v1-naip-2015.tar',
+ 'satlas-dataset-v1-naip-2016.tar',
+ 'satlas-dataset-v1-naip-2017.tar',
+ 'satlas-dataset-v1-naip-2018.tar',
+ 'satlas-dataset-v1-naip-2019.tar',
+ 'satlas-dataset-v1-naip-2020.tar',
+ ),
+ 'sentinel1': ('satlas-dataset-v1-sentinel1-new.tar',),
+ 'sentinel2': (
+ 'satlas-dataset-v1-sentinel2-a.tar',
+ 'satlas-dataset-v1-sentinel2-b.tar',
+ ),
+ 'static': ('satlas-dataset-v1-labels-static.tar',),
+ 'dynamic': ('satlas-dataset-v1-labels-dynamic.tar',),
+ 'metadata': ('satlas-dataset-v1-metadata.tar',),
+ }
+ md5s: ClassVar[dict[str, tuple[str, ...]]] = {
+ 'landsat': ('89ea5e8974826c071908392827780a06',),
+ 'naip': (
+ '523736842994861054f04b97c4d90bfb',
+ '636b9a3b08be0e40d098cb7b5e655b57',
+ '69e2b1052b1d2d465322a24cf7207a16',
+ '38999aea424d403ad60e1398443636aa',
+ '97f4855072a8a406a4bfbe94c5f7311c',
+ '9ba3c626b23e6d26749a323eaedc7c0a',
+ 'e4aba3d198dedfe1524a9338e85794aa',
+ '74191a36d841b0b9b5d5cbae9a92ad71',
+ '55b110cc6f734bf88793306d49f1c415',
+ '97fc8414334987c59593d574f112a77e',
+ ),
+ 'sentinel1': ('3d88a0a10df6ab0aa50db2ba4c475048',),
+ 'sentinel2': (
+ '7e1c6a1e322807fb11df8c0c062545ca',
+ '6636b8ecf2fff1d6723ecfef55a4876d',
+ ),
+ 'static': ('4e38c2573bc78cf1f0d7267e432cb42c',),
+ 'dynamic': ('4503ae687948e7d2cb7ade0083f77a8a',),
+ 'metadata': ('6b9ac5a4f9a1ee88a271d28f12854607',),
+ }
+
+ # NOTE: 'tci' is RGB (b04-b02), not BGR (b02-b04)
+ bands: ClassVar[dict[str, tuple[str, ...]]] = {
+ 'landsat': tuple(f'b{i}' for i in range(1, 12)),
+ 'naip': ('tci', 'ir'),
+ 'sentinel1': ('vh', 'vv'),
+ 'sentinel2': ('tci', 'b05', 'b06', 'b07', 'b08', 'b11', 'b12'),
+ }
+
+ chip_size = 512
+
+ def __init__(
+ self,
+ root: Path = 'data',
+ split: str = 'train_lowres',
+ good_images: str = 'good_images_lowres_all',
+ image_times: str = 'image_times',
+ images: Iterable[str] = ('sentinel1', 'sentinel2', 'landsat'),
+ labels: Iterable[str] = ('land_cover',),
+ transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new SatlasPretrain instance.
+
+ Args:
+ root: Root directory where dataset can be found.
+ split: Metadata split to load.
+ good_images: Metadata mapping between col/row and directory.
+ image_times: Metadata mapping between directory and ISO time.
+ images: List of image products.
+ labels: List of label products.
+ transforms: A function/transform that takes input sample and its target as
+ entry and returns a transformed version.
+ download: If True, download dataset and store it in the root directory.
+ checksum: If True, check the MD5 of the downloaded files (may be slow).
+
+ Raises:
+ AssertionError: If *images* is invalid.
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+ """
+ assert set(images) <= set(self.bands.keys())
+
+ self.root = root
+ self.images = images
+ self.labels = labels
+ self.transforms = transforms
+ self.download = download
+ self.checksum = checksum
+
+ self._verify()
+
+ # Read metadata files
+ self.split = pd.read_json(
+ os.path.join(root, 'metadata', f'{split}.json'), typ='frame'
+ )
+ self.good_images = pd.read_json(
+ os.path.join(root, 'metadata', f'{good_images}.json'), typ='frame'
+ )
+ self.image_times = pd.read_json(
+ os.path.join(root, 'metadata', f'{image_times}.json'), typ='series'
+ )
+
+ self.split.columns = ['col', 'row']
+ self.good_images.columns = ['col', 'row', 'directory']
+ self.good_images = self.good_images.groupby(['col', 'row'])
+
+ def __len__(self) -> int:
+ """Return the number of locations in the dataset.
+
+ Returns:
+ Length of the dataset
+ """
+ return len(self.split)
+
+ def __getitem__(self, index: int) -> dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: Index to return.
+
+ Returns:
+ Data and label at that index.
+ """
+ col, row = self.split.iloc[index]
+ directories = self.good_images.get_group((col, row))['directory']
+
+ sample: dict[str, Tensor] = {}
+
+ for image in self.images:
+ self._load_image(sample, image, col, row, directories)
+
+ for label in self.labels:
+ self._load_label(sample, label, col, row)
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def _load_image(
+ self,
+ sample: dict[str, Tensor],
+ image: str,
+ col: int,
+ row: int,
+ directories: pd.Series,
+ ) -> None:
+ """Load a single image.
+
+ Args:
+ sample: Dataset sample to populate.
+ image: Image product.
+ col: Web Mercator column.
+ row: Web Mercator row.
+ directories: Directories that may contain the image.
+ """
+ # Moved in PIL 9.1.0
+ try:
+ resample = Image.Resampling.BILINEAR
+ except AttributeError:
+ resample = Image.BILINEAR # type: ignore[attr-defined]
+
+ # Find directories that match image product
+ good_directories: list[str] = []
+ for directory in directories:
+ path = os.path.join(self.root, image, directory)
+ if os.path.isdir(path):
+ good_directories.append(directory)
+
+ # Choose a random timestamp
+ idx = torch.randint(len(good_directories), (1,))
+ directory = good_directories[idx]
+ time = self.image_times[directory].timestamp()
+ sample[f'time_{image}'] = torch.tensor(time)
+
+ # Load all bands
+ channels = []
+ for band in self.bands[image]:
+ path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png')
+ with Image.open(path) as img:
+ img = img.resize((self.chip_size, self.chip_size), resample=resample)
+ array = np.atleast_3d(np.array(img, dtype=np.float32))
+ channels.append(torch.tensor(array))
+ raster = rearrange(torch.cat(channels, dim=-1), 'h w c -> c h w')
+ sample[f'image_{image}'] = raster
+
+ def _load_label(
+ self, sample: dict[str, Tensor], label: str, col: int, row: int
+ ) -> None:
+ """Load a single label.
+
+ Args:
+ sample: Dataset sample to populate.
+ label: Label product.
+ col: Web Mercator column.
+ row: Web Mercator row.
+ """
+ path = os.path.join(self.root, 'static', f'{col}_{row}', f'{label}.png')
+ if os.path.isfile(path):
+ with Image.open(path) as img:
+ raster = torch.tensor(np.array(img, dtype=np.int64))
+ else:
+ raster = torch.zeros(self.chip_size, self.chip_size, dtype=torch.long)
+ sample[f'mask_{label}'] = raster
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset."""
+ products = [*self.images, 'metadata']
+ if self.labels:
+ products.append('static')
+
+ for product in products:
+ # Check if the extracted directory already exists
+ if os.path.isdir(os.path.join(self.root, product)):
+ continue
+
+ tarballs = self.tarballs[product]
+ md5s = self.md5s[product]
+ for tarball, md5 in zip(tarballs, md5s):
+ path = os.path.join(self.root, tarball)
+
+ # Check if the tarball has already been downloaded
+ if os.path.isfile(path):
+ extract_archive(path)
+ continue
+
+ # Check if the user requested to download the dataset
+ if not self.download:
+ raise DatasetNotFoundError(self)
+
+ # Download and extract the tarball
+ aws = which('aws')
+ aws('s3', 'cp', self.url + tarball, self.root)
+ check_integrity(path, md5 if self.checksum else None)
+ extract_archive(path)
+
+ def plot(
+ self,
+ sample: dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: str | None = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: A sample returned by :meth:`__getitem__`.
+ show_titles: Flag indicating whether to show titles above each panel.
+ suptitle: Optional string to use as a suptitle.
+
+ Returns:
+ A matplotlib Figure with the rendered sample.
+ """
+ images = []
+ titles = []
+ for key, value in sample.items():
+ match key.split('_', 1):
+ case ['image', 'landsat']:
+ images.append(rearrange(value[[3, 2, 1]], 'c h w -> h w c') / 255)
+ titles.append('Landsat 8/9')
+ case ['image', 'naip']:
+ images.append(rearrange(value[:3], 'c h w -> h w c') / 255)
+ titles.append('NAIP')
+ case ['image', 'sentinel1']:
+ images.extend([value[0] / 255, value[1] / 255])
+ titles.extend(['Sentinel-1 VH', 'Sentinel-1 VV'])
+ case ['image', 'sentinel2']:
+ images.append(rearrange(value[:3], 'c h w -> h w c') / 255)
+ titles.append('Sentinel-2')
+ case ['mask' | 'prediction', label]:
+ cmap = torch.tensor(TASKS[label]['colors'])
+ images.append(cmap[value])
+ titles.append(label.replace('_', ' ').capitalize())
+
+ fig, ax = plt.subplots(ncols=len(images), squeeze=False)
+ for i, (image, title) in enumerate(zip(images, titles)):
+ ax[0, i].imshow(image)
+ ax[0, i].axis('off')
+
+ if show_titles:
+ ax[0, i].set_title(title)
+
+ if suptitle is not None:
+ fig.suptitle(suptitle)
+
+ return fig