Support for radiant-mlhub 0.5+ (#1102)

* update datasets and tests to support radiant-mlhub>0.5

* add test coverage for nasa_marine_debris corrupted cases

* style fixes

* Correct return type in test_nasa_marine_debris.py

* Update setup.cfg to limit radiant-mlhub version

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* radiant-mlhub version updates to <0.6

* Update environment.yml to not upper bound radiant-mlhub

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
SpontaneousDuck 2023-02-17 15:30:40 -05:00 коммит произвёл Caleb Robinson
Родитель 0bcb6914a0
Коммит 424751ab52
13 изменённых файлов: 93 добавлений и 46 удалений

3
.github/dependabot.yml поставляемый
Просмотреть файл

@ -9,9 +9,6 @@ updates:
schedule:
interval: "daily"
ignore:
# radiant-mlhub 0.5+ changed download behavior:
# https://github.com/radiantearth/radiant-mlhub/pull/104
- dependency-name: "radiant-mlhub"
# setuptools releases new versions almost daily
- dependency-name: "setuptools"
update-types: ["version-update:semver-patch"]

Просмотреть файл

@ -36,6 +36,7 @@ dependencies:
- pytest-cov>=2.4
- pytorch-lightning>=1.5.1
- git+https://github.com/pytorch/pytorch_sphinx_theme
- pyupgrade>=2.4
- radiant-mlhub>=0.2.1
- rtree>=1
- scikit-image>=0.18

Просмотреть файл

@ -84,13 +84,11 @@ datasets =
pycocotools>=2.0.1,<3
# pyvista 0.25.2 required for wheels
pyvista>=0.25.2,<0.39
# radiant-mlhub 0.2.1+ required for api_key bugfix:
# https://github.com/radiantearth/radiant-mlhub/pull/48
# radiant-mlhub 0.5+ changed download behavior:
# https://github.com/radiantearth/radiant-mlhub/pull/104
radiant-mlhub>=0.2.1,<0.5
# rarfile 3+ required for correct Rar file detection
rarfile>=3,<5
radiant-mlhub>=0.2.1,<0.6
# rarfile 4+ required for wheels
rarfile>=4,<5
# scikit-image 0.18+ required for numpy 1.17+ compatibility
# https://github.com/scikit-image/scikit-image/issues/3655
scikit-image>=0.18,<0.20

Просмотреть файл

@ -16,15 +16,15 @@ from torch.utils.data import ConcatDataset
from torchgeo.datasets import BeninSmallHolderCashews
class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()
class TestBeninSmallHolderCashews:
@ -33,7 +33,7 @@ class TestBeninSmallHolderCashews:
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> BeninSmallHolderCashews:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
source_md5 = "255efff0f03bc6322470949a09bc76db"
labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127"
monkeypatch.setitem(BeninSmallHolderCashews.image_meta, "md5", source_md5)

Просмотреть файл

@ -15,7 +15,7 @@ from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import CloudCoverDetection
class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_cloud_cover_detection_challenge_v1", "*.tar.gz"
@ -24,15 +24,15 @@ class Dataset:
shutil.copy(tarball, output_dir)
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()
class TestCloudCoverDetection:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
test_image_meta = {
"filename": "ref_cloud_cover_detection_challenge_v1_test_source.tar.gz",

Просмотреть файл

@ -16,7 +16,7 @@ from torch.utils.data import ConcatDataset
from torchgeo.datasets import CV4AKenyaCropType
class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
@ -25,15 +25,15 @@ class Dataset:
shutil.copy(tarball, output_dir)
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()
class TestCV4AKenyaCropType:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
source_md5 = "7f4dcb3f33743dddd73f453176308bfb"
labels_md5 = "95fc59f1d94a85ec00931d4d1280bec9"
monkeypatch.setitem(CV4AKenyaCropType.image_meta, "md5", source_md5)

Просмотреть файл

@ -17,14 +17,14 @@ from torch.utils.data import ConcatDataset
from torchgeo.datasets import TropicalCyclone
class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
for tarball in glob.iglob(os.path.join("tests", "data", "cyclone", "*.tar.gz")):
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()
class TestTropicalCyclone:
@ -33,7 +33,7 @@ class TestTropicalCyclone:
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> TropicalCyclone:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
md5s = {
"train": {
"source": "2b818e0a0873728dabf52c7054a0ce4c",

Просмотреть файл

@ -15,23 +15,35 @@ from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import NASAMarineDebris
class Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "nasa_marine_debris", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()
class Collection_corrupted:
def download(self, output_dir: str, **kwargs: str) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(output_dir, filename), "w") as f:
f.write("bad")
def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted:
return Collection_corrupted()
class TestNASAMarineDebris:
@pytest.fixture()
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Dataset, "fetch", fetch)
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
md5s = ["6f4f0d2313323950e45bf3fc0c09b5de", "540cf1cf4fd2c13b609d0355abe955d7"]
monkeypatch.setattr(NASAMarineDebris, "md5s", md5s)
root = str(tmp_path)
transforms = nn.Identity()
@ -58,9 +70,25 @@ class TestNASAMarineDebris:
) -> None:
shutil.rmtree(dataset.root)
os.makedirs(str(tmp_path), exist_ok=True)
Dataset().download(output_dir=str(tmp_path))
Collection().download(output_dir=str(tmp_path))
NASAMarineDebris(root=str(tmp_path), download=False)
def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None:
filenames = NASAMarineDebris.filenames
for filename in filenames:
with open(os.path.join(tmp_path, filename), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset checksum mismatch."):
NASAMarineDebris(root=str(tmp_path), download=False, checksum=True)
def test_corrupted_new_download(
self, tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
with pytest.raises(RuntimeError, match="Dataset checksum mismatch."):
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_corrupted)
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "

Просмотреть файл

@ -17,7 +17,7 @@ from rasterio.crs import CRS
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
# TODO: read geospatial information from stac.json files
@ -56,6 +56,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
"""
dataset_id = "ts_cashew_benin"
collection_ids = ["ts_cashew_benin_source", "ts_cashew_benin_labels"]
image_meta = {
"filename": "ts_cashew_benin_source.tar.gz",
"md5": "957272c86e518a925a4e0d90dab4f92d",
@ -416,7 +417,8 @@ class BeninSmallHolderCashews(NonGeoDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)
image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])

Просмотреть файл

@ -14,7 +14,7 @@ import torch
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
# TODO: read geospatial information from stac.json files
@ -54,7 +54,12 @@ class CloudCoverDetection(NonGeoDataset):
.. versionadded:: 0.4
"""
dataset_id = "ref_cloud_cover_detection_challenge_v1"
collection_ids = [
"ref_cloud_cover_detection_challenge_v1_train_source",
"ref_cloud_cover_detection_challenge_v1_train_labels",
"ref_cloud_cover_detection_challenge_v1_test_source",
"ref_cloud_cover_detection_challenge_v1_test_labels",
]
image_meta = {
"train": {
@ -332,7 +337,8 @@ class CloudCoverDetection(NonGeoDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)
image_archive_path = os.path.join(
self.root, self.image_meta[self.split]["filename"]

Просмотреть файл

@ -15,7 +15,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
# TODO: read geospatial information from stac.json files
@ -56,7 +56,10 @@ class CV4AKenyaCropType(NonGeoDataset):
imagery and labels from the Radiant Earth MLHub
"""
dataset_id = "ref_african_crops_kenya_02"
collection_ids = [
"ref_african_crops_kenya_02_labels",
"ref_african_crops_kenya_02_source",
]
image_meta = {
"filename": "ref_african_crops_kenya_02_source.tar.gz",
"md5": "9c2004782f6dc83abb1bf45ba4d0da46",
@ -394,7 +397,8 @@ class CV4AKenyaCropType(NonGeoDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)
image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])

Просмотреть файл

@ -15,7 +15,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
class TropicalCyclone(NonGeoDataset):
@ -45,6 +45,12 @@ class TropicalCyclone(NonGeoDataset):
"""
collection_id = "nasa_tropical_storm_competition"
collection_ids = [
"nasa_tropical_storm_competition_train_source",
"nasa_tropical_storm_competition_test_source",
"nasa_tropical_storm_competition_train_labels",
"nasa_tropical_storm_competition_test_labels",
]
md5s = {
"train": {
"source": "97e913667a398704ea8d28196d91dad6",
@ -207,7 +213,8 @@ class TropicalCyclone(NonGeoDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub_dataset(self.collection_id, self.root, api_key)
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)
for split, resources in self.md5s.items():
for resource_type in resources:

Просмотреть файл

@ -14,7 +14,7 @@ from torch import Tensor
from torchvision.utils import draw_bounding_boxes
from .geo import NonGeoDataset
from .utils import download_radiant_mlhub_dataset, extract_archive
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
class NASAMarineDebris(NonGeoDataset):
@ -51,7 +51,7 @@ class NASAMarineDebris(NonGeoDataset):
.. versionadded:: 0.2
"""
dataset_id = "nasa_marine_debris"
collection_ids = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
directories = ["nasa_marine_debris_source", "nasa_marine_debris_labels"]
filenames = ["nasa_marine_debris_source.tar.gz", "nasa_marine_debris_labels.tar.gz"]
md5s = ["fe8698d1e68b3f24f0b86b04419a797d", "d8084f5a72778349e07ac90ec1e1d990"]
@ -189,9 +189,11 @@ class NASAMarineDebris(NonGeoDataset):
# Check if zip file already exists (if so then extract)
exists = []
for filename in self.filenames:
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if os.path.exists(filepath):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset checksum mismatch.")
exists.append(True)
extract_archive(filepath)
else:
@ -208,11 +210,13 @@ class NASAMarineDebris(NonGeoDataset):
"to automatically download the dataset."
)
# TODO: need a checksum check in here post downloading
# Download and extract the dataset
download_radiant_mlhub_dataset(self.dataset_id, self.root, self.api_key)
for filename in self.filenames:
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset checksum mismatch.")
extract_archive(filepath)
def plot(