зеркало из https://github.com/microsoft/torchgeo.git
Add download_radiant_mlhub_collection (#152)
* Add download_radiant_mlhub_collection * Add tests * Rename download_radiant_mlhub -> download_radiant_mlhub_dataset * Update test_utils.py * Update test_utils.py * Update test_utils.py Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
459524fedc
Коммит
30a082d33c
|
@ -23,7 +23,7 @@ class Dataset:
|
|||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Dataset:
|
||||
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset()
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class Dataset:
|
|||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Dataset:
|
||||
def fetch(dataset_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset()
|
||||
|
||||
|
||||
|
|
|
@ -19,17 +19,17 @@ TEST_DATA_DIR = "tests/data/spacenet"
|
|||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, collection_id: str) -> None:
|
||||
self.collection_id = collection_id
|
||||
def __init__(self, dataset_id: str) -> None:
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(TEST_DATA_DIR, self.collection_id, "*.tar.gz")
|
||||
glob_path = os.path.join(TEST_DATA_DIR, self.dataset_id, "*.tar.gz")
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset(collection_id)
|
||||
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset(dataset_id)
|
||||
|
||||
|
||||
class TestSpaceNet1:
|
||||
|
@ -42,7 +42,7 @@ class TestSpaceNet1:
|
|||
) -> SpaceNet1:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Dataset, "fetch", fetch
|
||||
radiant_mlhub.Dataset, "fetch", fetch_dataset
|
||||
)
|
||||
test_md5 = "829652022c2df4511ee4ae05bc290250"
|
||||
monkeypatch.setattr(SpaceNet1, "md5", test_md5) # type: ignore[attr-defined]
|
||||
|
|
|
@ -23,7 +23,8 @@ from torchgeo.datasets.utils import (
|
|||
collate_dict,
|
||||
disambiguate_timestamp,
|
||||
download_and_extract_archive,
|
||||
download_radiant_mlhub,
|
||||
download_radiant_mlhub_collection,
|
||||
download_radiant_mlhub_dataset,
|
||||
extract_archive,
|
||||
working_dir,
|
||||
)
|
||||
|
@ -52,10 +53,23 @@ class Dataset:
|
|||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Dataset:
|
||||
class Collection:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(
|
||||
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
|
||||
)
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset()
|
||||
|
||||
|
||||
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
|
||||
return Collection()
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
@ -110,14 +124,24 @@ def test_download_and_extract_archive(
|
|||
)
|
||||
|
||||
|
||||
def test_download_radiant_mlhub(
|
||||
def test_download_radiant_mlhub_dataset(
|
||||
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> None:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Dataset, "fetch", fetch
|
||||
radiant_mlhub.Dataset, "fetch", fetch_dataset
|
||||
)
|
||||
download_radiant_mlhub("", str(tmp_path))
|
||||
download_radiant_mlhub_dataset("", str(tmp_path))
|
||||
|
||||
|
||||
def test_download_radiant_mlhub_collection(
|
||||
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> None:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Collection, "fetch", fetch_collection
|
||||
)
|
||||
download_radiant_mlhub_collection("", str(tmp_path))
|
||||
|
||||
|
||||
def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
|
||||
|
@ -125,7 +149,14 @@ def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
|
|||
ImportError,
|
||||
match="radiant_mlhub is not installed and is required to download this dataset",
|
||||
):
|
||||
download_radiant_mlhub("", "")
|
||||
download_radiant_mlhub_dataset("", "")
|
||||
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="radiant_mlhub is not installed and is required to download this"
|
||||
+ " collection",
|
||||
):
|
||||
download_radiant_mlhub_collection("", "")
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
|
|
|
@ -15,7 +15,7 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub, extract_archive
|
||||
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
|
||||
|
||||
|
||||
# TODO: read geospatial information from stac.json files
|
||||
|
@ -407,7 +407,7 @@ class BeninSmallHolderCashews(VisionDataset):
|
|||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
download_radiant_mlhub(self.dataset_id, self.root, api_key)
|
||||
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
|
||||
|
||||
image_archive_path = os.path.join(self.root, self.image_meta["filename"])
|
||||
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
|
||||
|
|
|
@ -14,7 +14,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub, extract_archive
|
||||
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
|
||||
|
||||
|
||||
# TODO: read geospatial information from stac.json files
|
||||
|
@ -396,7 +396,7 @@ class CV4AKenyaCropType(VisionDataset):
|
|||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
download_radiant_mlhub(self.dataset_id, self.root, api_key)
|
||||
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
|
||||
|
||||
image_archive_path = os.path.join(self.root, self.image_meta["filename"])
|
||||
target_archive_path = os.path.join(self.root, self.target_meta["filename"])
|
||||
|
|
|
@ -14,7 +14,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub, extract_archive
|
||||
from .utils import check_integrity, download_radiant_mlhub_dataset, extract_archive
|
||||
|
||||
|
||||
class TropicalCycloneWindEstimation(VisionDataset):
|
||||
|
@ -201,7 +201,7 @@ class TropicalCycloneWindEstimation(VisionDataset):
|
|||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
download_radiant_mlhub(self.collection_id, self.root, api_key)
|
||||
download_radiant_mlhub_dataset(self.collection_id, self.root, api_key)
|
||||
|
||||
for split, resources in self.md5s.items():
|
||||
for resource_type in resources:
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
from torchgeo.datasets.geo import VisionDataset
|
||||
from torchgeo.datasets.utils import (
|
||||
check_integrity,
|
||||
download_radiant_mlhub,
|
||||
download_radiant_mlhub_dataset,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
@ -227,7 +227,7 @@ class SpaceNet1(VisionDataset):
|
|||
print("Files already downloaded")
|
||||
return
|
||||
|
||||
download_radiant_mlhub(self.dataset_id, self.root, api_key)
|
||||
download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key)
|
||||
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
|
||||
if (
|
||||
self.checksum
|
||||
|
|
|
@ -125,7 +125,7 @@ def download_and_extract_archive(
|
|||
extract_archive(archive, extract_root)
|
||||
|
||||
|
||||
def download_radiant_mlhub(
|
||||
def download_radiant_mlhub_dataset(
|
||||
dataset_id: str, download_root: str, api_key: Optional[str] = None
|
||||
) -> None:
|
||||
"""Download a dataset from Radiant Earth.
|
||||
|
@ -148,6 +148,29 @@ def download_radiant_mlhub(
|
|||
dataset.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
||||
def download_radiant_mlhub_collection(
|
||||
collection_id: str, download_root: str, api_key: Optional[str] = None
|
||||
) -> None:
|
||||
"""Download a collection from Radiant Earth.
|
||||
|
||||
Args:
|
||||
collection_id: the ID of the collection to fetch
|
||||
download_root: directory to download to
|
||||
api_key: the API key to use for all requests from the session. Can also be
|
||||
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
|
||||
``~/.mlhub/profiles``.
|
||||
"""
|
||||
try:
|
||||
import radiant_mlhub
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"radiant_mlhub is not installed and is required to download this collection"
|
||||
)
|
||||
|
||||
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
|
||||
collection.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
||||
class BoundingBox(Tuple[float, float, float, float, float, float]):
|
||||
"""Data class for indexing spatiotemporal data.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче