SSL4EO-L: add download support (#1424)

* SSL4EO-L: add download support

* Placate pydocstyle

* Meaning of root changed

* Missing one

* Incremental tarball concatenation

* Placate black

* Add download times
This commit is contained in:
Adam J. Stewart 2023-06-20 11:54:10 -05:00 коммит произвёл GitHub
Родитель a040c88fd1
Коммит 38be536b86
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
93 изменённых файлов: 234 добавлений и 29 удалений

Просмотреть файл

@ -8,7 +8,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/tm_toa"
root: "tests/data/ssl4eo/l"
split: "tm_toa"
seasons: 1
batch_size: 2

Просмотреть файл

@ -8,7 +8,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/etm_sr"
root: "tests/data/ssl4eo/l"
split: "etm_sr"
seasons: 2
batch_size: 2

Просмотреть файл

@ -10,7 +10,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/etm_toa"
root: "tests/data/ssl4eo/l"
split: "etm_toa"
seasons: 1
batch_size: 2

Просмотреть файл

@ -13,7 +13,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/oli_tirs_toa"
root: "tests/data/ssl4eo/l"
split: "oli_tirs_toa"
seasons: 2
batch_size: 2

Просмотреть файл

@ -11,7 +11,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/oli_sr"
root: "tests/data/ssl4eo/l"
split: "oli_sr"
seasons: 1
batch_size: 2

Просмотреть файл

@ -11,7 +11,7 @@ module:
datamodule:
_target_: torchgeo.datamodules.SSL4EOLDataModule
root: "tests/data/ssl4eo/l/tm_toa"
root: "tests/data/ssl4eo/l"
split: "tm_toa"
seasons: 2
batch_size: 2

Просмотреть файл

@ -14,13 +14,14 @@ from rasterio import Affine
from rasterio.crs import CRS
SIZE = 36
CHUNK_SIZE = 2**12
np.random.seed(0)
FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]
filenames: FILENAME_HIERARCHY = {
"tm_toa": {
"ssl4eo_l_tm_toa": {
"0000002": {
"LT05_172034_20010526": ["all_bands.tif"],
"LT05_172034_20020310": ["all_bands.tif"],
@ -34,7 +35,7 @@ filenames: FILENAME_HIERARCHY = {
"LT5_223084_20020923": ["all_bands.tif"],
},
},
"etm_sr": {
"ssl4eo_l_etm_toa": {
"0000002": {
"LE07_172034_20010526": ["all_bands.tif"],
"LE07_172034_20020310": ["all_bands.tif"],
@ -48,7 +49,7 @@ filenames: FILENAME_HIERARCHY = {
"LE07_223084_20020923": ["all_bands.tif"],
},
},
"etm_toa": {
"ssl4eo_l_etm_sr": {
"0000002": {
"LE07_172034_20010526": ["all_bands.tif"],
"LE07_172034_20020310": ["all_bands.tif"],
@ -62,7 +63,7 @@ filenames: FILENAME_HIERARCHY = {
"LE07_223084_20020923": ["all_bands.tif"],
},
},
"oli_tirs_toa": {
"ssl4eo_l_oli_tirs_toa": {
"0000002": {
"LC08_172034_20210306": ["all_bands.tif"],
"LC08_172034_20210829": ["all_bands.tif"],
@ -76,7 +77,7 @@ filenames: FILENAME_HIERARCHY = {
"LC08_223084_20221211": ["all_bands.tif"],
},
},
"oli_sr": {
"ssl4eo_l_oli_sr": {
"0000002": {
"LC08_172034_20210306": ["all_bands.tif"],
"LC08_172034_20210829": ["all_bands.tif"],
@ -92,7 +93,13 @@ filenames: FILENAME_HIERARCHY = {
},
}
num_bands = {"tm_toa": 7, "etm_sr": 6, "etm_toa": 9, "oli_tirs_toa": 11, "oli_sr": 7}
num_bands = {
"ssl4eo_l_tm_toa": 7,
"ssl4eo_l_etm_toa": 9,
"ssl4eo_l_etm_sr": 6,
"ssl4eo_l_oli_tirs_toa": 11,
"ssl4eo_l_oli_sr": 7,
}
def create_file(path: str) -> None:
@ -141,10 +148,25 @@ if __name__ == "__main__":
directories = filenames.keys()
for directory in directories:
# Create tarballs
# Create tarball
shutil.make_archive(directory, "gztar", ".", directory)
# Split tarball
path = f"{directory}.tar.gz"
paths = []
with open(path, "rb") as f:
suffix = "a"
while chunk := f.read(CHUNK_SIZE):
split = f"{path}a{suffix}"
with open(split, "wb") as g:
g.write(chunk)
suffix = chr(ord(suffix) + 1)
paths.append(split)
os.remove(path)
# Compute checksums
with open(f"{directory}.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(directory, md5)
for path in paths:
with open(path, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(path, md5)

Двоичные данные
tests/data/ssl4eo/l/etm_sr.tar.gz

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/etm_toa.tar.gz

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/oli_sr.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/oli_tirs_toa.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_etm_sr.tar.gzaa Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_etm_sr.tar.gzab Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_etm_sr.tar.gzac Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_etm_toa.tar.gzaa Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_etm_toa.tar.gzab Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_etm_toa.tar.gzac Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_oli_sr.tar.gzaa Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_oli_sr.tar.gzab Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_oli_sr.tar.gzac Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_oli_tirs_toa.tar.gzaa Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_oli_tirs_toa.tar.gzab Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_oli_tirs_toa.tar.gzac Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_tm_toa.tar.gzaa Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_tm_toa.tar.gzab Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/ssl4eo_l_tm_toa.tar.gzac Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ssl4eo/l/tm_toa.tar.gz

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
@ -13,16 +14,57 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo
from torchgeo.datasets import SSL4EOL, SSL4EOS12
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestSSL4EOL:
@pytest.fixture(params=zip(SSL4EOL.metadata.keys(), [1, 1, 2, 2, 4]))
def dataset(self, request: SubRequest) -> SSL4EOL:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> SSL4EOL:
monkeypatch.setattr(torchgeo.datasets.ssl4eo, "download_url", download_url)
url = os.path.join("tests", "data", "ssl4eo", "l", "ssl4eo_l_{0}.tar.gz{1}")
monkeypatch.setattr(SSL4EOL, "url", url)
checksums = {
"tm_toa": {
"aa": "010b9d72b476e0e30741c17725f84e5c",
"ab": "39171bd7bca8a56a8cb339a0f88da9d3",
"ac": "3cfc407ce3f4f4d6e3c5fdb457bb87da",
},
"etm_toa": {
"aa": "87e47278f5a30acd3b696b6daaa4713b",
"ab": "59295e1816e08a5acd3a18ae56b6f32e",
"ac": "f3ff76eb6987501000228ce15684e09f",
},
"etm_sr": {
"aa": "fd61a4154eafaeb350dbb01a2551a818",
"ab": "0c3117bc7682ba9ffdc6871e6c364b36",
"ac": "93d3385e47de4578878ca5c4fa6a628d",
},
"oli_tirs_toa": {
"aa": "defb9e91a73b145b2dbe347649bded06",
"ab": "97f7edaa4e288fc14ec7581dccea766f",
"ac": "7472fad9929a0dc96ccf4dc6c804b92f",
},
"oli_sr": {
"aa": "8fd3aa6b581d024299f44457956faa05",
"ab": "7eb4d761ce1afd89cae9c6142ca17882",
"ac": "a3210da9fcc71e3a4efde71c30d78c59",
},
}
monkeypatch.setattr(SSL4EOL, "checksums", checksums)
root = str(tmp_path)
split, seasons = request.param
root = os.path.join("tests", "data", "ssl4eo", "l", split)
transforms = nn.Identity()
return SSL4EOL(root, split, seasons, transforms)
return SSL4EOL(root, split, seasons, transforms, download=True, checksum=True)
def test_getitem(self, dataset: SSL4EOL) -> None:
x = dataset[0]
@ -41,6 +83,20 @@ class TestSSL4EOL:
assert isinstance(ds, ConcatDataset)
assert len(ds) == 2 * 2
def test_already_extracted(self, dataset: SSL4EOL) -> None:
SSL4EOL(dataset.root, dataset.split, dataset.seasons)
def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "ssl4eo", "l", "*.tar.gz*")
root = str(tmp_path)
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
SSL4EOL(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SSL4EOL(str(tmp_path))
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOL(split="foo")

Просмотреть файл

@ -3,6 +3,7 @@
"""Self-Supervised Learning for Earth Observation."""
import glob
import os
import random
from typing import Callable, Optional, TypedDict
@ -14,7 +15,7 @@ import torch
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, extract_archive
from .utils import check_integrity, download_url, extract_archive
class SSL4EO(NonGeoDataset):
@ -33,8 +34,8 @@ class SSL4EOL(NonGeoDataset):
Landsat version of SSL4EO.
The dataset consists of a parallel corpus (same locations for all splits, same dates
for SR/TOA) for the following sensors:
The dataset consists of a parallel corpus (same locations and dates for SR/TOA)
for the following sensors:
.. list-table::
:widths: 10 10 10 10 10
@ -50,8 +51,8 @@ class SSL4EOL(NonGeoDataset):
- TOA
- 7
- `GEE <https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LT05_C02_T1_TOA>`__
* - Landsat 4--7
- TM
* - Landsat 7
- ETM+
- SR
- 6
- `GEE <https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LT05_C02_T1_L2>`__
@ -77,6 +78,13 @@ class SSL4EOL(NonGeoDataset):
* Resampled to 30 m resolution (7920 x 7920 m)
* Single multispectral GeoTIFF file
.. note::
Each split is 300--400 GB and requires 3x that to concatenate and extract
tarballs. Tarballs can be safely deleted after extraction to save space.
The dataset takes about 1.5 hrs to download and checksum and another 3 hrs
to extract.
.. versionadded:: 0.5
""" # noqa: E501
@ -86,40 +94,103 @@ class SSL4EOL(NonGeoDataset):
metadata: dict[str, _Metadata] = {
"tm_toa": {"num_bands": 7, "rgb_bands": [2, 1, 0]},
"etm_sr": {"num_bands": 6, "rgb_bands": [2, 1, 0]},
"etm_toa": {"num_bands": 9, "rgb_bands": [2, 1, 0]},
"etm_sr": {"num_bands": 6, "rgb_bands": [2, 1, 0]},
"oli_tirs_toa": {"num_bands": 11, "rgb_bands": [3, 2, 1]},
"oli_sr": {"num_bands": 7, "rgb_bands": [3, 2, 1]},
}
url = "https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/main/{0}/ssl4eo_l_{0}.tar.gz{1}" # noqa: E501
checksums = {
"tm_toa": {
"aa": "553795b8d73aa253445b1e67c5b81f11",
"ab": "e9e0739b5171b37d16086cb89ab370e8",
"ac": "6cb27189f6abe500c67343bfcab2432c",
"ad": "15a885d4f544d0c1849523f689e27402",
"ae": "35523336bf9f8132f38ff86413dcd6dc",
"af": "fa1108436034e6222d153586861f663b",
"ag": "d5c91301c115c00acaf01ceb3b78c0fe",
},
"etm_toa": {
"aa": "587c3efc7d0a0c493dfb36139d91ccdf",
"ab": "ec34f33face893d2d8fd152496e1df05",
"ac": "947acc2c6bc3c1d1415ac92bab695380",
"ad": "e31273dec921e187f5c0dc73af5b6102",
"ae": "43390a47d138593095e9a6775ae7dc75",
"af": "082881464ca6dcbaa585f72de1ac14fd",
"ag": "de2511aaebd640bd5e5404c40d7494cb",
"ah": "124c5fbcda6871f27524ae59480dabc5",
"ai": "12b5f94824b7f102df30a63b1139fc57",
},
"etm_sr": {
"aa": "baa36a9b8e42e234bb44ab4046f8f2ac",
"ab": "9fb0f948c76154caabe086d2d0008fdf",
"ac": "99a55367178373805d357a096d68e418",
"ad": "59d53a643b9e28911246d4609744ef25",
"ae": "7abfcfc57528cb9c619c66ee307a2cc9",
"af": "bb23cf26cc9fe156e7a68589ec69f43e",
"ag": "97347e5a81d24c93cf33d99bb46a5b91",
},
"oli_tirs_toa": {
"aa": "4711369b861c856ebfadbc861e928d3a",
"ab": "660a96cda1caf54df837c4b3c6c703f6",
"ac": "c9b6a1117916ba318ac3e310447c60dc",
"ad": "b8502e9e92d4a7765a287d21d7c9146c",
"ae": "5c11c14cfe45f78de4f6d6faf03f3146",
"af": "5b0ed3901be1000137ddd3a6d58d5109",
"ag": "a3b6734f8fe6763dcf311c9464a05d5b",
"ah": "5e55f92e3238a8ab3e471be041f8111b",
"ai": "e20617f73d0232a0c0472ce336d4c92f",
},
"oli_sr": {
"aa": "ca338511c9da4dcbfddda28b38ca9e0a",
"ab": "7f4100aa9791156958dccf1bb2a88ae0",
"ac": "6b0f18be2b63ba9da194cc7886dbbc01",
"ad": "57efbcc894d8da8c4975c29437d8b775",
"ae": "2594a0a856897f3f5a902c830186872d",
"af": "a03839311a2b3dc17dfb9fb9bc4f9751",
"ag": "6a329d8fd9fdd591e400ab20f9d11dea",
},
}
def __init__(
self,
root: str = "data",
split: str = "oli_sr",
seasons: int = 1,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new SSL4EOL instance.
Args:
root: root directory where dataset can be found
split: one of ['tm_toa', 'etm_sr', 'etm_toa', 'oli_tirs_toa', 'oli_sr']
split: one of ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr']
seasons: number of seasonal patches to sample per location, 1--4
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 after downloading files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
AssertionError: if any arguments are invalid
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
assert split in self.metadata
assert seasons in range(1, 5)
self.root = root
self.subdir = os.path.join(root, f"ssl4eo_l_{split}")
self.split = split
self.seasons = seasons
self.transforms = transforms
self.download = download
self.checksum = checksum
self.scenes = sorted(os.listdir(root))
self._verify()
self.scenes = sorted(os.listdir(self.subdir))
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
@ -130,7 +201,7 @@ class SSL4EOL(NonGeoDataset):
Returns:
image sample
"""
root = os.path.join(self.root, self.scenes[index])
root = os.path.join(self.subdir, self.scenes[index])
subdirs = os.listdir(root)
subdirs = random.sample(subdirs, self.seasons)
@ -157,6 +228,62 @@ class SSL4EOL(NonGeoDataset):
"""
return len(self.scenes)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
path = os.path.join(self.subdir, "00000*", "*", "all_bands.tif")
if glob.glob(path):
return
# Check if the tar.gz files have already been downloaded
exists = []
for suffix in self.checksums[self.split]:
path = self.subdir + f".tar.gz{suffix}"
exists.append(os.path.exists(path))
if all(exists):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
for suffix, md5 in self.checksums[self.split].items():
download_url(
self.url.format(self.split, suffix),
self.root,
md5=md5 if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
# Concatenate all tarballs together
chunk_size = 2**15 # same as torchvision
path = self.subdir + ".tar.gz"
with open(path, "wb") as f:
for suffix in self.checksums[self.split]:
with open(path + suffix, "rb") as g:
while chunk := g.read(chunk_size):
f.write(chunk)
# Extract the concatenated tarball
extract_archive(path)
def plot(
self,
sample: dict[str, Tensor],