Add download_radiant_mlhub_collection (#152)

* Add download_radiant_mlhub_collection

* Add tests

* Rename download_radiant_mlhub -> download_radiant_mlhub_dataset

* Update test_utils.py

* Update test_utils.py

* Update test_utils.py

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Ashwin Nair 2021-09-20 11:23:42 +04:00 коммит произвёл GitHub
Родитель 459524fedc
Коммит 30a082d33c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 77 добавлений и 23 удалений

Просмотреть файл

@ -23,7 +23,7 @@ class Dataset:
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()

Просмотреть файл

@ -25,7 +25,7 @@ class Dataset:
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()

Просмотреть файл

@ -19,17 +19,17 @@ TEST_DATA_DIR = "tests/data/spacenet"
class Dataset:
def __init__(self, collection_id: str) -> None:
self.collection_id = collection_id
def __init__(self, dataset_id: str) -> None:
self.dataset_id = dataset_id
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(TEST_DATA_DIR, self.collection_id, "*.tar.gz")
glob_path = os.path.join(TEST_DATA_DIR, self.dataset_id, "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
return Dataset(collection_id)
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset(dataset_id)
class TestSpaceNet1:
@ -42,7 +42,7 @@ class TestSpaceNet1:
) -> SpaceNet1:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Dataset, "fetch", fetch
radiant_mlhub.Dataset, "fetch", fetch_dataset
)
test_md5 = "829652022c2df4511ee4ae05bc290250"
monkeypatch.setattr(SpaceNet1, "md5", test_md5) # type: ignore[attr-defined]

Просмотреть файл

@ -23,7 +23,8 @@ from torchgeo.datasets.utils import (
collate_dict,
disambiguate_timestamp,
download_and_extract_archive,
download_radiant_mlhub,
download_radiant_mlhub_collection,
download_radiant_mlhub_dataset,
extract_archive,
working_dir,
)
@ -52,10 +53,23 @@ class Dataset:
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
)
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
return Collection()
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
@ -110,14 +124,24 @@ def test_download_and_extract_archive(
)
def test_download_radiant_mlhub(
def test_download_radiant_mlhub_dataset(
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Dataset, "fetch", fetch
radiant_mlhub.Dataset, "fetch", fetch_dataset
)
download_radiant_mlhub("", str(tmp_path))
download_radiant_mlhub_dataset("", str(tmp_path))
def test_download_radiant_mlhub_collection(
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Collection, "fetch", fetch_collection
)
download_radiant_mlhub_collection("", str(tmp_path))
def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
@ -125,7 +149,14 @@ def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
ImportError,
match="radiant_mlhub is not installed and is required to download this dataset",
):
download_radiant_mlhub("", "")
download_radiant_mlhub_dataset("", "")
with pytest.raises(
ImportError,
match="radiant_mlhub is not installed and is required to download this"
+ " collection",
):
download_radiant_mlhub_collection("", "")
class TestBoundingBox:

Просмотреть файл

@ -15,7 +15,7 @@ import torch
from torch import Tensor
from .geo import VisionDataset
from .utils import check_integrity, download_radiant_mlhub, extract_archive
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
# TODO: read geospatial information from stac.json files
@ -407,7 +407,7 @@ class BeninSmallHolderCashews(VisionDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub(self.dataset_id, self.root, api_key)
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])

Просмотреть файл

@ -14,7 +14,7 @@ from PIL import Image
from torch import Tensor
from .geo import VisionDataset
from .utils import check_integrity, download_radiant_mlhub, extract_archive
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
# TODO: read geospatial information from stac.json files
@ -396,7 +396,7 @@ class CV4AKenyaCropType(VisionDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub(self.dataset_id, self.root, api_key)
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
image_archive_path = os.path.join(self.root, self.image_meta["filename"])
target_archive_path = os.path.join(self.root, self.target_meta["filename"])

Просмотреть файл

@ -14,7 +14,7 @@ from PIL import Image
from torch import Tensor
from .geo import VisionDataset
from .utils import check_integrity, download_radiant_mlhub, extract_archive
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
class TropicalCycloneWindEstimation(VisionDataset):
@ -201,7 +201,7 @@ class TropicalCycloneWindEstimation(VisionDataset):
print("Files already downloaded and verified")
return
download_radiant_mlhub(self.collection_id, self.root, api_key)
download_radiant_mlhub_dataset(self.collection_id, self.root, api_key)
for split, resources in self.md5s.items():
for resource_type in resources:

Просмотреть файл

@ -18,7 +18,7 @@ from torch import Tensor
from torchgeo.datasets.geo import VisionDataset
from torchgeo.datasets.utils import (
check_integrity,
download_radiant_mlhub,
download_radiant_mlhub_dataset,
extract_archive,
)
@ -227,7 +227,7 @@ class SpaceNet1(VisionDataset):
print("Files already downloaded")
return
download_radiant_mlhub(self.dataset_id, self.root, api_key)
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
if (
self.checksum

Просмотреть файл

@ -125,7 +125,7 @@ def download_and_extract_archive(
extract_archive(archive, extract_root)
def download_radiant_mlhub(
def download_radiant_mlhub_dataset(
dataset_id: str, download_root: str, api_key: Optional[str] = None
) -> None:
"""Download a dataset from Radiant Earth.
@ -148,6 +148,29 @@ def download_radiant_mlhub(
dataset.download(output_dir=download_root, api_key=api_key)
def download_radiant_mlhub_collection(
collection_id: str, download_root: str, api_key: Optional[str] = None
) -> None:
"""Download a collection from Radiant Earth.
Args:
collection_id: the ID of the collection to fetch
download_root: directory to download to
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
"""
try:
import radiant_mlhub
except ImportError:
raise ImportError(
"radiant_mlhub is not installed and is required to download this collection"
)
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
collection.download(output_dir=download_root, api_key=api_key)
class BoundingBox(Tuple[float, float, float, float, float, float]):
"""Data class for indexing spatiotemporal data.